Skip to content

Commit f43b92f

Browse files
Fix shared variable comparisons in OpFromGraph.make_node
1 parent bc10e2b commit f43b92f

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

aesara/compile/builders.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -765,28 +765,30 @@ def make_node(self, *inputs):
765765
for inp, inp_t in zip(non_shared_inputs, self.input_types)
766766
]
767767

768-
shared_inputs = inputs[num_expected_inps:]
769-
local_shared_inputs = self.inner_inputs[num_expected_inps:]
770-
771-
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
768+
inner_and_input_shareds = list(
769+
zip(self.shared_inputs, inputs[num_expected_inps:])
770+
)
772771

773772
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
774773
# The shared variables are not equal to the original shared
775774
# variables, so we construct a new `Op` that uses the new shared
776-
# variables instead
777-
replace = {
778-
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs)
779-
}
780-
replace.update(inner_and_input_shareds)
775+
# variables instead.
776+
# All this is really doing is making the unused (internally, at
777+
# least) `self.outputs` and `self.shared_inputs` consistent.
778+
# We could just as easily `copy` this `Op`, update
779+
# `self.shared_inputs`, and avoid cloning anything, but this is a
780+
# more "change-proof" approach, because it still work when/if those
781+
# attributes end up being used.
782+
replace = dict(inner_and_input_shareds)
781783

782784
# If the new shared variables are inconsistent with the inner-graph,
783785
# such errors should arise in this step
784786
new_outputs = clone_replace(
785-
self.inner_outputs, replace=replace, share_inputs=True
787+
self.outputs, replace=replace, share_inputs=True
786788
)
787789

788790
new_op = type(self)(
789-
inputs=non_shared_inputs,
791+
inputs=self.inputs,
790792
outputs=new_outputs,
791793
inline=self.is_inline,
792794
lop_overrides=self.lop_overrides,

tests/compile/test_builders.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,17 +480,29 @@ def test_make_node_shared(self):
480480
assert out_new.owner.op.shared_inputs == [y_clone]
481481

482482
out_fn = function([x], out_new)
483-
484483
assert np.array_equal(out_fn(1.0), 2.0)
485484

486485
y_clone.set_value(2.0)
487-
488486
assert np.array_equal(out_fn(1.0), 3.0)
489487

490488
# This should also work, because the containers are the same:
491489
# y.set_value(1.0)
492490
# assert np.array_equal(out_fn(1.0), 2.0)
493491

492+
def test_shared_with_constant_input(self):
493+
"""Make sure that a constant input can be given to an `OpFromGraph` instance."""
494+
x = at.scalar("x")
495+
y = shared(1.0, name="y")
496+
497+
test_ofg = OpFromGraph([x], [x + y])
498+
assert test_ofg.inputs == [x]
499+
assert test_ofg.shared_inputs == [y]
500+
501+
out = test_ofg(at.as_tensor(1.0, dtype=config.floatX))
502+
503+
out_fn = function([], out)
504+
assert np.array_equal(out_fn(), 2.0)
505+
494506

495507
def test_debugprint():
496508
x, y, z = matrices("xyz")

0 commit comments

Comments
 (0)