46
46
from aesara .tensor .random .opt import local_subtensor_rv_lift
47
47
from aesara .tensor .random .var import RandomStateSharedVariable
48
48
from aesara .tensor .sharedvar import ScalarSharedVariable
49
- from aesara .tensor .var import TensorVariable
49
+ from aesara .tensor .var import TensorVariable , TensorConstant
50
50
51
51
from pymc .aesaraf import (
52
52
compile_pymc ,
61
61
from pymc .distributions import joint_logpt
62
62
from pymc .distributions .logprob import _get_scaling
63
63
from pymc .distributions .transforms import _default_transform
64
- from pymc .exceptions import ImputationWarning , SamplingError , ShapeError
64
+ from pymc .exceptions import ImputationWarning , ShapeWarning , SamplingError , ShapeError
65
65
from pymc .initial_point import make_initial_point_fn
66
66
from pymc .math import flatten_list
67
67
from pymc .util import (
@@ -1179,24 +1179,49 @@ def set_data(
1179
1179
# Reject resizing if we already know that it would create shape problems.
1180
1180
# NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
1181
1181
# changes the values for one of them, they will run into shape problems nonetheless.
1182
- if original_coords is None :
1183
- length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1184
- if not isinstance (length_belongs_to , SharedVariable ) and length_changed :
1182
+ if length_changed :
1183
+ if isinstance (length_tensor ,TensorConstant ):
1185
1184
raise ShapeError (
1186
- f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1187
- f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1188
- f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1189
- f"for example by a model variable." ,
1190
- actual = new_length ,
1191
- expected = old_length ,
1185
+ f"Resizing dimension '{ dname } ' is impossible, because "
1186
+ f"a 'TensorConstant' stores its length. To be able "
1187
+ f"to change the dimension length, 'fixed' in "
1188
+ f"'model.add_coord' must be passed False."
1192
1189
)
1193
- if original_coords is not None and length_changed :
1194
- if length_changed and new_coords is None :
1195
- raise ValueError (
1196
- f"The '{ name } ' variable already had { len (original_coords )} coord values defined for"
1197
- f"its { dname } dimension. With the new values this dimension changes to length "
1198
- f"{ new_length } , so new coord values for the { dname } dimension are required."
1190
+ if length_tensor .owner is None :
1191
+ # This is the case if the dimension was initialized
1192
+ # from custom coords, but dimension length was not
1193
+ # stored in TensorConstant e.g by 'fixed' set to False
1194
+
1195
+ warnings .warn (
1196
+ f"You're changing the shape of a shared variable "
1197
+ f"in the '{ dname } ' dimension which was initialized "
1198
+ f"from coords. Make sure to update the corresponding "
1199
+ f"coords, otherwise you'll get shape issues." ,
1200
+ ShapeWarning ,
1199
1201
)
1202
+ else :
1203
+ length_belongs_to = length_tensor .owner .inputs [0 ].owner .inputs [0 ]
1204
+ if not isinstance (length_belongs_to , SharedVariable ):
1205
+ raise ShapeError (
1206
+ f"Resizing dimension '{ dname } ' with values of length { new_length } would lead to incompatibilities, "
1207
+ f"because the dimension was initialized from '{ length_belongs_to } ' which is not a shared variable. "
1208
+ f"Check if the dimension was defined implicitly before the shared variable '{ name } ' was created, "
1209
+ f"for example by a model variable." ,
1210
+ actual = new_length ,
1211
+ expected = old_length ,
1212
+ )
1213
+ if original_coords is not None :
1214
+ if new_coords is None :
1215
+ raise ValueError (
1216
+ f"The '{ name } ' variable already had { len (original_coords )} coord values defined for"
1217
+ f"its { dname } dimension. With the new values this dimension changes to length "
1218
+ f"{ new_length } , so new coord values for the { dname } dimension are required."
1219
+ )
1220
+ if isinstance (length_tensor , ScalarSharedVariable ):
1221
+ # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1222
+ length_tensor .set_value (new_length )
1223
+
1224
+
1200
1225
if new_coords is not None :
1201
1226
# Update the registered coord values (also if they were None)
1202
1227
if len (new_coords ) != new_length :
@@ -1206,9 +1231,7 @@ def set_data(
1206
1231
expected = new_length ,
1207
1232
)
1208
1233
self ._coords [dname ] = new_coords
1209
- if isinstance (length_tensor , ScalarSharedVariable ) and new_length != old_length :
1210
- # Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1211
- length_tensor .set_value (new_length )
1234
+
1212
1235
1213
1236
shared_object .set_value (values )
1214
1237
0 commit comments