From 7992b1679d7365dcc3bd9c01201b02ce32256fa0 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 1 May 2025 14:36:32 -0700 Subject: [PATCH] Make key optional in ipex.llm.functional.rotary_embedding --- .../llm/functional/fusions.py | 7 ++-- .../llm/modules/mha_fusion.py | 14 ++++--- .../models/cpu/fusions/mha_fusion.py | 40 ++++++++++++------- tests/cpu/test_ipex_llm_module.py | 12 ++++-- 4 files changed, 46 insertions(+), 27 deletions(-) diff --git a/intel_extension_for_pytorch/llm/functional/fusions.py b/intel_extension_for_pytorch/llm/functional/fusions.py index 093d013ee..cd4a23b0a 100644 --- a/intel_extension_for_pytorch/llm/functional/fusions.py +++ b/intel_extension_for_pytorch/llm/functional/fusions.py @@ -13,7 +13,7 @@ def rotary_embedding( query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor], sin: torch.Tensor, cos: torch.Tensor, rotary_dim: int, @@ -25,9 +25,10 @@ def rotary_embedding( on the `query ` or `key` before their multi-head attention computation. Args: - query, key (torch.Tensor) : inputs to be applied with position embeddings, + query (torch.Tensor), key (Optional[torch.Tensor]): inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). + `key` may be `None`, e.g. in case of cross-layer KV sharing. sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key. rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. @@ -42,7 +43,7 @@ def rotary_embedding( The according position_ids for the input. The shape should be [batch size, sequence length]. Return - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] + query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim]. """ diff --git a/intel_extension_for_pytorch/llm/modules/mha_fusion.py b/intel_extension_for_pytorch/llm/modules/mha_fusion.py index 2102eec38..d1de49de4 100644 --- a/intel_extension_for_pytorch/llm/modules/mha_fusion.py +++ b/intel_extension_for_pytorch/llm/modules/mha_fusion.py @@ -49,14 +49,15 @@ class RotaryEmbedding(nn.Module): [Direct function call] This module also provides a `.apply_function` function call to be used on query and key at the same time without initializing the module - (assume rotary embedding sin/cos values are provided). + (assume rotary embedding sin/cos values are provided). `key` is optional for `.apply_function` call. `apply_function()` Args: - query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of + query (torch.Tensor), key (Optional[torch.Tensor]) : inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). + `key` may be None, e.g. in case of cross-layer KV sharing. sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key. rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. head_dim (int) : head dim from the input shape. @@ -68,7 +69,7 @@ class RotaryEmbedding(nn.Module): for the input. The shape should be [batch size, sequence length]. Return: - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] + query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim]. """ @@ -137,14 +138,17 @@ def forward( def apply_function( cls, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor], sin: torch.Tensor, cos: torch.Tensor, rotary_dim: int, rotary_half: bool, position_ids: torch.Tensor = None, ): - # query, key (in/out shape) torch.Tensor : + # query: torch.Tensor with in/out shape: + # 4D: [batch, seqlen, num_head/num_kv_head, head_dim] + # 3D: [num_tokens, num_head/num_kv_head, head_dim] + # key (optional) None or torch.Tensor with in/out shape: # 4D: [batch, seqlen, num_head/num_kv_head, head_dim] # 3D: [num_tokens, num_head/num_kv_head, head_dim] # sin, cos: torch.Tensor [num_tokens, rotary_dim] diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py index 67a681a09..bc3cbc5aa 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py @@ -65,23 +65,28 @@ def forward( def rotary_embedding( cls, query, key, sin, cos, rotary_dim, rotary_half, position_ids=None ): - # query, key (in/out shape) torch.Tensor : + # query: torch.Tensor with in/out shape: + # 4D: [bs, seqlen, num_head/num_kv_head, head_dim] + # 3D: [num_tokens, num_head/num_kv_head, head_dim] + # key (optional) None or torch.Tensor with in/out shape: # 4D: [bs, seqlen, num_head/num_kv_head, head_dim] # 3D: [num_tokens, num_head/num_kv_head, head_dim] # sin, cos: torch.Tensor [num_tokens, rotary_dim] # position_ids (optional): torch.Tensor [bs, seqlen] head_dim = query.size(-1) num_head = query.size(-2) - num_kv_head = key.size(-2) + num_kv_head = key.size(-2) if key is not None else 0 input_3d = False assert ( - key.dim() == query.dim() and query.dim() == 3 or query.dim() == 4 + (key is None or key.dim() == query.dim()) + and query.dim() == 3 + or query.dim() == 4 ), "rotary embedding query/key dim == 3 or 4" if query.dim() == 3: input_3d = True query_ = query.unsqueeze(0) - key_ = key.unsqueeze(0) + key_ = key.unsqueeze(0) if key is not None else None else: query_ = query key_ = key @@ -124,21 +129,26 @@ def rotary_embedding( rotary_dim, ) - key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding( - key_, - sin_cos, - position_ids, - num_kv_head, - head_dim, - offset, - rotary_dim, - ) + if key is not None: + key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding( + key_, + sin_cos, + position_ids, + num_kv_head, + head_dim, + offset, + rotary_dim, + ) if input_3d: query_ = query_.view([-1, num_head, head_dim]) - key_ = key_.view([-1, num_kv_head, head_dim]) + if key_ is not None: + key_ = key_.view([-1, num_kv_head, head_dim]) # keep the inplace context as used in TGI query.copy_(query_) - key.copy_(key_) + + if key is not None: + key.copy_(key_) + return query, key diff --git a/tests/cpu/test_ipex_llm_module.py b/tests/cpu/test_ipex_llm_module.py index c505ea30e..e60caafa1 100644 --- a/tests/cpu/test_ipex_llm_module.py +++ b/tests/cpu/test_ipex_llm_module.py @@ -884,23 +884,27 @@ def test_rotary_embedding_tgi(self): (1, 32, 128), (32, 32, 128), ] - for size in test_tensor_size: + for size, use_key in itertools.product(test_tensor_size, [True, False]): q = torch.randn(size).float() - k = torch.randn(size).float() + k = torch.randn(size).float() if use_key else None rotary_dim = size[-1] seqlen = size[0] position_ids = torch.arange(size[0]) sin, cos = get_sin_cos(position_ids, rotary_dim, 10000, seqlen, q.dtype) ref_q = apply(q, cos, sin) - ref_k = apply(k, cos, sin) + ref_k = apply(k, cos, sin) if use_key else None ipex_q, ipex_k = ipex.llm.functional.rotary_embedding( q, k, sin, cos, rotary_dim, True ) self.assertEqual(ipex_q, ref_q) - self.assertEqual(ref_k, ipex_k) + if use_key: + self.assertEqual(ref_k, ipex_k) + else: + self.assertIsNone(ipex_k) + self.assertIsNone(ref_k) def test_add_layernorm(self): for add_back in [True, False]: