diff --git a/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala b/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala index 090d066223c..e284647af87 100644 --- a/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala +++ b/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.bson._ import org.bson.codecs.configuration.{ CodecRegistries, CodecRegistry } import org.bson.codecs.{ Codec, DecoderContext, Encoder, EncoderContext } +import scala.collection.immutable.Vector import org.mongodb.scala.bson.BsonNull diff --git a/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala b/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala index 95d7533cc87..c16215a16e8 100644 --- a/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala +++ b/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala @@ -30,6 +30,7 @@ import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty } import org.mongodb.scala.bson.codecs.Macros.{ createCodecProvider, createCodecProviderIgnoreNone } import org.mongodb.scala.bson.codecs.Registry.DEFAULT_CODEC_REGISTRY import org.mongodb.scala.bson.collection.immutable.Document +import scala.collection.immutable.Vector import scala.collection.JavaConverters._ import scala.reflect.ClassTag diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index d5d07273cea..8590c2920be 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -18,10 +18,13 @@ import org.bson.assertions.Assertions; import org.bson.internal.UuidHelper; +import org.bson.internal.vector.VectorHelper; import java.util.Arrays; import java.util.UUID; +import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; + /** * A representation of the BSON Binary type. Note that for performance reasons instances of this class are not immutable, * so care should be taken to only modify the underlying byte array if you know what you're doing, or else make a defensive copy. @@ -89,6 +92,20 @@ public BsonBinary(final UUID uuid) { this(uuid, UuidRepresentation.STANDARD); } + /** + * Constructs a {@linkplain BsonBinarySubType#VECTOR subtype 9} {@link BsonBinary} from the given {@link Vector}. + * + * @param vector the {@link Vector} + * @since 5.3 + */ + public BsonBinary(final Vector vector) { + if (vector == null) { + throw new IllegalArgumentException("Vector must not be null"); + } + this.data = encodeVectorToBinary(vector); + type = BsonBinarySubType.VECTOR.getValue(); + } + /** * Construct a new instance from the given UUID and UuidRepresentation * @@ -127,6 +144,21 @@ public UUID asUuid() { return UuidHelper.decodeBinaryToUuid(this.data.clone(), this.type, UuidRepresentation.STANDARD); } + /** + * Returns the binary as a {@link Vector}. The {@linkplain #getType() subtype} must be {@linkplain BsonBinarySubType#VECTOR 9}. + * + * @return the vector + * @throws BsonInvalidOperationException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. + * @since 5.3 + */ + public Vector asVector() { + if (type != BsonBinarySubType.VECTOR.getValue()) { + throw new BsonInvalidOperationException("type must be a Vector subtype."); + } + + return VectorHelper.decodeBinaryToVector(this.data); + } + /** * Returns the binary as a UUID. * diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index 3c5f72813b6..7b5948b4efc 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -17,7 +17,7 @@ package org.bson; /** - * The Binary subtype + * The Binary subtype. * * @since 3.0 */ @@ -60,7 +60,7 @@ public enum BsonBinarySubType { ENCRYPTED((byte) 0x06), /** - * Columnar data + * Columnar data. * * @since 4.4 */ @@ -73,6 +73,15 @@ public enum BsonBinarySubType { */ SENSITIVE((byte) 0x08), + /** + * Vector data. + * + * @mongodb.server.release 6.0 + * @since 5.3 + * @see Vector + */ + VECTOR((byte) 0x09), + /** * User defined binary data. */ @@ -81,10 +90,10 @@ public enum BsonBinarySubType { private final byte value; /** - * Returns true if the given value is a UUID subtype + * Returns true if the given value is a UUID subtype. * - * @param value the subtype value as a byte - * @return true if value is a UUID subtype + * @param value the subtype value as a byte. + * @return true if value is a UUID subtype. * @since 3.4 */ public static boolean isUuid(final byte value) { diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java new file mode 100644 index 00000000000..9678003b72f --- /dev/null +++ b/bson/src/main/org/bson/Float32Vector.java @@ -0,0 +1,79 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import java.util.Arrays; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a vector of 32-bit floating-point numbers, where each element in the vector is a float. + *

+ * The {@link Float32Vector} is used to store and retrieve data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#floatVector(float[]) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @since 5.3 + */ +public final class Float32Vector extends Vector { + + private final float[] data; + + Float32Vector(final float[] vectorData) { + super(DataType.FLOAT32); + this.data = assertNotNull(vectorData); + } + + /** + * Retrieve the underlying float array representing this {@link Float32Vector}, where each float + * represents an element of a vector. + *

+ * NOTE: The underlying float array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the underlying float array representing this {@link Float32Vector} vector. + */ + public float[] getData() { + return assertNotNull(data); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Float32Vector that = (Float32Vector) o; + return Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + + @Override + public String toString() { + return "Float32Vector{" + + "data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java new file mode 100644 index 00000000000..b61e6bfee55 --- /dev/null +++ b/bson/src/main/org/bson/Int8Vector.java @@ -0,0 +1,80 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import java.util.Arrays; +import java.util.Objects; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a vector of 8-bit signed integers, where each element in the vector is a byte. + *

+ * The {@link Int8Vector} is used to store and retrieve data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#int8Vector(byte[]) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @since 5.3 + */ +public final class Int8Vector extends Vector { + + private byte[] data; + + Int8Vector(final byte[] data) { + super(DataType.INT8); + this.data = assertNotNull(data); + } + + /** + * Retrieve the underlying byte array representing this {@link Int8Vector} vector, where each byte represents + * an element of a vector. + *

+ * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the underlying byte array representing this {@link Int8Vector} vector. + */ + public byte[] getData() { + return assertNotNull(data); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Int8Vector that = (Int8Vector) o; + return Objects.deepEquals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + + @Override + public String toString() { + return "Int8Vector{" + + "data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java new file mode 100644 index 00000000000..a5dd8f4dcdf --- /dev/null +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -0,0 +1,101 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import java.util.Arrays; +import java.util.Objects; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a packed bit vector, where each element of the vector is represented by a single bit (0 or 1). + *

+ * The {@link PackedBitVector} is used to store data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#packedBitVector(byte[], byte) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @since 5.3 + */ +public final class PackedBitVector extends Vector { + + private final byte padding; + private final byte[] data; + + PackedBitVector(final byte[] data, final byte padding) { + super(DataType.PACKED_BIT); + this.data = assertNotNull(data); + this.padding = padding; + } + + /** + * Retrieve the underlying byte array representing this {@link PackedBitVector} vector, where + * each bit represents an element of the vector (either 0 or 1). + *

+ * Note that the {@linkplain #getPadding() padding value} should be considered when interpreting the final byte of the array, + * as it indicates how many least-significant bits are to be ignored. + * + * @return the underlying byte array representing this {@link PackedBitVector} vector. + * @see #getPadding() + */ + public byte[] getData() { + return assertNotNull(data); + } + + /** + * Returns the padding value for this vector. + * + *

Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving + * {@linkplain #getData() the vector array}. For instance, if the padding value is 3, this means that the last byte contains + * 3 least-significant unused bits, which should be disregarded during operations.

+ *

+ * + * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the padding value (between 0 and 7). + */ + public byte getPadding() { + return this.padding; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PackedBitVector that = (PackedBitVector) o; + return padding == that.padding && Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Objects.hash(padding, Arrays.hashCode(data)); + } + + @Override + public String toString() { + return "PackedBitVector{" + + "padding=" + padding + + ", data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java new file mode 100644 index 00000000000..d267387d727 --- /dev/null +++ b/bson/src/main/org/bson/Vector.java @@ -0,0 +1,201 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import static org.bson.assertions.Assertions.isTrueArgument; +import static org.bson.assertions.Assertions.notNull; + +/** + * Represents a vector that is stored and retrieved using the BSON Binary Subtype 9 format. + * This class supports multiple vector {@link DataType}'s and provides static methods to create + * vectors. + *

+ * Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently + * in BSON using a binary format. + *

+ * NOTE: This class should be treated as sealed: it must not be extended or implemented by consumers of the library. + * + * @mongodb.server.release 6.0 + * @see BsonBinary + * @since 5.3 + */ +public abstract class Vector { + private final DataType dataType; + + Vector(final DataType dataType) { + this.dataType = dataType; + } + + /** + * Creates a vector with the {@link DataType#PACKED_BIT} data type. + *

+ * A {@link DataType#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte + * can hold up to 8 bits (vector elements). The padding parameter is used to specify how many least-significant bits in the final byte + * should be ignored.

+ * + *

For example, a vector with two bytes and a padding of 4 would have the following structure:

+ *
+     * Byte 1: 238 (binary: 11101110)
+     * Byte 2: 224 (binary: 11100000)
+     * Padding: 4 (ignore the last 4 bits in Byte 2)
+     * Resulting vector: 12 bits: 111011101110
+     * 
+ *

+ * NOTE: The byte array `data` is not copied; changes to the provided array will be reflected + * in the created {@link PackedBitVector} instance. + * + * @param data The byte array representing the packed bit vector data. Each byte can store 8 bits. + * @param padding The number of least-significant bits (0 to 7) to ignore in the final byte of the vector data. + * @return A {@link PackedBitVector} instance with the {@link DataType#PACKED_BIT} data type. + * @throws IllegalArgumentException If the padding value is greater than 7. + */ + public static PackedBitVector packedBitVector(final byte[] data, final byte padding) { + notNull("data", data); + isTrueArgument("Padding must be between 0 and 7 bits. Provided padding: " + padding, padding >= 0 && padding <= 7); + isTrueArgument("Padding must be 0 if vector is empty. Provided padding: " + padding, padding == 0 || data.length > 0); + return new PackedBitVector(data, padding); + } + + /** + * Creates a vector with the {@link DataType#INT8} data type. + * + *

A {@link DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, + * with values in the range [-128, 127].

+ *

+ * NOTE: The byte array `data` is not copied; changes to the provided array will be reflected + * in the created {@link Int8Vector} instance. + * + * @param data The byte array representing the {@link DataType#INT8} vector data. + * @return A {@link Int8Vector} instance with the {@link DataType#INT8} data type. + */ + public static Int8Vector int8Vector(final byte[] data) { + notNull("data", data); + return new Int8Vector(data); + } + + /** + * Creates a vector with the {@link DataType#FLOAT32} data type. + *

+ * A {@link DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

+ *

+ * NOTE: The float array `data` is not copied; changes to the provided array will be reflected + * in the created {@link Float32Vector} instance. + * + * @param data The float array representing the {@link DataType#FLOAT32} vector data. + * @return A {@link Float32Vector} instance with the {@link DataType#FLOAT32} data type. + */ + public static Float32Vector floatVector(final float[] data) { + notNull("data", data); + return new Float32Vector(data); + } + + /** + * Returns the {@link PackedBitVector}. + * + * @return {@link PackedBitVector}. + * @throws IllegalStateException if this vector is not of type {@link DataType#PACKED_BIT}. Use {@link #getDataType()} to check the vector + * type before calling this method. + */ + public PackedBitVector asPackedBitVector() { + ensureType(DataType.PACKED_BIT); + return (PackedBitVector) this; + } + + /** + * Returns the {@link Int8Vector}. + * + * @return {@link Int8Vector}. + * @throws IllegalStateException if this vector is not of type {@link DataType#INT8}. Use {@link #getDataType()} to check the vector + * type before calling this method. + */ + public Int8Vector asInt8Vector() { + ensureType(DataType.INT8); + return (Int8Vector) this; + } + + /** + * Returns the {@link Float32Vector}. + * + * @return {@link Float32Vector}. + * @throws IllegalStateException if this vector is not of type {@link DataType#FLOAT32}. Use {@link #getDataType()} to check the vector + * type before calling this method. + */ + public Float32Vector asFloat32Vector() { + ensureType(DataType.FLOAT32); + return (Float32Vector) this; + } + + /** + * Returns {@link DataType} of the vector. + * + * @return the data type of the vector. + */ + public DataType getDataType() { + return this.dataType; + } + + + private void ensureType(final DataType expected) { + if (this.dataType != expected) { + throw new IllegalStateException("Expected vector data type " + expected + ", but found " + this.dataType); + } + } + + /** + * Represents the data type (dtype) of a vector. + *

+ * Each dtype determines how the data in the vector is stored, including how many bits are used to represent each element + * in the vector. + * + * @mongodb.server.release 6.0 + * @since 5.3 + */ + public enum DataType { + /** + * An INT8 vector is a vector of 8-bit signed integers. The vector is stored as an array of bytes, where each byte + * represents a signed integer in the range [-128, 127]. + */ + INT8((byte) 0x03), + /** + * A FLOAT32 vector is a vector of 32-bit floating-point numbers, where each element in the vector is a float. + */ + FLOAT32((byte) 0x27), + /** + * A PACKED_BIT vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). + * Each byte can hold up to 8 bits (vector elements). + */ + PACKED_BIT((byte) 0x10); + + private final byte value; + + DataType(final byte value) { + this.value = value; + } + + /** + * Returns the byte value associated with this {@link DataType}. + * + *

This value is used in the BSON binary format to indicate the data type of the vector.

+ * + * @return the byte value representing the {@link DataType}. + */ + public byte getValue() { + return value; + } + } +} + diff --git a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java index 5969763546b..b454206d5e8 100644 --- a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java +++ b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java @@ -16,10 +16,12 @@ package org.bson.codecs; +import org.bson.BsonBinarySubType; import org.bson.BsonReader; import org.bson.BsonType; import org.bson.Transformer; import org.bson.UuidRepresentation; +import org.bson.Vector; import org.bson.codecs.configuration.CodecConfigurationException; import org.bson.codecs.configuration.CodecRegistry; @@ -28,6 +30,8 @@ import java.util.Arrays; import java.util.UUID; +import static org.bson.internal.UuidHelper.isLegacyUUID; + /** * Helper methods for Codec implementations for containers, e.g. {@code Map} and {@code Iterable}. */ @@ -42,28 +46,50 @@ static Object readValue(final BsonReader reader, final DecoderContext decoderCon reader.readNull(); return null; } else { - Codec codec = bsonTypeCodecMap.get(bsonType); + Codec currentCodec = bsonTypeCodecMap.get(bsonType); + + if (bsonType == BsonType.BINARY) { + byte binarySubType = reader.peekBinarySubType(); + currentCodec = getBinarySubTypeCodec( + reader, + uuidRepresentation, + registry, binarySubType, + currentCodec); + } + + return valueTransformer.transform(currentCodec.decode(reader, decoderContext)); + } + } + + private static Codec getBinarySubTypeCodec(final BsonReader reader, + final UuidRepresentation uuidRepresentation, + final CodecRegistry registry, + final byte binarySubType, + final Codec binaryTypeCodec) { - if (bsonType == BsonType.BINARY && reader.peekBinarySize() == 16) { - switch (reader.peekBinarySubType()) { - case 3: - if (uuidRepresentation == UuidRepresentation.JAVA_LEGACY - || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY - || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY) { - codec = registry.get(UUID.class); - } - break; - case 4: - if (uuidRepresentation == UuidRepresentation.STANDARD) { - codec = registry.get(UUID.class); - } - break; - default: - break; - } + if (binarySubType == BsonBinarySubType.VECTOR.getValue()) { + Codec vectorCodec = registry.get(Vector.class, registry); + if (vectorCodec != null) { + return vectorCodec; + } + } else if (reader.peekBinarySize() == 16) { + switch (binarySubType) { + case 3: + if (isLegacyUUID(uuidRepresentation)) { + return registry.get(UUID.class); + } + break; + case 4: + if (uuidRepresentation == UuidRepresentation.STANDARD) { + return registry.get(UUID.class); + } + break; + default: + break; } - return valueTransformer.transform(codec.decode(reader, decoderContext)); } + + return binaryTypeCodec; } static Codec getCodec(final CodecRegistry codecRegistry, final Type type) { @@ -77,7 +103,6 @@ static Codec getCodec(final CodecRegistry codecRegistry, final Type type) { } } - private ContainerCodecHelper() { } } diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java new file mode 100644 index 00000000000..a6df27e3f87 --- /dev/null +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -0,0 +1,56 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Float32Vector; + +/** + * Encodes and decodes {@link Float32Vector} objects. + * + */ +final class Float32VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Float32Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Float32Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asFloat32Vector(); + } + + @Override + public Class getEncoderClass() { + return Float32Vector.class; + } +} + diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java new file mode 100644 index 00000000000..a9a70f53746 --- /dev/null +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java @@ -0,0 +1,58 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Int8Vector; + +/** + * Encodes and decodes {@link Int8Vector} objects. + * + * @since 5.3 + */ +final class Int8VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Int8Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Int8Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asInt8Vector(); + } + + + @Override + public Class getEncoderClass() { + return Int8Vector.class; + } +} + diff --git a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java new file mode 100644 index 00000000000..6fcb9552955 --- /dev/null +++ b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java @@ -0,0 +1,59 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.PackedBitVector; + +/** + * Encodes and decodes {@link PackedBitVector} objects. + * + */ +final class PackedBitVectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final PackedBitVector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public PackedBitVector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BsonInvalidOperationException( + "Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asPackedBitVector(); + } + + + @Override + public Class getEncoderClass() { + return PackedBitVector.class; + } +} + + diff --git a/bson/src/main/org/bson/codecs/ValueCodecProvider.java b/bson/src/main/org/bson/codecs/ValueCodecProvider.java index 80ec5e6f18d..3a921c1b08a 100644 --- a/bson/src/main/org/bson/codecs/ValueCodecProvider.java +++ b/bson/src/main/org/bson/codecs/ValueCodecProvider.java @@ -42,6 +42,10 @@ *
  • {@link org.bson.codecs.StringCodec}
  • *
  • {@link org.bson.codecs.SymbolCodec}
  • *
  • {@link org.bson.codecs.UuidCodec}
  • + *
  • {@link VectorCodec}
  • + *
  • {@link Float32VectorCodec}
  • + *
  • {@link Int8VectorCodec}
  • + *
  • {@link PackedBitVectorCodec}
  • *
  • {@link org.bson.codecs.ByteCodec}
  • *
  • {@link org.bson.codecs.ShortCodec}
  • *
  • {@link org.bson.codecs.ByteArrayCodec}
  • @@ -86,6 +90,10 @@ private void addCodecs() { addCodec(new StringCodec()); addCodec(new SymbolCodec()); addCodec(new OverridableUuidRepresentationUuidCodec()); + addCodec(new VectorCodec()); + addCodec(new Float32VectorCodec()); + addCodec(new Int8VectorCodec()); + addCodec(new PackedBitVectorCodec()); addCodec(new ByteCodec()); addCodec(new PatternCodec()); diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java new file mode 100644 index 00000000000..87d847664dc --- /dev/null +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -0,0 +1,56 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Vector; + +/** + * Encodes and decodes {@link Vector} objects. + * + */ + final class VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found " + subType); + } + + return reader.readBinaryData() + .asBinary() + .asVector(); + } + + @Override + public Class getEncoderClass() { + return Vector.class; + } +} + + diff --git a/bson/src/main/org/bson/internal/UuidHelper.java b/bson/src/main/org/bson/internal/UuidHelper.java index efe3d5b5812..9c46614b56e 100644 --- a/bson/src/main/org/bson/internal/UuidHelper.java +++ b/bson/src/main/org/bson/internal/UuidHelper.java @@ -124,6 +124,12 @@ public static UUID decodeBinaryToUuid(final byte[] data, final byte type, final return new UUID(readLongFromArrayBigEndian(localData, 0), readLongFromArrayBigEndian(localData, 8)); } + public static boolean isLegacyUUID(final UuidRepresentation uuidRepresentation) { + return uuidRepresentation == UuidRepresentation.JAVA_LEGACY + || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY + || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY; + } + private UuidHelper() { } } diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java new file mode 100644 index 00000000000..9dbf583d2b0 --- /dev/null +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -0,0 +1,177 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.internal.vector; + +import org.bson.BsonBinary; +import org.bson.BsonInvalidOperationException; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.bson.assertions.Assertions; +import org.bson.types.Binary; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +/** + * Helper class for encoding and decoding vectors to and from {@link BsonBinary}/{@link Binary}. + * + *

    + * This class is not part of the public API and may be removed or changed at any time. + * + * @see Vector + * @see BsonBinary#asVector() + * @see BsonBinary#BsonBinary(Vector) + */ +public final class VectorHelper { + + private static final ByteOrder STORED_BYTE_ORDER = ByteOrder.LITTLE_ENDIAN; + private static final String ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE = "Unknown vector data type: "; + private static final byte ZERO_PADDING = 0; + + private VectorHelper() { + //NOP + } + + private static final int METADATA_SIZE = 2; + + public static byte[] encodeVectorToBinary(final Vector vector) { + Vector.DataType dataType = vector.getDataType(); + switch (dataType) { + case INT8: + return encodeVector(dataType.getValue(), ZERO_PADDING, vector.asInt8Vector().getData()); + case PACKED_BIT: + PackedBitVector packedBitVector = vector.asPackedBitVector(); + return encodeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getData()); + case FLOAT32: + return encodeVector(dataType.getValue(), vector.asFloat32Vector().getData()); + default: + throw Assertions.fail(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dataType); + } + } + + /** + * Decodes a vector from a binary representation. + *

    + * encodedVector is not mutated nor stored in the returned {@link Vector}. + */ + public static Vector decodeBinaryToVector(final byte[] encodedVector) { + isTrue("Vector encoded array length must be at least 2, but found: " + encodedVector.length, encodedVector.length >= METADATA_SIZE); + Vector.DataType dataType = determineVectorDType(encodedVector[0]); + byte padding = encodedVector[1]; + switch (dataType) { + case INT8: + return decodeInt8Vector(encodedVector, padding); + case PACKED_BIT: + return decodePackedBitVector(encodedVector, padding); + case FLOAT32: + return decodeFloat32Vector(encodedVector, padding); + default: + throw Assertions.fail(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dataType); + } + } + + private static Float32Vector decodeFloat32Vector(final byte[] encodedVector, final byte padding) { + isTrue("Padding must be 0 for FLOAT32 data type, but found: " + padding, padding == 0); + return Vector.floatVector(decodeLittleEndianFloats(encodedVector)); + } + + private static PackedBitVector decodePackedBitVector(final byte[] encodedVector, final byte padding) { + byte[] packedBitVector = extractVectorData(encodedVector); + isTrue("Padding must be 0 if vector is empty, but found: " + padding, padding == 0 || packedBitVector.length > 0); + isTrue("Padding must be between 0 and 7 bits, but found: " + padding, padding >= 0 && padding <= 7); + return Vector.packedBitVector(packedBitVector, padding); + } + + private static Int8Vector decodeInt8Vector(final byte[] encodedVector, final byte padding) { + isTrue("Padding must be 0 for INT8 data type, but found: " + padding, padding == 0); + byte[] int8Vector = extractVectorData(encodedVector); + return Vector.int8Vector(int8Vector); + } + + private static byte[] extractVectorData(final byte[] encodedVector) { + int vectorDataLength = encodedVector.length - METADATA_SIZE; + byte[] vectorData = new byte[vectorDataLength]; + System.arraycopy(encodedVector, METADATA_SIZE, vectorData, 0, vectorDataLength); + return vectorData; + } + + private static byte[] encodeVector(final byte dType, final byte padding, final byte[] vectorData) { + final byte[] bytes = new byte[vectorData.length + METADATA_SIZE]; + bytes[0] = dType; + bytes[1] = padding; + System.arraycopy(vectorData, 0, bytes, METADATA_SIZE, vectorData.length); + return bytes; + } + + private static byte[] encodeVector(final byte dType, final float[] vectorData) { + final byte[] bytes = new byte[vectorData.length * Float.BYTES + METADATA_SIZE]; + + bytes[0] = dType; + bytes[1] = ZERO_PADDING; + + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.order(STORED_BYTE_ORDER); + buffer.position(METADATA_SIZE); + + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + + // The JVM may optimize this operation internally, potentially using intrinsics + // or platform-specific optimizations (such as SIMD). If the byte order matches the underlying system's + // native order, the operation may involve a direct memory copy. + floatBuffer.put(vectorData); + + return bytes; + } + + private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { + isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type, but found: " + encodedVector.length, + (encodedVector.length - METADATA_SIZE) % Float.BYTES == 0); + + int vectorSize = encodedVector.length - METADATA_SIZE; + + int numFloats = vectorSize / Float.BYTES; + float[] floatArray = new float[numFloats]; + + ByteBuffer buffer = ByteBuffer.wrap(encodedVector, METADATA_SIZE, vectorSize); + buffer.order(STORED_BYTE_ORDER); + + // The JVM may optimize this operation internally, potentially using intrinsics + // or platform-specific optimizations (such as SIMD). If the byte order matches the underlying system's + // native order, the operation may involve a direct memory copy. + buffer.asFloatBuffer().get(floatArray); + return floatArray; + } + + public static Vector.DataType determineVectorDType(final byte dType) { + Vector.DataType[] values = Vector.DataType.values(); + for (Vector.DataType value : values) { + if (value.getValue() == dType) { + return value; + } + } + throw new BsonInvalidOperationException(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dType); + } + + private static void isTrue(final String message, final boolean condition) { + if (!condition) { + throw new BsonInvalidOperationException(message); + } + } +} diff --git a/bson/src/test/resources/bson-binary-vector/float32.json b/bson/src/test/resources/bson-binary-vector/float32.json new file mode 100644 index 00000000000..e1d142c184b --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/float32.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Vector with decimals and negative value FLOAT32", + "valid": true, + "vector": [127.7, -7.7], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927006666FF426666F6C000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson-binary-vector/int8.json b/bson/src/test/resources/bson-binary-vector/int8.json new file mode 100644 index 00000000000..c10c1b7d4e2 --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/int8.json @@ -0,0 +1,56 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson-binary-vector/packed_bit.json b/bson/src/test/resources/bson-binary-vector/packed_bit.json new file mode 100644 index 00000000000..69fb3948335 --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/packed_bit.json @@ -0,0 +1,97 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Exceeding maximum padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 8 + }, + { + "description": "Negative padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": -1 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson/binary.json b/bson/src/test/resources/bson/binary.json index 38a70d1fe0c..29d88471afe 100644 --- a/bson/src/test/resources/bson/binary.json +++ b/bson/src/test/resources/bson/binary.json @@ -74,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [ @@ -120,4 +150,4 @@ "string": "{\"x\" : { \"$uuid\" : \"----d264-44b3-4--9-90e8-e7d1dfc0----\"}}" } ] -} +} \ No newline at end of file diff --git a/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy b/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy index e51094e964f..503440daa04 100644 --- a/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy +++ b/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy @@ -48,9 +48,14 @@ class BsonBinarySpecification extends Specification { data == bsonBinary.getData() where: - subType << [BsonBinarySubType.BINARY, BsonBinarySubType.FUNCTION, BsonBinarySubType.MD5, - BsonBinarySubType.OLD_BINARY, BsonBinarySubType.USER_DEFINED, BsonBinarySubType.UUID_LEGACY, - BsonBinarySubType.UUID_STANDARD] + subType << [BsonBinarySubType.BINARY, + BsonBinarySubType.FUNCTION, + BsonBinarySubType.MD5, + BsonBinarySubType.OLD_BINARY, + BsonBinarySubType.USER_DEFINED, + BsonBinarySubType.UUID_LEGACY, + BsonBinarySubType.UUID_STANDARD, + BsonBinarySubType.VECTOR] } @Unroll diff --git a/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy b/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy index 8e502891095..448d63f23fd 100644 --- a/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy +++ b/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy @@ -34,5 +34,6 @@ class BsonBinarySubTypeSpecification extends Specification { 6 | false 7 | false 8 | false + 9 | false } } diff --git a/bson/src/test/unit/org/bson/BsonBinaryTest.java b/bson/src/test/unit/org/bson/BsonBinaryTest.java new file mode 100644 index 00000000000..029c611c594 --- /dev/null +++ b/bson/src/test/unit/org/bson/BsonBinaryTest.java @@ -0,0 +1,266 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class BsonBinaryTest { + + private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); + private static final byte INT8_DTYPE = Vector.DataType.INT8.getValue(); + private static final byte PACKED_BIT_DTYPE = Vector.DataType.PACKED_BIT.getValue(); + public static final int ZERO_PADDING = 0; + + @Test + void shouldThrowExceptionWhenCreatingBsonBinaryWithNullVector() { + // given + Vector vector = null; + + // when & then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new BsonBinary(vector)); + assertEquals("Vector must not be null", exception.getMessage()); + } + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionWhenBsonBinarySubTypeIsNotVector(final BsonBinarySubType bsonBinarySubType) { + // given + byte[] data = new byte[]{1, 2, 3, 4}; + BsonBinary bsonBinary = new BsonBinary(bsonBinarySubType.getValue(), data); + + // when & then + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); + assertEquals("type must be a Vector subtype.", exception.getMessage()); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideFloatVectors") + void shouldEncodeFloatVector(final Vector actualFloat32Vector, final byte[] expectedBsonEncodedVector) { + // when + BsonBinary actualBsonBinary = new BsonBinary(actualFloat32Vector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); + + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideFloatVectors") + void shouldDecodeFloatVector(final Float32Vector expectedFloatVector, final byte[] bsonEncodedVector) { + // when + Float32Vector decodedVector = (Float32Vector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); + + // then + assertEquals(expectedFloatVector, decodedVector); + } + + private static Stream provideFloatVectors() { + return Stream.of( + arguments( + Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f, -1.0f, Float.MAX_VALUE, Float.MIN_VALUE, Float.POSITIVE_INFINITY, + Float.NEGATIVE_INFINITY}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 205, (byte) 204, (byte) 140, (byte) 63, // 1.1f in little-endian + (byte) 205, (byte) 204, (byte) 12, (byte) 64, // 2.2f in little-endian + (byte) 51, (byte) 51, (byte) 83, (byte) 64, // 3.3f in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 191, // -1.0f in little-endian + (byte) 255, (byte) 255, (byte) 127, (byte) 127, // Float.MAX_VALUE in little-endian + (byte) 1, (byte) 0, (byte) 0, (byte) 0, // Float.MIN_VALUE in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 127, // Float.POSITIVE_INFINITY in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 255 // Float.NEGATIVE_INFINITY in little-endian + } + ), + arguments( + Vector.floatVector(new float[]{0.0f}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 0, (byte) 0, (byte) 0, (byte) 0 // 0.0f in little-endian + } + ), + arguments( + Vector.floatVector(new float[]{}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING} + ) + ); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideInt8Vectors") + void shouldEncodeInt8Vector(final Vector actualInt8Vector, final byte[] expectedBsonEncodedVector) { + // when + BsonBinary actualBsonBinary = new BsonBinary(actualInt8Vector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); + + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideInt8Vectors") + void shouldDecodeInt8Vector(final Int8Vector expectedInt8Vector, final byte[] bsonEncodedVector) { + // when + Int8Vector decodedVector = (Int8Vector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); + + // then + assertEquals(expectedInt8Vector, decodedVector); + } + + private static Stream provideInt8Vectors() { + return Stream.of( + arguments( + Vector.int8Vector(new byte[]{Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE}), + new byte[]{INT8_DTYPE, ZERO_PADDING, Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE + }), + arguments(Vector.int8Vector(new byte[]{}), + new byte[]{INT8_DTYPE, ZERO_PADDING} + ) + ); + } + + @ParameterizedTest + @MethodSource("providePackedBitVectors") + void shouldEncodePackedBitVector(final Vector actualPackedBitVector, final byte[] expectedBsonEncodedVector) { + // when + BsonBinary actualBsonBinary = new BsonBinary(actualPackedBitVector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); + + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest + @MethodSource("providePackedBitVectors") + void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector, final byte[] bsonEncodedVector) { + // when + PackedBitVector decodedVector = (PackedBitVector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); + + // then + assertEquals(expectedPackedBitVector, decodedVector); + } + + private static Stream providePackedBitVectors() { + return Stream.of( + arguments( + Vector.packedBitVector(new byte[]{(byte) 0, (byte) 255, (byte) 10}, (byte) 2), + new byte[]{PACKED_BIT_DTYPE, 2, (byte) 0, (byte) 255, (byte) 10} + ), + arguments( + Vector.packedBitVector(new byte[0], (byte) 0), + new byte[]{PACKED_BIT_DTYPE, 0} + )); + } + + @Test + void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { + // given + byte[] invalidData = {FLOAT32_DTYPE, 0, 10, 20, 30}; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Byte array length must be a multiple of 4 for FLOAT32 data type, but found: " + invalidData.length, + thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(ints = {0, 1}) + void shouldThrowExceptionWhenEncodedVectorLengthIsLessThenMetadataLength(final int encodedVectorLength) { + // given + byte[] invalidData = new byte[encodedVectorLength]; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Vector encoded array length must be at least 2, but found: " + encodedVectorLength, + thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidFloatArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {FLOAT32_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Padding must be 0 for FLOAT32 data type, but found: " + invalidPadding, thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidInt8ArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {INT8_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Padding must be 0 for INT8 data type, but found: " + invalidPadding, thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Padding must be between 0 and 7 bits, but found: " + invalidPadding, thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1, 2, 3, 4, 5, 6, 7, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(final byte invalidPadding) { + // given + byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding}; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Padding must be 0 if vector is empty, but found: " + invalidPadding, thrown.getMessage()); + } + + @Test + void shouldThrowWhenUnknownVectorDType() { + // when + BsonBinary bsonBinary = new BsonBinary(BsonBinarySubType.VECTOR, new byte[]{(byte) 0}); + assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); + } +} diff --git a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java index 15e27065ba2..c9e22fcce7a 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java @@ -40,6 +40,9 @@ public class BsonBinaryWriterTest { + private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); + private static final int ZERO_PADDING = 0; + private BsonBinaryWriter writer; private BasicOutputBuffer buffer; @@ -299,12 +302,38 @@ public void testWriteBinary() { writer.writeBinaryData("b1", new BsonBinary(new byte[]{0, 0, 0, 0, 0, 0, 0, 0})); writer.writeBinaryData("b2", new BsonBinary(BsonBinarySubType.OLD_BINARY, new byte[]{1, 1, 1, 1, 1})); writer.writeBinaryData("b3", new BsonBinary(BsonBinarySubType.FUNCTION, new byte[]{})); + writer.writeBinaryData("b4", new BsonBinary(BsonBinarySubType.VECTOR, new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 205, (byte) 204, (byte) 140, (byte) 63})); writer.writeEndDocument(); + byte[] expectedValues = new byte[]{ + 64, // total document length + 0, 0, 0, + + //Binary + (byte) BsonType.BINARY.getValue(), + 98, 49, 0, // name "b1" + 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + // Old binary + (byte) BsonType.BINARY.getValue(), + 98, 50, 0, // name "b2" + 9, 0, 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, + + // Function binary + (byte) BsonType.BINARY.getValue(), + 98, 51, 0, // name "b3" + 0, 0, 0, 0, 1, + + //Vector binary + (byte) BsonType.BINARY.getValue(), + 98, 52, 0, // name "b4" + 6, 0, 0, 0, // total length, int32 (little endian) + BsonBinarySubType.VECTOR.getValue(), FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, 63, + + 0 //end of document + }; - byte[] expectedValues = {49, 0, 0, 0, 5, 98, 49, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 98, 50, 0, - 9, 0, - 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, 5, 98, 51, 0, 0, 0, 0, 0, 1, 0}; assertArrayEquals(expectedValues, buffer.toByteArray()); } diff --git a/bson/src/test/unit/org/bson/BsonHelper.java b/bson/src/test/unit/org/bson/BsonHelper.java index 985e398b1ca..59fdba474a2 100644 --- a/bson/src/test/unit/org/bson/BsonHelper.java +++ b/bson/src/test/unit/org/bson/BsonHelper.java @@ -17,10 +17,12 @@ package org.bson; import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.DecoderContext; import org.bson.codecs.EncoderContext; import org.bson.io.BasicOutputBuffer; import org.bson.types.Decimal128; import org.bson.types.ObjectId; +import util.Hex; import java.nio.ByteBuffer; import java.util.Date; @@ -109,4 +111,23 @@ public static ByteBuffer toBson(final BsonDocument document) { private BsonHelper() { } + + public static BsonDocument decodeToDocument(final String subjectHex, final String description) { + ByteBuffer byteBuffer = ByteBuffer.wrap(Hex.decode(subjectHex)); + BsonDocument actualDecodedDocument = new BsonDocumentCodec().decode(new BsonBinaryReader(byteBuffer), + DecoderContext.builder().build()); + + if (byteBuffer.hasRemaining()) { + throw new BsonSerializationException(format("Should have consumed all bytes, but " + byteBuffer.remaining() + + " still remain in the buffer for document with description ", + description)); + } + return actualDecodedDocument; + } + + public static String encodeToHex(final BsonDocument decodedDocument) { + BasicOutputBuffer outputBuffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), decodedDocument, EncoderContext.builder().build()); + return Hex.encode(outputBuffer.toByteArray()); + } } diff --git a/bson/src/test/unit/org/bson/GenericBsonTest.java b/bson/src/test/unit/org/bson/GenericBsonTest.java index 2f50bcd7f61..6ba2c6ae382 100644 --- a/bson/src/test/unit/org/bson/GenericBsonTest.java +++ b/bson/src/test/unit/org/bson/GenericBsonTest.java @@ -16,10 +16,6 @@ package org.bson; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.DecoderContext; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import org.bson.json.JsonMode; import org.bson.json.JsonParseException; import org.bson.json.JsonWriterSettings; @@ -27,7 +23,6 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import util.Hex; import util.JsonPoweredTestHelper; import java.io.File; @@ -35,7 +30,6 @@ import java.io.StringReader; import java.io.StringWriter; import java.net.URISyntaxException; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -44,6 +38,8 @@ import static java.lang.String.format; import static org.bson.BsonDocument.parse; +import static org.bson.BsonHelper.decodeToDocument; +import static org.bson.BsonHelper.encodeToHex; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -207,25 +203,6 @@ private boolean shouldEscapeCharacter(final char escapedChar) { } } - private BsonDocument decodeToDocument(final String subjectHex, final String description) { - ByteBuffer byteBuffer = ByteBuffer.wrap(Hex.decode(subjectHex)); - BsonDocument actualDecodedDocument = new BsonDocumentCodec().decode(new BsonBinaryReader(byteBuffer), - DecoderContext.builder().build()); - - if (byteBuffer.hasRemaining()) { - throw new BsonSerializationException(format("Should have consumed all bytes, but " + byteBuffer.remaining() - + " still remain in the buffer for document with description ", - description)); - } - return actualDecodedDocument; - } - - private String encodeToHex(final BsonDocument decodedDocument) { - BasicOutputBuffer outputBuffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), decodedDocument, EncoderContext.builder().build()); - return Hex.encode(outputBuffer.toByteArray()); - } - private void runDecodeError(final BsonDocument testCase) { try { String description = testCase.getString("description").getValue(); diff --git a/bson/src/test/unit/org/bson/VectorTest.java b/bson/src/test/unit/org/bson/VectorTest.java new file mode 100644 index 00000000000..36cc7156db6 --- /dev/null +++ b/bson/src/test/unit/org/bson/VectorTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class VectorTest { + + @Test + void shouldCreateInt8Vector() { + // given + byte[] data = {1, 2, 3, 4, 5}; + + // when + Int8Vector vector = Vector.int8Vector(data); + + // then + assertNotNull(vector); + assertEquals(Vector.DataType.INT8, vector.getDataType()); + assertArrayEquals(data, vector.getData()); + } + + @Test + void shouldThrowExceptionWhenCreatingInt8VectorWithNullData() { + // given + byte[] data = null; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.int8Vector(data)); + assertEquals("data can not be null", exception.getMessage()); + } + + @Test + void shouldCreateFloat32Vector() { + // given + float[] data = {1.0f, 2.0f, 3.0f}; + + // when + Float32Vector vector = Vector.floatVector(data); + + // then + assertNotNull(vector); + assertEquals(Vector.DataType.FLOAT32, vector.getDataType()); + assertArrayEquals(data, vector.getData()); + } + + @Test + void shouldThrowExceptionWhenCreatingFloat32VectorWithNullData() { + // given + float[] data = null; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.floatVector(data)); + assertEquals("data can not be null", exception.getMessage()); + } + + + @ParameterizedTest(name = "{index}: validPadding={0}") + @ValueSource(bytes = {0, 1, 2, 3, 4, 5, 6, 7}) + void shouldCreatePackedBitVector(final byte validPadding) { + // given + byte[] data = {(byte) 0b10101010, (byte) 0b01010101}; + + // when + PackedBitVector vector = Vector.packedBitVector(data, validPadding); + + // then + assertNotNull(vector); + assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); + assertArrayEquals(data, vector.getData()); + assertEquals(validPadding, vector.getPadding()); + } + + @ParameterizedTest(name = "{index}: invalidPadding={0}") + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionWhenPackedBitVectorHasInvalidPadding(final byte invalidPadding) { + // given + byte[] data = {(byte) 0b10101010}; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + Vector.packedBitVector(data, invalidPadding)); + assertEquals("state should be: Padding must be between 0 and 7 bits. Provided padding: " + invalidPadding, exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenPackedBitVectorIsCreatedWithNullData() { + // given + byte[] data = null; + byte padding = 0; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + Vector.packedBitVector(data, padding)); + assertEquals("data can not be null", exception.getMessage()); + } + + @Test + void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() { + // given + byte[] data = new byte[0]; + byte padding = 0; + + // when + PackedBitVector vector = Vector.packedBitVector(data, padding); + + // then + assertNotNull(vector); + assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); + assertArrayEquals(data, vector.getData()); + assertEquals(padding, vector.getPadding()); + } + + @Test + void shouldThrowExceptionWhenPackedBitVectorWithNonZeroPaddingAndEmptyData() { + // given + byte[] data = new byte[0]; + byte padding = 1; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + Vector.packedBitVector(data, padding)); + assertEquals("state should be: Padding must be 0 if vector is empty. Provided padding: " + padding, exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingInt8DataFromNonInt8Vector() { + // given + float[] data = {1.0f, 2.0f}; + Vector vector = Vector.floatVector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asInt8Vector); + assertEquals("Expected vector data type INT8, but found FLOAT32", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingFloat32DataFromNonFloat32Vector() { + // given + byte[] data = {1, 2, 3}; + Vector vector = Vector.int8Vector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asFloat32Vector); + assertEquals("Expected vector data type FLOAT32, but found INT8", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingPackedBitDataFromNonPackedBitVector() { + // given + float[] data = {1.0f, 2.0f}; + Vector vector = Vector.floatVector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asPackedBitVector); + assertEquals("Expected vector data type PACKED_BIT, but found FLOAT32", exception.getMessage()); + } +} diff --git a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java index 79c65573556..67c6b561aa5 100644 --- a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java @@ -23,6 +23,7 @@ import org.bson.BsonObjectId; import org.bson.ByteBufNIO; import org.bson.Document; +import org.bson.Vector; import org.bson.io.BasicOutputBuffer; import org.bson.io.BsonInput; import org.bson.io.ByteBufferBsonInput; @@ -80,6 +81,9 @@ public void testPrimitiveBSONTypeCodecs() throws IOException { doc.put("code", new Code("var i = 0")); doc.put("minkey", new MinKey()); doc.put("maxkey", new MaxKey()); + doc.put("vectorFloat", Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f})); + doc.put("vectorInt8", Vector.int8Vector(new byte[]{10, 20, 30, 40})); + doc.put("vectorPackedBit", Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3)); // doc.put("pattern", Pattern.compile("^hello")); // TODO: Pattern doesn't override equals method! doc.put("null", null); diff --git a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy index c20299715e0..23c46fb7b0b 100644 --- a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy +++ b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy @@ -17,6 +17,10 @@ package org.bson.codecs import org.bson.Document +import org.bson.Float32Vector +import org.bson.Int8Vector +import org.bson.PackedBitVector +import org.bson.Vector import org.bson.codecs.configuration.CodecRegistries import org.bson.types.Binary import org.bson.types.Code @@ -32,6 +36,8 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicLong import java.util.regex.Pattern +//Codenarc +@SuppressWarnings("VectorIsObsolete") class ValueCodecProviderSpecification extends Specification { private final provider = new ValueCodecProvider() private final registry = CodecRegistries.fromProviders(provider) @@ -56,6 +62,10 @@ class ValueCodecProviderSpecification extends Specification { provider.get(Short, registry) instanceof ShortCodec provider.get(byte[], registry) instanceof ByteArrayCodec provider.get(Float, registry) instanceof FloatCodec + provider.get(Vector, registry) instanceof VectorCodec + provider.get(Float32Vector, registry) instanceof Float32VectorCodec + provider.get(Int8Vector, registry) instanceof Int8VectorCodec + provider.get(PackedBitVector, registry) instanceof PackedBitVectorCodec provider.get(Binary, registry) instanceof BinaryCodec provider.get(MinKey, registry) instanceof MinKeyCodec diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java new file mode 100644 index 00000000000..bf33af90cae --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BsonBinary; +import org.bson.BsonBinaryReader; +import org.bson.BsonBinarySubType; +import org.bson.BsonBinaryWriter; +import org.bson.BsonDocument; +import org.bson.BsonInvalidOperationException; +import org.bson.BsonType; +import org.bson.BsonWriter; +import org.bson.ByteBufNIO; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.bson.io.BasicOutputBuffer; +import org.bson.io.ByteBufferBsonInput; +import org.bson.io.OutputBuffer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.stream.Stream; + +import static org.bson.BsonHelper.toBson; +import static org.bson.assertions.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class VectorCodecTest extends CodecTestCase { + + private static Stream provideVectorsAndCodecs() { + return Stream.of( + arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec(), Float32Vector.class), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new Int8VectorCodec(), Int8Vector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new PackedBitVectorCodec(), PackedBitVector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new VectorCodec(), Vector.class), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new VectorCodec(), Vector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new VectorCodec(), Vector.class) + ); + } + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecs") + void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorCodec) throws IOException { + // given + BsonBinary bsonBinary = new BsonBinary(vectorToEncode); + byte[] encodedVector = bsonBinary.getData(); + ByteArrayOutputStream expectedStream = new ByteArrayOutputStream(); + // Total length of a Document (int 32). It is 0, because we do not expect + // codec to write the end of the document (that is when we back-patch the length of the document). + expectedStream.write(new byte[]{0, 0, 0, 0}); + // Bson type + expectedStream.write((byte) BsonType.BINARY.getValue()); + // Field name "b4" + expectedStream.write(new byte[]{98, 52, 0}); + // Total length of binary data (little-endian format) + expectedStream.write(new byte[]{(byte) encodedVector.length, 0, 0, 0}); + // Vector binary subtype + expectedStream.write(BsonBinarySubType.VECTOR.getValue()); + // Actual BSON binary data + expectedStream.write(encodedVector); + + OutputBuffer buffer = new BasicOutputBuffer(); + BsonWriter writer = new BsonBinaryWriter(buffer); + writer.writeStartDocument(); + writer.writeName("b4"); + + // when + vectorCodec.encode(writer, vectorToEncode, EncoderContext.builder().build()); + + // then + assertArrayEquals(expectedStream.toByteArray(), buffer.toByteArray()); + } + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecs") + void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorCodec) { + // given + OutputBuffer buffer = new BasicOutputBuffer(); + BsonWriter writer = new BsonBinaryWriter(buffer); + writer.writeStartDocument(); + writer.writeName("vector"); + writer.writeBinaryData(new BsonBinary(vectorToDecode)); + writer.writeEndDocument(); + + BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(buffer.toByteArray())))); + reader.readStartDocument(); + + // when + Vector decodedVector = vectorCodec.decode(reader, DecoderContext.builder().build()); + + // then + assertDoesNotThrow(reader::readEndDocument); + assertNotNull(decodedVector); + assertEquals(vectorToDecode, decodedVector); + } + + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { + // given + BsonDocument document = new BsonDocument("name", new BsonBinary(subType.getValue(), new byte[]{})); + BsonBinaryReader reader = new BsonBinaryReader(toBson(document)); + reader.readStartDocument(); + + // when & then + Stream.of(new Float32VectorCodec(), new Int8VectorCodec(), new PackedBitVectorCodec()) + .forEach(codec -> { + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, () -> + codec.decode(reader, DecoderContext.builder().build())); + assertEquals("Expected vector binary subtype 9 but found: " + subType.getValue(), exception.getMessage()); + }); + } + + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecs") + void shouldReturnCorrectEncoderClass(final Vector vector, + final Codec codec, + final Class expectedEncoderClass) { + // when + Class encoderClass = codec.getEncoderClass(); + + // then + assertEquals(expectedEncoderClass, encoderClass); + } +} diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java new file mode 100644 index 00000000000..64e84f6afc8 --- /dev/null +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -0,0 +1,276 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.vector; + +import org.bson.BsonArray; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.Float32Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import util.JsonPoweredTestHelper; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import static java.lang.String.format; +import static org.bson.BsonHelper.decodeToDocument; +import static org.bson.BsonHelper.encodeToHex; +import static org.bson.internal.vector.VectorHelper.determineVectorDType; +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeFalse; + +/** + * See + * JSON-based tests that included in test resources. + */ +class VectorGenericBsonTest { + + private static final List TEST_NAMES_TO_IGNORE = Arrays.asList( + //NO API to set padding for floats available. + "FLOAT32 with padding", + //NO API to set padding for floats available. + "INT8 with padding", + //It is impossible to provide float inputs for INT8 in the API. + "INT8 with float inputs", + //It is impossible to provide float inputs for INT8. + "Underflow Vector PACKED_BIT", + //It is impossible to provide float inputs for PACKED_BIT in the API. + "Vector with float values PACKED_BIT", + //It is impossible to provide float inputs for INT8. + "Overflow Vector PACKED_BIT", + //It is impossible to overflow byte with values higher than 127 in the API. + "Overflow Vector INT8", + //It is impossible to underflow byte with values lower than -128 in the API. + "Underflow Vector INT8"); + + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTestCases") + void shouldPassAllOutcomes(@SuppressWarnings("unused") final String description, + final BsonDocument testDefinition, final BsonDocument testCase) { + assumeFalse(TEST_NAMES_TO_IGNORE.contains(testCase.get("description").asString().getValue())); + + String testKey = testDefinition.getString("test_key").getValue(); + boolean isValidVector = testCase.getBoolean("valid").getValue(); + if (isValidVector) { + runValidTestCase(testKey, testCase); + } else { + runInvalidTestCase(testCase); + } + } + + private static void runInvalidTestCase(final BsonDocument testCase) { + BsonArray arrayVector = testCase.getArray("vector"); + byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); + byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); + Vector.DataType expectedDType = determineVectorDType(dtypeByte); + + switch (expectedDType) { + case INT8: + byte[] expectedVectorData = toByteArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, + () -> Vector.int8Vector(expectedVectorData))); + break; + case PACKED_BIT: + byte[] expectedVectorPackedBitData = toByteArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, + () -> Vector.packedBitVector(expectedVectorPackedBitData, expectedPadding))); + break; + case FLOAT32: + float[] expectedFloatVector = toFloatArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, () -> Vector.floatVector(expectedFloatVector))); + break; + default: + throw new IllegalArgumentException("Unsupported vector data type: " + expectedDType); + } + } + + private static void runValidTestCase(final String testKey, final BsonDocument testCase) { + String description = testCase.getString("description").getValue(); + byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); + + byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); + Vector.DataType expectedDType = determineVectorDType(dtypeByte); + String expectedCanonicalBsonHex = testCase.getString("canonical_bson").getValue().toUpperCase(); + + BsonArray arrayVector = testCase.getArray("vector"); + BsonDocument actualDecodedDocument = decodeToDocument(expectedCanonicalBsonHex, description); + Vector actualVector = actualDecodedDocument.getBinary("vector").asVector(); + + switch (expectedDType) { + case INT8: + byte[] expectedVectorData = toByteArray(arrayVector); + byte[] actualVectorData = actualVector.asInt8Vector().getData(); + assertVectorDecoding( + expectedVectorData, + expectedDType, + actualVectorData, + actualVector); + + assertThatVectorCreationResultsInCorrectBinary(Vector.int8Vector(expectedVectorData), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, + description); + break; + case PACKED_BIT: + PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); + byte[] expectedVectorPackedBitData = toByteArray(arrayVector); + assertVectorDecoding( + expectedVectorPackedBitData, + expectedDType, expectedPadding, + actualPackedBitVector); + + assertThatVectorCreationResultsInCorrectBinary( + Vector.packedBitVector(expectedVectorPackedBitData, expectedPadding), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, + description); + break; + case FLOAT32: + Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); + float[] expectedFloatVector = toFloatArray(arrayVector); + assertVectorDecoding( + expectedFloatVector, + expectedDType, + actualFloat32Vector); + assertThatVectorCreationResultsInCorrectBinary( + Vector.floatVector(expectedFloatVector), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, + description); + break; + default: + throw new IllegalArgumentException("Unsupported vector data type: " + expectedDType); + } + } + + private static void assertValidationException(final RuntimeException runtimeException) { + assertTrue(runtimeException instanceof IllegalArgumentException || runtimeException instanceof IllegalStateException); + } + + private static void assertThatVectorCreationResultsInCorrectBinary(final Vector expectedVectorData, + final String testKey, + final BsonDocument actualDecodedDocument, + final String expectedCanonicalBsonHex, + final String description) { + BsonDocument documentToEncode = new BsonDocument(testKey, new BsonBinary(expectedVectorData)); + assertEquals(documentToEncode, actualDecodedDocument); + assertEquals(expectedCanonicalBsonHex, encodeToHex(documentToEncode), + format("Failed to create expected BSON for document with description '%s'", description)); + } + + private static void assertVectorDecoding(final byte[] expectedVectorData, + final Vector.DataType expectedDType, + final byte[] actualVectorData, + final Vector actualVector) { + Assertions.assertArrayEquals(actualVectorData, expectedVectorData, + () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); + assertEquals(expectedDType, actualVector.getDataType()); + } + + private static void assertVectorDecoding(final byte[] expectedVectorData, + final Vector.DataType expectedDType, + final byte expectedPadding, + final PackedBitVector actualVector) { + byte[] actualVectorData = actualVector.getData(); + assertVectorDecoding( + expectedVectorData, + expectedDType, + actualVectorData, + actualVector); + assertEquals(expectedPadding, actualVector.getPadding()); + } + + private static void assertVectorDecoding(final float[] expectedVectorData, + final Vector.DataType expectedDType, + final Float32Vector actualVector) { + float[] actualVectorArray = actualVector.getData(); + Assertions.assertArrayEquals(actualVectorArray, expectedVectorData, + () -> "Actual: " + Arrays.toString(actualVectorArray) + " != Expected:" + Arrays.toString(expectedVectorData)); + assertEquals(expectedDType, actualVector.getDataType()); + } + + private static byte[] toByteArray(final BsonArray arrayVector) { + byte[] bytes = new byte[arrayVector.size()]; + for (int i = 0; i < arrayVector.size(); i++) { + bytes[i] = (byte) arrayVector.get(i).asInt32().getValue(); + } + return bytes; + } + + private static float[] toFloatArray(final BsonArray arrayVector) { + float[] floats = new float[arrayVector.size()]; + for (int i = 0; i < arrayVector.size(); i++) { + BsonValue bsonValue = arrayVector.get(i); + if (bsonValue.isString()) { + floats[i] = parseFloat(bsonValue.asString()); + } else { + floats[i] = (float) arrayVector.get(i).asDouble().getValue(); + } + } + return floats; + } + + private static float parseFloat(final BsonString bsonValue) { + String floatValue = bsonValue.getValue(); + switch (floatValue) { + case "-inf": + return Float.NEGATIVE_INFINITY; + case "inf": + return Float.POSITIVE_INFINITY; + default: + return Float.parseFloat(floatValue); + } + } + + private static Stream provideTestCases() throws URISyntaxException, IOException { + List data = new ArrayList<>(); + for (File file : JsonPoweredTestHelper.getTestFiles("/bson-binary-vector")) { + BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); + for (BsonValue curValue : testDocument.getArray("tests", new BsonArray())) { + BsonDocument testCaseDocument = curValue.asDocument(); + data.add(Arguments.of(createTestCaseDescription(testDocument, testCaseDocument), testDocument, testCaseDocument)); + } + } + return data.stream(); + } + + private static String createTestCaseDescription(final BsonDocument testDocument, + final BsonDocument testCaseDocument) { + boolean isValidTestCase = testCaseDocument.getBoolean("valid").getValue(); + String fileDescription = testDocument.getString("description").getValue(); + String testDescription = testCaseDocument.getString("description").getValue(); + return "[Valid input: " + isValidTestCase + "] " + fileDescription + ": " + testDescription; + } +} diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java new file mode 100644 index 00000000000..f5b8e63f8c3 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.vector.AbstractVectorFunctionalTest; +import com.mongodb.reactivestreams.client.MongoClients; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; + +public class VectorFunctionalTest extends AbstractVectorFunctionalTest { + @Override + protected MongoClient getMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java new file mode 100644 index 00000000000..c3edf6983da --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java @@ -0,0 +1,346 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.ReadConcern; +import com.mongodb.ReadPreference; +import com.mongodb.WriteConcern; +import com.mongodb.client.Fixture; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.OperationTest; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.Document; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.codecs.pojo.PojoCodecProvider; +import org.bson.types.Binary; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; +import static org.bson.Vector.DataType.FLOAT32; +import static org.bson.Vector.DataType.INT8; +import static org.bson.Vector.DataType.PACKED_BIT; +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public abstract class AbstractVectorFunctionalTest extends OperationTest { + + private static final byte VECTOR_SUBTYPE = BsonBinarySubType.VECTOR.getValue(); + private static final String FIELD_VECTOR = "vector"; + private static final CodecRegistry CODEC_REGISTRY = fromRegistries(getDefaultCodecRegistry(), + fromProviders(PojoCodecProvider + .builder() + .automatic(true).build())); + private MongoCollection documentCollection; + + private MongoClient mongoClient; + + @BeforeEach + public void setUp() { + super.beforeEach(); + mongoClient = getMongoClient(getMongoClientSettingsBuilder() + .codecRegistry(CODEC_REGISTRY) + .build()); + documentCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()); + } + + @AfterEach + @SuppressWarnings("try") + public void afterEach() { + try (MongoClient ignore = mongoClient) { + super.afterEach(); + } + } + + private static MongoClientSettings.Builder getMongoClientSettingsBuilder() { + return Fixture.getMongoClientSettingsBuilder() + .readConcern(ReadConcern.MAJORITY) + .writeConcern(WriteConcern.MAJORITY) + .readPreference(ReadPreference.primary()); + } + + protected abstract MongoClient getMongoClient(MongoClientSettings settings); + + @ParameterizedTest + @ValueSource(bytes = {-1, 1, 2, 3, 4, 5, 6, 7, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(final byte invalidPadding) { + //given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{PACKED_BIT.getValue(), invalidPadding}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 if vector is empty, but found: " + invalidPadding, exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidFloat32Padding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{FLOAT32.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 for FLOAT32 data type, but found: " + invalidPadding, exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidInt8Padding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{INT8.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 for INT8 data type, but found: " + invalidPadding, exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionForInvalidPackedBitPadding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{PACKED_BIT.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be between 0 and 7 bits, but found: " + invalidPadding, exception.getMessage()); + } + + private static Stream provideValidVectors() { + return Stream.of( + Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), + Vector.int8Vector(new byte[]{10, 20, 30, 40}), + Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3) + ); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVector(final Vector expectedVector) { + // Given + Document documentToInsert = new Document(FIELD_VECTOR, expectedVector) + .append("otherField", 1); // to test that the next field is not affected + documentCollection.insertOne(documentToInsert); + + // when & then + Vector actualVector = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + + assertEquals(expectedVector, actualVector); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithBsonBinary(final Vector expectedVector) { + // Given + Document documentToInsert = new Document(FIELD_VECTOR, new BsonBinary(expectedVector)); + documentCollection.insertOne(documentToInsert); + + // when & then + Vector actualVector = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + + assertEquals(actualVector, actualVector); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithFloatVectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(FloatVectorPojo.class); + Float32Vector vector = Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}); + + // whe + floatVectorPojoMongoCollection.insertOne(new FloatVectorPojo(vector)); + FloatVectorPojo floatVectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(floatVectorPojo); + assertEquals(vector, floatVectorPojo.getVector()); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithInt8VectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(Int8VectorPojo.class); + Int8Vector vector = Vector.int8Vector(new byte[]{10, 20, 30, 40}); + + // when + floatVectorPojoMongoCollection.insertOne(new Int8VectorPojo(vector)); + Int8VectorPojo int8VectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(int8VectorPojo); + assertEquals(vector, int8VectorPojo.getVector()); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithPackedBitVectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(PackedBitVectorPojo.class); + + PackedBitVector vector = Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3); + + // when + floatVectorPojoMongoCollection.insertOne(new PackedBitVectorPojo(vector)); + PackedBitVectorPojo packedBitVectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(packedBitVectorPojo); + assertEquals(vector, packedBitVectorPojo.getVector()); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithGenericVectorPojo(final Vector actualVector) { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(VectorPojo.class); + + // when + floatVectorPojoMongoCollection.insertOne(new VectorPojo(actualVector)); + VectorPojo vectorPojo = floatVectorPojoMongoCollection.find().first(); + + //then + Assertions.assertNotNull(vectorPojo); + assertEquals(actualVector, vectorPojo.getVector()); + } + + private Document findExactlyOne(final MongoCollection collection) { + List documents = new ArrayList<>(); + collection.find().into(documents); + assertEquals(1, documents.size(), "Expected exactly one document, but found: " + documents.size()); + return documents.get(0); + } + + public static class VectorPojo { + private Vector vector; + + public VectorPojo() { + } + + public VectorPojo(final Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Vector vector) { + this.vector = vector; + } + } + + public static class Int8VectorPojo { + private Int8Vector vector; + + public Int8VectorPojo() { + } + + public Int8VectorPojo(final Int8Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Int8Vector vector) { + this.vector = vector; + } + } + + public static class PackedBitVectorPojo { + private PackedBitVector vector; + + public PackedBitVectorPojo() { + } + + public PackedBitVectorPojo(final PackedBitVector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final PackedBitVector vector) { + this.vector = vector; + } + } + + public static class FloatVectorPojo { + private Float32Vector vector; + + public FloatVectorPojo() { + } + + public FloatVectorPojo(final Float32Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Float32Vector vector) { + this.vector = vector; + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java new file mode 100644 index 00000000000..63d756a8f35 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +public class VectorFunctionalTest extends AbstractVectorFunctionalTest { + @Override + protected MongoClient getMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } +}