@@ -890,30 +890,9 @@ def connection_pattern(self, node):
890
890
if self ._connection_pattern is not None :
891
891
return self ._connection_pattern
892
892
893
- inp_len = len (self .inner_inputs )
894
- out_len = len (self .inner_outputs )
895
- cpmat_self = io_connection_pattern (self .inner_inputs , self .inner_outputs )
896
-
897
- lop_op = self .get_lop_op ()
898
- cpmat_grad = io_connection_pattern (
899
- lop_op .inner_inputs [inp_len :], lop_op .inner_outputs
900
- )
901
-
902
- # cpmat_self |= cpmat_grad.T
903
- # cpmat_self &= out_is_disconnected
904
- for i , t in enumerate (self ._lop_op_stypes_l ):
905
- if t is not None :
906
- if isinstance (t .type , DisconnectedType ):
907
- for o in range (out_len ):
908
- cpmat_self [i ][o ] = False
909
- for o in range (out_len ):
910
- cpmat_self [i ][o ] |= cpmat_grad [o ][i ]
911
-
912
- # TODO in case DisconnectedType is implemented for R_op,
913
- # self._rop_op_stypes_l self._rop_op should considered for
914
- # connection_pattern
915
-
916
- return list (map (list , cpmat_self ))
893
+ ret = io_connection_pattern (self .inner_inputs , self .inner_outputs )
894
+ self ._connection_pattern = ret
895
+ return ret
917
896
918
897
def infer_shape (self , fgraph , node , shapes ):
919
898
# TODO: Use `fgraph.shape_feature` to do this instead.
0 commit comments