Skip to content

xpu lora support #3232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 147 additions & 23 deletions server/text_generation_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master

from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
Expand Down Expand Up @@ -128,17 +127,27 @@ def __init__(
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
if SYSTEM == "ipex":
self._use_cutlass_shrink = False
# [num_layers, r, hidden_size]
weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)

# [num_layers, hidden_size, r]
weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]
self._weights_b = torch.stack(weights_b)
else:
self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)
# [num_layers, hidden_size, r]
weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()
for w in weights_a
]
self._weights_a = torch.stack(weights_a)

# [num_layers, hidden_size, r]
weights_a = [
punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a
]
self._weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)

self.adapter_config = adapter_config

Expand Down Expand Up @@ -175,7 +184,10 @@ def _transpose_weights(self):

@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
if SYSTEM == "ipex":
return [IPEXBatchLoraWeights]
else:
return [BatchLoraWeights]

# prepare pre-loaded lora weights for use in the model.
#
Expand Down Expand Up @@ -245,17 +257,20 @@ def prepare_weights(
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [
punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list
]

if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
if SYSTEM != "ipex":
lora_a_list = [
punica_sgmv.pad_rank(w, dim=1, world_size=world_size)
for w in lora_a_list
]
lora_b_list = [
punica_sgmv.pad_rank(w, dim=0, world_size=world_size)
for w in lora_b_list
]

if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank

return LoraWeights(
*shard_lora_weights(
Expand Down Expand Up @@ -471,6 +486,115 @@ def load(
)


@dataclass
class IPEXBatchLoraWeights(BatchLoraWeights):
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None

first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
if len(lora_a) != 0:
lora_a_ptr = torch.stack(list(lora_a.values()))
if len(lora_b) != 0:
lora_b_ptr = torch.stack(list(lora_b.values()))

use_sgmv = True if prefill else False

adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)

if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1

rank_data = {}
segment_starts = None
segment_ends = None
if use_sgmv:
segment_starts = meta.adapter_segments[:-1]
segment_ends = meta.adapter_segments[1:]
if prefill_head_indices is not None:
segment_starts = prefill_head_segment_starts[:-1]
segment_ends = prefill_head_segment_ends[1:]
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
for rank, indices in rank_indices.items():
adapters_indices = []
lora_a_keys = list(lora_a.keys())
for segment_idx in batch_indices:
if segment_idx in indices:
adapters_indices.append(
lora_a_keys.index(segment_indices[segment_idx])
)
else:
adapters_indices.append(-1)
adapters_indices = torch.tensor(
adapters_indices, dtype=torch.int64, device=device
)
if use_sgmv:
adapters_indices = adapters_indices[segment_starts]
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=None,
tmp_expand=None,
lora_a_ptr=lora_a_ptr,
lora_b_ptr=lora_b_ptr,
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=adapters_indices,
)

return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)


def get_scaling_factor(
lora_alpha: int,
r: int,
Expand Down
87 changes: 86 additions & 1 deletion server/text_generation_server/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup

from text_generation_server.utils.import_utils import SYSTEM

from text_generation_server.utils.kernels import load_kernel

if SYSTEM == "cuda":
Expand Down Expand Up @@ -121,6 +121,91 @@ def forward_layer_type(
self.layer_id,
)

if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
elif SYSTEM == "ipex" and data is not None:
from intel_extension_for_pytorch.llm.functional import (
bgmv_expand,
bgmv_shrink,
sgmv_expand,
sgmv_shrink,
)

# In IPEX, we provide the same API for sgmv and bgmv
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result

for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr[:, self.layer_id, :].contiguous()
lora_b_ptr = rank_segments.lora_b_ptr[:, self.layer_id, :].contiguous()

if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")

if data.use_sgmv:
# Use SGMV for prefill
seq_len_tensor = (
rank_segments.segment_ends - rank_segments.segment_starts
).to(torch.int64)
b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)
total_tokens = seq_len_tensor.sum()
v = torch.zeros(
(total_tokens, r), dtype=input.dtype, device=input.device
)
bs = seq_len_tensor.shape[0]
sgmv_shrink(
input,
lora_a_ptr,
v,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
1.0,
)

if self.process_group.size() > 1:
v = self.collect_lora_a(v)

sgmv_expand(
v,
lora_b_ptr,
proj,
b_seq_start_loc,
seq_len_tensor,
rank_segments.indices,
bs,
seq_len_tensor.max().item(),
add_inputs=True,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
bgmv_shrink(
input,
lora_a_ptr,
v,
rank_segments.indices,
1.0,
)

if self.process_group.size() > 1:
v = self.collect_lora_a(v)

bgmv_expand(
v,
lora_b_ptr,
proj,
rank_segments.indices,
add_inputs=True,
)

if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
Expand Down