Skip to content

Commit a74592a

Browse files
michaelosthegericardoV94
authored andcommitted
Add Model.set_dim method for safer dims resizing
1 parent 0e7cd1f commit a74592a

File tree

2 files changed

+74
-10
lines changed

2 files changed

+74
-10
lines changed

pymc/model.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,25 +1067,26 @@ def add_coord(
10671067
raise ValueError(
10681068
f"Either `values` or `length` must be specified for the '{name}' dimension."
10691069
)
1070-
if isinstance(length, int):
1071-
length = at.constant(length)
1072-
elif length is not None and not isinstance(length, Variable):
1073-
raise ValueError(
1074-
f"The `length` passed for the '{name}' coord must be an Aesara Variable or None."
1075-
)
10761070
if values is not None:
10771071
# Conversion to a tuple ensures that the coordinate values are immutable.
10781072
# Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
10791073
values = tuple(values)
10801074
if name in self.coords:
10811075
if not np.array_equal(values, self.coords[name]):
10821076
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
1083-
else:
1077+
if length is not None and not isinstance(length, (int, Variable)):
1078+
raise ValueError(
1079+
f"The `length` passed for the '{name}' coord must be an int, Aesara Variable or None."
1080+
)
1081+
if length is None:
1082+
length = len(values)
1083+
if not isinstance(length, Variable):
10841084
if mutable:
1085-
self._dim_lengths[name] = length or aesara.shared(len(values))
1085+
length = aesara.shared(length)
10861086
else:
1087-
self._dim_lengths[name] = length or aesara.tensor.constant(len(values))
1088-
self._coords[name] = values
1087+
length = aesara.tensor.constant(length)
1088+
self._dim_lengths[name] = length
1089+
self._coords[name] = values
10891090

10901091
def add_coords(
10911092
self,
@@ -1101,6 +1102,36 @@ def add_coords(
11011102
for name, values in coords.items():
11021103
self.add_coord(name, values, length=lengths.get(name, None))
11031104

1105+
def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] = None):
1106+
"""Update a mutable dimension.
1107+
1108+
Parameters
1109+
----------
1110+
name
1111+
Name of the dimension.
1112+
new_length
1113+
New length of the dimension.
1114+
coord_values
1115+
Optional sequence of coordinate values.
1116+
"""
1117+
if not isinstance(self.dim_lengths[name], ScalarSharedVariable):
1118+
raise ValueError(f"The dimension '{name}' is immutable.")
1119+
if coord_values is None and self.coords.get(name, None) is not None:
1120+
raise ValueError(
1121+
f"'{name}' has coord values. Pass `set_dim(..., coord_values=...)` to update them."
1122+
)
1123+
if coord_values is not None:
1124+
len_cvals = len(coord_values)
1125+
if len_cvals != new_length:
1126+
raise ShapeError(
1127+
f"Length of new coordinate values does not match the new dimension length.",
1128+
actual=len_cvals,
1129+
expected=new_length,
1130+
)
1131+
self._coords[name] = tuple(coord_values)
1132+
self.dim_lengths[name].set_value(new_length)
1133+
return
1134+
11041135
def set_data(
11051136
self,
11061137
name: str,

pymc/tests/test_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,39 @@ def test_add_coord_mutable_kwarg():
765765
assert isinstance(m._dim_lengths["mutable2"], TensorVariable)
766766

767767

768+
def test_set_dim():
769+
"""Test the concious re-sizing of dims created through add_coord()."""
770+
with pm.Model() as pmodel:
771+
pmodel.add_coord("fdim", mutable=False, length=1)
772+
pmodel.add_coord("mdim", mutable=True, length=2)
773+
a = pm.Normal("a", dims="mdim")
774+
assert a.eval().shape == (2,)
775+
776+
with pytest.raises(ValueError, match="is immutable"):
777+
pmodel.set_dim("fdim", 3)
778+
779+
pmodel.set_dim("mdim", 3)
780+
assert a.eval().shape == (3,)
781+
782+
783+
def test_set_dim_with_coords():
784+
"""Test the concious re-sizing of dims created through add_coord() with coord value."""
785+
with pm.Model() as pmodel:
786+
pmodel.add_coord("mdim", mutable=True, length=2, values=["A", "B"])
787+
a = pm.Normal("a", dims="mdim")
788+
assert len(pmodel.coords["mdim"]) == 2
789+
790+
with pytest.raises(ValueError, match="has coord values"):
791+
pmodel.set_dim("mdim", new_length=3)
792+
793+
with pytest.raises(ShapeError, match="does not match"):
794+
pmodel.set_dim("mdim", new_length=3, coord_values=["A", "B"])
795+
796+
pmodel.set_dim("mdim", 3, ["A", "B", "C"])
797+
assert a.eval().shape == (3,)
798+
assert pmodel.coords["mdim"] == ("A", "B", "C")
799+
800+
768801
@pytest.mark.parametrize("jacobian", [True, False])
769802
def test_model_logp(jacobian):
770803
with pm.Model() as m:

0 commit comments

Comments
 (0)