Skip to content

Commit b41805c

Browse files
committed
added print_value function
1 parent 2de0c23 commit b41805c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pymc/model/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@
4040
from pytensor.compile import DeepCopyOp, get_mode
4141
from pytensor.compile.sharedvalue import SharedVariable
4242
from pytensor.graph.basic import Constant, Variable, graph_inputs
43+
from pytensor.printing import Print
4344
from pytensor.scalar import Cast
4445
from pytensor.tensor.elemwise import Elemwise
4546
from pytensor.tensor.random.op import RandomVariable
4647
from pytensor.tensor.random.type import RandomType
4748
from pytensor.tensor.sharedvar import ScalarSharedVariable
4849
from pytensor.tensor.variable import TensorConstant, TensorVariable
49-
from pytensor.printing import Print
5050
from typing_extensions import Self
5151

5252
from pymc.blocking import DictToArrayBijection, RaveledVars
@@ -2247,10 +2247,11 @@ def normal_logp(value, mu, sigma):
22472247

22482248
return var
22492249

2250+
22502251
def print_value(var, name=None):
22512252
"""Print value of variable when it is computed during sampling.
22522253
This is likely to affect sampling performance.
22532254
"""
22542255
if name is None:
22552256
name = var.name
2256-
return Print(name)(var)
2257+
return Print(name)(var)

0 commit comments

Comments
 (0)