Skip to content

Commit 7d34d33

Browse files
Make complex scalars work with numpy 2.0
This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Macros have been added to ensure compatibility with numpy < 2.0 Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior.
1 parent c2171fb commit 7d34d33

File tree

1 file changed

+161
-64
lines changed

1 file changed

+161
-64
lines changed

pytensor/scalar/basic.py

Lines changed: 161 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def c_headers(self, c_compiler=None, **kwargs):
349349
# we declare them here and they will be re-used by TensorType
350350
l.append("<numpy/arrayobject.h>")
351351
l.append("<numpy/arrayscalars.h>")
352+
l.append("<numpy/npy_math.h>")
353+
352354
if config.lib__amdlibm and c_compiler.supports_amdlibm:
353355
l += ["<amdlibm.h>"]
354356
return l
@@ -517,73 +519,167 @@ def c_support_code(self, **kwargs):
517519
# In that case we add the 'int' type to the real types.
518520
real_types.append("int")
519521

522+
# Macros for backwards compatibility with numpy < 2.0
523+
#
524+
# In numpy 2.0+, these are defined in npy_math.h, but
525+
# for early versions, they must be vendored by users (e.g. PyTensor)
526+
backwards_compat_macros = """
527+
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
528+
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
529+
530+
#include <numpy/npy_math.h>
531+
532+
#ifndef NPY_CSETREALF
533+
#define NPY_CSETREALF(c, r) (c)->real = (r)
534+
#endif
535+
#ifndef NPY_CSETIMAGF
536+
#define NPY_CSETIMAGF(c, i) (c)->imag = (i)
537+
#endif
538+
#ifndef NPY_CSETREAL
539+
#define NPY_CSETREAL(c, r) (c)->real = (r)
540+
#endif
541+
#ifndef NPY_CSETIMAG
542+
#define NPY_CSETIMAG(c, i) (c)->imag = (i)
543+
#endif
544+
#ifndef NPY_CSETREALL
545+
#define NPY_CSETREALL(c, r) (c)->real = (r)
546+
#endif
547+
#ifndef NPY_CSETIMAGL
548+
#define NPY_CSETIMAGL(c, i) (c)->imag = (i)
549+
#endif
550+
551+
#endif
552+
"""
553+
554+
def _make_get_set_real_imag(scalar_type: str) -> str:
555+
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
556+
557+
The functions called by these getter/setter functions are defining in npy_math.h, or
558+
in the `backward_compat_macros` defined above.
559+
560+
Args:
561+
scalar_type: float, double, or longdouble
562+
563+
Returns:
564+
C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
565+
given type.
566+
"""
567+
complex_type = "npy_c" + scalar_type
568+
suffix = "" if scalar_type == "double" else scalar_type[0]
569+
570+
if scalar_type == "longdouble":
571+
scalar_type = "npy_" + scalar_type
572+
573+
return_type = scalar_type
574+
575+
template = f"""
576+
static inline {return_type} get_real(const {complex_type} z)
577+
{{
578+
return npy_creal{suffix}(z);
579+
}}
580+
581+
static inline void set_real({complex_type} *z, const {scalar_type} r)
582+
{{
583+
NPY_CSETREAL{suffix.upper()}(z, r);
584+
}}
585+
586+
static inline {return_type} get_imag(const {complex_type} z)
587+
{{
588+
return npy_cimag{suffix}(z);
589+
}}
590+
591+
static inline void set_imag({complex_type} *z, const {scalar_type} i)
592+
{{
593+
NPY_CSETIMAG{suffix.upper()}(z, i);
594+
}}
595+
"""
596+
return template
597+
598+
get_set_aliases = "\n".join(
599+
_make_get_set_real_imag(stype)
600+
for stype in ["float", "double", "longdouble"]
601+
)
602+
603+
get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases
604+
605+
# Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
606+
#
607+
# The npy_complex64, npy_complex128 types are aliases defined at run time based on
608+
# the size of floats and doubles on the machine. This means that both types are
609+
# not necessarily defined on every machine, but a machine with 32-bit floats and
610+
# 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
611+
# as an alias of npy_complex128.
612+
#
613+
# In any case, the get/set real/imag functions defined above will always work for
614+
# npy_complex64 and npy_complex128.
520615
template = """
521-
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
522-
{
523-
typedef pytensor_complex%(nbits)s complex_type;
524-
typedef npy_float%(half_nbits)s scalar_type;
525-
526-
complex_type operator +(const complex_type &y) const {
527-
complex_type ret;
528-
ret.real = this->real + y.real;
529-
ret.imag = this->imag + y.imag;
530-
return ret;
531-
}
532-
533-
complex_type operator -() const {
534-
complex_type ret;
535-
ret.real = -this->real;
536-
ret.imag = -this->imag;
537-
return ret;
538-
}
539-
bool operator ==(const complex_type &y) const {
540-
return (this->real == y.real) && (this->imag == y.imag);
541-
}
542-
bool operator ==(const scalar_type &y) const {
543-
return (this->real == y) && (this->imag == 0);
544-
}
545-
complex_type operator -(const complex_type &y) const {
546-
complex_type ret;
547-
ret.real = this->real - y.real;
548-
ret.imag = this->imag - y.imag;
549-
return ret;
550-
}
551-
complex_type operator *(const complex_type &y) const {
552-
complex_type ret;
553-
ret.real = this->real * y.real - this->imag * y.imag;
554-
ret.imag = this->real * y.imag + this->imag * y.real;
555-
return ret;
556-
}
557-
complex_type operator /(const complex_type &y) const {
558-
complex_type ret;
559-
scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
560-
ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
561-
ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
562-
return ret;
563-
}
564-
template <typename T>
565-
complex_type& operator =(const T& y);
566-
567-
pytensor_complex%(nbits)s() {}
568-
569-
template <typename T>
570-
pytensor_complex%(nbits)s(const T& y) { *this = y; }
571-
572-
template <typename TR, typename TI>
573-
pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
616+
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
617+
typedef pytensor_complex%(nbits)s complex_type;
618+
typedef npy_float%(half_nbits)s scalar_type;
619+
620+
complex_type operator+(const complex_type &y) const {
621+
complex_type ret;
622+
set_real(&ret, get_real(*this) + get_real(y));
623+
set_imag(&ret, get_imag(*this) + get_imag(y));
624+
return ret;
625+
}
626+
627+
complex_type operator-() const {
628+
complex_type ret;
629+
set_real(&ret, -get_real(*this));
630+
set_imag(&ret, -get_imag(*this));
631+
return ret;
632+
}
633+
bool operator==(const complex_type &y) const {
634+
return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
635+
}
636+
bool operator==(const scalar_type &y) const {
637+
return (get_real(*this) == y) && (get_real(*this) == 0);
638+
}
639+
complex_type operator-(const complex_type &y) const {
640+
complex_type ret;
641+
set_real(&ret, get_real(*this) - get_real(y));
642+
set_imag(&ret, get_imag(*this) - get_imag(y));
643+
return ret;
644+
}
645+
complex_type operator*(const complex_type &y) const {
646+
complex_type ret;
647+
set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
648+
set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
649+
return ret;
650+
}
651+
complex_type operator/(const complex_type &y) const {
652+
complex_type ret;
653+
scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
654+
set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
655+
set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
656+
return ret;
657+
}
658+
template <typename T> complex_type &operator=(const T &y);
659+
660+
661+
pytensor_complex%(nbits)s() {}
662+
663+
template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
664+
665+
template <typename TR, typename TI>
666+
pytensor_complex%(nbits)s(const TR &r, const TI &i) {
667+
set_real(this, r);
668+
set_imag(this, i);
669+
}
574670
};
575671
"""
576672

577673
def operator_eq_real(mytype, othertype):
578674
return f"""
579675
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
580-
{{ this->real=y; this->imag=0; return *this; }}
676+
{{ set_real(this, y); set_imag(this, 0); return *this; }}
581677
"""
582678

583679
def operator_eq_cplx(mytype, othertype):
584680
return f"""
585681
template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y)
586-
{{ this->real=y.real; this->imag=y.imag; return *this; }}
682+
{{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }}
587683
"""
588684

589685
operator_eq = "".join(
@@ -605,10 +701,10 @@ def operator_eq_cplx(mytype, othertype):
605701
def operator_plus_real(mytype, othertype):
606702
return f"""
607703
const {mytype} operator+(const {mytype} &x, const {othertype} &y)
608-
{{ return {mytype}(x.real+y, x.imag); }}
704+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
609705
610706
const {mytype} operator+(const {othertype} &y, const {mytype} &x)
611-
{{ return {mytype}(x.real+y, x.imag); }}
707+
{{ return {mytype}(get_real(x) + y, get_imag(x)); }}
612708
"""
613709

614710
operator_plus = "".join(
@@ -620,10 +716,10 @@ def operator_plus_real(mytype, othertype):
620716
def operator_minus_real(mytype, othertype):
621717
return f"""
622718
const {mytype} operator-(const {mytype} &x, const {othertype} &y)
623-
{{ return {mytype}(x.real-y, x.imag); }}
719+
{{ return {mytype}(get_real(x) - y, get_imag(x)); }}
624720
625721
const {mytype} operator-(const {othertype} &y, const {mytype} &x)
626-
{{ return {mytype}(y-x.real, -x.imag); }}
722+
{{ return {mytype}(y - get_real(x), -get_imag(x)); }}
627723
"""
628724

629725
operator_minus = "".join(
@@ -635,10 +731,10 @@ def operator_minus_real(mytype, othertype):
635731
def operator_mul_real(mytype, othertype):
636732
return f"""
637733
const {mytype} operator*(const {mytype} &x, const {othertype} &y)
638-
{{ return {mytype}(x.real*y, x.imag*y); }}
734+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
639735
640736
const {mytype} operator*(const {othertype} &y, const {mytype} &x)
641-
{{ return {mytype}(x.real*y, x.imag*y); }}
737+
{{ return {mytype}(get_real(x) * y, get_imag(x) * y); }}
642738
"""
643739

644740
operator_mul = "".join(
@@ -648,7 +744,8 @@ def operator_mul_real(mytype, othertype):
648744
)
649745

650746
return (
651-
template % dict(nbits=64, half_nbits=32)
747+
get_set_aliases
748+
+ template % dict(nbits=64, half_nbits=32)
652749
+ template % dict(nbits=128, half_nbits=64)
653750
+ operator_eq
654751
+ operator_plus
@@ -663,7 +760,7 @@ def c_init_code(self, **kwargs):
663760
return ["import_array();"]
664761

665762
def c_code_cache_version(self):
666-
return (13, np.__version__)
763+
return (14, np.__version__)
667764

668765
def get_shape_info(self, obj):
669766
return obj.itemsize
@@ -2567,7 +2664,7 @@ def c_code(self, node, name, inputs, outputs, sub):
25672664
if type in float_types:
25682665
return f"{z} = fabs({x});"
25692666
if type in complex_types:
2570-
return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);"
2667+
return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));"
25712668
if node.outputs[0].type == bool:
25722669
return f"{z} = ({x}) ? 1 : 0;"
25732670
if type in uint_types:

0 commit comments

Comments
 (0)