Skip to content

Install torchao mps ops by default when running on Apple Silicon #1512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1253,11 +1253,6 @@ jobs:
./install/install_requirements.sh
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Install torchao-ops-mps
id: install-torchao-ops-mps
run: |
bash torchchat/utils/scripts/clone_torchao.sh
bash torchchat/utils/scripts/build_torchao_ops.sh mps
- name: Run inference
run: |
python torchchat.py download stories110M
Expand Down
9 changes: 8 additions & 1 deletion install/install_torchao.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ else
fi
echo "Using pip executable: $PIP_EXECUTABLE"

if [[ $(uname -s) == "Darwin" && $(uname -m) == "arm64" ]]; then
echo "Building torchao experimental mps ops (Apple Silicon detected)"
APPLE_SILICON_DETECTED=1
else
echo "NOT building torchao experimental mps ops (Apple Silicon NOT detected)"
APPLE_SILICON_DETECTED=0
fi

export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt)
(
set -x
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=${APPLE_SILICON_DETECTED} $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
)
42 changes: 5 additions & 37 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
IntxWeightEmbeddingQuantizer,
UIntxWeightOnlyLinearQuantizer,
)
from torchao.quantization.granularity import (
PerGroup,
Expand Down Expand Up @@ -137,12 +138,12 @@ def quantize_model(
group_size = q_kwargs["groupsize"]
bit_width = q_kwargs["bitwidth"]
has_weight_zeros = q_kwargs["has_weight_zeros"]
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bit_width}")

try:
quantize_(
model,
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
Expand All @@ -154,7 +155,7 @@ def quantize_model(
print("Encountered error during quantization: {e}")
print("Trying with PlainLayout")
quantize_(
model,
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
Expand Down Expand Up @@ -946,38 +947,5 @@ def quantized_model(self) -> nn.Module:
"linear:int4": Int4WeightOnlyQuantizer,
"linear:a8wxdq": None, # uses quantize_ API
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
"linear:afpwx": UIntxWeightOnlyLinearQuantizer,
}

try:
import importlib.util
import os
import sys

torchao_build_path = f"{os.getcwd()}/torchao-build"

# Try loading quantizer
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
"torchao_experimental_quant_api",
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
)
torchao_experimental_quant_api = importlib.util.module_from_spec(
torchao_experimental_quant_api_spec
)
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
torchao_experimental_quant_api_spec.loader.exec_module(
torchao_experimental_quant_api
)
from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer

# Try loading custom op
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
torch.ops.load_library(libpath)
print("Loaded torchao mps ops.")
except Exception as e:
print("Unable to load torchao mps ops library.")

except Exception as e:
print("Unable to import torchao experimental quant_api with error: ", e)
Loading