Description
Describe the issue
Hi IPEX team,
I have an application where I want to serve multiple models concurrently, and I want to share weights across concurrent instances. I normally do this with torch.load(path, mmap=True)
. However, calling ipex.llm.optimize
will interfere with weight sharing because ipex manipulates the weights in memory (does a deep copy from what I understand). I would like to instead save the ipex optimized model and load it (something like torch.load(ipex_model, mmap=True))
). However, I can't figure out how to do this, and was hoping you could provide an example.
How to reproduce:
My miniconda env.yml file is listed below. pip install -r requirements.txt
may not work here but you can create this env easily conda create -n ipex_issue python=3.10 && conda activate ipex_issue
followed by the install instructions here and pip install transformers==4.38.1
. I am using python 3.10 on an aws c7i.2xlarge instance.
certifi==2024.7.4
charset-normalizer==3.3.2
filelock==3.13.1
fsspec==2024.2.0
huggingface-hub==0.24.5
idna==3.7
intel_extension_for_pytorch==2.3.100
Jinja2==3.1.3
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
oneccl-bind-pt==2.3.0+cpu
packaging==24.1
pillow==10.2.0
psutil==6.0.0
PyYAML==6.0.2
regex==2024.7.24
requests==2.32.3
safetensors==0.4.4
sympy==1.12
tokenizers==0.15.2
torch==2.3.0+cpu
torchaudio==2.3.0+cpu
torchvision==0.18.0+cpu
tqdm==4.66.5
transformers==4.38.1
typing_extensions==4.9.0
urllib3==2.2.2
Here are the things I have tried:
import os
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import intel_extension_for_pytorch as ipex
config = AutoConfig.from_pretrained("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_ACCESS_TOKEN"))
# Just use a tiny model with random weights so it uses less mem, faster to test
config_micro = dict(
hidden_size=768,
intermediate_size=int(768*3.5),
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=8,
)
for k, v in config_micro.items():
setattr(config, k, v)
# NOTE: this is where we want to load the model with mmap=True to enable weight sharing
# EG model = AutoModelForCausalLM.from_pretrained(path, mmap=True)
model = AutoModelForCausalLM.from_config(config)
model.eval()
model.to(torch.bfloat16)
# NOTE: optimize deepcopies the model, breaks weights sharing / memory mapping
model = ipex.llm.optimize(
model, dtype=torch.bfloat16, inplace=True, deployment_mode=True
)
###
# Several attempts to save / load
###
# 0) model.save_pretrained("save_pretrained")
model.save_pretrained("save_pretrained")
# error: RecursionError: maximum recursion depth exceeded while calling a Python object
# 1) model.save
model.save("model_save")
# error: AttributeError: 'MistralForCausalLM' object has no attribute 'save'
# 2) torch.save
torch.save(model, "torch_save_model.pt")
# error: RuntimeError: Tried to serialize object __torch__.transformers.models.mistral.modeling_mistral.___torch_mangle_65.MistralForCausalLM which does not have a __getstate__ method defined!
# 3) model.trace_graph.save()
model.trace_graph.save("model_trace_graph")
m3 = torch.jit.load("model_trace_graph")
inputs = torch.randint(low=500, high=1_000, size=(1, 16), dtype=torch.int64)
m3(inputs)
# error: RuntimeError: forward() is missing value for argument 'attention_mask'.
# 4) save jit traced
with torch.no_grad():
traced_model = torch.jit.trace(model, inputs)
# error: RecursionError: maximum recursion depth exceeded in comparison
As a side note, I understand you normally use subprocess to deploy multiple concurrent models, but this is not an option for my case because the logic that decides how and when to fork processes is separated from the part of the code that loads the model.
At some point I think I was able to get option 0) above to work, but the loaded model would be a vanilla transformer without ipex optimizations, and I also can't seem to reproduce that behavior at least in this env.
Any help would be much appreciated.