diff --git a/samples/UnitOfWork.Host/Controllers/ValuesController.cs b/samples/UnitOfWork.Host/Controllers/ValuesController.cs index 5fcfe36..83d0052 100644 --- a/samples/UnitOfWork.Host/Controllers/ValuesController.cs +++ b/samples/UnitOfWork.Host/Controllers/ValuesController.cs @@ -13,7 +13,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Host.Controllers public class ValuesController : Controller { private readonly IUnitOfWork _unitOfWork; - private ILogger _logger; + private readonly ILogger _logger; // 1. IRepositoryFactory used for readonly scenario; // 2. IUnitOfWork used for read/write scenario; @@ -27,135 +27,137 @@ public ValuesController(IUnitOfWork unitOfWork, ILogger logger var repo = _unitOfWork.GetRepository(hasCustomRepository: true); if (repo.Count() == 0) { - repo.Insert(new Blog - { - Id = 1, - Url = "/a/" + 1, - Title = $"a{1}", - Posts = new List{ - new Post - { - Id = 1, - Title = "A", - Content = "A's content", - Comments = new List - { - new Comment - { - Id = 1, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 2, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 3, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 2, - Title = "B", - Content = "B's content", - Comments = new List - { - new Comment - { - Id = 4, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 5, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 6, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 3, - Title = "C", - Content = "C's content", - Comments = new List - { - new Comment - { - Id = 7, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 8, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 9, - Title = "c", - Content = "c's content", - } - }, - }, - new Post - { - Id = 4, - Title = "D", - Content = "D's content", - Comments = new List - { - new Comment - { - Id = 10, - Title = "A", - Content = "A's content", - }, - new Comment - { - Id = 11, - Title = "b", - Content = "b's content", - }, - new Comment - { - Id = 12, - Title = "c", - Content = "c's content", - } - }, - } - }, - }); + SeedInitialEntities(repo); _unitOfWork.SaveChanges(); } } + private static void SeedInitialEntities(IRepository repo) + => repo.Insert(new Blog + { + Id = 1, + Url = "/a/" + 1, + Title = $"a{1}", + Posts = new List{ + new Post + { + Id = 1, + Title = "A", + Content = "A's content", + Comments = new List + { + new Comment + { + Id = 1, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 2, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 3, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 2, + Title = "B", + Content = "B's content", + Comments = new List + { + new Comment + { + Id = 4, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 5, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 6, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 3, + Title = "C", + Content = "C's content", + Comments = new List + { + new Comment + { + Id = 7, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 8, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 9, + Title = "c", + Content = "c's content", + } + }, + }, + new Post + { + Id = 4, + Title = "D", + Content = "D's content", + Comments = new List + { + new Comment + { + Id = 10, + Title = "A", + Content = "A's content", + }, + new Comment + { + Id = 11, + Title = "b", + Content = "b's content", + }, + new Comment + { + Id = 12, + Title = "c", + Content = "c's content", + } + }, + } + }, + }); + // GET api/values [HttpGet] - public async Task> Get() - { - return await _unitOfWork.GetRepository().GetAllAsync(include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); - } + public async Task> Get() + => await _unitOfWork.GetRepository() + .GetAllAsync(include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); // GET api/values/Page/5/10 - [HttpGet("Page/{pageIndex}/{pageSize}")] + [HttpGet("Page/{pageIndex:int}/{pageSize:int}")] public async Task> Get(int pageIndex, int pageSize) { // projection @@ -170,11 +172,11 @@ public async Task> Get(string term) { _logger.LogInformation("demo about first or default with include"); - var item = _unitOfWork.GetRepository().GetFirstOrDefault(predicate: x => x.Title.Contains(term), include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); + var item = await _unitOfWork.GetRepository().GetFirstOrDefaultAsync(predicate: x => x.Title.Contains(term), include: source => source.Include(blog => blog.Posts).ThenInclude(post => post.Comments)); _logger.LogInformation("demo about first or default without include"); - item = _unitOfWork.GetRepository().GetFirstOrDefault(predicate: x => x.Title.Contains(term), orderBy: source => source.OrderByDescending(b => b.Id)); + item = await _unitOfWork.GetRepository().GetFirstOrDefaultAsync(predicate: x => x.Title.Contains(term), orderBy: source => source.OrderByDescending(b => b.Id)); _logger.LogInformation("demo about first or default with projection"); @@ -184,11 +186,9 @@ public async Task> Get(string term) } // GET api/values/4 - [HttpGet("{id}")] - public async Task Get(int id) - { - return await _unitOfWork.GetRepository().FindAsync(id); - } + [HttpGet("{id:int}")] + public async Task Get(int id) + => await _unitOfWork.GetRepository().FindAsync(id); // POST api/values [HttpPost] diff --git a/samples/UnitOfWork.Host/Models/BlogggingContext.cs b/samples/UnitOfWork.Host/Models/BlogggingContext.cs index 1099928..e5fc31d 100644 --- a/samples/UnitOfWork.Host/Models/BlogggingContext.cs +++ b/samples/UnitOfWork.Host/Models/BlogggingContext.cs @@ -12,10 +12,7 @@ public BloggingContext(DbContextOptions options) public DbSet Blogs { get; set; } public DbSet Posts { get; set; } - protected override void OnModelCreating(ModelBuilder modelBuilder) - { - modelBuilder.EnableAutoHistory(null); - } + protected override void OnModelCreating(ModelBuilder modelBuilder) => modelBuilder.EnableAutoHistory(null); } public class Blog diff --git a/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs b/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs index b33e932..571954d 100644 --- a/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs +++ b/samples/UnitOfWork.Host/Models/CustomBlogRepository.cs @@ -1,14 +1,9 @@ -using Arch.EntityFrameworkCore.UnitOfWork; -using Arch.EntityFrameworkCore.UnitOfWork.Host.Models; -using Microsoft.EntityFrameworkCore; - -namespace Host.Models +namespace Arch.EntityFrameworkCore.UnitOfWork.Host.Models { - public class CustomBlogRepository : Repository, IRepository + public class CustomBlogRepository : Repository { public CustomBlogRepository(BloggingContext dbContext) : base(dbContext) { - } } } diff --git a/samples/UnitOfWork.Host/Startup.cs b/samples/UnitOfWork.Host/Startup.cs index 081fce3..a29ae19 100644 --- a/samples/UnitOfWork.Host/Startup.cs +++ b/samples/UnitOfWork.Host/Startup.cs @@ -6,7 +6,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Hosting; -using Host.Models; namespace Arch.EntityFrameworkCore.UnitOfWork.Host { diff --git a/samples/UnitOfWork.Host/UnitOfWork.Host.csproj b/samples/UnitOfWork.Host/UnitOfWork.Host.csproj index 61060bd..03e55c7 100644 --- a/samples/UnitOfWork.Host/UnitOfWork.Host.csproj +++ b/samples/UnitOfWork.Host/UnitOfWork.Host.csproj @@ -1,23 +1,26 @@  + - net5.0 + net6.0 + Arch.EntityFrameworkCore.UnitOfWork.Host true Exe + - - - - - - - - - + + + + + + + + + diff --git a/src/Microsoft.EntityFrameworkCore.UnitOfWork/IRepositoryFactory.cs b/src/Microsoft.EntityFrameworkCore.UnitOfWork/IRepositoryFactory.cs deleted file mode 100644 index 19b808b..0000000 --- a/src/Microsoft.EntityFrameworkCore.UnitOfWork/IRepositoryFactory.cs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Arch team. All rights reserved. - -namespace Microsoft.EntityFrameworkCore -{ - /// - /// Defines the interfaces for interfaces. - /// - public interface IRepositoryFactory - { - /// - /// Gets the specified repository for the . - /// - /// True if providing custom repositry - /// The type of the entity. - /// An instance of type inherited from interface. - IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class; - } -} diff --git a/src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs b/src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs similarity index 97% rename from src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs rename to src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs index 01c8ef5..42697f7 100644 --- a/src/UnitOfWork/Collections/IEnumerablePagedListExtensions.cs +++ b/src/UnitOfWork/Collections/EnumerablePagedListExtensions.cs @@ -8,7 +8,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Collections /// /// Provides some extension methods for to provide paging capability. /// - public static class IEnumerablePagedListExtensions + public static class EnumerablePagedListExtensions { /// /// Converts the specified source to by the specified and . diff --git a/src/UnitOfWork/Collections/PagedList.cs b/src/UnitOfWork/Collections/PagedList.cs index 7dca0db..13d9c40 100644 --- a/src/UnitOfWork/Collections/PagedList.cs +++ b/src/UnitOfWork/Collections/PagedList.cs @@ -70,15 +70,15 @@ internal PagedList(IEnumerable source, int pageIndex, int pageSize, int index throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex"); } - if (source is IQueryable querable) + if (source is IQueryable queryable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; - TotalCount = querable.Count(); + TotalCount = queryable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); - Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); + Items = queryable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); } else { @@ -95,7 +95,7 @@ internal PagedList(IEnumerable source, int pageIndex, int pageSize, int index /// /// Initializes a new instance of the class. /// - internal PagedList() => Items = new T[0]; + internal PagedList() => Items = Array.Empty(); } @@ -165,15 +165,15 @@ public PagedList(IEnumerable source, Func, IEnumer throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex"); } - if (source is IQueryable querable) + if (source is IQueryable queryable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; - TotalCount = querable.Count(); + TotalCount = queryable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); - var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); + var items = queryable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List(converter(items)); } diff --git a/src/UnitOfWork/Collections/IQueryablePageListExtensions.cs b/src/UnitOfWork/Collections/QueryablePageListExtensions.cs similarity index 84% rename from src/UnitOfWork/Collections/IQueryablePageListExtensions.cs rename to src/UnitOfWork/Collections/QueryablePageListExtensions.cs index 9578aae..3d198cf 100644 --- a/src/UnitOfWork/Collections/IQueryablePageListExtensions.cs +++ b/src/UnitOfWork/Collections/QueryablePageListExtensions.cs @@ -6,7 +6,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Collections { - public static class IQueryablePageListExtensions + public static class QueryablePageListExtensions { /// /// Converts the specified source to by the specified and . @@ -20,7 +20,7 @@ public static class IQueryablePageListExtensions /// /// The start index value. /// An instance of the inherited from interface. - public static async Task> ToPagedListAsync(this IQueryable source, int pageIndex, int pageSize, int indexFrom = 0, CancellationToken cancellationToken = default(CancellationToken)) + public static async Task> ToPagedListAsync(this IQueryable source, int pageIndex, int pageSize, int indexFrom = 0, CancellationToken cancellationToken = default) { if (indexFrom > pageIndex) { @@ -28,8 +28,11 @@ public static class IQueryablePageListExtensions } var count = await source.CountAsync(cancellationToken).ConfigureAwait(false); - var items = await source.Skip((pageIndex - indexFrom) * pageSize) - .Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false); + var items = await source + .Skip((pageIndex - indexFrom) * pageSize) + .Take(pageSize) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); var pagedList = new PagedList() { diff --git a/src/UnitOfWork/IRepository.cs b/src/UnitOfWork/IRepository.cs index b4f94b3..6f96596 100644 --- a/src/UnitOfWork/IRepository.cs +++ b/src/UnitOfWork/IRepository.cs @@ -33,7 +33,7 @@ public interface IRepository where TEntity : class void ChangeTable(string table); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -44,16 +44,17 @@ public interface IRepository where TEntity : class /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - IPagedList GetPagedList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false); + IPagedList GetPagedList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -67,17 +68,18 @@ IPagedList GetPagedList(Expression> predicate = nul /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - Task> GetPagedListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false); + Task> GetPagedListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -89,17 +91,18 @@ Task> GetPagedListAsync(Expression> pred /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - IPagedList GetPagedList(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) where TResult : class; + IPagedList GetPagedList( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) where TResult : class; /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -114,18 +117,19 @@ IPagedList GetPagedList(Expression> sel /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - Task> GetPagedListAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) where TResult : class; + Task> GetPagedListAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) where TResult : class; /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -134,14 +138,15 @@ Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method defaults to a read-only, no-tracking query. - TEntity GetFirstOrDefault(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + TEntity GetFirstOrDefault( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -151,15 +156,16 @@ TEntity GetFirstOrDefault(Expression> predicate = null, /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method defaults to a read-only, no-tracking query. - TResult GetFirstOrDefault(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + TResult GetFirstOrDefault( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -177,7 +183,7 @@ Task GetFirstOrDefaultAsync(Expression> bool ignoreQueryFilters = false); /// - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method defaults to a read-only, no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method defaults to a read-only, no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -238,11 +244,12 @@ Task GetFirstOrDefaultAsync(Expression> predicate = /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - IQueryable GetAll(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + IQueryable GetAll( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// /// Gets all entities. This method is not recommended @@ -278,11 +285,12 @@ IQueryable GetAll(Expression> selector, /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - Task> GetAllAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + Task> GetAllAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// /// Gets all entities. This method is not recommended @@ -333,66 +341,74 @@ Task> GetAllAsync(Expression> sel /// /// Gets the max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - T Max(Expression> predicate = null, Expression> selector = null); + T Max(Expression> selector, Expression> predicate = null); /// /// Gets the async max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - Task MaxAsync(Expression> predicate = null, Expression> selector = null); + Task MaxAsync(Expression> selector, Expression> predicate = null); /// /// Gets the min based on a predicate. /// - /// /// + /// /// decimal - T Min(Expression> predicate = null, Expression> selector = null); + T Min(Expression> selector, Expression> predicate = null); /// /// Gets the async min based on a predicate. /// - /// /// + /// /// decimal - Task MinAsync(Expression> predicate = null, Expression> selector = null); + Task MinAsync(Expression> selector, Expression> predicate = null); /// /// Gets the average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - decimal Average (Expression> predicate = null, Expression> selector = null); + decimal Average(Expression> selector, Expression> predicate = null); /// - /// Gets the async average based on a predicate. - /// - /// - /// /// - /// decimal - Task AverageAsync(Expression> predicate = null, Expression> selector = null); + /// Gets the async average based on a predicate. + /// + /// + /// + /// /// + /// decimal + Task AverageAsync(Expression> selector, + Expression> predicate = null); /// /// Gets the sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - decimal Sum (Expression> predicate = null, Expression> selector = null); + decimal Sum(Expression> selector, Expression> predicate = null); /// /// Gets the async sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - Task SumAsync (Expression> predicate = null, Expression> selector = null); + Task SumAsync(Expression> selector, + Expression> predicate = null); /// /// Gets the Exists record based on a predicate. @@ -431,7 +447,7 @@ Task> GetAllAsync(Expression> sel /// The entity to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)); + ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default); /// /// Inserts a range of entities asynchronously. @@ -446,7 +462,7 @@ Task> GetAllAsync(Expression> sel /// The entities to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default(CancellationToken)); + Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default); /// /// Updates the specified entity. @@ -511,7 +527,7 @@ Task> GetAllAsync(Expression> sel /// - /// Gets the based on a predicate, orderby delegate. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -523,15 +539,16 @@ Task> GetAllAsync(Expression> sel /// /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - Task> GetListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false, - CancellationToken cancellationToken = default(CancellationToken)); + Task> GetListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false, + CancellationToken cancellationToken = default); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -540,11 +557,12 @@ Task> GetListAsync(Expression> predicate = nul /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - List GetList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false); + List GetList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false); /// diff --git a/src/UnitOfWork/IRepositoryFactory.cs b/src/UnitOfWork/IRepositoryFactory.cs index c5f017f..885ef48 100644 --- a/src/UnitOfWork/IRepositoryFactory.cs +++ b/src/UnitOfWork/IRepositoryFactory.cs @@ -14,7 +14,7 @@ public interface IRepositoryFactory /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class; diff --git a/src/UnitOfWork/IUnitOfWork.cs b/src/UnitOfWork/IUnitOfWork.cs index 4e7a1c0..06fb276 100644 --- a/src/UnitOfWork/IUnitOfWork.cs +++ b/src/UnitOfWork/IUnitOfWork.cs @@ -16,7 +16,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// /// Defines the interface(s) for unit of work. /// - public interface IUnitOfWork : IDisposable + public interface IUnitOfWork : IDisposable, IAsyncDisposable { /// /// Changes the database name. This require the databases in the same machine. NOTE: This only work for MySQL right now. @@ -30,7 +30,7 @@ public interface IUnitOfWork : IDisposable /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class; @@ -38,7 +38,7 @@ public interface IUnitOfWork : IDisposable /// /// Saves all changes made in this context to the database. /// - /// True if sayve changes ensure auto record the change history. + /// True if save changes ensure auto record the change history. /// The number of state entries written to the database. int SaveChanges(bool ensureAutoHistory = false); @@ -79,11 +79,11 @@ public interface IUnitOfWork : IDisposable /// Uses TrakGrap Api to attach disconnected entities /// /// Root entity - /// Delegate to convert Object's State properities to Entities entry state. + /// Delegate to convert Object's State properties to Entities entry state. void TrackGraph(object rootEntity, Action callback); /// - /// Starts Databaselevel Transaction + /// Starts DatabaseLevel Transaction /// /// The IsolationLevel /// Transaction Context diff --git a/src/UnitOfWork/IUnitOfWorkOfT.cs b/src/UnitOfWork/IUnitOfWorkOfT.cs index 4ce2e68..cadb711 100644 --- a/src/UnitOfWork/IUnitOfWorkOfT.cs +++ b/src/UnitOfWork/IUnitOfWorkOfT.cs @@ -10,7 +10,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// /// Defines the interface(s) for generic unit of work. /// - public interface IUnitOfWork : IUnitOfWork where TContext : DbContext { + public interface IUnitOfWork : IUnitOfWork where TContext : DbContext { /// /// Gets the db context. /// diff --git a/src/UnitOfWork/Repository.cs b/src/UnitOfWork/Repository.cs index 802af90..2e35e3a 100644 --- a/src/UnitOfWork/Repository.cs +++ b/src/UnitOfWork/Repository.cs @@ -22,8 +22,8 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// The type of the entity. public class Repository : IRepository where TEntity : class { - protected readonly DbContext _dbContext; - protected readonly DbSet _dbSet; + protected readonly DbContext DbContext; + protected readonly DbSet DbSet; /// /// Initializes a new instance of the class. @@ -31,8 +31,8 @@ public class Repository : IRepository where TEntity : class /// The database context. public Repository(DbContext dbContext) { - _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); - _dbSet = _dbContext.Set(); + DbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); + DbSet = DbContext.Set(); } /// @@ -44,7 +44,7 @@ public Repository(DbContext dbContext) /// public virtual void ChangeTable(string table) { - if (_dbContext.Model.FindEntityType(typeof(TEntity)) is IConventionEntityType relational) + if (DbContext.Model.FindEntityType(typeof(TEntity)) is IConventionEntityType relational) { relational.SetTableName(table); } @@ -54,10 +54,7 @@ public virtual void ChangeTable(string table) /// Gets all entities. This method is not recommended /// /// The . - public IQueryable GetAll() - { - return _dbSet; - } + public IQueryable GetAll() => DbSet; /// /// Gets all entities. This method is not recommended @@ -74,7 +71,7 @@ public IQueryable GetAll( Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -100,10 +97,8 @@ public IQueryable GetAll( { return orderBy(query); } - else - { - return query; - } + + return query; } /// @@ -117,12 +112,13 @@ public IQueryable GetAll( /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - public IQueryable GetAll(Expression> selector, + public IQueryable GetAll( + Expression> selector, Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -148,14 +144,12 @@ public IQueryable GetAll(Expression> sel { return orderBy(query).Select(selector); } - else - { - return query.Select(selector); - } + + return query.Select(selector); } /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -166,15 +160,16 @@ public IQueryable GetAll(Expression> sel /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual IPagedList GetPagedList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual IPagedList GetPagedList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -200,14 +195,12 @@ public virtual IPagedList GetPagedList(Expression> { return orderBy(query).ToPagedList(pageIndex, pageSize); } - else - { - return query.ToPagedList(pageIndex, pageSize); - } + + return query.ToPagedList(pageIndex, pageSize); } /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -221,16 +214,17 @@ public virtual IPagedList GetPagedList(Expression> /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual Task> GetPagedListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) + public virtual Task> GetPagedListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -256,14 +250,12 @@ public virtual Task> GetPagedListAsync(Expression - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -275,17 +267,18 @@ public virtual Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual IPagedList GetPagedList(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual IPagedList GetPagedList( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + bool ignoreQueryFilters = false) where TResult : class { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -311,14 +304,12 @@ public virtual IPagedList GetPagedList(Expression - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -333,18 +324,19 @@ public virtual IPagedList GetPagedList(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual Task> GetPagedListAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - int pageIndex = 0, - int pageSize = 20, - bool disableTracking = true, - CancellationToken cancellationToken = default(CancellationToken), - bool ignoreQueryFilters = false) + public virtual Task> GetPagedListAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + int pageIndex = 0, + int pageSize = 20, + bool disableTracking = true, + CancellationToken cancellationToken = default, + bool ignoreQueryFilters = false) where TResult : class { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -370,14 +362,12 @@ public virtual Task> GetPagedListAsync(Expression - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -386,13 +376,14 @@ public virtual Task> GetPagedListAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual TEntity GetFirstOrDefault(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual TEntity GetFirstOrDefault( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -418,21 +409,20 @@ public virtual TEntity GetFirstOrDefault(Expression> predica { return orderBy(query).FirstOrDefault(); } - else - { - return query.FirstOrDefault(); - } + + return query.FirstOrDefault(); } /// - public virtual async Task GetFirstOrDefaultAsync(Expression> predicate = null, + public virtual async Task GetFirstOrDefaultAsync( + Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -458,14 +448,12 @@ public virtual async Task GetFirstOrDefaultAsync(Expression - /// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query. + /// Gets the first or default entity based on a predicate, orderBy delegate and include delegate. This method default no-tracking query. /// /// The selector for projection. /// A function to test each element for a condition. @@ -475,14 +463,15 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionIgnore query filters /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual TResult GetFirstOrDefault(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual TResult GetFirstOrDefault( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -508,20 +497,19 @@ public virtual TResult GetFirstOrDefault(Expression - public virtual async Task GetFirstOrDefaultAsync(Expression> selector, - Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, bool ignoreQueryFilters = false) + public virtual async Task GetFirstOrDefaultAsync( + Expression> selector, + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -547,10 +535,8 @@ public virtual async Task GetFirstOrDefaultAsync(Expression @@ -559,21 +545,24 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionThe raw SQL. /// The parameters. /// An that contains elements that satisfy the condition specified by raw SQL. - public virtual IQueryable FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters); + public virtual IQueryable FromSql(string sql, params object[] parameters) + => DbSet.FromSqlRaw(sql, parameters); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// /// The values of the primary key for the entity to be found. /// The found entity or null. - public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues); + public virtual TEntity Find(params object[] keyValues) + => DbSet.Find(keyValues); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// /// The values of the primary key for the entity to be found. /// A that represents the asynchronous insert operation. - public virtual ValueTask FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); + public virtual ValueTask FindAsync(params object[] keyValues) + => DbSet.FindAsync(keyValues); /// /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. @@ -581,240 +570,162 @@ public virtual async Task GetFirstOrDefaultAsync(ExpressionThe values of the primary key for the entity to be found. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous find operation. The task result contains the found entity or null. - public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken); + public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => DbSet.FindAsync(keyValues, cancellationToken); /// /// Gets the count based on a predicate. /// /// /// - public virtual int Count(Expression> predicate = null) - { - if (predicate == null) - { - return _dbSet.Count(); - } - else - { - return _dbSet.Count(predicate); - } - } + public virtual int Count(Expression> predicate = null) + => predicate == null ? DbSet.Count() : DbSet.Count(predicate); /// /// Gets async the count based on a predicate. /// /// /// - public virtual async Task CountAsync(Expression> predicate = null) - { - if (predicate == null) - { - return await _dbSet.CountAsync(); - } - else - { - return await _dbSet.CountAsync(predicate); - } - } + public virtual async Task CountAsync(Expression> predicate = null) + => predicate == null ? await DbSet.CountAsync() : await DbSet.CountAsync(predicate); /// /// Gets the long count based on a predicate. /// /// /// - public virtual long LongCount(Expression> predicate = null) - { - if (predicate == null) - { - return _dbSet.LongCount(); - } - else - { - return _dbSet.LongCount(predicate); - } - } + public virtual long LongCount(Expression> predicate = null) + => predicate == null ? DbSet.LongCount() : DbSet.LongCount(predicate); /// /// Gets async the long count based on a predicate. /// /// /// - public virtual async Task LongCountAsync(Expression> predicate = null) - { - if (predicate == null) - { - return await _dbSet.LongCountAsync(); - } - else - { - return await _dbSet.LongCountAsync(predicate); - } - } + public virtual async Task LongCountAsync(Expression> predicate = null) + => predicate == null ? await DbSet.LongCountAsync() : await DbSet.LongCountAsync(predicate); /// /// Gets the max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual T Max(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Max(selector); - else - return _dbSet.Where(predicate).Max(selector); - } + public virtual T Max(Expression> selector, Expression> predicate = null) + => predicate == null ? DbSet.Max(selector) : DbSet.Where(predicate).Max(selector); /// /// Gets the async max based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task MaxAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.MaxAsync(selector); - else - return await _dbSet.Where(predicate).MaxAsync(selector); - } + public virtual async Task MaxAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.MaxAsync(selector) : await DbSet.Where(predicate).MaxAsync(selector); /// /// Gets the min based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual T Min(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Min(selector); - else - return _dbSet.Where(predicate).Min(selector); - } + public virtual T Min(Expression> selector, Expression> predicate = null) + => predicate == null ? DbSet.Min(selector) : DbSet.Where(predicate).Min(selector); /// /// Gets the async min based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task MinAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.MinAsync(selector); - else - return await _dbSet.Where(predicate).MinAsync(selector); - } + public virtual async Task MinAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.MinAsync(selector) : await DbSet.Where(predicate).MinAsync(selector); /// /// Gets the average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual decimal Average(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Average(selector); - else - return _dbSet.Where(predicate).Average(selector); - } + public virtual decimal Average(Expression> selector, + Expression> predicate = null) + => predicate == null ? DbSet.Average(selector) : DbSet.Where(predicate).Average(selector); /// /// Gets the async average based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task AverageAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.AverageAsync(selector); - else - return await _dbSet.Where(predicate).AverageAsync(selector); - } + public virtual async Task AverageAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.AverageAsync(selector) : await DbSet.Where(predicate).AverageAsync(selector); /// /// Gets the sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual decimal Sum(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return _dbSet.Sum(selector); - else - return _dbSet.Where(predicate).Sum(selector); - } + public virtual decimal Sum(Expression> selector, + Expression> predicate = null) + => predicate == null ? DbSet.Sum(selector) : DbSet.Where(predicate).Sum(selector); /// /// Gets the async sum based on a predicate. /// + /// /// - /// /// + /// /// /// decimal - public virtual async Task SumAsync(Expression> predicate = null, Expression> selector = null) - { - if (predicate == null) - return await _dbSet.SumAsync(selector); - else - return await _dbSet.Where(predicate).SumAsync(selector); - } + public virtual async Task SumAsync(Expression> selector, + Expression> predicate = null) + => predicate == null ? await DbSet.SumAsync(selector) : await DbSet.Where(predicate).SumAsync(selector); /// /// Gets the exists based on a predicate. /// /// /// - public bool Exists(Expression> selector = null) - { - if (selector == null) - { - return _dbSet.Any(); - } - else - { - return _dbSet.Any(selector); - } - } + public bool Exists(Expression> selector = null) + => selector == null ? DbSet.Any() : DbSet.Any(selector); + /// /// Gets the async exists based on a predicate. /// /// /// - public async Task ExistsAsync(Expression> selector = null) - { - if (selector == null) - { - return await _dbSet.AnyAsync(); - } - else - { - return await _dbSet.AnyAsync(selector); - } - } + public async Task ExistsAsync(Expression> selector = null) + => selector == null ? await DbSet.AnyAsync() : await DbSet.AnyAsync(selector); + /// /// Inserts a new entity synchronously. /// /// The entity to insert. - public virtual TEntity Insert(TEntity entity) - { - return _dbSet.Add(entity).Entity; - } + public virtual TEntity Insert(TEntity entity) + => DbSet.Add(entity).Entity; /// /// Inserts a range of entities synchronously. /// /// The entities to insert. - public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities); + public virtual void Insert(params TEntity[] entities) + => DbSet.AddRange(entities); /// /// Inserts a range of entities synchronously. /// /// The entities to insert. - public virtual void Insert(IEnumerable entities) => _dbSet.AddRange(entities); + public virtual void Insert(IEnumerable entities) + => DbSet.AddRange(entities); /// /// Inserts a new entity asynchronously. @@ -822,24 +733,21 @@ public virtual TEntity Insert(TEntity entity) /// The entity to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - public virtual ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)) - //public virtual Task InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)) - { - return _dbSet.AddAsync(entity, cancellationToken); - - // Shadow properties? - //var property = _dbContext.Entry(entity).Property("Created"); - //if (property != null) { - //property.CurrentValue = DateTime.Now; - //} - } - + public virtual ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default) + => DbSet.AddAsync(entity, cancellationToken); + + // Shadow properties? + //var property = _dbContext.Entry(entity).Property("Created"); + //if (property != null) { + //property.CurrentValue = DateTime.Now; + //} /// /// Inserts a range of entities asynchronously. /// /// The entities to insert. /// A that represents the asynchronous insert operation. - public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities); + public virtual Task InsertAsync(params TEntity[] entities) + => DbSet.AddRangeAsync(entities); /// /// Inserts a range of entities asynchronously. @@ -847,44 +755,43 @@ public virtual TEntity Insert(TEntity entity) /// The entities to insert. /// A to observe while waiting for the task to complete. /// A that represents the asynchronous insert operation. - public virtual Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken); + public virtual Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default) + => DbSet.AddRangeAsync(entities, cancellationToken); /// /// Updates the specified entity. /// /// The entity. - public virtual void Update(TEntity entity) - { - _dbSet.Update(entity); - } + public virtual void Update(TEntity entity) + => DbSet.Update(entity); /// /// Updates the specified entity. /// /// The entity. - public virtual void UpdateAsync(TEntity entity) - { - _dbSet.Update(entity); - - } + public virtual void UpdateAsync(TEntity entity) + => DbSet.Update(entity); /// /// Updates the specified entities. /// /// The entities. - public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities); + public virtual void Update(params TEntity[] entities) + => DbSet.UpdateRange(entities); /// /// Updates the specified entities. /// /// The entities. - public virtual void Update(IEnumerable entities) => _dbSet.UpdateRange(entities); + public virtual void Update(IEnumerable entities) + => DbSet.UpdateRange(entities); /// /// Deletes the specified entity. /// /// The entity to delete. - public virtual void Delete(TEntity entity) => _dbSet.Remove(entity); + public virtual void Delete(TEntity entity) + => DbSet.Remove(entity); /// /// Deletes the entity by the specified primary key. @@ -894,17 +801,17 @@ public virtual void Delete(object id) { // using a stub entity to mark for deletion var typeInfo = typeof(TEntity).GetTypeInfo(); - var key = _dbContext.Model.FindEntityType(typeInfo).FindPrimaryKey().Properties.FirstOrDefault(); + var key = DbContext.Model.FindEntityType(typeInfo).FindPrimaryKey().Properties.FirstOrDefault(); var property = typeInfo.GetProperty(key?.Name); if (property != null) { var entity = Activator.CreateInstance(); property.SetValue(entity, id); - _dbContext.Entry(entity).State = EntityState.Deleted; + DbContext.Entry(entity).State = EntityState.Deleted; } else { - var entity = _dbSet.Find(id); + var entity = DbSet.Find(id); if (entity != null) { Delete(entity); @@ -916,22 +823,23 @@ public virtual void Delete(object id) /// Deletes the specified entities. /// /// The entities. - public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities); + public virtual void Delete(params TEntity[] entities) + => DbSet.RemoveRange(entities); /// /// Deletes the specified entities. /// /// The entities. - public virtual void Delete(IEnumerable entities) => _dbSet.RemoveRange(entities); + public virtual void Delete(IEnumerable entities) + => DbSet.RemoveRange(entities); /// /// Gets all entities. This method is not recommended /// /// The . - public async Task> GetAllAsync() - { - return await _dbSet.ToListAsync(); - } + public async Task> GetAllAsync() + => await DbSet.ToListAsync(); + /// /// Gets all entities. This method is not recommended /// @@ -942,12 +850,13 @@ public async Task> GetAllAsync() /// Ignore query filters /// An that contains elements that satisfy the condition specified by . /// Ex: This method defaults to a read-only, no-tracking query. - public async Task> GetAllAsync(Expression> predicate = null, + public async Task> GetAllAsync( + Expression> predicate = null, Func, IOrderedQueryable> orderBy = null, Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -981,14 +890,14 @@ public async Task> GetAllAsync(Expression> pr private bool ExistsUpdateTimestamp(TEntity entity, out TEntity entityForUpdate) { - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); object[] objArr = key.Properties.Select(q => entity.GetType().GetProperty(q.Name).GetValue(entity, null)).ToArray(); - TEntity obj = _dbSet.Find(objArr); + TEntity obj = DbSet.Find(objArr); if (obj != null && obj.GetType().GetProperty("Timestamp") != null) { entity.GetType().GetProperty("Timestamp").SetValue(entity, obj.GetType().GetProperty("Timestamp").GetValue(obj, null)); - _dbContext.Entry(obj).State = EntityState.Detached; + DbContext.Entry(obj).State = EntityState.Detached; } entityForUpdate = entity; @@ -997,33 +906,29 @@ private bool ExistsUpdateTimestamp(TEntity entity, out TEntity entityForUpdate) } public virtual bool Exists(TEntity entity) { - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); object[] objArr = key.Properties.Select(q => entity.GetType().GetProperty(q.Name).GetValue(entity, null)).ToArray(); - TEntity obj = _dbSet.Find(objArr); - if (obj != null) _dbContext.Entry(obj).State = EntityState.Detached; + TEntity obj = DbSet.Find(objArr); + if (obj != null) DbContext.Entry(obj).State = EntityState.Detached; return obj != null; } public virtual void InsertOrUpdate(TEntity entity) { - TEntity entityForUpdate = null; - if (ExistsUpdateTimestamp(entity, out entityForUpdate)) { - //if (Exists(entity)) { + if (ExistsUpdateTimestamp(entity, out var entityForUpdate)) { Update(entityForUpdate); - } else { + } + else { Insert(entity); } } - public virtual void InsertOrUpdate(IEnumerable entities) - { - //foreach (TEntity entity in entities) InsertOrUpdate(entity); - _dbContext.BulkInsertOrUpdate(entities.ToList()); - } + public virtual void InsertOrUpdate(IEnumerable entities) + => DbContext.BulkInsertOrUpdate(entities.ToList()); /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -1032,13 +937,14 @@ public virtual void InsertOrUpdate(IEnumerable entities) /// /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual List GetList(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false) + public virtual List GetList( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { query = query.AsNoTracking(); @@ -1063,14 +969,12 @@ public virtual List GetList(Expression> predicate = { return orderBy(query).ToList(); } - else - { - return query.ToList(); - } + + return query.ToList(); } /// - /// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query. + /// Gets the based on a predicate, orderBy delegate and page information. This method default no-tracking query. /// /// A function to test each element for a condition. /// A function to order elements. @@ -1082,14 +986,15 @@ public virtual List GetList(Expression> predicate = /// /// An that contains elements that satisfy the condition specified by . /// This method default no-tracking query. - public virtual Task> GetListAsync(Expression> predicate = null, - Func, IOrderedQueryable> orderBy = null, - Func, IIncludableQueryable> include = null, - bool disableTracking = true, - bool ignoreQueryFilters = false, - CancellationToken cancellationToken = default(CancellationToken)) + public virtual Task> GetListAsync( + Expression> predicate = null, + Func, IOrderedQueryable> orderBy = null, + Func, IIncludableQueryable> include = null, + bool disableTracking = true, + bool ignoreQueryFilters = false, + CancellationToken cancellationToken = default) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { query = query.AsNoTracking(); @@ -1114,10 +1019,8 @@ public virtual Task> GetListAsync(Expression> { return orderBy(query).ToListAsync(cancellationToken); } - else - { - return query.ToListAsync(cancellationToken); - } + + return query.ToListAsync(cancellationToken); } @@ -1128,14 +1031,14 @@ public virtual Task> GetListAsync(Expression> /// The found entity or null. public virtual TEntity GetNextById(params object[] keyValues) { - TEntity res = _dbSet.Find(IncrementKey(keyValues)); + TEntity res = DbSet.Find(IncrementKey(keyValues)); if (res != null) { return res; } //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); //var ordByExp = GetOrderBy(keyColums[0],"asc"); @@ -1175,14 +1078,14 @@ public virtual TEntity GetNextById(params object[] keyValues) /// The found entity or null. public virtual Task GetNextByIdAsync(params object[] keyValues) { - TEntity res = _dbSet.Find(IncrementKey(keyValues)); + TEntity res = DbSet.Find(IncrementKey(keyValues)); if (res != null) { return Task.Factory.StartNew(() => res); } //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); //var ordByExp = GetOrderBy(keyColums[0],"asc"); @@ -1220,14 +1123,14 @@ public virtual Task GetNextByIdAsync(params object[] keyValues) /// The found entity or null. public virtual TEntity GetPreviousById(params object[] keyValues) { - TEntity res = _dbSet.Find(DecrementKey(keyValues)); + TEntity res = DbSet.Find(DecrementKey(keyValues)); if (res != null) { return res; } //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); //var ordByExp = GetOrderBy(keyColums[0],"asc"); @@ -1267,14 +1170,14 @@ public virtual TEntity GetPreviousById(params object[] keyValues) /// The found entity or null. public virtual Task GetPreviousByIdAsync(params object[] keyValues) { - TEntity res = _dbSet.Find(DecrementKey(keyValues)); + TEntity res = DbSet.Find(DecrementKey(keyValues)); if (res != null) { return Task.Factory.StartNew(() => res); } //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); //var ordByExp = GetOrderBy(keyColums[0],"asc"); @@ -1314,7 +1217,7 @@ public virtual Task GetPreviousByIdAsync(params object[] keyValues) public virtual TEntity GetFirst() { //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); var ordByExp = GetOrderByExpression(keyColums); @@ -1323,7 +1226,7 @@ public virtual TEntity GetFirst() if (lstObjs != null && lstObjs.Count > 0) { - return lstObjs.FirstOrDefault(); + return lstObjs.FirstOrDefault(); } else { @@ -1338,7 +1241,7 @@ public virtual TEntity GetFirst() public virtual Task GetFirstAsync() { //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); var ordByExp = GetOrderByExpression(keyColums); @@ -1347,7 +1250,7 @@ public virtual Task GetFirstAsync() if (lstObjs != null && lstObjs.Count > 0) { - return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); + return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); } else { @@ -1362,7 +1265,7 @@ public virtual Task GetFirstAsync() public virtual TEntity GetLast() { //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); var ordByExp = GetOrderByExpression(keyColums,true); @@ -1371,7 +1274,7 @@ public virtual TEntity GetLast() if (lstObjs != null && lstObjs.Count > 0) { - return lstObjs.FirstOrDefault(); + return lstObjs.FirstOrDefault(); } else { @@ -1386,7 +1289,7 @@ public virtual TEntity GetLast() public virtual Task GetLastAsync() { //No Result Found. So Order the Entity with key column and select next Entity - IEntityType entityType = _dbContext.Model.FindEntityType(typeof(TEntity).ToString()); + IEntityType entityType = DbContext.Model.FindEntityType(typeof(TEntity).ToString()); IKey key = entityType.FindPrimaryKey(); List keyColums = key.Properties.Select(q => q.Name).ToList(); var ordByExp = GetOrderByExpression(keyColums,true); @@ -1395,7 +1298,7 @@ public virtual Task GetLastAsync() if (lstObjs != null && lstObjs.Count > 0) { - return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); + return Task.Factory.StartNew(() => lstObjs.FirstOrDefault()); } else { @@ -1458,7 +1361,7 @@ public static Func, IOrderedQueryable> GetOrderBy(string ord ParameterExpression argQueryable = Expression.Parameter(typeQueryable, "p"); var outerExpression = Expression.Lambda(argQueryable, argQueryable); string[] props = orderColumn.Split('.'); - IQueryable query = new List().AsQueryable(); + IQueryable query = new List().AsQueryable(); Type type = typeof(T); ParameterExpression arg = Expression.Parameter(type, "x"); @@ -1693,96 +1596,6 @@ public static Expression> GetWhereConditionExpression(IKey key, } #endregion - /* - #region Logic33 - public static IEnumerable BuildOrderBys( - this IEnumerable source, - IEnumerable properties) - { - if (properties == null || properties.Count() == 0) return source; - - var typeOfT = typeof(T); - - Type t = typeOfT; - - IOrderedEnumerable result = null; - var thenBy = false; - - foreach (var item in properties) - { - var oExpr = Expression.Parameter(typeOfT, "o"); - - MemberExpression prop = GetMemberExpression(oExpr, item); - var propertyInfo = (PropertyInfo)prop.Member; - var propertyType = propertyInfo.PropertyType; - var isAscending = true; - - if (thenBy) - { - var prevExpr = Expression.Parameter(typeof(IOrderedEnumerable), "prevExpr"); - var expr1 = Expression.Lambda, IOrderedEnumerable>>( - Expression.Call( - (isAscending ? thenByMethod : thenByDescendingMethod).MakeGenericMethod(typeOfT, propertyType), - prevExpr, - Expression.Lambda( - typeof(Func<,>).MakeGenericType(typeOfT, propertyType), - Expression.MakeMemberAccess(oExpr, propertyInfo), - oExpr) - ), - prevExpr) - .Compile(); - result = expr1(result); - } - else - { - var prevExpr = Expression.Parameter(typeof(IEnumerable), "prevExpr"); - var expr1 = Expression.Lambda, IOrderedEnumerable>>( - Expression.Call( - (isAscending ? orderByMethod : orderByDescendingMethod).MakeGenericMethod(typeOfT, propertyType), - prevExpr, - Expression.Lambda( - typeof(Func<,>).MakeGenericType(typeOfT, propertyType), - Expression.MakeMemberAccess(oExpr, propertyInfo), - oExpr) - ), - prevExpr) - .Compile(); - result = expr1(source); - thenBy = true; - } - } - return result; - } - - - private static MethodInfo orderByMethod = - MethodOf(() => Enumerable.OrderBy(default(IEnumerable), default(Func))) - .GetGenericMethodDefinition(); - - private static MethodInfo orderByDescendingMethod = - MethodOf(() => Enumerable.OrderByDescending(default(IEnumerable), default(Func))) - .GetGenericMethodDefinition(); - - private static MethodInfo thenByMethod = - MethodOf(() => Enumerable.ThenBy(default(IOrderedEnumerable), default(Func))) - .GetGenericMethodDefinition(); - - private static MethodInfo thenByDescendingMethod = - MethodOf(() => Enumerable.ThenByDescending(default(IOrderedEnumerable), default(Func))) - .GetGenericMethodDefinition(); - - public static MethodInfo MethodOf(Expression> method) - { - MethodCallExpression mce = (MethodCallExpression)method.Body; - MethodInfo mi = mce.Method; - return mi; - } - - - #endregion - - */ - /// /// Gets all entities. This method is not recommended /// @@ -1800,7 +1613,7 @@ public async Task> GetAllAsync(Expression, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { - IQueryable query = _dbSet; + IQueryable query = DbSet; if (disableTracking) { @@ -1839,12 +1652,9 @@ public async Task> GetAllAsync(Expression /// The entity. /// /// The entity state. - public void ChangeEntityState(TEntity entity, EntityState state) - { - _dbContext.Entry(entity).State = state; - } + public void ChangeEntityState(TEntity entity, EntityState state) => DbContext.Entry(entity).State = state; - ValueTask IRepository.FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); - ValueTask IRepository.FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken); + ValueTask IRepository.FindAsync(params object[] keyValues) => DbSet.FindAsync(keyValues); + ValueTask IRepository.FindAsync(object[] keyValues, CancellationToken cancellationToken) => DbSet.FindAsync(keyValues, cancellationToken); } } diff --git a/src/UnitOfWork/UnitOfWork.cs b/src/UnitOfWork/UnitOfWork.cs index 07114df..87da947 100644 --- a/src/UnitOfWork/UnitOfWork.cs +++ b/src/UnitOfWork/UnitOfWork.cs @@ -20,26 +20,22 @@ namespace Arch.EntityFrameworkCore.UnitOfWork /// Represents the default implementation of the and interface. /// /// The type of the db context. - public class UnitOfWork : IRepositoryFactory, IUnitOfWork, IUnitOfWork where TContext : DbContext + public class UnitOfWork : IRepositoryFactory, IUnitOfWork where TContext : DbContext { - private readonly TContext _context; - private bool disposed = false; - private Dictionary repositories; + private bool _disposed; + private Dictionary _repositories; /// /// Initializes a new instance of the class. /// /// The context. - public UnitOfWork(TContext context) - { - _context = context ?? throw new ArgumentNullException(nameof(context)); - } + public UnitOfWork(TContext context) => DbContext = context ?? throw new ArgumentNullException(nameof(context)); /// /// Gets the db context. /// /// The instance of type . - public TContext DbContext => _context; + public TContext DbContext { get; } /// /// Changes the database name. This require the databases in the same machine. NOTE: This only work for MySQL right now. @@ -50,7 +46,7 @@ public UnitOfWork(TContext context) /// public void ChangeDatabase(string database) { - var connection = _context.Database.GetDbConnection(); + var connection = DbContext.Database.GetDbConnection(); if (connection.State.HasFlag(ConnectionState.Open)) { connection.ChangeDatabase(database); @@ -62,7 +58,7 @@ public void ChangeDatabase(string database) } // Following code only working for mysql. - var items = _context.Model.GetEntityTypes(); + var items = DbContext.Model.GetEntityTypes(); foreach (var item in items) { if (item is IConventionEntityType entityType) @@ -75,20 +71,20 @@ public void ChangeDatabase(string database) /// /// Gets the specified repository for the . /// - /// True if providing custom repositry + /// True if providing custom repository /// The type of the entity. /// An instance of type inherited from interface. public IRepository GetRepository(bool hasCustomRepository = false) where TEntity : class { - if (repositories == null) + if (_repositories == null) { - repositories = new Dictionary(); + _repositories = new Dictionary(); } - // what's the best way to support custom reposity? + // what's the best way to support custom repository? if (hasCustomRepository) { - var customRepo = _context.GetService>(); + var customRepo = DbContext.GetService>(); if (customRepo != null) { return customRepo; @@ -96,12 +92,12 @@ public IRepository GetRepository(bool hasCustomRepository = fa } var type = typeof(TEntity); - if (!repositories.ContainsKey(type)) + if (!_repositories.ContainsKey(type)) { - repositories[type] = new Repository(_context); + _repositories[type] = new Repository(DbContext); } - return (IRepository)repositories[type]; + return (IRepository)_repositories[type]; } /// @@ -110,7 +106,7 @@ public IRepository GetRepository(bool hasCustomRepository = fa /// The raw SQL. /// The parameters. /// The number of state entities written to database. - public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters); + public int ExecuteSqlCommand(string sql, params object[] parameters) => DbContext.Database.ExecuteSqlRaw(sql, parameters); /// /// Executes the specified raw SQL command. @@ -120,7 +116,7 @@ public IRepository GetRepository(bool hasCustomRepository = fa /// The DataTable. public DataTable ExecuteDtSqlCommand(string sql, params object[] parameters) { - SqlConnection conn = (SqlConnection) _context.Database.GetDbConnection(); + SqlConnection conn = (SqlConnection) DbContext.Database.GetDbConnection(); SqlCommand cmd = new SqlCommand(sql, conn); cmd.CommandTimeout = 0; @@ -153,17 +149,14 @@ public DataTable ExecuteDtSqlCommand(string sql, params object[] parameters) /// The raw SQL. /// The parameters. /// An that contains elements that satisfy the condition specified by raw SQL. - public IQueryable FromSql(string sql, params object[] parameters) where TEntity : class => _context.Set().FromSqlRaw(sql, parameters); + public IQueryable FromSql(string sql, params object[] parameters) where TEntity : class => DbContext.Set().FromSqlRaw(sql, parameters); /// /// Starts Databaselevel Transaction /// /// The IsolationLevel /// Transaction Context - public IDbContextTransaction BeginTransaction(System.Data.IsolationLevel isolation = System.Data.IsolationLevel.ReadCommitted) - { - return _context.Database.BeginTransaction(isolation); - } + public IDbContextTransaction BeginTransaction(System.Data.IsolationLevel isolation = System.Data.IsolationLevel.ReadCommitted) => DbContext.Database.BeginTransaction(isolation); /// /// Saves all changes made in this context to the database. @@ -174,10 +167,10 @@ public int SaveChanges(bool ensureAutoHistory = false) { if (ensureAutoHistory) { - _context.EnsureAutoHistory(); + DbContext.EnsureAutoHistory(); } - return _context.SaveChanges(); + return DbContext.SaveChanges(); } /// @@ -189,10 +182,10 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false) { if (ensureAutoHistory) { - _context.EnsureAutoHistory(); + DbContext.EnsureAutoHistory(); } - return await _context.SaveChangesAsync(); + return await DbContext.SaveChangesAsync(); } /// @@ -203,20 +196,18 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false) /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. public async Task SaveChangesAsync(bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks) { - using (var ts = new TransactionScope(TransactionScopeAsyncFlowOption.Enabled)) + using var ts = new TransactionScope(TransactionScopeAsyncFlowOption.Enabled); + var count = 0; + foreach (var unitOfWork in unitOfWorks) { - var count = 0; - foreach (var unitOfWork in unitOfWorks) - { - count += await unitOfWork.SaveChangesAsync(ensureAutoHistory).ConfigureAwait(false); - } + count += await unitOfWork.SaveChangesAsync(ensureAutoHistory).ConfigureAwait(false); + } - count += await SaveChangesAsync(ensureAutoHistory); + count += await SaveChangesAsync(ensureAutoHistory); - ts.Complete(); + ts.Complete(); - return count; - } + return count; } @@ -229,23 +220,23 @@ public async Task SaveChangesAsync(bool ensureAutoHistory = false, params I /// A that represents the asynchronous save operation. The task result contains the number of state entities written to database. public async Task SaveChangesAsync(Transaction transaction, bool ensureAutoHistory = false, params IUnitOfWork[] unitOfWorks) { - - using (var ts = new TransactionScope(transaction)) + using var ts = new TransactionScope(transaction); + var count = 0; + foreach (var unitOfWork in unitOfWorks) { - var count = 0; - foreach (var unitOfWork in unitOfWorks) - { - count += await unitOfWork.SaveChangesAsync(ensureAutoHistory); - } + count += await unitOfWork.SaveChangesAsync(ensureAutoHistory); + } - count += await SaveChangesAsync(ensureAutoHistory); + count += await SaveChangesAsync(ensureAutoHistory); - ts.Complete(); + ts.Complete(); - return count; - } + return count; } + + public void TrackGraph(object rootEntity, Action callback) => DbContext.ChangeTracker.TrackGraph(rootEntity, callback); + IDbContextTransaction IUnitOfWork.BeginTransaction(System.Data.IsolationLevel isolation) => DbContext.Database.BeginTransaction(isolation); /// @@ -264,33 +255,50 @@ public void Dispose() /// The disposing. protected virtual void Dispose(bool disposing) { - if (!disposed) + if (!_disposed) { if (disposing) { // clear repositories - if (repositories != null) - { - repositories.Clear(); - } + _repositories?.Clear(); // dispose the db context. - _context.Dispose(); + DbContext.Dispose(); } } - disposed = true; + _disposed = true; } - - public void TrackGraph(object rootEntity, Action callback) + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public async ValueTask DisposeAsync() { - _context.ChangeTracker.TrackGraph(rootEntity, callback); + await DisposeAsync(true); + + GC.SuppressFinalize(this); } - IDbContextTransaction IUnitOfWork.BeginTransaction(System.Data.IsolationLevel isolation) + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + /// The disposing. + protected virtual async ValueTask DisposeAsync(bool disposing) { - return _context.Database.BeginTransaction(isolation); - } + if (!_disposed) + { + if (disposing) + { + // clear repositories + _repositories?.Clear(); + + // dispose the db context. + await DbContext.DisposeAsync(); + } + } + _disposed = true; + } } } diff --git a/src/UnitOfWork/UnitOfWork.csproj b/src/UnitOfWork/UnitOfWork.csproj index cf485ae..d5880ec 100644 --- a/src/UnitOfWork/UnitOfWork.csproj +++ b/src/UnitOfWork/UnitOfWork.csproj @@ -3,10 +3,11 @@ A plugin for Microsoft.EntityFrameworkCore to support repository, unit of work patterns, and multiple database with distributed transaction supported. 3.1.0 rigofunc;rigofunc@outlook.com; - net5.0 + net6.0 $(NoWarn);CS1591 true true + Arch.EntityFrameworkCore.UnitOfWork Microsoft.EntityFrameworkCore.UnitOfWork Microsoft.EntityFrameworkCore.UnitOfWork Entity Framework Core;entity-framework-core;EF;Data;O/RM;unitofwork;Unit Of Work;unit-of-work @@ -19,10 +20,10 @@ snupkg - - - - + + + + diff --git a/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs b/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs index 433ea00..6335145 100644 --- a/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs +++ b/src/UnitOfWork/UnitOfWorkServiceCollectionExtensions.cs @@ -22,7 +22,7 @@ public static class UnitOfWorkServiceCollectionExtensions public static IServiceCollection AddUnitOfWork(this IServiceCollection services) where TContext : DbContext { services.AddScoped>(); - // Following has a issue: IUnitOfWork cannot support multiple dbcontext/database, + // Following has a issue: IUnitOfWork cannot support multiple DbContext/Database, // that means cannot call AddUnitOfWork multiple times. // Solution: check IUnitOfWork whether or null services.AddScoped>(); @@ -104,7 +104,7 @@ public static IServiceCollection AddUnitOfWork. /// /// The type of the entity. - /// The type of the custom repositry. + /// The type of the custom repository. /// The to add services to. /// The same service collection so that multiple calls can be chained. public static IServiceCollection AddCustomRepository(this IServiceCollection services) diff --git a/test/UnitOfWork.Tests/Entities/City.cs b/test/UnitOfWork.Tests/Entities/City.cs index f632058..10539d2 100644 --- a/test/UnitOfWork.Tests/Entities/City.cs +++ b/test/UnitOfWork.Tests/Entities/City.cs @@ -3,7 +3,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class City + public record City { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Country.cs b/test/UnitOfWork.Tests/Entities/Country.cs index d4d03d8..f23d2b0 100644 --- a/test/UnitOfWork.Tests/Entities/Country.cs +++ b/test/UnitOfWork.Tests/Entities/Country.cs @@ -2,7 +2,7 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Country + public record Country { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Customer.cs b/test/UnitOfWork.Tests/Entities/Customer.cs index d2188d5..094881a 100644 --- a/test/UnitOfWork.Tests/Entities/Customer.cs +++ b/test/UnitOfWork.Tests/Entities/Customer.cs @@ -1,6 +1,6 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Customer + public record Customer { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/Entities/Town.cs b/test/UnitOfWork.Tests/Entities/Town.cs index aca070a..16ee910 100644 --- a/test/UnitOfWork.Tests/Entities/Town.cs +++ b/test/UnitOfWork.Tests/Entities/Town.cs @@ -1,6 +1,6 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests.Entities { - public class Town + public record Town { public int Id { get; set; } public string Name { get; set; } diff --git a/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs b/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs index 5554b51..ecebdec 100644 --- a/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs +++ b/test/UnitOfWork.Tests/IQueryablePageListExtensionsTests.cs @@ -13,40 +13,35 @@ public class IQueryablePageListExtensionsTests [Fact] public async Task ToPagedListAsyncTest() { - using (var db = new InMemoryContext()) - { - var testItems = TestItems(); - await db.AddRangeAsync(testItems); - db.SaveChanges(); + await using var db = new InMemoryContext(); + var testItems = TestItems(); + await db.AddRangeAsync(testItems); + await db.SaveChangesAsync(); - var items = db.Customers.Where(t => t.Age > 1); + var items = db.Customers.Where(t => t.Age > 1); - var page = await items.ToPagedListAsync(1, 2); - Assert.NotNull(page); + var page = await items.ToPagedListAsync(1, 2); + Assert.NotNull(page); - Assert.Equal(4, page.TotalCount); - Assert.Equal(2, page.Items.Count); - Assert.Equal("E", page.Items[0].Name); + Assert.Equal(4, page.TotalCount); + Assert.Equal(2, page.Items.Count); + Assert.Equal("E", page.Items[0].Name); - page = await items.ToPagedListAsync(0, 2); - Assert.NotNull(page); - Assert.Equal(4, page.TotalCount); - Assert.Equal(2, page.Items.Count); - Assert.Equal("C", page.Items[0].Name); - } + page = await items.ToPagedListAsync(0, 2); + Assert.NotNull(page); + Assert.Equal(4, page.TotalCount); + Assert.Equal(2, page.Items.Count); + Assert.Equal("C", page.Items[0].Name); } - public List TestItems() - { - return new List() + private static IEnumerable TestItems() => new List() { - new Customer(){Name="A", Age=1}, - new Customer(){Name="B", Age=1}, - new Customer(){Name="C", Age=2}, - new Customer(){Name="D", Age=3}, - new Customer(){Name="E", Age=4}, - new Customer(){Name="F", Age=5}, + new(){Name="A", Age=1}, + new(){Name="B", Age=1}, + new(){Name="C", Age=2}, + new(){Name="D", Age=3}, + new(){Name="E", Age=4}, + new(){Name="F", Age=5}, }; - } } } diff --git a/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs b/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs index 504a3d2..e95ee93 100644 --- a/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs +++ b/test/UnitOfWork.Tests/IRepositoryGetPagedListTest.cs @@ -8,23 +8,23 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests { public class IRepositoryGetPagedListTest { - private static readonly InMemoryContext db; + private static readonly InMemoryContext Db; static IRepositoryGetPagedListTest() { - db = new InMemoryContext(); + Db = new InMemoryContext(); - db.AddRange(TestCountries); - db.AddRange(TestCities); - db.AddRange(TestTowns); + Db.AddRange(TestCountries); + Db.AddRange(TestCities); + Db.AddRange(TestTowns); - db.SaveChanges(); + Db.SaveChanges(); } [Fact] public void GetPagedList() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = repository.GetPagedList(predicate: t => t.Name == "C", include: source => source.Include(t => t.Country), pageSize: 1); @@ -39,7 +39,7 @@ public void GetPagedList() [Fact] public async Task GetPagedListAsync() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = await repository.GetPagedListAsync(predicate: t => t.Name == "C", include: source => source.Include(t => t.Country), pageSize: 1); @@ -54,7 +54,7 @@ public async Task GetPagedListAsync() [Fact] public async Task GetPagedListWithIncludingMultipleLevelsAsync() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = await repository.GetPagedListAsync(predicate: t => t.Name == "A", include: country => country.Include(c => c.Cities).ThenInclude(city => city.Towns), pageSize: 1); @@ -67,7 +67,7 @@ public async Task GetPagedListWithIncludingMultipleLevelsAsync() [Fact] public void GetPagedListWithoutInclude() { - var repository = new Repository(db); + var repository = new Repository(Db); var page = repository.GetPagedList(pageIndex: 0, pageSize: 1); @@ -75,30 +75,30 @@ public void GetPagedListWithoutInclude() Assert.Null(page.Items[0].Country); } - protected static List TestCountries => new List + private static IEnumerable TestCountries => new List { - new Country {Id = 1, Name = "A"}, - new Country {Id = 2, Name = "B"} + new() {Id = 1, Name = "A"}, + new() {Id = 2, Name = "B"} }; - public static List TestCities => new List + private static IEnumerable TestCities => new List { - new City { Id = 1, Name = "A", CountryId = 1}, - new City { Id = 2, Name = "B", CountryId = 2}, - new City { Id = 3, Name = "C", CountryId = 1}, - new City { Id = 4, Name = "D", CountryId = 2}, - new City { Id = 5, Name = "E", CountryId = 1}, - new City { Id = 6, Name = "F", CountryId = 2}, + new() { Id = 1, Name = "A", CountryId = 1}, + new() { Id = 2, Name = "B", CountryId = 2}, + new() { Id = 3, Name = "C", CountryId = 1}, + new() { Id = 4, Name = "D", CountryId = 2}, + new() { Id = 5, Name = "E", CountryId = 1}, + new() { Id = 6, Name = "F", CountryId = 2}, }; - public static List TestTowns => new List + private static IEnumerable TestTowns => new List { - new Town { Id = 1, Name="A", CityId = 1 }, - new Town { Id = 2, Name="B", CityId = 2 }, - new Town { Id = 3, Name="C", CityId = 3 }, - new Town { Id = 4, Name="D", CityId = 4 }, - new Town { Id = 5, Name="E", CityId = 5 }, - new Town { Id = 6, Name="F", CityId = 6 }, + new() { Id = 1, Name="A", CityId = 1 }, + new() { Id = 2, Name="B", CityId = 2 }, + new() { Id = 3, Name="C", CityId = 3 }, + new() { Id = 4, Name="D", CityId = 4 }, + new() { Id = 5, Name="E", CityId = 5 }, + new() { Id = 6, Name="F", CityId = 6 }, }; } } diff --git a/test/UnitOfWork.Tests/InMemoryContext.cs b/test/UnitOfWork.Tests/InMemoryContext.cs index cf60438..e0c4099 100644 --- a/test/UnitOfWork.Tests/InMemoryContext.cs +++ b/test/UnitOfWork.Tests/InMemoryContext.cs @@ -5,12 +5,9 @@ namespace Arch.EntityFrameworkCore.UnitOfWork.Tests { public class InMemoryContext : DbContext { - public DbSet Countries { get; set; } - public DbSet Customers { get; set; } + public DbSet Countries => Set(); + public DbSet Customers => Set(); - protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) - { - optionsBuilder.UseInMemoryDatabase("test"); - } + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => optionsBuilder.UseInMemoryDatabase("test"); } } diff --git a/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs b/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs index ded0b57..a178593 100644 --- a/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs +++ b/test/UnitOfWork.Tests/TestGetFirstOrDefaultAsync.cs @@ -52,30 +52,30 @@ public async void TestGetFirstOrDefaultAsyncCanInclude() } - protected static List TestCountries => new List + private static IEnumerable TestCountries => new List { - new Country {Id = 1, Name = "A"}, - new Country {Id = 2, Name = "B"} + new() {Id = 1, Name = "A"}, + new() {Id = 2, Name = "B"} }; - public static List TestCities => new List + private static IEnumerable TestCities => new List { - new City { Id = 1, Name = "A", CountryId = 1}, - new City { Id = 2, Name = "B", CountryId = 2}, - new City { Id = 3, Name = "C", CountryId = 1}, - new City { Id = 4, Name = "D", CountryId = 2}, - new City { Id = 5, Name = "E", CountryId = 1}, - new City { Id = 6, Name = "F", CountryId = 2}, + new() { Id = 1, Name = "A", CountryId = 1}, + new() { Id = 2, Name = "B", CountryId = 2}, + new() { Id = 3, Name = "C", CountryId = 1}, + new() { Id = 4, Name = "D", CountryId = 2}, + new() { Id = 5, Name = "E", CountryId = 1}, + new() { Id = 6, Name = "F", CountryId = 2}, }; - public static List TestTowns => new List + private static IEnumerable TestTowns => new List { - new Town { Id = 1, Name="TownA", CityId = 1 }, - new Town { Id = 2, Name="TownB", CityId = 2 }, - new Town { Id = 3, Name="TownC", CityId = 3 }, - new Town { Id = 4, Name="TownD", CityId = 4 }, - new Town { Id = 5, Name="TownE", CityId = 5 }, - new Town { Id = 6, Name="TownF", CityId = 6 }, + new() { Id = 1, Name="TownA", CityId = 1 }, + new() { Id = 2, Name="TownB", CityId = 2 }, + new() { Id = 3, Name="TownC", CityId = 3 }, + new() { Id = 4, Name="TownD", CityId = 4 }, + new() { Id = 5, Name="TownE", CityId = 5 }, + new() { Id = 6, Name="TownF", CityId = 6 }, }; } } diff --git a/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj b/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj index f1dbf4d..72d5c91 100644 --- a/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj +++ b/test/UnitOfWork.Tests/UnitOfWork.Tests.csproj @@ -1,12 +1,14 @@  - net5.0 + net6.0 + + Arch.EntityFrameworkCore.UnitOfWork.Tests - - + + all