diff --git a/bson/binary.py b/bson/binary.py index aab59cccbc..eefc5697b6 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -14,7 +14,6 @@ from __future__ import annotations import struct -from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, Union, overload from uuid import UUID @@ -227,7 +226,6 @@ class BinaryVectorDtype(Enum): PACKED_BIT = b"\x10" -@dataclass class BinaryVector: """Vector of numbers along with metadata for binary interoperability. .. versionadded:: 4.10 @@ -247,6 +245,16 @@ def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, paddin self.dtype = dtype self.padding = padding + def __repr__(self) -> str: + return f"BinaryVector(dtype={self.dtype}, padding={self.padding}, data={self.data})" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, BinaryVector): + return False + return ( + self.dtype == other.dtype and self.padding == other.padding and self.data == other.data + ) + class Binary(bytes): """Representation of BSON binary data. diff --git a/test/test_bson.py b/test/test_bson.py index e704efe451..6f26856b00 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -809,6 +809,64 @@ def test_vector(self): dtype=BinaryVectorDtype.PACKED_BIT, ) # type: ignore[call-overload] + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + def test_binaryvector_repr(self): + """Tests of repr(BinaryVector)""" + + data = [1 / 127, -7 / 6] + one = BinaryVector(data, BinaryVectorDtype.FLOAT32) + self.assertEqual( + repr(one), f"BinaryVector(dtype=BinaryVectorDtype.FLOAT32, padding=0, data={data})" + ) + self.assertRepr(one) + + data = [127, 7] + two = BinaryVector(data, BinaryVectorDtype.INT8) + self.assertEqual( + repr(two), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})" + ) + self.assertRepr(two) + + three = BinaryVector(data, BinaryVectorDtype.INT8, padding=0) + self.assertEqual( + repr(three), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})" + ) + self.assertRepr(three) + + four = BinaryVector(data, BinaryVectorDtype.PACKED_BIT, padding=3) + self.assertEqual( + repr(four), f"BinaryVector(dtype=BinaryVectorDtype.PACKED_BIT, padding=3, data={data})" + ) + self.assertRepr(four) + + zero = BinaryVector([], BinaryVectorDtype.INT8) + self.assertEqual( + repr(zero), "BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data=[])" + ) + self.assertRepr(zero) + + def test_binaryvector_equality(self): + """Tests of == __eq__""" + self.assertEqual( + BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([1.2, 1 - 1.0 / 3.0], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertNotEqual( + BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([1.2, 6.0 / 9.0], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertEqual( + BinaryVector([], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertNotEqual( + BinaryVector([1], BinaryVectorDtype.INT8), BinaryVector([2], BinaryVectorDtype.INT8) + ) + def test_unicode_regex(self): """Tests we do not get a segfault for C extension on unicode RegExs. This had been happening.