diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index c8eb48a20b7..15e79f7c31a 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -11,9 +11,8 @@ from peft import LoraConfig as _LoraConfig from torch.distributed import ProcessGroup from text_generation_server.utils.log import log_master - -from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.utils.kernels import load_kernel from text_generation_server.adapters.weights import ( AdapterBatchMetadata, @@ -128,17 +127,27 @@ def __init__( self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 - self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) self._is_transposed = False + if SYSTEM == "ipex": + self._use_cutlass_shrink = False + # [num_layers, r, hidden_size] + weights_a = [w.transpose(0, 1).contiguous() for w in weights_a] + self._weights_a = torch.stack(weights_a) + + # [num_layers, hidden_size, r] + weights_b = [w.transpose(0, 1).contiguous() for w in weights_b] + self._weights_b = torch.stack(weights_b) + else: + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) + # [num_layers, hidden_size, r] + weights_a = [ + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() + for w in weights_a + ] + self._weights_a = torch.stack(weights_a) - # [num_layers, hidden_size, r] - weights_a = [ - punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a - ] - self._weights_a = torch.stack(weights_a) - - # [num_layers, r, hidden_size] - self._weights_b = torch.stack(weights_b) + # [num_layers, r, hidden_size] + self._weights_b = torch.stack(weights_b) self.adapter_config = adapter_config @@ -175,7 +184,10 @@ def _transpose_weights(self): @classmethod def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: - return [BatchLoraWeights] + if SYSTEM == "ipex": + return [IPEXBatchLoraWeights] + else: + return [BatchLoraWeights] # prepare pre-loaded lora weights for use in the model. # @@ -245,17 +257,20 @@ def prepare_weights( lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [ - punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list - ] - lora_b_list = [ - punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list - ] - - if lora_a_list: - # update rank if it was padded - padded_rank = lora_a_list[0].size(1) - config.r = padded_rank + if SYSTEM != "ipex": + lora_a_list = [ + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) + for w in lora_a_list + ] + lora_b_list = [ + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) + for w in lora_b_list + ] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank return LoraWeights( *shard_lora_weights( @@ -471,6 +486,115 @@ def load( ) +@dataclass +class IPEXBatchLoraWeights(BatchLoraWeights): + @classmethod + def load( + self, + adapter_weights: Dict[int, AdapterWeights], + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Optional["BatchLoraWeights"]: + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} + adapter_weights = { + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) + } + if not adapter_weights: + return None + + first_weights = next(iter(adapter_weights.values())) + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + if len(lora_a) != 0: + lora_a_ptr = torch.stack(list(lora_a.values())) + if len(lora_b) != 0: + lora_b_ptr = torch.stack(list(lora_b.values())) + + use_sgmv = True if prefill else False + + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + if prefill_head_indices is not None: + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + for head_index in prefill_head_indices: + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + if head_index < meta.adapter_segments[j]: + prefill_head_segment_ends[-1] += 1 + else: + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) + j += 1 + + rank_data = {} + segment_starts = None + segment_ends = None + if use_sgmv: + segment_starts = meta.adapter_segments[:-1] + segment_ends = meta.adapter_segments[1:] + if prefill_head_indices is not None: + segment_starts = prefill_head_segment_starts[:-1] + segment_ends = prefill_head_segment_ends[1:] + batch_indices = [ + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() + ] + for rank, indices in rank_indices.items(): + adapters_indices = [] + lora_a_keys = list(lora_a.keys()) + for segment_idx in batch_indices: + if segment_idx in indices: + adapters_indices.append( + lora_a_keys.index(segment_indices[segment_idx]) + ) + else: + adapters_indices.append(-1) + adapters_indices = torch.tensor( + adapters_indices, dtype=torch.int64, device=device + ) + if use_sgmv: + adapters_indices = adapters_indices[segment_starts] + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=None, + tmp_expand=None, + lora_a_ptr=lora_a_ptr, + lora_b_ptr=lora_b_ptr, + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=adapters_indices, + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) + + def get_scaling_factor( lora_alpha: int, r: int, diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index abfb097ddcd..5e5a737bb7c 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -4,8 +4,8 @@ import torch.distributed from torch import nn from torch.distributed import ProcessGroup - from text_generation_server.utils.import_utils import SYSTEM + from text_generation_server.utils.kernels import load_kernel if SYSTEM == "cuda": @@ -121,6 +121,91 @@ def forward_layer_type( self.layer_id, ) + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + elif SYSTEM == "ipex" and data is not None: + from intel_extension_for_pytorch.llm.functional import ( + bgmv_expand, + bgmv_shrink, + sgmv_expand, + sgmv_shrink, + ) + + # In IPEX, we provide the same API for sgmv and bgmv + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous() + lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous() + + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + + if data.use_sgmv: + # Use SGMV for prefill + seq_len_tensor = ( + rank_segments.segment_ends - rank_segments.segment_starts + ).to(torch.int64) + b_seq_start_loc = rank_segments.segment_starts.to(torch.int64) + total_tokens = seq_len_tensor.sum() + v = torch.zeros( + (total_tokens, r), dtype=input.dtype, device=input.device + ) + bs = seq_len_tensor.shape[0] + sgmv_shrink( + input, + lora_a_ptr, + v, + b_seq_start_loc, + seq_len_tensor, + rank_segments.indices, + bs, + seq_len_tensor.max().item(), + 1.0, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + sgmv_expand( + v, + lora_b_ptr, + proj, + b_seq_start_loc, + seq_len_tensor, + rank_segments.indices, + bs, + seq_len_tensor.max().item(), + add_inputs=True, + ) + else: + # Use BGMV for decode + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + bgmv_shrink( + input, + lora_a_ptr, + v, + rank_segments.indices, + 1.0, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + bgmv_expand( + v, + lora_b_ptr, + proj, + rank_segments.indices, + add_inputs=True, + ) + if end_idx - start_idx != result.shape[1]: result[:, start_idx:end_idx] += proj else: