diff --git a/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProvider.kt b/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProvider.kt index 6ec1e606141..1ae5353dbaa 100644 --- a/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProvider.kt +++ b/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProvider.kt @@ -15,6 +15,8 @@ */ package org.bson.codecs.kotlinx +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.modules.SerializersModule import org.bson.codecs.Codec import org.bson.codecs.configuration.CodecProvider import org.bson.codecs.configuration.CodecRegistry @@ -24,8 +26,12 @@ import org.bson.codecs.configuration.CodecRegistry * * The underlying class must be annotated with the `@Serializable`. */ -public class KotlinSerializerCodecProvider : CodecProvider { +@OptIn(ExperimentalSerializationApi::class) +public class KotlinSerializerCodecProvider( + private val serializersModule: SerializersModule = defaultSerializersModule, + private val bsonConfiguration: BsonConfiguration = BsonConfiguration() +) : CodecProvider { override fun get(clazz: Class, registry: CodecRegistry): Codec? = - KotlinSerializerCodec.create(clazz.kotlin) + KotlinSerializerCodec.create(clazz.kotlin, serializersModule, bsonConfiguration) } diff --git a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt index 0870e2033e9..8d4fa304bc8 100644 --- a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt +++ b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt @@ -20,6 +20,20 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.modules.plus +import kotlinx.serialization.modules.polymorphic +import kotlinx.serialization.modules.subclass +import org.bson.BsonDocument +import org.bson.BsonDocumentReader +import org.bson.BsonDocumentWriter +import org.bson.codecs.DecoderContext +import org.bson.codecs.EncoderContext +import org.bson.codecs.kotlinx.samples.DataClassContainsOpen +import org.bson.codecs.kotlinx.samples.DataClassOpen +import org.bson.codecs.kotlinx.samples.DataClassOpenA +import org.bson.codecs.kotlinx.samples.DataClassOpenB import org.bson.codecs.kotlinx.samples.DataClassParameterized import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues import org.bson.conversions.Bson @@ -60,4 +74,37 @@ class KotlinSerializerCodecProviderTest { assertTrue { codec is KotlinSerializerCodec } assertEquals(DataClassWithSimpleValues::class.java, codec.encoderClass) } + + @OptIn(ExperimentalSerializationApi::class) + @Test + fun shouldAllowOverridingOfSerializersModuleAndBsonConfigurationInConstructor() { + + val serializersModule = + SerializersModule { + this.polymorphic(DataClassOpen::class) { + this.subclass(DataClassOpenA::class) + this.subclass(DataClassOpenB::class) + } + } + defaultSerializersModule + + val bsonConfiguration = BsonConfiguration(classDiscriminator = "__type") + val dataClassContainsOpenB = DataClassContainsOpen(DataClassOpenB(1)) + + val codec = + KotlinSerializerCodecProvider(serializersModule, bsonConfiguration) + .get(DataClassContainsOpen::class.java, Bson.DEFAULT_CODEC_REGISTRY)!! + + assertTrue { codec is KotlinSerializerCodec } + val encodedDocument = BsonDocument() + val writer = BsonDocumentWriter(encodedDocument) + codec.encode(writer, dataClassContainsOpenB, EncoderContext.builder().build()) + writer.flush() + + assertEquals( + BsonDocument.parse("""{"open": {"__type": "org.bson.codecs.kotlinx.samples.DataClassOpenB", "b": 1}}"""), + encodedDocument) + + assertEquals( + dataClassContainsOpenB, codec.decode(BsonDocumentReader(encodedDocument), DecoderContext.builder().build())) + } }