Skip to content

several torch binary functions don't accept scalars for one argument #271

Open
@mdhaber

Description

@mdhaber

Per the 2024.12 standard, these should work, but don't:

from array_api_compat import torch as xp
xp.minimum(xp.asarray(1), 2)  # TypeError: minimum(): argument 'other' (position 2) must be Tensor, not int
xp.minimum(2, xp.asarray(1))  # TypeError: minimum(): argument 'input' (position 1) must be Tensor, not int
xp.copysign(xp.asarray(1.), 2)  # OK
xp.copysign(2, xp.asarray(1.))
# TypeError: copysign() received an invalid combination of arguments - got (int, Tensor), but expected one of:
#  * (Tensor input, Tensor other, *, Tensor out = None)
#  * (Tensor input, Number other, *, Tensor out = None)

Looks like maximum also has trouble. Haven't checked the rest of the functions that now accept scalars, but I assume others may have trouble.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions