Skip to content

Commit 3758029

Browse files
authored
Warmup gaudi backend (#3172)
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * LLM warmup logic Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality warmup Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * optimize code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine log and fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix warmup issue for mllama Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * pingpong optimization Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * work with the latest vllm extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_scales which is not needed anymore Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * prefill bypass graph Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * pingpong optimization issue fix Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 02715dc commit 3758029

File tree

7 files changed

+925
-563
lines changed

7 files changed

+925
-563
lines changed

backends/gaudi/server/text_generation_server/layers/attention/common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class HPUPagedAttentionMetadata:
1313
block_list: Optional[torch.Tensor]
1414
block_mapping: Optional[torch.Tensor]
1515
block_usage: Optional[torch.Tensor]
16-
block_scales: Optional[torch.Tensor]
1716
block_groups: Optional[torch.Tensor]
1817
attn_bias: Optional[torch.Tensor]
1918

@@ -66,7 +65,6 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
6665
"block_list",
6766
"block_mapping",
6867
"block_usage",
69-
"block_scales",
7068
"block_groups",
7169
"attn_bias",
7270
],

backends/gaudi/server/text_generation_server/layers/attention/hpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def paged_attention(
7474
block_list=hpu_attention_meta.block_list,
7575
block_mapping=hpu_attention_meta.block_mapping,
7676
block_bias=hpu_attention_meta.attn_bias,
77-
block_scales=hpu_attention_meta.block_scales,
7877
block_groups=hpu_attention_meta.block_groups,
7978
scale=softmax_scale,
8079
matmul_qk_op=Matmul(),

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -681,11 +681,10 @@ def forward(
681681
# bsz, q_len, _ = hidden_states.size()
682682
(
683683
cross_attention_states,
684-
cu_seqlen_q,
685-
cu_seqlen_k,
684+
cross_attention_len,
686685
indices,
687686
) = cross_attention_states
688-
bs = cu_seqlen_q.size(0) - 1
687+
bs = cross_attention_len.size(0)
689688
query_states = self.q_proj(hidden_states)
690689
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
691690
query_states = self.q_norm(query_states)
@@ -814,8 +813,6 @@ def forward(
814813

815814
indices = cross_attention_states[-1]
816815
out_hidden_states = hidden_states[:]
817-
if len(indices) > 0:
818-
assert max(indices) < hidden_states.shape[0]
819816
hidden_states = hidden_states[indices]
820817
residual = hidden_states
821818
hidden_states = self.input_layernorm(hidden_states)
@@ -914,59 +911,14 @@ def forward(
914911
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
915912
lm_head_indices: Optional[torch.Tensor],
916913
adapter_data: Optional[torch.Tensor] = None,
917-
# XXX: Putting these as optional so that the cuda warmup calls can go through.
918914
cross_attention_states: Optional[torch.Tensor] = None,
919-
image_indices=None,
915+
indices=None,
916+
cross_attention_len: Optional[torch.Tensor] = None,
920917
):
921918
if cross_attention_states is not None:
922-
seqlen_q = len(image_indices)
923-
n_images = cross_attention_states.shape[0]
924-
seqlen_k = cross_attention_states.shape[1]
925-
device = cross_attention_states.device
926-
if cu_seqlen_prefill is not None:
927-
offset = 0
928-
cu_q = []
929-
indices = []
930-
for index in image_indices:
931-
cu_q.append(offset)
932-
length = seqlen.input_lengths[index].item()
933-
assert index < seqlen.cu_seqlen_q.shape[0]
934-
input_ids_offset = seqlen.cu_seqlen_q[index]
935-
indices.extend(range(input_ids_offset, input_ids_offset + length))
936-
offset += length
937-
cu_q.append(offset)
938-
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
939-
940-
assert max(indices) < input_ids.shape[0]
941-
942-
cu_seqlen_k = (
943-
torch.arange(
944-
n_images + 1,
945-
device=device,
946-
dtype=torch.int32,
947-
)
948-
* seqlen_k
949-
)
950-
else:
951-
cu_seqlen_q = torch.arange(
952-
seqlen_q + 1, device=device, dtype=torch.int32
953-
)
954-
seqlen_k = cross_attention_states.shape[1]
955-
n_images = cross_attention_states.shape[0]
956-
cu_seqlen_k = (
957-
torch.arange(
958-
n_images + 1,
959-
device=device,
960-
dtype=torch.int32,
961-
)
962-
* seqlen_k
963-
)
964-
indices = image_indices[:]
965-
966919
cross_attention_states = (
967920
cross_attention_states,
968-
cu_seqlen_q,
969-
cu_seqlen_k,
921+
cross_attention_len,
970922
indices,
971923
)
972924

0 commit comments

Comments
 (0)