Skip to content

Update example in "Adding Jax support for Ops" #654

Closed
@HarshvirSandhu

Description

@HarshvirSandhu

Issue with current documentation:

While reproducing this example, I encountered the following error:
AssertionError: (Eye{dtype='float64'}(<Scalar(float64, shape=())>, <Scalar(float64, shape=())>, 0), 'float64') from test_jax_Eye()

Complete code to reproduce the error:

import jax.numpy as jnp

from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt

@jax_funcify.register(Eye)
def jax_funcify_Eye(op):

    dtype = op.dtype

    def eye(N, M, k):
        return jnp.eye(N, M, k, dtype=dtype)

    return eye


def test_jax_Eye():
    """Test JAX conversion of the `Eye` `Op`."""

    x_at = pt.scalar()
    eye_var = pt.eye(x_at)

    out_fg = FunctionGraph(outputs=[eye_var])

    compare_jax_and_py(out_fg, [3])


test_jax_Eye()

Idea or request for content:

Instead of passing x_at to pt.eye, an integer can be used, like so:

def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = ptb.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions