Skip to content

Update ptxla training #9864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 6, 2024
Merged

Conversation

entrpn
Copy link
Contributor

@entrpn entrpn commented Nov 4, 2024

  • Updates TPU benchmark numbers.
  • Updates the ptxla training example code.
  • Adds flash attention to ptxla code running on TPUs.

@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Cc: @yiyixuxu could you review the changes made to attention_processor.py?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

@zpcore
Copy link
Contributor

zpcore commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

Hi @yiyixuxu , we wrapped the flash attention kernel call under condition if XLA_AVAILABLE. This shouldn't touch the default attention processor behavior. Can you give more details about use a custom attention? Thanks

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA

cc @DN6 here too

@entrpn
Copy link
Contributor Author

entrpn commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

@zpcore
Copy link
Contributor

zpcore commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

Yes, I can follow up with the code change.

@zpcore
Copy link
Contributor

zpcore commented Nov 5, 2024

Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with AttnProcessor2_0? My concern is that majority of the code will be the same as AttnProcessor2_0.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

@zpcore
that should not be a problem. a lot of our attention processors share majority of same code, e.g. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L732 and https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L2443

this way user can explicitly set to use flash attention if they want to

@miladm
Copy link

miladm commented Nov 6, 2024

@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition if XLA_AVAILABLE causes a trouble? Do you want this functionality to be more generalized?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

is it not possible that XLA_AVAILABLE but the user does not want to use flash attention?
our attention processors are designed to be very easy to switch & each one corresponding to a very specific method -> could be xformer, SDPA, or even like special method like fused has its own processor

@sayakpaul
Copy link
Member

@miladm @zpcore a gentle ping

@zpcore
Copy link
Contributor

zpcore commented Nov 28, 2024

Thanks for the review feedback. We split out the XLA flash attention process from AttnProcessor2_0 as requested in the review. PTAL

@sayakpaul sayakpaul requested a review from yiyixuxu November 29, 2024 02:02
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this and for being patient with our feedback.

I left a few minor comments.

The other reviewer, @yiyixuxu will review this soon. Please allow for some time because of the thanksgiving week.

Comment on lines 2787 to 2789
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a new attention processor, I think we can safely remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@@ -2750,6 +2763,117 @@ def __call__(
return hidden_states


class XLAFlashAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this will be automatically used when using the compatible models under an XLA environment, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, AttnProcessor2_0 will be replaced with XLAFlashAttnProcessor2_0 if XLA version condition satisfied.

Comment on lines 39 to 40
if is_torch_xla_available():
from torch_xla.experimental.custom_kernel import flash_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to go through any version check guards too i.e., a minimum version known to have flash_attention?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced the version check function is_torch_xla_version in import_utils.py. Added the version check for torch_xla here.

AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
if is_torch_xla_available:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here too. Does this need to be guarded with a version check too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the version check for torch_xla here too.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Dec 2, 2024

I think @yiyixuxu's point here is valid:

is it not possible that XLA_AVAILABLE but the user does not want to use flash attention?

IMO it's better to use a similar API to xformers to enable the XLA processor.

def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:

@zpcore
Copy link
Contributor

zpcore commented Dec 3, 2024

I think @yiyixuxu's point here is valid:

is it not possible that XLA_AVAILABLE but the user does not want to use flash attention?

IMO it's better to use a similar API to xformers to enable the XLA processor.

def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:

OK, now I get it! We have added functions like enable_xla_flash_attention similar to the enable_xformers_memory_efficient_attention to give user the option to enable xla flash attention or not. In the example (train_text_to_image_xla.py) we give, we apply the kernel to the diffusion model unet. Thanks!

@sayakpaul sayakpaul requested a review from DN6 December 4, 2024 01:27
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! l think we can merge this soon!

Comment on lines 297 to 303
if (
use_xla_flash_attention
and is_torch_xla_available
and is_torch_xla_version('>', '2.2')
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
):
processor = XLAFlashAttnProcessor2_0(partition_spec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (
use_xla_flash_attention
and is_torch_xla_available
and is_torch_xla_version('>', '2.2')
and (not is_spmd() or is_torch_xla_version('>', '2.3'))
):
processor = XLAFlashAttnProcessor2_0(partition_spec)
if use_xla_flash_attention:
if is_torch_xla_version("<", "2.3"):
raise ...
elif is_spmd() and is_torch_xla_version("<", "2.4"):
raise ...
else:
processor = XLAFlashAttnProcessor2_0(partition_spec)
):
processor = XLAFlashAttnProcessor2_0(partition_spec)

if user explicitly set xla_flash_attention, we want to give very explicit warn/error message when the condition wasn't met so they can take actions accordingly - we don't want to silently switch to something just because it wasn't installed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, updated!

partition_spec = self.partition_spec if is_spmd() else None
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
else:
hidden_states = F.scaled_dot_product_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to support SDPA in this XLAFlash attention processor! - we can remove all the logics related to it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a constraint when using the pallas kernel. We need this all(tensor.shape[2] >= 4096 for tensor in [query, key, value]) or xla will error out.

However, we added a new error message when it fall back to scaled_dot_product_attention to avoid silently skip the kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thank you for explaining to me!

Comment on lines 2792 to 2793
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

I think we don't need to support SDPA in XLA Flash attention processor! let's remove all the logics related to that to simplify things a bit!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check my comment above for why we still keep it here. Thanks

@@ -700,6 +700,19 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)


def is_torch_xla_version(operation: str, version: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make sure we can call is_torch_xla_version() when it is not installed? currently, I think you will have to run it together with is_torch_xla_available(), because the _torch_xla_version is not defined otherwise

we can do like this

_torch_version = "N/A"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, updated.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@yiyixuxu yiyixuxu merged commit 3cb7b86 into huggingface:main Dec 6, 2024
15 checks passed
@entrpn
Copy link
Contributor Author

entrpn commented Dec 9, 2024

Thank you all!

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* update ptxla example

---------

Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Co-authored-by: Pei Zhang <zpcore@gmail.com>
Co-authored-by: Pei Zhang <piz@google.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pei Zhang <pei@Peis-MacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants