diff --git a/src/Mapster.EFCore.Tests/EFCoreTest.cs b/src/Mapster.EFCore.Tests/EFCoreTest.cs index eec16e14..6fba5c15 100644 --- a/src/Mapster.EFCore.Tests/EFCoreTest.cs +++ b/src/Mapster.EFCore.Tests/EFCoreTest.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; namespace Mapster.EFCore.Tests { @@ -45,6 +46,34 @@ public void TestFindObject() first.Grade.ShouldBe(Grade.F); } + [TestMethod] + public async Task TestFindSingleObjectUsingProjectToType() + { + var options = new DbContextOptionsBuilder() + .UseInMemoryDatabase(Guid.NewGuid().ToString("N")) + .Options; + var context = new SchoolContext(options); + DbInitializer.Initialize(context); + + var mapsterInstance = new Mapper(); + + var query = context.Students.Where(s => s.ID == 1); + + async Task FirstExecute() => + await mapsterInstance.From(query) + .ProjectToType() + .FirstOrDefaultAsync(); + + await Should.NotThrowAsync(async () => + { + var first = await FirstExecute(); + + first.ShouldNotBeNull(); + first.ID.ShouldBe(1); + first.LastName.ShouldBe("Alexander"); + }); + } + [TestMethod] public void MapperInstance_From_OrderBy() { diff --git a/src/Mapster.EFCore/MapsterQueryable.cs b/src/Mapster.EFCore/MapsterQueryable.cs index 8d9ffb8a..cc4cd6be 100644 --- a/src/Mapster.EFCore/MapsterQueryable.cs +++ b/src/Mapster.EFCore/MapsterQueryable.cs @@ -82,11 +82,21 @@ public TResult ExecuteAsync(Expression expression, CancellationToken ca { var enumerable = ((IAsyncQueryProvider)_provider).ExecuteAsync(expression, cancellationToken); var enumerableType = typeof(TResult); + if (!IsAsyncEnumerableType(enumerableType)) + { + return enumerable; + } var elementType = enumerableType.GetGenericArguments()[0]; var wrapType = typeof(MapsterAsyncEnumerable<>).MakeGenericType(elementType); return (TResult) Activator.CreateInstance(wrapType, enumerable, _builder); } + private static bool IsAsyncEnumerableType(Type type) + { + return type.GetInterfaces() + .Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)); + } + public IAsyncEnumerable ExecuteEnumerableAsync(Expression expression, CancellationToken cancellationToken = default) { var enumerable = ((IAsyncQueryProvider)_provider).ExecuteAsync>(expression, cancellationToken);