Skip to content

Commit 83cda09

Browse files
sywangyimfuntowiczNarsil
authored andcommitted
add intel xpu support for TGI (huggingface#1475)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
1 parent 959b026 commit 83cda09

18 files changed

+433
-76
lines changed

.github/workflows/build.yaml

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,105 @@ jobs:
274274
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
275275
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
276276

277+
build-and-push-image-intel:
278+
concurrency:
279+
group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }}
280+
cancel-in-progress: true
281+
needs:
282+
- start-runner
283+
- build-and-push-image # Wait for the main docker image to be built
284+
- integration-tests # Wait for the main integration-tests
285+
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
286+
permissions:
287+
contents: write
288+
packages: write
289+
# This is used to complete the identity challenge
290+
# with sigstore/fulcio when running outside of PRs.
291+
id-token: write
292+
security-events: write
293+
steps:
294+
- name: Checkout repository
295+
uses: actions/checkout@v3
296+
- name: Initialize Docker Buildx
297+
uses: docker/setup-buildx-action@v2.0.0
298+
with:
299+
install: true
300+
- name: Inject slug/short variables
301+
uses: rlespinasse/github-slug-action@v4.4.1
302+
- name: Tailscale
303+
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
304+
with:
305+
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
306+
- name: Login to GitHub Container Registry
307+
if: github.event_name != 'pull_request'
308+
uses: docker/login-action@v2
309+
with:
310+
registry: ghcr.io
311+
username: ${{ github.actor }}
312+
password: ${{ secrets.GITHUB_TOKEN }}
313+
- name: Login to internal Container Registry
314+
uses: docker/login-action@v2.1.0
315+
with:
316+
username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }}
317+
password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }}
318+
registry: registry.internal.huggingface.tech
319+
- name: Login to Azure Container Registry
320+
if: github.event_name != 'pull_request'
321+
uses: docker/login-action@v2.1.0
322+
with:
323+
username: ${{ secrets.AZURE_DOCKER_USERNAME }}
324+
password: ${{ secrets.AZURE_DOCKER_PASSWORD }}
325+
registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io
326+
# If pull request
327+
- name: Extract metadata (tags, labels) for Docker
328+
if: ${{ github.event_name == 'pull_request' }}
329+
id: meta-pr
330+
uses: docker/metadata-action@v4.3.0
331+
with:
332+
images: |
333+
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
334+
tags: |
335+
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
336+
# If main, release or tag
337+
- name: Extract metadata (tags, labels) for Docker
338+
if: ${{ github.event_name != 'pull_request' }}
339+
id: meta
340+
uses: docker/metadata-action@v4.3.0
341+
with:
342+
flavor: |
343+
latest=false
344+
images: |
345+
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
346+
ghcr.io/huggingface/text-generation-inference
347+
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
348+
tags: |
349+
type=semver,pattern={{version}}-intel
350+
type=semver,pattern={{major}}.{{minor}}-intel
351+
type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
352+
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel
353+
- name: Build and push Docker image
354+
id: build-and-push
355+
uses: docker/build-push-action@v4
356+
with:
357+
context: .
358+
file: Dockerfile_intel
359+
push: true
360+
platforms: 'linux/amd64'
361+
build-args: |
362+
GIT_SHA=${{ env.GITHUB_SHA }}
363+
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel
364+
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
365+
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
366+
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
367+
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min
368+
277369
stop-runner:
278370
name: Stop self-hosted EC2 runner
279371
needs:
280372
- start-runner
281373
- build-and-push-image
282374
- build-and-push-image-rocm
375+
- build-and-push-image-intel
283376
- integration-tests
284377
runs-on: ubuntu-latest
285378
env:

Dockerfile_intel

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
2+
WORKDIR /usr/src
3+
4+
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
5+
6+
FROM chef as planner
7+
COPY Cargo.toml Cargo.toml
8+
COPY rust-toolchain.toml rust-toolchain.toml
9+
COPY proto proto
10+
COPY benchmark benchmark
11+
COPY router router
12+
COPY launcher launcher
13+
RUN cargo chef prepare --recipe-path recipe.json
14+
15+
FROM chef AS builder
16+
17+
ARG GIT_SHA
18+
ARG DOCKER_LABEL
19+
20+
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
21+
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
22+
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
23+
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
24+
rm -f $PROTOC_ZIP
25+
26+
COPY --from=planner /usr/src/recipe.json recipe.json
27+
RUN cargo chef cook --release --recipe-path recipe.json
28+
29+
COPY Cargo.toml Cargo.toml
30+
COPY rust-toolchain.toml rust-toolchain.toml
31+
COPY proto proto
32+
COPY benchmark benchmark
33+
COPY router router
34+
COPY launcher launcher
35+
RUN cargo build --release
36+
37+
38+
# Text Generation Inference base image for Intel
39+
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
40+
41+
USER root
42+
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
43+
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
44+
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
45+
46+
47+
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
48+
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
49+
50+
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
51+
52+
# Text Generation Inference base env
53+
ENV HUGGINGFACE_HUB_CACHE=/data \
54+
HF_HUB_ENABLE_HF_TRANSFER=1 \
55+
PORT=80
56+
57+
58+
WORKDIR /usr/src
59+
# Build pytorch and ipex
60+
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b xpu_main origin/xpu-main
61+
RUN git clone https://github.com/pytorch/pytorch.git && cd pytorch && git checkout 209f2fa8ff86652f67d75c2f19bf9cb9942fd018 && git apply /usr/src/intel-extension-for-pytorch/torch_patches/00*.patch
62+
63+
# Install server
64+
COPY proto proto
65+
COPY server server
66+
COPY server/Makefile server/Makefile
67+
RUN cd server && \
68+
make gen-server && \
69+
pip install -r requirements_cuda.txt && \
70+
pip install ".[accelerate, peft, outlines]" --no-cache-dir
71+
72+
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
73+
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
74+
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
75+
ENV DIAGUTIL_PATH=/opt/intel/oneapi/compiler/latest/etc/compiler/sys_check/sys_check.sh
76+
ENV CCL_CONFIGURATION=cpu_gpu_dpcpp
77+
ENV MANPATH=/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/compiler/latest/share/man
78+
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
79+
ENV CMPLR_ROOT=/opt/intel/oneapi/compiler/latest
80+
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
81+
ENV OCL_ICD_FILENAMES=libintelocl_emu.so:libalteracl.so:/opt/intel/oneapi/compiler/latest/lib/libintelocl.so
82+
ENV CLASSPATH=/opt/intel/oneapi/mpi/latest/share/java/mpi.jar:/opt/intel/oneapi/mpi/latest/share/java/mpi.jar
83+
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
84+
ENV MKLROOT=/opt/intel/oneapi/mkl/latest
85+
ENV NLSPATH=/opt/intel/oneapi/mkl/latest/share/locale/%l_%t/%N:/opt/intel/oneapi/compiler/latest/lib/locale/%l_%t/%N
86+
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
87+
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
88+
ENV CCL_ZE_IPC_EXCHANGE=sockets
89+
90+
91+
RUN pip uninstall -y torch && cd pytorch && git submodule update --init --recursive && python setup.py install
92+
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=ON BUILD_WITH_CPU=ON USE_XETLA=ON python setup.py install
93+
94+
# Install benchmarker
95+
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
96+
# Install router
97+
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
98+
# Install launcher
99+
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
100+
101+
# Final image
102+
FROM base
103+
104+
ENTRYPOINT ["text-generation-launcher"]
105+
CMD ["--json-output"]

launcher/src/env_runtime.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ pub(crate) struct Env {
77
git_sha: &'static str,
88
docker_label: &'static str,
99
nvidia_env: String,
10+
xpu_env: String,
1011
}
1112

1213
impl Env {
1314
pub fn new() -> Self {
1415
let nvidia_env = nvidia_smi();
16+
let xpu_env = xpu_smi();
1517

1618
Self {
1719
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
20+
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
1821
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
1922
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
2023
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
@@ -31,7 +34,8 @@ impl fmt::Display for Env {
3134
writeln!(f, "Cargo version: {}", self.cargo_version)?;
3235
writeln!(f, "Commit sha: {}", self.git_sha)?;
3336
writeln!(f, "Docker label: {}", self.docker_label)?;
34-
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
37+
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
38+
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
3539

3640
Ok(())
3741
}
@@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
4347
let output = nvidia_smi.replace('\n', "\n ");
4448
Some(output.trim().to_string())
4549
}
50+
51+
fn xpu_smi() -> Option<String> {
52+
let output = Command::new("xpu-smi").arg("discovery").output().ok()?;
53+
let xpu_smi = String::from_utf8(output.stdout).ok()?;
54+
let output = xpu_smi.replace('\n', "\n ");
55+
Some(output.trim().to_string())
56+
}

server/text_generation_server/models/cache_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from typing import Optional, List, Tuple
5+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
56

67
BLOCK_SIZE: int = 16
78
# Will be set in warmup
@@ -24,7 +25,10 @@ def __init__(
2425
self.repeat_slots = repeat_slots
2526

2627
element_size = torch.tensor([], dtype=dtype).element_size()
27-
x = self.block_size // element_size
28+
if IS_XPU_SYSTEM:
29+
x = 1
30+
else:
31+
x = self.block_size // element_size
2832

2933
self.kv_cache = [
3034
(

server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
from transformers.configuration_utils import PretrainedConfig
2222
from typing import Optional, List, Tuple, Any
2323
from loguru import logger
24+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
2425

25-
from vllm.model_executor.layers.fused_moe import fused_moe
26+
if not IS_XPU_SYSTEM:
27+
from vllm.model_executor.layers.fused_moe import fused_moe
2628
from text_generation_server.utils import paged_attention, flash_attn
2729
from text_generation_server.utils.layers import (
2830
FastLinear,

server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
import numpy as np
2525

2626
from torch import nn
27-
from vllm.model_executor.layers.fused_moe import fused_moe
27+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
28+
29+
if not IS_XPU_SYSTEM:
30+
from vllm.model_executor.layers.fused_moe import fused_moe
2831
from transformers.activations import ACT2FN
2932
from transformers.configuration_utils import PretrainedConfig
3033
from typing import Optional, List, Tuple

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from text_generation_server.utils.dist import MEMORY_FRACTION
3434

3535
tracer = trace.get_tracer(__name__)
36-
36+
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
3737

3838
@dataclass
3939
class FlashCausalLMBatch(Batch):
@@ -752,7 +752,10 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
752752

753753
def warmup(self, batch: FlashCausalLMBatch):
754754
# The warmup batch is the biggest batch we could ever receive
755-
torch.cuda.empty_cache()
755+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
756+
torch.cuda.empty_cache()
757+
elif IS_XPU_SYSTEM:
758+
torch.xpu.empty_cache()
756759
try:
757760
cache_manager = set_cache_manager(
758761
batch.blocks,
@@ -772,20 +775,29 @@ def warmup(self, batch: FlashCausalLMBatch):
772775
f"You need to decrease `--max-batch-prefill-tokens`"
773776
) from e
774777

775-
torch.cuda.synchronize(self.device)
778+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
779+
torch.cuda.synchronize(self.device)
780+
elif IS_XPU_SYSTEM:
781+
torch.xpu.synchronize(self.device)
776782

777783
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
778784
# Calculate the number of blocks that can be allocated with the free memory
779785
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
780786
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
781787
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
782788

783-
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
784-
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
789+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
790+
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
791+
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
785792

786-
free_memory = max(
787-
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
788-
)
793+
free_memory = max(
794+
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
795+
)
796+
elif IS_XPU_SYSTEM:
797+
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
798+
free_memory = int(total_gpu_memory *0.5)
799+
else:
800+
raise NotImplementedError("FlashModel is only available on GPU")
789801

790802
num_blocks = (
791803
# Leave 5% for some wiggle room

server/text_generation_server/models/flash_llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
tracer = trace.get_tracer(__name__)
2020

21+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
2122

2223
class FlashLlama(FlashCausalLM):
2324
def __init__(
@@ -33,6 +34,9 @@ def __init__(
3334
if torch.cuda.is_available():
3435
device = torch.device(f"cuda:{rank}")
3536
dtype = torch.float16 if dtype is None else dtype
37+
elif IS_XPU_SYSTEM:
38+
device = torch.device(f"xpu:{rank}")
39+
dtype = torch.float16 if dtype is None else dtype
3640
else:
3741
raise NotImplementedError("FlashLlama is only available on GPU")
3842

server/text_generation_server/models/flash_mistral.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
# Will be set in init
3434
SLIDING_WINDOW: Optional[int] = None
3535
SLIDING_WINDOW_BLOCKS: Optional[int] = None
36+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
3637

37-
MEM_POOL = torch.cuda.graph_pool_handle()
38+
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
3839

3940

4041
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
@@ -316,6 +317,9 @@ def __init__(
316317
if torch.cuda.is_available():
317318
device = torch.device(f"cuda:{rank}")
318319
dtype = torch.float16 if dtype is None else dtype
320+
elif IS_XPU_SYSTEM:
321+
device = torch.device(f"xpu:{rank}")
322+
dtype = torch.float16 if dtype is None else dtype
319323
else:
320324
raise NotImplementedError("FlashMistral is only available on GPU")
321325

0 commit comments

Comments
 (0)