Skip to content

local_sum_make_vector rewrite can introduce forbidden float64 operations at the graph level  #653

Closed
@tvwenger

Description

@tvwenger

Describe the issue:

This may be related to pymc-devs/pymc#6779

Description:

When creating a model with floatX="float32" that includes a Dirichlet distribution a single distribution, the floatX assignment is respected. When creating a model with a Dirichlet distribution as well as another distribution two distributions, however, the floatX assignment is NOT respected, but only upon sampling. This is a weird bug.

Expected Behavior

The model should respect floatX in all cases.

Actual Behavior

When the model includes a Dirichlet distribution and then ANY other distribution two distributions, the graph includes float64 despite the request that floatX="float32".

Minimum Working Example

In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.

The first two models sample without issue, and floatX is respected.

The second and third models raise float64 errors during sampling. The error appears after model.point_logps(), which is what was all that was being checked in pymc-devs/pymc#6779

The output (with truncated error messages) is appended below:

pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX =  float64

test_dirichlet
pytensor.config.floatX =  float32
foo float32
{'foo': -1.5}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_normal
pytensor.config.floatX =  float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_dirichlet_normal
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
 ...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

test_dirichlet_halfcauchy
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -1.14}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

Note that the output of print(model.point_logps()) demonstrates that the error occurs after model.point_logps(). The error occurs during sampling.

Reproduceable code example:

import pytensor
import pytensor.tensor as pt
import pymc as pm

print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)

print("pytensor.config.floatX = ", pytensor.config.floatX)
print()


def test_dirichlet():
    print("test_dirichlet")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_normal():
    print("test_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Normal("foo", mu=0.0, sigma=1.0)
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_normal():
    print("test_dirichlet_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.Normal("bar", mu=0.0, sigma=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_halfcauchy():
    print("test_dirichlet_halfcauchy")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.HalfCauchy("bar", beta=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_halfcauchy()

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in local_sum_make_vector
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in <listcomp>
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 763, in cast
    return _cast_mapping[dtype_name](x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 484, in make_node
    outputs = [
              ^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 485, in <listcomp>
    TensorType(dtype=dtype, shape=shape)()
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 228, in __call__
    return utils.add_tag_trace(self.make_variable(name))
                               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 200, in make_variable
    return self.variable_type(self, None, name=name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/variable.py", line 900, in __init__
    raise Exception(msg)
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

PyMC version information:

pytensor version: 2.18.6 pymc version: 5.10.4

Context for the issue:

Models that include a Dirichlet distribution as well as any other distribution cannot use float32.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions