Skip to content

Commit 53ec8e3

Browse files
committed
Do not try to infer artificial connection patterns in OpFromGraph
1 parent ab304cb commit 53ec8e3

File tree

2 files changed

+5
-24
lines changed

2 files changed

+5
-24
lines changed

pytensor/compile/builders.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -890,30 +890,9 @@ def connection_pattern(self, node):
890890
if self._connection_pattern is not None:
891891
return self._connection_pattern
892892

893-
inp_len = len(self.inner_inputs)
894-
out_len = len(self.inner_outputs)
895-
cpmat_self = io_connection_pattern(self.inner_inputs, self.inner_outputs)
896-
897-
lop_op = self.get_lop_op()
898-
cpmat_grad = io_connection_pattern(
899-
lop_op.inner_inputs[inp_len:], lop_op.inner_outputs
900-
)
901-
902-
# cpmat_self |= cpmat_grad.T
903-
# cpmat_self &= out_is_disconnected
904-
for i, t in enumerate(self._lop_op_stypes_l):
905-
if t is not None:
906-
if isinstance(t.type, DisconnectedType):
907-
for o in range(out_len):
908-
cpmat_self[i][o] = False
909-
for o in range(out_len):
910-
cpmat_self[i][o] |= cpmat_grad[o][i]
911-
912-
# TODO in case DisconnectedType is implemented for R_op,
913-
# self._rop_op_stypes_l self._rop_op should considered for
914-
# connection_pattern
915-
916-
return list(map(list, cpmat_self))
893+
ret = io_connection_pattern(self.inner_inputs, self.inner_outputs)
894+
self._connection_pattern = ret
895+
return ret
917896

918897
def infer_shape(self, fgraph, node, shapes):
919898
# TODO: Use `fgraph.shape_feature` to do this instead.

tests/compile/test_builders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def go2(inps, gs):
244244
[x, w, b],
245245
[x * w + b],
246246
grad_overrides=[go1, NullType()(), DisconnectedType()()],
247+
# This is a fake override, so a fake connection_pattern must be provided as well
248+
connection_pattern=[[True], [True], [False]],
247249
)
248250
zz2 = pt_sum(op_linear2(xx, ww, bb))
249251
dx2, dw2, db2 = grad(

0 commit comments

Comments
 (0)