Skip to content

Handle more AdvancedSetSubtensor ops in numba #772

Closed
@aseyboldt

Description

@aseyboldt

Description

This model runs partially in objectmode:

import pymc as pm
import numpy as np
import pytensor

with pm.Model() as model:
    values = np.array([np.nan, 0.])
    x = pm.Normal("x", observed=values)
    pm.Normal("y", mu=x, observed=np.ones(2))

logp = model.logp()
func = pm.compile_pymc(model.value_vars, logp, mode="NUMBA")
pytensor.dprint(func)
Composite{(i2 + (i1 * sqr(i0)) + i3)} [id A] '__logp' 7
 ├─ DropDims{axis=0} [id B] 6
 │  └─ x_unobserved [id C]
 ├─ -0.5 [id D]
 ├─ -1.8378770664093453 [id E]
 └─ Sum{axes=None} [id F] 5
    └─ SpecifyShape [id G] 'sigma > 0' 4
       ├─ Composite{((-0.5 * sqr((1.0 - i0))) - 0.9189385332046727)} [id H] 3
       │  └─ AdvancedSetSubtensor [id I] 2
       │     ├─ AdvancedSetSubtensor [id J] 1
       │     │  ├─ AllocEmpty{dtype='float64'} [id K] 0
       │     │  │  └─ 2 [id L]
       │     │  ├─ x_unobserved [id C]
       │     │  └─ [ True False] [id M]
       │     ├─ [0.] [id N]
       │     └─ [False  True] [id O]
       └─ 2 [id P]
...

The AdvancedSetSubtensor nodes use object model in numba currently, but I think we should be able to implement those without too much trouble.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions