Skip to content

Commit 3bbf8ee

Browse files
michaelosthegericardoV94
authored andcommitted
Update Aesara to 2.8.8
This includes several changes related to static shape handling. `test_var_replacement` is marked XFAIL because of incompatibility with the new static shape handling.
1 parent 2c946fa commit 3bbf8ee

14 files changed

+34
-20
lines changed

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aesara=2.8.7
8+
- aesara=2.8.8
99
- arviz>=0.13.0
1010
- blas
1111
- cachetools>=4.2.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies
8-
- aesara=2.8.7
8+
- aesara=2.8.8
99
- arviz>=0.13.0
1010
- blas
1111
- cachetools>=4.2.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aesara=2.8.7
8+
- aesara=2.8.8
99
- arviz>=0.13.0
1010
- blas
1111
- cachetools>=4.2.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ channels:
55
- defaults
66
dependencies:
77
# Base dependencies (see install guide for Windows)
8-
- aesara=2.8.7
8+
- aesara=2.8.8
99
- arviz>=0.13.0
1010
- blas
1111
- cachetools>=4.2.1

pymc/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,15 @@ def softmax(x, axis=None):
284284
# drops that warning
285285
with warnings.catch_warnings():
286286
warnings.simplefilter("ignore", UserWarning)
287-
return at.nnet.softmax(x, axis=axis)
287+
return at.special.softmax(x, axis=axis)
288288

289289

290290
def log_softmax(x, axis=None):
291291
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
292292
# drops that warning
293293
with warnings.catch_warnings():
294294
warnings.simplefilter("ignore", UserWarning)
295-
return at.nnet.logsoftmax(x, axis=axis)
295+
return at.special.log_softmax(x, axis=axis)
296296

297297

298298
def logbern(log_p):

pymc/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,15 +1212,29 @@ def set_data(
12121212
"or define it via a `pm.MutableData` variable."
12131213
)
12141214
elif length_tensor.owner is not None:
1215-
# The dimension was created from a model variable.
1215+
# The dimension was created from another variable:
1216+
length_tensor_origin = length_tensor.owner.inputs[0]
12161217
# Get a handle on the tensor from which this dimension length was
12171218
# obtained by doing subindexing on the shape as in `.shape[i]`.
1218-
# Needed to check if it was another shared variable.
1219+
if isinstance(length_tensor_origin, TensorConstant):
1220+
raise ShapeError(
1221+
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1222+
f"because the dimension length is tied to a {length_tensor_origin}. "
1223+
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1224+
f"for example by another model variable.",
1225+
actual=new_length,
1226+
expected=old_length,
1227+
)
1228+
1229+
# The shape entry this dimension is tied to is not a TensorConstant.
1230+
# Whether the dimension can be resized depends on the kind of Variable the shape belongs to.
12191231
# TODO: Consider checking the graph is what we are assuming it is
12201232
# isinstance(length_tensor.owner.op, Subtensor)
12211233
# isinstance(length_tensor.owner.inputs[0].owner.op, Shape)
1222-
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1234+
length_belongs_to = length_tensor_origin.owner.inputs[0]
1235+
12231236
if length_belongs_to is shared_object:
1237+
# This is the shared variable that's being updated!
12241238
# No surprise it's changing.
12251239
pass
12261240
elif isinstance(length_belongs_to, SharedVariable):

pymc/tests/distributions/test_truncated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
9797

9898
xt = Truncated.dist(x, lower=lower, upper=upper)
9999
assert isinstance(xt.owner.op, TruncatedRV)
100-
assert xt.type == x.type
100+
assert xt.type.dtype == x.type.dtype
101101

102102
xt_draws = draw(xt, draws=5)
103103
assert np.all(xt_draws >= lower)
@@ -162,7 +162,7 @@ def test_truncation_discrete_random(op_type, lower, upper):
162162
x = geometric_op(p, name="x", size=500)
163163
xt = Truncated.dist(x, lower=lower, upper=upper)
164164
assert isinstance(xt.owner.op, TruncatedRV)
165-
assert xt.type == x.type
165+
assert xt.type.dtype == x.type.dtype
166166

167167
xt_draws = draw(xt)
168168
assert np.all(xt_draws >= lower)

pymc/tests/test_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,6 @@ def test_free_rv(self):
714714
with pm.Model() as model5:
715715
n = pm.Normal("n", total_size=[2, Ellipsis, 2], size=(2, 2))
716716
p5 = model5.compile_fn(model5.logp(), point_fn=False)
717-
assert p4() == p5(pm.floatX([[1]]))
718717
assert p4() == p5(pm.floatX([[1, 1], [1, 1]]))
719718

720719

pymc/tests/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def test_invlogit_deprecation_warning():
280280
@pytest.mark.parametrize(
281281
"aesara_function, pymc_wrapper",
282282
[
283-
(at.nnet.softmax, softmax),
284-
(at.nnet.logsoftmax, log_softmax),
283+
(at.special.softmax, softmax),
284+
(at.special.log_softmax, log_softmax),
285285
],
286286
)
287287
def test_softmax_logsoftmax_no_warnings(aesara_function, pymc_wrapper):

pymc/tests/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def test_observed_type(self):
216216
x1 = pm.Normal("x1", observed=X_)
217217
x2 = pm.Normal("x2", observed=X)
218218

219-
assert x1.type == X.type
220-
assert x2.type == X.type
219+
assert x1.type.dtype == X.type.dtype
220+
assert x2.type.dtype == X.type.dtype
221221

222222

223223
def test_duplicate_vars():
@@ -935,7 +935,7 @@ def test_set_data_constant_shape_error():
935935
pmodel.add_coord("weekday", length=x.shape[0])
936936
pm.MutableData("y", np.arange(7), dims="weekday")
937937

938-
msg = "because the dimension was initialized from 'x' which is not a shared variable"
938+
msg = "because the dimension length is tied to a TensorConstant"
939939
with pytest.raises(ShapeError, match=msg):
940940
pmodel.set_data("y", np.arange(10))
941941

pymc/tests/variational/test_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_remove_scan_op():
307307
buff.close()
308308

309309

310+
@pytest.mark.xfail(reason="Broke from static shape handling with Aesara 2.8.8")
310311
def test_var_replacement():
311312
X_mean = pm.floatX(np.linspace(0, 10, 10))
312313
y = pm.floatX(np.random.normal(X_mean * 4, 0.05))

pymc/variational/updates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def total_norm_constraint(tensor_vars, max_norm, epsilon=1e-7, return_norm=False
10521052
>>> x = at.matrix()
10531053
>>> y = at.ivector()
10541054
>>> l_in = InputLayer((5, 10))
1055-
>>> l1 = DenseLayer(l_in, num_units=7, nonlinearity=at.nnet.softmax)
1055+
>>> l1 = DenseLayer(l_in, num_units=7, nonlinearity=at.special.softmax)
10561056
>>> output = lasagne.layers.get_output(l1, x)
10571057
>>> cost = at.mean(at.nnet.categorical_crossentropy(output, y))
10581058
>>> all_params = lasagne.layers.get_all_params(l1)

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify.
22
# See that file for comments about the need/usage of each dependency.
33

4-
aesara==2.8.7
4+
aesara==2.8.8
55
arviz>=0.13.0
66
cachetools>=4.2.1
77
cloudpickle

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
aesara==2.8.7
1+
aesara==2.8.8
22
arviz>=0.13.0
33
cachetools>=4.2.1
44
cloudpickle

0 commit comments

Comments
 (0)