Skip to content

ENH: add new function one_hot #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented May 29, 2025

Fixes #305

Questions:

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

@lucascolley lucascolley changed the title Add one_hot ENH: add new function one_hot May 29, 2025
@lucascolley lucascolley added enhancement New feature or request new function labels May 29, 2025
@lucascolley lucascolley added this to the 0.8.0 milestone May 29, 2025
@NeilGirdhar NeilGirdhar force-pushed the onehot branch 3 times, most recently from 9b5f393 to 02544de Compare May 30, 2025 05:47
@lucascolley
Copy link
Member

RE dtype, I think data-apis/array-api#848 will give us something slightly cleaner down the line.

In SciPy we have been using https://github.com/scipy/scipy/blob/main/scipy/_lib/_array_api.py#L399 with force_floating=True. Taking xp.empty(()).dtype is probably fine for now though.

@NeilGirdhar
Copy link
Contributor Author

Yeah, I'm not 100% sure how those links will help in this case? We're not casting one thing to another kind, or promoting to float. We just want the default float dtype irrespective of what was passed in. You may want to consider giving names to:

xp.asarray(1j).dtype  # Default complex
xp.asarray(1).dtype  # Default int
xp.empty(()).dtype  # Default float

I'm not sure though, and these are one-liners.

@lucascolley
Copy link
Member

I'm not 100% sure how those links will help in this case?

the point is that instead of writing

dtype = xp.empty(()).dtype
x = xp.zeros(..., dtype=dtype)

we would have

x = xp.zeros(...)
x = xp.astype(x, 'real floating')

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

Ah, right. Okay. I guess you mean askind, right?

Also, could you help me solve the Dask errors? This is all foreign to me.

And how do I make an array with a non-concrete size? (x.size=None)

@lucascolley
Copy link
Member

I guess you mean askind, right?

Nope, data-apis/array-api#848 suggests overloading the dtype parameter of astype for this. Feel free to comment over there if you disagree!

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

suggests overloading the dtype parameter of astype for this. Feel free to comment over there if you disagree!

Oh, no worries, I don't have a strong opinion. Just trying to keep up with all the planned changes 😄

Did you see my edits? I could use some guidance with the Dask errors.

@crusaderky
Copy link
Contributor

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

to me this doesn't make much sense. Why shouldn't it be bool?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

to me this doesn't make much sense. Why shouldn't it be bool?

I'd rather follow what the libraries are doing than force double conversion for delegated code. If you do that, most people would end up having to write their own one_hot method in order to avoid it.

In general, the reason it's not bool is because these values often serve as the inputs to machine learning algorithms.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 31, 2025

@lucascolley Ready for your review

Comment on lines +399 to +406
if supports_fancy_indexing:
out = at(out)[xp.arange(x_size), x_flattened].set(1)
else:
for i in range(x_size):
x_i = x_flattened[i]
if not supports_array_indexing:
x_i = int(x_i)
out = at(out)[i, x_i].set(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crusaderky have we encountered patterns like this before?

Comment on lines +456 to +458
@pytest.mark.skip_xp_backend(
Backend.DASK, reason="backend does not yet support indexed assignment"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this consistent with our skips elsewhere for the same reason @crusaderky ? (I'm on mobile, not super easy to check)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: new function one_hot
3 participants