@@ -349,6 +349,8 @@ def c_headers(self, c_compiler=None, **kwargs):
349
349
# we declare them here and they will be re-used by TensorType
350
350
l .append ("<numpy/arrayobject.h>" )
351
351
l .append ("<numpy/arrayscalars.h>" )
352
+ l .append ("<numpy/npy_math.h>" )
353
+
352
354
if config .lib__amdlibm and c_compiler .supports_amdlibm :
353
355
l += ["<amdlibm.h>" ]
354
356
return l
@@ -517,73 +519,167 @@ def c_support_code(self, **kwargs):
517
519
# In that case we add the 'int' type to the real types.
518
520
real_types .append ("int" )
519
521
522
+ # Macros for backwards compatibility with numpy < 2.0
523
+ #
524
+ # In numpy 2.0+, these are defined in npy_math.h, but
525
+ # for early versions, they must be vendored by users (e.g. PyTensor)
526
+ backwards_compat_macros = """
527
+ #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
528
+ #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
529
+
530
+ #include <numpy/npy_math.h>
531
+
532
+ #ifndef NPY_CSETREALF
533
+ #define NPY_CSETREALF(c, r) (c)->real = (r)
534
+ #endif
535
+ #ifndef NPY_CSETIMAGF
536
+ #define NPY_CSETIMAGF(c, i) (c)->imag = (i)
537
+ #endif
538
+ #ifndef NPY_CSETREAL
539
+ #define NPY_CSETREAL(c, r) (c)->real = (r)
540
+ #endif
541
+ #ifndef NPY_CSETIMAG
542
+ #define NPY_CSETIMAG(c, i) (c)->imag = (i)
543
+ #endif
544
+ #ifndef NPY_CSETREALL
545
+ #define NPY_CSETREALL(c, r) (c)->real = (r)
546
+ #endif
547
+ #ifndef NPY_CSETIMAGL
548
+ #define NPY_CSETIMAGL(c, i) (c)->imag = (i)
549
+ #endif
550
+
551
+ #endif
552
+ """
553
+
554
+ def _make_get_set_real_imag (scalar_type : str ) -> str :
555
+ """Make overloaded getter/setter functions for real/imag parts of numpy complex types.
556
+
557
+ The functions called by these getter/setter functions are defining in npy_math.h, or
558
+ in the `backward_compat_macros` defined above.
559
+
560
+ Args:
561
+ scalar_type: float, double, or longdouble
562
+
563
+ Returns:
564
+ C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the
565
+ given type.
566
+ """
567
+ complex_type = "npy_c" + scalar_type
568
+ suffix = "" if scalar_type == "double" else scalar_type [0 ]
569
+
570
+ if scalar_type == "longdouble" :
571
+ scalar_type = "npy_" + scalar_type
572
+
573
+ return_type = scalar_type
574
+
575
+ template = f"""
576
+ static inline { return_type } get_real(const { complex_type } z)
577
+ {{
578
+ return npy_creal{ suffix } (z);
579
+ }}
580
+
581
+ static inline void set_real({ complex_type } *z, const { scalar_type } r)
582
+ {{
583
+ NPY_CSETREAL{ suffix .upper ()} (z, r);
584
+ }}
585
+
586
+ static inline { return_type } get_imag(const { complex_type } z)
587
+ {{
588
+ return npy_cimag{ suffix } (z);
589
+ }}
590
+
591
+ static inline void set_imag({ complex_type } *z, const { scalar_type } i)
592
+ {{
593
+ NPY_CSETIMAG{ suffix .upper ()} (z, i);
594
+ }}
595
+ """
596
+ return template
597
+
598
+ get_set_aliases = "\n " .join (
599
+ _make_get_set_real_imag (stype )
600
+ for stype in ["float" , "double" , "longdouble" ]
601
+ )
602
+
603
+ get_set_aliases = backwards_compat_macros + "\n " + get_set_aliases
604
+
605
+ # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
606
+ #
607
+ # The npy_complex64, npy_complex128 types are aliases defined at run time based on
608
+ # the size of floats and doubles on the machine. This means that both types are
609
+ # not necessarily defined on every machine, but a machine with 32-bit floats and
610
+ # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
611
+ # as an alias of npy_complex128.
612
+ #
613
+ # In any case, the get/set real/imag functions defined above will always work for
614
+ # npy_complex64 and npy_complex128.
520
615
template = """
521
- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
522
- {
523
- typedef pytensor_complex%(nbits)s complex_type;
524
- typedef npy_float%(half_nbits)s scalar_type;
525
-
526
- complex_type operator +(const complex_type &y) const {
527
- complex_type ret;
528
- ret.real = this->real + y.real;
529
- ret.imag = this->imag + y.imag;
530
- return ret;
531
- }
532
-
533
- complex_type operator -() const {
534
- complex_type ret;
535
- ret.real = -this->real;
536
- ret.imag = -this->imag;
537
- return ret;
538
- }
539
- bool operator ==(const complex_type &y) const {
540
- return (this->real == y.real) && (this->imag == y.imag);
541
- }
542
- bool operator ==(const scalar_type &y) const {
543
- return (this->real == y) && (this->imag == 0);
544
- }
545
- complex_type operator -(const complex_type &y) const {
546
- complex_type ret;
547
- ret.real = this->real - y.real;
548
- ret.imag = this->imag - y.imag;
549
- return ret;
550
- }
551
- complex_type operator *(const complex_type &y) const {
552
- complex_type ret;
553
- ret.real = this->real * y.real - this->imag * y.imag;
554
- ret.imag = this->real * y.imag + this->imag * y.real;
555
- return ret;
556
- }
557
- complex_type operator /(const complex_type &y) const {
558
- complex_type ret;
559
- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
560
- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
561
- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
562
- return ret;
563
- }
564
- template <typename T>
565
- complex_type& operator =(const T& y);
566
-
567
- pytensor_complex%(nbits)s() {}
568
-
569
- template <typename T>
570
- pytensor_complex%(nbits)s(const T& y) { *this = y; }
571
-
572
- template <typename TR, typename TI>
573
- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
616
+ struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
617
+ typedef pytensor_complex%(nbits)s complex_type;
618
+ typedef npy_float%(half_nbits)s scalar_type;
619
+
620
+ complex_type operator+(const complex_type &y) const {
621
+ complex_type ret;
622
+ set_real(&ret, get_real(*this) + get_real(y));
623
+ set_imag(&ret, get_imag(*this) + get_imag(y));
624
+ return ret;
625
+ }
626
+
627
+ complex_type operator-() const {
628
+ complex_type ret;
629
+ set_real(&ret, -get_real(*this));
630
+ set_imag(&ret, -get_imag(*this));
631
+ return ret;
632
+ }
633
+ bool operator==(const complex_type &y) const {
634
+ return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y));
635
+ }
636
+ bool operator==(const scalar_type &y) const {
637
+ return (get_real(*this) == y) && (get_real(*this) == 0);
638
+ }
639
+ complex_type operator-(const complex_type &y) const {
640
+ complex_type ret;
641
+ set_real(&ret, get_real(*this) - get_real(y));
642
+ set_imag(&ret, get_imag(*this) - get_imag(y));
643
+ return ret;
644
+ }
645
+ complex_type operator*(const complex_type &y) const {
646
+ complex_type ret;
647
+ set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y));
648
+ set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y));
649
+ return ret;
650
+ }
651
+ complex_type operator/(const complex_type &y) const {
652
+ complex_type ret;
653
+ scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y);
654
+ set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square);
655
+ set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square);
656
+ return ret;
657
+ }
658
+ template <typename T> complex_type &operator=(const T &y);
659
+
660
+
661
+ pytensor_complex%(nbits)s() {}
662
+
663
+ template <typename T> pytensor_complex%(nbits)s(const T &y) { *this = y; }
664
+
665
+ template <typename TR, typename TI>
666
+ pytensor_complex%(nbits)s(const TR &r, const TI &i) {
667
+ set_real(this, r);
668
+ set_imag(this, i);
669
+ }
574
670
};
575
671
"""
576
672
577
673
def operator_eq_real (mytype , othertype ):
578
674
return f"""
579
675
template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
580
- {{ this->real=y; this->imag=0 ; return *this; }}
676
+ {{ set_real( this, y); set_imag( this, 0) ; return *this; }}
581
677
"""
582
678
583
679
def operator_eq_cplx (mytype , othertype ):
584
680
return f"""
585
681
template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
586
- {{ this->real=y.real; this->imag=y.imag ; return *this; }}
682
+ {{ set_real( this, get_real(y)); set_imag( this, get_imag(y)) ; return *this; }}
587
683
"""
588
684
589
685
operator_eq = "" .join (
@@ -605,10 +701,10 @@ def operator_eq_cplx(mytype, othertype):
605
701
def operator_plus_real (mytype , othertype ):
606
702
return f"""
607
703
const { mytype } operator+(const { mytype } &x, const { othertype } &y)
608
- {{ return { mytype } (x.real+ y, x.imag ); }}
704
+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
609
705
610
706
const { mytype } operator+(const { othertype } &y, const { mytype } &x)
611
- {{ return { mytype } (x.real+ y, x.imag ); }}
707
+ {{ return { mytype } (get_real(x) + y, get_imag(x) ); }}
612
708
"""
613
709
614
710
operator_plus = "" .join (
@@ -620,10 +716,10 @@ def operator_plus_real(mytype, othertype):
620
716
def operator_minus_real (mytype , othertype ):
621
717
return f"""
622
718
const { mytype } operator-(const { mytype } &x, const { othertype } &y)
623
- {{ return { mytype } (x.real- y, x.imag ); }}
719
+ {{ return { mytype } (get_real(x) - y, get_imag(x) ); }}
624
720
625
721
const { mytype } operator-(const { othertype } &y, const { mytype } &x)
626
- {{ return { mytype } (y-x.real , -x.imag ); }}
722
+ {{ return { mytype } (y - get_real(x) , -get_imag(x) ); }}
627
723
"""
628
724
629
725
operator_minus = "" .join (
@@ -635,10 +731,10 @@ def operator_minus_real(mytype, othertype):
635
731
def operator_mul_real (mytype , othertype ):
636
732
return f"""
637
733
const { mytype } operator*(const { mytype } &x, const { othertype } &y)
638
- {{ return { mytype } (x.real* y, x.imag* y); }}
734
+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
639
735
640
736
const { mytype } operator*(const { othertype } &y, const { mytype } &x)
641
- {{ return { mytype } (x.real* y, x.imag* y); }}
737
+ {{ return { mytype } (get_real(x) * y, get_imag(x) * y); }}
642
738
"""
643
739
644
740
operator_mul = "" .join (
@@ -648,7 +744,8 @@ def operator_mul_real(mytype, othertype):
648
744
)
649
745
650
746
return (
651
- template % dict (nbits = 64 , half_nbits = 32 )
747
+ get_set_aliases
748
+ + template % dict (nbits = 64 , half_nbits = 32 )
652
749
+ template % dict (nbits = 128 , half_nbits = 64 )
653
750
+ operator_eq
654
751
+ operator_plus
@@ -663,7 +760,7 @@ def c_init_code(self, **kwargs):
663
760
return ["import_array();" ]
664
761
665
762
def c_code_cache_version (self ):
666
- return (13 , np .__version__ )
763
+ return (14 , np .__version__ )
667
764
668
765
def get_shape_info (self , obj ):
669
766
return obj .itemsize
@@ -2567,7 +2664,7 @@ def c_code(self, node, name, inputs, outputs, sub):
2567
2664
if type in float_types :
2568
2665
return f"{ z } = fabs({ x } );"
2569
2666
if type in complex_types :
2570
- return f"{ z } = sqrt({ x } .real* { x } .real + { x } .imag* { x } .imag );"
2667
+ return f"{ z } = sqrt(get_real( { x } ) * get_real( { x } ) + get_imag( { x } ) * get_imag( { x } ) );"
2571
2668
if node .outputs [0 ].type == bool :
2572
2669
return f"{ z } = ({ x } ) ? 1 : 0;"
2573
2670
if type in uint_types :
0 commit comments