diff --git a/src/Framework/Core/Controls/GridViewDataSetExtensions.cs b/src/Framework/Core/Controls/GridViewDataSetExtensions.cs index 0b4d04756b..e78e10bc2a 100644 --- a/src/Framework/Core/Controls/GridViewDataSetExtensions.cs +++ b/src/Framework/Core/Controls/GridViewDataSetExtensions.cs @@ -62,16 +62,7 @@ public static async Task LoadFromQueryableAsync(this IGridViewDataSet data var filtered = filteringOptions.ApplyToQueryable(queryable); var sorted = sortingOptions.ApplyToQueryable(filtered); var paged = pagingOptions.ApplyToQueryable(sorted); - if (paged is not IAsyncEnumerable asyncPaged) - { - throw new ArgumentException($"The specified IQueryable ({queryable.GetType().FullName}), does not support async enumeration. Please use the LoadFromQueryable method.", nameof(queryable)); - } - - var result = new List(); - await foreach (var item in asyncPaged.WithCancellation(cancellationToken)) - { - result.Add(item); - } + var result = (await AsyncQueryableImplementation.QueryableToListAsync(paged, cancellationToken)).ToList(); dataSet.Items = result; if (pagingOptions is IPagingOptionsLoadingPostProcessor pagingOptionsLoadingPostProcessor) diff --git a/src/Framework/Core/Controls/Options/AsyncQueryableImplementation.cs b/src/Framework/Core/Controls/Options/AsyncQueryableImplementation.cs new file mode 100644 index 0000000000..f46c710f33 --- /dev/null +++ b/src/Framework/Core/Controls/Options/AsyncQueryableImplementation.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace DotVVM.Framework.Controls +{ + public static class AsyncQueryableImplementation + { + public static async Task> QueryableToListAsync(IQueryable queryable, CancellationToken cancellationToken) + { + if (queryable is IAsyncEnumerable asyncPaged) + { + // use IAsyncEnumerable implementation + var result = new List(); + await foreach (var item in asyncPaged.WithCancellation(cancellationToken)) + { + result.Add(item); + } + return result; + } + + var queryableType = queryable.GetType(); + if (queryableType is { Namespace: "Marten.Linq", Name: "MartenLinqQueryable`1" }) + { + var result = await MartenToListAsync(queryable, queryableType, cancellationToken); + if (result is not null) + { + return result; + } + } + + throw new ArgumentException($"The specified IQueryable ({queryable.GetType().FullName}), does not support async enumeration. Please use the LoadFromQueryable method.", nameof(queryable)); + } + + + static MethodInfo? martenMethodCache; + private static Task?> MartenToListAsync(IQueryable queryable, Type queryableType, CancellationToken ct) + { + var toListAsyncMethod = martenMethodCache ?? queryableType.Assembly.GetType("Marten.QueryableExtensions")!.GetMethods().SingleOrDefault(m => m.Name == "ToListAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken)); + if (toListAsyncMethod is null) + { + return Task.FromResult?>(null); + } + + if (martenMethodCache is null) + Interlocked.CompareExchange(ref martenMethodCache, toListAsyncMethod, null); + + var toListMethodGeneric = toListAsyncMethod.MakeGenericMethod(typeof(T)); + var result = toListMethodGeneric.Invoke(null, [queryable, ct])!; + return (Task?>)result; + } + } +} diff --git a/src/Framework/Core/Controls/Options/PagingImplementation.cs b/src/Framework/Core/Controls/Options/PagingImplementation.cs index db6f117ffa..988278e908 100644 --- a/src/Framework/Core/Controls/Options/PagingImplementation.cs +++ b/src/Framework/Core/Controls/Options/PagingImplementation.cs @@ -46,6 +46,7 @@ public static async Task QueryableAsyncCount(IQueryable queryable, Ca // CustomAsyncQueryableCountDelegate, and we do accept PRs adding new heuristics ;) ) return await ( EfCoreAsyncCountHack(queryable, queryableType, ct) ?? + MartenAsyncCountHack(queryable, queryableType, ct) ?? StandardAsyncCountHack(queryable, ct) ); } @@ -70,6 +71,23 @@ public static async Task QueryableAsyncCount(IQueryable queryable, Ca return (Task)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!; } + static MethodInfo? martenMethodCache; + static Task? MartenAsyncCountHack(IQueryable queryable, Type queryableType, CancellationToken ct) + { + if (!(queryableType.Namespace == "Marten.Linq" && queryableType.Name == "MartenLinqQueryable`1")) + return null; + + var countMethod = martenMethodCache ?? queryableType.Assembly.GetType("Marten.QueryableExtensions")!.GetMethods().SingleOrDefault(m => m.Name == "CountAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken)); + if (countMethod is null) + return null; + + if (martenMethodCache is null) + Interlocked.CompareExchange(ref martenMethodCache, countMethod, null); + + var countMethodGeneric = countMethod.MakeGenericMethod(typeof(T)); + return (Task)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!; + } + static Task StandardAsyncCountHack(IQueryable queryable, CancellationToken ct) { #if NETSTANDARD2_1_OR_GREATER