diff --git a/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs b/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs index 8b163a50..221fd80a 100644 --- a/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs +++ b/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs @@ -18,7 +18,7 @@ type ObjectListFilter = | In of FieldFilter | StartsWith of FieldFilter | EndsWith of FieldFilter - | Contains of FieldFilter + | Contains of FieldFilter | OfTypes of Type list | FilterField of FieldFilter @@ -142,25 +142,27 @@ module ObjectListFilter = let private StringStartsWithMethod = stringType.GetMethod ("StartsWith", [| stringType |]) let private StringEndsWithMethod = stringType.GetMethod ("EndsWith", [| stringType |]) let private StringContainsMethod = stringType.GetMethod ("Contains", [| stringType |]) - let private getEnumerableContainsMethod (memberType : Type) = + let private getCollectionInstanceContainsMethod (memberType : Type) = + memberType + .GetMethods(BindingFlags.Instance ||| BindingFlags.Public) + .FirstOrDefault (fun m -> m.Name = "Contains" && m.GetParameters().Length = 1) + |> ValueOption.ofObj + let private getEnumerableContainsMethod (itemType : Type) = match typeof .GetMethods(BindingFlags.Static ||| BindingFlags.Public) .FirstOrDefault (fun m -> m.Name = "Contains" && m.GetParameters().Length = 2) with | null -> raise (MissingMemberException "Static 'Contains' method with 2 parameters not found on 'Enumerable' class") - | containsGenericStaticMethod -> - if - memberType.IsGenericType - && memberType.GenericTypeArguments.Length = 1 - then - containsGenericStaticMethod.MakeGenericMethod (memberType.GenericTypeArguments) - else - let ienumerable = - memberType - .GetInterfaces() - .First (fun i -> i.FullName.StartsWith "System.Collections.Generic.IEnumerable`1") - containsGenericStaticMethod.MakeGenericMethod ([| ienumerable.GenericTypeArguments[0] |]) + | containsGenericStaticMethod -> containsGenericStaticMethod.MakeGenericMethod ([| itemType |]) + let private getEnumerableCastMethod (itemType : Type) = + match + typeof + .GetMethods(BindingFlags.Static ||| BindingFlags.Public) + .FirstOrDefault (fun m -> m.Name = "Cast" && m.GetParameters().Length = 1) + with + | null -> raise (MissingMemberException "Static 'Cast' method with 1 parameter not found on 'Enumerable' class") + | castGenericStaticMethod -> castGenericStaticMethod.MakeGenericMethod ([| itemType |]) let getField (param : ParameterExpression) fieldName = Expression.PropertyOrField (param, fieldName) @@ -204,38 +206,38 @@ module ObjectListFilter = && memberType .GetInterfaces() .Any (fun i -> i.FullName.StartsWith "System.Collections.Generic.IEnumerable`1") + let callContains memberType = + let itemType = + if ``member``.Type.IsArray then ``member``.Type.GetElementType() + else ``member``.Type.GetGenericArguments()[0] + let valueType = + match f.Value with + | null -> itemType + | value -> value.GetType() + let castedMember = + if itemType = valueType then ``member`` :> Expression + else + let castMethod = getEnumerableCastMethod valueType + Expression.Call (castMethod, ``member``) + match getCollectionInstanceContainsMethod memberType with + | ValueNone -> + let enumerableContains = getEnumerableContainsMethod valueType + Expression.Call (enumerableContains, castedMember, Expression.Constant (f.Value)) + | ValueSome instanceContainsMethod -> + Expression.Call (castedMember, instanceContainsMethod, Expression.Constant (f.Value)) match ``member``.Member with - | :? PropertyInfo as prop when prop.PropertyType |> isEnumerable -> - match - prop.PropertyType - .GetMethods(BindingFlags.Instance ||| BindingFlags.Public) - .FirstOrDefault (fun m -> m.Name = "Contains" && m.GetParameters().Length = 1) - with - | null -> - Expression.Call ( - getEnumerableContainsMethod prop.PropertyType, - Expression.PropertyOrField (param, f.FieldName), - Expression.Constant (f.Value) - ) - | instanceContainsMethod -> - Expression.Call (Expression.PropertyOrField (param, f.FieldName), instanceContainsMethod, Expression.Constant (f.Value)) - | :? FieldInfo as field when field.FieldType |> isEnumerable -> - Expression.Call ( - getEnumerableContainsMethod field.FieldType, - Expression.PropertyOrField (param, f.FieldName), - Expression.Constant (f.Value) - ) + | :? PropertyInfo as prop when prop.PropertyType |> isEnumerable -> callContains prop.PropertyType + | :? FieldInfo as field when field.FieldType |> isEnumerable -> callContains field.FieldType | _ -> if ``member``.Type = stringType then Expression.Call (``member``, StringContainsMethod, Expression.Constant (f.Value)) else Expression.Call (Expression.Convert (``member``, stringType), StringContainsMethod, Expression.Constant (f.Value)) - | In f -> + | In f when not (f.Value.IsEmpty) -> let ``member`` = Expression.PropertyOrField (param, f.FieldName) - f.Value - |> Seq.map (fun v -> Expression.Equal (``member``, Expression.Constant (v))) - |> Seq.reduce (fun acc expr -> Expression.OrElse (acc, expr)) - :> Expression + let enumerableContains = getEnumerableContainsMethod typeof + Expression.Call (enumerableContains, Expression.Constant (f.Value), Expression.Convert (``member``, typeof)) + | In f -> Expression.Constant (true) | OfTypes types -> types |> Seq.map (fun t -> buildTypeDiscriminatorCheck param t) diff --git a/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs index 29431240..7dc5a1e5 100644 --- a/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs +++ b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs @@ -79,6 +79,8 @@ type ValidIntObject = type FakeEntity = { ValidStringStruct : ValidStringStruct ValidStringObject : ValidStringObject + ValidStringStructList : ValidStringStruct list + ValidStringObjectList : ValidStringObject list string : string ValidIntStruct : ValidIntStruct ValidIntObject : ValidIntObject @@ -125,6 +127,23 @@ let ``ObjectListFilter works with Contains operator for ValidStringStruct`` () = let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE CONTAINS(root["validStringStruct"], "athan")""" +[] +let ``ObjectListFilter works with Contains operator for ValidStringStruct list`` () = + let queryable = container.GetItemLinqQueryable () + let filter = Contains { FieldName = "validStringStructList"; Value = "athan" } + let filterQuery = queryable.Apply (filter, filterOptions) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE ARRAY_CONTAINS(root["validStringStructList"], "athan")""" + +[] +let ``ObjectListFilter works with In operator for ValidStringStruct list`` () = + let queryable = container.GetItemLinqQueryable () + let filter = In { FieldName = "validStringStruct"; Value = [ "athan"; "gaja" ] } + let filterQuery = queryable.Apply (filter, filterOptions) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE ARRAY_CONTAINS([ "athan", "gaja" ], root["validStringStruct"])""" + + [] let ``ObjectListFilter works with Equals operator for ValidStringObject`` () = let filter = Equals { FieldName = "validStringObject"; Value = ValidStringObject "Jonathan" } @@ -133,6 +152,7 @@ let ``ObjectListFilter works with Equals operator for ValidStringObject`` () = let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE (root["validStringObject"] = "Jonathan")""" + [] let ``ObjectListFilter works with GreaterThan operator for ValidIntStruct`` () = let queryable = container.GetItemLinqQueryable () diff --git a/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqTests.fs b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqTests.fs index d5a6bf82..f1c06fa8 100644 --- a/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqTests.fs +++ b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqTests.fs @@ -189,6 +189,27 @@ let ``ObjectListFilter works with IN operator for int type field`` () = result.Contact |> equals { Email = "b.adams@gmail.com" } result.Friends |> equals [ { Email = "j.abrams@gmail.com" }; { Email = "l.trif@gmail.com" } ] +[] +let ``ObjectListFilter works with Contains operator for array type field`` () = + let filter = Contains { FieldName = "friends"; Value = { Email = "j.abrams@gmail.com" } } + let queryable = data.AsQueryable () + let filteredData = queryable.Apply (filter) |> Seq.toList + List.length filteredData |> equals 2 + do + let result = List.head filteredData + result.ID |> equals 4 + result.FirstName |> equals "Ben" + result.LastName |> equals "Adams" + result.Contact |> equals { Email = "b.adams@gmail.com" } + result.Friends |> equals [ { Email = "j.abrams@gmail.com" }; { Email = "l.trif@gmail.com" } ] + do + let result = List.last filteredData + result.ID |> equals 7 + result.FirstName |> equals "Jeneffer" + result.LastName |> equals "Trif" + result.Contact |> equals { Email = "j.trif@gmail.com" } + result.Friends |> equals [ { Email = "j.abrams@gmail.com" } ] + [] let ``ObjectListFilter works with FilterField operator`` () = let filter =