From 8cdec02d3cad58d01cfdb4912a7026dbe1c75c7c Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 19 Sep 2024 15:04:20 -0700 Subject: [PATCH 01/20] Add BSON Binary Subtype 9 support for vector storage and retrieval. - Implement encoding and decoding methods for new BSON binary subtype 9. - Add support for INT8, FLOAT32, and PACKED_BIT dtypes with padding. - Provide API methods for encoding vectors to BSON binary and decoding BSON binary to vectors. JAVA-5544 --- bson/src/main/org/bson/BsonBinary.java | 32 +++ bson/src/main/org/bson/BsonBinarySubType.java | 22 +- bson/src/main/org/bson/Vector.java | 268 ++++++++++++++++++ .../bson/internal/vector/VectorHelper.java | 150 ++++++++++ .../resources/bson-binary-vector/README.md | 40 +++ .../resources/bson-binary-vector/float32.json | 50 ++++ .../resources/bson-binary-vector/int8.json | 56 ++++ .../bson-binary-vector/packed_bit.json | 97 +++++++ bson/src/test/resources/bson/binary.json | 37 ++- bson/src/test/unit/org/bson/BsonHelper.java | 21 ++ .../test/unit/org/bson/GenericBsonTest.java | 27 +- bson/src/test/unit/org/bson/VectorTest.java | 181 ++++++++++++ .../internal/vector/VectorHelperTest.java | 184 ++++++++++++ .../bson/vector/VectorGenericBsonTest.java | 265 +++++++++++++++++ 14 files changed, 1399 insertions(+), 31 deletions(-) create mode 100644 bson/src/main/org/bson/Vector.java create mode 100644 bson/src/main/org/bson/internal/vector/VectorHelper.java create mode 100644 bson/src/test/resources/bson-binary-vector/README.md create mode 100644 bson/src/test/resources/bson-binary-vector/float32.json create mode 100644 bson/src/test/resources/bson-binary-vector/int8.json create mode 100644 bson/src/test/resources/bson-binary-vector/packed_bit.json create mode 100644 bson/src/test/unit/org/bson/VectorTest.java create mode 100644 bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java create mode 100644 bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index d5d07273cea..9aae40377d3 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -18,10 +18,13 @@ import org.bson.assertions.Assertions; import org.bson.internal.UuidHelper; +import org.bson.internal.vector.VectorHelper; import java.util.Arrays; import java.util.UUID; +import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; + /** * A representation of the BSON Binary type. Note that for performance reasons instances of this class are not immutable, * so care should be taken to only modify the underlying byte array if you know what you're doing, or else make a defensive copy. @@ -89,6 +92,20 @@ public BsonBinary(final UUID uuid) { this(uuid, UuidRepresentation.STANDARD); } + /** + * Construct a Type 9 BsonBinary from the given Vector. + * + * @param vector the {@link Vector} + * @since BINARY_VECTOR + */ + public BsonBinary(final Vector vector) { + if (vector == null) { + throw new IllegalArgumentException("Vector must not be null"); + } + this.data = encodeVectorToBinary(vector); + type = BsonBinarySubType.VECTOR.getValue(); + } + /** * Construct a new instance from the given UUID and UuidRepresentation * @@ -127,6 +144,21 @@ public UUID asUuid() { return UuidHelper.decodeBinaryToUuid(this.data.clone(), this.type, UuidRepresentation.STANDARD); } + /** + * Returns the binary as a {@link Vector}. The binary type must be 9. + * + * @return the vector + * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. + * @since BINARY_VECTOR + */ + public Vector asVector() { + if (!BsonBinarySubType.isVector(type)) { + throw new BsonInvalidOperationException("type must be a Vector subtype."); + } + + return VectorHelper.decodeBinaryToVector(this.data); + } + /** * Returns the binary as a UUID. * diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index fb1b8d0dfbe..a01fe672afb 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -17,7 +17,7 @@ package org.bson; /** - * The Binary subtype + * The Binary subtype. * * @since 3.0 */ @@ -60,12 +60,20 @@ public enum BsonBinarySubType { ENCRYPTED((byte) 0x06), /** - * Columnar data + * Columnar data. * * @since 4.4 */ COLUMN((byte) 0x07), + /** + * Vector data. + * + * @since BINARY_VECTOR + * @see Vector + */ + VECTOR((byte) 0x09), + /** * User defined binary data. */ @@ -74,16 +82,20 @@ public enum BsonBinarySubType { private final byte value; /** - * Returns true if the given value is a UUID subtype + * Returns true if the given value is a UUID subtype. * - * @param value the subtype value as a byte - * @return true if value is a UUID subtype + * @param value the subtype value as a byte. + * @return true if value is a UUID subtype. * @since 3.4 */ public static boolean isUuid(final byte value) { return value == UUID_LEGACY.getValue() || value == UUID_STANDARD.getValue(); } + public static boolean isVector(final byte value) { + return value == VECTOR.getValue(); + } + BsonBinarySubType(final byte value) { this.value = value; } diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java new file mode 100644 index 00000000000..38d1125c55a --- /dev/null +++ b/bson/src/main/org/bson/Vector.java @@ -0,0 +1,268 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + + +import java.util.Arrays; +import java.util.Objects; + +import static org.bson.assertions.Assertions.assertNotNull; +import static org.bson.assertions.Assertions.isTrue; +import static org.bson.assertions.Assertions.isTrueArgument; +import static org.bson.assertions.Assertions.notNull; + +/** + * Represents a vector that is stored and retrieved using the BSON Binary Subtype 9 format. + * This class supports multiple vector {@link Dtype}'s and provides static methods to create + * vectors and methods to retrieve their underlying data. + *

+ * Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently + * in BSON using a binary format. + * + * @mongodb.server.release 6.0 + * @see BsonBinary + * @since BINARY_VECTOR + */ +public final class Vector { + private final byte padding; + private byte[] vectorData; + private float[] floatVectorData; + private final Dtype vectorType; + + Vector(final byte padding, final byte[] vectorData, final Dtype vectorType) { + this.padding = padding; + this.vectorData = assertNotNull(vectorData); + this.vectorType = assertNotNull(vectorType); + } + + Vector(final byte[] vectorData, final Dtype vectorType) { + this((byte) 0, vectorData, vectorType); + } + + Vector(final float[] vectorData) { + this.padding = 0; + this.floatVectorData = assertNotNull(vectorData); + this.vectorType = Dtype.FLOAT32; + } + + /** + * Creates a vector with the {@link Dtype#PACKED_BIT} data type. + *

+ * A {@link Dtype#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte + * can hold up to 8 bits (vector elements). The padding parameter is used to specify how many bits in the final byte should be ignored.

+ * + *

For example, a vector with two bytes and a padding of 4 would have the following structure:

+ *
+     * Byte 1: 238 (binary: 11101110)
+     * Byte 2: 224 (binary: 11100000)
+     * Padding: 4 (ignore the last 4 bits in Byte 2)
+     * Resulting vector: 12 bits: 111011101110
+     * 
+ * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + * + * @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits. + * @param padding The number of bits (0 to 7) to ignore in the final byte of the vector data. + * @return A Vector instance with the {@link Dtype#PACKED_BIT} data type. + * @throws IllegalArgumentException If the padding value is greater than 7. + */ + public static Vector packedBitVector(final byte[] vectorData, final byte padding) { + isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); + notNull("Vector data", vectorData); + isTrue("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); + return new Vector(padding, vectorData, Dtype.PACKED_BIT); + } + + /** + * Creates a vector with the {@link Dtype#INT8} data type. + * + *

A {@link Dtype#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, + * with values in the range [-128, 127].

+ *

+ * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + * + * @param vectorData The byte array representing the {@link Dtype#INT8} vector data. + * @return A Vector instance with the {@link Dtype#INT8} data type. + */ + public static Vector int8Vector(final byte[] vectorData) { + notNull("vectorData", vectorData); + return new Vector(vectorData, Dtype.INT8); + } + + /** + * Creates a vector with the {@link Dtype#FLOAT32} data type. + * + *

A {@link Dtype#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

+ *

+ * NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + * + * @param vectorData The float array representing the {@link Dtype#FLOAT32} vector data. + * @return A Vector instance with the {@link Dtype#FLOAT32} data type. + */ + public static Vector floatVector(final float[] vectorData) { + notNull("vectorData", vectorData); + return new Vector(vectorData); + } + + /** + * Returns the {@link Dtype#PACKED_BIT} vector data as a byte array. + * + *

This method is used to retrieve the underlying underlying byte array representing the {@link Dtype#PACKED_BIT} vector, where + * each bit represents an element of the vector (either 0 or 1). + * + * @return the packed bit vector data. + * @throws IllegalStateException if this vector is not of type {@link Dtype#PACKED_BIT}. Use {@link #getDataType()} to check the vector type before + * calling this method. + * @see #getPadding() getPadding() specifies how many least-significant bits in the final byte should be ignored. + */ + public byte[] asPackedBitVectorData() { + if (this.vectorType != Dtype.PACKED_BIT) { + throw new IllegalStateException("Vector is not binary quantized"); + } + return assertNotNull(vectorData); + } + + /** + * Returns the {@link Dtype#INT8} vector data as a byte array. + * + *

This method is used to retrieve the underlying byte array representing the {@link Dtype#INT8} vector, where each byte represents + * an element of a vector.

+ * + * @return the {@link Dtype#INT8} vector data. + * @throws IllegalStateException if this vector is not of type {@link Dtype#INT8}. Use {@link #getDataType()} to check the vector + * type before calling this method. + */ + public byte[] asInt8VectorData() { + if (this.vectorType != Dtype.INT8) { + throw new IllegalStateException("Vector is not INT8"); + } + return assertNotNull(vectorData); + } + + /** + * Returns the {@link Dtype#FLOAT32} vector data as a float array. + * + *

This method is used to retrieve the underlying float array representing the {@link Dtype#FLOAT32} vector, where each float + * represents an element of a vector.

+ * + * @return the float array representing the FLOAT32 vector. + * @throws IllegalStateException if this vector is not of type {@link Dtype#FLOAT32}. Use {@link #getDataType()} to check the vector + * type before calling this method. + */ + public float[] asFloatVectorData() { + if (this.vectorType != Dtype.FLOAT32) { + throw new IllegalStateException("Vector is not FLOAT32"); + } + + return assertNotNull(floatVectorData); + } + + /** + * Returns the padding value for this vector. + * + *

Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not + * all {@link Dtype}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

+ * + * @return the padding value (between 0 and 7). + */ + public byte getPadding() { + return this.padding; + } + + + /** + * Returns {@link Dtype} of the vector. + * + * @return the data type of the vector. + */ + public Dtype getDataType() { + return this.vectorType; + } + + + @Override + public String toString() { + return "Vector{" + + "padding=" + padding + ", " + + "vectorData=" + (vectorData == null ? Arrays.toString(floatVectorData) : Arrays.toString(vectorData)) + + ", vectorType=" + vectorType + + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Vector)) { + return false; + } + + Vector vector = (Vector) o; + return padding == vector.padding && Arrays.equals(vectorData, vector.vectorData) + && Arrays.equals(floatVectorData, vector.floatVectorData) && vectorType == vector.vectorType; + } + + @Override + public int hashCode() { + int result = padding; + result = 31 * result + Arrays.hashCode(vectorData); + result = 31 * result + Arrays.hashCode(floatVectorData); + result = 31 * result + Objects.hashCode(vectorType); + return result; + } + + /** + * Represents the data type (dtype) of a vector. + *

+ * Each dtype determines how the data in the vector is stored, including how many bits are used to represent each element + * in the vector. + */ + public enum Dtype { + /** + * An INT8 vector is a vector of 8-bit signed integers. The vector is stored as an array of bytes, where each byte + * represents a signed integer in the range [-128, 127]. + */ + INT8((byte) 0x03), + /** + * A FLOAT32 vector is a vector of 32-bit floating-point numbers, where each element in the vector is a float. + */ + FLOAT32((byte) 0x27), + /** + * A PACKED_BIT vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). + * Each byte can hold up to 8 bits (vector elements). + */ + PACKED_BIT((byte) 0x10); + + private final byte value; + + Dtype(final byte value) { + this.value = value; + } + + /** + * Returns the byte value associated with this {@link Dtype}. + * + *

This value is used in the BSON binary format to indicate the data type of the vector.

+ * + * @return the byte value representing the {@link Dtype}. + */ + public byte getValue() { + return value; + } + } +} + diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java new file mode 100644 index 00000000000..12cabef0d95 --- /dev/null +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -0,0 +1,150 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.internal.vector; + +import org.bson.BsonBinary; +import org.bson.Vector; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +import static org.bson.assertions.Assertions.isTrue; + +/** + * Helper class for encoding and decoding vectors to and from binary. + * + *

+ * This class is not part of the public API and may be removed or changed at any time. + * + * @see Vector + * @see BsonBinary#asVector() + * @see BsonBinary#BsonBinary(Vector) + */ +public final class VectorHelper { + + private static final ByteOrder STORED_BYTE_ORDER = ByteOrder.LITTLE_ENDIAN; + + private VectorHelper() { + //NOP + } + + private static final int METADATA_SIZE = 2; + private static final int FLOAT_SIZE = 4; + + public static byte[] encodeVectorToBinary(final Vector vector) { + Vector.Dtype dtype = vector.getDataType(); + byte padding = vector.getPadding(); + switch (dtype) { + case INT8: + return writeVector(dtype.getValue(), padding, vector.asInt8VectorData()); + case PACKED_BIT: + return writeVector(dtype.getValue(), padding, vector.asPackedBitVectorData()); + case FLOAT32: + return writeVector(dtype.getValue(), padding, vector.asFloatVectorData()); + + default: + throw new AssertionError("Unknown vector dtype: " + dtype); + } + } + + public static Vector decodeBinaryToVector(final byte[] encodedVector) { + isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE); + + Vector.Dtype dtype = determineVectorDType(encodedVector[0]); + byte padding = encodedVector[1]; + switch (dtype) { + case INT8: + byte[] int8Vector = getVectorBytesWithoutMetadata(encodedVector); + return Vector.int8Vector(int8Vector); + case PACKED_BIT: + byte[] packedBitVector = getVectorBytesWithoutMetadata(encodedVector); + return Vector.packedBitVector(packedBitVector, padding); + case FLOAT32: + isTrue("Byte array length must be a multiple of 4 for FLOAT32 dtype.", + (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); + return Vector.floatVector(readLittleEndianFloats(encodedVector)); + + default: + throw new AssertionError("Unknown vector dtype: " + dtype); + } + } + + private static byte[] getVectorBytesWithoutMetadata(final byte[] encodedVector) { + int vectorDataLength; + byte[] vectorData; + vectorDataLength = encodedVector.length - METADATA_SIZE; + vectorData = new byte[vectorDataLength]; + System.arraycopy(encodedVector, METADATA_SIZE, vectorData, 0, vectorDataLength); + return vectorData; + } + + + public static byte[] writeVector(final byte dtype, final byte padding, final byte[] vectorData) { + final byte[] bytes = new byte[vectorData.length + METADATA_SIZE]; + bytes[0] = dtype; + bytes[1] = padding; + System.arraycopy(vectorData, 0, bytes, METADATA_SIZE, vectorData.length); + return bytes; + } + + public static byte[] writeVector(final byte dtype, final byte padding, final float[] vectorData) { + final byte[] bytes = new byte[vectorData.length * FLOAT_SIZE + METADATA_SIZE]; + + bytes[0] = dtype; + bytes[1] = padding; + + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.order(STORED_BYTE_ORDER); + buffer.position(METADATA_SIZE); + + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + + // The JVM may optimize this operation internally, potentially using intrinsics + // or platform-specific optimizations (such as SIMD). If the byte order matches the underlying system's + // native order, the operation may involve a direct memory copy. + floatBuffer.put(vectorData); + + return bytes; + } + + private static float[] readLittleEndianFloats(final byte[] encodedVector) { + int vectorSize = encodedVector.length - METADATA_SIZE; + + int numFloats = vectorSize / FLOAT_SIZE; + float[] floatArray = new float[numFloats]; + + ByteBuffer buffer = ByteBuffer.wrap(encodedVector, METADATA_SIZE, vectorSize); + buffer.order(STORED_BYTE_ORDER); + + // The JVM may optimize this operation internally, potentially using intrinsics + // or platform-specific optimizations (such as SIMD). If the byte order matches the underlying system's + // native order, the operation may involve a direct memory copy. + buffer.asFloatBuffer().get(floatArray); + return floatArray; + } + + public static Vector.Dtype determineVectorDType(final byte dtype) { + Vector.Dtype[] values = Vector.Dtype.values(); + for (Vector.Dtype value : values) { + if (value.getValue() == dtype) { + return value; + } + } + throw new IllegalStateException("Unknown vector dtype: " + dtype); + } +} diff --git a/bson/src/test/resources/bson-binary-vector/README.md b/bson/src/test/resources/bson-binary-vector/README.md new file mode 100644 index 00000000000..73a5f0a9f33 --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/README.md @@ -0,0 +1,40 @@ +# Testing Binary subtype 9: Vector + +The JSON files in this directory tree are platform-independent tests that drivers can use to prove their conformance to +the specification. + +These tests focus on the roundtrip of the list numbers as input/output, along with their data type and byte padding. + +Additional tests exist in `bson_corpus/tests/binary.json` but do not sufficiently test the end-to-end process of Vector +to BSON. For this reason, drivers must create a bespoke test runner for the vector subtype. + +Each test case here pertains to a single vector. The inputs required to create the Binary BSON object are defined, and +when valid, the Canonical BSON and Extended JSON representations are included for comparison. + +## Version + +Files in the "specifications" repository have no version scheme. They are not tied to a MongoDB server version. + +## Format + +#### Top level keys + +Each JSON file contains three top-level keys. + +- `description`: human-readable description of what is in the file +- `test_key`: Field name used when decoding/encoding a BSON document containing the single BSON Binary for the test + case. Applies to *every* case. +- `tests`: array of test case objects, each of which have the following keys. Valid cases will also contain additional + binary and json encoding values. + +#### Keys of tests objects + +- `description`: string describing the test. +- `valid`: boolean indicating if the vector, dtype, and padding should be considered a valid input. +- `vector`: list of numbers +- `dtype_hex`: string defining the data type in hex (e.g. "0x10", "0x27") +- `dtype_alias`: (optional) string defining the data dtype, perhaps as Enum. +- `padding`: (optional) integer for byte padding. Defaults to 0. +- `canonical_bson`: (required if valid is true) an (uppercase) big-endian hex representation of a BSON byte string. +- `canonical_extjson`: (required if valid is true) string containing a Canonical Extended JSON document. Because this is + itself embedded as a *string* inside a JSON document, characters like quote and backslash are escaped. \ No newline at end of file diff --git a/bson/src/test/resources/bson-binary-vector/float32.json b/bson/src/test/resources/bson-binary-vector/float32.json new file mode 100644 index 00000000000..e1d142c184b --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/float32.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Vector with decimals and negative value FLOAT32", + "valid": true, + "vector": [127.7, -7.7], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927006666FF426666F6C000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson-binary-vector/int8.json b/bson/src/test/resources/bson-binary-vector/int8.json new file mode 100644 index 00000000000..c10c1b7d4e2 --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/int8.json @@ -0,0 +1,56 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson-binary-vector/packed_bit.json b/bson/src/test/resources/bson-binary-vector/packed_bit.json new file mode 100644 index 00000000000..69fb3948335 --- /dev/null +++ b/bson/src/test/resources/bson-binary-vector/packed_bit.json @@ -0,0 +1,97 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Exceeding maximum padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 8 + }, + { + "description": "Negative padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": -1 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} \ No newline at end of file diff --git a/bson/src/test/resources/bson/binary.json b/bson/src/test/resources/bson/binary.json index d3c57ec1081..29d88471afe 100644 --- a/bson/src/test/resources/bson/binary.json +++ b/bson/src/test/resources/bson/binary.json @@ -55,6 +55,11 @@ "canonical_bson": "1D000000057800100000000773FFD26444B34C6990E8E7D1DFC035D400", "canonical_extjson": "{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"07\"}}}" }, + { + "description": "subtype 0x08", + "canonical_bson": "1D000000057800100000000873FFD26444B34C6990E8E7D1DFC035D400", + "canonical_extjson": "{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"08\"}}}" + }, { "description": "subtype 0x80", "canonical_bson": "0F0000000578000200000080FFFF00", @@ -69,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [ @@ -115,4 +150,4 @@ "string": "{\"x\" : { \"$uuid\" : \"----d264-44b3-4--9-90e8-e7d1dfc0----\"}}" } ] -} +} \ No newline at end of file diff --git a/bson/src/test/unit/org/bson/BsonHelper.java b/bson/src/test/unit/org/bson/BsonHelper.java index 985e398b1ca..59fdba474a2 100644 --- a/bson/src/test/unit/org/bson/BsonHelper.java +++ b/bson/src/test/unit/org/bson/BsonHelper.java @@ -17,10 +17,12 @@ package org.bson; import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.DecoderContext; import org.bson.codecs.EncoderContext; import org.bson.io.BasicOutputBuffer; import org.bson.types.Decimal128; import org.bson.types.ObjectId; +import util.Hex; import java.nio.ByteBuffer; import java.util.Date; @@ -109,4 +111,23 @@ public static ByteBuffer toBson(final BsonDocument document) { private BsonHelper() { } + + public static BsonDocument decodeToDocument(final String subjectHex, final String description) { + ByteBuffer byteBuffer = ByteBuffer.wrap(Hex.decode(subjectHex)); + BsonDocument actualDecodedDocument = new BsonDocumentCodec().decode(new BsonBinaryReader(byteBuffer), + DecoderContext.builder().build()); + + if (byteBuffer.hasRemaining()) { + throw new BsonSerializationException(format("Should have consumed all bytes, but " + byteBuffer.remaining() + + " still remain in the buffer for document with description ", + description)); + } + return actualDecodedDocument; + } + + public static String encodeToHex(final BsonDocument decodedDocument) { + BasicOutputBuffer outputBuffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), decodedDocument, EncoderContext.builder().build()); + return Hex.encode(outputBuffer.toByteArray()); + } } diff --git a/bson/src/test/unit/org/bson/GenericBsonTest.java b/bson/src/test/unit/org/bson/GenericBsonTest.java index 2f50bcd7f61..6ba2c6ae382 100644 --- a/bson/src/test/unit/org/bson/GenericBsonTest.java +++ b/bson/src/test/unit/org/bson/GenericBsonTest.java @@ -16,10 +16,6 @@ package org.bson; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.DecoderContext; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import org.bson.json.JsonMode; import org.bson.json.JsonParseException; import org.bson.json.JsonWriterSettings; @@ -27,7 +23,6 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import util.Hex; import util.JsonPoweredTestHelper; import java.io.File; @@ -35,7 +30,6 @@ import java.io.StringReader; import java.io.StringWriter; import java.net.URISyntaxException; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -44,6 +38,8 @@ import static java.lang.String.format; import static org.bson.BsonDocument.parse; +import static org.bson.BsonHelper.decodeToDocument; +import static org.bson.BsonHelper.encodeToHex; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -207,25 +203,6 @@ private boolean shouldEscapeCharacter(final char escapedChar) { } } - private BsonDocument decodeToDocument(final String subjectHex, final String description) { - ByteBuffer byteBuffer = ByteBuffer.wrap(Hex.decode(subjectHex)); - BsonDocument actualDecodedDocument = new BsonDocumentCodec().decode(new BsonBinaryReader(byteBuffer), - DecoderContext.builder().build()); - - if (byteBuffer.hasRemaining()) { - throw new BsonSerializationException(format("Should have consumed all bytes, but " + byteBuffer.remaining() - + " still remain in the buffer for document with description ", - description)); - } - return actualDecodedDocument; - } - - private String encodeToHex(final BsonDocument decodedDocument) { - BasicOutputBuffer outputBuffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(outputBuffer), decodedDocument, EncoderContext.builder().build()); - return Hex.encode(outputBuffer.toByteArray()); - } - private void runDecodeError(final BsonDocument testCase) { try { String description = testCase.getString("description").getValue(); diff --git a/bson/src/test/unit/org/bson/VectorTest.java b/bson/src/test/unit/org/bson/VectorTest.java new file mode 100644 index 00000000000..c1ea00ca0d7 --- /dev/null +++ b/bson/src/test/unit/org/bson/VectorTest.java @@ -0,0 +1,181 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class VectorTest { + + @Test + void shouldCreateInt8Vector() { + // given + byte[] data = {1, 2, 3, 4, 5}; + + // when + Vector vector = Vector.int8Vector(data); + + // then + assertNotNull(vector); + assertEquals(Vector.Dtype.INT8, vector.getDataType()); + assertArrayEquals(data, vector.asInt8VectorData()); + assertEquals(0, vector.getPadding()); + } + + @Test + void shouldThrowExceptionWhenCreatingInt8VectorWithNullData() { + // given + byte[] data = null; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.int8Vector(data)); + assertEquals("vectorData can not be null", exception.getMessage()); + } + + @Test + void shouldCreateFloat32Vector() { + // given + float[] data = {1.0f, 2.0f, 3.0f}; + + // when + Vector vector = Vector.floatVector(data); + + // then + assertNotNull(vector); + assertEquals(Vector.Dtype.FLOAT32, vector.getDataType()); + assertArrayEquals(data, vector.asFloatVectorData()); + assertEquals(0, vector.getPadding()); + } + + @Test + void shouldThrowExceptionWhenCreatingFloat32VectorWithNullData() { + // given + float[] data = null; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.floatVector(data)); + assertEquals("vectorData can not be null", exception.getMessage()); + } + + + @ParameterizedTest(name = "{index}: validPadding={0}") + @ValueSource(bytes = {0, 1, 2, 3, 4, 5, 6, 7}) + void shouldCreatePackedBitVector(final byte validPadding) { + // given + byte[] data = {(byte) 0b10101010, (byte) 0b01010101}; + + // when + Vector vector = Vector.packedBitVector(data, validPadding); + + // then + assertNotNull(vector); + assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); + assertArrayEquals(data, vector.asPackedBitVectorData()); + assertEquals(validPadding, vector.getPadding()); + } + + @ParameterizedTest(name = "{index}: invalidPadding={0}") + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionWhenPackedBitVectorHasInvalidPadding(final byte invalidPadding) { + // given + byte[] data = {(byte) 0b10101010}; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + Vector.packedBitVector(data, invalidPadding)); + assertEquals("state should be: Padding must be between 0 and 7 bits.", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenPackedBitVectorIsCreatedWithNullData() { + // given + byte[] data = null; + byte padding = 0; + + // when & Then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + Vector.packedBitVector(data, padding)); + assertEquals("Vector data can not be null", exception.getMessage()); + } + + @Test + void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() { + // given + byte[] data = new byte[0]; + byte padding = 0; + + // when + Vector vector = Vector.packedBitVector(data, padding); + + // then + assertNotNull(vector); + assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); + assertArrayEquals(data, vector.asPackedBitVectorData()); + assertEquals(padding, vector.getPadding()); + } + + @Test + void shouldThrowExceptionWhenPackedBitVectorWithNonZeroPaddingAndEmptyData() { + // given + byte[] data = new byte[0]; + byte padding = 1; + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> + Vector.packedBitVector(data, padding)); + assertEquals("state should be: Padding must be 0 if vector is empty", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingInt8DataFromNonInt8Vector() { + // given + float[] data = {1.0f, 2.0f}; + Vector vector = Vector.floatVector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asInt8VectorData); + assertEquals("Vector is not INT8", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingFloat32DataFromNonFloat32Vector() { + // given + byte[] data = {1, 2, 3}; + Vector vector = Vector.int8Vector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asFloatVectorData); + assertEquals("Vector is not FLOAT32", exception.getMessage()); + } + + @Test + void shouldThrowExceptionWhenRetrievingPackedBitDataFromNonPackedBitVector() { + // given + float[] data = {1.0f, 2.0f}; + Vector vector = Vector.floatVector(data); + + // when & Then + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asPackedBitVectorData); + assertEquals("Vector is not binary quantized", exception.getMessage()); + } +} diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java new file mode 100644 index 00000000000..ffc6abc5cdf --- /dev/null +++ b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.internal.vector; + +import org.bson.Vector; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class VectorHelperTest { + private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue(); + private static final byte INT8_DTYPE = Vector.Dtype.INT8.getValue(); + private static final byte PACKED_BIT_DTYPE = Vector.Dtype.PACKED_BIT.getValue(); + public static final int ZERO_PADDING = 0; + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideFloatVectors") + void shouldEncodeFloatVector(final Vector actualFloat32Vector, final byte[] expectedBsonEncodedVector) { + // when + byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualFloat32Vector); + + //Then + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideFloatVectors") + void shouldDecodeFloatVector(final Vector expectedFloatVector, final byte[] bsonEncodedVector) { + // when + Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + + // then + assertEquals(Vector.Dtype.FLOAT32, decodedVector.getDataType()); + assertEquals(0, decodedVector.getPadding()); + assertArrayEquals(expectedFloatVector.asFloatVectorData(), decodedVector.asFloatVectorData()); + } + + private static Stream provideFloatVectors() { + return Stream.of( + new Object[]{ + Vector.floatVector( + new float[]{1.1f, 2.2f, 3.3f, -1.0f, Float.MAX_VALUE, Float.MIN_VALUE, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 205, (byte) 204, (byte) 140, (byte) 63, // 1.1f in little-endian + (byte) 205, (byte) 204, (byte) 12, (byte) 64, // 2.2f in little-endian + (byte) 51, (byte) 51, (byte) 83, (byte) 64, // 3.3f in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 191, // -1.0f in little-endian + (byte) 255, (byte) 255, (byte) 127, (byte) 127, // Float.MAX_VALUE in little-endian + (byte) 1, (byte) 0, (byte) 0, (byte) 0, // Float.MIN_VALUE in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 127, // Float.POSITIVE_INFINITY in little-endian + (byte) 0, (byte) 0, (byte) 128, (byte) 255, // Float.NEGATIVE_INFINITY in little-endian + }}, + new Object[]{ + Vector.floatVector(new float[]{0.0f}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 0, (byte) 0, (byte) 0, (byte) 0 // 0.0f in little-endian + }}, + new Object[]{ + Vector.floatVector(new float[]{}), + new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + }} + ); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideInt8Vectors") + void shouldEncodeInt8Vector(final Vector actualInt8Vector, final byte[] expectedBsonEncodedVector) { + // when + byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualInt8Vector); + + // then + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("provideInt8Vectors") + void shouldDecodeInt8Vector(final Vector expectedInt8Vector, final byte[] bsonEncodedVector) { + // when + Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + + // then + assertEquals(Vector.Dtype.INT8, decodedVector.getDataType()); + assertArrayEquals(expectedInt8Vector.asInt8VectorData(), decodedVector.asInt8VectorData()); + } + + private static Stream provideInt8Vectors() { + return Stream.of( + new Object[]{ + Vector.int8Vector(new byte[]{Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE}), + new byte[]{INT8_DTYPE, ZERO_PADDING, Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE + }}, + new Object[]{Vector.int8Vector(new byte[]{}), + new byte[]{INT8_DTYPE, ZERO_PADDING} + } + ); + } + + @ParameterizedTest + @MethodSource("providePackedBitVectors") + void shouldEncodePackedBitVector(final Vector actualPackedBitVector, final byte[] expectedBsonEncodedVector) { + // when + byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualPackedBitVector); + + // then + assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); + } + + @ParameterizedTest + @MethodSource("providePackedBitVectors") + void shouldDecodePackedBitVector(final Vector expectedPackedBitVector, final byte[] bsonEncodedVector) { + // when + Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + + // then + assertEquals(Vector.Dtype.PACKED_BIT, decodedVector.getDataType()); + assertArrayEquals(expectedPackedBitVector.asPackedBitVectorData(), decodedVector.asPackedBitVectorData()); + assertEquals(expectedPackedBitVector.getPadding(), decodedVector.getPadding()); + } + + private static Stream providePackedBitVectors() { + return Stream.of( + new Object[]{ + Vector.packedBitVector(new byte[]{(byte) 15, (byte) 240}, (byte) 2), + new byte[]{PACKED_BIT_DTYPE, 2, (byte) 15, (byte) 240} + }, + new Object[]{ + Vector.packedBitVector(new byte[]{(byte) 170}, (byte) 4), + new byte[]{PACKED_BIT_DTYPE, 4, (byte) 170} + } + ); + } + + @Test + void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { + // given: an encoded vector with an invalid length (not a multiple of 4) + byte[] invalidData = {FLOAT32_DTYPE, 0, 10, 20, 30}; + + // when & Then + IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { + VectorHelper.decodeBinaryToVector(invalidData); + }); + assertEquals("state should be: Byte array length must be a multiple of 4 for FLOAT32 dtype.", thrown.getMessage()); + } + + @Test + void shouldDetermineVectorDType() { + // given + Vector.Dtype[] values = Vector.Dtype.values(); + + for (Vector.Dtype value : values) { + // when + byte dtype = value.getValue(); + Vector.Dtype actual = VectorHelper.determineVectorDType(dtype); + + // then + assertEquals(value, actual); + } + } + + @Test + void shouldThrowWhenUnknownVectorDType() { + assertThrows(IllegalStateException.class, () -> VectorHelper.determineVectorDType((byte) 0)); + } +} diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java new file mode 100644 index 00000000000..c268bb751be --- /dev/null +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -0,0 +1,265 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.vector; + +import org.bson.BsonArray; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.Vector; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import util.JsonPoweredTestHelper; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import static java.lang.String.format; +import static org.bson.BsonHelper.decodeToDocument; +import static org.bson.BsonHelper.encodeToHex; +import static org.bson.internal.vector.VectorHelper.determineVectorDType; +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeFalse; + +// BSON tests powered by language-agnostic JSON-based tests included in test resources +class VectorGenericBsonTest { + + private static final List TEST_NAMES_TO_IGNORE = Arrays.asList( + //NO API to set padding for Floats available + "FLOAT32 with padding", + //NO API to set padding for Floats available + "INT8 with padding", + //It is impossible to provide float inputs for INT8 in the API + "INT8 with float inputs", + //It is impossible to provide float inputs for INT8 + "Underflow Vector PACKED_BIT", + //It is impossible to provide float inputs for PACKED_BIT in the API + "Vector with float values PACKED_BIT", + //It is impossible to provide float inputs for INT8 + "Overflow Vector PACKED_BIT", + "Overflow Vector INT8", + // It is impossible to provide -129 for byte. + "Underflow Vector INT8"); + + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTestCases") + void shouldPassAllOutcomes(@SuppressWarnings("unused") final String description, + final BsonDocument testDefinition, final BsonDocument testCase) { + assumeFalse(TEST_NAMES_TO_IGNORE.contains(testCase.get("description").asString().getValue())); + + String testKey = testDefinition.getString("test_key").getValue(); + boolean isValidVector = testCase.getBoolean("valid").getValue(); + if (isValidVector) { + runValidTestCase(testKey, testCase); + } else { + runInvalidTestCase(testCase); + } + } + + private void runInvalidTestCase(final BsonDocument testCase) { + BsonArray arrayVector = testCase.getArray("vector"); + byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); + byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); + Vector.Dtype expectedDType = determineVectorDType(dtypeByte); + + switch (expectedDType) { + case INT8: + byte[] expectedVectorData = toByteArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, + () -> Vector.int8Vector(expectedVectorData))); + break; + case PACKED_BIT: + byte[] expectedVectorPackedBitData = toByteArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, + () -> Vector.packedBitVector(expectedVectorPackedBitData, expectedPadding))); + break; + case FLOAT32: + float[] expectedFloatVector = toFloatArray(arrayVector); + assertValidationException(assertThrows(RuntimeException.class, () -> Vector.floatVector(expectedFloatVector))); + break; + default: + throw new IllegalArgumentException("Unsupported vector data type: " + expectedDType); + } + } + + private void runValidTestCase(final String testKey, final BsonDocument testCase) { + String description = testCase.getString("description").getValue(); + byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); + + byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); + Vector.Dtype expectedDType = determineVectorDType(dtypeByte); + String expectedCanonicalBsonHex = testCase.getString("canonical_bson").getValue().toUpperCase(); + + BsonArray arrayVector = testCase.getArray("vector"); + BsonDocument actualDecodedDocument = decodeToDocument(expectedCanonicalBsonHex, description); + Vector actualVector = actualDecodedDocument.getBinary("vector").asVector(); + + switch (expectedDType) { + case INT8: + byte[] expectedVectorData = toByteArray(arrayVector); + byte[] actualVectorData = actualVector.asInt8VectorData(); + assertVectorDecoding( + expectedCanonicalBsonHex, expectedVectorData, + expectedDType, expectedPadding, actualDecodedDocument, + actualVectorData, actualVector); + + assertThatVectorCreationResultsInCorrectBinary(Vector.int8Vector(expectedVectorData), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, description); + break; + case PACKED_BIT: + byte[] actualVectorPackedBitData = actualVector.asPackedBitVectorData(); + byte[] expectedVectorPackedBitData = toByteArray(arrayVector); + assertVectorDecoding( + expectedCanonicalBsonHex, expectedVectorPackedBitData, + expectedDType, expectedPadding, actualDecodedDocument, + actualVectorPackedBitData, actualVector); + + assertThatVectorCreationResultsInCorrectBinary( + Vector.packedBitVector(expectedVectorPackedBitData, expectedPadding), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, + description); + break; + case FLOAT32: + float[] actualFloatVector = actualVector.asFloatVectorData(); + float[] expectedFloatVector = toFloatArray(arrayVector); +// assertVectorDecoding( +// expectedCanonicalBsonHex, expectedFloatVector, +// expectedDType, expectedPadding, actualDecodedDocument, +// actualFloatVector, actualVector); + assertThatVectorCreationResultsInCorrectBinary( + Vector.floatVector(expectedFloatVector), + testKey, + actualDecodedDocument, + expectedCanonicalBsonHex, + description); + break; + default: + throw new IllegalArgumentException("Unsupported vector data type: " + expectedDType); + } + } + + private static void assertValidationException(final RuntimeException runtimeException) { + assertTrue(runtimeException instanceof IllegalArgumentException || runtimeException instanceof IllegalStateException); + } + + private static void assertThatVectorCreationResultsInCorrectBinary(final Vector expectedVectorData, + final String testKey, + final BsonDocument actualDecodedDocument, + final String expectedCanonicalBsonHex, + final String description) { + BsonDocument documentToEncode = new BsonDocument(testKey, new BsonBinary(expectedVectorData)); + assertEquals(documentToEncode, actualDecodedDocument); + assertEquals(expectedCanonicalBsonHex, encodeToHex(documentToEncode), + format("Failed to create expected BSON for document with description '%s'", description)); + } + + private void assertVectorDecoding(final String expectedCanonicalBsonHex, + final byte[] expectedVectorData, + final Vector.Dtype expectedDType, + final byte expectedPadding, + final BsonDocument actualDecodedDocument, + final byte[] actualVectorData, + final Vector actualVector) { + assertEquals(expectedCanonicalBsonHex, encodeToHex(actualDecodedDocument)); + Assertions.assertArrayEquals(actualVectorData, expectedVectorData, + () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); + assertEquals(expectedDType, actualVector.getDataType()); + assertEquals(expectedPadding, actualVector.getPadding()); + } + + private void assertVectorDecoding(final String expectedCanonicalBsonHex, + final float[] expectedVectorData, + final Vector.Dtype expectedDType, + final byte expectedPadding, + final BsonDocument actualDecodedDocument, + final float[] actualVectorData, + final Vector actualVector) { + assertEquals(expectedCanonicalBsonHex, encodeToHex(actualDecodedDocument)); + Assertions.assertArrayEquals(actualVectorData, expectedVectorData, + () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); + assertEquals(expectedDType, actualVector.getDataType()); + assertEquals(expectedPadding, actualVector.getPadding()); + } + + private byte[] toByteArray(final BsonArray arrayVector) { + byte[] bytes = new byte[arrayVector.size()]; + for (int i = 0; i < arrayVector.size(); i++) { + bytes[i] = (byte) arrayVector.get(i).asInt32().getValue(); + } + return bytes; + } + + private float[] toFloatArray(final BsonArray arrayVector) { + float[] floats = new float[arrayVector.size()]; + for (int i = 0; i < arrayVector.size(); i++) { + BsonValue bsonValue = arrayVector.get(i); + if (bsonValue.isString()) { + floats[i] = parseFloat(bsonValue.asString()); + } else { + floats[i] = (float) arrayVector.get(i).asDouble().getValue(); + } + } + return floats; + } + + private static float parseFloat(final BsonString bsonValue) { + String floatValue = bsonValue.getValue(); + switch (floatValue) { + case "-inf": + return Float.NEGATIVE_INFINITY; + case "inf": + return Float.POSITIVE_INFINITY; + default: + return Float.parseFloat(floatValue); + } + } + + private static Stream provideTestCases() throws URISyntaxException, IOException { + List data = new ArrayList<>(); + for (File file : JsonPoweredTestHelper.getTestFiles("/bson-binary-vector")) { + BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); + for (BsonValue curValue : testDocument.getArray("tests", new BsonArray())) { + BsonDocument testCaseDocument = curValue.asDocument(); + data.add(Arguments.of(createTestCaseDescription(testDocument, testCaseDocument), testDocument, testCaseDocument)); + } + } + return data.stream(); + } + + private static String createTestCaseDescription(final BsonDocument testDocument, + final BsonDocument testCaseDocument) { + boolean isValidTestCase = testCaseDocument.getBoolean("valid").getValue(); + String testSuiteDescription = testDocument.getString("description").getValue(); + String testCaseDescription = testCaseDocument.getString("description").getValue(); + return "[Valid input: " + isValidTestCase + "] " + testSuiteDescription + ": " + testCaseDescription; + } +} From 5237979e39d08b4423847235be77ce137596afe9 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 15 Oct 2024 19:41:04 -0700 Subject: [PATCH 02/20] Add Vector subtypes. JAVA-5544 --- bson/src/main/org/bson/Float32Vector.java | 84 ++++++++++ bson/src/main/org/bson/Int8Vector.java | 84 ++++++++++ bson/src/main/org/bson/PackedBitVector.java | 105 ++++++++++++ bson/src/main/org/bson/Vector.java | 156 +++++------------- .../bson/internal/vector/VectorHelper.java | 16 +- bson/src/main/org/bson/types/Binary.java | 34 ++++ .../org/bson/BsonBinarySpecification.groovy | 11 +- .../BsonBinarySubTypeSpecification.groovy | 1 + .../test/unit/org/bson/BsonBinaryTest.java | 117 +++++++++++++ .../unit/org/bson/BsonBinaryWriterTest.java | 12 +- bson/src/test/unit/org/bson/VectorTest.java | 30 ++-- .../internal/vector/VectorHelperTest.java | 22 +-- .../test/unit/org/bson/types/BinaryTest.java | 123 ++++++++++++++ .../bson/vector/VectorGenericBsonTest.java | 61 ++++--- 14 files changed, 688 insertions(+), 168 deletions(-) create mode 100644 bson/src/main/org/bson/Float32Vector.java create mode 100644 bson/src/main/org/bson/Int8Vector.java create mode 100644 bson/src/main/org/bson/PackedBitVector.java create mode 100644 bson/src/test/unit/org/bson/BsonBinaryTest.java create mode 100644 bson/src/test/unit/org/bson/types/BinaryTest.java diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java new file mode 100644 index 00000000000..ad2b5973bfa --- /dev/null +++ b/bson/src/main/org/bson/Float32Vector.java @@ -0,0 +1,84 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.bson.types.Binary; + +import java.util.Arrays; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a vector of 32-bit floating-point numbers, where each element in the vector is a float. + *

+ * The {@link Float32Vector} is used to store and retrieve data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#floatVector(float[]) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @see Binary#Binary(Vector) + * @see Binary#asVector() + * @since BINARY_VECTOR + */ +public class Float32Vector extends Vector { + + private final float[] vectorData; + + Float32Vector(final float[] vectorData) { + super(Dtype.FLOAT32); + this.vectorData = assertNotNull(vectorData); + } + + /** + * Retrieve the underlying float array representing this {@link Float32Vector}, where each float + * represents an element of a vector. + *

+ * NOTE: The underlying float array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the underlying float array representing this {@link Float32Vector} vector. + */ + public float[] getVectorArray() { + return assertNotNull(vectorData); + } + + @Override + public final boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Float32Vector)) { + return false; + } + + Float32Vector that = (Float32Vector) o; + return Arrays.equals(vectorData, that.vectorData); + } + + @Override + public int hashCode() { + return Arrays.hashCode(vectorData); + } + + @Override + public String toString() { + return "Float32Vector{" + + "vectorData=" + Arrays.toString(vectorData) + + ", vectorType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java new file mode 100644 index 00000000000..56520b4de49 --- /dev/null +++ b/bson/src/main/org/bson/Int8Vector.java @@ -0,0 +1,84 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.bson.types.Binary; + +import java.util.Arrays; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a vector of 8-bit signed integers, where each element in the vector is a byte. + *

+ * The {@link Int8Vector} is used to store and retrieve data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#int8Vector(byte[]) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @see Binary#Binary(Vector) + * @see Binary#asVector() + * @since BINARY_VECTOR + */ +public class Int8Vector extends Vector { + + private byte[] vectorData; + + Int8Vector(final byte[] vectorData) { + super(Dtype.INT8); + this.vectorData = assertNotNull(vectorData); + } + + /** + * Retrieve the underlying byte array representing this {@link Int8Vector} vector, where each byte represents + * an element of a vector. + *

+ * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the underlying byte array representing this {@link Int8Vector} vector. + */ + public byte[] getVectorArray() { + return assertNotNull(vectorData); + } + + @Override + public final boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Int8Vector)) { + return false; + } + + Int8Vector that = (Int8Vector) o; + return Arrays.equals(vectorData, that.vectorData); + } + + @Override + public int hashCode() { + return Arrays.hashCode(vectorData); + } + + @Override + public String toString() { + return "Int8Vector{" + + "vectorData=" + Arrays.toString(vectorData) + + ", vectorType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java new file mode 100644 index 00000000000..37362ba50fa --- /dev/null +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -0,0 +1,105 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.bson.types.Binary; + +import java.util.Arrays; + +import static org.bson.assertions.Assertions.assertNotNull; + +/** + * Represents a packed bit vector, where each element of the vector is represented by a single bit (0 or 1). + *

+ * The {@link PackedBitVector} is used to store data efficiently using the BSON Binary Subtype 9 format. + * + * @mongodb.server.release 6.0 + * @see Vector#packedBitVector(byte[], byte) + * @see BsonBinary#BsonBinary(Vector) + * @see BsonBinary#asVector() + * @see Binary#Binary(Vector) + * @see Binary#asVector() + * @since BINARY_VECTOR + */ +public class PackedBitVector extends Vector { + + private final byte padding; + private final byte[] vectorData; + + PackedBitVector(final byte[] vectorData, final byte padding) { + super(Dtype.PACKED_BIT); + this.vectorData = assertNotNull(vectorData); + this.padding = padding; + } + + /** + * Retrieve the underlying byte array representing this {@link PackedBitVector} vector, where + * each bit represents an element of the vector (either 0 or 1). + *

+ * Note that the {@linkplain #getPadding() padding value} should be considered when interpreting the final byte of the array, + * as it indicates how many least-significant bits are to be ignored. + * + * @return the underlying byte array representing this {@link PackedBitVector} vector. + * @see #getPadding() + */ + public byte[] getVectorArray() { + return assertNotNull(vectorData); + } + + /** + * Returns the padding value for this vector. + * + *

Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not + * all {@link Dtype}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

+ *

+ * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. + * + * @return the padding value (between 0 and 7). + */ + public byte getPadding() { + return this.padding; + } + + @Override + public final boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PackedBitVector)) { + return false; + } + + PackedBitVector that = (PackedBitVector) o; + return padding == that.padding && Arrays.equals(vectorData, that.vectorData); + } + + @Override + public int hashCode() { + int result = padding; + result = 31 * result + Arrays.hashCode(vectorData); + return result; + } + + @Override + public String toString() { + return "PackedBitVector{" + + "padding=" + padding + + ", vectorData=" + Arrays.toString(vectorData) + + ", vectorType=" + getDataType() + + '}'; + } +} diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index 38d1125c55a..e66fb313cdf 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -17,10 +17,6 @@ package org.bson; -import java.util.Arrays; -import java.util.Objects; - -import static org.bson.assertions.Assertions.assertNotNull; import static org.bson.assertions.Assertions.isTrue; import static org.bson.assertions.Assertions.isTrueArgument; import static org.bson.assertions.Assertions.notNull; @@ -28,7 +24,7 @@ /** * Represents a vector that is stored and retrieved using the BSON Binary Subtype 9 format. * This class supports multiple vector {@link Dtype}'s and provides static methods to create - * vectors and methods to retrieve their underlying data. + * vectors. *

* Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently * in BSON using a binary format. @@ -37,26 +33,12 @@ * @see BsonBinary * @since BINARY_VECTOR */ -public final class Vector { - private final byte padding; - private byte[] vectorData; - private float[] floatVectorData; - private final Dtype vectorType; - - Vector(final byte padding, final byte[] vectorData, final Dtype vectorType) { - this.padding = padding; - this.vectorData = assertNotNull(vectorData); - this.vectorType = assertNotNull(vectorType); - } - Vector(final byte[] vectorData, final Dtype vectorType) { - this((byte) 0, vectorData, vectorType); - } +public class Vector { + private final Dtype vectorType; - Vector(final float[] vectorData) { - this.padding = 0; - this.floatVectorData = assertNotNull(vectorData); - this.vectorType = Dtype.FLOAT32; + Vector(final Dtype vectorType) { + this.vectorType = vectorType; } /** @@ -72,18 +54,20 @@ public final class Vector { * Padding: 4 (ignore the last 4 bits in Byte 2) * Resulting vector: 12 bits: 111011101110 * - * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + *

+ * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected + * in the created {@link PackedBitVector} instance. * * @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits. * @param padding The number of bits (0 to 7) to ignore in the final byte of the vector data. - * @return A Vector instance with the {@link Dtype#PACKED_BIT} data type. + * @return A {@link PackedBitVector} instance with the {@link Dtype#PACKED_BIT} data type. * @throws IllegalArgumentException If the padding value is greater than 7. */ - public static Vector packedBitVector(final byte[] vectorData, final byte padding) { + public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) { isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); notNull("Vector data", vectorData); isTrue("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); - return new Vector(padding, vectorData, Dtype.PACKED_BIT); + return new PackedBitVector(vectorData, padding); } /** @@ -92,97 +76,69 @@ public static Vector packedBitVector(final byte[] vectorData, final byte padding *

A {@link Dtype#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, * with values in the range [-128, 127].

*

- * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected + * in the created {@link Int8Vector} instance. * * @param vectorData The byte array representing the {@link Dtype#INT8} vector data. - * @return A Vector instance with the {@link Dtype#INT8} data type. + * @return A {@link Int8Vector} instance with the {@link Dtype#INT8} data type. */ - public static Vector int8Vector(final byte[] vectorData) { + public static Int8Vector int8Vector(final byte[] vectorData) { notNull("vectorData", vectorData); - return new Vector(vectorData, Dtype.INT8); + return new Int8Vector(vectorData); } /** * Creates a vector with the {@link Dtype#FLOAT32} data type. - * - *

A {@link Dtype#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

*

- * NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected in the created {@link Vector} instance. + * A {@link Dtype#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

+ *

+ * NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected + * in the created {@link Float32Vector} instance. * * @param vectorData The float array representing the {@link Dtype#FLOAT32} vector data. - * @return A Vector instance with the {@link Dtype#FLOAT32} data type. + * @return A {@link Float32Vector} instance with the {@link Dtype#FLOAT32} data type. */ - public static Vector floatVector(final float[] vectorData) { + public static Float32Vector floatVector(final float[] vectorData) { notNull("vectorData", vectorData); - return new Vector(vectorData); + return new Float32Vector(vectorData); } /** - * Returns the {@link Dtype#PACKED_BIT} vector data as a byte array. - * - *

This method is used to retrieve the underlying underlying byte array representing the {@link Dtype#PACKED_BIT} vector, where - * each bit represents an element of the vector (either 0 or 1). + * Returns the {@link PackedBitVector}. * - * @return the packed bit vector data. - * @throws IllegalStateException if this vector is not of type {@link Dtype#PACKED_BIT}. Use {@link #getDataType()} to check the vector type before - * calling this method. - * @see #getPadding() getPadding() specifies how many least-significant bits in the final byte should be ignored. + * @return {@link PackedBitVector}. + * @throws IllegalStateException if this vector is not of type {@link Dtype#PACKED_BIT}. Use {@link #getDataType()} to check the vector + * type before calling this method. */ - public byte[] asPackedBitVectorData() { - if (this.vectorType != Dtype.PACKED_BIT) { - throw new IllegalStateException("Vector is not binary quantized"); - } - return assertNotNull(vectorData); + public PackedBitVector asPackedBitVector() { + ensureType(Dtype.PACKED_BIT); + return (PackedBitVector) this; } /** - * Returns the {@link Dtype#INT8} vector data as a byte array. - * - *

This method is used to retrieve the underlying byte array representing the {@link Dtype#INT8} vector, where each byte represents - * an element of a vector.

+ * Returns the {@link Int8Vector}. * - * @return the {@link Dtype#INT8} vector data. + * @return {@link Int8Vector}. * @throws IllegalStateException if this vector is not of type {@link Dtype#INT8}. Use {@link #getDataType()} to check the vector * type before calling this method. */ - public byte[] asInt8VectorData() { - if (this.vectorType != Dtype.INT8) { - throw new IllegalStateException("Vector is not INT8"); - } - return assertNotNull(vectorData); + public Int8Vector asInt8Vector() { + ensureType(Dtype.INT8); + return (Int8Vector) this; } /** - * Returns the {@link Dtype#FLOAT32} vector data as a float array. - * - *

This method is used to retrieve the underlying float array representing the {@link Dtype#FLOAT32} vector, where each float - * represents an element of a vector.

+ * Returns the {@link Float32Vector}. * - * @return the float array representing the FLOAT32 vector. + * @return {@link Float32Vector}. * @throws IllegalStateException if this vector is not of type {@link Dtype#FLOAT32}. Use {@link #getDataType()} to check the vector * type before calling this method. */ - public float[] asFloatVectorData() { - if (this.vectorType != Dtype.FLOAT32) { - throw new IllegalStateException("Vector is not FLOAT32"); - } - - return assertNotNull(floatVectorData); - } - - /** - * Returns the padding value for this vector. - * - *

Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not - * all {@link Dtype}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

- * - * @return the padding value (between 0 and 7). - */ - public byte getPadding() { - return this.padding; + public Float32Vector asFloat32Vector() { + ensureType(Dtype.FLOAT32); + return (Float32Vector) this; } - /** * Returns {@link Dtype} of the vector. * @@ -193,36 +149,10 @@ public Dtype getDataType() { } - @Override - public String toString() { - return "Vector{" - + "padding=" + padding + ", " - + "vectorData=" + (vectorData == null ? Arrays.toString(floatVectorData) : Arrays.toString(vectorData)) - + ", vectorType=" + vectorType - + '}'; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (!(o instanceof Vector)) { - return false; + private void ensureType(final Dtype expected) { + if (this.vectorType != expected) { + throw new IllegalStateException("Expected vector type " + expected + " but found " + this.vectorType); } - - Vector vector = (Vector) o; - return padding == vector.padding && Arrays.equals(vectorData, vector.vectorData) - && Arrays.equals(floatVectorData, vector.floatVectorData) && vectorType == vector.vectorType; - } - - @Override - public int hashCode() { - int result = padding; - result = 31 * result + Arrays.hashCode(vectorData); - result = 31 * result + Arrays.hashCode(floatVectorData); - result = 31 * result + Objects.hashCode(vectorType); - return result; } /** diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index 12cabef0d95..a5e9bf8adb5 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -17,6 +17,7 @@ package org.bson.internal.vector; import org.bson.BsonBinary; +import org.bson.PackedBitVector; import org.bson.Vector; import java.nio.ByteBuffer; @@ -26,7 +27,7 @@ import static org.bson.assertions.Assertions.isTrue; /** - * Helper class for encoding and decoding vectors to and from binary. + * Helper class for encoding and decoding vectors to and from {@link BsonBinary}. * *

* This class is not part of the public API and may be removed or changed at any time. @@ -48,20 +49,25 @@ private VectorHelper() { public static byte[] encodeVectorToBinary(final Vector vector) { Vector.Dtype dtype = vector.getDataType(); - byte padding = vector.getPadding(); switch (dtype) { case INT8: - return writeVector(dtype.getValue(), padding, vector.asInt8VectorData()); + return writeVector(dtype.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); case PACKED_BIT: - return writeVector(dtype.getValue(), padding, vector.asPackedBitVectorData()); + PackedBitVector packedBitVector = vector.asPackedBitVector(); + return writeVector(dtype.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); case FLOAT32: - return writeVector(dtype.getValue(), padding, vector.asFloatVectorData()); + return writeVector(dtype.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); default: throw new AssertionError("Unknown vector dtype: " + dtype); } } + /** + * Decodes a vector from a binary representation. + *

+ * encodedVector is not mutated nor stored in the returned {@link Vector}. + */ public static Vector decodeBinaryToVector(final byte[] encodedVector) { isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE); diff --git a/bson/src/main/org/bson/types/Binary.java b/bson/src/main/org/bson/types/Binary.java index 5ba482ccc41..f1c8db03562 100644 --- a/bson/src/main/org/bson/types/Binary.java +++ b/bson/src/main/org/bson/types/Binary.java @@ -17,10 +17,15 @@ package org.bson.types; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.Vector; +import org.bson.internal.vector.VectorHelper; import java.io.Serializable; import java.util.Arrays; +import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; + /** * Generic binary holder. */ @@ -67,6 +72,35 @@ public Binary(final byte type, final byte[] data) { this.data = data.clone(); } + /** + * Construct a Type 9 BsonBinary from the given Vector. + * + * @param vector the {@link Vector} + * @since BINARY_VECTOR + */ + public Binary(final Vector vector) { + if (vector == null) { + throw new IllegalArgumentException("Vector must not be null"); + } + this.data = encodeVectorToBinary(vector); + type = BsonBinarySubType.VECTOR.getValue(); + } + + /** + * Returns the binary as a {@link Vector}. The binary type must be 9. + * + * @return the vector + * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. + * @since BINARY_VECTOR + */ + public Vector asVector() { + if (!BsonBinarySubType.isVector(type)) { + throw new BsonInvalidOperationException("type must be a Vector subtype."); + } + + return VectorHelper.decodeBinaryToVector(this.data); + } + /** * Get the binary sub type as a byte. * diff --git a/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy b/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy index e51094e964f..503440daa04 100644 --- a/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy +++ b/bson/src/test/unit/org/bson/BsonBinarySpecification.groovy @@ -48,9 +48,14 @@ class BsonBinarySpecification extends Specification { data == bsonBinary.getData() where: - subType << [BsonBinarySubType.BINARY, BsonBinarySubType.FUNCTION, BsonBinarySubType.MD5, - BsonBinarySubType.OLD_BINARY, BsonBinarySubType.USER_DEFINED, BsonBinarySubType.UUID_LEGACY, - BsonBinarySubType.UUID_STANDARD] + subType << [BsonBinarySubType.BINARY, + BsonBinarySubType.FUNCTION, + BsonBinarySubType.MD5, + BsonBinarySubType.OLD_BINARY, + BsonBinarySubType.USER_DEFINED, + BsonBinarySubType.UUID_LEGACY, + BsonBinarySubType.UUID_STANDARD, + BsonBinarySubType.VECTOR] } @Unroll diff --git a/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy b/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy index f26d1ad00d9..a849f1fc595 100644 --- a/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy +++ b/bson/src/test/unit/org/bson/BsonBinarySubTypeSpecification.groovy @@ -33,5 +33,6 @@ class BsonBinarySubTypeSpecification extends Specification { 5 | false 6 | false 7 | false + 9 | false } } diff --git a/bson/src/test/unit/org/bson/BsonBinaryTest.java b/bson/src/test/unit/org/bson/BsonBinaryTest.java new file mode 100644 index 00000000000..62d54116276 --- /dev/null +++ b/bson/src/test/unit/org/bson/BsonBinaryTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.bson.assertions.Assertions.fail; +import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class BsonBinaryTest { + + static Stream provideVectors() { + return Stream.of( + Vector.floatVector(new float[]{1.5f, 2.1f, 3.1f}), + Vector.int8Vector(new byte[]{10, 20, 30}), + Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010000}, (byte) 3) + ); + } + + @ParameterizedTest + @MethodSource("provideVectors") + void shouldCreateBsonBinaryFromVector(final Vector vector) { + // when + BsonBinary bsonBinary = new BsonBinary(vector); + + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), bsonBinary.getType(), "The subtype must be VECTOR"); + assertNotNull(bsonBinary.getData(), "BsonBinary data should not be null"); + assertArrayEquals(encodeVectorToBinary(vector), bsonBinary.getData()); + } + + @Test + void shouldThrowExceptionWhenCreatingBsonBinaryWithNullVector() { + // given + Vector vector = null; + + // when & then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new BsonBinary(vector)); + assertEquals("Vector must not be null", exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("provideVectors") + void shouldConvertBsonBinaryToVector(final Vector actualVector) { + // given + BsonBinary bsonBinary = new BsonBinary(actualVector); + + // when + Vector decodedVector = bsonBinary.asVector(); + + // then + assertNotNull(decodedVector); + assertEquals(actualVector.getDataType(), decodedVector.getDataType()); + assertVectorDataDecoding(actualVector, decodedVector); + } + + private static void assertVectorDataDecoding(final Vector actualVector, final Vector decodedVector) { + switch (actualVector.getDataType()) { + case FLOAT32: + Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); + Float32Vector decodedFloat32Vector1 = decodedVector.asFloat32Vector(); + assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector1.getVectorArray(), + "Float vector data should match after decoding"); + break; + case INT8: + Int8Vector actualInt8Vector = actualVector.asInt8Vector(); + Int8Vector decodedInt8Vector1 = decodedVector.asInt8Vector(); + assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector1.getVectorArray(), + "Int8 vector data should match after decoding"); + break; + case PACKED_BIT: + PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); + PackedBitVector decodedPackedBitVector = decodedVector.asPackedBitVector(); + assertArrayEquals(actualPackedBitVector.getVectorArray(), decodedPackedBitVector.getVectorArray(), + "Packed bit vector data should match after decoding"); + assertEquals(actualPackedBitVector.getPadding(), decodedPackedBitVector.getPadding(), "Padding should match after decoding"); + break; + default: + fail("Unexpected vector type: " + actualVector.getDataType()); + } + } + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionWhenBsonBinarySubTypeIsNotVector(final BsonBinarySubType bsonBinarySubType) { + // given + byte[] data = new byte[]{1, 2, 3, 4}; + BsonBinary bsonBinary = new BsonBinary(bsonBinarySubType.getValue(), data); + + // when & then + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); + assertEquals("type must be a Vector subtype.", exception.getMessage()); + } +} diff --git a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java index 15e27065ba2..91fbd3dbf1f 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java @@ -40,6 +40,9 @@ public class BsonBinaryWriterTest { + private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue(); + private static final int ZERO_PADDING = 0; + private BsonBinaryWriter writer; private BasicOutputBuffer buffer; @@ -299,12 +302,19 @@ public void testWriteBinary() { writer.writeBinaryData("b1", new BsonBinary(new byte[]{0, 0, 0, 0, 0, 0, 0, 0})); writer.writeBinaryData("b2", new BsonBinary(BsonBinarySubType.OLD_BINARY, new byte[]{1, 1, 1, 1, 1})); writer.writeBinaryData("b3", new BsonBinary(BsonBinarySubType.FUNCTION, new byte[]{})); + writer.writeBinaryData("b4", new BsonBinary(BsonBinarySubType.VECTOR, new byte[]{FLOAT32_DTYPE, ZERO_PADDING, + (byte) 205, (byte) 204, (byte) 140, (byte) 63})); + writer.writeEndDocument(); byte[] expectedValues = {49, 0, 0, 0, 5, 98, 49, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 98, 50, 0, 9, 0, - 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, 5, 98, 51, 0, 0, 0, 0, 0, 1, 0}; + 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, + 5, 98, 51, 0, 0, 0, 0, 0, 1, 0, + 6, BsonBinarySubType.VECTOR.getValue(), FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, 63, + }; + assertArrayEquals(expectedValues, buffer.toByteArray()); } diff --git a/bson/src/test/unit/org/bson/VectorTest.java b/bson/src/test/unit/org/bson/VectorTest.java index c1ea00ca0d7..f16373a0777 100644 --- a/bson/src/test/unit/org/bson/VectorTest.java +++ b/bson/src/test/unit/org/bson/VectorTest.java @@ -33,13 +33,12 @@ void shouldCreateInt8Vector() { byte[] data = {1, 2, 3, 4, 5}; // when - Vector vector = Vector.int8Vector(data); + Int8Vector vector = Vector.int8Vector(data); // then assertNotNull(vector); assertEquals(Vector.Dtype.INT8, vector.getDataType()); - assertArrayEquals(data, vector.asInt8VectorData()); - assertEquals(0, vector.getPadding()); + assertArrayEquals(data, vector.getVectorArray()); } @Test @@ -58,13 +57,12 @@ void shouldCreateFloat32Vector() { float[] data = {1.0f, 2.0f, 3.0f}; // when - Vector vector = Vector.floatVector(data); + Float32Vector vector = Vector.floatVector(data); // then assertNotNull(vector); assertEquals(Vector.Dtype.FLOAT32, vector.getDataType()); - assertArrayEquals(data, vector.asFloatVectorData()); - assertEquals(0, vector.getPadding()); + assertArrayEquals(data, vector.getVectorArray()); } @Test @@ -85,12 +83,12 @@ void shouldCreatePackedBitVector(final byte validPadding) { byte[] data = {(byte) 0b10101010, (byte) 0b01010101}; // when - Vector vector = Vector.packedBitVector(data, validPadding); + PackedBitVector vector = Vector.packedBitVector(data, validPadding); // then assertNotNull(vector); assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); - assertArrayEquals(data, vector.asPackedBitVectorData()); + assertArrayEquals(data, vector.getVectorArray()); assertEquals(validPadding, vector.getPadding()); } @@ -125,12 +123,12 @@ void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() { byte padding = 0; // when - Vector vector = Vector.packedBitVector(data, padding); + PackedBitVector vector = Vector.packedBitVector(data, padding); // then assertNotNull(vector); assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); - assertArrayEquals(data, vector.asPackedBitVectorData()); + assertArrayEquals(data, vector.getVectorArray()); assertEquals(padding, vector.getPadding()); } @@ -153,8 +151,8 @@ void shouldThrowExceptionWhenRetrievingInt8DataFromNonInt8Vector() { Vector vector = Vector.floatVector(data); // when & Then - IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asInt8VectorData); - assertEquals("Vector is not INT8", exception.getMessage()); + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asInt8Vector); + assertEquals("Expected vector type INT8 but found FLOAT32", exception.getMessage()); } @Test @@ -164,8 +162,8 @@ void shouldThrowExceptionWhenRetrievingFloat32DataFromNonFloat32Vector() { Vector vector = Vector.int8Vector(data); // when & Then - IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asFloatVectorData); - assertEquals("Vector is not FLOAT32", exception.getMessage()); + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asFloat32Vector); + assertEquals("Expected vector type FLOAT32 but found INT8", exception.getMessage()); } @Test @@ -175,7 +173,7 @@ void shouldThrowExceptionWhenRetrievingPackedBitDataFromNonPackedBitVector() { Vector vector = Vector.floatVector(data); // when & Then - IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asPackedBitVectorData); - assertEquals("Vector is not binary quantized", exception.getMessage()); + IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asPackedBitVector); + assertEquals("Expected vector type PACKED_BIT but found FLOAT32", exception.getMessage()); } } diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java index ffc6abc5cdf..4edf6b489f9 100644 --- a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java +++ b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java @@ -16,6 +16,9 @@ package org.bson.internal.vector; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; import org.bson.Vector; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -45,14 +48,13 @@ void shouldEncodeFloatVector(final Vector actualFloat32Vector, final byte[] expe @ParameterizedTest(name = "{index}: {0}") @MethodSource("provideFloatVectors") - void shouldDecodeFloatVector(final Vector expectedFloatVector, final byte[] bsonEncodedVector) { + void shouldDecodeFloatVector(final Float32Vector expectedFloatVector, final byte[] bsonEncodedVector) { // when - Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + Float32Vector decodedVector = (Float32Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then assertEquals(Vector.Dtype.FLOAT32, decodedVector.getDataType()); - assertEquals(0, decodedVector.getPadding()); - assertArrayEquals(expectedFloatVector.asFloatVectorData(), decodedVector.asFloatVectorData()); + assertArrayEquals(expectedFloatVector.getVectorArray(), decodedVector.getVectorArray()); } private static Stream provideFloatVectors() { @@ -94,13 +96,13 @@ void shouldEncodeInt8Vector(final Vector actualInt8Vector, final byte[] expected @ParameterizedTest(name = "{index}: {0}") @MethodSource("provideInt8Vectors") - void shouldDecodeInt8Vector(final Vector expectedInt8Vector, final byte[] bsonEncodedVector) { + void shouldDecodeInt8Vector(final Int8Vector expectedInt8Vector, final byte[] bsonEncodedVector) { // when - Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + Int8Vector decodedVector = (Int8Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then assertEquals(Vector.Dtype.INT8, decodedVector.getDataType()); - assertArrayEquals(expectedInt8Vector.asInt8VectorData(), decodedVector.asInt8VectorData()); + assertArrayEquals(expectedInt8Vector.getVectorArray(), decodedVector.getVectorArray()); } private static Stream provideInt8Vectors() { @@ -127,13 +129,13 @@ void shouldEncodePackedBitVector(final Vector actualPackedBitVector, final byte[ @ParameterizedTest @MethodSource("providePackedBitVectors") - void shouldDecodePackedBitVector(final Vector expectedPackedBitVector, final byte[] bsonEncodedVector) { + void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector, final byte[] bsonEncodedVector) { // when - Vector decodedVector = VectorHelper.decodeBinaryToVector(bsonEncodedVector); + PackedBitVector decodedVector = (PackedBitVector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then assertEquals(Vector.Dtype.PACKED_BIT, decodedVector.getDataType()); - assertArrayEquals(expectedPackedBitVector.asPackedBitVectorData(), decodedVector.asPackedBitVectorData()); + assertArrayEquals(expectedPackedBitVector.getVectorArray(), decodedVector.getVectorArray()); assertEquals(expectedPackedBitVector.getPadding(), decodedVector.getPadding()); } diff --git a/bson/src/test/unit/org/bson/types/BinaryTest.java b/bson/src/test/unit/org/bson/types/BinaryTest.java new file mode 100644 index 00000000000..c1defa310c9 --- /dev/null +++ b/bson/src/test/unit/org/bson/types/BinaryTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.types; + +import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.bson.assertions.Assertions.fail; +import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class BinaryTest { + static Stream provideVectors() { + return Stream.of( + Vector.floatVector(new float[]{1.5f, 2.1f, 3.1f}), + Vector.int8Vector(new byte[]{10, 20, 30}), + Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010000}, (byte) 3) + ); + } + + @ParameterizedTest + @MethodSource("provideVectors") + void shouldCreateBinaryFromVector(final Vector vector) { + // when + Binary binary = new Binary(vector); + + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), binary.getType(), "The subtype must be VECTOR"); + assertNotNull(binary.getData(), "Binary data should not be null"); + assertArrayEquals(encodeVectorToBinary(vector), binary.getData()); + } + + @Test + void shouldThrowExceptionWhenCreatingBinaryWithNullVector() { + // given + Vector vector = null; + + // when & then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new Binary(vector)); + assertEquals("Vector must not be null", exception.getMessage()); + } + + @ParameterizedTest + @MethodSource("provideVectors") + void shouldConvertBinaryToVector(final Vector actualVector) { + // given + Binary binary = new Binary(actualVector); + + // when + Vector decodedVector = binary.asVector(); + + // then + assertNotNull(decodedVector); + assertEquals(actualVector.getDataType(), decodedVector.getDataType()); + assertVectorDataDecoding(actualVector, decodedVector); + } + + private static void assertVectorDataDecoding(final Vector actualVector, final Vector decodedVector) { + switch (actualVector.getDataType()) { + case FLOAT32: + Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); + Float32Vector decodedFloat32Vector1 = decodedVector.asFloat32Vector(); + assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector1.getVectorArray(), + "Float vector data should match after decoding"); + break; + case INT8: + Int8Vector actualInt8Vector = actualVector.asInt8Vector(); + Int8Vector decodedInt8Vector1 = decodedVector.asInt8Vector(); + assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector1.getVectorArray(), + "Int8 vector data should match after decoding"); + break; + case PACKED_BIT: + PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); + PackedBitVector decodedPackedBitVector = decodedVector.asPackedBitVector(); + assertArrayEquals(actualPackedBitVector.getVectorArray(), decodedPackedBitVector.getVectorArray(), + "Packed bit vector data should match after decoding"); + assertEquals(actualPackedBitVector.getPadding(), decodedPackedBitVector.getPadding(), + "Padding should match after decoding"); + break; + default: + fail("Unexpected vector type: " + actualVector.getDataType()); + } + } + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionWhenBinarySubTypeIsNotVector(final BsonBinarySubType binarySubType) { + // given + byte[] data = new byte[]{1, 2, 3, 4}; + Binary binary = new Binary(binarySubType.getValue(), data); + + // when & then + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, binary::asVector); + assertEquals("type must be a Vector subtype.", exception.getMessage()); + } +} diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java index c268bb751be..dc1310137ec 100644 --- a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -21,6 +21,8 @@ import org.bson.BsonDocument; import org.bson.BsonString; import org.bson.BsonValue; +import org.bson.Float32Vector; +import org.bson.PackedBitVector; import org.bson.Vector; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.params.ParameterizedTest; @@ -122,24 +124,29 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) switch (expectedDType) { case INT8: byte[] expectedVectorData = toByteArray(arrayVector); - byte[] actualVectorData = actualVector.asInt8VectorData(); + byte[] actualVectorData = actualVector.asInt8Vector().getVectorArray(); assertVectorDecoding( - expectedCanonicalBsonHex, expectedVectorData, - expectedDType, expectedPadding, actualDecodedDocument, - actualVectorData, actualVector); + expectedCanonicalBsonHex, + expectedVectorData, + expectedDType, + actualDecodedDocument, + actualVectorData, + actualVector); assertThatVectorCreationResultsInCorrectBinary(Vector.int8Vector(expectedVectorData), testKey, actualDecodedDocument, - expectedCanonicalBsonHex, description); + expectedCanonicalBsonHex, + description); break; case PACKED_BIT: - byte[] actualVectorPackedBitData = actualVector.asPackedBitVectorData(); + PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); byte[] expectedVectorPackedBitData = toByteArray(arrayVector); assertVectorDecoding( expectedCanonicalBsonHex, expectedVectorPackedBitData, - expectedDType, expectedPadding, actualDecodedDocument, - actualVectorPackedBitData, actualVector); + expectedDType, expectedPadding, + actualDecodedDocument, + actualPackedBitVector); assertThatVectorCreationResultsInCorrectBinary( Vector.packedBitVector(expectedVectorPackedBitData, expectedPadding), @@ -149,12 +156,14 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) description); break; case FLOAT32: - float[] actualFloatVector = actualVector.asFloatVectorData(); + Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); float[] expectedFloatVector = toFloatArray(arrayVector); -// assertVectorDecoding( -// expectedCanonicalBsonHex, expectedFloatVector, -// expectedDType, expectedPadding, actualDecodedDocument, -// actualFloatVector, actualVector); + assertVectorDecoding( + expectedCanonicalBsonHex, + expectedFloatVector, + expectedDType, + actualDecodedDocument, + actualFloat32Vector); assertThatVectorCreationResultsInCorrectBinary( Vector.floatVector(expectedFloatVector), testKey, @@ -185,7 +194,6 @@ private static void assertThatVectorCreationResultsInCorrectBinary(final Vector private void assertVectorDecoding(final String expectedCanonicalBsonHex, final byte[] expectedVectorData, final Vector.Dtype expectedDType, - final byte expectedPadding, final BsonDocument actualDecodedDocument, final byte[] actualVectorData, final Vector actualVector) { @@ -193,21 +201,34 @@ private void assertVectorDecoding(final String expectedCanonicalBsonHex, Assertions.assertArrayEquals(actualVectorData, expectedVectorData, () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); assertEquals(expectedDType, actualVector.getDataType()); + } + + private void assertVectorDecoding(final String expectedCanonicalBsonHex, + final byte[] expectedVectorData, + final Vector.Dtype expectedDType, + final byte expectedPadding, + final BsonDocument actualDecodedDocument, + final PackedBitVector actualVector) { + byte[] actualVectorData = actualVector.getVectorArray(); + assertVectorDecoding(expectedCanonicalBsonHex, + expectedVectorData, + expectedDType, + actualDecodedDocument, + actualVectorData, + actualVector); assertEquals(expectedPadding, actualVector.getPadding()); } private void assertVectorDecoding(final String expectedCanonicalBsonHex, final float[] expectedVectorData, final Vector.Dtype expectedDType, - final byte expectedPadding, final BsonDocument actualDecodedDocument, - final float[] actualVectorData, - final Vector actualVector) { + final Float32Vector actualVector) { assertEquals(expectedCanonicalBsonHex, encodeToHex(actualDecodedDocument)); - Assertions.assertArrayEquals(actualVectorData, expectedVectorData, - () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); + float[] actualVectorArray = actualVector.getVectorArray(); + Assertions.assertArrayEquals(actualVectorArray, expectedVectorData, + () -> "Actual: " + Arrays.toString(actualVectorArray) + " != Expected:" + Arrays.toString(expectedVectorData)); assertEquals(expectedDType, actualVector.getDataType()); - assertEquals(expectedPadding, actualVector.getPadding()); } private byte[] toByteArray(final BsonArray arrayVector) { From c68bd7081fcae33ab401d3ebc07b4cc72f1077e3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 15 Oct 2024 19:56:05 -0700 Subject: [PATCH 03/20] Add Vector codecs. JAVA-5544 --- .../org/bson/codecs/Float32VectorCodec.java | 63 +++++++ .../main/org/bson/codecs/Int8VectorCodec.java | 63 +++++++ .../org/bson/codecs/PackedBitVectorCodec.java | 64 ++++++++ .../org/bson/codecs/ValueCodecProvider.java | 6 + .../unit/org/bson/codecs/CodecTestCase.java | 4 +- .../ValueCodecProviderSpecification.groovy | 3 + .../unit/org/bson/codecs/VectorCodecTest.java | 155 ++++++++++++++++++ 7 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 bson/src/main/org/bson/codecs/Float32VectorCodec.java create mode 100644 bson/src/main/org/bson/codecs/Int8VectorCodec.java create mode 100644 bson/src/main/org/bson/codecs/PackedBitVectorCodec.java create mode 100644 bson/src/test/unit/org/bson/codecs/VectorCodecTest.java diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java new file mode 100644 index 00000000000..da39e5b8abf --- /dev/null +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -0,0 +1,63 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BSONException; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Float32Vector; + +/** + * Encodes and decodes {@code Vector} objects. + * + * @since BINARY_VECTOR + */ +final class Float32VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Float32Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Float32Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BSONException("Unexpected BsonBinarySubType"); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asFloat32Vector(); + } + + + @Override + public Class getEncoderClass() { + return Float32Vector.class; + } + + @Override + public String toString() { + return "Float32VectorCodec{}"; + } +} + diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java new file mode 100644 index 00000000000..7bf79d1b87d --- /dev/null +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java @@ -0,0 +1,63 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BSONException; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Int8Vector; + +/** + * Encodes and decodes {@link Int8Vector} objects. + * + * @since BINARY_VECTOR + */ +final class Int8VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Int8Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Int8Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BSONException("Unexpected BsonBinarySubType"); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asInt8Vector(); + } + + + @Override + public Class getEncoderClass() { + return Int8Vector.class; + } + + @Override + public String toString() { + return "Int8VectorCodec{}"; + } +} + diff --git a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java new file mode 100644 index 00000000000..3bc3cfe19c1 --- /dev/null +++ b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java @@ -0,0 +1,64 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BSONException; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.PackedBitVector; + +/** + * Encodes and decodes {@link PackedBitVector} objects. + * + * @since BINARY_VECTOR + */ +final class PackedBitVectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final PackedBitVector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public PackedBitVector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BSONException("Unexpected BsonBinarySubType"); + } + + return reader.readBinaryData() + .asBinary() + .asVector() + .asPackedBitVector(); + } + + + @Override + public Class getEncoderClass() { + return PackedBitVector.class; + } + + @Override + public String toString() { + return "PackedBitVectorCodec{}"; + } +} + + diff --git a/bson/src/main/org/bson/codecs/ValueCodecProvider.java b/bson/src/main/org/bson/codecs/ValueCodecProvider.java index 80ec5e6f18d..a9b7c300f9d 100644 --- a/bson/src/main/org/bson/codecs/ValueCodecProvider.java +++ b/bson/src/main/org/bson/codecs/ValueCodecProvider.java @@ -42,6 +42,9 @@ *

  • {@link org.bson.codecs.StringCodec}
  • *
  • {@link org.bson.codecs.SymbolCodec}
  • *
  • {@link org.bson.codecs.UuidCodec}
  • + *
  • {@link Float32VectorCodec}
  • + *
  • {@link Int8VectorCodec}
  • + *
  • {@link PackedBitVectorCodec}
  • *
  • {@link org.bson.codecs.ByteCodec}
  • *
  • {@link org.bson.codecs.ShortCodec}
  • *
  • {@link org.bson.codecs.ByteArrayCodec}
  • @@ -86,6 +89,9 @@ private void addCodecs() { addCodec(new StringCodec()); addCodec(new SymbolCodec()); addCodec(new OverridableUuidRepresentationUuidCodec()); + addCodec(new Float32VectorCodec()); + addCodec(new Int8VectorCodec()); + addCodec(new PackedBitVectorCodec()); addCodec(new ByteCodec()); addCodec(new PatternCodec()); diff --git a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java index 17768d0d133..b092121eb9d 100644 --- a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java +++ b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java @@ -85,14 +85,14 @@ public void roundTrip(final Document input, final Document expected) { roundTrip(input, result -> assertEquals(expected, result)); } - OutputBuffer encode(final Codec codec, final T value) { + protected OutputBuffer encode(final Codec codec, final T value) { OutputBuffer buffer = new BasicOutputBuffer(); BsonWriter writer = new BsonBinaryWriter(buffer); codec.encode(writer, value, EncoderContext.builder().build()); return buffer; } - T decode(final Codec codec, final OutputBuffer buffer) { + protected T decode(final Codec codec, final OutputBuffer buffer) { BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(buffer.toByteArray())))); return codec.decode(reader, DecoderContext.builder().build()); } diff --git a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy index c20299715e0..a4054f664ac 100644 --- a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy +++ b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy @@ -17,6 +17,7 @@ package org.bson.codecs import org.bson.Document +import org.bson.Vector import org.bson.codecs.configuration.CodecRegistries import org.bson.types.Binary import org.bson.types.Code @@ -32,6 +33,7 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicLong import java.util.regex.Pattern +@SuppressWarnings("VectorIsObsolete") class ValueCodecProviderSpecification extends Specification { private final provider = new ValueCodecProvider() private final registry = CodecRegistries.fromProviders(provider) @@ -56,6 +58,7 @@ class ValueCodecProviderSpecification extends Specification { provider.get(Short, registry) instanceof ShortCodec provider.get(byte[], registry) instanceof ByteArrayCodec provider.get(Float, registry) instanceof FloatCodec + provider.get(Vector, registry) instanceof Float32VectorCodec provider.get(Binary, registry) instanceof BinaryCodec provider.get(MinKey, registry) instanceof MinKeyCodec diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java new file mode 100644 index 00000000000..8dd0af1ba65 --- /dev/null +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BSONException; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Document; +import org.bson.Vector; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.io.OutputBuffer; +import org.bson.types.Binary; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +class VectorCodecTest extends CodecTestCase { + + private static final CodecRegistry CODEC_REGISTRIES = fromProviders(asList(new ValueCodecProvider(), new DocumentCodecProvider())); + + private static Stream provideVectorsAndCodecsForRoundTrip() { + return Stream.of( + arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec()), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new Int8VectorCodec()), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new PackedBitVectorCodec()) + ); + } + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecsForRoundTrip") + void shouldRoundTripVectors(final Vector vectorToEncode) { + //given + Document expectedDocument = new Document("vector", vectorToEncode); + + //when + Codec codec = CODEC_REGISTRIES.get(Document.class); + OutputBuffer buffer = encode(codec, expectedDocument); + Document actualDecodedDocument = decode(codec, buffer); + + //then + Binary binaryVector = (Binary) actualDecodedDocument.get("vector"); + assertNotEquals(actualDecodedDocument, expectedDocument); + Vector actualVector = binaryVector.asVector(); + assertEquals(actualVector, vectorToEncode); + } + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecsForRoundTrip") + void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorCodec) { + // given + BsonWriter mockWriter = Mockito.mock(BsonWriter.class); + + // when + vectorCodec.encode(mockWriter, vectorToEncode, EncoderContext.builder().build()); + + // then + verify(mockWriter, times(1)).writeBinaryData(new BsonBinary(vectorToEncode)); + verifyNoMoreInteractions(mockWriter); + } + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecsForRoundTrip") + void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorCodec) { + // given + BsonReader mockReader = Mockito.mock(BsonReader.class); + BsonBinary bsonBinary = new BsonBinary(vectorToDecode); + when(mockReader.peekBinarySubType()).thenReturn(BsonBinarySubType.VECTOR.getValue()); + when(mockReader.readBinaryData()).thenReturn(bsonBinary); + + // when + Vector decodedVector = vectorCodec.decode(mockReader, DecoderContext.builder().build()); + + // then + assertNotNull(decodedVector); + assertEquals(vectorToDecode, decodedVector); + } + + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { + // given + BsonReader mockReader = Mockito.mock(BsonReader.class); + when(mockReader.peekBinarySubType()).thenReturn(subType.getValue()); + + Stream.of(new Float32VectorCodec(), new Int8VectorCodec(), new PackedBitVectorCodec()) + .forEach(codec -> { + // when & then + BSONException exception = assertThrows(BSONException.class, () -> + codec.decode(mockReader, DecoderContext.builder().build())); + assertEquals("Unexpected BsonBinarySubType", exception.getMessage()); + }); + } + + + @ParameterizedTest + @MethodSource("provideVectorsAndCodecsForRoundTrip") + void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec) { + // when + Class encoderClass = codec.getEncoderClass(); + + // then + assertEquals(vector.getClass(), encoderClass); + } + + @ParameterizedTest + @MethodSource("provideVectorsCodec") + void shouldConvertToString(final Codec codec) { + // when + String result = codec.toString(); + + // then + assertEquals(codec.getClass().getSimpleName() + "{}", result); + } + + private static Stream> provideVectorsCodec() { + return Stream.of( + new Float32VectorCodec(), + new Int8VectorCodec(), + new PackedBitVectorCodec() + ); + } +} From a50708fea4a75fcaf5ebe00948a44d51c948d288 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 15 Oct 2024 21:57:31 -0700 Subject: [PATCH 04/20] Make subclasses final. JAVA-5544 --- bson/src/main/org/bson/Float32Vector.java | 2 +- bson/src/main/org/bson/Int8Vector.java | 2 +- bson/src/main/org/bson/PackedBitVector.java | 2 +- bson/src/main/org/bson/Vector.java | 3 +-- bson/src/main/org/bson/internal/vector/VectorHelper.java | 3 ++- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java index ad2b5973bfa..e3718d77280 100644 --- a/bson/src/main/org/bson/Float32Vector.java +++ b/bson/src/main/org/bson/Float32Vector.java @@ -35,7 +35,7 @@ * @see Binary#asVector() * @since BINARY_VECTOR */ -public class Float32Vector extends Vector { +public final class Float32Vector extends Vector { private final float[] vectorData; diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java index 56520b4de49..c3972d98a31 100644 --- a/bson/src/main/org/bson/Int8Vector.java +++ b/bson/src/main/org/bson/Int8Vector.java @@ -35,7 +35,7 @@ * @see Binary#asVector() * @since BINARY_VECTOR */ -public class Int8Vector extends Vector { +public final class Int8Vector extends Vector { private byte[] vectorData; diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index 37362ba50fa..43dfe6a060f 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -35,7 +35,7 @@ * @see Binary#asVector() * @since BINARY_VECTOR */ -public class PackedBitVector extends Vector { +public final class PackedBitVector extends Vector { private final byte padding; private final byte[] vectorData; diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index e66fb313cdf..1719cc77b4f 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -33,8 +33,7 @@ * @see BsonBinary * @since BINARY_VECTOR */ - -public class Vector { +public abstract class Vector { private final Dtype vectorType; Vector(final Dtype vectorType) { diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index a5e9bf8adb5..e926d938116 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -19,6 +19,7 @@ import org.bson.BsonBinary; import org.bson.PackedBitVector; import org.bson.Vector; +import org.bson.types.Binary; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -27,7 +28,7 @@ import static org.bson.assertions.Assertions.isTrue; /** - * Helper class for encoding and decoding vectors to and from {@link BsonBinary}. + * Helper class for encoding and decoding vectors to and from {@link BsonBinary}/{@link Binary}. * *

    * This class is not part of the public API and may be removed or changed at any time. From 7eb252b406308b96b3a6f5ca9f0de74c2edb3619 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 16 Oct 2024 15:23:30 -0700 Subject: [PATCH 05/20] Add functional tests. JAVA-5544 --- .../bson/codecs/macrocodecs/MacroCodec.scala | 1 + .../scala/bson/codecs/MacrosSpec.scala | 1 + bson/src/main/org/bson/Float32Vector.java | 2 +- bson/src/main/org/bson/Int8Vector.java | 2 +- bson/src/main/org/bson/PackedBitVector.java | 2 +- bson/src/main/org/bson/Vector.java | 2 +- .../org/bson/codecs/Float32VectorCodec.java | 1 - .../org/bson/codecs/ValueCodecProvider.java | 1 + .../src/main/org/bson/codecs/VectorCodec.java | 62 +++ .../bson/internal/vector/VectorHelper.java | 28 +- .../unit/org/bson/codecs/VectorCodecTest.java | 17 +- .../internal/vector/VectorHelperTest.java | 65 +++- .../client/vector/VectorFunctionalTest.java | 30 ++ .../vector/VectorAbstractFunctionalTest.java | 359 ++++++++++++++++++ .../client/vector/VectorFunctionalTest.java | 28 ++ 15 files changed, 572 insertions(+), 29 deletions(-) create mode 100644 bson/src/main/org/bson/codecs/VectorCodec.java create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java diff --git a/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala b/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala index 090d066223c..e284647af87 100644 --- a/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala +++ b/bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/MacroCodec.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.bson._ import org.bson.codecs.configuration.{ CodecRegistries, CodecRegistry } import org.bson.codecs.{ Codec, DecoderContext, Encoder, EncoderContext } +import scala.collection.immutable.Vector import org.mongodb.scala.bson.BsonNull diff --git a/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala b/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala index 95d7533cc87..c16215a16e8 100644 --- a/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala +++ b/bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala @@ -30,6 +30,7 @@ import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty } import org.mongodb.scala.bson.codecs.Macros.{ createCodecProvider, createCodecProviderIgnoreNone } import org.mongodb.scala.bson.codecs.Registry.DEFAULT_CODEC_REGISTRY import org.mongodb.scala.bson.collection.immutable.Document +import scala.collection.immutable.Vector import scala.collection.JavaConverters._ import scala.reflect.ClassTag diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java index e3718d77280..367b75830b1 100644 --- a/bson/src/main/org/bson/Float32Vector.java +++ b/bson/src/main/org/bson/Float32Vector.java @@ -57,7 +57,7 @@ public float[] getVectorArray() { } @Override - public final boolean equals(final Object o) { + public boolean equals(final Object o) { if (this == o) { return true; } diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java index c3972d98a31..f3948b57cb1 100644 --- a/bson/src/main/org/bson/Int8Vector.java +++ b/bson/src/main/org/bson/Int8Vector.java @@ -57,7 +57,7 @@ public byte[] getVectorArray() { } @Override - public final boolean equals(final Object o) { + public boolean equals(final Object o) { if (this == o) { return true; } diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index 43dfe6a060f..450ae80f354 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -75,7 +75,7 @@ public byte getPadding() { } @Override - public final boolean equals(final Object o) { + public boolean equals(final Object o) { if (this == o) { return true; } diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index 1719cc77b4f..f0a32456ba9 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -63,8 +63,8 @@ public abstract class Vector { * @throws IllegalArgumentException If the padding value is greater than 7. */ public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) { - isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); notNull("Vector data", vectorData); + isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); isTrue("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); return new PackedBitVector(vectorData, padding); } diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java index da39e5b8abf..d0297662313 100644 --- a/bson/src/main/org/bson/codecs/Float32VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -49,7 +49,6 @@ public Float32Vector decode(final BsonReader reader, final DecoderContext decode .asFloat32Vector(); } - @Override public Class getEncoderClass() { return Float32Vector.class; diff --git a/bson/src/main/org/bson/codecs/ValueCodecProvider.java b/bson/src/main/org/bson/codecs/ValueCodecProvider.java index a9b7c300f9d..43716faac5b 100644 --- a/bson/src/main/org/bson/codecs/ValueCodecProvider.java +++ b/bson/src/main/org/bson/codecs/ValueCodecProvider.java @@ -89,6 +89,7 @@ private void addCodecs() { addCodec(new StringCodec()); addCodec(new SymbolCodec()); addCodec(new OverridableUuidRepresentationUuidCodec()); + addCodec(new VectorCodec()); addCodec(new Float32VectorCodec()); addCodec(new Int8VectorCodec()); addCodec(new PackedBitVectorCodec()); diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java new file mode 100644 index 00000000000..f847b222bf3 --- /dev/null +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -0,0 +1,62 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.bson.codecs; + +import org.bson.BSONException; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.BsonReader; +import org.bson.BsonWriter; +import org.bson.Vector; + +/** + * Encodes and decodes {@code Vector} objects. + * + * @since BINARY_VECTOR + */ + final class VectorCodec implements Codec { + + @Override + public void encode(final BsonWriter writer, final Vector vectorToEncode, final EncoderContext encoderContext) { + writer.writeBinaryData(new BsonBinary(vectorToEncode)); + } + + @Override + public Vector decode(final BsonReader reader, final DecoderContext decoderContext) { + byte subType = reader.peekBinarySubType(); + + if (subType != BsonBinarySubType.VECTOR.getValue()) { + throw new BSONException("Unexpected BsonBinarySubType"); + } + + return reader.readBinaryData() + .asBinary() + .asVector(); + } + + @Override + public Class getEncoderClass() { + return Vector.class; + } + + @Override + public String toString() { + return "VectorCodec{}"; + } +} + + diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index e926d938116..1b4324595c1 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -76,43 +76,45 @@ public static Vector decodeBinaryToVector(final byte[] encodedVector) { byte padding = encodedVector[1]; switch (dtype) { case INT8: + isTrue("Padding must be 0 for INT8 data type.", padding == 0); byte[] int8Vector = getVectorBytesWithoutMetadata(encodedVector); return Vector.int8Vector(int8Vector); case PACKED_BIT: byte[] packedBitVector = getVectorBytesWithoutMetadata(encodedVector); + isTrue("Padding must be 0 if vector is empty.", padding == 0 || packedBitVector.length > 0); + isTrue("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); return Vector.packedBitVector(packedBitVector, padding); case FLOAT32: - isTrue("Byte array length must be a multiple of 4 for FLOAT32 dtype.", + isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); + isTrue("Padding must be 0 for FLOAT32 data type.", padding == 0); return Vector.floatVector(readLittleEndianFloats(encodedVector)); default: - throw new AssertionError("Unknown vector dtype: " + dtype); + throw new AssertionError("Unknown vector data type: " + dtype); } } private static byte[] getVectorBytesWithoutMetadata(final byte[] encodedVector) { - int vectorDataLength; - byte[] vectorData; - vectorDataLength = encodedVector.length - METADATA_SIZE; - vectorData = new byte[vectorDataLength]; + int vectorDataLength = encodedVector.length - METADATA_SIZE; + byte[] vectorData = new byte[vectorDataLength]; System.arraycopy(encodedVector, METADATA_SIZE, vectorData, 0, vectorDataLength); return vectorData; } - public static byte[] writeVector(final byte dtype, final byte padding, final byte[] vectorData) { + public static byte[] writeVector(final byte dType, final byte padding, final byte[] vectorData) { final byte[] bytes = new byte[vectorData.length + METADATA_SIZE]; - bytes[0] = dtype; + bytes[0] = dType; bytes[1] = padding; System.arraycopy(vectorData, 0, bytes, METADATA_SIZE, vectorData.length); return bytes; } - public static byte[] writeVector(final byte dtype, final byte padding, final float[] vectorData) { + public static byte[] writeVector(final byte dType, final byte padding, final float[] vectorData) { final byte[] bytes = new byte[vectorData.length * FLOAT_SIZE + METADATA_SIZE]; - bytes[0] = dtype; + bytes[0] = dType; bytes[1] = padding; ByteBuffer buffer = ByteBuffer.wrap(bytes); @@ -145,13 +147,13 @@ private static float[] readLittleEndianFloats(final byte[] encodedVector) { return floatArray; } - public static Vector.Dtype determineVectorDType(final byte dtype) { + public static Vector.Dtype determineVectorDType(final byte dType) { Vector.Dtype[] values = Vector.Dtype.values(); for (Vector.Dtype value : values) { - if (value.getValue() == dtype) { + if (value.getValue() == dType) { return value; } } - throw new IllegalStateException("Unknown vector dtype: " + dtype); + throw new IllegalStateException("Unknown vector data type: " + dType); } } diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java index 8dd0af1ba65..d3532aedecb 100644 --- a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -22,6 +22,9 @@ import org.bson.BsonReader; import org.bson.BsonWriter; import org.bson.Document; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; import org.bson.Vector; import org.bson.codecs.configuration.CodecRegistry; import org.bson.io.OutputBuffer; @@ -52,9 +55,12 @@ class VectorCodecTest extends CodecTestCase { private static Stream provideVectorsAndCodecsForRoundTrip() { return Stream.of( - arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec()), - arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new Int8VectorCodec()), - arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new PackedBitVectorCodec()) + arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec(), Float32Vector.class), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new Int8VectorCodec(), Int8Vector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new PackedBitVectorCodec(), PackedBitVector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new VectorCodec(), Vector.class), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new VectorCodec(), Vector.class), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3), new VectorCodec(), Vector.class) ); } @@ -127,12 +133,12 @@ void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { @ParameterizedTest @MethodSource("provideVectorsAndCodecsForRoundTrip") - void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec) { + void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec, final Class expectedEncoderClass) { // when Class encoderClass = codec.getEncoderClass(); // then - assertEquals(vector.getClass(), encoderClass); + assertEquals(expectedEncoderClass, encoderClass); } @ParameterizedTest @@ -147,6 +153,7 @@ void shouldConvertToString(final Codec codec) { private static Stream> provideVectorsCodec() { return Stream.of( + new VectorCodec(), new Float32VectorCodec(), new Int8VectorCodec(), new PackedBitVectorCodec() diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java index 4edf6b489f9..48cdecbf797 100644 --- a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java +++ b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import java.util.stream.Stream; @@ -142,26 +143,78 @@ void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector, private static Stream providePackedBitVectors() { return Stream.of( new Object[]{ - Vector.packedBitVector(new byte[]{(byte) 15, (byte) 240}, (byte) 2), - new byte[]{PACKED_BIT_DTYPE, 2, (byte) 15, (byte) 240} + Vector.packedBitVector(new byte[]{(byte) 0, (byte) 255, (byte) 10}, (byte) 2), + new byte[]{PACKED_BIT_DTYPE, 2, (byte) 0, (byte) 255, (byte) 10} }, new Object[]{ - Vector.packedBitVector(new byte[]{(byte) 170}, (byte) 4), - new byte[]{PACKED_BIT_DTYPE, 4, (byte) 170} + Vector.packedBitVector(new byte[0], (byte) 0), + new byte[]{PACKED_BIT_DTYPE, 0} } ); } @Test void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { - // given: an encoded vector with an invalid length (not a multiple of 4) + // given byte[] invalidData = {FLOAT32_DTYPE, 0, 10, 20, 30}; // when & Then IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { VectorHelper.decodeBinaryToVector(invalidData); }); - assertEquals("state should be: Byte array length must be a multiple of 4 for FLOAT32 dtype.", thrown.getMessage()); + assertEquals("state should be: Byte array length must be a multiple of 4 for FLOAT32 data type.", thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidFloatArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {FLOAT32_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & Then + IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { + VectorHelper.decodeBinaryToVector(invalidData); + }); + assertEquals("state should be: Padding must be 0 for FLOAT32 data type.", thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidInt8ArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {INT8_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & Then + IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { + VectorHelper.decodeBinaryToVector(invalidData); + }); + assertEquals("state should be: Padding must be 0 for INT8 data type.", thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecode(final byte invalidPadding) { + // given + byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding, 10, 20, 30, 20}; + + // when & Then + IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { + VectorHelper.decodeBinaryToVector(invalidData); + }); + assertEquals("state should be: Padding must be between 0 and 7 bits.", thrown.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1, 2, 3, 4, 5, 6, 7, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(final byte invalidPadding) { + // given + byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding}; + + // when & Then + IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { + VectorHelper.decodeBinaryToVector(invalidData); + }); + assertEquals("state should be: Padding must be 0 if vector is empty", thrown.getMessage()); } @Test diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java new file mode 100644 index 00000000000..32bd5385b37 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.vector.VectorAbstractFunctionalTest; +import com.mongodb.reactivestreams.client.MongoClients; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; + +public class VectorFunctionalTest extends VectorAbstractFunctionalTest { + @Override + protected MongoClient getMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java new file mode 100644 index 00000000000..e87838d7e00 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java @@ -0,0 +1,359 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.ReadConcern; +import com.mongodb.ReadPreference; +import com.mongodb.WriteConcern; +import com.mongodb.client.Fixture; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.OperationTest; +import org.bson.BsonBinary; +import org.bson.BsonBinarySubType; +import org.bson.Document; +import org.bson.Float32Vector; +import org.bson.Int8Vector; +import org.bson.PackedBitVector; +import org.bson.Vector; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.codecs.pojo.PojoCodecProvider; +import org.bson.types.Binary; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; +import static org.bson.Vector.Dtype.FLOAT32; +import static org.bson.Vector.Dtype.INT8; +import static org.bson.Vector.Dtype.PACKED_BIT; +import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; + +public abstract class VectorAbstractFunctionalTest extends OperationTest { + + private static final byte VECTOR_SUBTYPE = BsonBinarySubType.VECTOR.getValue(); + private static final String FIELD_VECTOR = "vector"; + private static final CodecRegistry CODEC_REGISTRY = fromRegistries(getDefaultCodecRegistry(), + fromProviders(PojoCodecProvider + .builder() + .automatic(true).build())); + private MongoCollection documentCollection; + + private MongoClient mongoClient; + + @BeforeEach + public void setUp() { + super.beforeEach(); + mongoClient = getMongoClient(getMongoClientSettingsBuilder() + .codecRegistry(CODEC_REGISTRY) + .build()); + documentCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()); + } + + @AfterEach + public void afterEach() { + super.afterEach(); + if (mongoClient != null) { + mongoClient.close(); + } + } + + private MongoClientSettings.Builder getMongoClientSettingsBuilder() { + return Fixture.getMongoClientSettingsBuilder() + .readConcern(ReadConcern.MAJORITY) + .writeConcern(WriteConcern.MAJORITY) + .readPreference(ReadPreference.primary()); + } + + protected abstract MongoClient getMongoClient(MongoClientSettings settings); + + @ParameterizedTest + @ValueSource(bytes = {-1, 1, 2, 3, 4, 5, 6, 7, 8}) + void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(final byte invalidPadding) { + //given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{PACKED_BIT.getValue(), invalidPadding}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + Binary invalidVectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); + Assertions.assertEquals("state should be: Padding must be 0 if vector is empty.", exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidFloat32Padding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{FLOAT32.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + Binary invalidVectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); + Assertions.assertEquals("state should be: Padding must be 0 for FLOAT32 data type.", exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 1}) + void shouldThrowExceptionForInvalidInt8Padding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{INT8.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + Binary invalidVectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); + Assertions.assertEquals("state should be: Padding must be 0 for INT8 data type.", exception.getMessage()); + } + + @ParameterizedTest + @ValueSource(bytes = {-1, 8}) + void shouldThrowExceptionForInvalidPackedBitPadding(final byte invalidPadding) { + // given + Binary invalidVector = new Binary(VECTOR_SUBTYPE, new byte[]{PACKED_BIT.getValue(), invalidPadding, 10, 20, 30, 40}); + documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); + + // when & then + Binary invalidVectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); + Assertions.assertEquals("state should be: Padding must be between 0 and 7 bits.", exception.getMessage()); + } + + private static Stream provideValidVectors() { + return Stream.of( + Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), + Vector.int8Vector(new byte[]{10, 20, 30, 40}), + Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3) + ); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithCodec(final Vector actualVector) { + // Given + Document documentToInsert = new Document(FIELD_VECTOR, actualVector); + documentCollection.insertOne(documentToInsert); + + // when & then + Binary vectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + Assertions.assertEquals(actualVector, vectorBinary.asVector()); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithBinary(final Vector actualVector) { + // given + Document documentToInsert = new Document(FIELD_VECTOR, new Binary(actualVector)); + documentCollection.insertOne(documentToInsert); + + // when & then + Binary vectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + Assertions.assertEquals(actualVector, vectorBinary.asVector()); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithBsonBinary(final Vector actualVector) { + // Given + Document documentToInsert = new Document(FIELD_VECTOR, new BsonBinary(actualVector)); + documentCollection.insertOne(documentToInsert); + + // when & then + Binary vectorBinary = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Binary.class); + + Assertions.assertEquals(actualVector, vectorBinary.asVector()); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithFloatVectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(FloatVectorPojo.class); + Float32Vector vector = Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}); + + // whe + floatVectorPojoMongoCollection.insertOne(new FloatVectorPojo(vector)); + FloatVectorPojo floatVectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(floatVectorPojo); + Assertions.assertEquals(vector, floatVectorPojo.getVector()); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithInt8VectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(Int8VectorPojo.class); + Int8Vector vector = Vector.int8Vector(new byte[]{10, 20, 30, 40}); + + // when + floatVectorPojoMongoCollection.insertOne(new Int8VectorPojo(vector)); + Int8VectorPojo int8VectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(int8VectorPojo); + Assertions.assertEquals(vector, int8VectorPojo.getVector()); + } + + @Test + void shouldStoreAndRetrieveValidVectorWithPackedBitVectorPojo() { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(PackedBitVectorPojo.class); + + PackedBitVector vector = Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3); + + // when + floatVectorPojoMongoCollection.insertOne(new PackedBitVectorPojo(vector)); + PackedBitVectorPojo packedBitVectorPojo = floatVectorPojoMongoCollection.find().first(); + + // then + Assertions.assertNotNull(packedBitVectorPojo); + Assertions.assertEquals(vector, packedBitVectorPojo.getVector()); + } + + @ParameterizedTest + @MethodSource("provideValidVectors") + void shouldStoreAndRetrieveValidVectorWithGenericVectorPojo(final Vector actualVector) { + // given + MongoCollection floatVectorPojoMongoCollection = mongoClient + .getDatabase(getDatabaseName()) + .getCollection(getCollectionName()).withDocumentClass(VectorPojo.class); + + // when + floatVectorPojoMongoCollection.insertOne(new VectorPojo(actualVector)); + VectorPojo vectorPojo = floatVectorPojoMongoCollection.find().first(); + + //then + Assertions.assertNotNull(vectorPojo); + Assertions.assertEquals(actualVector, vectorPojo.getVector()); + } + + private Document findExactlyOne(final MongoCollection collection) { + List documents = new ArrayList<>(); + collection.find().into(documents); + if (documents.size() != 1) { + throw new IllegalStateException("Expected exactly one document, but found: " + documents.size()); + } + return documents.get(0); + } + + public static class VectorPojo { + private Vector vector; + + public VectorPojo() { + } + + public VectorPojo(final Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Vector vector) { + this.vector = vector; + } + } + + public static class Int8VectorPojo { + private Int8Vector vector; + + public Int8VectorPojo() { + } + + public Int8VectorPojo(final Int8Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Int8Vector vector) { + this.vector = vector; + } + } + + public static class PackedBitVectorPojo { + private PackedBitVector vector; + + public PackedBitVectorPojo() { + } + + public PackedBitVectorPojo(final PackedBitVector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final PackedBitVector vector) { + this.vector = vector; + } + } + + public static class FloatVectorPojo { + private Float32Vector vector; + + public FloatVectorPojo() { + } + + public FloatVectorPojo(final Float32Vector vector) { + this.vector = vector; + } + + public Vector getVector() { + return vector; + } + + public void setVector(final Float32Vector vector) { + this.vector = vector; + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java new file mode 100644 index 00000000000..a0cddb6dbca --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.vector; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +public class VectorFunctionalTest extends VectorAbstractFunctionalTest { + @Override + protected MongoClient getMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } +} From e5bd2b700a8fc955d85dededdc59c5f9782cef72 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 17 Oct 2024 13:56:17 -0700 Subject: [PATCH 06/20] Fix tests. JAVA-5544 --- .../unit/org/bson/BsonBinaryWriterTest.java | 33 +++++++++++++++---- .../ValueCodecProviderSpecification.groovy | 8 ++++- .../internal/vector/VectorHelperTest.java | 2 +- .../bson/vector/VectorGenericBsonTest.java | 24 +++----------- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java index 91fbd3dbf1f..750350d3d98 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java @@ -305,14 +305,33 @@ public void testWriteBinary() { writer.writeBinaryData("b4", new BsonBinary(BsonBinarySubType.VECTOR, new byte[]{FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, (byte) 63})); - writer.writeEndDocument(); - - byte[] expectedValues = {49, 0, 0, 0, 5, 98, 49, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 98, 50, 0, - 9, 0, - 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, - 5, 98, 51, 0, 0, 0, 0, 0, 1, 0, - 6, BsonBinarySubType.VECTOR.getValue(), FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, 63, + byte[] expectedValues = new byte[]{ + 64, // total document length + 0, 0, 0, + + //Binary + (byte) BsonType.BINARY.getValue(), + 98, 49, 0, // name "b1" + 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + // Old binary + (byte) BsonType.BINARY.getValue(), + 98, 50, 0, // name "b2" + 9, 0, 0, 0, 2, 5, 0, 0, 0, 1, 1, 1, 1, 1, + + // Function binary + (byte) BsonType.BINARY.getValue(), + 98, 51, 0, // name "b3" + 0, 0, 0, 0, 1, + + //Vector binary + (byte) BsonType.BINARY.getValue(), + 98, 52, 0, // name "b4" + 6, 0, 0, 0, // total length, int32 (little endian) + BsonBinarySubType.VECTOR.getValue(), FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, 63, + + 0 //end of document }; assertArrayEquals(expectedValues, buffer.toByteArray()); diff --git a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy index a4054f664ac..872e4dd6142 100644 --- a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy +++ b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy @@ -17,6 +17,9 @@ package org.bson.codecs import org.bson.Document +import org.bson.Float32Vector +import org.bson.Int8Vector +import org.bson.PackedBitVector import org.bson.Vector import org.bson.codecs.configuration.CodecRegistries import org.bson.types.Binary @@ -58,7 +61,10 @@ class ValueCodecProviderSpecification extends Specification { provider.get(Short, registry) instanceof ShortCodec provider.get(byte[], registry) instanceof ByteArrayCodec provider.get(Float, registry) instanceof FloatCodec - provider.get(Vector, registry) instanceof Float32VectorCodec + provider.get(Vector, registry) instanceof VectorCodec + provider.get(Float32Vector, registry) instanceof Float32VectorCodec + provider.get(Int8Vector, registry) instanceof Int8VectorCodec + provider.get(PackedBitVector, registry) instanceof PackedBitVectorCodec provider.get(Binary, registry) instanceof BinaryCodec provider.get(MinKey, registry) instanceof MinKeyCodec diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java index 48cdecbf797..8087c039bdb 100644 --- a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java +++ b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java @@ -214,7 +214,7 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(fi IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { VectorHelper.decodeBinaryToVector(invalidData); }); - assertEquals("state should be: Padding must be 0 if vector is empty", thrown.getMessage()); + assertEquals("state should be: Padding must be 0 if vector is empty.", thrown.getMessage()); } @Test diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java index dc1310137ec..a0da9178294 100644 --- a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -126,10 +126,8 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) byte[] expectedVectorData = toByteArray(arrayVector); byte[] actualVectorData = actualVector.asInt8Vector().getVectorArray(); assertVectorDecoding( - expectedCanonicalBsonHex, expectedVectorData, expectedDType, - actualDecodedDocument, actualVectorData, actualVector); @@ -143,9 +141,8 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); byte[] expectedVectorPackedBitData = toByteArray(arrayVector); assertVectorDecoding( - expectedCanonicalBsonHex, expectedVectorPackedBitData, + expectedVectorPackedBitData, expectedDType, expectedPadding, - actualDecodedDocument, actualPackedBitVector); assertThatVectorCreationResultsInCorrectBinary( @@ -159,10 +156,8 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); float[] expectedFloatVector = toFloatArray(arrayVector); assertVectorDecoding( - expectedCanonicalBsonHex, expectedFloatVector, expectedDType, - actualDecodedDocument, actualFloat32Vector); assertThatVectorCreationResultsInCorrectBinary( Vector.floatVector(expectedFloatVector), @@ -191,40 +186,31 @@ private static void assertThatVectorCreationResultsInCorrectBinary(final Vector format("Failed to create expected BSON for document with description '%s'", description)); } - private void assertVectorDecoding(final String expectedCanonicalBsonHex, - final byte[] expectedVectorData, + private void assertVectorDecoding(final byte[] expectedVectorData, final Vector.Dtype expectedDType, - final BsonDocument actualDecodedDocument, final byte[] actualVectorData, final Vector actualVector) { - assertEquals(expectedCanonicalBsonHex, encodeToHex(actualDecodedDocument)); Assertions.assertArrayEquals(actualVectorData, expectedVectorData, () -> "Actual: " + Arrays.toString(actualVectorData) + " != Expected:" + Arrays.toString(expectedVectorData)); assertEquals(expectedDType, actualVector.getDataType()); } - private void assertVectorDecoding(final String expectedCanonicalBsonHex, - final byte[] expectedVectorData, + private void assertVectorDecoding(final byte[] expectedVectorData, final Vector.Dtype expectedDType, final byte expectedPadding, - final BsonDocument actualDecodedDocument, final PackedBitVector actualVector) { byte[] actualVectorData = actualVector.getVectorArray(); - assertVectorDecoding(expectedCanonicalBsonHex, + assertVectorDecoding( expectedVectorData, expectedDType, - actualDecodedDocument, actualVectorData, actualVector); assertEquals(expectedPadding, actualVector.getPadding()); } - private void assertVectorDecoding(final String expectedCanonicalBsonHex, - final float[] expectedVectorData, + private void assertVectorDecoding(final float[] expectedVectorData, final Vector.Dtype expectedDType, - final BsonDocument actualDecodedDocument, final Float32Vector actualVector) { - assertEquals(expectedCanonicalBsonHex, encodeToHex(actualDecodedDocument)); float[] actualVectorArray = actualVector.getVectorArray(); Assertions.assertArrayEquals(actualVectorArray, expectedVectorData, () -> "Actual: " + Arrays.toString(actualVectorArray) + " != Expected:" + Arrays.toString(expectedVectorData)); From 8de3743b0ce76490aed2c7bc2ef4bd735eb887a4 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 17 Oct 2024 14:13:59 -0700 Subject: [PATCH 07/20] Rename Type to DataType. JAVA-5544 --- bson/src/main/org/bson/Float32Vector.java | 2 +- bson/src/main/org/bson/Int8Vector.java | 2 +- bson/src/main/org/bson/PackedBitVector.java | 4 +- bson/src/main/org/bson/Vector.java | 54 +++++++++---------- .../bson/internal/vector/VectorHelper.java | 24 ++++----- .../unit/org/bson/BsonBinaryWriterTest.java | 2 +- bson/src/test/unit/org/bson/VectorTest.java | 8 +-- .../internal/vector/VectorHelperTest.java | 18 +++---- .../bson/vector/VectorGenericBsonTest.java | 10 ++-- .../vector/VectorAbstractFunctionalTest.java | 6 +-- 10 files changed, 65 insertions(+), 65 deletions(-) diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java index 367b75830b1..e456539d272 100644 --- a/bson/src/main/org/bson/Float32Vector.java +++ b/bson/src/main/org/bson/Float32Vector.java @@ -40,7 +40,7 @@ public final class Float32Vector extends Vector { private final float[] vectorData; Float32Vector(final float[] vectorData) { - super(Dtype.FLOAT32); + super(DataType.FLOAT32); this.vectorData = assertNotNull(vectorData); } diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java index f3948b57cb1..5a1669ecae1 100644 --- a/bson/src/main/org/bson/Int8Vector.java +++ b/bson/src/main/org/bson/Int8Vector.java @@ -40,7 +40,7 @@ public final class Int8Vector extends Vector { private byte[] vectorData; Int8Vector(final byte[] vectorData) { - super(Dtype.INT8); + super(DataType.INT8); this.vectorData = assertNotNull(vectorData); } diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index 450ae80f354..6119a7d9427 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -41,7 +41,7 @@ public final class PackedBitVector extends Vector { private final byte[] vectorData; PackedBitVector(final byte[] vectorData, final byte padding) { - super(Dtype.PACKED_BIT); + super(DataType.PACKED_BIT); this.vectorData = assertNotNull(vectorData); this.padding = padding; } @@ -64,7 +64,7 @@ public byte[] getVectorArray() { * Returns the padding value for this vector. * *

    Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not - * all {@link Dtype}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

    + * all {@link DataType}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

    *

    * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. * diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index f0a32456ba9..40dc25004c7 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -23,7 +23,7 @@ /** * Represents a vector that is stored and retrieved using the BSON Binary Subtype 9 format. - * This class supports multiple vector {@link Dtype}'s and provides static methods to create + * This class supports multiple vector {@link DataType}'s and provides static methods to create * vectors. *

    * Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently @@ -34,16 +34,16 @@ * @since BINARY_VECTOR */ public abstract class Vector { - private final Dtype vectorType; + private final DataType vectorType; - Vector(final Dtype vectorType) { + Vector(final DataType vectorType) { this.vectorType = vectorType; } /** - * Creates a vector with the {@link Dtype#PACKED_BIT} data type. + * Creates a vector with the {@link DataType#PACKED_BIT} data type. *

    - * A {@link Dtype#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte + * A {@link DataType#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte * can hold up to 8 bits (vector elements). The padding parameter is used to specify how many bits in the final byte should be ignored.

    * *

    For example, a vector with two bytes and a padding of 4 would have the following structure:

    @@ -59,7 +59,7 @@ public abstract class Vector { * * @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits. * @param padding The number of bits (0 to 7) to ignore in the final byte of the vector data. - * @return A {@link PackedBitVector} instance with the {@link Dtype#PACKED_BIT} data type. + * @return A {@link PackedBitVector} instance with the {@link DataType#PACKED_BIT} data type. * @throws IllegalArgumentException If the padding value is greater than 7. */ public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) { @@ -70,16 +70,16 @@ public static PackedBitVector packedBitVector(final byte[] vectorData, final byt } /** - * Creates a vector with the {@link Dtype#INT8} data type. + * Creates a vector with the {@link DataType#INT8} data type. * - *

    A {@link Dtype#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, + *

    A {@link DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, * with values in the range [-128, 127].

    *

    * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected * in the created {@link Int8Vector} instance. * - * @param vectorData The byte array representing the {@link Dtype#INT8} vector data. - * @return A {@link Int8Vector} instance with the {@link Dtype#INT8} data type. + * @param vectorData The byte array representing the {@link DataType#INT8} vector data. + * @return A {@link Int8Vector} instance with the {@link DataType#INT8} data type. */ public static Int8Vector int8Vector(final byte[] vectorData) { notNull("vectorData", vectorData); @@ -87,15 +87,15 @@ public static Int8Vector int8Vector(final byte[] vectorData) { } /** - * Creates a vector with the {@link Dtype#FLOAT32} data type. + * Creates a vector with the {@link DataType#FLOAT32} data type. *

    - * A {@link Dtype#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

    + * A {@link DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

    *

    * NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected * in the created {@link Float32Vector} instance. * - * @param vectorData The float array representing the {@link Dtype#FLOAT32} vector data. - * @return A {@link Float32Vector} instance with the {@link Dtype#FLOAT32} data type. + * @param vectorData The float array representing the {@link DataType#FLOAT32} vector data. + * @return A {@link Float32Vector} instance with the {@link DataType#FLOAT32} data type. */ public static Float32Vector floatVector(final float[] vectorData) { notNull("vectorData", vectorData); @@ -106,11 +106,11 @@ public static Float32Vector floatVector(final float[] vectorData) { * Returns the {@link PackedBitVector}. * * @return {@link PackedBitVector}. - * @throws IllegalStateException if this vector is not of type {@link Dtype#PACKED_BIT}. Use {@link #getDataType()} to check the vector + * @throws IllegalStateException if this vector is not of type {@link DataType#PACKED_BIT}. Use {@link #getDataType()} to check the vector * type before calling this method. */ public PackedBitVector asPackedBitVector() { - ensureType(Dtype.PACKED_BIT); + ensureType(DataType.PACKED_BIT); return (PackedBitVector) this; } @@ -118,11 +118,11 @@ public PackedBitVector asPackedBitVector() { * Returns the {@link Int8Vector}. * * @return {@link Int8Vector}. - * @throws IllegalStateException if this vector is not of type {@link Dtype#INT8}. Use {@link #getDataType()} to check the vector + * @throws IllegalStateException if this vector is not of type {@link DataType#INT8}. Use {@link #getDataType()} to check the vector * type before calling this method. */ public Int8Vector asInt8Vector() { - ensureType(Dtype.INT8); + ensureType(DataType.INT8); return (Int8Vector) this; } @@ -130,25 +130,25 @@ public Int8Vector asInt8Vector() { * Returns the {@link Float32Vector}. * * @return {@link Float32Vector}. - * @throws IllegalStateException if this vector is not of type {@link Dtype#FLOAT32}. Use {@link #getDataType()} to check the vector + * @throws IllegalStateException if this vector is not of type {@link DataType#FLOAT32}. Use {@link #getDataType()} to check the vector * type before calling this method. */ public Float32Vector asFloat32Vector() { - ensureType(Dtype.FLOAT32); + ensureType(DataType.FLOAT32); return (Float32Vector) this; } /** - * Returns {@link Dtype} of the vector. + * Returns {@link DataType} of the vector. * * @return the data type of the vector. */ - public Dtype getDataType() { + public DataType getDataType() { return this.vectorType; } - private void ensureType(final Dtype expected) { + private void ensureType(final DataType expected) { if (this.vectorType != expected) { throw new IllegalStateException("Expected vector type " + expected + " but found " + this.vectorType); } @@ -160,7 +160,7 @@ private void ensureType(final Dtype expected) { * Each dtype determines how the data in the vector is stored, including how many bits are used to represent each element * in the vector. */ - public enum Dtype { + public enum DataType { /** * An INT8 vector is a vector of 8-bit signed integers. The vector is stored as an array of bytes, where each byte * represents a signed integer in the range [-128, 127]. @@ -178,16 +178,16 @@ public enum Dtype { private final byte value; - Dtype(final byte value) { + DataType(final byte value) { this.value = value; } /** - * Returns the byte value associated with this {@link Dtype}. + * Returns the byte value associated with this {@link DataType}. * *

    This value is used in the BSON binary format to indicate the data type of the vector.

    * - * @return the byte value representing the {@link Dtype}. + * @return the byte value representing the {@link DataType}. */ public byte getValue() { return value; diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index 1b4324595c1..268ae081fcd 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -49,18 +49,18 @@ private VectorHelper() { private static final int FLOAT_SIZE = 4; public static byte[] encodeVectorToBinary(final Vector vector) { - Vector.Dtype dtype = vector.getDataType(); - switch (dtype) { + Vector.DataType dataType = vector.getDataType(); + switch (dataType) { case INT8: - return writeVector(dtype.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); + return writeVector(dataType.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); case PACKED_BIT: PackedBitVector packedBitVector = vector.asPackedBitVector(); - return writeVector(dtype.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); + return writeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); case FLOAT32: - return writeVector(dtype.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); + return writeVector(dataType.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); default: - throw new AssertionError("Unknown vector dtype: " + dtype); + throw new AssertionError("Unknown vector dtype: " + dataType); } } @@ -72,9 +72,9 @@ public static byte[] encodeVectorToBinary(final Vector vector) { public static Vector decodeBinaryToVector(final byte[] encodedVector) { isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE); - Vector.Dtype dtype = determineVectorDType(encodedVector[0]); + Vector.DataType dataType = determineVectorDType(encodedVector[0]); byte padding = encodedVector[1]; - switch (dtype) { + switch (dataType) { case INT8: isTrue("Padding must be 0 for INT8 data type.", padding == 0); byte[] int8Vector = getVectorBytesWithoutMetadata(encodedVector); @@ -91,7 +91,7 @@ public static Vector decodeBinaryToVector(final byte[] encodedVector) { return Vector.floatVector(readLittleEndianFloats(encodedVector)); default: - throw new AssertionError("Unknown vector data type: " + dtype); + throw new AssertionError("Unknown vector data type: " + dataType); } } @@ -147,9 +147,9 @@ private static float[] readLittleEndianFloats(final byte[] encodedVector) { return floatArray; } - public static Vector.Dtype determineVectorDType(final byte dType) { - Vector.Dtype[] values = Vector.Dtype.values(); - for (Vector.Dtype value : values) { + public static Vector.DataType determineVectorDType(final byte dType) { + Vector.DataType[] values = Vector.DataType.values(); + for (Vector.DataType value : values) { if (value.getValue() == dType) { return value; } diff --git a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java index 750350d3d98..c9e22fcce7a 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryWriterTest.java @@ -40,7 +40,7 @@ public class BsonBinaryWriterTest { - private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue(); + private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); private static final int ZERO_PADDING = 0; private BsonBinaryWriter writer; diff --git a/bson/src/test/unit/org/bson/VectorTest.java b/bson/src/test/unit/org/bson/VectorTest.java index f16373a0777..66b0914097b 100644 --- a/bson/src/test/unit/org/bson/VectorTest.java +++ b/bson/src/test/unit/org/bson/VectorTest.java @@ -37,7 +37,7 @@ void shouldCreateInt8Vector() { // then assertNotNull(vector); - assertEquals(Vector.Dtype.INT8, vector.getDataType()); + assertEquals(Vector.DataType.INT8, vector.getDataType()); assertArrayEquals(data, vector.getVectorArray()); } @@ -61,7 +61,7 @@ void shouldCreateFloat32Vector() { // then assertNotNull(vector); - assertEquals(Vector.Dtype.FLOAT32, vector.getDataType()); + assertEquals(Vector.DataType.FLOAT32, vector.getDataType()); assertArrayEquals(data, vector.getVectorArray()); } @@ -87,7 +87,7 @@ void shouldCreatePackedBitVector(final byte validPadding) { // then assertNotNull(vector); - assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); + assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); assertArrayEquals(data, vector.getVectorArray()); assertEquals(validPadding, vector.getPadding()); } @@ -127,7 +127,7 @@ void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() { // then assertNotNull(vector); - assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType()); + assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); assertArrayEquals(data, vector.getVectorArray()); assertEquals(padding, vector.getPadding()); } diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java index 8087c039bdb..111b4a6f5ba 100644 --- a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java +++ b/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java @@ -32,9 +32,9 @@ import static org.junit.jupiter.api.Assertions.assertThrows; class VectorHelperTest { - private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue(); - private static final byte INT8_DTYPE = Vector.Dtype.INT8.getValue(); - private static final byte PACKED_BIT_DTYPE = Vector.Dtype.PACKED_BIT.getValue(); + private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); + private static final byte INT8_DTYPE = Vector.DataType.INT8.getValue(); + private static final byte PACKED_BIT_DTYPE = Vector.DataType.PACKED_BIT.getValue(); public static final int ZERO_PADDING = 0; @ParameterizedTest(name = "{index}: {0}") @@ -54,7 +54,7 @@ void shouldDecodeFloatVector(final Float32Vector expectedFloatVector, final byte Float32Vector decodedVector = (Float32Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then - assertEquals(Vector.Dtype.FLOAT32, decodedVector.getDataType()); + assertEquals(Vector.DataType.FLOAT32, decodedVector.getDataType()); assertArrayEquals(expectedFloatVector.getVectorArray(), decodedVector.getVectorArray()); } @@ -102,7 +102,7 @@ void shouldDecodeInt8Vector(final Int8Vector expectedInt8Vector, final byte[] bs Int8Vector decodedVector = (Int8Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then - assertEquals(Vector.Dtype.INT8, decodedVector.getDataType()); + assertEquals(Vector.DataType.INT8, decodedVector.getDataType()); assertArrayEquals(expectedInt8Vector.getVectorArray(), decodedVector.getVectorArray()); } @@ -135,7 +135,7 @@ void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector, PackedBitVector decodedVector = (PackedBitVector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); // then - assertEquals(Vector.Dtype.PACKED_BIT, decodedVector.getDataType()); + assertEquals(Vector.DataType.PACKED_BIT, decodedVector.getDataType()); assertArrayEquals(expectedPackedBitVector.getVectorArray(), decodedVector.getVectorArray()); assertEquals(expectedPackedBitVector.getPadding(), decodedVector.getPadding()); } @@ -220,12 +220,12 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(fi @Test void shouldDetermineVectorDType() { // given - Vector.Dtype[] values = Vector.Dtype.values(); + Vector.DataType[] values = Vector.DataType.values(); - for (Vector.Dtype value : values) { + for (Vector.DataType value : values) { // when byte dtype = value.getValue(); - Vector.Dtype actual = VectorHelper.determineVectorDType(dtype); + Vector.DataType actual = VectorHelper.determineVectorDType(dtype); // then assertEquals(value, actual); diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java index a0da9178294..d2689ff762f 100644 --- a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -87,7 +87,7 @@ private void runInvalidTestCase(final BsonDocument testCase) { BsonArray arrayVector = testCase.getArray("vector"); byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); - Vector.Dtype expectedDType = determineVectorDType(dtypeByte); + Vector.DataType expectedDType = determineVectorDType(dtypeByte); switch (expectedDType) { case INT8: @@ -114,7 +114,7 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); - Vector.Dtype expectedDType = determineVectorDType(dtypeByte); + Vector.DataType expectedDType = determineVectorDType(dtypeByte); String expectedCanonicalBsonHex = testCase.getString("canonical_bson").getValue().toUpperCase(); BsonArray arrayVector = testCase.getArray("vector"); @@ -187,7 +187,7 @@ private static void assertThatVectorCreationResultsInCorrectBinary(final Vector } private void assertVectorDecoding(final byte[] expectedVectorData, - final Vector.Dtype expectedDType, + final Vector.DataType expectedDType, final byte[] actualVectorData, final Vector actualVector) { Assertions.assertArrayEquals(actualVectorData, expectedVectorData, @@ -196,7 +196,7 @@ private void assertVectorDecoding(final byte[] expectedVectorData, } private void assertVectorDecoding(final byte[] expectedVectorData, - final Vector.Dtype expectedDType, + final Vector.DataType expectedDType, final byte expectedPadding, final PackedBitVector actualVector) { byte[] actualVectorData = actualVector.getVectorArray(); @@ -209,7 +209,7 @@ private void assertVectorDecoding(final byte[] expectedVectorData, } private void assertVectorDecoding(final float[] expectedVectorData, - final Vector.Dtype expectedDType, + final Vector.DataType expectedDType, final Float32Vector actualVector) { float[] actualVectorArray = actualVector.getVectorArray(); Assertions.assertArrayEquals(actualVectorArray, expectedVectorData, diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java index e87838d7e00..2afd544b883 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java @@ -47,9 +47,9 @@ import java.util.stream.Stream; import static com.mongodb.MongoClientSettings.getDefaultCodecRegistry; -import static org.bson.Vector.Dtype.FLOAT32; -import static org.bson.Vector.Dtype.INT8; -import static org.bson.Vector.Dtype.PACKED_BIT; +import static org.bson.Vector.DataType.FLOAT32; +import static org.bson.Vector.DataType.INT8; +import static org.bson.Vector.DataType.PACKED_BIT; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; From d5c3fe99cdaa45d0531eb40938a3d3f501b631e3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 17 Oct 2024 22:50:28 -0700 Subject: [PATCH 08/20] Add javadoc. Rename methods. JAVA-5544 --- bson/src/main/org/bson/BsonBinary.java | 4 +- bson/src/main/org/bson/BsonBinarySubType.java | 9 +++- bson/src/main/org/bson/Float32Vector.java | 2 +- bson/src/main/org/bson/Int8Vector.java | 2 +- bson/src/main/org/bson/PackedBitVector.java | 8 +-- bson/src/main/org/bson/Vector.java | 7 +-- .../org/bson/codecs/Float32VectorCodec.java | 4 +- .../main/org/bson/codecs/Int8VectorCodec.java | 2 +- .../org/bson/codecs/PackedBitVectorCodec.java | 2 +- .../org/bson/codecs/ValueCodecProvider.java | 1 + .../src/main/org/bson/codecs/VectorCodec.java | 4 +- .../bson/internal/vector/VectorHelper.java | 54 +++++++++++-------- bson/src/main/org/bson/types/Binary.java | 4 +- .../test/unit/org/bson/types/BinaryTest.java | 8 +-- 14 files changed, 66 insertions(+), 45 deletions(-) diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index 9aae40377d3..05c20af94ab 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -96,7 +96,7 @@ public BsonBinary(final UUID uuid) { * Construct a Type 9 BsonBinary from the given Vector. * * @param vector the {@link Vector} - * @since BINARY_VECTOR + * @since 5.3 */ public BsonBinary(final Vector vector) { if (vector == null) { @@ -149,7 +149,7 @@ public UUID asUuid() { * * @return the vector * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. - * @since BINARY_VECTOR + * @since 5.3 */ public Vector asVector() { if (!BsonBinarySubType.isVector(type)) { diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index bc59541636e..efc3971e5f0 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -76,7 +76,7 @@ public enum BsonBinarySubType { /** * Vector data. * - * @since BINARY_VECTOR + * @since 5.3 * @see Vector */ VECTOR((byte) 0x09), @@ -99,6 +99,13 @@ public static boolean isUuid(final byte value) { return value == UUID_LEGACY.getValue() || value == UUID_STANDARD.getValue(); } + /** + * Returns true if the given value is a {@link #VECTOR} subtype. + * + * @param value the subtype value as a byte. + * @return true if value is a {@link #VECTOR} subtype. + * @since 5.3 + */ public static boolean isVector(final byte value) { return value == VECTOR.getValue(); } diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java index e456539d272..8bea1a8fc4b 100644 --- a/bson/src/main/org/bson/Float32Vector.java +++ b/bson/src/main/org/bson/Float32Vector.java @@ -33,7 +33,7 @@ * @see BsonBinary#asVector() * @see Binary#Binary(Vector) * @see Binary#asVector() - * @since BINARY_VECTOR + * @since 5.3 */ public final class Float32Vector extends Vector { diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java index 5a1669ecae1..218eab566c4 100644 --- a/bson/src/main/org/bson/Int8Vector.java +++ b/bson/src/main/org/bson/Int8Vector.java @@ -33,7 +33,7 @@ * @see BsonBinary#asVector() * @see Binary#Binary(Vector) * @see Binary#asVector() - * @since BINARY_VECTOR + * @since 5.3 */ public final class Int8Vector extends Vector { diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index 6119a7d9427..bd4f1f3ee50 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -33,7 +33,7 @@ * @see BsonBinary#asVector() * @see Binary#Binary(Vector) * @see Binary#asVector() - * @since BINARY_VECTOR + * @since 5.3 */ public final class PackedBitVector extends Vector { @@ -63,9 +63,11 @@ public byte[] getVectorArray() { /** * Returns the padding value for this vector. * - *

    Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not - * all {@link DataType}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.

    + *

    Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving + * {@linkplain #getVectorArray() the vector array). For instance, if the padding value is 3, this means that the last byte contains + * 3 least-significant unused bits, which should be disregarded during operations.

    *

    + * * NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance. * * @return the padding value (between 0 and 7). diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index 40dc25004c7..bf1ea52ba30 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -31,7 +31,7 @@ * * @mongodb.server.release 6.0 * @see BsonBinary - * @since BINARY_VECTOR + * @since 5.3 */ public abstract class Vector { private final DataType vectorType; @@ -44,7 +44,8 @@ public abstract class Vector { * Creates a vector with the {@link DataType#PACKED_BIT} data type. *

    * A {@link DataType#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte - * can hold up to 8 bits (vector elements). The padding parameter is used to specify how many bits in the final byte should be ignored.

    + * can hold up to 8 bits (vector elements). The padding parameter is used to specify how many least-significant bits in the final byte + * should be ignored.

    * *

    For example, a vector with two bytes and a padding of 4 would have the following structure:

    *
    @@ -58,7 +59,7 @@ public abstract class Vector {
          * in the created {@link PackedBitVector} instance.
          *
          * @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits.
    -     * @param padding    The number of bits (0 to 7) to ignore in the final byte of the vector data.
    +     * @param padding    The number of least-significant bits (0 to 7) to ignore in the final byte of the vector data.
          * @return A {@link PackedBitVector} instance with the {@link DataType#PACKED_BIT} data type.
          * @throws IllegalArgumentException If the padding value is greater than 7.
          */
    diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java
    index d0297662313..a596cb51f5d 100644
    --- a/bson/src/main/org/bson/codecs/Float32VectorCodec.java
    +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java
    @@ -24,9 +24,9 @@
     import org.bson.Float32Vector;
     
     /**
    - * Encodes and decodes {@code Vector} objects.
    + * Encodes and decodes {@link Float32Vector} objects.
      *
    - * @since BINARY_VECTOR
    + * @since 5.3
      */
     final class Float32VectorCodec implements Codec {
     
    diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java
    index 7bf79d1b87d..2c548d88f71 100644
    --- a/bson/src/main/org/bson/codecs/Int8VectorCodec.java
    +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java
    @@ -26,7 +26,7 @@
     /**
      * Encodes and decodes {@link Int8Vector} objects.
      *
    - * @since BINARY_VECTOR
    + * @since 5.3
      */
     final class Int8VectorCodec implements Codec {
     
    diff --git a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java
    index 3bc3cfe19c1..8ecaba4c396 100644
    --- a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java
    +++ b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java
    @@ -26,7 +26,7 @@
     /**
      * Encodes and decodes {@link PackedBitVector} objects.
      *
    - * @since BINARY_VECTOR
    + * @since 5.3
      */
     final class PackedBitVectorCodec implements Codec {
     
    diff --git a/bson/src/main/org/bson/codecs/ValueCodecProvider.java b/bson/src/main/org/bson/codecs/ValueCodecProvider.java
    index 43716faac5b..3a921c1b08a 100644
    --- a/bson/src/main/org/bson/codecs/ValueCodecProvider.java
    +++ b/bson/src/main/org/bson/codecs/ValueCodecProvider.java
    @@ -42,6 +42,7 @@
      *     
  • {@link org.bson.codecs.StringCodec}
  • *
  • {@link org.bson.codecs.SymbolCodec}
  • *
  • {@link org.bson.codecs.UuidCodec}
  • + *
  • {@link VectorCodec}
  • *
  • {@link Float32VectorCodec}
  • *
  • {@link Int8VectorCodec}
  • *
  • {@link PackedBitVectorCodec}
  • diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java index f847b222bf3..60b2a0553b3 100644 --- a/bson/src/main/org/bson/codecs/VectorCodec.java +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -24,9 +24,9 @@ import org.bson.Vector; /** - * Encodes and decodes {@code Vector} objects. + * Encodes and decodes {@link Vector} objects. * - * @since BINARY_VECTOR + * @since 5.3 */ final class VectorCodec implements Codec { diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index 268ae081fcd..4880d92ca69 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -17,6 +17,8 @@ package org.bson.internal.vector; import org.bson.BsonBinary; +import org.bson.Float32Vector; +import org.bson.Int8Vector; import org.bson.PackedBitVector; import org.bson.Vector; import org.bson.types.Binary; @@ -52,13 +54,12 @@ public static byte[] encodeVectorToBinary(final Vector vector) { Vector.DataType dataType = vector.getDataType(); switch (dataType) { case INT8: - return writeVector(dataType.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); + return encodeVector(dataType.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); case PACKED_BIT: PackedBitVector packedBitVector = vector.asPackedBitVector(); - return writeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); + return encodeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); case FLOAT32: - return writeVector(dataType.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); - + return encodeVector(dataType.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); default: throw new AssertionError("Unknown vector dtype: " + dataType); } @@ -71,39 +72,48 @@ public static byte[] encodeVectorToBinary(final Vector vector) { */ public static Vector decodeBinaryToVector(final byte[] encodedVector) { isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE); - Vector.DataType dataType = determineVectorDType(encodedVector[0]); byte padding = encodedVector[1]; switch (dataType) { case INT8: - isTrue("Padding must be 0 for INT8 data type.", padding == 0); - byte[] int8Vector = getVectorBytesWithoutMetadata(encodedVector); - return Vector.int8Vector(int8Vector); + return decodeInt8Vector(encodedVector, padding); case PACKED_BIT: - byte[] packedBitVector = getVectorBytesWithoutMetadata(encodedVector); - isTrue("Padding must be 0 if vector is empty.", padding == 0 || packedBitVector.length > 0); - isTrue("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); - return Vector.packedBitVector(packedBitVector, padding); + return decodePackedBitVector(encodedVector, padding); case FLOAT32: - isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", - (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); - isTrue("Padding must be 0 for FLOAT32 data type.", padding == 0); - return Vector.floatVector(readLittleEndianFloats(encodedVector)); - + return decodeFloat32Vector(encodedVector, padding); default: throw new AssertionError("Unknown vector data type: " + dataType); } } - private static byte[] getVectorBytesWithoutMetadata(final byte[] encodedVector) { + private static Float32Vector decodeFloat32Vector(final byte[] encodedVector, final byte padding) { + isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", + (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); + isTrue("Padding must be 0 for FLOAT32 data type.", padding == 0); + return Vector.floatVector(decodeLittleEndianFloats(encodedVector)); + } + + private static PackedBitVector decodePackedBitVector(final byte[] encodedVector, final byte padding) { + byte[] packedBitVector = extractVectorData(encodedVector); + isTrue("Padding must be 0 if vector is empty.", padding == 0 || packedBitVector.length > 0); + isTrue("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); + return Vector.packedBitVector(packedBitVector, padding); + } + + private static Int8Vector decodeInt8Vector(final byte[] encodedVector, final byte padding) { + isTrue("Padding must be 0 for INT8 data type.", padding == 0); + byte[] int8Vector = extractVectorData(encodedVector); + return Vector.int8Vector(int8Vector); + } + + private static byte[] extractVectorData(final byte[] encodedVector) { int vectorDataLength = encodedVector.length - METADATA_SIZE; byte[] vectorData = new byte[vectorDataLength]; System.arraycopy(encodedVector, METADATA_SIZE, vectorData, 0, vectorDataLength); return vectorData; } - - public static byte[] writeVector(final byte dType, final byte padding, final byte[] vectorData) { + public static byte[] encodeVector(final byte dType, final byte padding, final byte[] vectorData) { final byte[] bytes = new byte[vectorData.length + METADATA_SIZE]; bytes[0] = dType; bytes[1] = padding; @@ -111,7 +121,7 @@ public static byte[] writeVector(final byte dType, final byte padding, final byt return bytes; } - public static byte[] writeVector(final byte dType, final byte padding, final float[] vectorData) { + public static byte[] encodeVector(final byte dType, final byte padding, final float[] vectorData) { final byte[] bytes = new byte[vectorData.length * FLOAT_SIZE + METADATA_SIZE]; bytes[0] = dType; @@ -131,7 +141,7 @@ public static byte[] writeVector(final byte dType, final byte padding, final flo return bytes; } - private static float[] readLittleEndianFloats(final byte[] encodedVector) { + private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { int vectorSize = encodedVector.length - METADATA_SIZE; int numFloats = vectorSize / FLOAT_SIZE; diff --git a/bson/src/main/org/bson/types/Binary.java b/bson/src/main/org/bson/types/Binary.java index f1c8db03562..186d7544a9d 100644 --- a/bson/src/main/org/bson/types/Binary.java +++ b/bson/src/main/org/bson/types/Binary.java @@ -76,7 +76,7 @@ public Binary(final byte type, final byte[] data) { * Construct a Type 9 BsonBinary from the given Vector. * * @param vector the {@link Vector} - * @since BINARY_VECTOR + * @since 5.3 */ public Binary(final Vector vector) { if (vector == null) { @@ -91,7 +91,7 @@ public Binary(final Vector vector) { * * @return the vector * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. - * @since BINARY_VECTOR + * @since 5.3 */ public Vector asVector() { if (!BsonBinarySubType.isVector(type)) { diff --git a/bson/src/test/unit/org/bson/types/BinaryTest.java b/bson/src/test/unit/org/bson/types/BinaryTest.java index c1defa310c9..cef524be3b3 100644 --- a/bson/src/test/unit/org/bson/types/BinaryTest.java +++ b/bson/src/test/unit/org/bson/types/BinaryTest.java @@ -86,14 +86,14 @@ private static void assertVectorDataDecoding(final Vector actualVector, final Ve switch (actualVector.getDataType()) { case FLOAT32: Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); - Float32Vector decodedFloat32Vector1 = decodedVector.asFloat32Vector(); - assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector1.getVectorArray(), + Float32Vector decodedFloat32Vector = decodedVector.asFloat32Vector(); + assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector.getVectorArray(), "Float vector data should match after decoding"); break; case INT8: Int8Vector actualInt8Vector = actualVector.asInt8Vector(); - Int8Vector decodedInt8Vector1 = decodedVector.asInt8Vector(); - assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector1.getVectorArray(), + Int8Vector decodedInt8Vector = decodedVector.asInt8Vector(); + assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector.getVectorArray(), "Int8 vector data should match after decoding"); break; case PACKED_BIT: From d3c37898bc499134e5a9ff022b11bbc4a2e44077 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 17 Oct 2024 22:56:52 -0700 Subject: [PATCH 09/20] Move validation to specific method. JAVA-5544 --- bson/src/main/org/bson/internal/vector/VectorHelper.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index 4880d92ca69..bb0c5d28084 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -87,8 +87,6 @@ public static Vector decodeBinaryToVector(final byte[] encodedVector) { } private static Float32Vector decodeFloat32Vector(final byte[] encodedVector, final byte padding) { - isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", - (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); isTrue("Padding must be 0 for FLOAT32 data type.", padding == 0); return Vector.floatVector(decodeLittleEndianFloats(encodedVector)); } @@ -142,6 +140,9 @@ public static byte[] encodeVector(final byte dType, final byte padding, final fl } private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { + isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", + (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); + int vectorSize = encodedVector.length - METADATA_SIZE; int numFloats = vectorSize / FLOAT_SIZE; From a71779ea877945572aaeecac47b75ad8627738a1 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Wed, 23 Oct 2024 18:49:25 -0700 Subject: [PATCH 10/20] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- bson/src/main/org/bson/BsonBinary.java | 6 +++--- bson/src/main/org/bson/BsonBinarySubType.java | 1 + bson/src/main/org/bson/Vector.java | 7 +++++-- bson/src/main/org/bson/codecs/Float32VectorCodec.java | 6 ------ bson/src/main/org/bson/codecs/Int8VectorCodec.java | 5 ----- bson/src/main/org/bson/codecs/VectorCodec.java | 5 ----- bson/src/main/org/bson/internal/vector/VectorHelper.java | 6 +++--- .../client/vector/VectorAbstractFunctionalTest.java | 4 ++-- 8 files changed, 14 insertions(+), 26 deletions(-) diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index 05c20af94ab..d2592df8cdb 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -93,7 +93,7 @@ public BsonBinary(final UUID uuid) { } /** - * Construct a Type 9 BsonBinary from the given Vector. + * Construct a {@linkplain BsonBinarySubType#VECTOR subtype 9} {@link BsonBinary} from the given {@link @Vector}. * * @param vector the {@link Vector} * @since 5.3 @@ -145,10 +145,10 @@ public UUID asUuid() { } /** - * Returns the binary as a {@link Vector}. The binary type must be 9. + * Returns the binary as a {@link Vector}. The {@linkplain #getType() subtype} must be {@linkplain BsonBinarySubType#VECTOR 9}. * * @return the vector - * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. + * @throws BsonInvalidOperationException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. * @since 5.3 */ public Vector asVector() { diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index efc3971e5f0..541cd224999 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -76,6 +76,7 @@ public enum BsonBinarySubType { /** * Vector data. * + * @mongodb.server.release 6.0 * @since 5.3 * @see Vector */ diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index bf1ea52ba30..1a85757dd45 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -64,9 +64,9 @@ public abstract class Vector { * @throws IllegalArgumentException If the padding value is greater than 7. */ public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) { - notNull("Vector data", vectorData); + notNull("vectorData", vectorData); isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); - isTrue("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); + isTrueArgument("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); return new PackedBitVector(vectorData, padding); } @@ -160,6 +160,9 @@ private void ensureType(final DataType expected) { *

    * Each dtype determines how the data in the vector is stored, including how many bits are used to represent each element * in the vector. + * + * @mongodb.server.release 6.0 + * @since 5.3 */ public enum DataType { /** diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java index a596cb51f5d..9182ce43fc0 100644 --- a/bson/src/main/org/bson/codecs/Float32VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -54,9 +54,3 @@ public Class getEncoderClass() { return Float32Vector.class; } - @Override - public String toString() { - return "Float32VectorCodec{}"; - } -} - diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java index 2c548d88f71..3adeb599040 100644 --- a/bson/src/main/org/bson/codecs/Int8VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java @@ -54,10 +54,5 @@ public Int8Vector decode(final BsonReader reader, final DecoderContext decoderCo public Class getEncoderClass() { return Int8Vector.class; } - - @Override - public String toString() { - return "Int8VectorCodec{}"; - } } diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java index 60b2a0553b3..a91b009760e 100644 --- a/bson/src/main/org/bson/codecs/VectorCodec.java +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -52,11 +52,6 @@ public Vector decode(final BsonReader reader, final DecoderContext decoderContex public Class getEncoderClass() { return Vector.class; } - - @Override - public String toString() { - return "VectorCodec{}"; - } } diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index bb0c5d28084..1bed4c97abd 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -111,7 +111,7 @@ private static byte[] extractVectorData(final byte[] encodedVector) { return vectorData; } - public static byte[] encodeVector(final byte dType, final byte padding, final byte[] vectorData) { + private static byte[] encodeVector(final byte dType, final byte padding, final byte[] vectorData) { final byte[] bytes = new byte[vectorData.length + METADATA_SIZE]; bytes[0] = dType; bytes[1] = padding; @@ -119,7 +119,7 @@ public static byte[] encodeVector(final byte dType, final byte padding, final by return bytes; } - public static byte[] encodeVector(final byte dType, final byte padding, final float[] vectorData) { + private static byte[] encodeVector(final byte dType, final byte padding, final float[] vectorData) { final byte[] bytes = new byte[vectorData.length * FLOAT_SIZE + METADATA_SIZE]; bytes[0] = dType; @@ -158,7 +158,7 @@ private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { return floatArray; } - public static Vector.DataType determineVectorDType(final byte dType) { + private static Vector.DataType determineVectorDType(final byte dType) { Vector.DataType[] values = Vector.DataType.values(); for (Vector.DataType value : values) { if (value.getValue() == dType) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java index 2afd544b883..2ad55bbed60 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java @@ -84,7 +84,7 @@ public void afterEach() { } } - private MongoClientSettings.Builder getMongoClientSettingsBuilder() { + private static MongoClientSettings.Builder getMongoClientSettingsBuilder() { return Fixture.getMongoClientSettingsBuilder() .readConcern(ReadConcern.MAJORITY) .writeConcern(WriteConcern.MAJORITY) @@ -163,7 +163,7 @@ private static Stream provideValidVectors() { @ParameterizedTest @MethodSource("provideValidVectors") - void shouldStoreAndRetrieveValidVectorWithCodec(final Vector actualVector) { + void shouldStoreAndRetrieveValidVector(final Vector actualVector) { // Given Document documentToInsert = new Document(FIELD_VECTOR, actualVector); documentCollection.insertOne(documentToInsert); From ec7cafe7e1b9306086935cd211d9138ca216319f Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Thu, 24 Oct 2024 11:02:24 -0700 Subject: [PATCH 11/20] Update bson/src/main/org/bson/PackedBitVector.java Co-authored-by: Valentin Kovalenko --- bson/src/main/org/bson/PackedBitVector.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index bd4f1f3ee50..3f9b2dc84e1 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -91,9 +91,7 @@ public boolean equals(final Object o) { @Override public int hashCode() { - int result = padding; - result = 31 * result + Arrays.hashCode(vectorData); - return result; + return Objects.hash(padding, Arrays.hashCode(vectorData)); } @Override From d3a2287c13c913859851ff11767dba2d529ca32c Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Fri, 25 Oct 2024 15:35:55 -0700 Subject: [PATCH 12/20] Add decoding to Vector in DocumentCodec. Add javadocs. Use state verification instead of behaviour verification in tests. JAVA-5544 --- bson/src/main/org/bson/BsonBinary.java | 4 +- bson/src/main/org/bson/BsonBinarySubType.java | 11 -- bson/src/main/org/bson/Float32Vector.java | 23 +-- bson/src/main/org/bson/Int8Vector.java | 26 ++- bson/src/main/org/bson/PackedBitVector.java | 28 ++- bson/src/main/org/bson/Vector.java | 51 +++--- .../org/bson/codecs/ContainerCodecHelper.java | 65 ++++--- .../org/bson/codecs/Float32VectorCodec.java | 4 +- .../main/org/bson/codecs/Int8VectorCodec.java | 2 +- .../org/bson/codecs/PackedBitVectorCodec.java | 8 +- .../src/main/org/bson/codecs/VectorCodec.java | 3 +- .../bson/internal/vector/VectorHelper.java | 45 ++--- bson/src/main/org/bson/types/Binary.java | 34 ---- .../resources/bson-binary-vector/README.md | 40 ----- .../test/unit/org/bson/BsonBinaryTest.java | 117 ------------- ...perTest.java => BsonBinaryVectorTest.java} | 160 ++++++++++-------- bson/src/test/unit/org/bson/VectorTest.java | 26 +-- .../unit/org/bson/codecs/CodecTestCase.java | 8 +- .../org/bson/codecs/DocumentCodecTest.java | 20 ++- .../unit/org/bson/codecs/VectorCodecTest.java | 118 ++++++------- .../test/unit/org/bson/types/BinaryTest.java | 123 -------------- .../bson/vector/VectorGenericBsonTest.java | 46 ++--- .../client/vector/VectorFunctionalTest.java | 4 +- ...java => AbstractVectorFunctionalTest.java} | 98 +++++------ .../client/vector/VectorFunctionalTest.java | 2 +- 25 files changed, 390 insertions(+), 676 deletions(-) delete mode 100644 bson/src/test/resources/bson-binary-vector/README.md delete mode 100644 bson/src/test/unit/org/bson/BsonBinaryTest.java rename bson/src/test/unit/org/bson/{internal/vector/VectorHelperTest.java => BsonBinaryVectorTest.java} (55%) delete mode 100644 bson/src/test/unit/org/bson/types/BinaryTest.java rename driver-sync/src/test/functional/com/mongodb/client/vector/{VectorAbstractFunctionalTest.java => AbstractVectorFunctionalTest.java} (76%) diff --git a/bson/src/main/org/bson/BsonBinary.java b/bson/src/main/org/bson/BsonBinary.java index d2592df8cdb..8590c2920be 100644 --- a/bson/src/main/org/bson/BsonBinary.java +++ b/bson/src/main/org/bson/BsonBinary.java @@ -93,7 +93,7 @@ public BsonBinary(final UUID uuid) { } /** - * Construct a {@linkplain BsonBinarySubType#VECTOR subtype 9} {@link BsonBinary} from the given {@link @Vector}. + * Constructs a {@linkplain BsonBinarySubType#VECTOR subtype 9} {@link BsonBinary} from the given {@link Vector}. * * @param vector the {@link Vector} * @since 5.3 @@ -152,7 +152,7 @@ public UUID asUuid() { * @since 5.3 */ public Vector asVector() { - if (!BsonBinarySubType.isVector(type)) { + if (type != BsonBinarySubType.VECTOR.getValue()) { throw new BsonInvalidOperationException("type must be a Vector subtype."); } diff --git a/bson/src/main/org/bson/BsonBinarySubType.java b/bson/src/main/org/bson/BsonBinarySubType.java index 541cd224999..7b5948b4efc 100644 --- a/bson/src/main/org/bson/BsonBinarySubType.java +++ b/bson/src/main/org/bson/BsonBinarySubType.java @@ -100,17 +100,6 @@ public static boolean isUuid(final byte value) { return value == UUID_LEGACY.getValue() || value == UUID_STANDARD.getValue(); } - /** - * Returns true if the given value is a {@link #VECTOR} subtype. - * - * @param value the subtype value as a byte. - * @return true if value is a {@link #VECTOR} subtype. - * @since 5.3 - */ - public static boolean isVector(final byte value) { - return value == VECTOR.getValue(); - } - BsonBinarySubType(final byte value) { this.value = value; } diff --git a/bson/src/main/org/bson/Float32Vector.java b/bson/src/main/org/bson/Float32Vector.java index 8bea1a8fc4b..9678003b72f 100644 --- a/bson/src/main/org/bson/Float32Vector.java +++ b/bson/src/main/org/bson/Float32Vector.java @@ -16,8 +16,6 @@ package org.bson; -import org.bson.types.Binary; - import java.util.Arrays; import static org.bson.assertions.Assertions.assertNotNull; @@ -31,17 +29,15 @@ * @see Vector#floatVector(float[]) * @see BsonBinary#BsonBinary(Vector) * @see BsonBinary#asVector() - * @see Binary#Binary(Vector) - * @see Binary#asVector() * @since 5.3 */ public final class Float32Vector extends Vector { - private final float[] vectorData; + private final float[] data; Float32Vector(final float[] vectorData) { super(DataType.FLOAT32); - this.vectorData = assertNotNull(vectorData); + this.data = assertNotNull(vectorData); } /** @@ -52,8 +48,8 @@ public final class Float32Vector extends Vector { * * @return the underlying float array representing this {@link Float32Vector} vector. */ - public float[] getVectorArray() { - return assertNotNull(vectorData); + public float[] getData() { + return assertNotNull(data); } @Override @@ -61,24 +57,23 @@ public boolean equals(final Object o) { if (this == o) { return true; } - if (!(o instanceof Float32Vector)) { + if (o == null || getClass() != o.getClass()) { return false; } - Float32Vector that = (Float32Vector) o; - return Arrays.equals(vectorData, that.vectorData); + return Arrays.equals(data, that.data); } @Override public int hashCode() { - return Arrays.hashCode(vectorData); + return Arrays.hashCode(data); } @Override public String toString() { return "Float32Vector{" - + "vectorData=" + Arrays.toString(vectorData) - + ", vectorType=" + getDataType() + + "data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + '}'; } } diff --git a/bson/src/main/org/bson/Int8Vector.java b/bson/src/main/org/bson/Int8Vector.java index 218eab566c4..b61e6bfee55 100644 --- a/bson/src/main/org/bson/Int8Vector.java +++ b/bson/src/main/org/bson/Int8Vector.java @@ -16,9 +16,8 @@ package org.bson; -import org.bson.types.Binary; - import java.util.Arrays; +import java.util.Objects; import static org.bson.assertions.Assertions.assertNotNull; @@ -31,17 +30,15 @@ * @see Vector#int8Vector(byte[]) * @see BsonBinary#BsonBinary(Vector) * @see BsonBinary#asVector() - * @see Binary#Binary(Vector) - * @see Binary#asVector() * @since 5.3 */ public final class Int8Vector extends Vector { - private byte[] vectorData; + private byte[] data; - Int8Vector(final byte[] vectorData) { + Int8Vector(final byte[] data) { super(DataType.INT8); - this.vectorData = assertNotNull(vectorData); + this.data = assertNotNull(data); } /** @@ -52,8 +49,8 @@ public final class Int8Vector extends Vector { * * @return the underlying byte array representing this {@link Int8Vector} vector. */ - public byte[] getVectorArray() { - return assertNotNull(vectorData); + public byte[] getData() { + return assertNotNull(data); } @Override @@ -61,24 +58,23 @@ public boolean equals(final Object o) { if (this == o) { return true; } - if (!(o instanceof Int8Vector)) { + if (o == null || getClass() != o.getClass()) { return false; } - Int8Vector that = (Int8Vector) o; - return Arrays.equals(vectorData, that.vectorData); + return Objects.deepEquals(data, that.data); } @Override public int hashCode() { - return Arrays.hashCode(vectorData); + return Arrays.hashCode(data); } @Override public String toString() { return "Int8Vector{" - + "vectorData=" + Arrays.toString(vectorData) - + ", vectorType=" + getDataType() + + "data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + '}'; } } diff --git a/bson/src/main/org/bson/PackedBitVector.java b/bson/src/main/org/bson/PackedBitVector.java index 3f9b2dc84e1..a5dd8f4dcdf 100644 --- a/bson/src/main/org/bson/PackedBitVector.java +++ b/bson/src/main/org/bson/PackedBitVector.java @@ -16,9 +16,8 @@ package org.bson; -import org.bson.types.Binary; - import java.util.Arrays; +import java.util.Objects; import static org.bson.assertions.Assertions.assertNotNull; @@ -31,18 +30,16 @@ * @see Vector#packedBitVector(byte[], byte) * @see BsonBinary#BsonBinary(Vector) * @see BsonBinary#asVector() - * @see Binary#Binary(Vector) - * @see Binary#asVector() * @since 5.3 */ public final class PackedBitVector extends Vector { private final byte padding; - private final byte[] vectorData; + private final byte[] data; - PackedBitVector(final byte[] vectorData, final byte padding) { + PackedBitVector(final byte[] data, final byte padding) { super(DataType.PACKED_BIT); - this.vectorData = assertNotNull(vectorData); + this.data = assertNotNull(data); this.padding = padding; } @@ -56,15 +53,15 @@ public final class PackedBitVector extends Vector { * @return the underlying byte array representing this {@link PackedBitVector} vector. * @see #getPadding() */ - public byte[] getVectorArray() { - return assertNotNull(vectorData); + public byte[] getData() { + return assertNotNull(data); } /** * Returns the padding value for this vector. * *

    Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving - * {@linkplain #getVectorArray() the vector array). For instance, if the padding value is 3, this means that the last byte contains + * {@linkplain #getData() the vector array}. For instance, if the padding value is 3, this means that the last byte contains * 3 least-significant unused bits, which should be disregarded during operations.

    *

    * @@ -81,25 +78,24 @@ public boolean equals(final Object o) { if (this == o) { return true; } - if (!(o instanceof PackedBitVector)) { + if (o == null || getClass() != o.getClass()) { return false; } - PackedBitVector that = (PackedBitVector) o; - return padding == that.padding && Arrays.equals(vectorData, that.vectorData); + return padding == that.padding && Arrays.equals(data, that.data); } @Override public int hashCode() { - return Objects.hash(padding, Arrays.hashCode(vectorData)); + return Objects.hash(padding, Arrays.hashCode(data)); } @Override public String toString() { return "PackedBitVector{" + "padding=" + padding - + ", vectorData=" + Arrays.toString(vectorData) - + ", vectorType=" + getDataType() + + ", data=" + Arrays.toString(data) + + ", dataType=" + getDataType() + '}'; } } diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index 1a85757dd45..8b1548efd7f 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -16,8 +16,6 @@ package org.bson; - -import static org.bson.assertions.Assertions.isTrue; import static org.bson.assertions.Assertions.isTrueArgument; import static org.bson.assertions.Assertions.notNull; @@ -28,16 +26,19 @@ *

    * Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently * in BSON using a binary format. + *

    + * NOTE: This class is intended to be treated as sealed. Any subclasses added outside the library are not guaranteed to + * function correctly in the current and future releases. * * @mongodb.server.release 6.0 * @see BsonBinary * @since 5.3 */ public abstract class Vector { - private final DataType vectorType; + private final DataType dataType; - Vector(final DataType vectorType) { - this.vectorType = vectorType; + Vector(final DataType dataType) { + this.dataType = dataType; } /** @@ -55,19 +56,19 @@ public abstract class Vector { * Resulting vector: 12 bits: 111011101110 *

    *

    - * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected + * NOTE: The byte array `data` is not copied; changes to the provided array will be reflected * in the created {@link PackedBitVector} instance. * - * @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits. + * @param data The byte array representing the packed bit vector data. Each byte can store 8 bits. * @param padding The number of least-significant bits (0 to 7) to ignore in the final byte of the vector data. * @return A {@link PackedBitVector} instance with the {@link DataType#PACKED_BIT} data type. * @throws IllegalArgumentException If the padding value is greater than 7. */ - public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) { - notNull("vectorData", vectorData); - isTrueArgument("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); - isTrueArgument("Padding must be 0 if vector is empty", padding == 0 || vectorData.length > 0); - return new PackedBitVector(vectorData, padding); + public static PackedBitVector packedBitVector(final byte[] data, final byte padding) { + notNull("data", data); + isTrueArgument("Padding must be between 0 and 7 bits. Provided padding: " + padding, padding >= 0 && padding <= 7); + isTrueArgument("Padding must be 0 if vector is empty. Provided padding: " + padding, padding == 0 || data.length > 0); + return new PackedBitVector(data, padding); } /** @@ -76,15 +77,15 @@ public static PackedBitVector packedBitVector(final byte[] vectorData, final byt *

    A {@link DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector, * with values in the range [-128, 127].

    *

    - * NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected + * NOTE: The byte array `data` is not copied; changes to the provided array will be reflected * in the created {@link Int8Vector} instance. * - * @param vectorData The byte array representing the {@link DataType#INT8} vector data. + * @param data The byte array representing the {@link DataType#INT8} vector data. * @return A {@link Int8Vector} instance with the {@link DataType#INT8} data type. */ - public static Int8Vector int8Vector(final byte[] vectorData) { - notNull("vectorData", vectorData); - return new Int8Vector(vectorData); + public static Int8Vector int8Vector(final byte[] data) { + notNull("data", data); + return new Int8Vector(data); } /** @@ -92,15 +93,15 @@ public static Int8Vector int8Vector(final byte[] vectorData) { *

    * A {@link DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.

    *

    - * NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected + * NOTE: The float array `data` is not copied; changes to the provided array will be reflected * in the created {@link Float32Vector} instance. * - * @param vectorData The float array representing the {@link DataType#FLOAT32} vector data. + * @param data The float array representing the {@link DataType#FLOAT32} vector data. * @return A {@link Float32Vector} instance with the {@link DataType#FLOAT32} data type. */ - public static Float32Vector floatVector(final float[] vectorData) { - notNull("vectorData", vectorData); - return new Float32Vector(vectorData); + public static Float32Vector floatVector(final float[] data) { + notNull("data", data); + return new Float32Vector(data); } /** @@ -145,13 +146,13 @@ public Float32Vector asFloat32Vector() { * @return the data type of the vector. */ public DataType getDataType() { - return this.vectorType; + return this.dataType; } private void ensureType(final DataType expected) { - if (this.vectorType != expected) { - throw new IllegalStateException("Expected vector type " + expected + " but found " + this.vectorType); + if (this.dataType != expected) { + throw new IllegalStateException("Expected vector data type " + expected + ", but found " + this.dataType); } } diff --git a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java index 5969763546b..827858d11cb 100644 --- a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java +++ b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java @@ -16,10 +16,12 @@ package org.bson.codecs; +import org.bson.BsonBinarySubType; import org.bson.BsonReader; import org.bson.BsonType; import org.bson.Transformer; import org.bson.UuidRepresentation; +import org.bson.Vector; import org.bson.codecs.configuration.CodecConfigurationException; import org.bson.codecs.configuration.CodecRegistry; @@ -42,28 +44,53 @@ static Object readValue(final BsonReader reader, final DecoderContext decoderCon reader.readNull(); return null; } else { - Codec codec = bsonTypeCodecMap.get(bsonType); + Codec currentCodec = bsonTypeCodecMap.get(bsonType); - if (bsonType == BsonType.BINARY && reader.peekBinarySize() == 16) { - switch (reader.peekBinarySubType()) { - case 3: - if (uuidRepresentation == UuidRepresentation.JAVA_LEGACY - || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY - || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY) { - codec = registry.get(UUID.class); - } - break; - case 4: - if (uuidRepresentation == UuidRepresentation.STANDARD) { - codec = registry.get(UUID.class); - } - break; - default: - break; - } + if (bsonType == BsonType.BINARY) { + byte binarySubType = reader.peekBinarySubType(); + currentCodec = getBinarySubTypeCodec(reader, + uuidRepresentation, + registry, binarySubType, + currentCodec); + } + + return valueTransformer.transform(currentCodec.decode(reader, decoderContext)); + } + } + + private static Codec getBinarySubTypeCodec(final BsonReader reader, + final UuidRepresentation uuidRepresentation, + final CodecRegistry registry, + final byte binarySubType, + final Codec currentTypeCodec) { + + if (binarySubType == BsonBinarySubType.VECTOR.getValue()) { + Codec vectorCodec = registry.get(Vector.class, registry); + if (vectorCodec != null) { + return vectorCodec; } - return valueTransformer.transform(codec.decode(reader, decoderContext)); } + + if (reader.peekBinarySize() == 16) { + switch (binarySubType) { + case 3: + if (uuidRepresentation == UuidRepresentation.JAVA_LEGACY + || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY + || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY) { + return registry.get(UUID.class); + } + break; + case 4: + if (uuidRepresentation == UuidRepresentation.STANDARD) { + return registry.get(UUID.class); + } + break; + default: + break; + } + } + + return currentTypeCodec; } static Codec getCodec(final CodecRegistry codecRegistry, final Type type) { diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java index 9182ce43fc0..0933a00590a 100644 --- a/bson/src/main/org/bson/codecs/Float32VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -26,7 +26,6 @@ /** * Encodes and decodes {@link Float32Vector} objects. * - * @since 5.3 */ final class Float32VectorCodec implements Codec { @@ -40,7 +39,7 @@ public Float32Vector decode(final BsonReader reader, final DecoderContext decode byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Unexpected BsonBinarySubType"); + throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() @@ -53,4 +52,5 @@ public Float32Vector decode(final BsonReader reader, final DecoderContext decode public Class getEncoderClass() { return Float32Vector.class; } +} diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java index 3adeb599040..dc99877dd1a 100644 --- a/bson/src/main/org/bson/codecs/Int8VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java @@ -40,7 +40,7 @@ public Int8Vector decode(final BsonReader reader, final DecoderContext decoderCo byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Unexpected BsonBinarySubType"); + throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() diff --git a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java index 8ecaba4c396..1fb4deb5e20 100644 --- a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java +++ b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java @@ -26,7 +26,6 @@ /** * Encodes and decodes {@link PackedBitVector} objects. * - * @since 5.3 */ final class PackedBitVectorCodec implements Codec { @@ -40,7 +39,7 @@ public PackedBitVector decode(final BsonReader reader, final DecoderContext deco byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Unexpected BsonBinarySubType"); + throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() @@ -54,11 +53,6 @@ public PackedBitVector decode(final BsonReader reader, final DecoderContext deco public Class getEncoderClass() { return PackedBitVector.class; } - - @Override - public String toString() { - return "PackedBitVectorCodec{}"; - } } diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java index a91b009760e..4f4c1cf010d 100644 --- a/bson/src/main/org/bson/codecs/VectorCodec.java +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -26,7 +26,6 @@ /** * Encodes and decodes {@link Vector} objects. * - * @since 5.3 */ final class VectorCodec implements Codec { @@ -40,7 +39,7 @@ public Vector decode(final BsonReader reader, final DecoderContext decoderContex byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Unexpected BsonBinarySubType"); + throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found " + subType); } return reader.readBinaryData() diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index 1bed4c97abd..a634e7b7810 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -17,18 +17,18 @@ package org.bson.internal.vector; import org.bson.BsonBinary; +import org.bson.BsonInvalidOperationException; import org.bson.Float32Vector; import org.bson.Int8Vector; import org.bson.PackedBitVector; import org.bson.Vector; +import org.bson.assertions.Assertions; import org.bson.types.Binary; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import static org.bson.assertions.Assertions.isTrue; - /** * Helper class for encoding and decoding vectors to and from {@link BsonBinary}/{@link Binary}. * @@ -42,26 +42,27 @@ public final class VectorHelper { private static final ByteOrder STORED_BYTE_ORDER = ByteOrder.LITTLE_ENDIAN; + private static final String ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE = "Unknown vector data type: "; + private static final byte ZERO_PADDING = 0; private VectorHelper() { //NOP } private static final int METADATA_SIZE = 2; - private static final int FLOAT_SIZE = 4; public static byte[] encodeVectorToBinary(final Vector vector) { Vector.DataType dataType = vector.getDataType(); switch (dataType) { case INT8: - return encodeVector(dataType.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray()); + return encodeVector(dataType.getValue(), ZERO_PADDING, vector.asInt8Vector().getData()); case PACKED_BIT: PackedBitVector packedBitVector = vector.asPackedBitVector(); - return encodeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray()); + return encodeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getData()); case FLOAT32: - return encodeVector(dataType.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray()); + return encodeVector(dataType.getValue(), vector.asFloat32Vector().getData()); default: - throw new AssertionError("Unknown vector dtype: " + dataType); + throw Assertions.fail(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dataType); } } @@ -82,24 +83,24 @@ public static Vector decodeBinaryToVector(final byte[] encodedVector) { case FLOAT32: return decodeFloat32Vector(encodedVector, padding); default: - throw new AssertionError("Unknown vector data type: " + dataType); + throw Assertions.fail(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dataType); } } private static Float32Vector decodeFloat32Vector(final byte[] encodedVector, final byte padding) { - isTrue("Padding must be 0 for FLOAT32 data type.", padding == 0); + isTrue("Padding must be 0 for FLOAT32 data type, but found: " + padding, padding == 0); return Vector.floatVector(decodeLittleEndianFloats(encodedVector)); } private static PackedBitVector decodePackedBitVector(final byte[] encodedVector, final byte padding) { byte[] packedBitVector = extractVectorData(encodedVector); - isTrue("Padding must be 0 if vector is empty.", padding == 0 || packedBitVector.length > 0); - isTrue("Padding must be between 0 and 7 bits.", padding >= 0 && padding <= 7); + isTrue("Padding must be 0 if vector is empty, but found: " + padding, padding == 0 || packedBitVector.length > 0); + isTrue("Padding must be between 0 and 7 bits, but found: " + padding, padding >= 0 && padding <= 7); return Vector.packedBitVector(packedBitVector, padding); } private static Int8Vector decodeInt8Vector(final byte[] encodedVector, final byte padding) { - isTrue("Padding must be 0 for INT8 data type.", padding == 0); + isTrue("Padding must be 0 for INT8 data type, but found: " + padding, padding == 0); byte[] int8Vector = extractVectorData(encodedVector); return Vector.int8Vector(int8Vector); } @@ -119,11 +120,11 @@ private static byte[] encodeVector(final byte dType, final byte padding, final b return bytes; } - private static byte[] encodeVector(final byte dType, final byte padding, final float[] vectorData) { - final byte[] bytes = new byte[vectorData.length * FLOAT_SIZE + METADATA_SIZE]; + private static byte[] encodeVector(final byte dType, final float[] vectorData) { + final byte[] bytes = new byte[vectorData.length * Float.BYTES + METADATA_SIZE]; bytes[0] = dType; - bytes[1] = padding; + bytes[1] = ZERO_PADDING; ByteBuffer buffer = ByteBuffer.wrap(bytes); buffer.order(STORED_BYTE_ORDER); @@ -141,11 +142,11 @@ private static byte[] encodeVector(final byte dType, final byte padding, final f private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", - (encodedVector.length - METADATA_SIZE) % FLOAT_SIZE == 0); + (encodedVector.length - METADATA_SIZE) % Float.BYTES == 0); int vectorSize = encodedVector.length - METADATA_SIZE; - int numFloats = vectorSize / FLOAT_SIZE; + int numFloats = vectorSize / Float.BYTES; float[] floatArray = new float[numFloats]; ByteBuffer buffer = ByteBuffer.wrap(encodedVector, METADATA_SIZE, vectorSize); @@ -158,13 +159,19 @@ private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { return floatArray; } - private static Vector.DataType determineVectorDType(final byte dType) { + public static Vector.DataType determineVectorDType(final byte dType) { Vector.DataType[] values = Vector.DataType.values(); for (Vector.DataType value : values) { if (value.getValue() == dType) { return value; } } - throw new IllegalStateException("Unknown vector data type: " + dType); + throw new BsonInvalidOperationException(ERROR_MESSAGE_UNKNOWN_VECTOR_DATA_TYPE + dType); + } + + private static void isTrue(final String message, final boolean condition) { + if (!condition) { + throw new BsonInvalidOperationException(message); + } } } diff --git a/bson/src/main/org/bson/types/Binary.java b/bson/src/main/org/bson/types/Binary.java index 186d7544a9d..5ba482ccc41 100644 --- a/bson/src/main/org/bson/types/Binary.java +++ b/bson/src/main/org/bson/types/Binary.java @@ -17,15 +17,10 @@ package org.bson.types; import org.bson.BsonBinarySubType; -import org.bson.BsonInvalidOperationException; -import org.bson.Vector; -import org.bson.internal.vector.VectorHelper; import java.io.Serializable; import java.util.Arrays; -import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; - /** * Generic binary holder. */ @@ -72,35 +67,6 @@ public Binary(final byte type, final byte[] data) { this.data = data.clone(); } - /** - * Construct a Type 9 BsonBinary from the given Vector. - * - * @param vector the {@link Vector} - * @since 5.3 - */ - public Binary(final Vector vector) { - if (vector == null) { - throw new IllegalArgumentException("Vector must not be null"); - } - this.data = encodeVectorToBinary(vector); - type = BsonBinarySubType.VECTOR.getValue(); - } - - /** - * Returns the binary as a {@link Vector}. The binary type must be 9. - * - * @return the vector - * @throws IllegalArgumentException if the binary subtype is not {@link BsonBinarySubType#VECTOR}. - * @since 5.3 - */ - public Vector asVector() { - if (!BsonBinarySubType.isVector(type)) { - throw new BsonInvalidOperationException("type must be a Vector subtype."); - } - - return VectorHelper.decodeBinaryToVector(this.data); - } - /** * Get the binary sub type as a byte. * diff --git a/bson/src/test/resources/bson-binary-vector/README.md b/bson/src/test/resources/bson-binary-vector/README.md deleted file mode 100644 index 73a5f0a9f33..00000000000 --- a/bson/src/test/resources/bson-binary-vector/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Testing Binary subtype 9: Vector - -The JSON files in this directory tree are platform-independent tests that drivers can use to prove their conformance to -the specification. - -These tests focus on the roundtrip of the list numbers as input/output, along with their data type and byte padding. - -Additional tests exist in `bson_corpus/tests/binary.json` but do not sufficiently test the end-to-end process of Vector -to BSON. For this reason, drivers must create a bespoke test runner for the vector subtype. - -Each test case here pertains to a single vector. The inputs required to create the Binary BSON object are defined, and -when valid, the Canonical BSON and Extended JSON representations are included for comparison. - -## Version - -Files in the "specifications" repository have no version scheme. They are not tied to a MongoDB server version. - -## Format - -#### Top level keys - -Each JSON file contains three top-level keys. - -- `description`: human-readable description of what is in the file -- `test_key`: Field name used when decoding/encoding a BSON document containing the single BSON Binary for the test - case. Applies to *every* case. -- `tests`: array of test case objects, each of which have the following keys. Valid cases will also contain additional - binary and json encoding values. - -#### Keys of tests objects - -- `description`: string describing the test. -- `valid`: boolean indicating if the vector, dtype, and padding should be considered a valid input. -- `vector`: list of numbers -- `dtype_hex`: string defining the data type in hex (e.g. "0x10", "0x27") -- `dtype_alias`: (optional) string defining the data dtype, perhaps as Enum. -- `padding`: (optional) integer for byte padding. Defaults to 0. -- `canonical_bson`: (required if valid is true) an (uppercase) big-endian hex representation of a BSON byte string. -- `canonical_extjson`: (required if valid is true) string containing a Canonical Extended JSON document. Because this is - itself embedded as a *string* inside a JSON document, characters like quote and backslash are escaped. \ No newline at end of file diff --git a/bson/src/test/unit/org/bson/BsonBinaryTest.java b/bson/src/test/unit/org/bson/BsonBinaryTest.java deleted file mode 100644 index 62d54116276..00000000000 --- a/bson/src/test/unit/org/bson/BsonBinaryTest.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bson; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; -import org.junit.jupiter.params.provider.MethodSource; - -import java.util.stream.Stream; - -import static org.bson.assertions.Assertions.fail; -import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class BsonBinaryTest { - - static Stream provideVectors() { - return Stream.of( - Vector.floatVector(new float[]{1.5f, 2.1f, 3.1f}), - Vector.int8Vector(new byte[]{10, 20, 30}), - Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010000}, (byte) 3) - ); - } - - @ParameterizedTest - @MethodSource("provideVectors") - void shouldCreateBsonBinaryFromVector(final Vector vector) { - // when - BsonBinary bsonBinary = new BsonBinary(vector); - - // then - assertEquals(BsonBinarySubType.VECTOR.getValue(), bsonBinary.getType(), "The subtype must be VECTOR"); - assertNotNull(bsonBinary.getData(), "BsonBinary data should not be null"); - assertArrayEquals(encodeVectorToBinary(vector), bsonBinary.getData()); - } - - @Test - void shouldThrowExceptionWhenCreatingBsonBinaryWithNullVector() { - // given - Vector vector = null; - - // when & then - IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new BsonBinary(vector)); - assertEquals("Vector must not be null", exception.getMessage()); - } - - @ParameterizedTest - @MethodSource("provideVectors") - void shouldConvertBsonBinaryToVector(final Vector actualVector) { - // given - BsonBinary bsonBinary = new BsonBinary(actualVector); - - // when - Vector decodedVector = bsonBinary.asVector(); - - // then - assertNotNull(decodedVector); - assertEquals(actualVector.getDataType(), decodedVector.getDataType()); - assertVectorDataDecoding(actualVector, decodedVector); - } - - private static void assertVectorDataDecoding(final Vector actualVector, final Vector decodedVector) { - switch (actualVector.getDataType()) { - case FLOAT32: - Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); - Float32Vector decodedFloat32Vector1 = decodedVector.asFloat32Vector(); - assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector1.getVectorArray(), - "Float vector data should match after decoding"); - break; - case INT8: - Int8Vector actualInt8Vector = actualVector.asInt8Vector(); - Int8Vector decodedInt8Vector1 = decodedVector.asInt8Vector(); - assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector1.getVectorArray(), - "Int8 vector data should match after decoding"); - break; - case PACKED_BIT: - PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); - PackedBitVector decodedPackedBitVector = decodedVector.asPackedBitVector(); - assertArrayEquals(actualPackedBitVector.getVectorArray(), decodedPackedBitVector.getVectorArray(), - "Packed bit vector data should match after decoding"); - assertEquals(actualPackedBitVector.getPadding(), decodedPackedBitVector.getPadding(), "Padding should match after decoding"); - break; - default: - fail("Unexpected vector type: " + actualVector.getDataType()); - } - } - - @ParameterizedTest - @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) - void shouldThrowExceptionWhenBsonBinarySubTypeIsNotVector(final BsonBinarySubType bsonBinarySubType) { - // given - byte[] data = new byte[]{1, 2, 3, 4}; - BsonBinary bsonBinary = new BsonBinary(bsonBinarySubType.getValue(), data); - - // when & then - BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); - assertEquals("type must be a Vector subtype.", exception.getMessage()); - } -} diff --git a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java b/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java similarity index 55% rename from bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java rename to bson/src/test/unit/org/bson/BsonBinaryVectorTest.java index 111b4a6f5ba..485448a5bc4 100644 --- a/bson/src/test/unit/org/bson/internal/vector/VectorHelperTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java @@ -14,14 +14,12 @@ * limitations under the License. */ -package org.bson.internal.vector; +package org.bson; -import org.bson.Float32Vector; -import org.bson.Int8Vector; -import org.bson.PackedBitVector; -import org.bson.Vector; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; @@ -30,20 +28,46 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class BsonBinaryVectorTest { -class VectorHelperTest { private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); private static final byte INT8_DTYPE = Vector.DataType.INT8.getValue(); private static final byte PACKED_BIT_DTYPE = Vector.DataType.PACKED_BIT.getValue(); public static final int ZERO_PADDING = 0; + @Test + void shouldThrowExceptionWhenCreatingBsonBinaryWithNullVector() { + // given + Vector vector = null; + + // when & then + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new BsonBinary(vector)); + assertEquals("Vector must not be null", exception.getMessage()); + } + + @ParameterizedTest + @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) + void shouldThrowExceptionWhenBsonBinarySubTypeIsNotVector(final BsonBinarySubType bsonBinarySubType) { + // given + byte[] data = new byte[]{1, 2, 3, 4}; + BsonBinary bsonBinary = new BsonBinary(bsonBinarySubType.getValue(), data); + + // when & then + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); + assertEquals("type must be a Vector subtype.", exception.getMessage()); + } + @ParameterizedTest(name = "{index}: {0}") @MethodSource("provideFloatVectors") void shouldEncodeFloatVector(final Vector actualFloat32Vector, final byte[] expectedBsonEncodedVector) { // when - byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualFloat32Vector); + BsonBinary actualBsonBinary = new BsonBinary(actualFloat32Vector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); - //Then + // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); } @@ -51,18 +75,17 @@ void shouldEncodeFloatVector(final Vector actualFloat32Vector, final byte[] expe @MethodSource("provideFloatVectors") void shouldDecodeFloatVector(final Float32Vector expectedFloatVector, final byte[] bsonEncodedVector) { // when - Float32Vector decodedVector = (Float32Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); + Float32Vector decodedVector = (Float32Vector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); // then - assertEquals(Vector.DataType.FLOAT32, decodedVector.getDataType()); - assertArrayEquals(expectedFloatVector.getVectorArray(), decodedVector.getVectorArray()); + assertEquals(expectedFloatVector, decodedVector); } - private static Stream provideFloatVectors() { + private static Stream provideFloatVectors() { return Stream.of( - new Object[]{ - Vector.floatVector( - new float[]{1.1f, 2.2f, 3.3f, -1.0f, Float.MAX_VALUE, Float.MIN_VALUE, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY}), + arguments( + Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f, -1.0f, Float.MAX_VALUE, Float.MIN_VALUE, Float.POSITIVE_INFINITY, + Float.NEGATIVE_INFINITY}), new byte[]{FLOAT32_DTYPE, ZERO_PADDING, (byte) 205, (byte) 204, (byte) 140, (byte) 63, // 1.1f in little-endian (byte) 205, (byte) 204, (byte) 12, (byte) 64, // 2.2f in little-endian @@ -71,17 +94,19 @@ private static Stream provideFloatVectors() { (byte) 255, (byte) 255, (byte) 127, (byte) 127, // Float.MAX_VALUE in little-endian (byte) 1, (byte) 0, (byte) 0, (byte) 0, // Float.MIN_VALUE in little-endian (byte) 0, (byte) 0, (byte) 128, (byte) 127, // Float.POSITIVE_INFINITY in little-endian - (byte) 0, (byte) 0, (byte) 128, (byte) 255, // Float.NEGATIVE_INFINITY in little-endian - }}, - new Object[]{ + (byte) 0, (byte) 0, (byte) 128, (byte) 255 // Float.NEGATIVE_INFINITY in little-endian + } + ), + arguments( Vector.floatVector(new float[]{0.0f}), new byte[]{FLOAT32_DTYPE, ZERO_PADDING, (byte) 0, (byte) 0, (byte) 0, (byte) 0 // 0.0f in little-endian - }}, - new Object[]{ + } + ), + arguments( Vector.floatVector(new float[]{}), - new byte[]{FLOAT32_DTYPE, ZERO_PADDING, - }} + new byte[]{FLOAT32_DTYPE, ZERO_PADDING} + ) ); } @@ -89,9 +114,11 @@ private static Stream provideFloatVectors() { @MethodSource("provideInt8Vectors") void shouldEncodeInt8Vector(final Vector actualInt8Vector, final byte[] expectedBsonEncodedVector) { // when - byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualInt8Vector); + BsonBinary actualBsonBinary = new BsonBinary(actualInt8Vector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); } @@ -99,22 +126,21 @@ void shouldEncodeInt8Vector(final Vector actualInt8Vector, final byte[] expected @MethodSource("provideInt8Vectors") void shouldDecodeInt8Vector(final Int8Vector expectedInt8Vector, final byte[] bsonEncodedVector) { // when - Int8Vector decodedVector = (Int8Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); + Int8Vector decodedVector = (Int8Vector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); // then - assertEquals(Vector.DataType.INT8, decodedVector.getDataType()); - assertArrayEquals(expectedInt8Vector.getVectorArray(), decodedVector.getVectorArray()); + assertEquals(expectedInt8Vector, decodedVector); } - private static Stream provideInt8Vectors() { + private static Stream provideInt8Vectors() { return Stream.of( - new Object[]{ + arguments( Vector.int8Vector(new byte[]{Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE}), new byte[]{INT8_DTYPE, ZERO_PADDING, Byte.MAX_VALUE, 1, 2, 3, 4, Byte.MIN_VALUE - }}, - new Object[]{Vector.int8Vector(new byte[]{}), + }), + arguments(Vector.int8Vector(new byte[]{}), new byte[]{INT8_DTYPE, ZERO_PADDING} - } + ) ); } @@ -122,9 +148,11 @@ private static Stream provideInt8Vectors() { @MethodSource("providePackedBitVectors") void shouldEncodePackedBitVector(final Vector actualPackedBitVector, final byte[] expectedBsonEncodedVector) { // when - byte[] actualBsonEncodedVector = VectorHelper.encodeVectorToBinary(actualPackedBitVector); + BsonBinary actualBsonBinary = new BsonBinary(actualPackedBitVector); + byte[] actualBsonEncodedVector = actualBsonBinary.getData(); // then + assertEquals(BsonBinarySubType.VECTOR.getValue(), actualBsonBinary.getType(), "The subtype must be VECTOR"); assertArrayEquals(expectedBsonEncodedVector, actualBsonEncodedVector); } @@ -132,25 +160,22 @@ void shouldEncodePackedBitVector(final Vector actualPackedBitVector, final byte[ @MethodSource("providePackedBitVectors") void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector, final byte[] bsonEncodedVector) { // when - PackedBitVector decodedVector = (PackedBitVector) VectorHelper.decodeBinaryToVector(bsonEncodedVector); + PackedBitVector decodedVector = (PackedBitVector) new BsonBinary(BsonBinarySubType.VECTOR, bsonEncodedVector).asVector(); // then - assertEquals(Vector.DataType.PACKED_BIT, decodedVector.getDataType()); - assertArrayEquals(expectedPackedBitVector.getVectorArray(), decodedVector.getVectorArray()); - assertEquals(expectedPackedBitVector.getPadding(), decodedVector.getPadding()); + assertEquals(expectedPackedBitVector, decodedVector); } - private static Stream providePackedBitVectors() { + private static Stream providePackedBitVectors() { return Stream.of( - new Object[]{ + arguments( Vector.packedBitVector(new byte[]{(byte) 0, (byte) 255, (byte) 10}, (byte) 2), new byte[]{PACKED_BIT_DTYPE, 2, (byte) 0, (byte) 255, (byte) 10} - }, - new Object[]{ + ), + arguments( Vector.packedBitVector(new byte[0], (byte) 0), new byte[]{PACKED_BIT_DTYPE, 0} - } - ); + )); } @Test @@ -159,10 +184,10 @@ void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { byte[] invalidData = {FLOAT32_DTYPE, 0, 10, 20, 30}; // when & Then - IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { - VectorHelper.decodeBinaryToVector(invalidData); + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("state should be: Byte array length must be a multiple of 4 for FLOAT32 data type.", thrown.getMessage()); + assertEquals("Byte array length must be a multiple of 4 for FLOAT32 data type.", thrown.getMessage()); } @ParameterizedTest @@ -172,10 +197,10 @@ void shouldThrowExceptionForInvalidFloatArrayPaddingWhenDecode(final byte invali byte[] invalidData = {FLOAT32_DTYPE, invalidPadding, 10, 20, 30, 20}; // when & Then - IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { - VectorHelper.decodeBinaryToVector(invalidData); + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("state should be: Padding must be 0 for FLOAT32 data type.", thrown.getMessage()); + assertEquals("Padding must be 0 for FLOAT32 data type, but found: " + invalidPadding, thrown.getMessage()); } @ParameterizedTest @@ -185,10 +210,10 @@ void shouldThrowExceptionForInvalidInt8ArrayPaddingWhenDecode(final byte invalid byte[] invalidData = {INT8_DTYPE, invalidPadding, 10, 20, 30, 20}; // when & Then - IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { - VectorHelper.decodeBinaryToVector(invalidData); + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("state should be: Padding must be 0 for INT8 data type.", thrown.getMessage()); + assertEquals("Padding must be 0 for INT8 data type, but found: " + invalidPadding, thrown.getMessage()); } @ParameterizedTest @@ -197,11 +222,11 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecode(final byte in // given byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding, 10, 20, 30, 20}; - // when & Then - IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { - VectorHelper.decodeBinaryToVector(invalidData); + // when & then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("state should be: Padding must be between 0 and 7 bits.", thrown.getMessage()); + assertEquals("Padding must be between 0 and 7 bits, but found: " + invalidPadding, thrown.getMessage()); } @ParameterizedTest @@ -211,29 +236,16 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(fi byte[] invalidData = {PACKED_BIT_DTYPE, invalidPadding}; // when & Then - IllegalStateException thrown = assertThrows(IllegalStateException.class, () -> { - VectorHelper.decodeBinaryToVector(invalidData); + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("state should be: Padding must be 0 if vector is empty.", thrown.getMessage()); - } - - @Test - void shouldDetermineVectorDType() { - // given - Vector.DataType[] values = Vector.DataType.values(); - - for (Vector.DataType value : values) { - // when - byte dtype = value.getValue(); - Vector.DataType actual = VectorHelper.determineVectorDType(dtype); - - // then - assertEquals(value, actual); - } + assertEquals("Padding must be 0 if vector is empty, but found: " + invalidPadding, thrown.getMessage()); } @Test void shouldThrowWhenUnknownVectorDType() { - assertThrows(IllegalStateException.class, () -> VectorHelper.determineVectorDType((byte) 0)); + // when + BsonBinary bsonBinary = new BsonBinary(BsonBinarySubType.VECTOR, new byte[]{(byte) 0}); + assertThrows(BsonInvalidOperationException.class, bsonBinary::asVector); } } diff --git a/bson/src/test/unit/org/bson/VectorTest.java b/bson/src/test/unit/org/bson/VectorTest.java index 66b0914097b..36cc7156db6 100644 --- a/bson/src/test/unit/org/bson/VectorTest.java +++ b/bson/src/test/unit/org/bson/VectorTest.java @@ -38,7 +38,7 @@ void shouldCreateInt8Vector() { // then assertNotNull(vector); assertEquals(Vector.DataType.INT8, vector.getDataType()); - assertArrayEquals(data, vector.getVectorArray()); + assertArrayEquals(data, vector.getData()); } @Test @@ -48,7 +48,7 @@ void shouldThrowExceptionWhenCreatingInt8VectorWithNullData() { // when & Then IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.int8Vector(data)); - assertEquals("vectorData can not be null", exception.getMessage()); + assertEquals("data can not be null", exception.getMessage()); } @Test @@ -62,7 +62,7 @@ void shouldCreateFloat32Vector() { // then assertNotNull(vector); assertEquals(Vector.DataType.FLOAT32, vector.getDataType()); - assertArrayEquals(data, vector.getVectorArray()); + assertArrayEquals(data, vector.getData()); } @Test @@ -72,7 +72,7 @@ void shouldThrowExceptionWhenCreatingFloat32VectorWithNullData() { // when & Then IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.floatVector(data)); - assertEquals("vectorData can not be null", exception.getMessage()); + assertEquals("data can not be null", exception.getMessage()); } @@ -88,7 +88,7 @@ void shouldCreatePackedBitVector(final byte validPadding) { // then assertNotNull(vector); assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); - assertArrayEquals(data, vector.getVectorArray()); + assertArrayEquals(data, vector.getData()); assertEquals(validPadding, vector.getPadding()); } @@ -101,7 +101,7 @@ void shouldThrowExceptionWhenPackedBitVectorHasInvalidPadding(final byte invalid // when & Then IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.packedBitVector(data, invalidPadding)); - assertEquals("state should be: Padding must be between 0 and 7 bits.", exception.getMessage()); + assertEquals("state should be: Padding must be between 0 and 7 bits. Provided padding: " + invalidPadding, exception.getMessage()); } @Test @@ -113,7 +113,7 @@ void shouldThrowExceptionWhenPackedBitVectorIsCreatedWithNullData() { // when & Then IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.packedBitVector(data, padding)); - assertEquals("Vector data can not be null", exception.getMessage()); + assertEquals("data can not be null", exception.getMessage()); } @Test @@ -128,7 +128,7 @@ void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() { // then assertNotNull(vector); assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType()); - assertArrayEquals(data, vector.getVectorArray()); + assertArrayEquals(data, vector.getData()); assertEquals(padding, vector.getPadding()); } @@ -139,9 +139,9 @@ void shouldThrowExceptionWhenPackedBitVectorWithNonZeroPaddingAndEmptyData() { byte padding = 1; // when & Then - IllegalStateException exception = assertThrows(IllegalStateException.class, () -> + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Vector.packedBitVector(data, padding)); - assertEquals("state should be: Padding must be 0 if vector is empty", exception.getMessage()); + assertEquals("state should be: Padding must be 0 if vector is empty. Provided padding: " + padding, exception.getMessage()); } @Test @@ -152,7 +152,7 @@ void shouldThrowExceptionWhenRetrievingInt8DataFromNonInt8Vector() { // when & Then IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asInt8Vector); - assertEquals("Expected vector type INT8 but found FLOAT32", exception.getMessage()); + assertEquals("Expected vector data type INT8, but found FLOAT32", exception.getMessage()); } @Test @@ -163,7 +163,7 @@ void shouldThrowExceptionWhenRetrievingFloat32DataFromNonFloat32Vector() { // when & Then IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asFloat32Vector); - assertEquals("Expected vector type FLOAT32 but found INT8", exception.getMessage()); + assertEquals("Expected vector data type FLOAT32, but found INT8", exception.getMessage()); } @Test @@ -174,6 +174,6 @@ void shouldThrowExceptionWhenRetrievingPackedBitDataFromNonPackedBitVector() { // when & Then IllegalStateException exception = assertThrows(IllegalStateException.class, vector::asPackedBitVector); - assertEquals("Expected vector type PACKED_BIT but found FLOAT32", exception.getMessage()); + assertEquals("Expected vector data type PACKED_BIT, but found FLOAT32", exception.getMessage()); } } diff --git a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java index b092121eb9d..52b21e1e8db 100644 --- a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java +++ b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java @@ -85,14 +85,18 @@ public void roundTrip(final Document input, final Document expected) { roundTrip(input, result -> assertEquals(expected, result)); } - protected OutputBuffer encode(final Codec codec, final T value) { + OutputBuffer encode(final Codec codec, final T value) { OutputBuffer buffer = new BasicOutputBuffer(); BsonWriter writer = new BsonBinaryWriter(buffer); codec.encode(writer, value, EncoderContext.builder().build()); return buffer; } - protected T decode(final Codec codec, final OutputBuffer buffer) { + void encode(final Codec codec, final T value, final BsonWriter writer) { + codec.encode(writer, value, EncoderContext.builder().build()); + } + + T decode(final Codec codec, final OutputBuffer buffer) { BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(buffer.toByteArray())))); return codec.decode(reader, DecoderContext.builder().build()); } diff --git a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java index 79c65573556..d407df31d37 100644 --- a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java @@ -23,6 +23,7 @@ import org.bson.BsonObjectId; import org.bson.ByteBufNIO; import org.bson.Document; +import org.bson.Vector; import org.bson.io.BasicOutputBuffer; import org.bson.io.BsonInput; import org.bson.io.ByteBufferBsonInput; @@ -36,6 +37,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -43,11 +47,13 @@ import java.util.Date; import java.util.HashSet; import java.util.List; +import java.util.stream.Stream; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; public class DocumentCodecTest { private BasicOutputBuffer buffer; @@ -64,8 +70,9 @@ public void tearDown() { writer.close(); } - @Test - public void testPrimitiveBSONTypeCodecs() throws IOException { + @ParameterizedTest + @MethodSource("provideVectorsForRoundTrip") + public void testPrimitiveBSONTypeCodecs(final Vector vector) throws IOException { DocumentCodec documentCodec = new DocumentCodec(); Document doc = new Document(); doc.put("oid", new ObjectId()); @@ -80,6 +87,7 @@ public void testPrimitiveBSONTypeCodecs() throws IOException { doc.put("code", new Code("var i = 0")); doc.put("minkey", new MinKey()); doc.put("maxkey", new MaxKey()); + doc.put("vector", vector); // doc.put("pattern", Pattern.compile("^hello")); // TODO: Pattern doesn't override equals method! doc.put("null", null); @@ -90,6 +98,14 @@ public void testPrimitiveBSONTypeCodecs() throws IOException { assertEquals(doc, decodedDocument); } + private static Stream provideVectorsForRoundTrip() { + return Stream.of( + arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f})), + arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40})), + arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3)) + ); + } + @Test public void testIterableEncoding() throws IOException { DocumentCodec documentCodec = new DocumentCodec(); diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java index d3532aedecb..5e0ea495f75 100644 --- a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -18,41 +18,39 @@ import org.bson.BSONException; import org.bson.BsonBinary; +import org.bson.BsonBinaryReader; import org.bson.BsonBinarySubType; -import org.bson.BsonReader; +import org.bson.BsonBinaryWriter; +import org.bson.BsonDocument; +import org.bson.BsonType; import org.bson.BsonWriter; -import org.bson.Document; +import org.bson.ByteBufNIO; import org.bson.Float32Vector; import org.bson.Int8Vector; import org.bson.PackedBitVector; import org.bson.Vector; -import org.bson.codecs.configuration.CodecRegistry; +import org.bson.io.BasicOutputBuffer; +import org.bson.io.ByteBufferBsonInput; import org.bson.io.OutputBuffer; -import org.bson.types.Binary; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mockito; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; import java.util.stream.Stream; -import static java.util.Arrays.asList; -import static org.bson.codecs.configuration.CodecRegistries.fromProviders; +import static org.bson.BsonHelper.toBson; +import static org.bson.assertions.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; class VectorCodecTest extends CodecTestCase { - private static final CodecRegistry CODEC_REGISTRIES = fromProviders(asList(new ValueCodecProvider(), new DocumentCodecProvider())); - private static Stream provideVectorsAndCodecsForRoundTrip() { return Stream.of( arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec(), Float32Vector.class), @@ -66,47 +64,56 @@ private static Stream provideVectorsAndCodecsForRoundTrip() { @ParameterizedTest @MethodSource("provideVectorsAndCodecsForRoundTrip") - void shouldRoundTripVectors(final Vector vectorToEncode) { - //given - Document expectedDocument = new Document("vector", vectorToEncode); - - //when - Codec codec = CODEC_REGISTRIES.get(Document.class); - OutputBuffer buffer = encode(codec, expectedDocument); - Document actualDecodedDocument = decode(codec, buffer); - - //then - Binary binaryVector = (Binary) actualDecodedDocument.get("vector"); - assertNotEquals(actualDecodedDocument, expectedDocument); - Vector actualVector = binaryVector.asVector(); - assertEquals(actualVector, vectorToEncode); - } - - @ParameterizedTest - @MethodSource("provideVectorsAndCodecsForRoundTrip") - void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorCodec) { + void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorCodec) throws IOException { // given - BsonWriter mockWriter = Mockito.mock(BsonWriter.class); + BsonBinary bsonBinary = new BsonBinary(vectorToEncode); + byte[] encodedVector = bsonBinary.getData(); + ByteArrayOutputStream expectedStream = new ByteArrayOutputStream(); + // Start of document with total length of 4 bytes (little-endian format) + byte totalDocumentLength = (byte) (14 + encodedVector.length); + expectedStream.write(new byte[]{totalDocumentLength, 0, 0, 0}); + // Bson type for vector + expectedStream.write((byte) BsonType.BINARY.getValue()); + // Field name "b4" (ASCII for 'b', '4', null terminator) + expectedStream.write(new byte[]{98, 52, 0}); + // Total length of binary data (little-endian format) + expectedStream.write(new byte[]{(byte) encodedVector.length, 0, 0, 0}); + // Vector binary subtype + expectedStream.write(BsonBinarySubType.VECTOR.getValue()); + // Actual BSON binary data + expectedStream.write(encodedVector); + // End of document + expectedStream.write(0); + + OutputBuffer buffer = new BasicOutputBuffer(); + BsonWriter writer = new BsonBinaryWriter(buffer); + writer.writeStartDocument(); + writer.writeName("b4"); // when - vectorCodec.encode(mockWriter, vectorToEncode, EncoderContext.builder().build()); + vectorCodec.encode(writer, vectorToEncode, EncoderContext.builder().build()); + writer.writeEndDocument(); // then - verify(mockWriter, times(1)).writeBinaryData(new BsonBinary(vectorToEncode)); - verifyNoMoreInteractions(mockWriter); + assertArrayEquals(expectedStream.toByteArray(), buffer.toByteArray()); } @ParameterizedTest @MethodSource("provideVectorsAndCodecsForRoundTrip") void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorCodec) { // given - BsonReader mockReader = Mockito.mock(BsonReader.class); - BsonBinary bsonBinary = new BsonBinary(vectorToDecode); - when(mockReader.peekBinarySubType()).thenReturn(BsonBinarySubType.VECTOR.getValue()); - when(mockReader.readBinaryData()).thenReturn(bsonBinary); + OutputBuffer buffer = new BasicOutputBuffer(); + BsonWriter writer = new BsonBinaryWriter(buffer); + writer.writeStartDocument(); + writer.writeName("vector"); + writer.writeBinaryData(new BsonBinary(vectorToDecode)); + writer.writeEndDocument(); + + BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(buffer.toByteArray())))); + reader.readStartDocument(); // when - Vector decodedVector = vectorCodec.decode(mockReader, DecoderContext.builder().build()); + Vector decodedVector = vectorCodec.decode(reader, DecoderContext.builder().build()); // then assertNotNull(decodedVector); @@ -118,22 +125,25 @@ void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorC @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { // given - BsonReader mockReader = Mockito.mock(BsonReader.class); - when(mockReader.peekBinarySubType()).thenReturn(subType.getValue()); + BsonDocument document = new BsonDocument("name", new BsonBinary(subType.getValue(), new byte[]{})); + BsonBinaryReader reader = new BsonBinaryReader(toBson(document)); + reader.readStartDocument(); + // when & then Stream.of(new Float32VectorCodec(), new Int8VectorCodec(), new PackedBitVectorCodec()) .forEach(codec -> { - // when & then BSONException exception = assertThrows(BSONException.class, () -> - codec.decode(mockReader, DecoderContext.builder().build())); - assertEquals("Unexpected BsonBinarySubType", exception.getMessage()); + codec.decode(reader, DecoderContext.builder().build())); + assertEquals("Expected vector binary subtype 9 but found: " + subType.getValue(), exception.getMessage()); }); } @ParameterizedTest @MethodSource("provideVectorsAndCodecsForRoundTrip") - void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec, final Class expectedEncoderClass) { + void shouldReturnCorrectEncoderClass(final Vector vector, + final Codec codec, + final Class expectedEncoderClass) { // when Class encoderClass = codec.getEncoderClass(); @@ -141,16 +151,6 @@ void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec) { - // when - String result = codec.toString(); - - // then - assertEquals(codec.getClass().getSimpleName() + "{}", result); - } - private static Stream> provideVectorsCodec() { return Stream.of( new VectorCodec(), diff --git a/bson/src/test/unit/org/bson/types/BinaryTest.java b/bson/src/test/unit/org/bson/types/BinaryTest.java deleted file mode 100644 index cef524be3b3..00000000000 --- a/bson/src/test/unit/org/bson/types/BinaryTest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bson.types; - -import org.bson.BsonBinarySubType; -import org.bson.BsonInvalidOperationException; -import org.bson.Float32Vector; -import org.bson.Int8Vector; -import org.bson.PackedBitVector; -import org.bson.Vector; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; -import org.junit.jupiter.params.provider.MethodSource; - -import java.util.stream.Stream; - -import static org.bson.assertions.Assertions.fail; -import static org.bson.internal.vector.VectorHelper.encodeVectorToBinary; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class BinaryTest { - static Stream provideVectors() { - return Stream.of( - Vector.floatVector(new float[]{1.5f, 2.1f, 3.1f}), - Vector.int8Vector(new byte[]{10, 20, 30}), - Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010000}, (byte) 3) - ); - } - - @ParameterizedTest - @MethodSource("provideVectors") - void shouldCreateBinaryFromVector(final Vector vector) { - // when - Binary binary = new Binary(vector); - - // then - assertEquals(BsonBinarySubType.VECTOR.getValue(), binary.getType(), "The subtype must be VECTOR"); - assertNotNull(binary.getData(), "Binary data should not be null"); - assertArrayEquals(encodeVectorToBinary(vector), binary.getData()); - } - - @Test - void shouldThrowExceptionWhenCreatingBinaryWithNullVector() { - // given - Vector vector = null; - - // when & then - IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new Binary(vector)); - assertEquals("Vector must not be null", exception.getMessage()); - } - - @ParameterizedTest - @MethodSource("provideVectors") - void shouldConvertBinaryToVector(final Vector actualVector) { - // given - Binary binary = new Binary(actualVector); - - // when - Vector decodedVector = binary.asVector(); - - // then - assertNotNull(decodedVector); - assertEquals(actualVector.getDataType(), decodedVector.getDataType()); - assertVectorDataDecoding(actualVector, decodedVector); - } - - private static void assertVectorDataDecoding(final Vector actualVector, final Vector decodedVector) { - switch (actualVector.getDataType()) { - case FLOAT32: - Float32Vector actualFloat32Vector = actualVector.asFloat32Vector(); - Float32Vector decodedFloat32Vector = decodedVector.asFloat32Vector(); - assertArrayEquals(actualFloat32Vector.getVectorArray(), decodedFloat32Vector.getVectorArray(), - "Float vector data should match after decoding"); - break; - case INT8: - Int8Vector actualInt8Vector = actualVector.asInt8Vector(); - Int8Vector decodedInt8Vector = decodedVector.asInt8Vector(); - assertArrayEquals(actualInt8Vector.getVectorArray(), decodedInt8Vector.getVectorArray(), - "Int8 vector data should match after decoding"); - break; - case PACKED_BIT: - PackedBitVector actualPackedBitVector = actualVector.asPackedBitVector(); - PackedBitVector decodedPackedBitVector = decodedVector.asPackedBitVector(); - assertArrayEquals(actualPackedBitVector.getVectorArray(), decodedPackedBitVector.getVectorArray(), - "Packed bit vector data should match after decoding"); - assertEquals(actualPackedBitVector.getPadding(), decodedPackedBitVector.getPadding(), - "Padding should match after decoding"); - break; - default: - fail("Unexpected vector type: " + actualVector.getDataType()); - } - } - - @ParameterizedTest - @EnumSource(value = BsonBinarySubType.class, mode = EnumSource.Mode.EXCLUDE, names = {"VECTOR"}) - void shouldThrowExceptionWhenBinarySubTypeIsNotVector(final BsonBinarySubType binarySubType) { - // given - byte[] data = new byte[]{1, 2, 3, 4}; - Binary binary = new Binary(binarySubType.getValue(), data); - - // when & then - BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, binary::asVector); - assertEquals("type must be a Vector subtype.", exception.getMessage()); - } -} diff --git a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java index d2689ff762f..64e84f6afc8 100644 --- a/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java +++ b/bson/src/test/unit/org/bson/vector/VectorGenericBsonTest.java @@ -47,24 +47,28 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeFalse; -// BSON tests powered by language-agnostic JSON-based tests included in test resources +/** + * See + * JSON-based tests that included in test resources. + */ class VectorGenericBsonTest { private static final List TEST_NAMES_TO_IGNORE = Arrays.asList( - //NO API to set padding for Floats available + //NO API to set padding for floats available. "FLOAT32 with padding", - //NO API to set padding for Floats available + //NO API to set padding for floats available. "INT8 with padding", - //It is impossible to provide float inputs for INT8 in the API + //It is impossible to provide float inputs for INT8 in the API. "INT8 with float inputs", - //It is impossible to provide float inputs for INT8 + //It is impossible to provide float inputs for INT8. "Underflow Vector PACKED_BIT", - //It is impossible to provide float inputs for PACKED_BIT in the API + //It is impossible to provide float inputs for PACKED_BIT in the API. "Vector with float values PACKED_BIT", - //It is impossible to provide float inputs for INT8 + //It is impossible to provide float inputs for INT8. "Overflow Vector PACKED_BIT", + //It is impossible to overflow byte with values higher than 127 in the API. "Overflow Vector INT8", - // It is impossible to provide -129 for byte. + //It is impossible to underflow byte with values lower than -128 in the API. "Underflow Vector INT8"); @@ -83,7 +87,7 @@ void shouldPassAllOutcomes(@SuppressWarnings("unused") final String description, } } - private void runInvalidTestCase(final BsonDocument testCase) { + private static void runInvalidTestCase(final BsonDocument testCase) { BsonArray arrayVector = testCase.getArray("vector"); byte expectedPadding = (byte) testCase.getInt32("padding").getValue(); byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); @@ -109,7 +113,7 @@ private void runInvalidTestCase(final BsonDocument testCase) { } } - private void runValidTestCase(final String testKey, final BsonDocument testCase) { + private static void runValidTestCase(final String testKey, final BsonDocument testCase) { String description = testCase.getString("description").getValue(); byte dtypeByte = Byte.decode(testCase.getString("dtype_hex").getValue()); @@ -124,7 +128,7 @@ private void runValidTestCase(final String testKey, final BsonDocument testCase) switch (expectedDType) { case INT8: byte[] expectedVectorData = toByteArray(arrayVector); - byte[] actualVectorData = actualVector.asInt8Vector().getVectorArray(); + byte[] actualVectorData = actualVector.asInt8Vector().getData(); assertVectorDecoding( expectedVectorData, expectedDType, @@ -186,7 +190,7 @@ private static void assertThatVectorCreationResultsInCorrectBinary(final Vector format("Failed to create expected BSON for document with description '%s'", description)); } - private void assertVectorDecoding(final byte[] expectedVectorData, + private static void assertVectorDecoding(final byte[] expectedVectorData, final Vector.DataType expectedDType, final byte[] actualVectorData, final Vector actualVector) { @@ -195,11 +199,11 @@ private void assertVectorDecoding(final byte[] expectedVectorData, assertEquals(expectedDType, actualVector.getDataType()); } - private void assertVectorDecoding(final byte[] expectedVectorData, + private static void assertVectorDecoding(final byte[] expectedVectorData, final Vector.DataType expectedDType, final byte expectedPadding, final PackedBitVector actualVector) { - byte[] actualVectorData = actualVector.getVectorArray(); + byte[] actualVectorData = actualVector.getData(); assertVectorDecoding( expectedVectorData, expectedDType, @@ -208,16 +212,16 @@ private void assertVectorDecoding(final byte[] expectedVectorData, assertEquals(expectedPadding, actualVector.getPadding()); } - private void assertVectorDecoding(final float[] expectedVectorData, + private static void assertVectorDecoding(final float[] expectedVectorData, final Vector.DataType expectedDType, final Float32Vector actualVector) { - float[] actualVectorArray = actualVector.getVectorArray(); + float[] actualVectorArray = actualVector.getData(); Assertions.assertArrayEquals(actualVectorArray, expectedVectorData, () -> "Actual: " + Arrays.toString(actualVectorArray) + " != Expected:" + Arrays.toString(expectedVectorData)); assertEquals(expectedDType, actualVector.getDataType()); } - private byte[] toByteArray(final BsonArray arrayVector) { + private static byte[] toByteArray(final BsonArray arrayVector) { byte[] bytes = new byte[arrayVector.size()]; for (int i = 0; i < arrayVector.size(); i++) { bytes[i] = (byte) arrayVector.get(i).asInt32().getValue(); @@ -225,7 +229,7 @@ private byte[] toByteArray(final BsonArray arrayVector) { return bytes; } - private float[] toFloatArray(final BsonArray arrayVector) { + private static float[] toFloatArray(final BsonArray arrayVector) { float[] floats = new float[arrayVector.size()]; for (int i = 0; i < arrayVector.size(); i++) { BsonValue bsonValue = arrayVector.get(i); @@ -265,8 +269,8 @@ private static Stream provideTestCases() throws URISyntaxException, I private static String createTestCaseDescription(final BsonDocument testDocument, final BsonDocument testCaseDocument) { boolean isValidTestCase = testCaseDocument.getBoolean("valid").getValue(); - String testSuiteDescription = testDocument.getString("description").getValue(); - String testCaseDescription = testCaseDocument.getString("description").getValue(); - return "[Valid input: " + isValidTestCase + "] " + testSuiteDescription + ": " + testCaseDescription; + String fileDescription = testDocument.getString("description").getValue(); + String testDescription = testCaseDocument.getString("description").getValue(); + return "[Valid input: " + isValidTestCase + "] " + fileDescription + ": " + testDescription; } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java index 32bd5385b37..f5b8e63f8c3 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/vector/VectorFunctionalTest.java @@ -18,11 +18,11 @@ import com.mongodb.MongoClientSettings; import com.mongodb.client.MongoClient; -import com.mongodb.client.vector.VectorAbstractFunctionalTest; +import com.mongodb.client.vector.AbstractVectorFunctionalTest; import com.mongodb.reactivestreams.client.MongoClients; import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; -public class VectorFunctionalTest extends VectorAbstractFunctionalTest { +public class VectorFunctionalTest extends AbstractVectorFunctionalTest { @Override protected MongoClient getMongoClient(final MongoClientSettings settings) { return new SyncMongoClient(MongoClients.create(settings)); diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java similarity index 76% rename from driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java rename to driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java index 2ad55bbed60..bfecbc60680 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorAbstractFunctionalTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java @@ -26,6 +26,7 @@ import com.mongodb.client.model.OperationTest; import org.bson.BsonBinary; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; import org.bson.Document; import org.bson.Float32Vector; import org.bson.Int8Vector; @@ -52,8 +53,9 @@ import static org.bson.Vector.DataType.PACKED_BIT; import static org.bson.codecs.configuration.CodecRegistries.fromProviders; import static org.bson.codecs.configuration.CodecRegistries.fromRegistries; +import static org.junit.jupiter.api.Assertions.assertEquals; -public abstract class VectorAbstractFunctionalTest extends OperationTest { +public abstract class AbstractVectorFunctionalTest extends OperationTest { private static final byte VECTOR_SUBTYPE = BsonBinarySubType.VECTOR.getValue(); private static final String FIELD_VECTOR = "vector"; @@ -77,11 +79,12 @@ public void setUp() { } @AfterEach + @SuppressWarnings("try") public void afterEach() { super.afterEach(); - if (mongoClient != null) { - mongoClient.close(); - } + try (MongoClient ignore = mongoClient) { + //NOOP + } } private static MongoClientSettings.Builder getMongoClientSettingsBuilder() { @@ -101,11 +104,11 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(fi documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); // when & then - Binary invalidVectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); - - IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); - Assertions.assertEquals("state should be: Padding must be 0 if vector is empty.", exception.getMessage()); + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 if vector is empty, but found: " + invalidPadding, exception.getMessage()); } @ParameterizedTest @@ -116,11 +119,11 @@ void shouldThrowExceptionForInvalidFloat32Padding(final byte invalidPadding) { documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); // when & then - Binary invalidVectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); - - IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); - Assertions.assertEquals("state should be: Padding must be 0 for FLOAT32 data type.", exception.getMessage()); + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 for FLOAT32 data type, but found: " + invalidPadding, exception.getMessage()); } @ParameterizedTest @@ -131,11 +134,11 @@ void shouldThrowExceptionForInvalidInt8Padding(final byte invalidPadding) { documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); // when & then - Binary invalidVectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); - - IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); - Assertions.assertEquals("state should be: Padding must be 0 for INT8 data type.", exception.getMessage()); + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be 0 for INT8 data type, but found: " + invalidPadding, exception.getMessage()); } @ParameterizedTest @@ -146,11 +149,11 @@ void shouldThrowExceptionForInvalidPackedBitPadding(final byte invalidPadding) { documentCollection.insertOne(new Document(FIELD_VECTOR, invalidVector)); // when & then - Binary invalidVectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); - - IllegalStateException exception = Assertions.assertThrows(IllegalStateException.class, invalidVectorBinary::asVector); - Assertions.assertEquals("state should be: Padding must be between 0 and 7 bits.", exception.getMessage()); + BsonInvalidOperationException exception = Assertions.assertThrows(BsonInvalidOperationException.class, ()-> { + findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); + }); + assertEquals("Padding must be between 0 and 7 bits, but found: " + invalidPadding, exception.getMessage()); } private static Stream provideValidVectors() { @@ -163,44 +166,31 @@ private static Stream provideValidVectors() { @ParameterizedTest @MethodSource("provideValidVectors") - void shouldStoreAndRetrieveValidVector(final Vector actualVector) { + void shouldStoreAndRetrieveValidVector(final Vector expectedVector) { // Given - Document documentToInsert = new Document(FIELD_VECTOR, actualVector); + Document documentToInsert = new Document(FIELD_VECTOR, expectedVector) + .append("otherField", 1); // to test that the next field is not affected documentCollection.insertOne(documentToInsert); // when & then - Binary vectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); + Vector actualVector = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); - Assertions.assertEquals(actualVector, vectorBinary.asVector()); + assertEquals(expectedVector, actualVector); } @ParameterizedTest @MethodSource("provideValidVectors") - void shouldStoreAndRetrieveValidVectorWithBinary(final Vector actualVector) { - // given - Document documentToInsert = new Document(FIELD_VECTOR, new Binary(actualVector)); - documentCollection.insertOne(documentToInsert); - - // when & then - Binary vectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); - - Assertions.assertEquals(actualVector, vectorBinary.asVector()); - } - - @ParameterizedTest - @MethodSource("provideValidVectors") - void shouldStoreAndRetrieveValidVectorWithBsonBinary(final Vector actualVector) { + void shouldStoreAndRetrieveValidVectorWithBsonBinary(final Vector expectedVector) { // Given - Document documentToInsert = new Document(FIELD_VECTOR, new BsonBinary(actualVector)); + Document documentToInsert = new Document(FIELD_VECTOR, new BsonBinary(expectedVector)); documentCollection.insertOne(documentToInsert); // when & then - Binary vectorBinary = findExactlyOne(documentCollection) - .get(FIELD_VECTOR, Binary.class); + Vector actualVector = findExactlyOne(documentCollection) + .get(FIELD_VECTOR, Vector.class); - Assertions.assertEquals(actualVector, vectorBinary.asVector()); + assertEquals(actualVector, actualVector); } @Test @@ -217,7 +207,7 @@ void shouldStoreAndRetrieveValidVectorWithFloatVectorPojo() { // then Assertions.assertNotNull(floatVectorPojo); - Assertions.assertEquals(vector, floatVectorPojo.getVector()); + assertEquals(vector, floatVectorPojo.getVector()); } @Test @@ -234,7 +224,7 @@ void shouldStoreAndRetrieveValidVectorWithInt8VectorPojo() { // then Assertions.assertNotNull(int8VectorPojo); - Assertions.assertEquals(vector, int8VectorPojo.getVector()); + assertEquals(vector, int8VectorPojo.getVector()); } @Test @@ -252,7 +242,7 @@ void shouldStoreAndRetrieveValidVectorWithPackedBitVectorPojo() { // then Assertions.assertNotNull(packedBitVectorPojo); - Assertions.assertEquals(vector, packedBitVectorPojo.getVector()); + assertEquals(vector, packedBitVectorPojo.getVector()); } @ParameterizedTest @@ -269,15 +259,13 @@ void shouldStoreAndRetrieveValidVectorWithGenericVectorPojo(final Vector actualV //then Assertions.assertNotNull(vectorPojo); - Assertions.assertEquals(actualVector, vectorPojo.getVector()); + assertEquals(actualVector, vectorPojo.getVector()); } private Document findExactlyOne(final MongoCollection collection) { List documents = new ArrayList<>(); collection.find().into(documents); - if (documents.size() != 1) { - throw new IllegalStateException("Expected exactly one document, but found: " + documents.size()); - } + assertEquals(1, documents.size(), "Expected exactly one document, but found: " + documents.size()); return documents.get(0); } diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java index a0cddb6dbca..63d756a8f35 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/VectorFunctionalTest.java @@ -20,7 +20,7 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; -public class VectorFunctionalTest extends VectorAbstractFunctionalTest { +public class VectorFunctionalTest extends AbstractVectorFunctionalTest { @Override protected MongoClient getMongoClient(final MongoClientSettings settings) { return MongoClients.create(settings); From 45d21edf243c738fa65298cc1e9c8d011150403e Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 28 Oct 2024 16:19:03 -0700 Subject: [PATCH 13/20] Change exception type. JAVA-5544 --- bson/src/main/org/bson/codecs/Float32VectorCodec.java | 4 ++-- bson/src/main/org/bson/codecs/Int8VectorCodec.java | 4 ++-- bson/src/main/org/bson/codecs/PackedBitVectorCodec.java | 5 +++-- bson/src/main/org/bson/codecs/VectorCodec.java | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/bson/src/main/org/bson/codecs/Float32VectorCodec.java b/bson/src/main/org/bson/codecs/Float32VectorCodec.java index 0933a00590a..a6df27e3f87 100644 --- a/bson/src/main/org/bson/codecs/Float32VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Float32VectorCodec.java @@ -16,9 +16,9 @@ package org.bson.codecs; -import org.bson.BSONException; import org.bson.BsonBinary; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; import org.bson.BsonReader; import org.bson.BsonWriter; import org.bson.Float32Vector; @@ -39,7 +39,7 @@ public Float32Vector decode(final BsonReader reader, final DecoderContext decode byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() diff --git a/bson/src/main/org/bson/codecs/Int8VectorCodec.java b/bson/src/main/org/bson/codecs/Int8VectorCodec.java index dc99877dd1a..a9a70f53746 100644 --- a/bson/src/main/org/bson/codecs/Int8VectorCodec.java +++ b/bson/src/main/org/bson/codecs/Int8VectorCodec.java @@ -16,9 +16,9 @@ package org.bson.codecs; -import org.bson.BSONException; import org.bson.BsonBinary; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; import org.bson.BsonReader; import org.bson.BsonWriter; import org.bson.Int8Vector; @@ -40,7 +40,7 @@ public Int8Vector decode(final BsonReader reader, final DecoderContext decoderCo byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() diff --git a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java index 1fb4deb5e20..6fcb9552955 100644 --- a/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java +++ b/bson/src/main/org/bson/codecs/PackedBitVectorCodec.java @@ -16,9 +16,9 @@ package org.bson.codecs; -import org.bson.BSONException; import org.bson.BsonBinary; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; import org.bson.BsonReader; import org.bson.BsonWriter; import org.bson.PackedBitVector; @@ -39,7 +39,8 @@ public PackedBitVector decode(final BsonReader reader, final DecoderContext deco byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); + throw new BsonInvalidOperationException( + "Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found: " + subType); } return reader.readBinaryData() diff --git a/bson/src/main/org/bson/codecs/VectorCodec.java b/bson/src/main/org/bson/codecs/VectorCodec.java index 4f4c1cf010d..87d847664dc 100644 --- a/bson/src/main/org/bson/codecs/VectorCodec.java +++ b/bson/src/main/org/bson/codecs/VectorCodec.java @@ -16,9 +16,9 @@ package org.bson.codecs; -import org.bson.BSONException; import org.bson.BsonBinary; import org.bson.BsonBinarySubType; +import org.bson.BsonInvalidOperationException; import org.bson.BsonReader; import org.bson.BsonWriter; import org.bson.Vector; @@ -39,7 +39,7 @@ public Vector decode(final BsonReader reader, final DecoderContext decoderContex byte subType = reader.peekBinarySubType(); if (subType != BsonBinarySubType.VECTOR.getValue()) { - throw new BSONException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found " + subType); + throw new BsonInvalidOperationException("Expected vector binary subtype " + BsonBinarySubType.VECTOR.getValue() + " but found " + subType); } return reader.readBinaryData() From b2952b7fa3b75493b732743f577ebf1bc9de1c48 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 28 Oct 2024 16:29:23 -0700 Subject: [PATCH 14/20] Update tests. JAVA-5544 --- bson/src/test/unit/org/bson/codecs/VectorCodecTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java index 5e0ea495f75..036066f6984 100644 --- a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -16,7 +16,7 @@ package org.bson.codecs; -import org.bson.BSONException; +import org.bson.BsonInvalidOperationException; import org.bson.BsonBinary; import org.bson.BsonBinaryReader; import org.bson.BsonBinarySubType; @@ -132,7 +132,7 @@ void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { // when & then Stream.of(new Float32VectorCodec(), new Int8VectorCodec(), new PackedBitVectorCodec()) .forEach(codec -> { - BSONException exception = assertThrows(BSONException.class, () -> + BsonInvalidOperationException exception = assertThrows(BsonInvalidOperationException.class, () -> codec.decode(reader, DecoderContext.builder().build())); assertEquals("Expected vector binary subtype 9 but found: " + subType.getValue(), exception.getMessage()); }); From 097bf1390051b2e84d356dfab14e7966e2497c2e Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 29 Oct 2024 15:57:42 -0700 Subject: [PATCH 15/20] Change javadoc. JAVA-5544 --- bson/src/main/org/bson/Vector.java | 3 +-- .../org/bson/codecs/ValueCodecProviderSpecification.groovy | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bson/src/main/org/bson/Vector.java b/bson/src/main/org/bson/Vector.java index 8b1548efd7f..d267387d727 100644 --- a/bson/src/main/org/bson/Vector.java +++ b/bson/src/main/org/bson/Vector.java @@ -27,8 +27,7 @@ * Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently * in BSON using a binary format. *

    - * NOTE: This class is intended to be treated as sealed. Any subclasses added outside the library are not guaranteed to - * function correctly in the current and future releases. + * NOTE: This class should be treated as sealed: it must not be extended or implemented by consumers of the library. * * @mongodb.server.release 6.0 * @see BsonBinary diff --git a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy index 872e4dd6142..23c46fb7b0b 100644 --- a/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy +++ b/bson/src/test/unit/org/bson/codecs/ValueCodecProviderSpecification.groovy @@ -36,6 +36,7 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicLong import java.util.regex.Pattern +//Codenarc @SuppressWarnings("VectorIsObsolete") class ValueCodecProviderSpecification extends Specification { private final provider = new ValueCodecProvider() From e65817c44019e98199e38f4ecc41d8b923583743 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 29 Oct 2024 16:04:02 -0700 Subject: [PATCH 16/20] Remove redundant method. JAVA-5544 --- bson/src/test/unit/org/bson/codecs/CodecTestCase.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java index 52b21e1e8db..17768d0d133 100644 --- a/bson/src/test/unit/org/bson/codecs/CodecTestCase.java +++ b/bson/src/test/unit/org/bson/codecs/CodecTestCase.java @@ -92,10 +92,6 @@ OutputBuffer encode(final Codec codec, final T value) { return buffer; } - void encode(final Codec codec, final T value, final BsonWriter writer) { - codec.encode(writer, value, EncoderContext.builder().build()); - } - T decode(final Codec codec, final OutputBuffer buffer) { BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(new ByteBufNIO(ByteBuffer.wrap(buffer.toByteArray())))); return codec.decode(reader, DecoderContext.builder().build()); From 0a168d92feec02121afdc23903cd874f993597e7 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 29 Oct 2024 18:25:39 -0700 Subject: [PATCH 17/20] Change validation message. JAVA-5544 --- bson/src/main/org/bson/internal/vector/VectorHelper.java | 2 +- bson/src/test/unit/org/bson/BsonBinaryVectorTest.java | 3 ++- .../client/vector/AbstractVectorFunctionalTest.java | 7 +++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index a634e7b7810..f637a9b8a68 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -141,7 +141,7 @@ private static byte[] encodeVector(final byte dType, final float[] vectorData) { } private static float[] decodeLittleEndianFloats(final byte[] encodedVector) { - isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type.", + isTrue("Byte array length must be a multiple of 4 for FLOAT32 data type, but found: " + encodedVector.length, (encodedVector.length - METADATA_SIZE) % Float.BYTES == 0); int vectorSize = encodedVector.length - METADATA_SIZE; diff --git a/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java b/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java index 485448a5bc4..fc42bfc093e 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java @@ -187,7 +187,8 @@ void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); }); - assertEquals("Byte array length must be a multiple of 4 for FLOAT32 data type.", thrown.getMessage()); + assertEquals("Byte array length must be a multiple of 4 for FLOAT32 data type, but found: " + invalidData.length, + thrown.getMessage()); } @ParameterizedTest diff --git a/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java index bfecbc60680..c3edf6983da 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/vector/AbstractVectorFunctionalTest.java @@ -81,10 +81,9 @@ public void setUp() { @AfterEach @SuppressWarnings("try") public void afterEach() { - super.afterEach(); - try (MongoClient ignore = mongoClient) { - //NOOP - } + try (MongoClient ignore = mongoClient) { + super.afterEach(); + } } private static MongoClientSettings.Builder getMongoClientSettingsBuilder() { From cbfae8fd0dc4d326598c1b507ee3e0850cb6dc70 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 29 Oct 2024 21:13:20 -0700 Subject: [PATCH 18/20] Remove redundant test. JAVA-5544 --- ...aryVectorTest.java => BsonBinaryTest.java} | 2 +- .../org/bson/codecs/DocumentCodecTest.java | 22 +++++-------------- .../unit/org/bson/codecs/VectorCodecTest.java | 17 ++++---------- 3 files changed, 10 insertions(+), 31 deletions(-) rename bson/src/test/unit/org/bson/{BsonBinaryVectorTest.java => BsonBinaryTest.java} (99%) diff --git a/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java b/bson/src/test/unit/org/bson/BsonBinaryTest.java similarity index 99% rename from bson/src/test/unit/org/bson/BsonBinaryVectorTest.java rename to bson/src/test/unit/org/bson/BsonBinaryTest.java index fc42bfc093e..6ab4b0202ca 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryVectorTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryTest.java @@ -30,7 +30,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.params.provider.Arguments.arguments; -class BsonBinaryVectorTest { +class BsonBinaryTest { private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue(); private static final byte INT8_DTYPE = Vector.DataType.INT8.getValue(); diff --git a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java index d407df31d37..67c6b561aa5 100644 --- a/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/DocumentCodecTest.java @@ -37,9 +37,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -47,13 +44,11 @@ import java.util.Date; import java.util.HashSet; import java.util.List; -import java.util.stream.Stream; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.params.provider.Arguments.arguments; public class DocumentCodecTest { private BasicOutputBuffer buffer; @@ -70,9 +65,8 @@ public void tearDown() { writer.close(); } - @ParameterizedTest - @MethodSource("provideVectorsForRoundTrip") - public void testPrimitiveBSONTypeCodecs(final Vector vector) throws IOException { + @Test + public void testPrimitiveBSONTypeCodecs() throws IOException { DocumentCodec documentCodec = new DocumentCodec(); Document doc = new Document(); doc.put("oid", new ObjectId()); @@ -87,7 +81,9 @@ public void testPrimitiveBSONTypeCodecs(final Vector vector) throws IOException doc.put("code", new Code("var i = 0")); doc.put("minkey", new MinKey()); doc.put("maxkey", new MaxKey()); - doc.put("vector", vector); + doc.put("vectorFloat", Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f})); + doc.put("vectorInt8", Vector.int8Vector(new byte[]{10, 20, 30, 40})); + doc.put("vectorPackedBit", Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3)); // doc.put("pattern", Pattern.compile("^hello")); // TODO: Pattern doesn't override equals method! doc.put("null", null); @@ -98,14 +94,6 @@ public void testPrimitiveBSONTypeCodecs(final Vector vector) throws IOException assertEquals(doc, decodedDocument); } - private static Stream provideVectorsForRoundTrip() { - return Stream.of( - arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f})), - arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40})), - arguments(Vector.packedBitVector(new byte[]{(byte) 0b10101010, (byte) 0b01010101}, (byte) 3)) - ); - } - @Test public void testIterableEncoding() throws IOException { DocumentCodec documentCodec = new DocumentCodec(); diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java index 036066f6984..150a42898c5 100644 --- a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -51,7 +51,7 @@ class VectorCodecTest extends CodecTestCase { - private static Stream provideVectorsAndCodecsForRoundTrip() { + private static Stream provideVectorsAndCodecs() { return Stream.of( arguments(Vector.floatVector(new float[]{1.1f, 2.2f, 3.3f}), new Float32VectorCodec(), Float32Vector.class), arguments(Vector.int8Vector(new byte[]{10, 20, 30, 40}), new Int8VectorCodec(), Int8Vector.class), @@ -63,7 +63,7 @@ private static Stream provideVectorsAndCodecsForRoundTrip() { } @ParameterizedTest - @MethodSource("provideVectorsAndCodecsForRoundTrip") + @MethodSource("provideVectorsAndCodecs") void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorCodec) throws IOException { // given BsonBinary bsonBinary = new BsonBinary(vectorToEncode); @@ -99,7 +99,7 @@ void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorC } @ParameterizedTest - @MethodSource("provideVectorsAndCodecsForRoundTrip") + @MethodSource("provideVectorsAndCodecs") void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorCodec) { // given OutputBuffer buffer = new BasicOutputBuffer(); @@ -140,7 +140,7 @@ void shouldThrowExceptionForInvalidSubType(final BsonBinarySubType subType) { @ParameterizedTest - @MethodSource("provideVectorsAndCodecsForRoundTrip") + @MethodSource("provideVectorsAndCodecs") void shouldReturnCorrectEncoderClass(final Vector vector, final Codec codec, final Class expectedEncoderClass) { @@ -150,13 +150,4 @@ void shouldReturnCorrectEncoderClass(final Vector vector, // then assertEquals(expectedEncoderClass, encoderClass); } - - private static Stream> provideVectorsCodec() { - return Stream.of( - new VectorCodec(), - new Float32VectorCodec(), - new Int8VectorCodec(), - new PackedBitVectorCodec() - ); - } } From 952ce3555cb00f5a35c53f29cd75c12f1833cea9 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Tue, 29 Oct 2024 21:40:09 -0700 Subject: [PATCH 19/20] Clarify comments, adjust assertion. JAVA-5544 --- .../unit/org/bson/codecs/VectorCodecTest.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java index 150a42898c5..bf33af90cae 100644 --- a/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java +++ b/bson/src/test/unit/org/bson/codecs/VectorCodecTest.java @@ -16,12 +16,12 @@ package org.bson.codecs; -import org.bson.BsonInvalidOperationException; import org.bson.BsonBinary; import org.bson.BsonBinaryReader; import org.bson.BsonBinarySubType; import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; +import org.bson.BsonInvalidOperationException; import org.bson.BsonType; import org.bson.BsonWriter; import org.bson.ByteBufNIO; @@ -45,6 +45,7 @@ import static org.bson.BsonHelper.toBson; import static org.bson.assertions.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -69,12 +70,12 @@ void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorC BsonBinary bsonBinary = new BsonBinary(vectorToEncode); byte[] encodedVector = bsonBinary.getData(); ByteArrayOutputStream expectedStream = new ByteArrayOutputStream(); - // Start of document with total length of 4 bytes (little-endian format) - byte totalDocumentLength = (byte) (14 + encodedVector.length); - expectedStream.write(new byte[]{totalDocumentLength, 0, 0, 0}); - // Bson type for vector + // Total length of a Document (int 32). It is 0, because we do not expect + // codec to write the end of the document (that is when we back-patch the length of the document). + expectedStream.write(new byte[]{0, 0, 0, 0}); + // Bson type expectedStream.write((byte) BsonType.BINARY.getValue()); - // Field name "b4" (ASCII for 'b', '4', null terminator) + // Field name "b4" expectedStream.write(new byte[]{98, 52, 0}); // Total length of binary data (little-endian format) expectedStream.write(new byte[]{(byte) encodedVector.length, 0, 0, 0}); @@ -82,8 +83,6 @@ void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorC expectedStream.write(BsonBinarySubType.VECTOR.getValue()); // Actual BSON binary data expectedStream.write(encodedVector); - // End of document - expectedStream.write(0); OutputBuffer buffer = new BasicOutputBuffer(); BsonWriter writer = new BsonBinaryWriter(buffer); @@ -92,7 +91,6 @@ void shouldEncodeVector(final Vector vectorToEncode, final Codec vectorC // when vectorCodec.encode(writer, vectorToEncode, EncoderContext.builder().build()); - writer.writeEndDocument(); // then assertArrayEquals(expectedStream.toByteArray(), buffer.toByteArray()); @@ -116,6 +114,7 @@ void shouldDecodeVector(final Vector vectorToDecode, final Codec vectorC Vector decodedVector = vectorCodec.decode(reader, DecoderContext.builder().build()); // then + assertDoesNotThrow(reader::readEndDocument); assertNotNull(decodedVector); assertEquals(vectorToDecode, decodedVector); } From 3422dcdd0408645e83868729e1bf67b97c2cf8c3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 10:44:08 -0700 Subject: [PATCH 20/20] Add tests and change validation message. JAVA-5544 --- .../org/bson/codecs/ContainerCodecHelper.java | 18 ++++++++---------- .../src/main/org/bson/internal/UuidHelper.java | 6 ++++++ .../org/bson/internal/vector/VectorHelper.java | 2 +- .../src/test/unit/org/bson/BsonBinaryTest.java | 14 ++++++++++++++ 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java index 827858d11cb..b454206d5e8 100644 --- a/bson/src/main/org/bson/codecs/ContainerCodecHelper.java +++ b/bson/src/main/org/bson/codecs/ContainerCodecHelper.java @@ -30,6 +30,8 @@ import java.util.Arrays; import java.util.UUID; +import static org.bson.internal.UuidHelper.isLegacyUUID; + /** * Helper methods for Codec implementations for containers, e.g. {@code Map} and {@code Iterable}. */ @@ -48,7 +50,8 @@ static Object readValue(final BsonReader reader, final DecoderContext decoderCon if (bsonType == BsonType.BINARY) { byte binarySubType = reader.peekBinarySubType(); - currentCodec = getBinarySubTypeCodec(reader, + currentCodec = getBinarySubTypeCodec( + reader, uuidRepresentation, registry, binarySubType, currentCodec); @@ -62,21 +65,17 @@ private static Codec getBinarySubTypeCodec(final BsonReader reader, final UuidRepresentation uuidRepresentation, final CodecRegistry registry, final byte binarySubType, - final Codec currentTypeCodec) { + final Codec binaryTypeCodec) { if (binarySubType == BsonBinarySubType.VECTOR.getValue()) { Codec vectorCodec = registry.get(Vector.class, registry); if (vectorCodec != null) { return vectorCodec; } - } - - if (reader.peekBinarySize() == 16) { + } else if (reader.peekBinarySize() == 16) { switch (binarySubType) { case 3: - if (uuidRepresentation == UuidRepresentation.JAVA_LEGACY - || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY - || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY) { + if (isLegacyUUID(uuidRepresentation)) { return registry.get(UUID.class); } break; @@ -90,7 +89,7 @@ private static Codec getBinarySubTypeCodec(final BsonReader reader, } } - return currentTypeCodec; + return binaryTypeCodec; } static Codec getCodec(final CodecRegistry codecRegistry, final Type type) { @@ -104,7 +103,6 @@ static Codec getCodec(final CodecRegistry codecRegistry, final Type type) { } } - private ContainerCodecHelper() { } } diff --git a/bson/src/main/org/bson/internal/UuidHelper.java b/bson/src/main/org/bson/internal/UuidHelper.java index efe3d5b5812..9c46614b56e 100644 --- a/bson/src/main/org/bson/internal/UuidHelper.java +++ b/bson/src/main/org/bson/internal/UuidHelper.java @@ -124,6 +124,12 @@ public static UUID decodeBinaryToUuid(final byte[] data, final byte type, final return new UUID(readLongFromArrayBigEndian(localData, 0), readLongFromArrayBigEndian(localData, 8)); } + public static boolean isLegacyUUID(final UuidRepresentation uuidRepresentation) { + return uuidRepresentation == UuidRepresentation.JAVA_LEGACY + || uuidRepresentation == UuidRepresentation.C_SHARP_LEGACY + || uuidRepresentation == UuidRepresentation.PYTHON_LEGACY; + } + private UuidHelper() { } } diff --git a/bson/src/main/org/bson/internal/vector/VectorHelper.java b/bson/src/main/org/bson/internal/vector/VectorHelper.java index f637a9b8a68..9dbf583d2b0 100644 --- a/bson/src/main/org/bson/internal/vector/VectorHelper.java +++ b/bson/src/main/org/bson/internal/vector/VectorHelper.java @@ -72,7 +72,7 @@ public static byte[] encodeVectorToBinary(final Vector vector) { * encodedVector is not mutated nor stored in the returned {@link Vector}. */ public static Vector decodeBinaryToVector(final byte[] encodedVector) { - isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE); + isTrue("Vector encoded array length must be at least 2, but found: " + encodedVector.length, encodedVector.length >= METADATA_SIZE); Vector.DataType dataType = determineVectorDType(encodedVector[0]); byte padding = encodedVector[1]; switch (dataType) { diff --git a/bson/src/test/unit/org/bson/BsonBinaryTest.java b/bson/src/test/unit/org/bson/BsonBinaryTest.java index 6ab4b0202ca..029c611c594 100644 --- a/bson/src/test/unit/org/bson/BsonBinaryTest.java +++ b/bson/src/test/unit/org/bson/BsonBinaryTest.java @@ -191,6 +191,20 @@ void shouldThrowExceptionForInvalidFloatArrayLengthWhenDecode() { thrown.getMessage()); } + @ParameterizedTest + @ValueSource(ints = {0, 1}) + void shouldThrowExceptionWhenEncodedVectorLengthIsLessThenMetadataLength(final int encodedVectorLength) { + // given + byte[] invalidData = new byte[encodedVectorLength]; + + // when & Then + BsonInvalidOperationException thrown = assertThrows(BsonInvalidOperationException.class, () -> { + new BsonBinary(BsonBinarySubType.VECTOR, invalidData).asVector(); + }); + assertEquals("Vector encoded array length must be at least 2, but found: " + encodedVectorLength, + thrown.getMessage()); + } + @ParameterizedTest @ValueSource(bytes = {-1, 1}) void shouldThrowExceptionForInvalidFloatArrayPaddingWhenDecode(final byte invalidPadding) {