@@ -1067,25 +1067,26 @@ def add_coord(
1067
1067
raise ValueError (
1068
1068
f"Either `values` or `length` must be specified for the '{ name } ' dimension."
1069
1069
)
1070
- if isinstance (length , int ):
1071
- length = at .constant (length )
1072
- elif length is not None and not isinstance (length , Variable ):
1073
- raise ValueError (
1074
- f"The `length` passed for the '{ name } ' coord must be an Aesara Variable or None."
1075
- )
1076
1070
if values is not None :
1077
1071
# Conversion to a tuple ensures that the coordinate values are immutable.
1078
1072
# Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
1079
1073
values = tuple (values )
1080
1074
if name in self .coords :
1081
1075
if not np .array_equal (values , self .coords [name ]):
1082
1076
raise ValueError (f"Duplicate and incompatible coordinate: { name } ." )
1083
- else :
1077
+ if length is not None and not isinstance (length , (int , Variable )):
1078
+ raise ValueError (
1079
+ f"The `length` passed for the '{ name } ' coord must be an int, Aesara Variable or None."
1080
+ )
1081
+ if length is None :
1082
+ length = len (values )
1083
+ if not isinstance (length , Variable ):
1084
1084
if mutable :
1085
- self . _dim_lengths [ name ] = length or aesara .shared (len ( values ) )
1085
+ length = aesara .shared (length )
1086
1086
else :
1087
- self ._dim_lengths [name ] = length or aesara .tensor .constant (len (values ))
1088
- self ._coords [name ] = values
1087
+ length = aesara .tensor .constant (length )
1088
+ self ._dim_lengths [name ] = length
1089
+ self ._coords [name ] = values
1089
1090
1090
1091
def add_coords (
1091
1092
self ,
@@ -1101,6 +1102,36 @@ def add_coords(
1101
1102
for name , values in coords .items ():
1102
1103
self .add_coord (name , values , length = lengths .get (name , None ))
1103
1104
1105
+ def set_dim (self , name : str , new_length : int , coord_values : Optional [Sequence ] = None ):
1106
+ """Update a mutable dimension.
1107
+
1108
+ Parameters
1109
+ ----------
1110
+ name
1111
+ Name of the dimension.
1112
+ new_length
1113
+ New length of the dimension.
1114
+ coord_values
1115
+ Optional sequence of coordinate values.
1116
+ """
1117
+ if not isinstance (self .dim_lengths [name ], ScalarSharedVariable ):
1118
+ raise ValueError (f"The dimension '{ name } ' is immutable." )
1119
+ if coord_values is None and self .coords .get (name , None ) is not None :
1120
+ raise ValueError (
1121
+ f"'{ name } ' has coord values. Pass `set_dim(..., coord_values=...)` to update them."
1122
+ )
1123
+ if coord_values is not None :
1124
+ len_cvals = len (coord_values )
1125
+ if len_cvals != new_length :
1126
+ raise ShapeError (
1127
+ f"Length of new coordinate values does not match the new dimension length." ,
1128
+ actual = len_cvals ,
1129
+ expected = new_length ,
1130
+ )
1131
+ self ._coords [name ] = tuple (coord_values )
1132
+ self .dim_lengths [name ].set_value (new_length )
1133
+ return
1134
+
1104
1135
def set_data (
1105
1136
self ,
1106
1137
name : str ,
0 commit comments