Skip to content

Commit 2b58be2

Browse files
committed
Remove custom Complex type
1 parent 5e69937 commit 2b58be2

File tree

4 files changed

+15
-176
lines changed

4 files changed

+15
-176
lines changed

pytensor/scalar/basic.py

Lines changed: 11 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,6 @@ class ScalarType(CType, HasDataType, HasShape):
279279
Analogous to TensorType, but for zero-dimensional objects.
280280
Maps directly to C primitives.
281281
282-
TODO: refactor to be named ScalarType for consistency with TensorType.
283-
284282
"""
285283

286284
__props__ = ("dtype",)
@@ -350,11 +348,14 @@ def c_element_type(self):
350348
return self.dtype_specs()[1]
351349

352350
def c_headers(self, c_compiler=None, **kwargs):
353-
l = ["<math.h>"]
354-
# These includes are needed by ScalarType and TensorType,
355-
# we declare them here and they will be re-used by TensorType
356-
l.append("<numpy/arrayobject.h>")
357-
l.append("<numpy/arrayscalars.h>")
351+
l = [
352+
"<math.h>",
353+
# These includes are needed by ScalarType and TensorType,
354+
# we declare them here and they will be re-used by TensorType
355+
"<numpy/arrayobject.h>",
356+
"<numpy/arrayscalars.h>",
357+
"<numpy/npy_2_complexcompat.h>",
358+
]
358359
if config.lib__amblibm and c_compiler.supports_amdlibm:
359360
l += ["<amdlibm.h>"]
360361
return l
@@ -396,8 +397,8 @@ def dtype_specs(self):
396397
"float16": (np.float16, "npy_float16", "Float16"),
397398
"float32": (np.float32, "npy_float32", "Float32"),
398399
"float64": (np.float64, "npy_float64", "Float64"),
399-
"complex128": (np.complex128, "pytensor_complex128", "Complex128"),
400-
"complex64": (np.complex64, "pytensor_complex64", "Complex64"),
400+
"complex128": (np.complex128, "npy_complex128", "Complex128"),
401+
"complex64": (np.complex64, "npy_complex64", "Complex64"),
401402
"bool": (np.bool_, "npy_bool", "Bool"),
402403
"uint8": (np.uint8, "npy_uint8", "UInt8"),
403404
"int8": (np.int8, "npy_int8", "Int8"),
@@ -506,171 +507,11 @@ def c_sync(self, name, sub):
506507
def c_cleanup(self, name, sub):
507508
return ""
508509

509-
def c_support_code(self, **kwargs):
510-
if self.dtype.startswith("complex"):
511-
cplx_types = ["pytensor_complex64", "pytensor_complex128"]
512-
real_types = [
513-
"npy_int8",
514-
"npy_int16",
515-
"npy_int32",
516-
"npy_int64",
517-
"npy_float32",
518-
"npy_float64",
519-
]
520-
# If the 'int' C type is not exactly the same as an existing
521-
# 'npy_intX', some C code may not compile, e.g. when assigning
522-
# the value 0 (cast to 'int' in C) to an PyTensor_complex64.
523-
if np.dtype("intc").num not in [np.dtype(d[4:]).num for d in real_types]:
524-
# In that case we add the 'int' type to the real types.
525-
real_types.append("int")
526-
527-
template = """
528-
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
529-
{
530-
typedef pytensor_complex%(nbits)s complex_type;
531-
typedef npy_float%(half_nbits)s scalar_type;
532-
533-
complex_type operator +(const complex_type &y) const {
534-
complex_type ret;
535-
ret.real = this->real + y.real;
536-
ret.imag = this->imag + y.imag;
537-
return ret;
538-
}
539-
540-
complex_type operator -() const {
541-
complex_type ret;
542-
ret.real = -this->real;
543-
ret.imag = -this->imag;
544-
return ret;
545-
}
546-
bool operator ==(const complex_type &y) const {
547-
return (this->real == y.real) && (this->imag == y.imag);
548-
}
549-
bool operator ==(const scalar_type &y) const {
550-
return (this->real == y) && (this->imag == 0);
551-
}
552-
complex_type operator -(const complex_type &y) const {
553-
complex_type ret;
554-
ret.real = this->real - y.real;
555-
ret.imag = this->imag - y.imag;
556-
return ret;
557-
}
558-
complex_type operator *(const complex_type &y) const {
559-
complex_type ret;
560-
ret.real = this->real * y.real - this->imag * y.imag;
561-
ret.imag = this->real * y.imag + this->imag * y.real;
562-
return ret;
563-
}
564-
complex_type operator /(const complex_type &y) const {
565-
complex_type ret;
566-
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
567-
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
568-
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
569-
return ret;
570-
}
571-
template <typename T>
572-
complex_type& operator =(const T& y);
573-
574-
pytensor_complex%(nbits)s() {}
575-
576-
template <typename T>
577-
pytensor_complex%(nbits)s(const T& y) { *this = y; }
578-
579-
template <typename TR, typename TI>
580-
pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
581-
};
582-
"""
583-
584-
def operator_eq_real(mytype, othertype):
585-
return f"""
586-
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
587-
{{ this->real=y; this->imag=0; return *this; }}
588-
"""
589-
590-
def operator_eq_cplx(mytype, othertype):
591-
return f"""
592-
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
593-
{{ this->real=y.real; this->imag=y.imag; return *this; }}
594-
"""
595-
596-
operator_eq = "".join(
597-
operator_eq_real(ctype, rtype)
598-
for ctype in cplx_types
599-
for rtype in real_types
600-
) + "".join(
601-
operator_eq_cplx(ctype1, ctype2)
602-
for ctype1 in cplx_types
603-
for ctype2 in cplx_types
604-
)
605-
606-
# We are not using C++ generic templating here, because this would
607-
# generate two different functions for adding a complex64 and a
608-
# complex128, one returning a complex64, the other a complex128,
609-
# and the compiler complains it is ambiguous.
610-
# Instead, we generate code for known and safe types only.
611-
612-
def operator_plus_real(mytype, othertype):
613-
return f"""
614-
const {mytype} operator+(const {mytype} &x, const {othertype} &y)
615-
{{ return {mytype}(x.real+y, x.imag); }}
616-
617-
const {mytype} operator+(const {othertype} &y, const {mytype} &x)
618-
{{ return {mytype}(x.real+y, x.imag); }}
619-
"""
620-
621-
operator_plus = "".join(
622-
operator_plus_real(ctype, rtype)
623-
for ctype in cplx_types
624-
for rtype in real_types
625-
)
626-
627-
def operator_minus_real(mytype, othertype):
628-
return f"""
629-
const {mytype} operator-(const {mytype} &x, const {othertype} &y)
630-
{{ return {mytype}(x.real-y, x.imag); }}
631-
632-
const {mytype} operator-(const {othertype} &y, const {mytype} &x)
633-
{{ return {mytype}(y-x.real, -x.imag); }}
634-
"""
635-
636-
operator_minus = "".join(
637-
operator_minus_real(ctype, rtype)
638-
for ctype in cplx_types
639-
for rtype in real_types
640-
)
641-
642-
def operator_mul_real(mytype, othertype):
643-
return f"""
644-
const {mytype} operator*(const {mytype} &x, const {othertype} &y)
645-
{{ return {mytype}(x.real*y, x.imag*y); }}
646-
647-
const {mytype} operator*(const {othertype} &y, const {mytype} &x)
648-
{{ return {mytype}(x.real*y, x.imag*y); }}
649-
"""
650-
651-
operator_mul = "".join(
652-
operator_mul_real(ctype, rtype)
653-
for ctype in cplx_types
654-
for rtype in real_types
655-
)
656-
657-
return (
658-
template % dict(nbits=64, half_nbits=32)
659-
+ template % dict(nbits=128, half_nbits=64)
660-
+ operator_eq
661-
+ operator_plus
662-
+ operator_minus
663-
+ operator_mul
664-
)
665-
666-
else:
667-
return ""
668-
669510
def c_init_code(self, **kwargs):
670511
return ["import_array();"]
671512

672513
def c_code_cache_version(self):
673-
return (13, np.__version__)
514+
return (14, np.__version__)
674515

675516
def get_shape_info(self, obj):
676517
return obj.itemsize

pytensor/sparse/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class SparseTensorType(TensorType, HasDataType):
5959
"int32": (int, "npy_int32", "NPY_INT32"),
6060
"uint64": (int, "npy_uint64", "NPY_UINT64"),
6161
"int64": (int, "npy_int64", "NPY_INT64"),
62-
"complex128": (complex, "pytensor_complex128", "NPY_COMPLEX128"),
63-
"complex64": (complex, "pytensor_complex64", "NPY_COMPLEX64"),
62+
"complex128": (complex, "npy_complex128", "NPY_COMPLEX128"),
63+
"complex64": (complex, "npy_complex64", "NPY_COMPLEX64"),
6464
}
6565
ndim = 2
6666

pytensor/tensor/elemwise_cgen.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,6 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
166166
167167
"""
168168
type = dtype.upper()
169-
if type.startswith("PYTENSOR_COMPLEX"):
170-
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
171169
nd = len(loop_orders[0])
172170
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
173171

pytensor/tensor/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
"int32": (int, "npy_int32", "NPY_INT32"),
5151
"uint64": (int, "npy_uint64", "NPY_UINT64"),
5252
"int64": (int, "npy_int64", "NPY_INT64"),
53-
"complex128": (complex, "pytensor_complex128", "NPY_COMPLEX128"),
54-
"complex64": (complex, "pytensor_complex64", "NPY_COMPLEX64"),
53+
"complex128": (complex, "npy_complex128", "NPY_COMPLEX128"),
54+
"complex64": (complex, "npy_complex64", "NPY_COMPLEX64"),
5555
}
5656

5757

0 commit comments

Comments
 (0)