From 6074185a36b8bd8c4b8685e45d989d484d8de1f9 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 19 Mar 2025 15:08:06 -0400 Subject: [PATCH] install torchao mps ops by default when running on Apple Silicon --- .github/workflows/pull.yml | 5 ----- install/install_torchao.sh | 9 +++++++- torchchat/utils/quantize.py | 42 +++++-------------------------------- 3 files changed, 13 insertions(+), 43 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e44d9d037..744818e1a 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1253,11 +1253,6 @@ jobs: ./install/install_requirements.sh pip3 list python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' - - name: Install torchao-ops-mps - id: install-torchao-ops-mps - run: | - bash torchchat/utils/scripts/clone_torchao.sh - bash torchchat/utils/scripts/build_torchao_ops.sh mps - name: Run inference run: | python torchchat.py download stories110M diff --git a/install/install_torchao.sh b/install/install_torchao.sh index 84974040a..9ab7cca9f 100644 --- a/install/install_torchao.sh +++ b/install/install_torchao.sh @@ -31,9 +31,16 @@ else fi echo "Using pip executable: $PIP_EXECUTABLE" +if [[ $(uname -s) == "Darwin" && $(uname -m) == "arm64" ]]; then + echo "Building torchao experimental mps ops (Apple Silicon detected)" + APPLE_SILICON_DETECTED=1 +else + echo "NOT building torchao experimental mps ops (Apple Silicon NOT detected)" + APPLE_SILICON_DETECTED=0 +fi export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt) ( set -x - USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN} + USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=${APPLE_SILICON_DETECTED} $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN} ) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 6246f1c05..102d36657 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -56,6 +56,7 @@ from torchao.experimental.quant_api import ( int8_dynamic_activation_intx_weight, IntxWeightEmbeddingQuantizer, + UIntxWeightOnlyLinearQuantizer, ) from torchao.quantization.granularity import ( PerGroup, @@ -137,12 +138,12 @@ def quantize_model( group_size = q_kwargs["groupsize"] bit_width = q_kwargs["bitwidth"] has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) + granularity = PerRow() if group_size == -1 else PerGroup(group_size) weight_dtype = getattr(torch, f"int{bit_width}") try: quantize_( - model, + model, int8_dynamic_activation_intx_weight( weight_dtype=weight_dtype, granularity=granularity, @@ -154,7 +155,7 @@ def quantize_model( print("Encountered error during quantization: {e}") print("Trying with PlainLayout") quantize_( - model, + model, int8_dynamic_activation_intx_weight( weight_dtype=weight_dtype, granularity=granularity, @@ -946,38 +947,5 @@ def quantized_model(self) -> nn.Module: "linear:int4": Int4WeightOnlyQuantizer, "linear:a8wxdq": None, # uses quantize_ API "linear:a8w4dq": Int8DynActInt4WeightQuantizer, + "linear:afpwx": UIntxWeightOnlyLinearQuantizer, } - -try: - import importlib.util - import os - import sys - - torchao_build_path = f"{os.getcwd()}/torchao-build" - - # Try loading quantizer - torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location( - "torchao_experimental_quant_api", - f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py", - ) - torchao_experimental_quant_api = importlib.util.module_from_spec( - torchao_experimental_quant_api_spec - ) - sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api - torchao_experimental_quant_api_spec.loader.exec_module( - torchao_experimental_quant_api - ) - from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer - quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer - - # Try loading custom op - try: - libname = "libtorchao_ops_mps_aten.dylib" - libpath = f"{torchao_build_path}/cmake-out/lib/{libname}" - torch.ops.load_library(libpath) - print("Loaded torchao mps ops.") - except Exception as e: - print("Unable to load torchao mps ops library.") - -except Exception as e: - print("Unable to import torchao experimental quant_api with error: ", e)