From 74396ff9d4de6d2904f2577426632076ec724fb0 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:43:44 -0700 Subject: [PATCH 1/8] lowbit init --- config/data/custom.json | 6 ++++++ config/data/custom_existing_int4.json | 6 ++++++ quantization/quantize.py | 29 +++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 config/data/custom.json create mode 100644 config/data/custom_existing_int4.json diff --git a/config/data/custom.json b/config/data/custom.json new file mode 100644 index 000000000..c964a5686 --- /dev/null +++ b/config/data/custom.json @@ -0,0 +1,6 @@ +{ + "executor": {"accelerator": "cpu"}, + "precision": {"dtype": "fp32"}, + "embedding": {"bitwidth": 4, "groupsize" : 32}, + "_custom": {} +} diff --git a/config/data/custom_existing_int4.json b/config/data/custom_existing_int4.json new file mode 100644 index 000000000..c569aa64e --- /dev/null +++ b/config/data/custom_existing_int4.json @@ -0,0 +1,6 @@ +{ + "executor": {"accelerator": "cpu"}, + "precision": {"dtype": "fp32"}, + "embedding": {"bitwidth": 4, "groupsize" : 32}, + "linear:int4": {"groupsize": 256} +} diff --git a/quantization/quantize.py b/quantization/quantize.py index 8efc4fa08..3a3c24585 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -564,6 +564,34 @@ def quantized_model(self) -> nn.Module: return self.quantize(self.model_) +class CustomHandler(QuantHandler): + def __init__(self, model: nn.Module, device="cpu", tokenizer=None): + self.model_ = model + self.device = device + self.tokenizer = tokenizer + + def create_quantized_state_dict(self) -> Dict: # "StateDict" + pass + + def convert_for_runtime(self) -> nn.Module: + pass + + def quantized_model(self) -> nn.Module: + self.model_ = self.model_.to("cpu") + + import importlib.util + import sys + spec = importlib.util.spec_from_file_location( + "torch_custom_op", + "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op_v2.py" + ) + torch_custom_op = importlib.util.module_from_spec(spec) + sys.modules["torch_custom_op"] = torch_custom_op + spec.loader.exec_module(torch_custom_op) + + torch_custom_op.replace_linear_with_quantized_linear(self.model_, kwargs={"group_size": 256, "nbit": 4, "squeeze_unsqueeze_dim0": True}) + return self.model_ + ########################################################################## ### quantization dictionary ### @@ -575,6 +603,7 @@ def quantized_model(self) -> nn.Module: "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, + "_custom": CustomHandler, } ao_quantizer_class_dict = { From 95c66b87c02a017401353fc7978a25cbe97e1081 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:26:23 -0700 Subject: [PATCH 2/8] mods --- generate.py | 2 ++ quantization/quantize.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/generate.py b/generate.py index fc48375d2..272831b68 100644 --- a/generate.py +++ b/generate.py @@ -30,6 +30,8 @@ from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info +torch.ops.load_library("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.dylib") +torch.set_num_threads(6) # 6 threads is better perf on my machine and is what I used for ET tests too class _ChatFormatter(ABC): def __init__(self, tokenizer): diff --git a/quantization/quantize.py b/quantization/quantize.py index 3a3c24585..e4b005642 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -583,7 +583,7 @@ def quantized_model(self) -> nn.Module: import sys spec = importlib.util.spec_from_file_location( "torch_custom_op", - "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op_v2.py" + "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py" ) torch_custom_op = importlib.util.module_from_spec(spec) sys.modules["torch_custom_op"] = torch_custom_op From 3b7e17c779814181d4330c665a6476c71c00d178 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:30:07 -0700 Subject: [PATCH 3/8] updates --- config/data/custom.json | 6 ------ config/data/custom_existing_int4.json | 6 ------ 2 files changed, 12 deletions(-) delete mode 100644 config/data/custom.json delete mode 100644 config/data/custom_existing_int4.json diff --git a/config/data/custom.json b/config/data/custom.json deleted file mode 100644 index c964a5686..000000000 --- a/config/data/custom.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "executor": {"accelerator": "cpu"}, - "precision": {"dtype": "fp32"}, - "embedding": {"bitwidth": 4, "groupsize" : 32}, - "_custom": {} -} diff --git a/config/data/custom_existing_int4.json b/config/data/custom_existing_int4.json deleted file mode 100644 index c569aa64e..000000000 --- a/config/data/custom_existing_int4.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "executor": {"accelerator": "cpu"}, - "precision": {"dtype": "fp32"}, - "embedding": {"bitwidth": 4, "groupsize" : 32}, - "linear:int4": {"groupsize": 256} -} From 008dd9024dcfd2476d396199a267c7f469a64819 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 29 Aug 2024 09:44:43 -0700 Subject: [PATCH 4/8] Add timers to generate.py that better capture torch.compile perf --- generate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/generate.py b/generate.py index 272831b68..35b693f3b 100644 --- a/generate.py +++ b/generate.py @@ -397,6 +397,13 @@ def decode_n_tokens( ) input_pos += 1 break + if _i == 1: + t0 = time.time() + if _i == num_new_tokens - 2: + t1 = time.time() + print(f"\nTime to generate {num_new_tokens-2} tokens: {t1-t0}") + print(f"\nTokens/sec to generate {num_new_tokens-2} tokens: {(num_new_tokens-2) / (t1-t0)}") + if not encountered_eos: eos_token = torch.tensor( From 1406a3c151d25594a0141070c4de38601ce1c826 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:55:59 -0700 Subject: [PATCH 5/8] updates --- generate.py | 3 - quantization/quantize.py | 133 ++++++++++++++++++++++++++++++--------- runner/aoti.cmake | 3 + 3 files changed, 106 insertions(+), 33 deletions(-) diff --git a/generate.py b/generate.py index 35b693f3b..b32ecba0d 100644 --- a/generate.py +++ b/generate.py @@ -30,9 +30,6 @@ from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info -torch.ops.load_library("/tmp/cmake-out/torch_ao/examples/torch_custom_op/libtorch_custom_op.dylib") -torch.set_num_threads(6) # 6 threads is better perf on my machine and is what I used for ET tests too - class _ChatFormatter(ABC): def __init__(self, tokenizer): self.tokenizer = tokenizer diff --git a/quantization/quantize.py b/quantization/quantize.py index e4b005642..fdbc7d5ee 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -22,7 +22,11 @@ # from __future__ import annotations +# torchao_experimental +import importlib.util + import json +import sys # from functools import reduce # from math import gcd @@ -45,6 +49,22 @@ ) from torchao.utils import unwrap_tensor_subclass +torchao_experimental_spec = importlib.util.spec_from_file_location( + "torchao_experimental", + "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py", +) +torchao_experimental = importlib.util.module_from_spec(torchao_experimental_spec) +sys.modules["torchao_experimental"] = torchao_experimental +torchao_experimental_spec.loader.exec_module(torchao_experimental) + +import glob + +libs = glob.glob("/tmp/cmake-out/torchao/libtorch_custom_op.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +torch.ops.load_library(libs[0]) + +from torchao_experimental import Int8DynActLowbitWeightQuantizer + ######################################################################### ### torchchat quantization API ### @@ -92,9 +112,20 @@ def quantize_model( try: # Easier to ask forgiveness than permission - quant_handler = ao_quantizer_class_dict[quantizer]( - groupsize=q_kwargs["groupsize"], device=device, precision=precision - ) + if quantizer == "linear:a8wlow": + quant_handler = ao_quantizer_class_dict[quantizer]( + device=device, + precision=precision, + bitwidth=q_kwargs.get("bitwidth", 4), + groupsize=q_kwargs.get("groupsize", 128), + has_weight_zeros=q_kwargs.get("has_weight_zeros", False), + ) + else: + quant_handler = ao_quantizer_class_dict[quantizer]( + groupsize=q_kwargs["groupsize"], + device=device, + precision=precision, + ) except TypeError as e: if "unexpected keyword argument 'device'" in str(e): quant_handler = ao_quantizer_class_dict[quantizer]( @@ -564,33 +595,75 @@ def quantized_model(self) -> nn.Module: return self.quantize(self.model_) -class CustomHandler(QuantHandler): - def __init__(self, model: nn.Module, device="cpu", tokenizer=None): - self.model_ = model - self.device = device - self.tokenizer = tokenizer +# class A8WLowHandler(QuantHandler): +# def __init__( +# self, +# model: nn.Module, +# device="cpu", +# tokenizer=None, +# *, +# bitwidth: Optional[int] = None, +# groupsize: Optional[int] = None, +# has_weight_zeros: Optional[bool] = None, +# ): +# self.model_ = model +# self.device = device +# self.tokenizer = tokenizer + +# if bitwidth is None: +# self.bitwidth = 4 +# print(f"Warning: bitwidth not specified, defaulting to {self.bitwidth}.") +# else: +# self.bitwidth = bitwidth + +# if groupsize is None: +# self.groupsize = 256 +# print(f"Warning: groupsize not specified, defaulting to {self.groupsize}.") +# else: +# self.groupsize = groupsize + +# if has_weight_zeros is None: +# self.has_weight_zeros = False +# print(f"Warning: has_weight_zeros not specified, defaulting to {self.has_weight_zeros}.") +# else: +# self.has_weight_zeros = has_weight_zeros + +# print("Quantizing with:") +# print(f"\tbitwidth: {self.bitwidth}") +# print(f"\tgroupsize: {self.groupsize}") +# print(f"\thas_weight_zeros: {self.has_weight_zeros}") + +# def create_quantized_state_dict(self) -> Dict: # "StateDict" +# pass + +# def convert_for_runtime(self) -> nn.Module: +# pass + +# def quantized_model(self) -> nn.Module: +# self.model_ = self.model_.to("cpu") + +# import importlib.util +# import sys + +# spec = importlib.util.spec_from_file_location( +# "torch_custom_op", +# "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py", +# ) +# torch_custom_op = importlib.util.module_from_spec(spec) +# sys.modules["torch_custom_op"] = torch_custom_op +# spec.loader.exec_module(torch_custom_op) + +# torch_custom_op.replace_linear_with_quantized_linear( +# self.model_, +# kwargs={ +# "group_size": self.groupsize, +# "nbit": self.bitwidth, +# "has_weight_zeros": self.has_weight_zeros, +# "squeeze_unsqueeze_dim0": True, +# }, +# ) +# return self.model_ - def create_quantized_state_dict(self) -> Dict: # "StateDict" - pass - - def convert_for_runtime(self) -> nn.Module: - pass - - def quantized_model(self) -> nn.Module: - self.model_ = self.model_.to("cpu") - - import importlib.util - import sys - spec = importlib.util.spec_from_file_location( - "torch_custom_op", - "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py" - ) - torch_custom_op = importlib.util.module_from_spec(spec) - sys.modules["torch_custom_op"] = torch_custom_op - spec.loader.exec_module(torch_custom_op) - - torch_custom_op.replace_linear_with_quantized_linear(self.model_, kwargs={"group_size": 256, "nbit": 4, "squeeze_unsqueeze_dim0": True}) - return self.model_ ########################################################################## ### quantization dictionary ### @@ -603,10 +676,10 @@ def quantized_model(self) -> nn.Module: "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, - "_custom": CustomHandler, } ao_quantizer_class_dict = { "linear:int4": Int4WeightOnlyQuantizer, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, + "linear:a8wlow": Int8DynActLowbitWeightQuantizer, } diff --git a/runner/aoti.cmake b/runner/aoti.cmake index 156e9bcce..798b9da46 100644 --- a/runner/aoti.cmake +++ b/runner/aoti.cmake @@ -28,3 +28,6 @@ if(Torch_FOUND) target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m) set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17) endif() + + +target_link_libraries(aoti_run "/tmp/cmake-out/torchao/libtorch_custom_op.dylib") From 334bc5b8e6b97f317d091c312c13adb349915715 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 9 Sep 2024 11:43:54 -0700 Subject: [PATCH 6/8] update --- quantization/quantize.py | 71 ---------------------------------------- 1 file changed, 71 deletions(-) diff --git a/quantization/quantize.py b/quantization/quantize.py index fdbc7d5ee..0f4c06cc9 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -594,77 +594,6 @@ def quantize(self, module): def quantized_model(self) -> nn.Module: return self.quantize(self.model_) - -# class A8WLowHandler(QuantHandler): -# def __init__( -# self, -# model: nn.Module, -# device="cpu", -# tokenizer=None, -# *, -# bitwidth: Optional[int] = None, -# groupsize: Optional[int] = None, -# has_weight_zeros: Optional[bool] = None, -# ): -# self.model_ = model -# self.device = device -# self.tokenizer = tokenizer - -# if bitwidth is None: -# self.bitwidth = 4 -# print(f"Warning: bitwidth not specified, defaulting to {self.bitwidth}.") -# else: -# self.bitwidth = bitwidth - -# if groupsize is None: -# self.groupsize = 256 -# print(f"Warning: groupsize not specified, defaulting to {self.groupsize}.") -# else: -# self.groupsize = groupsize - -# if has_weight_zeros is None: -# self.has_weight_zeros = False -# print(f"Warning: has_weight_zeros not specified, defaulting to {self.has_weight_zeros}.") -# else: -# self.has_weight_zeros = has_weight_zeros - -# print("Quantizing with:") -# print(f"\tbitwidth: {self.bitwidth}") -# print(f"\tgroupsize: {self.groupsize}") -# print(f"\thas_weight_zeros: {self.has_weight_zeros}") - -# def create_quantized_state_dict(self) -> Dict: # "StateDict" -# pass - -# def convert_for_runtime(self) -> nn.Module: -# pass - -# def quantized_model(self) -> nn.Module: -# self.model_ = self.model_.to("cpu") - -# import importlib.util -# import sys - -# spec = importlib.util.spec_from_file_location( -# "torch_custom_op", -# "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py", -# ) -# torch_custom_op = importlib.util.module_from_spec(spec) -# sys.modules["torch_custom_op"] = torch_custom_op -# spec.loader.exec_module(torch_custom_op) - -# torch_custom_op.replace_linear_with_quantized_linear( -# self.model_, -# kwargs={ -# "group_size": self.groupsize, -# "nbit": self.bitwidth, -# "has_weight_zeros": self.has_weight_zeros, -# "squeeze_unsqueeze_dim0": True, -# }, -# ) -# return self.model_ - - ########################################################################## ### quantization dictionary ### From b8ee9da13ea3b795aec00811bcc562918b1cddac Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:32:32 -0700 Subject: [PATCH 7/8] add torchao build scripts --- .pins/torchao-pin.txt | 1 + quantization/quantize.py | 47 ++++++++++++++++---------- runner/aoti.cmake | 2 +- runner/et.cmake | 3 ++ scripts/build_native.sh | 8 +++++ scripts/build_torchao_custom_ops.sh | 23 +++++++++++++ scripts/install_utils.sh | 51 +++++++++++++++++++++++++++++ 7 files changed, 116 insertions(+), 19 deletions(-) create mode 100644 .pins/torchao-pin.txt create mode 100644 scripts/build_torchao_custom_ops.sh diff --git a/.pins/torchao-pin.txt b/.pins/torchao-pin.txt new file mode 100644 index 000000000..a3402d40c --- /dev/null +++ b/.pins/torchao-pin.txt @@ -0,0 +1 @@ +85d03de43160328eaf350e7ec3877d3d7b57da50 diff --git a/quantization/quantize.py b/quantization/quantize.py index 0f4c06cc9..0b54d73aa 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -49,23 +49,6 @@ ) from torchao.utils import unwrap_tensor_subclass -torchao_experimental_spec = importlib.util.spec_from_file_location( - "torchao_experimental", - "/Users/scroy/fbsource/fbcode/pytorch/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py", -) -torchao_experimental = importlib.util.module_from_spec(torchao_experimental_spec) -sys.modules["torchao_experimental"] = torchao_experimental -torchao_experimental_spec.loader.exec_module(torchao_experimental) - -import glob - -libs = glob.glob("/tmp/cmake-out/torchao/libtorch_custom_op.*") -libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -from torchao_experimental import Int8DynActLowbitWeightQuantizer - - ######################################################################### ### torchchat quantization API ### @@ -119,6 +102,7 @@ def quantize_model( bitwidth=q_kwargs.get("bitwidth", 4), groupsize=q_kwargs.get("groupsize", 128), has_weight_zeros=q_kwargs.get("has_weight_zeros", False), + squeeze_unsqueeze_dim0=True, ) else: quant_handler = ao_quantizer_class_dict[quantizer]( @@ -610,5 +594,32 @@ def quantized_model(self) -> nn.Module: ao_quantizer_class_dict = { "linear:int4": Int4WeightOnlyQuantizer, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, - "linear:a8wlow": Int8DynActLowbitWeightQuantizer, } + +try: + import os + torchao_build_path = f"{os.getcwd()}/torchao-build" + + # Load quantizer + torchao_experimental_spec = importlib.util.spec_from_file_location( + "torchao_experimental", + f"{torchao_build_path}/src/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py", + ) + torchao_experimental = importlib.util.module_from_spec(torchao_experimental_spec) + sys.modules["torchao_experimental"] = torchao_experimental + torchao_experimental_spec.loader.exec_module(torchao_experimental) + from torchao_experimental import Int8DynActLowbitWeightQuantizer + ao_quantizer_class_dict["linear:a8wlow"] = Int8DynActLowbitWeightQuantizer + + # Try loading custom op + try: + import glob + libs = glob.glob(f"{torchao_build_path}/cmake-out/liblowbit_op_aten.*") + libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) + torch.ops.load_library(libs[0]) + except Exception as e: + print("Failed to load custom ops : ", e) + print("Slow fallback kernels will be used.") + +except Exception as e: + print(f"Failed to use torchao_experimental kernels: {e}") diff --git a/runner/aoti.cmake b/runner/aoti.cmake index 798b9da46..24ec6e505 100644 --- a/runner/aoti.cmake +++ b/runner/aoti.cmake @@ -30,4 +30,4 @@ if(Torch_FOUND) endif() -target_link_libraries(aoti_run "/tmp/cmake-out/torchao/libtorch_custom_op.dylib") +target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/liblowbit_op_aten${CMAKE_SHARED_LIBRARY_SUFFIX}") diff --git a/runner/et.cmake b/runner/et.cmake index 7fc16b1f2..5fe852abd 100644 --- a/runner/et.cmake +++ b/runner/et.cmake @@ -106,6 +106,7 @@ if(executorch_FOUND) target_link_libraries(et_run PRIVATE "$") + # This one is needed for cpuinfo where it uses android specific log lib if(ANDROID) target_link_libraries(et_run PRIVATE log) @@ -129,3 +130,5 @@ if(executorch_FOUND) else() MESSAGE(WARNING "ExecuTorch package not found") endif() + +target_link_libraries(et_run PRIVATE "${TORCHCHAT_ROOT}/torchao-build/cmake-out/liblowbit_op_executorch${CMAKE_SHARED_LIBRARY_SUFFIX}") diff --git a/scripts/build_native.sh b/scripts/build_native.sh index 6ceea0aee..7e7574aed 100755 --- a/scripts/build_native.sh +++ b/scripts/build_native.sh @@ -60,6 +60,10 @@ if [ -z "${ET_BUILD_DIR}" ]; then ET_BUILD_DIR="et-build" fi +if [ -z "${TORCHAO_BUILD_DIR}" ]; then + TORCHAO_BUILD_DIR="torchao-build" +fi + source "$TORCHCHAT_ROOT/scripts/install_utils.sh" pushd ${TORCHCHAT_ROOT} @@ -70,6 +74,10 @@ if [[ "$TARGET" == "et" ]]; then install_pip_dependencies clone_executorch install_executorch_libs false + + EXECUTORCH_INCLUDE_DIRS=${TORCHCHAT_ROOT}/et-build/src + EXECUTORCH_LIBRARIES=${TORCHCHAT_ROOT}/et-build/install/lib/libexecutorch_no_prim_ops.a + install_torchao_custom_executorch_ops fi popd diff --git a/scripts/build_torchao_custom_ops.sh b/scripts/build_torchao_custom_ops.sh new file mode 100644 index 000000000..cf0626afb --- /dev/null +++ b/scripts/build_torchao_custom_ops.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z "${TORCHCHAT_ROOT}" ]; then + # Get the absolute path of the current script + SCRIPT_PATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" + # Get the absolute path of the parent directory + TORCHCHAT_ROOT="$(dirname "$SCRIPT_PATH")" +fi + +if [ -z "${TORCHAO_BUILD_DIR}" ]; then + TORCHAO_BUILD_DIR="torchao-build" +fi + +source "$TORCHCHAT_ROOT/scripts/install_utils.sh" + +find_cmake_prefix_path +clone_torchao +install_torchao_custom_aten_ops diff --git a/scripts/install_utils.sh b/scripts/install_utils.sh index 3b3ad4926..a7badf720 100644 --- a/scripts/install_utils.sh +++ b/scripts/install_utils.sh @@ -124,3 +124,54 @@ install_executorch_libs() { install_executorch_python_libs $1 } + +clone_torchao() { + echo "Cloning torchao to ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src" + rm -rf ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src + mkdir -p ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src + pushd ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src + echo $pwd + + cp -R /Users/scroy/fbsource/fbcode/pytorch/ao . + # git clone https://github.com/pytorch/ao.git + # cd ao + # git checkout $(cat ${TORCHCHAT_ROOT}/.pins/torchao-pin.txt) + + popd +} + +install_torchao_custom_aten_ops() { + echo "Installing custom torchao ops" + pushd ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op + export TORCHAO_INCLUDE_DIRS=${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src/ao + + if [ "${CMAKE_OUT_DIR}" == "" ]; then + CMAKE_OUT_DIR="${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/cmake-out" + fi + + cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ + -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ + -DPLATFORM="ATEN" \ + -S . \ + -B ${CMAKE_OUT_DIR} -G Ninja + cmake --build ${CMAKE_OUT_DIR} +} + +install_torchao_custom_executorch_ops() { + echo "Installing custom torchao ops" + pushd ${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src/ao/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op + export TORCHAO_INCLUDE_DIRS=${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/src/ao + + if [ "${CMAKE_OUT_DIR}" == "" ]; then + CMAKE_OUT_DIR="${TORCHCHAT_ROOT}/${TORCHAO_BUILD_DIR}/cmake-out" + fi + + cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ + -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ + -DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \ + -DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \ + -DPLATFORM="EXECUTORCH" \ + -S . \ + -B ${CMAKE_OUT_DIR} -G Ninja + cmake --build ${CMAKE_OUT_DIR} +} From a71ec16ed4c8f6c39662bde947f99c1d098c78ab Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:38:08 -0700 Subject: [PATCH 8/8] formatting fixes --- generate.py | 1 + quantization/quantize.py | 8 ++++---- runner/et.cmake | 1 - 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/generate.py b/generate.py index b32ecba0d..57004b057 100644 --- a/generate.py +++ b/generate.py @@ -30,6 +30,7 @@ from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info + class _ChatFormatter(ABC): def __init__(self, tokenizer): self.tokenizer = tokenizer diff --git a/quantization/quantize.py b/quantization/quantize.py index 0b54d73aa..a1232327e 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -22,11 +22,7 @@ # from __future__ import annotations -# torchao_experimental -import importlib.util - import json -import sys # from functools import reduce # from math import gcd @@ -49,6 +45,7 @@ ) from torchao.utils import unwrap_tensor_subclass + ######################################################################### ### torchchat quantization API ### @@ -578,6 +575,7 @@ def quantize(self, module): def quantized_model(self) -> nn.Module: return self.quantize(self.model_) + ########################################################################## ### quantization dictionary ### @@ -597,6 +595,8 @@ def quantized_model(self) -> nn.Module: } try: + import importlib.util + import sys import os torchao_build_path = f"{os.getcwd()}/torchao-build" diff --git a/runner/et.cmake b/runner/et.cmake index 5fe852abd..6eeb8f59b 100644 --- a/runner/et.cmake +++ b/runner/et.cmake @@ -106,7 +106,6 @@ if(executorch_FOUND) target_link_libraries(et_run PRIVATE "$") - # This one is needed for cpuinfo where it uses android specific log lib if(ANDROID) target_link_libraries(et_run PRIVATE log)