Skip to content

Commit d13f863

Browse files
committed
adding more scalar tests
1 parent 299568b commit d13f863

File tree

3 files changed

+111
-5
lines changed

3 files changed

+111
-5
lines changed

quaddtype/quaddtype/src/ops.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ quad_negative(Sleef_quad *op, Sleef_quad *out)
1010
return 0;
1111
}
1212

13+
static int
14+
quad_positive(Sleef_quad *op, Sleef_quad *out)
15+
{
16+
*out = *op;
17+
return 0;
18+
}
19+
1320
static inline int
1421
quad_absolute(Sleef_quad *op, Sleef_quad *out)
1522
{

quaddtype/quaddtype/src/scalar_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ PyNumberMethods quad_as_scalar = {
165165
.nb_true_divide = (binaryfunc)quad_binary_func<quad_div>,
166166
.nb_power = (ternaryfunc)quad_binary_func<quad_pow>,
167167
.nb_negative = (unaryfunc)quad_unary_func<quad_negative>,
168-
.nb_positive = (unaryfunc)quad_unary_func<quad_absolute>,
168+
.nb_positive = (unaryfunc)quad_unary_func<quad_positive>,
169169
.nb_absolute = (unaryfunc)quad_unary_func<quad_absolute>,
170170
.nb_bool = (inquiry)quad_nonzero,
171171
};

quaddtype/tests/test_quaddtype.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,108 @@
11
import pytest
2-
2+
import sys
33
import numpy as np
4+
import operator
5+
46
from quaddtype import QuadPrecDType, QuadPrecision
57

68

7-
def test_dtype():
8-
a = QuadPrecision("1.63")
9-
assert f"{np.array([a], dtype=QuadPrecDType).dtype}" == "QuadPrecDType()"
9+
def test_create_scalar_simple():
10+
assert isinstance(QuadPrecision("12.0"), QuadPrecision)
11+
assert isinstance(QuadPrecision(1.63), QuadPrecision)
12+
assert isinstance(QuadPrecision(1), QuadPrecision)
13+
14+
15+
def test_basic_equality():
16+
assert QuadPrecision("12") == QuadPrecision(
17+
"12.0") == QuadPrecision("12.00")
18+
19+
20+
@pytest.mark.parametrize("val", ["123532.543", "12893283.5"])
21+
def test_scalar_repr(val):
22+
expected = f"QuadPrecision('{str(QuadPrecision(val))}')"
23+
assert repr(QuadPrecision(val)) == expected
24+
25+
26+
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow"])
27+
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"])
28+
def test_binary_ops(op, other):
29+
op_func = getattr(operator, op)
30+
quad_a = QuadPrecision("12.5")
31+
quad_b = QuadPrecision(other)
32+
float_a = 12.5
33+
float_b = float(other)
34+
35+
quad_result = op_func(quad_a, quad_b)
36+
float_result = op_func(float_a, float_b)
37+
38+
assert np.abs(np.float64(quad_result) - float_result) < 1e-10
39+
40+
41+
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
42+
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"])
43+
def test_comparisons(op, other):
44+
op_func = getattr(operator, op)
45+
quad_a = QuadPrecision("12.5")
46+
quad_b = QuadPrecision(other)
47+
float_a = 12.5
48+
float_b = float(other)
49+
50+
assert op_func(quad_a, quad_b) == op_func(float_a, float_b)
51+
52+
53+
@pytest.mark.parametrize("op, val, expected", [
54+
("neg", "3.0", "-3.0"),
55+
("neg", "-3.0", "3.0"),
56+
("pos", "3.0", "3.0"),
57+
("pos", "-3.0", "-3.0"),
58+
("abs", "3.0", "3.0"),
59+
("abs", "-3.0", "3.0"),
60+
("neg", "12.5", "-12.5"),
61+
("pos", "100.0", "100.0"),
62+
("abs", "-25.5", "25.5"),
63+
])
64+
def test_unary_ops(op, val, expected):
65+
quad_val = QuadPrecision(val)
66+
expected_val = QuadPrecision(expected)
67+
68+
if op == "neg":
69+
result = -quad_val
70+
elif op == "pos":
71+
result = +quad_val
72+
elif op == "abs":
73+
result = abs(quad_val)
74+
else:
75+
raise ValueError(f"Unsupported operation: {op}")
76+
77+
assert result == expected_val, f"{op}({val}) should be {expected}, but got {result}"
78+
79+
80+
def test_nan_and_inf():
81+
assert (QuadPrecision("nan") != QuadPrecision("nan")) == (
82+
QuadPrecision("nan") == QuadPrecision("nan"))
83+
assert QuadPrecision("inf") > QuadPrecision("1e1000")
84+
assert QuadPrecision("-inf") < QuadPrecision("-1e1000")
85+
86+
87+
def test_dtype_creation():
88+
dtype = QuadPrecDType()
89+
assert isinstance(dtype, np.dtype)
90+
assert dtype.name == 'QuadPrecDType128'
91+
92+
93+
def test_array_creation():
94+
arr = np.array([1, 2, 3], dtype=QuadPrecDType())
95+
assert arr.dtype.name == 'QuadPrecDType128'
96+
assert all(isinstance(x, QuadPrecision) for x in arr)
97+
98+
99+
def test_array_operations():
100+
arr1 = np.array(
101+
[QuadPrecision("1.5"), QuadPrecision("2.5"), QuadPrecision("3.5")])
102+
arr2 = np.array(
103+
[QuadPrecision("0.5"), QuadPrecision("1.0"), QuadPrecision("1.5")])
104+
105+
result = arr1 + arr2
106+
expected = np.array(
107+
[QuadPrecision("2.0"), QuadPrecision("3.5"), QuadPrecision("5.0")])
108+
assert np.all(result == expected)

0 commit comments

Comments
 (0)