Skip to content

Commit 2de0c23

Browse files
committed
added print_value function
1 parent 627a8dd commit 2de0c23

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

pymc/model/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pytensor.tensor.random.type import RandomType
4747
from pytensor.tensor.sharedvar import ScalarSharedVariable
4848
from pytensor.tensor.variable import TensorConstant, TensorVariable
49+
from pytensor.printing import Print
4950
from typing_extensions import Self
5051

5152
from pymc.blocking import DictToArrayBijection, RaveledVars
@@ -2245,3 +2246,11 @@ def normal_logp(value, mu, sigma):
22452246
)
22462247

22472248
return var
2249+
2250+
def print_value(var, name=None):
2251+
"""Print value of variable when it is computed during sampling.
2252+
This is likely to affect sampling performance.
2253+
"""
2254+
if name is None:
2255+
name = var.name
2256+
return Print(name)(var)

0 commit comments

Comments
 (0)