|
1 | 1 | import pytest
|
2 |
| - |
| 2 | +import sys |
3 | 3 | import numpy as np
|
| 4 | +import operator |
| 5 | + |
4 | 6 | from quaddtype import QuadPrecDType, QuadPrecision
|
5 | 7 |
|
6 | 8 |
|
7 |
| -def test_dtype(): |
8 |
| - a = QuadPrecision("1.63") |
9 |
| - assert f"{np.array([a], dtype=QuadPrecDType).dtype}" == "QuadPrecDType()" |
| 9 | +def test_create_scalar_simple(): |
| 10 | + assert isinstance(QuadPrecision("12.0"), QuadPrecision) |
| 11 | + assert isinstance(QuadPrecision(1.63), QuadPrecision) |
| 12 | + assert isinstance(QuadPrecision(1), QuadPrecision) |
| 13 | + |
| 14 | + |
| 15 | +def test_basic_equality(): |
| 16 | + assert QuadPrecision("12") == QuadPrecision( |
| 17 | + "12.0") == QuadPrecision("12.00") |
| 18 | + |
| 19 | + |
| 20 | +@pytest.mark.parametrize("val", ["123532.543", "12893283.5"]) |
| 21 | +def test_scalar_repr(val): |
| 22 | + expected = f"QuadPrecision('{str(QuadPrecision(val))}')" |
| 23 | + assert repr(QuadPrecision(val)) == expected |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow"]) |
| 27 | +@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"]) |
| 28 | +def test_binary_ops(op, other): |
| 29 | + op_func = getattr(operator, op) |
| 30 | + quad_a = QuadPrecision("12.5") |
| 31 | + quad_b = QuadPrecision(other) |
| 32 | + float_a = 12.5 |
| 33 | + float_b = float(other) |
| 34 | + |
| 35 | + quad_result = op_func(quad_a, quad_b) |
| 36 | + float_result = op_func(float_a, float_b) |
| 37 | + |
| 38 | + assert np.abs(np.float64(quad_result) - float_result) < 1e-10 |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"]) |
| 42 | +@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"]) |
| 43 | +def test_comparisons(op, other): |
| 44 | + op_func = getattr(operator, op) |
| 45 | + quad_a = QuadPrecision("12.5") |
| 46 | + quad_b = QuadPrecision(other) |
| 47 | + float_a = 12.5 |
| 48 | + float_b = float(other) |
| 49 | + |
| 50 | + assert op_func(quad_a, quad_b) == op_func(float_a, float_b) |
| 51 | + |
| 52 | + |
| 53 | +@pytest.mark.parametrize("op, val, expected", [ |
| 54 | + ("neg", "3.0", "-3.0"), |
| 55 | + ("neg", "-3.0", "3.0"), |
| 56 | + ("pos", "3.0", "3.0"), |
| 57 | + ("pos", "-3.0", "-3.0"), |
| 58 | + ("abs", "3.0", "3.0"), |
| 59 | + ("abs", "-3.0", "3.0"), |
| 60 | + ("neg", "12.5", "-12.5"), |
| 61 | + ("pos", "100.0", "100.0"), |
| 62 | + ("abs", "-25.5", "25.5"), |
| 63 | +]) |
| 64 | +def test_unary_ops(op, val, expected): |
| 65 | + quad_val = QuadPrecision(val) |
| 66 | + expected_val = QuadPrecision(expected) |
| 67 | + |
| 68 | + if op == "neg": |
| 69 | + result = -quad_val |
| 70 | + elif op == "pos": |
| 71 | + result = +quad_val |
| 72 | + elif op == "abs": |
| 73 | + result = abs(quad_val) |
| 74 | + else: |
| 75 | + raise ValueError(f"Unsupported operation: {op}") |
| 76 | + |
| 77 | + assert result == expected_val, f"{op}({val}) should be {expected}, but got {result}" |
| 78 | + |
| 79 | + |
| 80 | +def test_nan_and_inf(): |
| 81 | + assert (QuadPrecision("nan") != QuadPrecision("nan")) == ( |
| 82 | + QuadPrecision("nan") == QuadPrecision("nan")) |
| 83 | + assert QuadPrecision("inf") > QuadPrecision("1e1000") |
| 84 | + assert QuadPrecision("-inf") < QuadPrecision("-1e1000") |
| 85 | + |
| 86 | + |
| 87 | +def test_dtype_creation(): |
| 88 | + dtype = QuadPrecDType() |
| 89 | + assert isinstance(dtype, np.dtype) |
| 90 | + assert dtype.name == 'QuadPrecDType128' |
| 91 | + |
| 92 | + |
| 93 | +def test_array_creation(): |
| 94 | + arr = np.array([1, 2, 3], dtype=QuadPrecDType()) |
| 95 | + assert arr.dtype.name == 'QuadPrecDType128' |
| 96 | + assert all(isinstance(x, QuadPrecision) for x in arr) |
| 97 | + |
| 98 | + |
| 99 | +def test_array_operations(): |
| 100 | + arr1 = np.array( |
| 101 | + [QuadPrecision("1.5"), QuadPrecision("2.5"), QuadPrecision("3.5")]) |
| 102 | + arr2 = np.array( |
| 103 | + [QuadPrecision("0.5"), QuadPrecision("1.0"), QuadPrecision("1.5")]) |
| 104 | + |
| 105 | + result = arr1 + arr2 |
| 106 | + expected = np.array( |
| 107 | + [QuadPrecision("2.0"), QuadPrecision("3.5"), QuadPrecision("5.0")]) |
| 108 | + assert np.all(result == expected) |
0 commit comments