Skip to content

Commit 5f4e5aa

Browse files
support TensorConstant as dim length in set_data
1 parent 4ab748c commit 5f4e5aa

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

pymc/model.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from aesara.tensor.random.opt import local_subtensor_rv_lift
4747
from aesara.tensor.random.var import RandomStateSharedVariable
4848
from aesara.tensor.sharedvar import ScalarSharedVariable
49-
from aesara.tensor.var import TensorVariable
49+
from aesara.tensor.var import TensorVariable, TensorConstant
5050

5151
from pymc.aesaraf import (
5252
compile_pymc,
@@ -61,7 +61,7 @@
6161
from pymc.distributions import joint_logpt
6262
from pymc.distributions.logprob import _get_scaling
6363
from pymc.distributions.transforms import _default_transform
64-
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
64+
from pymc.exceptions import ImputationWarning, ShapeWarning, SamplingError, ShapeError
6565
from pymc.initial_point import make_initial_point_fn
6666
from pymc.math import flatten_list
6767
from pymc.util import (
@@ -1179,24 +1179,49 @@ def set_data(
11791179
# Reject resizing if we already know that it would create shape problems.
11801180
# NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
11811181
# changes the values for one of them, they will run into shape problems nonetheless.
1182-
if original_coords is None:
1183-
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1184-
if not isinstance(length_belongs_to, SharedVariable) and length_changed:
1182+
if length_changed:
1183+
if isinstance(length_tensor,TensorConstant):
11851184
raise ShapeError(
1186-
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1187-
f"because the dimension was initialized from '{length_belongs_to}' which is not a shared variable. "
1188-
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1189-
f"for example by a model variable.",
1190-
actual=new_length,
1191-
expected=old_length,
1185+
f"Resizing dimension '{dname}' is impossible, because "
1186+
f"a 'TensorConstant' stores its length. To be able "
1187+
f"to change the dimension length, 'fixed' in "
1188+
f"'model.add_coord' must be passed False."
11921189
)
1193-
if original_coords is not None and length_changed:
1194-
if length_changed and new_coords is None:
1195-
raise ValueError(
1196-
f"The '{name}' variable already had {len(original_coords)} coord values defined for"
1197-
f"its {dname} dimension. With the new values this dimension changes to length "
1198-
f"{new_length}, so new coord values for the {dname} dimension are required."
1190+
if length_tensor.owner is None:
1191+
# This is the case if the dimension was initialized
1192+
# from custom coords, but dimension length was not
1193+
# stored in TensorConstant e.g by 'fixed' set to False
1194+
1195+
warnings.warn(
1196+
f"You're changing the shape of a shared variable "
1197+
f"in the '{dname}' dimension which was initialized "
1198+
f"from coords. Make sure to update the corresponding "
1199+
f"coords, otherwise you'll get shape issues.",
1200+
ShapeWarning,
11991201
)
1202+
else:
1203+
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1204+
if not isinstance(length_belongs_to, SharedVariable):
1205+
raise ShapeError(
1206+
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1207+
f"because the dimension was initialized from '{length_belongs_to}' which is not a shared variable. "
1208+
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1209+
f"for example by a model variable.",
1210+
actual=new_length,
1211+
expected=old_length,
1212+
)
1213+
if original_coords is not None:
1214+
if new_coords is None:
1215+
raise ValueError(
1216+
f"The '{name}' variable already had {len(original_coords)} coord values defined for"
1217+
f"its {dname} dimension. With the new values this dimension changes to length "
1218+
f"{new_length}, so new coord values for the {dname} dimension are required."
1219+
)
1220+
if isinstance(length_tensor, ScalarSharedVariable):
1221+
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1222+
length_tensor.set_value(new_length)
1223+
1224+
12001225
if new_coords is not None:
12011226
# Update the registered coord values (also if they were None)
12021227
if len(new_coords) != new_length:
@@ -1206,9 +1231,7 @@ def set_data(
12061231
expected=new_length,
12071232
)
12081233
self._coords[dname] = new_coords
1209-
if isinstance(length_tensor, ScalarSharedVariable) and new_length != old_length:
1210-
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1211-
length_tensor.set_value(new_length)
1234+
12121235

12131236
shared_object.set_value(values)
12141237

0 commit comments

Comments
 (0)