diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 0d6f43bc83..a51c0d7cc1 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -248,11 +248,6 @@ def apply(self, fgraph): # misc special cases for speed that break canonicalization optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3) -# misc special cases for speed that are dependent on the device. -optdb.register( - "specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6 -) # must be after gpu stuff at 48.5 - # especially constant merge optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49) diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 67ad98994a..a5958d7f4f 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -640,16 +640,6 @@ def add_tensor_configvars(): in_c_key=False, ) - config.add( - "tensor__local_elemwise_fusion", - ( - "Enable or not in fast_run mode(fast_run optimization) the elemwise " - "fusion optimization" - ), - BoolParam(True), - in_c_key=False, - ) - # http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx config.add( "lib__amblibm", diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c3bb653870..58a5918c12 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -205,25 +205,6 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): return node_rewriter -def register_specialize_device( - node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs -): - if isinstance(node_rewriter, str): - - def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): - return register_specialize_device( - inner_rewriter, node_rewriter, *tags, **kwargs - ) - - return register - else: - name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__ - compile.optdb["specialize_device"].register( - name, node_rewriter, "fast_run", *tags, **kwargs - ) - return node_rewriter - - @register_canonicalize @register_specialize @node_rewriter([TensorFromScalar]) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index b76b048e72..e42543f9e6 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1085,38 +1085,10 @@ def print_profile(stream, prof, level=0): print(blanc, " time_toposort", prof[7], file=stream) -if config.tensor__local_elemwise_fusion: - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) - fuse_seqopt = SequenceDB() - fuse_seqopt.register( - "local_add_mul_fusion", - EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), - "fast_run", - "fusion", - position=0, - ) - fuse_seqopt.register( - "composite_elemwise_fusion", - FusionOptimizer(), - "fast_run", - "fusion", - position=1, - ) - compile.optdb.register( - "elemwise_fusion", - fuse_seqopt, - "fast_run", - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) - - @register_canonicalize @register_specialize @node_rewriter([Elemwise]) -def local_useless_composite(fgraph, node): +def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" if not isinstance(node.op, Elemwise) or not isinstance( node.op.scalar_op, aes.Composite @@ -1150,11 +1122,20 @@ def local_careduce_fusion(fgraph, node): """Fuse a `CAReduce` applied to an `Elemwise`.""" (car_input,) = node.inputs + car_scalar_op = node.op.scalar_op + + # FIXME: This check is needed because of the faulty logic in the FIXME below! + # Right now, rewrite only works for `Sum`/`Prod` + if not isinstance(car_scalar_op, (aes.Add, aes.Mul)): + return None + elm_node = car_input.owner if elm_node is None or not isinstance(elm_node.op, Elemwise): return False + elm_scalar_op = elm_node.op.scalar_op + elm_inputs = elm_node.inputs elm_outputs = elm_node.outputs @@ -1166,21 +1147,15 @@ def local_careduce_fusion(fgraph, node): return False # Don't form the fusion when the target language is Python - elm_scalar_op = elm_node.op.scalar_op - car_scalar_op = node.op.scalar_op - if get_target_language() == ("py",): return False - try: - elm_scalar_op.c_code( - elm_node, - "test_presence_of_c_code", - ["x" for x in elm_inputs], - ["z" for z in elm_outputs], - {"fail": "%(fail)s"}, - ) + if not elm_scalar_op.supports_c_code(elm_inputs, elm_outputs): + return None + # FIXME: This fails with Ops like `Max` whose `c_code` always expects two inputs! + # Should implement a `CAReduce.supports_c_code`? + try: car_scalar_op.c_code( node, "test_presence_of_c_code", @@ -1191,18 +1166,24 @@ def local_careduce_fusion(fgraph, node): except (NotImplementedError, MethodNotDefined): return False - car_axis = node.op.axis + car_op = node.op + car_acc_dtype = node.op.acc_dtype scalar_elm_inputs = [ aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs ] + elm_output = elm_scalar_op(*scalar_elm_inputs) + # This input represents the previous value in the `CAReduce` binary reduction - carried_car_input = elm_output.type() - scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)] + carried_car_input = aes.get_scalar_type(car_acc_dtype).make_variable() + + scalar_fused_output = car_scalar_op(carried_car_input, elm_output) + if scalar_fused_output.type.dtype != car_acc_dtype: + scalar_fused_output = aes.cast(scalar_fused_output, car_acc_dtype) fused_scalar_op = aes.Composite( - inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs + inputs=[carried_car_input] + scalar_elm_inputs, outputs=[scalar_fused_output] ) # The fused `Op` needs to look and behave like a `BinaryScalarOp` @@ -1211,16 +1192,56 @@ def local_careduce_fusion(fgraph, node): fused_scalar_op.nin = 2 fused_scalar_op.nout = 1 - new_car_op = CAReduce(fused_scalar_op, car_axis) + new_car_op = CAReduce( + scalar_op=fused_scalar_op, + axis=car_op.axis, + acc_dtype=car_acc_dtype, + dtype=car_op.dtype, + upcast_discrete_output=car_op.upcast_discrete_output, + ) return [new_car_op(*elm_inputs)] +# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) +fuse_seqopt = SequenceDB() compile.optdb.register( + "elemwise_fusion", + fuse_seqopt, + "fast_run", + "fusion", + "local_elemwise_fusion", + "FusionOptimizer", + position=49, +) + +fuse_seqopt.register( + "local_add_mul_fusion", + EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), + "fast_run", + "fusion", + position=0, +) +fuse_seqopt.register( + "composite_elemwise_fusion", + FusionOptimizer(), + "fast_run", + "fusion", + position=1, +) +fuse_seqopt.register( + "local_useless_composite_outputs", + in2out(local_useless_composite_outputs), + "fast_run", + "fusion", + position=2, +) +fuse_seqopt.register( "local_careduce_fusion", in2out(local_careduce_fusion), + "fast_run", "fusion", - position=49, + position=10, ) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b0d124f1c8..3c7f563271 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -88,7 +88,6 @@ local_fill_sink, register_canonicalize, register_specialize, - register_specialize_device, register_stabilize, register_uncanonicalize, register_useless, @@ -2078,12 +2077,14 @@ def local_pow_specialize(fgraph, node): return False -@register_specialize_device +@register_specialize @node_rewriter([at_pow]) -def local_pow_specialize_device(fgraph, node): - """ - This rewrite is not the same on all device. We do it only on cpu here. +def local_pow_to_nested_squaring(fgraph, node): + """Convert a large power exponent to multiple squaring operations. + + Note: This sounds like the kind of thing any half-decent compiler can do by itself? """ + if node.op == at_pow: # the idea here is that we have pow(x, y) odtype = node.outputs[0].dtype diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 2f45f63192..ead6575eb3 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1177,8 +1177,24 @@ def test_test_values(self, test_value): ) @pytest.mark.parametrize("linker", ["cvm", "py"]) + @pytest.mark.parametrize("inp_dtype", ("floatX", "int32")) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) - def test_CAReduce_single_input(self, linker, axis): + @pytest.mark.parametrize( + "careduce_op, numpy_op", + [ + (at_sum, np.sum), + pytest.param( + at_all, + np.all, + marks=pytest.mark.xfail( + reason="Rewrite logic does not support all CAReduce" + ), + ), + ], + ) + def test_CAReduce_single_input( + self, linker, inp_dtype, axis, careduce_op, numpy_op + ): """Make sure that `CAReduce` and `Elemwise` fusions work with a single input.""" mode = Mode(linker=linker) @@ -1188,8 +1204,8 @@ def test_CAReduce_single_input(self, linker, axis): "inplace", ) - x = tensor(dtype="floatX", shape=(None, None, None), name="x") - out = exp(x).sum(axis=axis) + x = tensor(dtype=inp_dtype, shape=(None, None, None), name="x") + out = careduce_op(exp(x), axis=axis) out_fn = function([x], out, mode=mode) @@ -1198,9 +1214,9 @@ def test_CAReduce_single_input(self, linker, axis): assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) rng = np.random.default_rng(2320) - x_val = rng.random((4, 3, 2), dtype=config.floatX) + x_val = rng.random((4, 3, 2)).astype(x.type.dtype) - exp_res = np.exp(x_val).sum(axis=axis) + exp_res = numpy_op(np.exp(x_val), axis=axis) out_val = out_fn(x_val) assert out_val.shape == exp_res.shape @@ -1216,7 +1232,7 @@ def test_CAReduce_single_input(self, linker, axis): # `Elemwise`s with more than one client shouldn't be rewritten x = tensor(dtype="floatX", shape=(None, None, None), name="x") exp_x = exp(x) - out = exp_x.sum(axis=axis) + exp(x) + out = careduce_op(exp_x, axis=axis) + exp(x) out_fn = function([x], out, mode=mode) out_nodes = out_fn.maker.fgraph.toposort() @@ -1409,39 +1425,40 @@ def test_nested_composite(self): fval = f([1, 2, 3]) assert np.all(fval == [6, 12, 18]) - def test_local_useless_composite(self): - x = aes.float32() - y = aes.float32() - z = aes.float32() - c = aes.Composite([x, y, z], [x + 1, y - 1]) - X = matrix("X") - Y = matrix("Y") - Z = matrix("Z") - o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) - mode = get_default_mode().including("local_useless_composite") - - f = function([X, Y, Z], [o1, o2], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 2 - assert len(topo[0].outputs) == 2 - res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) - utt.assert_allclose(res1, [[2.0]]) - utt.assert_allclose(res2, [[0.0]]) - - f = function([X, Y, Z], o1, mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) - f = function([X, Y, Z], o2, mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) +def test_local_useless_composite_outputs(): + x = aes.float32() + y = aes.float32() + z = aes.float32() + c = aes.Composite([x, y, z], [x + 1, y - 1]) + X = matrix("X") + Y = matrix("Y") + Z = matrix("Z") + o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) + mode = get_default_mode().including("local_useless_composite") + + f = function([X, Y, Z], [o1, o2], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 2 + assert len(topo[0].outputs) == 2 + res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) + utt.assert_allclose(res1, [[2.0]]) + utt.assert_allclose(res2, [[0.0]]) + + f = function([X, Y, Z], o1, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) + + f = function([X, Y, Z], o2, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) def test_local_useless_dimshuffle_makevector(): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f69879a51d..81fb55e09b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1672,12 +1672,12 @@ def test_local_pow_specialize(): utt.assert_allclose(f(val_no0), val_no0 ** (-0.5)) -def test_local_pow_specialize_device_more_aggressive_on_cpu(): +def test_local_pow_to_nested_squaring(): mode = config.mode if mode == "FAST_COMPILE": mode = "FAST_RUN" mode = get_mode(mode) - mode = mode.excluding("fusion").excluding("gpu") + mode = mode.excluding("fusion") v = vector() val = np.arange(10, dtype=config.floatX)