Skip to content

collect_default_updates should ignore unused rng inputs to OpFromGraph #7657

Closed
@jessegrabowski

Description

@jessegrabowski

Description

Consider the following graph:

import pytensor
import pytensor.tensor.random as ptr
import pytensor.tensor as pt
from pytensor.compile import shared
from pytensor.compile.builders import OpFromGraph

import numpy as np
import pymc as pm

X = pt.tensor('X', shape=(100,))
rng = shared(np.random.default_rng())
new_rng, mask = ptr.bernoulli(p=0.5, size=(100,), rng=rng).owner.outputs

X_masked, _ = OpFromGraph([X, rng], [X * mask, new_rng])(X, rng)
g = pytensor.grad(X_masked.sum(), X)
g.dprint()
OpFromGraph{inline=False} [id A]
 ├─ X [id B]
 ├─ RNG(<Generator(PCG64) at 0x7FC57FB550E0>) [id C]
 ├─ OpFromGraph{inline=False}.0 [id D]
 │  ├─ X [id B]
 │  └─ RNG(<Generator(PCG64) at 0x7FC57FB550E0>) [id C]
 └─ Second [id E]
    ├─ OpFromGraph{inline=False}.0 [id D]
    │  └─ ···
    └─ ExpandDims{axis=0} [id F]
       └─ Second [id G]
          ├─ Sum{axes=None} [id H]
          │  └─ OpFromGraph{inline=False}.0 [id D]
          │     └─ ···
          └─ 1.0 [id I]

Inner graphs:

OpFromGraph{inline=False} [id A]
 ← Mul [id J]
    ├─ *3-<Vector(float64, shape=(100,))> [id K]
    └─ bernoulli_rv{"()->()"}.1 [id L]
       ├─ *1-<RandomGeneratorType> [id M]
       ├─ [100] [id N]
       └─ ExpandDims{axis=0} [id O]
          └─ 0.5 [id P]

OpFromGraph{inline=False} [id D]
 ← Mul [id Q]
    ├─ *0-<Vector(float64, shape=(100,))> [id R]
    └─ bernoulli_rv{"()->()"}.1 [id S]
       ├─ *1-<RandomGeneratorType> [id M]
       ├─ [100] [id N]
       └─ ExpandDims{axis=0} [id T]
          └─ 0.5 [id P]
 ← bernoulli_rv{"()->()"}.0 [id S]
    └─ ···

This gradient does not depend on the rng input, since it's going to be evaluated at whatever the draw of mask was. The way OFG handles this is to pass in all inputs, but then only use what it needs. This is visible by the missing *0 and *2 in the first inner graph of the dprint. A consequence of this is that collect_default_updates raises an error here:

pm.pytensorf.collect_default_updates(g)
collect_default_updates error
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:915, in collect_default_updates.<locals>.find_default_update(clients, rng)
    914 try:
--> 915     next_rng = collect_default_updates_inner_fgraph(client)[rng]
    916 except (ValueError, KeyError):

File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:799, in collect_default_updates_inner_fgraph(node)
    798 inp_idx = op.inner_inputs.index(rng)
--> 799 out_idx = op.inner_outputs.index(update)
    800 updates[node.inputs[inp_idx]] = node.outputs[out_idx]

ValueError: bernoulli_rv{"()->()"}.0 is not in list

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[29], line 1
----> 1 pm.pytensorf.collect_default_updates(g)

File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:947, in collect_default_updates(outputs, inputs, must_be_shared)
    936 # Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
    937 for input_rng in (
    938     inp
    939     for inp in graph_inputs(outs, blockers=inputs)
   (...)
    945     # Even if an explicit default update is provided, we call it to
    946     # issue any warnings about invalid random graphs.
--> 947     default_update = find_default_update(clients, input_rng)
    949     # Respect default update if provided
    950     if hasattr(input_rng, "default_update") and input_rng.default_update is not None:

File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:863, in collect_default_updates.<locals>.find_default_update(clients, rng)
    858     return rng
    860 if len(rng_clients) > 1:
    861     # Multiple clients are techincally fine if they are used in identical operations
    862     # We check if the default_update of each client would be the same
--> 863     update, *other_updates = (
    864         find_default_update(
    865             # Pass version of clients that includes only one the RNG clients at a time
    866             clients | {rng: [rng_client]},
    867             rng,
    868         )
    869         for rng_client in rng_clients
    870     )
    871     if all(equal_computations([update], [other_update]) for other_update in other_updates):
    872         return update

File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:864, in <genexpr>(.0)
    858     return rng
    860 if len(rng_clients) > 1:
    861     # Multiple clients are techincally fine if they are used in identical operations
    862     # We check if the default_update of each client would be the same
    863     update, *other_updates = (
--> 864         find_default_update(
    865             # Pass version of clients that includes only one the RNG clients at a time
    866             clients | {rng: [rng_client]},
    867             rng,
    868         )
    869         for rng_client in rng_clients
    870     )
    871     if all(equal_computations([update], [other_update]) for other_update in other_updates):
    872         return update

File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:917, in collect_default_updates.<locals>.find_default_update(clients, rng)
    915         next_rng = collect_default_updates_inner_fgraph(client)[rng]
    916     except (ValueError, KeyError):
--> 917         raise ValueError(
    918             f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n"
    919             "You can use `pytensorf.collect_default_updates` and include those updates as outputs."
    920         )
    921 else:
    922     # We don't know how this RNG should be updated. The user should provide an update manually
    923     return None

ValueError: No update found for at least one RNG used in OpFromGraph Op OpFromGraph{inline=False}.
You can use `pytensorf.collect_default_updates` and include those updates as outputs.

Talking to @ricardoV94 , it seems that it should always be safe to always ignore unused RNG inputs to OFGs, which would fix this error.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions