Closed
Description
Description
Consider the following graph:
import pytensor
import pytensor.tensor.random as ptr
import pytensor.tensor as pt
from pytensor.compile import shared
from pytensor.compile.builders import OpFromGraph
import numpy as np
import pymc as pm
X = pt.tensor('X', shape=(100,))
rng = shared(np.random.default_rng())
new_rng, mask = ptr.bernoulli(p=0.5, size=(100,), rng=rng).owner.outputs
X_masked, _ = OpFromGraph([X, rng], [X * mask, new_rng])(X, rng)
g = pytensor.grad(X_masked.sum(), X)
g.dprint()
OpFromGraph{inline=False} [id A]
├─ X [id B]
├─ RNG(<Generator(PCG64) at 0x7FC57FB550E0>) [id C]
├─ OpFromGraph{inline=False}.0 [id D]
│ ├─ X [id B]
│ └─ RNG(<Generator(PCG64) at 0x7FC57FB550E0>) [id C]
└─ Second [id E]
├─ OpFromGraph{inline=False}.0 [id D]
│ └─ ···
└─ ExpandDims{axis=0} [id F]
└─ Second [id G]
├─ Sum{axes=None} [id H]
│ └─ OpFromGraph{inline=False}.0 [id D]
│ └─ ···
└─ 1.0 [id I]
Inner graphs:
OpFromGraph{inline=False} [id A]
← Mul [id J]
├─ *3-<Vector(float64, shape=(100,))> [id K]
└─ bernoulli_rv{"()->()"}.1 [id L]
├─ *1-<RandomGeneratorType> [id M]
├─ [100] [id N]
└─ ExpandDims{axis=0} [id O]
└─ 0.5 [id P]
OpFromGraph{inline=False} [id D]
← Mul [id Q]
├─ *0-<Vector(float64, shape=(100,))> [id R]
└─ bernoulli_rv{"()->()"}.1 [id S]
├─ *1-<RandomGeneratorType> [id M]
├─ [100] [id N]
└─ ExpandDims{axis=0} [id T]
└─ 0.5 [id P]
← bernoulli_rv{"()->()"}.0 [id S]
└─ ···
This gradient does not depend on the rng
input, since it's going to be evaluated at whatever the draw of mask
was. The way OFG handles this is to pass in all inputs, but then only use what it needs. This is visible by the missing *0
and *2
in the first inner graph of the dprint. A consequence of this is that collect_default_updates
raises an error here:
pm.pytensorf.collect_default_updates(g)
collect_default_updates error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:915, in collect_default_updates.<locals>.find_default_update(clients, rng)
914 try:
--> 915 next_rng = collect_default_updates_inner_fgraph(client)[rng]
916 except (ValueError, KeyError):
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:799, in collect_default_updates_inner_fgraph(node)
798 inp_idx = op.inner_inputs.index(rng)
--> 799 out_idx = op.inner_outputs.index(update)
800 updates[node.inputs[inp_idx]] = node.outputs[out_idx]
ValueError: bernoulli_rv{"()->()"}.0 is not in list
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[29], line 1
----> 1 pm.pytensorf.collect_default_updates(g)
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:947, in collect_default_updates(outputs, inputs, must_be_shared)
936 # Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
937 for input_rng in (
938 inp
939 for inp in graph_inputs(outs, blockers=inputs)
(...)
945 # Even if an explicit default update is provided, we call it to
946 # issue any warnings about invalid random graphs.
--> 947 default_update = find_default_update(clients, input_rng)
949 # Respect default update if provided
950 if hasattr(input_rng, "default_update") and input_rng.default_update is not None:
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:863, in collect_default_updates.<locals>.find_default_update(clients, rng)
858 return rng
860 if len(rng_clients) > 1:
861 # Multiple clients are techincally fine if they are used in identical operations
862 # We check if the default_update of each client would be the same
--> 863 update, *other_updates = (
864 find_default_update(
865 # Pass version of clients that includes only one the RNG clients at a time
866 clients | {rng: [rng_client]},
867 rng,
868 )
869 for rng_client in rng_clients
870 )
871 if all(equal_computations([update], [other_update]) for other_update in other_updates):
872 return update
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:864, in <genexpr>(.0)
858 return rng
860 if len(rng_clients) > 1:
861 # Multiple clients are techincally fine if they are used in identical operations
862 # We check if the default_update of each client would be the same
863 update, *other_updates = (
--> 864 find_default_update(
865 # Pass version of clients that includes only one the RNG clients at a time
866 clients | {rng: [rng_client]},
867 rng,
868 )
869 for rng_client in rng_clients
870 )
871 if all(equal_computations([update], [other_update]) for other_update in other_updates):
872 return update
File ~/mambaforge/envs/pytensor-ml/lib/python3.12/site-packages/pymc/pytensorf.py:917, in collect_default_updates.<locals>.find_default_update(clients, rng)
915 next_rng = collect_default_updates_inner_fgraph(client)[rng]
916 except (ValueError, KeyError):
--> 917 raise ValueError(
918 f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n"
919 "You can use `pytensorf.collect_default_updates` and include those updates as outputs."
920 )
921 else:
922 # We don't know how this RNG should be updated. The user should provide an update manually
923 return None
ValueError: No update found for at least one RNG used in OpFromGraph Op OpFromGraph{inline=False}.
You can use `pytensorf.collect_default_updates` and include those updates as outputs.
Talking to @ricardoV94 , it seems that it should always be safe to always ignore unused RNG inputs to OFGs, which would fix this error.