@@ -279,8 +279,6 @@ class ScalarType(CType, HasDataType, HasShape):
279
279
Analogous to TensorType, but for zero-dimensional objects.
280
280
Maps directly to C primitives.
281
281
282
- TODO: refactor to be named ScalarType for consistency with TensorType.
283
-
284
282
"""
285
283
286
284
__props__ = ("dtype" ,)
@@ -350,11 +348,14 @@ def c_element_type(self):
350
348
return self .dtype_specs ()[1 ]
351
349
352
350
def c_headers (self , c_compiler = None , ** kwargs ):
353
- l = ["<math.h>" ]
354
- # These includes are needed by ScalarType and TensorType,
355
- # we declare them here and they will be re-used by TensorType
356
- l .append ("<numpy/arrayobject.h>" )
357
- l .append ("<numpy/arrayscalars.h>" )
351
+ l = [
352
+ "<math.h>" ,
353
+ # These includes are needed by ScalarType and TensorType,
354
+ # we declare them here and they will be re-used by TensorType
355
+ "<numpy/arrayobject.h>" ,
356
+ "<numpy/arrayscalars.h>" ,
357
+ "<numpy/npy_2_complexcompat.h>" ,
358
+ ]
358
359
if config .lib__amblibm and c_compiler .supports_amdlibm :
359
360
l += ["<amdlibm.h>" ]
360
361
return l
@@ -396,8 +397,8 @@ def dtype_specs(self):
396
397
"float16" : (np .float16 , "npy_float16" , "Float16" ),
397
398
"float32" : (np .float32 , "npy_float32" , "Float32" ),
398
399
"float64" : (np .float64 , "npy_float64" , "Float64" ),
399
- "complex128" : (np .complex128 , "pytensor_complex128 " , "Complex128" ),
400
- "complex64" : (np .complex64 , "pytensor_complex64 " , "Complex64" ),
400
+ "complex128" : (np .complex128 , "npy_complex128 " , "Complex128" ),
401
+ "complex64" : (np .complex64 , "npy_complex64 " , "Complex64" ),
401
402
"bool" : (np .bool_ , "npy_bool" , "Bool" ),
402
403
"uint8" : (np .uint8 , "npy_uint8" , "UInt8" ),
403
404
"int8" : (np .int8 , "npy_int8" , "Int8" ),
@@ -506,171 +507,11 @@ def c_sync(self, name, sub):
506
507
def c_cleanup (self , name , sub ):
507
508
return ""
508
509
509
- def c_support_code (self , ** kwargs ):
510
- if self .dtype .startswith ("complex" ):
511
- cplx_types = ["pytensor_complex64" , "pytensor_complex128" ]
512
- real_types = [
513
- "npy_int8" ,
514
- "npy_int16" ,
515
- "npy_int32" ,
516
- "npy_int64" ,
517
- "npy_float32" ,
518
- "npy_float64" ,
519
- ]
520
- # If the 'int' C type is not exactly the same as an existing
521
- # 'npy_intX', some C code may not compile, e.g. when assigning
522
- # the value 0 (cast to 'int' in C) to an PyTensor_complex64.
523
- if np .dtype ("intc" ).num not in [np .dtype (d [4 :]).num for d in real_types ]:
524
- # In that case we add the 'int' type to the real types.
525
- real_types .append ("int" )
526
-
527
- template = """
528
- struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s
529
- {
530
- typedef pytensor_complex%(nbits)s complex_type;
531
- typedef npy_float%(half_nbits)s scalar_type;
532
-
533
- complex_type operator +(const complex_type &y) const {
534
- complex_type ret;
535
- ret.real = this->real + y.real;
536
- ret.imag = this->imag + y.imag;
537
- return ret;
538
- }
539
-
540
- complex_type operator -() const {
541
- complex_type ret;
542
- ret.real = -this->real;
543
- ret.imag = -this->imag;
544
- return ret;
545
- }
546
- bool operator ==(const complex_type &y) const {
547
- return (this->real == y.real) && (this->imag == y.imag);
548
- }
549
- bool operator ==(const scalar_type &y) const {
550
- return (this->real == y) && (this->imag == 0);
551
- }
552
- complex_type operator -(const complex_type &y) const {
553
- complex_type ret;
554
- ret.real = this->real - y.real;
555
- ret.imag = this->imag - y.imag;
556
- return ret;
557
- }
558
- complex_type operator *(const complex_type &y) const {
559
- complex_type ret;
560
- ret.real = this->real * y.real - this->imag * y.imag;
561
- ret.imag = this->real * y.imag + this->imag * y.real;
562
- return ret;
563
- }
564
- complex_type operator /(const complex_type &y) const {
565
- complex_type ret;
566
- scalar_type y_norm_square = y.real * y.real + y.imag * y.imag;
567
- ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square;
568
- ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square;
569
- return ret;
570
- }
571
- template <typename T>
572
- complex_type& operator =(const T& y);
573
-
574
- pytensor_complex%(nbits)s() {}
575
-
576
- template <typename T>
577
- pytensor_complex%(nbits)s(const T& y) { *this = y; }
578
-
579
- template <typename TR, typename TI>
580
- pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; }
581
- };
582
- """
583
-
584
- def operator_eq_real (mytype , othertype ):
585
- return f"""
586
- template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
587
- {{ this->real=y; this->imag=0; return *this; }}
588
- """
589
-
590
- def operator_eq_cplx (mytype , othertype ):
591
- return f"""
592
- template <> { mytype } & { mytype } ::operator=<{ othertype } >(const { othertype } & y)
593
- {{ this->real=y.real; this->imag=y.imag; return *this; }}
594
- """
595
-
596
- operator_eq = "" .join (
597
- operator_eq_real (ctype , rtype )
598
- for ctype in cplx_types
599
- for rtype in real_types
600
- ) + "" .join (
601
- operator_eq_cplx (ctype1 , ctype2 )
602
- for ctype1 in cplx_types
603
- for ctype2 in cplx_types
604
- )
605
-
606
- # We are not using C++ generic templating here, because this would
607
- # generate two different functions for adding a complex64 and a
608
- # complex128, one returning a complex64, the other a complex128,
609
- # and the compiler complains it is ambiguous.
610
- # Instead, we generate code for known and safe types only.
611
-
612
- def operator_plus_real (mytype , othertype ):
613
- return f"""
614
- const { mytype } operator+(const { mytype } &x, const { othertype } &y)
615
- {{ return { mytype } (x.real+y, x.imag); }}
616
-
617
- const { mytype } operator+(const { othertype } &y, const { mytype } &x)
618
- {{ return { mytype } (x.real+y, x.imag); }}
619
- """
620
-
621
- operator_plus = "" .join (
622
- operator_plus_real (ctype , rtype )
623
- for ctype in cplx_types
624
- for rtype in real_types
625
- )
626
-
627
- def operator_minus_real (mytype , othertype ):
628
- return f"""
629
- const { mytype } operator-(const { mytype } &x, const { othertype } &y)
630
- {{ return { mytype } (x.real-y, x.imag); }}
631
-
632
- const { mytype } operator-(const { othertype } &y, const { mytype } &x)
633
- {{ return { mytype } (y-x.real, -x.imag); }}
634
- """
635
-
636
- operator_minus = "" .join (
637
- operator_minus_real (ctype , rtype )
638
- for ctype in cplx_types
639
- for rtype in real_types
640
- )
641
-
642
- def operator_mul_real (mytype , othertype ):
643
- return f"""
644
- const { mytype } operator*(const { mytype } &x, const { othertype } &y)
645
- {{ return { mytype } (x.real*y, x.imag*y); }}
646
-
647
- const { mytype } operator*(const { othertype } &y, const { mytype } &x)
648
- {{ return { mytype } (x.real*y, x.imag*y); }}
649
- """
650
-
651
- operator_mul = "" .join (
652
- operator_mul_real (ctype , rtype )
653
- for ctype in cplx_types
654
- for rtype in real_types
655
- )
656
-
657
- return (
658
- template % dict (nbits = 64 , half_nbits = 32 )
659
- + template % dict (nbits = 128 , half_nbits = 64 )
660
- + operator_eq
661
- + operator_plus
662
- + operator_minus
663
- + operator_mul
664
- )
665
-
666
- else :
667
- return ""
668
-
669
510
def c_init_code (self , ** kwargs ):
670
511
return ["import_array();" ]
671
512
672
513
def c_code_cache_version (self ):
673
- return (13 , np .__version__ )
514
+ return (14 , np .__version__ )
674
515
675
516
def get_shape_info (self , obj ):
676
517
return obj .itemsize
0 commit comments