diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQuery.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQuery.cs index 9d9588299e6..67f5f6f58e7 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQuery.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQuery.cs @@ -13,14 +13,17 @@ * limitations under the License. */ +using System; using System.Linq; using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecutableQueryTranslators { @@ -141,17 +144,42 @@ public override string ToString() // private methods private BsonDocumentStagePipelineDefinition CreateCollectionPipelineDefinition(BsonDocument[] stages) { - return new BsonDocumentStagePipelineDefinition(stages, (IBsonSerializer)_pipeline.OutputSerializer); + var outputSerializer = GetOutputSerializer(); + return new BsonDocumentStagePipelineDefinition(stages, outputSerializer); } private BsonDocumentStagePipelineDefinition CreateDatabasePipelineDefinition(BsonDocument[] stages) { - return new BsonDocumentStagePipelineDefinition(stages, (IBsonSerializer)_pipeline.OutputSerializer); + var outputSerializer = GetOutputSerializer(); + return new BsonDocumentStagePipelineDefinition(stages, outputSerializer); } private BsonDocument[] RenderPipeline() { return _pipeline.Render().AsBsonArray.Cast().ToArray(); } + + private IBsonSerializer GetOutputSerializer() + { + var outputSerializer = _pipeline.OutputSerializer; + var outputType = outputSerializer.ValueType; + + if (outputType == typeof(TOutput)) + { + return (IBsonSerializer)outputSerializer; + } + + if (!typeof(TOutput).IsAssignableFrom(outputType)) + { + throw new NotSupportedException($"The type of the pipeline output is {outputType} which is not assignable to {typeof(TOutput)}."); + } + + if (typeof(TOutput).IsNullableOf(outputType)) + { + return (IBsonSerializer)NullableSerializer.Create(outputSerializer); + } + + return (IBsonSerializer)DowncastingSerializer.Create(typeof(TOutput), outputType, outputSerializer); + } } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQueryTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQueryTests.cs new file mode 100644 index 00000000000..bfafb2b05ca --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExecutableQueryTests.cs @@ -0,0 +1,149 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Linq; +using FluentAssertions; +using MongoDB.Driver.Linq; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToExecutableQueryTranslators +{ + public class ExecutableQueryTests : Linq3IntegrationTest + { + [Fact] + public void Cast_to_object_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable(); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(5); + } + + [Fact] + public void Cast_aggregation_to_object_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable().GroupBy( + p => p.Type, + (k, p) => new ProductAggregation {Type = k, MaxPrice = p.Select(i => i.Price).Max()}); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(2); + } + + [Fact] + public void Cast_int_to_object_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable().Select(p => p.Id); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(5); + } + + [Fact] + public void Cast_to_nullable_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable().Select(p => p.Id); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(5); + } + + [Fact] + public void Cast_to_incompatible_type_should_throw() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable(); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var exception = Record.Exception(() => queryable2.ToList()); + + exception.Should().BeOfType(); + exception.Message.Should().Contain($"The type of the pipeline output is {typeof(DerivedProduct)} which is not assignable to {typeof(ProductAggregation)}"); + } + + [Fact] + public void Cast_to_interface_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable(); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(5); + } + + [Fact] + public void Cast_to_base_class_should_work() + { + var collection = GetCollection(); + var queryable1 = collection.AsQueryable(); + var queryable2 = queryable1.Provider.CreateQuery(queryable1.Expression); + + var results = queryable2.ToList(); + + results.Should().HaveCount(5); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection, + new DerivedProduct { Id = 1, Type = "a", Price = 1 }, + new DerivedProduct { Id = 2, Type = "a", Price = 5 }, + new DerivedProduct { Id = 3, Type = "a", Price = 12 }, + new DerivedProduct { Id = 4, Type = "b", Price = 2 }, + new DerivedProduct { Id = 5, Type = "b", Price = 7 }); + return collection; + } + + private interface IProduct + { + string Type { get; set; } + decimal Price { get; set; } + } + + private class ProductBase : IProduct + { + public int Id { get; set; } + public string Type { get; set; } + public decimal Price { get; set; } + } + + private class DerivedProduct : ProductBase + { + } + + private class ProductAggregation + { + public string Type { get; set; } + public decimal MaxPrice { get; set; } + } + } +}