22
22
from typing import Callable , Optional , Sequence , Tuple , Union
23
23
24
24
import numpy as np
25
+ import opcode
25
26
26
27
from aesara import tensor as at
27
28
from aesara .compile .builders import OpFromGraph
@@ -164,6 +165,45 @@ def fn(*args, **kwargs):
164
165
return fn
165
166
166
167
168
+ # Helper function from pyprob
169
+ def _extract_target_of_assignment (depth ):
170
+ frame = sys ._getframe (depth )
171
+ code = frame .f_code
172
+ next_instruction = code .co_code [frame .f_lasti + 2 ]
173
+ instruction_arg = code .co_code [frame .f_lasti + 3 ]
174
+ instruction_name = opcode .opname [next_instruction ]
175
+ if instruction_name == "STORE_FAST" :
176
+ return code .co_varnames [instruction_arg ]
177
+ elif instruction_name in ["STORE_NAME" , "STORE_GLOBAL" ]:
178
+ return code .co_names [instruction_arg ]
179
+ elif (
180
+ instruction_name in ["LOAD_FAST" , "LOAD_NAME" , "LOAD_GLOBAL" ]
181
+ and opcode .opname [code .co_code [frame .f_lasti + 4 ]] in ["LOAD_CONST" , "LOAD_FAST" ]
182
+ and opcode .opname [code .co_code [frame .f_lasti + 6 ]] == "STORE_SUBSCR"
183
+ ):
184
+ if instruction_name == "LOAD_FAST" :
185
+ base_name = code .co_varnames [instruction_arg ]
186
+ else :
187
+ base_name = code .co_names [instruction_arg ]
188
+
189
+ second_instruction = opcode .opname [code .co_code [frame .f_lasti + 4 ]]
190
+ second_arg = code .co_code [frame .f_lasti + 5 ]
191
+ if second_instruction == "LOAD_CONST" :
192
+ value = code .co_consts [second_arg ]
193
+ elif second_instruction == "LOAD_FAST" :
194
+ var_name = code .co_varnames [second_arg ]
195
+ value = frame .f_locals [var_name ]
196
+ else :
197
+ value = None
198
+ if value is not None :
199
+ index_name = repr (value )
200
+ return base_name + "[" + index_name + "]"
201
+ else :
202
+ return None
203
+ else :
204
+ return None
205
+
206
+
167
207
class SymbolicRandomVariable (OpFromGraph ):
168
208
"""Symbolic Random Variable
169
209
@@ -216,7 +256,6 @@ class Distribution(metaclass=DistributionMeta):
216
256
217
257
def __new__ (
218
258
cls ,
219
- name : str ,
220
259
* args ,
221
260
rng = None ,
222
261
dims : Optional [Dims ] = None ,
@@ -234,8 +273,6 @@ def __new__(
234
273
----------
235
274
cls : type
236
275
A PyMC distribution.
237
- name : str
238
- Name for the new model variable.
239
276
rng : optional
240
277
Random number generator to use with the RandomVariable.
241
278
dims : tuple, optional
@@ -277,6 +314,19 @@ def __new__(
277
314
"for a standalone distribution."
278
315
)
279
316
317
+ if "name" in kwargs :
318
+ name = kwargs .pop ("name" )
319
+ elif len (args ) > 0 and isinstance (args [0 ], string_types ):
320
+ name = args [0 ]
321
+ args = args [1 :]
322
+ else :
323
+ name = _extract_target_of_assignment (2 )
324
+ if name is None :
325
+ raise TypeError ("Name could not be inferred for variable" )
326
+
327
+ if not isinstance (name , string_types ):
328
+ raise TypeError (f"Name needs to be a string but got: { name } " )
329
+
280
330
if "testval" in kwargs :
281
331
initval = kwargs .pop ("testval" )
282
332
warnings .warn (
@@ -285,9 +335,6 @@ def __new__(
285
335
stacklevel = 2 ,
286
336
)
287
337
288
- if not isinstance (name , string_types ):
289
- raise TypeError (f"Name needs to be a string but got: { name } " )
290
-
291
338
dims = convert_dims (dims )
292
339
if observed is not None :
293
340
observed = convert_observed_data (observed )
0 commit comments