-
Notifications
You must be signed in to change notification settings - Fork 6k
Update ptxla training #9864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ptxla training #9864
Conversation
Cc: @yiyixuxu could you review the changes made to |
@entrpn can you use a custom attention instead? (without updating our default attention processor) |
Hi @yiyixuxu , we wrapped the flash attention kernel call under condition |
I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA cc @DN6 here too |
Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with |
@zpcore this way user can explicitly set to use flash attention if they want to |
@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition |
is it not possible that XLA_AVAILABLE but the user does not want to use flash attention? |
Thanks for the review feedback. We split out the XLA flash attention process from AttnProcessor2_0 as requested in the review. PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this and for being patient with our feedback.
I left a few minor comments.
The other reviewer, @yiyixuxu will review this soon. Please allow for some time because of the thanksgiving week.
if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
deprecate("scale", "1.0.0", deprecation_message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a new attention processor, I think we can safely remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
@@ -2750,6 +2763,117 @@ def __call__( | |||
return hidden_states | |||
|
|||
|
|||
class XLAFlashAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this will be automatically used when using the compatible models under an XLA environment, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, AttnProcessor2_0 will be replaced with XLAFlashAttnProcessor2_0 if XLA version condition satisfied.
if is_torch_xla_available(): | ||
from torch_xla.experimental.custom_kernel import flash_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to go through any version check guards too i.e., a minimum version known to have flash_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introduced the version check function is_torch_xla_version
in import_utils.py. Added the version check for torch_xla here.
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | ||
) | ||
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: | ||
if is_torch_xla_available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here too. Does this need to be guarded with a version check too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the version check for torch_xla here too.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I think @yiyixuxu's point here is valid:
IMO it's better to use a similar API to xformers to enable the XLA processor.
|
OK, now I get it! We have added functions like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! l think we can merge this soon!
if ( | ||
use_xla_flash_attention | ||
and is_torch_xla_available | ||
and is_torch_xla_version('>', '2.2') | ||
and (not is_spmd() or is_torch_xla_version('>', '2.3')) | ||
): | ||
processor = XLAFlashAttnProcessor2_0(partition_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ( | |
use_xla_flash_attention | |
and is_torch_xla_available | |
and is_torch_xla_version('>', '2.2') | |
and (not is_spmd() or is_torch_xla_version('>', '2.3')) | |
): | |
processor = XLAFlashAttnProcessor2_0(partition_spec) | |
if use_xla_flash_attention: | |
if is_torch_xla_version("<", "2.3"): | |
raise ... | |
elif is_spmd() and is_torch_xla_version("<", "2.4"): | |
raise ... | |
else: | |
processor = XLAFlashAttnProcessor2_0(partition_spec) | |
): | |
processor = XLAFlashAttnProcessor2_0(partition_spec) |
if user explicitly set xla_flash_attention, we want to give very explicit warn/error message when the condition wasn't met so they can take actions accordingly - we don't want to silently switch to something just because it wasn't installed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, updated!
partition_spec = self.partition_spec if is_spmd() else None | ||
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) | ||
else: | ||
hidden_states = F.scaled_dot_product_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need to support SDPA in this XLAFlash attention processor! - we can remove all the logics related to it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a constraint when using the pallas kernel. We need this all(tensor.shape[2] >= 4096 for tensor in [query, key, value])
or xla will error out.
However, we added a new error message when it fall back to scaled_dot_product_attention
to avoid silently skip the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok thank you for explaining to me!
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
I think we don't need to support SDPA in XLA Flash attention processor! let's remove all the logics related to that to simplify things a bit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check my comment above for why we still keep it here. Thanks
@@ -700,6 +700,19 @@ def is_torch_version(operation: str, version: str): | |||
return compare_versions(parse(_torch_version), operation, version) | |||
|
|||
|
|||
def is_torch_xla_version(operation: str, version: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make sure we can call is_torch_xla_version()
when it is not installed? currently, I think you will have to run it together with is_torch_xla_available()
, because the _torch_xla_version
is not defined otherwise
we can do like this
_torch_version = "N/A" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you!
Thank you all! |
* update ptxla example --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Pei Zhang <zpcore@gmail.com> Co-authored-by: Pei Zhang <piz@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pei Zhang <pei@Peis-MacBook-Pro.local> Co-authored-by: hlky <hlky@hlky.ac>
@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.