diff --git a/src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs b/src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs index 5b5b0d1a55..45627fd1a6 100644 --- a/src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs +++ b/src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs @@ -58,28 +58,12 @@ public virtual IQueryable Get() public virtual IQueryable Filter(IQueryable entities, FilterQuery filterQuery) { - if (filterQuery == null) - return entities; - - if (filterQuery.IsAttributeOfRelationship) - return entities.Filter(new RelatedAttrFilterQuery(_jsonApiContext, filterQuery)); - - return entities.Filter(new AttrFilterQuery(_jsonApiContext, filterQuery)); + return entities.Filter(_jsonApiContext, filterQuery); } public virtual IQueryable Sort(IQueryable entities, List sortQueries) { - if (sortQueries == null || sortQueries.Count == 0) - return entities; - - var orderedEntities = entities.Sort(sortQueries[0]); - - if (sortQueries.Count <= 1) return orderedEntities; - - for (var i = 1; i < sortQueries.Count; i++) - orderedEntities = orderedEntities.Sort(sortQueries[i]); - - return orderedEntities; + return entities.Sort(sortQueries); } public virtual async Task GetAsync(TId id) @@ -156,26 +140,36 @@ public virtual IQueryable Include(IQueryable entities, string public virtual async Task> PageAsync(IQueryable entities, int pageSize, int pageNumber) { - if (pageSize > 0) + if (pageNumber >= 0) { - if (pageNumber == 0) - pageNumber = 1; - - if (pageNumber > 0) - return await entities - .Skip((pageNumber - 1) * pageSize) - .Take(pageSize) - .ToListAsync(); - else // page from the end of the set - return (await entities - .OrderByDescending(t => t.Id) - .Skip((Math.Abs(pageNumber) - 1) * pageSize) - .Take(pageSize) - .ToListAsync()) - .OrderBy(t => t.Id) - .ToList(); + return await entities.PageForward(pageSize, pageNumber).ToListAsync(); } + // since EntityFramework does not support IQueryable.Reverse(), we need to know the number of queried entities + int numberOfEntities = await this.CountAsync(entities); + + // may be negative + int virtualFirstIndex = numberOfEntities - pageSize * Math.Abs(pageNumber); + int numberOfElementsInPage = Math.Min(pageSize, virtualFirstIndex + pageSize); + + return await entities + .Skip(virtualFirstIndex) + .Take(numberOfElementsInPage) + .ToListAsync(); + } + + public async Task CountAsync(IQueryable entities) + { + return await entities.CountAsync(); + } + + public Task FirstOrDefaultAsync(IQueryable entities) + { + return entities.FirstOrDefaultAsync(); + } + + public async Task> ToListAsync(IQueryable entities) + { return await entities.ToListAsync(); } } diff --git a/src/JsonApiDotNetCore/Data/IEntityReadRepository.cs b/src/JsonApiDotNetCore/Data/IEntityReadRepository.cs index aad16a9efc..a86b7334a9 100644 --- a/src/JsonApiDotNetCore/Data/IEntityReadRepository.cs +++ b/src/JsonApiDotNetCore/Data/IEntityReadRepository.cs @@ -27,5 +27,11 @@ public interface IEntityReadRepository Task GetAsync(TId id); Task GetAndIncludeAsync(TId id, string relationshipName); + + Task CountAsync(IQueryable entities); + + Task FirstOrDefaultAsync(IQueryable entities); + + Task> ToListAsync(IQueryable entities); } } diff --git a/src/JsonApiDotNetCore/Extensions/IQueryableExtensions.cs b/src/JsonApiDotNetCore/Extensions/IQueryableExtensions.cs index ee6b4451c5..6bbb1115b2 100644 --- a/src/JsonApiDotNetCore/Extensions/IQueryableExtensions.cs +++ b/src/JsonApiDotNetCore/Extensions/IQueryableExtensions.cs @@ -2,15 +2,30 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Reflection; using JsonApiDotNetCore.Internal; using JsonApiDotNetCore.Internal.Query; +using JsonApiDotNetCore.Services; namespace JsonApiDotNetCore.Extensions { // ReSharper disable once InconsistentNaming public static class IQueryableExtensions { + public static IQueryable Sort(this IQueryable source, List sortQueries) + { + if (sortQueries == null || sortQueries.Count == 0) + return source; + + var orderedEntities = source.Sort(sortQueries[0]); + + if (sortQueries.Count <= 1) return orderedEntities; + + for (var i = 1; i < sortQueries.Count; i++) + orderedEntities = orderedEntities.Sort(sortQueries[i]); + + return orderedEntities; + } + public static IOrderedQueryable Sort(this IQueryable source, SortQuery sortQuery) { return sortQuery.Direction == SortDirection.Descending @@ -62,6 +77,17 @@ private static IOrderedQueryable CallGenericOrderMethod(IQuery return (IOrderedQueryable)result; } + public static IQueryable Filter(this IQueryable source, IJsonApiContext jsonApiContext, FilterQuery filterQuery) + { + if (filterQuery == null) + return source; + + if (filterQuery.IsAttributeOfRelationship) + return source.Filter(new RelatedAttrFilterQuery(jsonApiContext, filterQuery)); + + return source.Filter(new AttrFilterQuery(jsonApiContext, filterQuery)); + } + public static IQueryable Filter(this IQueryable source, AttrFilterQuery filterQuery) { if (filterQuery == null) @@ -201,5 +227,21 @@ public static IQueryable Select(this IQueryable sourc Expression.Call(typeof(Queryable), "Select", new[] { sourceType, resultType }, source.Expression, Expression.Quote(selector))); } + + public static IQueryable PageForward(this IQueryable source, int pageSize, int pageNumber) + { + if (pageSize > 0) + { + if (pageNumber == 0) + pageNumber = 1; + + if (pageNumber > 0) + return source + .Skip((pageNumber - 1) * pageSize) + .Take(pageSize); + } + + return source; + } } } diff --git a/src/JsonApiDotNetCore/Services/EntityResourceService.cs b/src/JsonApiDotNetCore/Services/EntityResourceService.cs index cc1ba897e1..c0b4847f13 100644 --- a/src/JsonApiDotNetCore/Services/EntityResourceService.cs +++ b/src/JsonApiDotNetCore/Services/EntityResourceService.cs @@ -5,7 +5,6 @@ using JsonApiDotNetCore.Extensions; using JsonApiDotNetCore.Internal; using JsonApiDotNetCore.Models; -using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; namespace JsonApiDotNetCore.Services @@ -50,7 +49,7 @@ public virtual async Task> GetAsync() entities = IncludeRelationships(entities, _jsonApiContext.QuerySet.IncludedRelationships); if (_jsonApiContext.Options.IncludeTotalRecordCount) - _jsonApiContext.PageManager.TotalRecords = await entities.CountAsync(); + _jsonApiContext.PageManager.TotalRecords = await _entities.CountAsync(entities); // pagination should be done last since it will execute the query var pagedEntities = await ApplyPageQueryAsync(entities); @@ -72,12 +71,12 @@ private bool ShouldIncludeRelationships() private async Task GetWithRelationshipsAsync(TId id) { - var query = _entities.Get(); + var query = _entities.Get().Where(e => e.Id.Equals(id)); _jsonApiContext.QuerySet.IncludedRelationships.ForEach(r => { query = _entities.Include(query, r); }); - return await query.FirstOrDefaultAsync(e => e.Id.Equals(id)); + return await _entities.FirstOrDefaultAsync(query); } public virtual async Task GetRelationshipsAsync(TId id, string relationshipName) @@ -166,7 +165,7 @@ private async Task> ApplyPageQueryAsync(IQueryable entities) { var pageManager = _jsonApiContext.PageManager; if (!pageManager.IsPaginated) - return entities; + return await _entities.ToListAsync(entities); _logger?.LogInformation($"Applying paging query. Fetching page {pageManager.CurrentPage} with {pageManager.PageSize} entities"); diff --git a/test/JsonApiDotNetCoreExampleTests/Acceptance/Spec/PagingTests.cs b/test/JsonApiDotNetCoreExampleTests/Acceptance/Spec/PagingTests.cs index 02dbd019e0..7d8401f78d 100644 --- a/test/JsonApiDotNetCoreExampleTests/Acceptance/Spec/PagingTests.cs +++ b/test/JsonApiDotNetCoreExampleTests/Acceptance/Spec/PagingTests.cs @@ -1,7 +1,9 @@ +using System.Collections.Generic; using System.Linq; using System.Net; using System.Threading.Tasks; using Bogus; +using JsonApiDotNetCore.Models; using JsonApiDotNetCore.Serialization; using JsonApiDotNetCoreExample; using JsonApiDotNetCoreExample.Models; @@ -50,7 +52,7 @@ public async Task Can_Paginate_TodoItems_From_Start() { const int expectedEntitiesPerPage = 2; var totalCount = expectedEntitiesPerPage * 2; var person = new Person(); - var todoItems = _todoItemFaker.Generate(totalCount); + var todoItems = _todoItemFaker.Generate(totalCount).ToList(); foreach (var todoItem in todoItems) todoItem.Owner = person; @@ -70,12 +72,8 @@ public async Task Can_Paginate_TodoItems_From_Start() { var body = await response.Content.ReadAsStringAsync(); var deserializedBody = GetService().DeserializeList(body); - Assert.NotEmpty(deserializedBody); - Assert.Equal(expectedEntitiesPerPage, deserializedBody.Count); - - var expectedTodoItems = Context.TodoItems.Take(2); - foreach (var todoItem in expectedTodoItems) - Assert.NotNull(deserializedBody.SingleOrDefault(t => t.Id == todoItem.Id)); + var expectedTodoItems = new[] { todoItems[0], todoItems[1] }; + Assert.Equal(expectedTodoItems, deserializedBody, new IdComparer()); } [Fact] @@ -84,7 +82,7 @@ public async Task Can_Paginate_TodoItems_From_End() { const int expectedEntitiesPerPage = 2; var totalCount = expectedEntitiesPerPage * 2; var person = new Person(); - var todoItems = _todoItemFaker.Generate(totalCount); + var todoItems = _todoItemFaker.Generate(totalCount).ToList(); foreach (var todoItem in todoItems) todoItem.Owner = person; @@ -104,18 +102,16 @@ public async Task Can_Paginate_TodoItems_From_End() { var body = await response.Content.ReadAsStringAsync(); var deserializedBody = GetService().DeserializeList(body); - Assert.NotEmpty(deserializedBody); - Assert.Equal(expectedEntitiesPerPage, deserializedBody.Count); + var expectedTodoItems = new[] { todoItems[totalCount - 2], todoItems[totalCount - 1] }; + Assert.Equal(expectedTodoItems, deserializedBody, new IdComparer()); + } - var expectedTodoItems = Context.TodoItems - .OrderByDescending(t => t.Id) - .Take(2) - .ToList() - .OrderBy(t => t.Id) - .ToList(); + private class IdComparer : IEqualityComparer + where T : IIdentifiable + { + public bool Equals(T x, T y) => x?.StringId == y?.StringId; - for (int i = 0; i < expectedEntitiesPerPage; i++) - Assert.Equal(expectedTodoItems[i].Id, deserializedBody[i].Id); + public int GetHashCode(T obj) => obj?.StringId?.GetHashCode() ?? 0; } } -} \ No newline at end of file +} diff --git a/test/UnitTests/Data/DefaultEntityRepository_Tests.cs b/test/UnitTests/Data/DefaultEntityRepository_Tests.cs index 50d24409f8..a8ec56fe9c 100644 --- a/test/UnitTests/Data/DefaultEntityRepository_Tests.cs +++ b/test/UnitTests/Data/DefaultEntityRepository_Tests.cs @@ -13,7 +13,8 @@ using Microsoft.Extensions.Logging; using JsonApiDotNetCore.Services; using System.Threading.Tasks; - +using System.Linq; + namespace UnitTests.Data { public class DefaultEntityRepository_Tests : JsonApiControllerMixin @@ -93,6 +94,83 @@ private DefaultEntityRepository GetRepository() _loggFactoryMock.Object, _jsonApiContextMock.Object, _contextResolverMock.Object); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + [InlineData(-10)] + public async Task Page_When_PageSize_Is_NonPositive_Does_Nothing(int pageSize) + { + var todoItems = DbSetMock.Create(TodoItems(2, 3, 1)).Object; + var repository = GetRepository(); + + var result = await repository.PageAsync(todoItems, pageSize, 3); + + Assert.Equal(TodoItems(2, 3, 1), result, new IdComparer()); + } + + [Fact] + public async Task Page_When_PageNumber_Is_Zero_Pretends_PageNumber_Is_One() + { + var todoItems = DbSetMock.Create(TodoItems(2, 3, 1)).Object; + var repository = GetRepository(); + + var result = await repository.PageAsync(todoItems, 1, 0); + + Assert.Equal(TodoItems(2), result, new IdComparer()); + } + + [Fact] + public async Task Page_When_PageNumber_Of_PageSize_Does_Not_Exist_Return_Empty_Queryable() + { + var todoItems = DbSetMock.Create(TodoItems(2, 3, 1)).Object; + var repository = GetRepository(); + + var result = await repository.PageAsync(todoItems, 2, 3); + + Assert.Empty(result); + } + + [Theory] + [InlineData(3, 2, new[] { 4, 5, 6 })] + [InlineData(8, 2, new[] { 9 })] + [InlineData(20, 1, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 })] + public async Task Page_When_PageNumber_Is_Positive_Returns_PageNumberTh_Page_Of_Size_PageSize(int pageSize, int pageNumber, int[] expectedResult) + { + var todoItems = DbSetMock.Create(TodoItems(1, 2, 3, 4, 5, 6, 7, 8, 9)).Object; + var repository = GetRepository(); + + var result = await repository.PageAsync(todoItems, pageSize, pageNumber); + + Assert.Equal(TodoItems(expectedResult), result, new IdComparer()); + } + + [Theory] + [InlineData(6, -1, new[] { 4, 5, 6, 7, 8, 9 })] + [InlineData(6, -2, new[] { 1, 2, 3 })] + [InlineData(20, -1, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 })] + public async Task Page_When_PageNumber_Is_Negative_Returns_PageNumberTh_Page_From_End(int pageSize, int pageNumber, int[] expectedIds) + { + var todoItems = DbSetMock.Create(TodoItems(1, 2, 3, 4, 5, 6, 7, 8, 9)).Object; + var repository = GetRepository(); + + var result = await repository.PageAsync(todoItems, pageSize, pageNumber); + + Assert.Equal(TodoItems(expectedIds), result, new IdComparer()); + } + + private static TodoItem[] TodoItems(params int[] ids) + { + return ids.Select(id => new TodoItem { Id = id }).ToArray(); + } + + private class IdComparer : IEqualityComparer + where T : IIdentifiable + { + public bool Equals(T x, T y) => x?.StringId == y?.StringId; + + public int GetHashCode(T obj) => obj?.StringId?.GetHashCode() ?? 0; } } }