diff --git a/common/common.cpp b/common/common.cpp index 94f545f815c27..93c6bbba491b6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -10,6 +10,7 @@ #include "llama.h" #include +#include #include #include #include diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f78dd..5a89c680e7a20 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -27,6 +27,7 @@ if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf +from gguf.tmac_utils import get_quantization_config, preprocess_for_t_mac, is_tmac_ftype, derive_ftype_from_quantization_config logger = logging.getLogger("hf-to-gguf") @@ -66,6 +67,7 @@ class Model: metadata_override: Path | None dir_model_card: Path remote_hf_model_id: str | None + enable_t_mac: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -74,7 +76,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, + enable_t_mac: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -109,17 +112,27 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py + self.enable_t_mac = enable_t_mac + + # Load model quantization config + self.quantization_config: dict[str, Any] = get_quantization_config(self.dir_model) # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: - # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. - _, first_tensor = next(self.get_tensors()) - if first_tensor.dtype == torch.float16: - logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_F16 + if self.enable_t_mac: + ftype = derive_ftype_from_quantization_config(self.quantization_config) + logger.info(f"choosing --outtype {ftype} from quantization config") + if ftype is not None: + self.ftype = ftype else: - logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. + _, first_tensor = next(self.get_tensors()) + if first_tensor.dtype == torch.float16: + logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + else: + logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, @@ -280,6 +293,184 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + _gptq_quant_dict: dict[str, Tensor] | None = None + _t_mac_raw_shape: tuple[int, ...] | None = None + + # Repack and merge qweight, scales, and qzeros into a single tensor + # Currently, this logic is nearly impossible to be implemented in quants.py + def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Convert unsupported bfloat16 to float32 + if data_torch.dtype == torch.bfloat16: + data_torch = data_torch.to(torch.float32) + + if not self.enable_t_mac or isinstance(self, BitnetModel): + return self.modify_tensors(data_torch, name, bid) + + self._t_mac_raw_shape = None # reset to make sure old values don't leak into new tensors case + if self.quantization_config["quant_method"] == "gptq": # AutoGPTQ/GPTQModel + if name.endswith(".g_idx"): + return [] + + if name.endswith(".qweight") or name.endswith(".scales") or name.endswith(".qzeros"): + if self._gptq_quant_dict is None: + self._gptq_quant_dict = {} + suffix = "." + name.split(".")[-1] + base_name = name.replace(suffix, "") + self._gptq_quant_dict.setdefault(base_name, {})[suffix] = data_torch + if len(self._gptq_quant_dict[base_name]) < 3: + return [] + + # Get weight components: all [out_feature, in_feature] + qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy() + scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy() + qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy() + name = base_name + ".weight" + from gguf.tmac_utils import unpack_gptqv2 + w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in self.quantization_config["quantizer"]) + if bits != self.quantization_config["bits"] or group_size != self.quantization_config["group_size"]: + # Currently, we only support models that all weights are corresponding to the quantization config. + raise ValueError("Error while parsing weights for quantization_config: {}, but got bits={} and group_size={}".format( + self.quantization_config, bits, group_size)) + self._t_mac_raw_shape = w.shape + + # For permutation in, e.g., LlamaModel + w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy() + scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy() + zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy() + + if self.quantization_config["bits"] > 0: + if self.quantization_config["sym"]: + if not np.allclose(zeros, np.zeros_like(zeros)): + logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric") + else: + zeros = None + data_torch = torch.from_numpy(preprocess_for_t_mac(w, scales, zeros, bits=bits)) + else: + # TODO: Here should not be reached? + old_shape = w.shape + w = w.astype("float32").reshape(-1, group_size) + scales = scales.astype("float32").reshape(-1, 1) + zeros = zeros.astype("float32").reshape(-1, 1) + data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales + data_torch = torch.from_numpy(data.reshape(old_shape)) + if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_torch = data_torch.to(torch.float16) + + return [(self.map_tensor_name(name), data_torch)] + elif self.quantization_config["quant_method"] == "bitdistiller": + new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias")) + extra_f32 = any(self.match_model_tensor_name(new_name, key, bid) for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + )) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + data = data_torch.numpy() + n_dims = len(data.shape) + extra_f16 = any(cond for cond in ( + (name.endswith(".weight") and n_dims >= 2), + )) + + do_modify = False + if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: + if is_tmac_ftype(self.ftype) and any(self.match_model_tensor_name(new_name, key, bid) for key in [ + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_QKV, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.FFN_UP, + gguf.MODEL_TENSOR.FFN_DOWN, + gguf.MODEL_TENSOR.FFN_GATE, + ]): + do_modify = True + else: + do_modify = False + + # logger.debug(f"gguf: quantizing tensor {name} to {self.ftype.name}. \tbits = {self.quantization_config['bits']}," + + # f"\tgroup_size = {self.quantization_config['group_size']}, \tsym = {self.quantization_config['sym']}. \tdo_modify = {do_modify}") + + if do_modify: + bits = self.quantization_config["bits"] + group_size = self.quantization_config["group_size"] + w, scales, zeros = self._t_mac_quantize_tensor_bitdistiller( + LazyTorchTensor.to_eager(data_torch), + n_bit=bits, + zero_point=True, + q_group_size=group_size, + ) + self._t_mac_raw_shape = w.shape + + # For permutation in, e.g., LlamaModel + w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy() + scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy() + zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy() + + if is_tmac_ftype(self.ftype): + if self.quantization_config["sym"]: + if not np.allclose(zeros, np.zeros_like(zeros)): + logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric") + else: + zeros = None + data_torch = torch.from_numpy(preprocess_for_t_mac(w, scales, zeros, bits=bits)) + else: + old_shape = w.shape + w = w.astype("float32").reshape(-1, group_size) + scales = scales.astype("float32").reshape(-1, 1) + zeros = zeros.astype("float32").reshape(-1, 1) + data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales + data_torch = torch.from_numpy(data.reshape(old_shape)) + if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_torch = data_torch.to(torch.float16) + + return [(self.map_tensor_name(name), data_torch)] + + return self.modify_tensors(data_torch, name, bid) + + # Modified version of BitDistiller pseudo_quantize_tensor + # core quantization method (simulated quantization) + def _t_mac_quantize_tensor_bitdistiller(self, w, n_bit=8, zero_point=True, q_group_size=-1): + org_w_shape = w.shape + if q_group_size > 0: + assert org_w_shape[-1] % q_group_size == 0 + w = w.reshape(-1, q_group_size) + elif q_group_size == -1: + w = w.reshape(-1, w.shape[-1]) + assert w.dim() == 2 + if zero_point: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + else: # we actually never used this + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (n_bit - 1) - 1 + min_int = - 2 ** (n_bit - 1) + scales = max_val / max_int + zeros = 0 + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) + + w = w.reshape(org_w_shape).numpy() + scales = scales.numpy().reshape(w.shape[0], -1) + zeros = zeros.numpy().reshape(w.shape[0], -1) if zero_point else None + + if zero_point: + w = w.astype(np.uint8) + zeros = (zeros - (2 ** (n_bit - 1))) * scales + return w, scales, zeros + else: + w = (w - min_int).astype(np.uint8) + return w, scales, zeros + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -300,7 +491,7 @@ def prepare_tensors(self): old_dtype = data_torch.dtype # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): + if data_torch.dtype not in (torch.float16, torch.float32) and not self.enable_t_mac: data_torch = data_torch.to(torch.float32) # use the first number-like part of the tensor name as the block id @@ -310,7 +501,13 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)): + # Some GPTQ models have empty bias tensors which are not in the model architecture. + # These tensors will cause tensor number check to fail, so we have to skip them. + if self.enable_t_mac and new_name.endswith(".bias") and np.all(LazyTorchTensor.to_eager(data_torch).numpy() == 0): + logger.info(f"Skipping empty bias tensor: {new_name}") + continue + # TODO: why do we squeeze here? # data = data_torch.squeeze().numpy() data = data_torch.numpy() @@ -364,6 +561,29 @@ def prepare_tensors(self): # TODO: use Q4_K and Q6_K data_qtype = gguf.GGMLQuantizationType.F16 + # If _t_mac_raw_shape is not None, the tensor is quantized by GPTQ + if self.enable_t_mac and self._t_mac_raw_shape is not None: + if self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_BN_0: + data_qtype = gguf.GGMLQuantizationType.TMAC_BN_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G64_0: + data_qtype = gguf.GGMLQuantizationType.TMAC_W2G64_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G64_1: + data_qtype = gguf.GGMLQuantizationType.TMAC_W2G64_1 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G128_0: + data_qtype = gguf.GGMLQuantizationType.TMAC_W2G128_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G128_1: + data_qtype = gguf.GGMLQuantizationType.TMAC_W2G128_1 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G64_0: + data_qtype = gguf.GGMLQuantizationType.TMAC_W4G64_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G64_1: + data_qtype = gguf.GGMLQuantizationType.TMAC_W4G64_1 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G128_0: + data_qtype = gguf.GGMLQuantizationType.TMAC_W4G128_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G128_1: + data_qtype = gguf.GGMLQuantizationType.TMAC_W4G128_1 + else: + raise ValueError(f"Unsupported ftype: {self.ftype}") + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) if isinstance(data_qtype, bool): if self.ftype == gguf.LlamaFileType.ALL_F32: @@ -378,6 +598,12 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: data_qtype = gguf.GGMLQuantizationType.TQ2_0 + elif is_tmac_ftype(self.ftype): + # If the tensor is successfully quantized, data_qtype should be TMAC_* + # If data_qtype is still bool, then the tensor should not be quantized + # In practice, this tensor is `output.weight` for GPTQ models + # TODO: Consider quantizing it? + data_qtype = gguf.GGMLQuantizationType.F16 else: raise ValueError(f"Unknown file type: {self.ftype.name}") @@ -388,15 +614,17 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.F16 data = gguf.quants.quantize(data, data_qtype) - shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + shape = self._t_mac_raw_shape or (gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape) # reverse shape to make it similar to the internal ggml dimension order shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" # n_dims is implicit in the shape - logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}, data = {data.shape}") - self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) + raw_shape = gguf.quant_shape_to_byte_shape(self._t_mac_raw_shape, data_qtype) if is_tmac_ftype(self.ftype) and self._t_mac_raw_shape else None + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype, raw_shape=raw_shape) def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MODEL) @@ -2046,6 +2274,7 @@ def weight_quant(self, weight: Tensor) -> Tensor: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) + self._t_mac_raw_shape = None if any(self.match_model_tensor_name(new_name, key, bid) for key in [ gguf.MODEL_TENSOR.ATTN_Q, gguf.MODEL_TENSOR.ATTN_K, @@ -2055,8 +2284,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter gguf.MODEL_TENSOR.FFN_DOWN, gguf.MODEL_TENSOR.FFN_GATE, ]): + # TODO: apply latest updates # transform weight into 1/0/-1 (in fp32) data_torch = self.weight_quant(data_torch) + from gguf.tmac_utils import is_tmac_ftype + if self.enable_t_mac and is_tmac_ftype(self.ftype): + # transform weight into TMAC_BN_0 format + data = LazyTorchTensor.to_eager(data_torch).numpy() + scale = np.max(np.abs(data)) + w = np.round(data / scale + 2).astype(np.uint8) + data_torch = torch.from_numpy(preprocess_for_t_mac(w, scale.reshape(1), bits=2)) + self.quantization_config["bits"] = 2 + self.quantization_config["group_size"] = -1 + self.quantization_config["sym"] = True + self.quantization_config["quant_method"] = "bitnet" + self._t_mac_raw_shape = w.shape yield (new_name, data_torch) @@ -5394,6 +5636,7 @@ class LazyTorchTensor(gguf.LazyBase): _dtype_map: dict[torch.dtype, type] = { torch.float16: np.float16, torch.float32: np.float32, + torch.bfloat16: np.float32, } # used for safetensors slices @@ -5469,8 +5712,11 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "tmac_bn_0", "tmac_w2g64_0", "tmac_w2g64_1", + "tmac_w2g128_0", "tmac_w2g128_1", "tmac_w4g64_0", "tmac_w4g64_1", "tmac_w4g128_0", + "tmac_w4g128_1", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, " + "and tmac_bn_0 for bitnet, tmac_wXgY_0/1 for GPTQ, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -5525,6 +5771,10 @@ def parse_args() -> argparse.Namespace: "--remote", action="store_true", help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", ) + parser.add_argument( + "--enable-t-mac", action="store_true", + help="Enable T-MAC quantization format (disabled by default). Support TMAC_*, Q4_0, TQ types, and GPTQ, GPTQv2, BitNet and BitDistiller models." + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5584,6 +5834,15 @@ def main() -> None: "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, + "tmac_bn_0": gguf.LlamaFileType.MOSTLY_TMAC_BN_0, + "tmac_w2g64_0": gguf.LlamaFileType.MOSTLY_TMAC_W2G64_0, + "tmac_w2g64_1": gguf.LlamaFileType.MOSTLY_TMAC_W2G64_1, + "tmac_w2g128_0": gguf.LlamaFileType.MOSTLY_TMAC_W2G128_0, + "tmac_w2g128_1": gguf.LlamaFileType.MOSTLY_TMAC_W2G128_1, + "tmac_w4g64_0": gguf.LlamaFileType.MOSTLY_TMAC_W4G64_0, + "tmac_w4g64_1": gguf.LlamaFileType.MOSTLY_TMAC_W4G64_1, + "tmac_w4g128_0": gguf.LlamaFileType.MOSTLY_TMAC_W4G128_0, + "tmac_w4g128_1": gguf.LlamaFileType.MOSTLY_TMAC_W4G128_1, "auto": gguf.LlamaFileType.GUESSED, } @@ -5620,7 +5879,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=str(args.model) if args.remote else None) + remote_hf_model_id=str(args.model) if args.remote else None, + enable_t_mac=args.enable_t_mac) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index d33f843b417cf..91308ab93cd4b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -208,6 +208,8 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") +option(GGML_TMAC "ggml: use TMAC" OFF) + # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) @@ -217,6 +219,9 @@ option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) # set(CMAKE_C_STANDARD 11) +if (GGML_TMAC) + set(CMAKE_C_STANDARD 17) +endif() set(CMAKE_C_STANDARD_REQUIRED true) set(CMAKE_CXX_STANDARD 17) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index f5e11f1e10002..dfbd97adf4528 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -57,6 +57,8 @@ extern "C" { GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); + GGML_BACKEND_API void ggml_threadpool_atomic_store_explicit(struct ggml_threadpool * threadpool, int value); + GGML_BACKEND_API int ggml_threadpool_atomic_fetch_add_explicit(struct ggml_threadpool * threadpool, int value); // ggml_graph_plan() has to be called before ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data @@ -120,12 +122,12 @@ extern "C" { GGML_BACKEND_API void ggml_cpu_init(void); + GGML_BACKEND_API void ggml_cpu_tmac_init(const char * fname); + // // CPU backend // - GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void); - GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend); GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8fcc16df998be..b9640374260e3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -388,7 +388,16 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_4 = 36, // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, - GGML_TYPE_COUNT = 39, + GGML_TYPE_TMAC_BN_0 = 39, + GGML_TYPE_TMAC_W2G64_0 = 40, + GGML_TYPE_TMAC_W2G64_1 = 41, + GGML_TYPE_TMAC_W2G128_0 = 42, + GGML_TYPE_TMAC_W2G128_1 = 43, + GGML_TYPE_TMAC_W4G64_0 = 44, + GGML_TYPE_TMAC_W4G64_1 = 45, + GGML_TYPE_TMAC_W4G128_0 = 46, + GGML_TYPE_TMAC_W4G128_1 = 47, + GGML_TYPE_COUNT = 48, }; // precision diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f00700da71fcd..6f22e83da1446 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -196,6 +196,7 @@ add_library(ggml-base ggml.c ggml-alloc.c ggml-backend.cpp + ggml-common.h ggml-opt.cpp ggml-threading.cpp ggml-threading.h diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index e73a3b69b5da2..873a6e4cedb89 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -22,6 +22,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/amx/amx.h ggml-cpu/amx/mmq.cpp ggml-cpu/amx/mmq.h + ggml-cpu/tmac/tmac.cpp + ggml-cpu/tmac/tmac.h + ggml-cpu/tmac/lut_mul_mat.cpp + ggml-cpu/tmac/lut_mul_mat.h + ggml-cpu/tmac/lut_ctor.cpp + ggml-cpu/tmac/lut_ctor.h + ggml-cpu/tmac/tbl.cpp + ggml-cpu/tmac/tbl.h ggml-cpu/ggml-cpu-impl.h ggml-cpu/common.h ggml-cpu/binary-ops.h @@ -72,6 +80,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/llamafile/sgemm.h) endif() + if (GGML_TMAC) + target_compile_definitions(${GGML_CPU_NAME} PUBLIC GGML_USE_TMAC) + target_include_directories(${GGML_CPU_NAME} PUBLIC ggml-cpu/tmac) + get_target_property(cdefs ${GGML_CPU_NAME} COMPILE_DEFINITIONS) + message(STATUS "GGML_CPU_NAME: ${GGML_CPU_NAME} COMPILE_DEFINITIONS: ${cdefs}") + + if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR + (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) + message(FATAL_ERROR "Clang is required for T-MAC compilation") + endif() + + if (GGML_TMAC_RECHUNK) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE TMAC_RECHUNK) + endif() + endif() + if (GGML_CPU_HBM) find_library(memkind memkind REQUIRED) @@ -145,6 +169,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH}) endif() endif() + if (GGML_TMAC) + # ARM Windows with LLVM clang GNU interface + # We need fullfp16 for T-MAC + # TODO: check_cxx_source_compiles + list(APPEND ARCH_FLAGS -march=armv8.2a+fp16) + endif() # show enabled features if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") @@ -181,7 +211,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_NATIVE) include(ggml-cpu/cmake/FindSIMD.cmake) endif () - if (GGML_AVX512) + # Can't use GGML_AVX512 with T-MAC and Clang for MSVC + # with error: conflicting types for '_m_prefetchw + if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) list(APPEND ARCH_FLAGS /arch:AVX512) # /arch:AVX512 includes: __AVX512F__, __AVX512CD__, __AVX512BW__, __AVX512DQ__, and __AVX512VL__ # MSVC has no compile-time flags enabling specific @@ -323,6 +355,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) endif() endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC) + # We need fullfp16 for T-MAC + # TODO: we need to simplify this logic through check_cxx_source_compiles or Presets? + check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8) + if (GGML_COMPILER_SUPPORT_MATMUL_INT8) + # Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7 + # based on arm64-windows-llvm.cmake + list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only) + add_compile_definitions(__ARM_FEATURE_MATMUL_INT8) + else () + # Jetson AGX Orin, Raspberry Pi 5 + list(APPEND ARCH_FLAGS -march=armv8.2a+fp16) + endif () elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") message(STATUS "loongarch64 detected") diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 50400328738ef..a918743df5e07 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_TMAC +#include "tmac.h" +#endif + #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -373,6 +377,51 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_TMAC_BN_0] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W2G64_0] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W2G64_1] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W2G128_0] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W2G128_1] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W4G64_0] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W4G64_1] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W4G128_0] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_TMAC_W4G128_1] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -2639,6 +2688,14 @@ void ggml_threadpool_resume(struct ggml_threadpool * threadpool) { #endif } +void ggml_threadpool_atomic_store_explicit(struct ggml_threadpool * threadpool, int value) { + atomic_store_explicit(&threadpool->current_chunk, value, memory_order_relaxed); +} + +int ggml_threadpool_atomic_fetch_add_explicit(struct ggml_threadpool * threadpool, int value) { + return (int)atomic_fetch_add_explicit(&threadpool->current_chunk, value, memory_order_relaxed); +} + struct ggml_cplan ggml_graph_plan( const struct ggml_cgraph * cgraph, int n_threads, @@ -3406,6 +3463,10 @@ void ggml_cpu_init(void) { ggml_init_arm_arch_features(); #endif +#ifdef GGML_USE_TMAC + ggml_tmac_init(); +#endif + is_first_call = false; } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 4b688a67eb23b..74e32c406fc9a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -5,6 +5,7 @@ #include "ggml-cpu-traits.h" #include "ggml-impl.h" #include "amx/amx.h" +#include "tmac/tmac.h" #include #include @@ -43,6 +44,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } #endif +#ifdef GGML_USE_TMAC + if (ggml_backend_tmac_buffer_type()) { + bufts.push_back(ggml_backend_tmac_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6050147be70ac..e579ae69d9d11 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4978,6 +4978,15 @@ void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_TMAC_BN_0: + case GGML_TYPE_TMAC_W2G64_0: + case GGML_TYPE_TMAC_W2G64_1: + case GGML_TYPE_TMAC_W2G128_0: + case GGML_TYPE_TMAC_W2G128_1: + case GGML_TYPE_TMAC_W4G64_0: + case GGML_TYPE_TMAC_W4G64_1: + case GGML_TYPE_TMAC_W4G128_0: + case GGML_TYPE_TMAC_W4G128_1: case GGML_TYPE_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/tmac/lut_ctor.cpp b/ggml/src/ggml-cpu/tmac/lut_ctor.cpp new file mode 100644 index 0000000000000..c926624fc3c50 --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/lut_ctor.cpp @@ -0,0 +1,272 @@ +#include "lut_ctor.h" + +#include + +#if defined __AVX2__ +static inline float _mm256_addv_ps(const __m256 v) { + __m128 res = _mm256_extractf128_ps(v, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(v)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} +#endif + + +// Current implementation requires (K * 4) == act_group_size and K >= 8 +// s0 = -1, s1 = 1 +// TODO: loop K +// Still preserve FastAggregationK althougth it's unused for compatibility +template +inline int32_t lut_ctor_g4_int8_impl(int32_t act_k, int8_t* qlut, tmac_float_type* b, tmac_float_type* lut_scales, tmac_float_type* lut_biases) { +#ifdef __ARM_NEON + float16x8_t vec_lut[16]; + float16_t biases = 0.0; + float16_t scales = *lut_scales; + float16_t t_scales = scales ? 1.0 / scales : 0.0; + + for (int k = 0; k < act_k / 32; ++k) { + float16x8x4_t vec_bs = vld4q_f16(b + k * 32); + +#pragma unroll + for (int g = 1; g < 16; g += 2) { + vec_lut[g] = vec_bs.val[0]; + if (g & 0b0010) { + vec_lut[g] = vec_lut[g] + vec_bs.val[1]; + } else { + vec_lut[g] = vec_lut[g] - vec_bs.val[1]; + } + if (g & 0b0100) { + vec_lut[g] = vec_lut[g] + vec_bs.val[2]; + } else { + vec_lut[g] = vec_lut[g] - vec_bs.val[2]; + } + if (g & 0b1000) { + vec_lut[g] = vec_lut[g] + vec_bs.val[3]; + } else { + vec_lut[g] = vec_lut[g] - vec_bs.val[3]; + } + } +#pragma unroll + for (int g = 0; g < 16; g += 2) { + vec_lut[g] = -vec_lut[15 - g]; + } + + biases += vaddvq_f16(vec_lut[0]); +#undef vaddvq_f16 + +#pragma unroll + for (int g = 0; g < 16; ++g) { + vec_lut[g] = vmulq_n_f16(vec_lut[g], t_scales); + } + + int8x8_t vec_qlut[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + vec_qlut[g] = vqmovn_s16(vcvtnq_s16_f16(vec_lut[g])); + } + +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + g, vec_qlut[g], 0); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 + g, vec_qlut[g], 1); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 2 + g, vec_qlut[g], 2); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 3 + g, vec_qlut[g], 3); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 4 + g, vec_qlut[g], 4); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 5 + g, vec_qlut[g], 5); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 6 + g, vec_qlut[g], 6); + } +#pragma unroll + for (int g = 0; g < 16; ++g) { + vst1_lane_s8(qlut + k * 8 * 16 + 16 * 7 + g, vec_qlut[g], 7); + } + } +#elif defined __AVX2__ + __m256 vec_lut[16]; + float biases = 0.0; + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1); + +#pragma unroll + for (int g = 1; g < 16; g += 2) { + vec_lut[g] = vec_b0; + if (g & 0b0010) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b1); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b1); + } + if (g & 0b0100) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b2); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b2); + } + if (g & 0b1000) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b3); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b3); + } + } +#pragma unroll + for (int g = 0; g < 16; g += 2) { + vec_lut[g] = -vec_lut[15 - g]; + } + + biases += _mm256_addv_ps(vec_lut[0]); + +#pragma unroll + for (int g = 0; g < 16; ++g) { + vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales)); + } + + __m256i vec_qlut[4]; + const __m256i shuf = _mm256_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); +#pragma unroll + for (int g = 0; g < 4; g += 1) { + __m256i i0 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 0], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i1 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 1], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i2 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 2], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i3 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 3], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + i0 = _mm256_packs_epi32(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32(i2, i3); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16(i0, i2); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + vec_qlut[g] = _mm256_shuffle_epi8(i0, shuf); // 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 + } + + int32_t* qlut_i32 = reinterpret_cast(qlut); +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); + } +#pragma unroll + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); + } + } +#endif + + *lut_scales = scales; + *lut_biases = biases; + + return 0; +} + + +#ifdef __cplusplus +extern "C" { +#endif + +int32_t partial_max_g4_int8_k8(void* lut_scales_, void* b_) { + tmac_float_type* lut_scales = (tmac_float_type*)lut_scales_; + tmac_float_type* b = (tmac_float_type*)b_; +#ifdef __ARM_NEON + float16x8x4_t vec_bs = vld4q_f16(b); + float16x8_t abssum = vabsq_f16(vec_bs.val[0]) + vabsq_f16(vec_bs.val[1]) + vabsq_f16(vec_bs.val[2]) + vabsq_f16(vec_bs.val[3]); + float16_t scales = vmaxvq_f16(abssum) / 127; + *lut_scales = std::max(*lut_scales, scales); +#elif defined __AVX2__ + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); + __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); + __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); + __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float scales = _mm_cvtss_f32(max4) / 127; + *lut_scales = std::max(*lut_scales, scales); +#endif + + return 0; +} + +int32_t partial_max_reset(void* lut_scales_) { + tmac_float_type* lut_scales = (tmac_float_type*)lut_scales_; + *lut_scales = 0.0; + return 0; +} + +#ifdef __cplusplus +} +#endif + + +void lut_ctor_int8_g4(void* B, void* LUT_Scales, void* LUT_Biases, void* QLUT, int K, const struct tmac_kernel_config * const kernel_config) { + // TODO: handle bitnet here + + int act_group_size = kernel_config->act_group_size; + int bits = kernel_config->bits; + + int kk_outer_max = K / act_group_size; + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + partial_max_reset((&(((tmac_float_type*)LUT_Scales)[kk_outer]))); + for (int32_t k_outer = 0; k_outer < act_group_size / 32; ++k_outer) { + partial_max_g4_int8_k8((&(((tmac_float_type*)LUT_Scales)[kk_outer])), (&(((tmac_float_type*)B)[((kk_outer * act_group_size) + (k_outer * 32))]))); + } + } + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + if (bits == 2) { + lut_ctor_g4_int8_impl<0, 2>(act_group_size, (&(((int8_t*)QLUT)[(k_outer_1 * act_group_size * 4)])), (&(((tmac_float_type*)B)[(k_outer_1 * act_group_size)])), (&(((tmac_float_type*)LUT_Scales)[k_outer_1])), (&(((tmac_float_type*)LUT_Biases)[k_outer_1]))); + } else if (bits == 4) { + lut_ctor_g4_int8_impl<0, 4>(act_group_size, (&(((int8_t*)QLUT)[(k_outer_1 * act_group_size * 4)])), (&(((tmac_float_type*)B)[(k_outer_1 * act_group_size)])), (&(((tmac_float_type*)LUT_Scales)[k_outer_1])), (&(((tmac_float_type*)LUT_Biases)[k_outer_1]))); + } + } +} + diff --git a/ggml/src/ggml-cpu/tmac/lut_ctor.h b/ggml/src/ggml-cpu/tmac/lut_ctor.h new file mode 100644 index 0000000000000..3a9ec81c1c492 --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/lut_ctor.h @@ -0,0 +1,72 @@ +#pragma once + +/* Please do not include this header file outside ggml-cpu/tmac */ + +#ifndef INTRINSIC_TYPES_H +#define INTRINSIC_TYPES_H + +#ifdef __ARM_NEON +#include +#elif defined __AVX2__ +#include +#endif + +#ifdef __ARM_NEON +typedef float16_t tmac_float_type; +#else +#include +#include +typedef float tmac_float_type; +#endif + +#endif + + +#ifdef __ARM_NEON +#define vaddvq_f16(v) \ + ((v)[0] + (v)[1] + (v)[2] + (v)[3] + (v)[4] + (v)[5] + (v)[6] + (v)[7]) +#elif defined __AVX2__ +static inline float _mm256_addv_ps(const __m256 v); +#endif + +#define my_fputs(s) fputs(s, stderr); fflush(stderr); +#define my_fputsf(buf, s, ...) snprintf(buf, sizeof(buf), s, __VA_ARGS__); my_fputs(buf); + + +struct tmac_kernel_config { + int32_t g; + int32_t ngroups_per_elem; + int32_t q_group_size; + int32_t act_group_size; + + bool has_scale; + int kfactor; + int bits; + int actk; // should be equal to (act_group_size / g). + bool has_zero_point; + bool one_scale; + + int32_t bm; + uint32_t simd_n_in; + uint32_t simd_n_out; + + int32_t chunk_n; +}; + + + +#ifdef __cplusplus +extern "C" { +#endif + +int32_t partial_max_g4_int8_k8(void* lut_scales_, void* b_); + +int32_t partial_max_reset(void* lut_scales_); + +void lut_ctor_int8_g4(void* B, void* LUT_Scales, void* LUT_Biases, void* QLUT, int K, const struct tmac_kernel_config * const kernel_config); + +#ifdef __cplusplus +} +#endif + + diff --git a/ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp b/ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp new file mode 100644 index 0000000000000..7e64fcb013547 --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp @@ -0,0 +1,1222 @@ +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP +#include "ggml.h" +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "lut_mul_mat.h" + + +#if defined(GGML_USE_TMAC) + +namespace ggml::cpu::tmac { + bool tensor_traits::work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) { + if (ggml_tmac_can_mul_mat(op)) { + size = ggml_backend_tmac_desired_wsize(op); + return true; + } + return false; + } + + bool tensor_traits::compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) { + if (ggml_tmac_can_mul_mat(op)) { + ggml_backend_tmac_mul_mat(params, op); + return true; + }; + return false; + } +} // namespace ggml::cpu::tmac + + +/****** T-MAC properties ******/ +constexpr size_t kAllocAlignment = 64; +const int n_threads = 8; + +static tmac_tensor_extra * tmac_tensor_extras = nullptr; +static size_t tmac_tensor_extras_index = 0; + +struct tmac_run_single_kernel_settings { + int32_t test_time_ms; + int32_t M; + int32_t N; + int32_t K; + + int32_t n; + + struct tmac_kernel_config * kernel_config; +}; + +static bool initialized = false; +void tmac_init() { + if (initialized) { + return; + } + initialized = true; + + if (tmac_tensor_extras == nullptr) { + tmac_tensor_extras = new tmac_tensor_extra[GGML_TMAC_MAX_NODES]; + } + tmac_tensor_extras_index = 0; +} +void tmac_free() { + // TODO +} + +/****** T-MAC helper functions ******/ +static inline bool is_tmac_2bit_type(enum ggml_type type) { + return ( + type == GGML_TYPE_TMAC_BN_0 || + type == GGML_TYPE_TMAC_W2G64_0 || + type == GGML_TYPE_TMAC_W2G64_1 || + type == GGML_TYPE_TMAC_W2G128_0 || + type == GGML_TYPE_TMAC_W2G128_1 + ); +} + +static inline bool is_tmac_4bit_type(enum ggml_type type) { + return ( + type == GGML_TYPE_TMAC_W4G64_0 || + type == GGML_TYPE_TMAC_W4G64_1 || + type == GGML_TYPE_TMAC_W4G128_0 || + type == GGML_TYPE_TMAC_W4G128_1 + ); +} + +bool is_tmac_type(enum ggml_type type) { + return ( + is_tmac_2bit_type(type) || + is_tmac_4bit_type(type) + ); +} + +bool is_type_supported(enum ggml_type type) { + return ( + type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TQ1_0 || + type == GGML_TYPE_TQ2_0 || + is_tmac_2bit_type(type) || + is_tmac_4bit_type(type) + ); +} + +bool ggml_tmac_can_mul_mat(const struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + struct ggml_tensor * src1 = dst->src[1]; + + if (dst->op == GGML_OP_MUL_MAT && + (is_type_supported(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + strcmp(src0->name, "token_embd.weight") && // means not equal + strcmp(src0->name, "output.weight")) { + return true; + } + return false; +} + +static inline int get_type_bits(enum ggml_type type) { + if (is_tmac_2bit_type(type) || type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0) { + return 2; + } else if (is_tmac_4bit_type(type) || type == GGML_TYPE_Q4_0) { + return 4; + } else { + return 0; + } +} + +static inline int get_type_group_size(enum ggml_type type) { + switch (type) { + case GGML_TYPE_TMAC_BN_0: + return -1; + case GGML_TYPE_TMAC_W2G64_0: + case GGML_TYPE_TMAC_W2G64_1: + case GGML_TYPE_TMAC_W4G64_0: + case GGML_TYPE_TMAC_W4G64_1: + return 64; + case GGML_TYPE_TMAC_W2G128_0: + case GGML_TYPE_TMAC_W2G128_1: + case GGML_TYPE_TMAC_W4G128_0: + case GGML_TYPE_TMAC_W4G128_1: + return 128; + case GGML_TYPE_Q4_0: + return 32; + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + return 256; + default: + return 0; + } +} + +static inline bool get_type_has_zero_point(enum ggml_type type) { + switch (type) { + case GGML_TYPE_TMAC_BN_0: + case GGML_TYPE_TMAC_W2G64_0: + case GGML_TYPE_TMAC_W4G64_0: + case GGML_TYPE_TMAC_W2G128_0: + case GGML_TYPE_TMAC_W4G128_0: + return false; + case GGML_TYPE_TMAC_W2G64_1: + case GGML_TYPE_TMAC_W4G64_1: + case GGML_TYPE_TMAC_W2G128_1: + case GGML_TYPE_TMAC_W4G128_1: + return true; + case GGML_TYPE_Q4_0: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + return false; + default: + return false; + } +} + +static inline bool get_type_is_one_scale(enum ggml_type type) { + switch (type) { + case GGML_TYPE_TMAC_BN_0: + return true; + default: + return false; + } +} + +static inline int ggml_tmac_get_scales_size(const struct tmac_kernel_config * kernel_config, int m, int k) { + int scales_size; + if (kernel_config->one_scale) { + scales_size = 1; + } else if (kernel_config->has_zero_point) { + scales_size = m * k / kernel_config->q_group_size * 2; + } else{ + scales_size = m * k / kernel_config->q_group_size; + } + return scales_size; +} + +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, kAllocAlignment); +#else + void * ptr = nullptr; + posix_memalign(&ptr, kAllocAlignment, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} + + +/****** T-MAC meta model info ******/ +static void init_tmac_kernel_config_from_tensor_type(enum ggml_type type, int M, struct tmac_kernel_config * kernel_config) { + kernel_config->bits = get_type_bits(type); + kernel_config->q_group_size = get_type_group_size(type); + kernel_config->has_zero_point = get_type_has_zero_point(type); + kernel_config->one_scale = get_type_is_one_scale(type); + + // Fixed features + kernel_config->has_scale = true; + kernel_config->g = 4; + kernel_config->ngroups_per_elem = 8 / kernel_config->g; + + // Decide q_group_size for BN_0 + if (kernel_config->q_group_size == -1) { + if (M % 256 == 0) { + kernel_config->q_group_size = 64; + } else if (M % 128 == 0) { + kernel_config->q_group_size = 64; + } else if (M % 64 == 0) { + kernel_config->q_group_size = 64; + } else if (M % 32 == 0) { + kernel_config->q_group_size = 32; + } else { + GGML_LOG_ERROR("Unsupported M value. Expected multiple of 32, got %d. Please check all of the model weight shapes.\n", M); + } + } + + if (kernel_config->q_group_size % 64 == 0) { + kernel_config->act_group_size = 64; + } else if (kernel_config->q_group_size % 32 == 0) { + kernel_config->act_group_size = 32; + } else { + GGML_LOG_ERROR("Unsupported activation group size: %d\n", kernel_config->q_group_size); + } + kernel_config->actk = kernel_config->act_group_size / kernel_config->g; + + // kfactor to be tuned + // bm to be tuned + kernel_config->simd_n_in = 16; + kernel_config->simd_n_out = 8; + + kernel_config->chunk_n = 8; +} + + +/****** T-MAC configurations ******/ +static std::unordered_map final_tmac_kernel_config; +static std::string get_tmac_kernel_config_key(int M, int K, int bits) { + return "M" + std::to_string(M) + "_K" + std::to_string(K) + "_b" + std::to_string(bits); +} +struct tmac_kernel_config * find_tmac_kernel_config(int M, int K, int bits) +{ + std::string key = get_tmac_kernel_config_key(M, K, bits); + if (final_tmac_kernel_config.count(key) == 0) { + return nullptr; + } + return &final_tmac_kernel_config[key]; +} +static void insert_or_assign_tmac_kernel_config(int M, int K, int bits, struct tmac_kernel_config kernel_config) +{ + std::string key = get_tmac_kernel_config_key(M, K, bits); + final_tmac_kernel_config.insert_or_assign(key, kernel_config); +} + + +static inline void ggml_tmac_forward_mul_mat( + void * A, void * B, void * C, void * QLUT, void * LUT_Scales, void * LUT_Biases, void * Scales, + int M, int N, int K, const struct tmac_kernel_config * kernel_config) { + // Currently, scale is a must. + assert(kernel_config->has_scale); + // Currently, one_scale and has_zero_point are mutually exclusive. + assert(!(kernel_config->one_scale && kernel_config->has_zero_point)); + + int bits = kernel_config->bits; + int bm = kernel_config->bm; + int act_group_size = kernel_config->act_group_size; + + lut_ctor_int8_g4(B, LUT_Scales, LUT_Biases, QLUT, K, kernel_config); + + const int m = bm / bits; + const int64_t chunk_size0 = m; + + for (int32_t chunk_outer = 0; chunk_outer < M/m; chunk_outer++) { + /* One Block */ + const int64_t w_offset = chunk_outer * m * K * bits / 8; + const int64_t scales_offset = kernel_config->one_scale ? 0 : ggml_tmac_get_scales_size(kernel_config, m, K) * chunk_outer; + + for (int32_t n_outer = 0; n_outer < N; n_outer++) { + const int64_t qlut_offset = K * n_outer * 4; + const int64_t lut_scales_offset = K / act_group_size * n_outer; + const int64_t dst_offset = M * n_outer + chunk_outer * chunk_size0; + + int8_t *lut = (int8_t *)QLUT + qlut_offset; + uint8_t *a = (uint8_t *)A + w_offset; + tmac_float_type *scales = (tmac_float_type *)Scales + scales_offset; + tmac_float_type *lut_scales = (tmac_float_type *)LUT_Scales + lut_scales_offset; + tmac_float_type *lut_biases = (tmac_float_type *)LUT_Biases + lut_scales_offset; + tmac_float_type *act_output = (tmac_float_type *)C + dst_offset; + + qgemm_lut_int8_g4(a, lut, scales, lut_scales, lut_biases, act_output, bm, K, N, kernel_config); + } + /* One Block */ + } +} + +static void ggml_tmac_tune_single_kernel_config(const struct tmac_run_single_kernel_settings * const settings, double & elapsed_time) { + if (settings->kernel_config->kfactor < settings->kernel_config->actk) { + return; + } + + const int test_time_ms = settings->test_time_ms; + const int M = settings->M; + const int N = settings->N; + const int K = settings->K; + const struct tmac_kernel_config * const kernel_config = settings->kernel_config; + const int bits = kernel_config->bits; + const int act_group_size = kernel_config->act_group_size; + const int bm = kernel_config->bm; + // const int m = bm / bits; + const int scales_size = ggml_tmac_get_scales_size(kernel_config, M, K); + + std::chrono::duration total_elapsed = std::chrono::duration::zero(); + GGML_LOG_DEBUG("Run single kernel config: M=%d, N=%d, K=%d, bm=%d, kfactor=%d, actk=%d\n", M, N, K, bm, kernel_config->kfactor, kernel_config->actk); + int n_try = 0; + while (total_elapsed.count() < test_time_ms / 1000.0) { + uint8_t *A = new uint8_t[M * K * bits / 8]; // quantized weight + tmac_float_type *B = new tmac_float_type[K * N]; // activation + tmac_float_type *C = new tmac_float_type[M * N]; // output + int8_t *QLUT = new int8_t[K * N * 4]; + tmac_float_type *LUT_Scales = new tmac_float_type[K * N / act_group_size]; + tmac_float_type *LUT_Biases = new tmac_float_type[K * N / act_group_size]; + tmac_float_type *Scales = new tmac_float_type[scales_size]; + + // multi-threading profiling + auto start = std::chrono::high_resolution_clock::now(); + ggml_tmac_forward_mul_mat(A, B, C, QLUT, LUT_Scales, LUT_Biases, Scales, + M, N, K, kernel_config); + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration elapsed = end - start; + total_elapsed += elapsed; + n_try++; + + delete[] A; + delete[] B; + delete[] C; + delete[] QLUT; + delete[] LUT_Scales; + delete[] LUT_Biases; + delete[] Scales; + } + + elapsed_time = total_elapsed.count() / n_try * 1000.0; // in ms +} + +static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int M, int K) { + const int bits = get_type_bits(tensor->type); + struct tmac_kernel_config * existing_kcfg = find_tmac_kernel_config(M, K, bits); + if (existing_kcfg != nullptr) { + return; + } + + struct tmac_kernel_config kernel_config; + init_tmac_kernel_config_from_tensor_type(tensor->type, M, &kernel_config); + + // TODO: add more choices for prefilling? + int N = 1; + + // search space + std::vector bms; + if (bits == 1 || bits == 2 || bits == 4) { + bms = {256, 512, 1024, 2048, 320, 640, 1280}; + } else if (bits == 3) { + bms = {192, 384, 576, 768}; + } + std::vector bns = {8, 16, 32, 64}; + std::vector kfactors = {8, 16}; + + + double min_time = 1e9; + struct tmac_kernel_config best_kcfg = kernel_config; + + auto profile_based = [&]() { + for (int bm: bms) { + if (M % (bm/bits) != 0 || bm % bits != 0) { + continue; + } + + kernel_config.bm = bm; + for (int n: bns) { + if ((N >= n && N % n != 0) || (N < n && n != bns[0])) { + continue; + } + + for (int kfactor: kfactors) { + if ((kfactor < kernel_config.actk) || (kfactor * kernel_config.g > kernel_config.q_group_size)) { + continue; + } + + kernel_config.kfactor = kfactor; + // insert to dict for finding + insert_or_assign_tmac_kernel_config(M, K, bits, kernel_config); + struct tmac_run_single_kernel_settings settings = { + /* .test_time_ms = */ 5000, + /* .M = */ M, + /* .N = */ N, + /* .K = */ K, + /* .n = */ n, + /* .kernel_config = */ &kernel_config + }; + double this_time; + ggml_tmac_tune_single_kernel_config(&settings, this_time); + if (this_time < min_time) { + min_time = this_time; + best_kcfg = kernel_config; + } + } + } + }; + }; + auto rule_based = [&]() { + float smallest_penalty = 1e9; + for (int bm: bms) { + if (M % (bm/bits) != 0 || bm % bits != 0) { + continue; + } + int num_tiles = M / (bm/bits); + int num_groups = (num_tiles + n_threads - 1) / n_threads; + float penalty = 0.1 * num_groups + (num_groups - 1.0 * num_tiles / n_threads) / num_groups; + if (penalty <= smallest_penalty) { + smallest_penalty = penalty; + best_kcfg.bm = bm; + } + } + + int largest_kfactor = 0; + for (int kfactor: kfactors) { + if ((kfactor < kernel_config.actk) || (kfactor * kernel_config.g > kernel_config.q_group_size)) { + continue; + } + if (kfactor > largest_kfactor) { + largest_kfactor = kfactor; + best_kcfg.kfactor = kfactor; + } + } + }; + rule_based(); + + // Save the results + insert_or_assign_tmac_kernel_config(M, K, bits, best_kcfg); + GGML_LOG_INFO("Tuned kernel config: M=%d, N=%d, K=%d, bm=%d, kfactor=%d, bits=%d, actk=%d, g=%d, ngroups_per_elem=%d, q_group_size=%d, act_group_size=%d\n", + M, N, K, best_kcfg.bm, best_kcfg.kfactor, bits, best_kcfg.actk, best_kcfg.g, best_kcfg.ngroups_per_elem, best_kcfg.q_group_size, best_kcfg.act_group_size); +} + + + +size_t ggml_backend_tmac_desired_wsize(const struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + struct ggml_tensor * src1 = dst->src[1]; + + const size_t n = src0->ne[1]; // llama.cpp n + const size_t k = src1->ne[0]; // k + const size_t m = src1->ne[1]; // llama.cpp m + const int bits = get_type_bits(src0->type); + + struct tmac_kernel_config * kernel_config = find_tmac_kernel_config(n, k, bits); + if (kernel_config == nullptr) { + ggml_tmac_tune_kernel_config(src0, n, k); + kernel_config = find_tmac_kernel_config(n, k, bits); + } + const int lut_scales_size = k / kernel_config->act_group_size; + + size_t wsize = k * m * 4 * sizeof(int8_t) + lut_scales_size * m * 2 * sizeof(tmac_float_type); + if (sizeof(tmac_float_type) == 2) { + // Need fp32 to fp16 conversion + wsize += std::max(k, n) * m * sizeof(tmac_float_type); + } + wsize = ((wsize - 1) / kAllocAlignment + 1) * kAllocAlignment; + return wsize; +} + +size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor) { + if (is_tmac_type(tensor->type)) { + const int bits = get_type_bits(tensor->type); + + int k = tensor->ne[0]; + int m = tensor->ne[1]; // `n` in llama.cpp + + struct tmac_kernel_config * kernel_config = find_tmac_kernel_config(m, k, bits); + if (kernel_config == nullptr) { + ggml_tmac_tune_kernel_config(tensor, m, k); + kernel_config = find_tmac_kernel_config(m, k, bits); + } + + const int scales_size = ggml_tmac_get_scales_size(kernel_config, m, k); + // Currently, always uses float16 to store scales or zero points + size_t nbytes = k * m / 8 * bits + scales_size * sizeof(ggml_fp16_t); + nbytes = GGML_PAD(nbytes, GGUF_DEFAULT_ALIGNMENT); + return nbytes; + } else { + return ggml_nbytes(tensor); + } +} + + + + +/****** T-MAC convert tensor ******/ +static bool do_permutate(enum ggml_type type) { + return true; + // if (type == GGML_TYPE_I1 || + // type == GGML_TYPE_I2 || + // type == GGML_TYPE_I3 || + // type == GGML_TYPE_I4) { + // // Add additional args to decide if permuted I2 or naive I2 + // return false; + // } else { + // return true; + // } +} + +struct BlockQ40TypeAccessor { + using block_t = block_q4_0; + + static constexpr int BITS = 4; + static constexpr int SIMD_LEN = 16; + static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS; + static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + int internal_idx = idx % group_size; + const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN; + int simd_idx = internal_idx % simd_n_elem; + return simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS); + } + + static tmac_float_type get_scale(const void * data, int idx) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(d); + } + } +}; + +struct BlockI2TypeAccessor { + static constexpr int BITS = 2; + static constexpr int n_elem = 8 / BITS; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) data; + int elem_idx = idx % n_elem; + return qs[idx / n_elem] >> ((n_elem - 1 - elem_idx) * BITS); + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + const ggml_fp16_t * ss = (const ggml_fp16_t *) data; + ggml_fp16_t s = ss[idx / group_size]; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&s); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(s); + } + } + + static tmac_float_type get_zero_point(const void * data, int idx, int group_size) { + const ggml_fp16_t * zs = (const ggml_fp16_t *) data; + ggml_fp16_t z = zs[idx / group_size]; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&z); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(z); + } + } +}; + +struct BlockI4TypeAccessor { + static constexpr int BITS = 4; + static constexpr int n_elem = 8 / BITS; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) data; + int elem_idx = idx % n_elem; + return qs[idx / n_elem] >> ((n_elem - 1 - elem_idx) * BITS); + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + const ggml_fp16_t * ss = (const ggml_fp16_t *) data; + ggml_fp16_t s = ss[idx / group_size]; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&s); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(s); + } + } + + static tmac_float_type get_zero_point(const void * data, int idx, int group_size) { + const ggml_fp16_t * zs = (const ggml_fp16_t *) data; + ggml_fp16_t z = zs[idx / group_size]; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&z); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(z); + } + } +}; + + +struct BlockTQ10TypeAccessor { + using block_t = block_tq1_0; + + static constexpr int elements_qs = 5; // 5 elements per byte + static constexpr int elements_qh = 4; // 4 elements per byte + static constexpr int BITS = 2; + static constexpr int group_size_qs = sizeof(((block_t *)0)->qs) * elements_qs; + static constexpr int group_size_qh = sizeof(((block_t *)0)->qh) * elements_qh; + static constexpr int group_size = group_size_qs + group_size_qh; + static constexpr int SIMD_LEN_qs_1 = 32; + static constexpr int SIMD_LEN_qs_2 = 16; + static constexpr int SIMD_LEN_qh = 4; + static constexpr int simd_n_elem_qs_1 = SIMD_LEN_qs_1 * elements_qs; // 160 + static constexpr int simd_n_elem_qs_2 = SIMD_LEN_qs_2 * elements_qs; // 80 + static constexpr int simd_n_elem_qh = SIMD_LEN_qh * elements_qh; // 16 + + static constexpr uint8_t pow3[5] = {1, 3, 9, 27, 81}; + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + uint8_t cur_qs; + uint8_t trit; + int internal_idx = idx % group_size; + + if (internal_idx < simd_n_elem_qs_1) { + const int internal_offset = 0; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx; + int simd_byte = simd_idx % SIMD_LEN_qs_1; + int simd_trit = simd_idx / SIMD_LEN_qs_1; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + else if (internal_idx < simd_n_elem_qs_1 + simd_n_elem_qs_2) { + const int internal_offset = SIMD_LEN_qs_1; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx - simd_n_elem_qs_1; + int simd_byte = simd_idx % SIMD_LEN_qs_2; + int simd_trit = simd_idx / SIMD_LEN_qs_2; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + else { + const int internal_offset = SIMD_LEN_qs_1 + SIMD_LEN_qs_2; + const uint8_t * simd_qs = qs + internal_offset; + int simd_idx = internal_idx - simd_n_elem_qs_1 - simd_n_elem_qs_2; + int simd_byte = simd_idx % SIMD_LEN_qh; + int simd_trit = simd_idx / SIMD_LEN_qh; + + cur_qs = simd_qs[simd_byte] * pow3[simd_trit]; + trit = ((uint16_t) cur_qs * 3) >> 8; + } + + return trit + 1; + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(d); + } + } +}; + +struct BlockTQ20TypeAccessor { + using block_t = block_tq2_0; + + static constexpr int BITS = 2; + static constexpr int SIMD_LEN = 32; + static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS; // 256 + static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS; // 128 + + static uint8_t get_q(const void * data, int idx) { + const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs); + int internal_idx = idx % group_size; + const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN; + int simd_idx = internal_idx % simd_n_elem; + return (simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS)) + 1; + } + + static tmac_float_type get_scale(const void * data, int idx, int group_size) { + ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d; + if constexpr (sizeof(tmac_float_type) == 2) { + tmac_float_type * fp16dp = reinterpret_cast(&d); + return *fp16dp; + } else { + return ggml_fp16_to_fp32(d); + } + } +}; + +static inline void ggml_tmac_transform_tensor(struct ggml_tensor * tensor, const void * origin_data) { + GGML_ASSERT(tensor->extra != nullptr); + struct ggml::cpu::tmac::tensor_traits * tensor_extra = (struct ggml::cpu::tmac::tensor_traits *) tensor->extra; + if (!(is_type_supported(tensor->type) && tensor_extra->get_tmac_tensor_extra(tensor->name) == nullptr)) { + return; + } + + const int bits = get_type_bits(tensor->type); + int k = tensor->ne[0]; + int m = tensor->ne[1]; // `n` in llama.cpp + + struct tmac_kernel_config * kernel_config = find_tmac_kernel_config(m, k, bits); + if (kernel_config == nullptr) { + ggml_tmac_tune_kernel_config(tensor, m, k); + kernel_config = find_tmac_kernel_config(m, k, bits); + } + + // Currently, scale is a must. + assert(kernel_config->has_scale); + // Currently, one_scale and has_zero_point are mutually exclusive. + assert(!(kernel_config->one_scale && kernel_config->has_zero_point)); + + const int g = kernel_config->g; + const int ngroups_per_elem = kernel_config->ngroups_per_elem; + const int bm = kernel_config->bm; + const int simd_n_in = kernel_config->simd_n_in; + const int simd_n_out = kernel_config->simd_n_out; + const int kfactor = kernel_config->kfactor; + const int group_size = kernel_config->q_group_size; + + const int act_group_size = kernel_config->act_group_size; + const int lut_scales_size = k / act_group_size; + const int scales_size = ggml_tmac_get_scales_size(kernel_config, m, k); + const int n_tile_num = m * bits / bm; + + GGML_LOG_DEBUG("Transforming tensor: %s (m: %d, k: %d, bits: %d)\n", tensor->name, m, k, bits); + GGML_LOG_DEBUG("kcfg (bm=%d, simd_n_in=%d, simd_n_out=%d, kfactor=%d, group_size=%d, lut_scales_size=%d, scales_size=%d, n_tile_num=%d)\n", + bm, simd_n_in, simd_n_out, kfactor, group_size, lut_scales_size, scales_size, n_tile_num); + if (bm == 0) { + if (!strcmp(tensor->name, "token_embd.weight") || !strcmp(tensor->name, "output.weight")) { + GGML_LOG_WARN("Do not find kcfg for %s. Consider compiling T-MAC kernel for it if vocab size is a multiply of 128 or 320, detected %lld.\n", tensor->name, tensor->ne[1]); + return; + } + else { + // TODO: Instead of fatal error, try to avoid using t-mac? + GGML_LOG_ERROR("Failed to find kcfg. Abort transforming\n"); + return; + } + } + + const int mgroup = ngroups_per_elem * simd_n_in; + m = m * bits; + + uint8_t * qweights; + tmac_float_type * scales; + + // TODO: if sizeof(tmac_float_type) <= sizeof(float), we can copy tensor->data to qweights and scales, + // and do permutation on tensor->data, finally aligned_free qweights and scales. + if (do_permutate(tensor->type)) { + scales = (tmac_float_type *) aligned_malloc(scales_size * sizeof(tmac_float_type)); + qweights = (uint8_t *) aligned_malloc(k * m / 8); + } else { + /* scales could be either float32 or float16, so inplace cast is feasible. */ + GGML_ASSERT(sizeof(tmac_float_type) <= sizeof(float)); + qweights = (uint8_t *) tensor->data; + scales = (tmac_float_type *) (qweights + k * m / 8); + float * i2_scales = (float * )(qweights + k * m / 8); + for (int i = 0; i < scales_size; i++) { + scales[i] = (tmac_float_type) i2_scales[i]; + } + } + + struct tmac_tensor_extra * cur_tensor_extra = new tmac_tensor_extra({ + /* .lut_scales_size = */ lut_scales_size, + /* .scales_size = */ scales_size, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }); + tensor_extra->set_tmac_tensor_extra(tensor->name, cur_tensor_extra); + + if (do_permutate(tensor->type)) { +// for fast testing +// #define TMAC_EMPTY_WEIGHTS +#ifndef TMAC_EMPTY_WEIGHTS + std::vector threads; + const int n_threads = std::thread::hardware_concurrency(); + + // TODO: optimize to accelerate weights loading + uint8_t * buf2 = new uint8_t[m * k / g]; + memset(buf2, 0, m * k / g); + + // # (M // bits, K, bits) + // w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1) + // # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g) + // w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g) + // w = sum([(w[:, :, :, ig] << ig) for ig in range(g)]) + threads.reserve(n_threads); + auto parallel_worker_buf2 = [&](size_t start_index, size_t end_index) { + for (int im = start_index; im < end_index; im++) { + for (int ik = 0; ik < k; ik++) { + uint8_t v; + if (tensor->type == GGML_TYPE_Q4_0) { + v = BlockQ40TypeAccessor::get_q(origin_data, im * k + ik); + } else if (is_tmac_2bit_type(tensor->type)) { + v = BlockI2TypeAccessor::get_q(origin_data, im * k + ik); + } else if (is_tmac_4bit_type(tensor->type)) { + v = BlockI4TypeAccessor::get_q(origin_data, im * k + ik); + } else if (tensor->type == GGML_TYPE_TQ1_0) { + v = BlockTQ10TypeAccessor::get_q(origin_data, im * k + ik); + } else if (tensor->type == GGML_TYPE_TQ2_0) { + v = BlockTQ20TypeAccessor::get_q(origin_data, im * k + ik); + } else { + GGML_LOG_ERROR("Unsupported type: %s\n", ggml_type_name(tensor->type)); + } + + for (int ib = 0; ib < bits; ib++) { + int new_im = im; + int new_ib = ib; + int new_ik = ik / g; + int shft_left = ik % g; + buf2[new_im * bits * k / g + new_ib * k / g + new_ik] += ((v >> ib) & 1) << shft_left; + } + } + } + }; + + size_t start_index = 0; + size_t chunk_size = m / bits / n_threads; + for (size_t i = 0; i < n_threads; ++i) { + size_t end_index = (i == n_threads - 1) ? m / bits : start_index + chunk_size; + + // Create and launch a thread + threads.emplace_back(parallel_worker_buf2, + start_index, + end_index); // Pass the mutex array by reference + + start_index = end_index; + } + // Wait for all threads to complete + for (std::thread& t : threads) { + t.join(); + } + threads.clear(); + + // # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 + // # for bits=3 + // # bit0: [0, 8), bit1: [8, 16), bit2: [16, 24), bit0: [24, 32) + // # (M // bits // simd_n_float16, bits, simd_n_float16, K // g) + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + // mgroup = ngroups_per_elem * simd_n_in + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + memset(qweights, 0, m * k / g / ngroups_per_elem); + + int c0_fac2 = k / g; + int c0_fac1 = simd_n_out * c0_fac2; + int c0_fac0 = bits * c0_fac1; + + int c1_nb2 = k / g; + int c1_nb1 = simd_n_in * c1_nb2; + int c1_nb0 = ngroups_per_elem * c1_nb1; + int c1_fac2 = k / g; + int c1_fac1 = ngroups_per_elem * c1_fac2; + int c1_fac0 = simd_n_in * c1_fac1; + + + int c2_nb4 = kfactor; + int c2_nb3 = k / g / kfactor * c2_nb4; + int c2_nb2 = ngroups_per_elem * c2_nb3; + int c2_nb1 = simd_n_in * c2_nb2; + int c2_nb0 = bm / mgroup * c2_nb1; + int c2_fac3 = simd_n_in * ngroups_per_elem; + int c2_fac2 = kfactor * c2_fac3; + int c2_fac1 = bm / mgroup * c2_fac2; + int c2_fac0 = k / g / kfactor * c2_fac1; + + threads.reserve(n_threads); + auto parallel_worker_qweights = [&](size_t start_index, size_t end_index) { + for (int im = start_index; im < end_index; im++) { + for (int ib = 0; ib < bits; ib++) { + for (int ik = 0; ik < k / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + int new_im = im / simd_n_out; + int new_isno = im % simd_n_out; + int new_ib = ib; + int new_ik = ik; + int new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + int new_ing = (new_idx % c1_nb0) / c1_nb1; + int new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + int new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + int new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g); + } + } + } + }; + + start_index = 0; + chunk_size = m / bits / n_threads; + for (size_t i = 0; i < n_threads; ++i) { + size_t end_index = (i == n_threads - 1) ? m / bits : start_index + chunk_size; + + // Create and launch a thread + threads.emplace_back(parallel_worker_qweights, + start_index, + end_index); // Pass the mutex array by reference + + start_index = end_index; + } + // Wait for all threads to complete + for (std::thread& t : threads) { + t.join(); + } + threads.clear(); + + const ggml_fp16_t * int_n_scales = (const ggml_fp16_t * ) ((const uint8_t *) origin_data + k * m / 8); + const ggml_fp16_t * int_n_zero_points = int_n_scales + scales_size / 2; + + if (scales_size < m / bits) { // BitNet-like scale (m_groups,) + for (int i = 0; i < scales_size; i++) { + scales[i] = BlockI2TypeAccessor::get_scale(int_n_scales, i, 1); + } + } else { + // TODO: move if-else outside the loop + // scales = scales.reshape(M // bm, bm // bits, K // group_size).transpose(0, 2, 1) + for (int im = 0; im < m / bits; im += 1) { + for (int ik = 0; ik < k; ik += group_size) { + tmac_float_type scale; + int idx = im * k + ik; + if (tensor->type == GGML_TYPE_Q4_0) { + scale = BlockQ40TypeAccessor::get_scale(origin_data, idx); + } else if (is_tmac_2bit_type(tensor->type)) { + scale = BlockI2TypeAccessor::get_scale(int_n_scales, idx, group_size); + } else if (is_tmac_4bit_type(tensor->type)) { + scale = BlockI4TypeAccessor::get_scale(int_n_scales, idx, group_size); + } else if (tensor->type == GGML_TYPE_TQ1_0) { + scale = BlockTQ10TypeAccessor::get_scale(origin_data, idx, group_size); + } else if (tensor->type == GGML_TYPE_TQ2_0) { + scale = BlockTQ20TypeAccessor::get_scale(origin_data, idx, group_size); + } else { + GGML_LOG_ERROR("Unsupported type for get_scale: %s\n", ggml_type_name(tensor->type)); + } + + tmac_float_type zero_point; + if (get_type_has_zero_point(tensor->type)) { + if (is_tmac_2bit_type(tensor->type)) { + zero_point = BlockI2TypeAccessor::get_zero_point(int_n_zero_points, idx, group_size); + } else if (is_tmac_4bit_type(tensor->type)) { + zero_point = BlockI4TypeAccessor::get_zero_point(int_n_zero_points, idx, group_size); + } else { + GGML_LOG_ERROR("Unsupported type for get_zero_point: %s\n", ggml_type_name(tensor->type)); + } + } + + idx = idx / group_size; + int nb1 = k / group_size; + int nb0 = bm / bits * nb1; + int new_im = idx / nb0; + int new_ibm = (idx % nb0) / nb1; + int new_ik = (idx % nb1); + + if (get_type_has_zero_point(tensor->type)) { + int new_isimd = new_ibm % simd_n_out; + int new_idx_outer = new_im * bm / bits * k / group_size / simd_n_out + + new_ik * bm / bits / simd_n_out + + new_ibm / simd_n_out; + int new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + int new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + scales[new_idx_scale] = scale; + scales[new_idx_zero] = zero_point; + } else { + int new_idx = new_im * bm / bits * k / group_size + new_ik * bm / bits + new_ibm; + scales[new_idx] = scale; + } + } + } + } + + delete[] buf2; +#else + memset(qweights, 0x88, k * m / 8); + for (int i = 0; i < scales_size; i++) { + scales[i] = 1.0f; + } +#endif + } // if (do_permutate(tensor->type)) +} + +void ggml_backend_tmac_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0 && size == ggml_tmac_get_nbytes(tensor)); // only full tensor conversion is supported for now + ggml_tmac_transform_tensor(tensor, data); +} + + +/****** T-MAC compute ******/ + + +// m = batch_size +// n = output_dim +// t-mac llama.cpp n and m swapped +static inline void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits) { + struct tmac_kernel_config * kernel_config = find_tmac_kernel_config(n, k, bits); + if (kernel_config == nullptr) { + throw std::runtime_error("ggml_tmac_mul_mat_task_init: Failed to find kernel config for m" + std::to_string(n) + "_k" + std::to_string(k) + "_b" + std::to_string(bits)); + } + lut_ctor_int8_g4(src1, lut_scales, lut_biases, qlut, k, kernel_config); +} + +static inline void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits) { + struct tmac_kernel_config * kernel_config = find_tmac_kernel_config(n, k, bits); + if (kernel_config == nullptr) { + GGML_LOG_INFO("Failed to find kernel config for m%d_k%d_b%d\n", n, k, bits); + throw std::runtime_error("ggml_tmac_mul_mat_task_compute: Failed to find kernel config for m" + std::to_string(n) + "_k" + std::to_string(k) + "_b" + std::to_string(bits)); + } + qgemm_lut_int8_g4(src0, qlut, scales, lut_scales, lut_biases, dst, kernel_config->bm, k, m, kernel_config); +} + + +void ggml_backend_tmac_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + const int bits = get_type_bits(src0->type); + // src0: weight, ne00 = k, ne01 = n + // src1: activation, ne10 = k, ne11 = m + char * wdata = (char *) (params->wdata); + + struct tmac_tensor_extra * wt = ((struct ggml::cpu::tmac::tensor_traits *)src0->extra)->get_tmac_tensor_extra(src0->name); + char * cur_wdata = wdata; + tmac_float_type * tmac_f_ptr = (tmac_float_type *) wdata; + if (sizeof(tmac_float_type) == 2) { + cur_wdata = wdata + MAX(ne10, ne01) * ne11 * sizeof(tmac_float_type); + }; + int8_t * qlut = (int8_t *) cur_wdata; + tmac_float_type * lut_scales = (tmac_float_type *) (qlut + ne10 * ne11 * 4); + tmac_float_type * lut_biases = (tmac_float_type *) (lut_scales + wt->lut_scales_size * ne11); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + tmac_float_type * act_input; + if (sizeof(tmac_float_type) == 2) { + act_input = tmac_f_ptr; + } else { + act_input = (tmac_float_type *) src1->data; + } + + for (int ine11 = ith; ine11 < ne11; ine11 += nth) { + if constexpr (sizeof(tmac_float_type) == 2) { + // TODO: can we reuse the src1->data memory? + ggml_fp32_to_fp16_row((const float *) src1->data + ne10 * ine11, (ggml_fp16_t *) act_input + ne10 * ine11, ne10); + } + ggml_tmac_mul_mat_task_init(act_input + ne10 * ine11, + qlut + ne10 * ine11 * 4, + lut_scales + wt->lut_scales_size * ine11, + lut_biases + wt->lut_scales_size * ine11, + ne01, ne00, 1, bits); + } + + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_atomic_store_explicit(params->threadpool, nth); + // atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); + } + + ggml_barrier(params->threadpool); + + tmac_float_type * act_output = (tmac_float_type *) (dst->data); + if constexpr (sizeof(tmac_float_type) == 2) { + act_output = tmac_f_ptr; + } + + const int n_tile_num = wt->n_tile_num; + // Currently, T-MAC requires ne0 devisible by n_tile_num + GGML_ASSERT(ne0 % n_tile_num == 0); + + const int64_t w_size = ne00 * ne01 * bits / 8; + const int64_t w_chunk_size = w_size / n_tile_num; + + const int64_t nr0 = ne0; + const int64_t nr1 = ne1 * ne2 * ne3; + + // Adopt the same style with current llama.cpp impl + // But different chunk size for 0/1 dim. + // No scrap. + const int chunk_size0 = ne0 / n_tile_num; + const int chunk_size1 = 8; // TODO: tune in T-MAC + + // nchunk0 == n_tile_num + int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; + int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1; + + int64_t dr0 = chunk_size0; + int64_t dr1 = chunk_size1; +#if defined(TMAC_RECHUNK) + // Rechunk + if ((nchunk1 == 1) && (nchunk0 > nth * 4)) { + // dr0 should be divisible by chunk_size0 + dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0; + nchunk0 = (nr0 + dr0 - 1) / dr0; + } +#endif + + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + // inline ggml_compute_forward_mul_mat_one_chunk here for simplicity + for (int64_t ichunk0 = ir0_start / chunk_size0; ichunk0 < ir0_end / chunk_size0; ichunk0++) { + const int64_t w_offset = ichunk0 * w_chunk_size; + const int64_t scales_offset = ichunk0 * wt->scales_size / n_tile_num; + + for (int64_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { + const int64_t qlut_offset = ne10 * ine11 * 4; + const int64_t lut_scales_offset = wt->lut_scales_size * ine11; + const int64_t dst_offset = ne0 * ine11 + ichunk0 * chunk_size0; + + ggml_tmac_mul_mat_task_compute(wt->qweights + w_offset, + wt->scales + scales_offset, + qlut + qlut_offset, + lut_scales + lut_scales_offset, + lut_biases + lut_scales_offset, + act_output + dst_offset, + ne01, ne00, 1, bits); + if constexpr (sizeof(tmac_float_type) == 2) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *) act_output + dst_offset, (float *) dst->data + dst_offset, chunk_size0); + } + } + } + + if (nth >= nchunk0 * nchunk1) { + break; + } + + // current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); + current_chunk = ggml_threadpool_atomic_fetch_add_explicit(params->threadpool, 1); + } + return; +} + +#endif // GGML_USE_TMAC \ No newline at end of file diff --git a/ggml/src/ggml-cpu/tmac/lut_mul_mat.h b/ggml/src/ggml-cpu/tmac/lut_mul_mat.h new file mode 100644 index 0000000000000..5a94f3dba0f6d --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/lut_mul_mat.h @@ -0,0 +1,69 @@ +#pragma once + +/* Please do not include this header file outside ggml-cpu/tmac */ + +#include "lut_ctor.h" +#include "tbl.h" +#include "ggml-cpu-traits.h" + +#include + +static const int GGML_TMAC_MAX_NODES = 8192; +struct tmac_tensor_extra { + int lut_scales_size; + int scales_size; + int n_tile_num; + uint8_t * qweights; + tmac_float_type * scales; +}; + +namespace ggml::cpu::tmac { + class tensor_traits : public ggml::cpu::tensor_traits { + std::unordered_map tmac_tensor_extra; + // struct tmac_tensor_extra * tmac_tensor_extra = nullptr; + + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override; + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override; + +public: + struct tmac_tensor_extra * get_tmac_tensor_extra(std::string tensor_name) { + if (tmac_tensor_extra.find(tensor_name) == tmac_tensor_extra.end()) { + return nullptr; + } + return tmac_tensor_extra[tensor_name]; + } + void set_tmac_tensor_extra(std::string tensor_name, struct tmac_tensor_extra * extra) { + // if (tmac_tensor_extra.find(tensor_name) != tmac_tensor_extra.end()) { + // GGML_LOG_WARN("tmac_tensor_extra already exists for tensor %s. Overriding the data!\n", tensor_name.c_str()); + // } + tmac_tensor_extra[tensor_name] = extra; + } + }; +} // namespace ggml::cpu::tmac + + +#ifdef __cplusplus +extern "C" { +#endif + +void tmac_init(void); + +bool is_tmac_type(enum ggml_type type); + +bool is_type_supported(enum ggml_type type); + +size_t ggml_backend_tmac_desired_wsize(const struct ggml_tensor * dst); + +size_t ggml_backend_tmac_get_alloc_size(const struct ggml_tensor * tensor); + +size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor); + +void ggml_backend_tmac_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + +bool ggml_tmac_can_mul_mat(const struct ggml_tensor * dst); + +void ggml_backend_tmac_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/tmac/tbl.cpp b/ggml/src/ggml-cpu/tmac/tbl.cpp new file mode 100644 index 0000000000000..13f7881b6d0b7 --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/tbl.cpp @@ -0,0 +1,944 @@ +#include "tbl.h" +#include "lut_ctor.h" +#include "../../common/log.h" + +#include "string.h" +#include +#include +#include +#include +#include +#include +#include + + +#ifdef __ARM_NEON +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + int8x16_t lhs; + + inline void push(int8x16_t v, int k) { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = vrhaddq_s8(lhs, adder.get()); + } + } + } + + inline int8x16_t get() { + return lhs; + } + + inline int16x8_t get_low() { + return vmovl_s8(vget_low_s8(lhs)); + } + + inline int16x8_t get_high() { + return vmovl_high_s8(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + int8x16_t lhs; + + inline void push(int8x16_t v, int k) { + if (k == 0) { + lhs = v; + } else { + lhs = vrhaddq_s8(lhs, v); + } + } + + inline int8x16_t get() { + return lhs; + } + + inline int16x8_t get_low() { + return vmovl_s8(vget_low_s8(lhs)); + } + + inline int16x8_t get_high() { + return vmovl_high_s8(lhs); + } +}; + +struct SignedLongAdder { + int16x8_t lhs_low; + int16x8_t lhs_high; + int8x16_t lhs; + + inline void push(int8x16_t v, int k) { + if (k == 0) { + lhs = v; + } else { + lhs_low = vaddl_s8(vget_low_s8(lhs), vget_low_s8(v)); + lhs_high = vaddl_high_s8(lhs, v); + } + } + + inline int16x8_t get_low() { + return lhs_low; + } + + inline int16x8_t get_high() { + return lhs_high; + } +}; + +template +struct SignedWideningAdder { + SignedLongAdder adder; + int16x8_t lhs_low; + int16x8_t lhs_high; + + inline void push(int8x16_t v, int k) { + if (k % 2 == 0) { + adder.push(v, 0); + } else { + adder.push(v, 1); + if (k == 1) { + lhs_low = adder.get_low(); + lhs_high = adder.get_high(); + } else { + lhs_low += adder.get_low(); + lhs_high += adder.get_high(); + } + } + } + + inline int16x8_t get_low() { + return lhs_low; + } + + inline int16x8_t get_high() { + return lhs_high; + } +}; +#elif defined __AVX2__ +#define extract_low_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_castsi256_si128(v)) +#define extract_high_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_extracti128_si256(v, 1)) +#define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v)) +#define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1)) + +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + __m256i lhs; + + inline void push(__m256i v, int k) { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = _mm256_avg_epu8(lhs, adder.get()); + } + } + } + + inline __m256i get() { + return lhs; + } + + inline __m256i get_low() { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() { + return extract_high_epi8_epi16(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + __m256i lhs; + + inline void push(__m256i v, int k) { + if (k == 0) { + lhs = v; + } else { + lhs = _mm256_avg_epu8(lhs, v); + } + } + + inline __m256i get() { + return lhs; + } + + inline __m256i get_low() { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() { + return extract_high_epi8_epi16(lhs); + } +}; + +template +struct SignedWideningAdder { + __m256i lhs_low; + __m256i lhs_high; + + inline void push(__m256i v, int k) { + if (k == 0) { + lhs_low = extract_low_epi8_epi16(v); + lhs_high = extract_high_epi8_epi16(v); + } else { + lhs_low = _mm256_add_epi16(lhs_low, extract_low_epi8_epi16(v)); + lhs_high = _mm256_add_epi16(lhs_high, extract_high_epi8_epi16(v)); + } + } + + inline __m256i get_low() { + return lhs_low; + } + + inline __m256i get_high() { + return lhs_high; + } +}; + +#endif + +template +using SignedAdder = typename std::conditional, SignedWideningAdder>::type; + + +template +struct mylog2 { + enum { + value = 1 + mylog2::value + }; +}; + +template <> +struct mylog2<0> { + enum { + value = -1 + }; +}; + + + +template +inline int32_t tbl_g4_float_float_update_impl(int32_t m, tmac_float_type* c, tmac_float_type* lut, uint8_t* a, tmac_float_type* scales) { +#ifdef __ARM_NEON + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + uint8x16x2_t vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = vld2q_u8(reinterpret_cast(lut + k * 16)); + } + + for (int i = 0; i < m / 2; i += 16) { + float16x8_t vec_c0 = vld1q_f16(c + i * 2); + float16x8_t vec_c1 = vld1q_f16(c + i * 2 + 8); + float16x8_t vec_c2 = vld1q_f16(c + i * 2 + 16); + float16x8_t vec_c3 = vld1q_f16(c + i * 2 + 24); + // Currently assume K * 4 weights share the same group of scale + float16x8_t vec_s0 = vld1q_f16(scales + i * 2); + float16x8_t vec_s1 = vld1q_f16(scales + i * 2 + 8); + float16x8_t vec_s2 = vld1q_f16(scales + i * 2 + 16); + float16x8_t vec_s3 = vld1q_f16(scales + i * 2 + 24); + +#pragma unroll + for (int k = 0; k < K; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + uint8x16_t vec_as = vld1q_u8(a + i * K + k * 16); + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); + + uint8x16_t vec_v_bot_low = vqtbl1q_u8(vec_lut[k].val[0], vec_a_bot); + uint8x16_t vec_v_bot_high = vqtbl1q_u8(vec_lut[k].val[1], vec_a_bot); + uint8x16x2_t vec_v_bot = vzipq_u8(vec_v_bot_low, vec_v_bot_high); + + uint8x16_t vec_v_top_low = vqtbl1q_u8(vec_lut[k].val[0], vec_a_top); + uint8x16_t vec_v_top_high = vqtbl1q_u8(vec_lut[k].val[1], vec_a_top); + uint8x16x2_t vec_v_top = vzipq_u8(vec_v_top_low, vec_v_top_high); + + if (has_scale) { + // TODO: optimize scales + vec_c0 += vreinterpretq_f16_u8(vec_v_bot.val[0]) * vec_s0; + vec_c1 += vreinterpretq_f16_u8(vec_v_bot.val[1]) * vec_s1; + vec_c2 += vreinterpretq_f16_u8(vec_v_top.val[0]) * vec_s2; + vec_c3 += vreinterpretq_f16_u8(vec_v_top.val[1]) * vec_s3; + } else { + vec_c0 += vreinterpretq_f16_u8(vec_v_bot.val[0]); + vec_c1 += vreinterpretq_f16_u8(vec_v_bot.val[1]); + vec_c2 += vreinterpretq_f16_u8(vec_v_top.val[0]); + vec_c3 += vreinterpretq_f16_u8(vec_v_top.val[1]); + } + } + + vst1q_f16(c + i * 2, vec_c0); + vst1q_f16(c + i * 2 + 8, vec_c1); + vst1q_f16(c + i * 2 + 16, vec_c2); + vst1q_f16(c + i * 2 + 24, vec_c3); + } +#endif + + return 0; +} + +template +constexpr int get_bias_scale() { + // The bias scale will be added to the first bit + // 15 = (1/2 + 1 + 2 + 4) / (1/2) + // 7 = (1/2 + 1 + 2) / (1/2) + // 3 = (1/2 + 1) / (1/2) + // 1 = (1/2) / (1/2) + if constexpr (bits == 4) { + return 15; + } else if constexpr (bits == 3) { + return 7; + } else if constexpr (bits == 2) { + return 3; + } else if constexpr (bits == 1) { + return 1; + } else { + return 0; + } +} + + +// When FastAggregation is enabled, FastAggregationK = ActK +// zero_points is merged into scales to maintain API +template +inline int32_t tbl_g4_int8_float_update_impl(int32_t m, tmac_float_type* c, int8_t* lut, uint8_t* a, tmac_float_type* scales, tmac_float_type* lut_scales, tmac_float_type* lut_biases) { +#ifdef __ARM_NEON + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + int8x16_t vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + + SignedAdder adder_bot, adder_top; + for (int i = 0; i < m / 2; i += 16) { + float16x8_t vec_c0, vec_c1, vec_c2, vec_c3; + + tmac_float_type partial_sum = (tmac_float_type) -0.0f; +#pragma unroll + for (int kk = 0; kk < K; kk += ActK) { +#pragma unroll + for (int k = 0; k < ActK; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + uint8x16_t vec_as = vld1q_u8(a + i * K + (kk + k) * 16); + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); + + int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_bot); + int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_top); + adder_bot.push(vec_v_bot_tmp, k); + adder_top.push(vec_v_top_tmp, k); + } + + float16x8_t vec_v_bot_low = vcvtq_f16_s16(adder_bot.get_low()); + float16x8_t vec_v_bot_high = vcvtq_f16_s16(adder_bot.get_high()); + float16x8_t vec_v_top_low = vcvtq_f16_s16(adder_top.get_low()); + float16x8_t vec_v_top_high = vcvtq_f16_s16(adder_top.get_high()); + + tmac_float_type lut_s = lut_scales[kk / ActK]; + tmac_float_type lut_b = lut_biases[kk / ActK]; + + // lut_b = -sum(xi for i in range(ActK * 4)) + if (ZeroPoint) { + partial_sum += lut_b; + } + + // https://arxiv.org/pdf/2106.10860.pdf + // Fast aggregation bias: -FastAggregationK * log2(FastAggregationK) / 4 * (act_k / FastAggregationK) + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + +#define lut_fma(vs, ib) \ + ((ib) % Bits) ? ((vs) * lut_s) \ + : ((vs) * lut_s + lut_b) + if (kk == 0) { + vec_c0 = lut_fma(vec_v_bot_low, (i / 4 )); + vec_c1 = lut_fma(vec_v_bot_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_top_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_top_high, (i / 4 + 3)); + } else { + vec_c0 += lut_fma(vec_v_bot_low, (i / 4 )); + vec_c1 += lut_fma(vec_v_bot_high, (i / 4 + 1)); + vec_c2 += lut_fma(vec_v_top_low, (i / 4 + 2)); + vec_c3 += lut_fma(vec_v_top_high, (i / 4 + 3)); + } +#undef lut_fma + } + + if (ZeroPoint) { + // OneScale mode is disabled for ZeroPoint = True + float16x8_t vec_s0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16); + float16x8_t vec_s1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16); + float16x8_t vec_s2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16); + float16x8_t vec_s3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16); + // default_zero = 2 ** (bits - 1) + // w = (w - default_zero - (zeros - default_zero)) * scales + vec_c0 = vld1q_f16(c + i * 2) + vec_c0 * vec_s0; + vec_c1 = vld1q_f16(c + i * 2 + 8) + vec_c1 * vec_s1; + vec_c2 = vld1q_f16(c + i * 2 + 16) + vec_c2 * vec_s2; + vec_c3 = vld1q_f16(c + i * 2 + 24) + vec_c3 * vec_s3; + float16x8_t vec_z0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16 + 8); + float16x8_t vec_z1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16 + 8); + float16x8_t vec_z2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16 + 8); + float16x8_t vec_z3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16 + 8); + partial_sum *= 2; +#define add_zero(cs, zs, ib) \ + ((ib) % Bits) ? ((cs)) \ + : ((cs) + zs * partial_sum) + vst1q_f16(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4 ))); + vst1q_f16(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + vst1q_f16(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + vst1q_f16(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3))); +#undef add_zero + } else { + if (OneScale) { + tmac_float_type vec_s = scales[0]; + vst1q_f16(c + i * 2, vld1q_f16(c + i * 2 ) + vec_c0 * vec_s); + vst1q_f16(c + i * 2 + 8, vld1q_f16(c + i * 2 + 8 ) + vec_c1 * vec_s); + vst1q_f16(c + i * 2 + 16, vld1q_f16(c + i * 2 + 16) + vec_c2 * vec_s); + vst1q_f16(c + i * 2 + 24, vld1q_f16(c + i * 2 + 24) + vec_c3 * vec_s); + } else { + float16x8_t vec_s0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 8); + float16x8_t vec_s1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 8); + float16x8_t vec_s2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 8); + float16x8_t vec_s3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 8); + vst1q_f16(c + i * 2, vld1q_f16(c + i * 2 ) + vec_c0 * vec_s0); + vst1q_f16(c + i * 2 + 8, vld1q_f16(c + i * 2 + 8 ) + vec_c1 * vec_s1); + vst1q_f16(c + i * 2 + 16, vld1q_f16(c + i * 2 + 16) + vec_c2 * vec_s2); + vst1q_f16(c + i * 2 + 24, vld1q_f16(c + i * 2 + 24) + vec_c3 * vec_s3); + } + } + } +#elif defined __AVX2__ + const __m128i vec_mask = _mm_set1_epi8(0x0f); + __m128i vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 16)); + } + + SignedAdder adder; + for (int i = 0; i < m / 2; i += 16) { + __m256 vec_c0, vec_c1, vec_c2, vec_c3; + + tmac_float_type partial_sum = (tmac_float_type)-0.0f; +#pragma unroll + for (int kk = 0; kk < K; kk += ActK) { +#pragma unroll + for (int k = 0; k < ActK; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + __m128i vec_as = _mm_loadu_si128(reinterpret_cast<__m128i*>(a + i * K + (kk + k) * 16)); + __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + __m256i vec_lut_ = _mm256_set_m128i(vec_lut[kk + k], vec_lut[kk + k]); + __m256i vec_a = _mm256_set_m128i(vec_a_top, vec_a_bot); + __m256i vec_v = _mm256_shuffle_epi8(vec_lut_, vec_a); + adder.push(vec_v, k); + } + + __m256 vec_v_low_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_low())); + __m256 vec_v_low_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_low())); + __m256 vec_v_high_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_high())); + __m256 vec_v_high_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_high())); + + tmac_float_type lut_s = lut_scales[kk / ActK]; + tmac_float_type lut_b = lut_biases[kk / ActK]; + + partial_sum += lut_b; + + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + +#define lut_fma(vs, ib) \ + ((ib) % Bits) ? (_mm256_mul_ps((vs), _mm256_set1_ps(lut_s))) \ + : (_mm256_fmadd_ps((vs), _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b))) + if (kk == 0) { + vec_c0 = lut_fma(vec_v_low_low, (i / 4 )); + vec_c1 = lut_fma(vec_v_low_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_high_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_high_high, (i / 4 + 3)); + } else { + vec_c0 = _mm256_add_ps(vec_c0, lut_fma(vec_v_low_low, (i / 4 ))); + vec_c1 = _mm256_add_ps(vec_c1, lut_fma(vec_v_low_high, (i / 4 + 1))); + vec_c2 = _mm256_add_ps(vec_c2, lut_fma(vec_v_high_low, (i / 4 + 2))); + vec_c3 = _mm256_add_ps(vec_c3, lut_fma(vec_v_high_high, (i / 4 + 3))); + } +#undef lut_fma + } + + if (ZeroPoint) { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 16); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16); + vec_c0 = _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2)); + vec_c1 = _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8)); + vec_c2 = _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16)); + vec_c3 = _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24)); + __m256 vec_z0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 16 + 8); + __m256 vec_z1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16 + 8); + __m256 vec_z2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16 + 8); + __m256 vec_z3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16 + 8); + partial_sum *= 2; +#define add_zero(cs, zs, ib) \ + ((ib) % Bits) ? ((cs)) \ + : (_mm256_fmadd_ps((zs), _mm256_set1_ps(partial_sum), (cs))) + _mm256_storeu_ps(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4 ))); + _mm256_storeu_ps(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + _mm256_storeu_ps(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + _mm256_storeu_ps(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3))); +#undef add_zero + } else if (OneScale) { + tmac_float_type single_scale = scales[0]; + __m256 vec_s = _mm256_set1_ps(single_scale); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s, _mm256_loadu_ps(c + i * 2 + 24))); + } else { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4 ) / Bits) * 8); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 8); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 8); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 8); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24))); + } + } +#endif + + return 0; +} + +// Unified scale +// TODO: implement fast aggregation for unified scale +template +inline int32_t tbl_g4_int8_int32_update_impl(int32_t m, int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + int8x16_t vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + + SignedAdder adder_bot, adder_top; + for (int i = 0; i < m / 2; i += 16) { +#pragma unroll + for (int k = 0; k < K; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + uint8x16_t vec_as = vld1q_u8(a + i * K + k * 16); + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); + + int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[k], vec_a_bot); + int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[k], vec_a_top); + adder_bot.push(vec_v_bot_tmp, k); + adder_top.push(vec_v_top_tmp, k); + } + + int16x8_t vec_v_bot_low = adder_bot.get_low(); + int16x8_t vec_v_bot_high = adder_bot.get_high(); + int16x8_t vec_v_top_low = adder_top.get_low(); + int16x8_t vec_v_top_high = adder_top.get_high(); + + int32x4_t vec_v_bot_low_low = vmovl_s16(vget_low_s16(vec_v_bot_low)); + int32x4_t vec_v_bot_low_high = vmovl_high_s16(vec_v_bot_low); + int32x4_t vec_v_bot_high_low = vmovl_s16(vget_low_s16(vec_v_bot_high)); + int32x4_t vec_v_bot_high_high = vmovl_high_s16(vec_v_bot_high); + int32x4_t vec_v_top_low_low = vmovl_s16(vget_low_s16(vec_v_top_low)); + int32x4_t vec_v_top_low_high = vmovl_high_s16(vec_v_top_low); + int32x4_t vec_v_top_high_low = vmovl_s16(vget_low_s16(vec_v_top_high)); + int32x4_t vec_v_top_high_high = vmovl_high_s16(vec_v_top_high); + + vst1q_s32(c + i * 2, vld1q_s32(c + i * 2 ) + vec_v_bot_low_low ); + vst1q_s32(c + i * 2 + 4, vld1q_s32(c + i * 2 + 4 ) + vec_v_bot_low_high ); + vst1q_s32(c + i * 2 + 8, vld1q_s32(c + i * 2 + 8 ) + vec_v_bot_high_low ); + vst1q_s32(c + i * 2 + 12, vld1q_s32(c + i * 2 + 12) + vec_v_bot_high_high); + vst1q_s32(c + i * 2 + 16, vld1q_s32(c + i * 2 + 16) + vec_v_top_low_low ); + vst1q_s32(c + i * 2 + 20, vld1q_s32(c + i * 2 + 20) + vec_v_top_low_high ); + vst1q_s32(c + i * 2 + 24, vld1q_s32(c + i * 2 + 24) + vec_v_top_high_low ); + vst1q_s32(c + i * 2 + 28, vld1q_s32(c + i * 2 + 28) + vec_v_top_high_high); + } + +#elif defined __AVX2__ + const __m128i vec_mask = _mm_set1_epi8(0x0f); + __m128i vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 16)); + } + + SignedAdder adder; + for (int i = 0; i < m / 2; i += 16) { +#pragma unroll + for (int k = 0; k < K; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + __m128i vec_as = _mm_loadu_si128(reinterpret_cast<__m128i*>(a + i * K + k * 16)); + __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + __m256i vec_lut_ = _mm256_set_m128i(vec_lut[k], vec_lut[k]); + __m256i vec_a = _mm256_set_m128i(vec_a_top, vec_a_bot); + __m256i vec_v = _mm256_shuffle_epi8(vec_lut_, vec_a); + adder.push(vec_v, k); + } + + __m256i vec_v_low_low = extract_low_epi16_epi32(adder.get_low()); + __m256i vec_v_low_high = extract_high_epi16_epi32(adder.get_low()); + __m256i vec_v_high_low = extract_low_epi16_epi32(adder.get_high()); + __m256i vec_v_high_high = extract_high_epi16_epi32(adder.get_high()); + __m256i vec_c0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2)); + __m256i vec_c1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 8)); + __m256i vec_c2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 16)); + __m256i vec_c3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 24)); + vec_c0 = _mm256_add_epi32(vec_c0, vec_v_low_low); + vec_c1 = _mm256_add_epi32(vec_c1, vec_v_low_high); + vec_c2 = _mm256_add_epi32(vec_c2, vec_v_high_low); + vec_c3 = _mm256_add_epi32(vec_c3, vec_v_high_high); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 ), vec_c0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 8 ), vec_c1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 16), vec_c2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 24), vec_c3); + } + +#endif + return 0; +} + +template +inline int32_t tbl_g4_int8_int16_update_impl(int32_t m, int16_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + int8x16_t vec_lut[K]; + +#pragma unroll + for (int k = 0; k < K; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + + SignedAdder adder_bot, adder_top; + for (int i = 0; i < m / 2; i += 16) { +#pragma unroll + for (int k = 0; k < K; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + uint8x16_t vec_as = vld1q_u8(a + i * K + k * 16); + uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4); + uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask); + + int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[k], vec_a_bot); + int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[k], vec_a_top); + adder_bot.push(vec_v_bot_tmp, k); + adder_top.push(vec_v_top_tmp, k); + } + + int16x8_t vec_v_bot_low = adder_bot.get_low(); + int16x8_t vec_v_bot_high = adder_bot.get_high(); + int16x8_t vec_v_top_low = adder_top.get_low(); + int16x8_t vec_v_top_high = adder_top.get_high(); + vst1q_s16(c + i * 2, vld1q_s16(c + i * 2 ) + vec_v_bot_low); + vst1q_s16(c + i * 2 + 8, vld1q_s16(c + i * 2 + 8 ) + vec_v_bot_high); + vst1q_s16(c + i * 2 + 16, vld1q_s16(c + i * 2 + 16) + vec_v_top_low); + vst1q_s16(c + i * 2 + 24, vld1q_s16(c + i * 2 + 24) + vec_v_top_high); + } +#elif defined __AVX2__ + // TODO: implement this +#endif +} + + +inline void tbl_g4_int8_float_gather_bit1_impl(int32_t m, tmac_float_type* C_global, tmac_float_type* CBits, tmac_float_type* C) { + constexpr int32_t bits = 1; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + #pragma unroll + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (tmac_float_type)5.000000e-01f); + + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + #pragma unroll + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +inline void tbl_g4_int8_float_gather_bit2_impl(int32_t m, tmac_float_type* C_global, tmac_float_type* CBits, tmac_float_type* C) { + constexpr int32_t bits = 2; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + #pragma unroll + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (tmac_float_type)5.000000e-01f) + + (CBits[cse_var_2 + bit_offset_1]); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + #pragma unroll + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +inline void tbl_g4_int8_float_gather_bit3_impl(int32_t m, tmac_float_type* C_global, tmac_float_type* CBits, tmac_float_type* C) { + constexpr int32_t bits = 3; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + #pragma unroll + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + int32_t bit_offset_2 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 16; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (tmac_float_type)5.000000e-01f) + + (CBits[cse_var_2 + bit_offset_1]) + + (CBits[cse_var_2 + bit_offset_2] * (tmac_float_type)2.000000e+00f); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + #pragma unroll + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +inline void tbl_g4_int8_float_gather_bit4_impl(int32_t m, tmac_float_type* C_global, tmac_float_type* CBits, tmac_float_type* C) { + constexpr int32_t bits = 4; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + #pragma unroll + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + int32_t bit_offset_2 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 16; + int32_t bit_offset_3 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 24; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (tmac_float_type)5.000000e-01f) + + (CBits[cse_var_2 + bit_offset_1]) + + (CBits[cse_var_2 + bit_offset_2] * (tmac_float_type)2.000000e+00f) + + (CBits[cse_var_2 + bit_offset_3] * (tmac_float_type)4.000000e+00f); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + #pragma unroll + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + + + + +#ifdef __cplusplus +extern "C" { +#endif + +int32_t tbl_int8_reset(int32_t m, int8_t* c) { + memset(c, 0, m); + return 0; +} + +int32_t tbl_float_reset(int32_t m, void* c) { + memset(c, 0, m * sizeof(tmac_float_type)); + return 0; +} + +int32_t tbl_int32_reset(int32_t m, int32_t* c) { + memset(c, 0, m * sizeof(int32_t)); + return 0; +} + +int32_t tbl_int16_reset(int32_t m, int16_t* c) { + memset(c, 0, m * sizeof(int16_t)); + return 0; +} + +#ifdef __cplusplus +} +#endif + + +void qgemm_lut_int8_g4( + void* A, void* LUT, void* Scales, void* LUT_Scales, void* LUT_Biases, void* C, + int bm, int K, int N, const struct tmac_kernel_config * const kernel_config) { + // TODO: support N > 1 + if (N != 1) { + throw std::runtime_error("N > 1 is not supported yet"); + } + + const int g = kernel_config->g; + const int ngroups_per_elem = 8 / g; + int q_group_size = kernel_config->q_group_size; + int act_group_size = kernel_config->act_group_size; + bool has_scale = kernel_config->has_scale; + int kfactor = kernel_config->kfactor; + int bits = kernel_config->bits; + int actk = kernel_config->actk; + bool has_zero_point = kernel_config->has_zero_point; + bool one_scale = kernel_config->one_scale; + int m = bm / bits; + + tmac_float_type *CBits = new tmac_float_type[bm]; + tmac_float_type *C_global = new tmac_float_type[m]; + tbl_int32_reset(bm * sizeof(tmac_float_type) / sizeof(int32_t), (&(((int32_t*)CBits)[0]))); + + int32_t k_outer_max = K / (kfactor * g); + int32_t scale_gs = q_group_size / (kfactor * g); + int32_t scale_idx_shfr = 0; + if (scale_gs == 1) { + scale_idx_shfr = 0; + } else if (scale_gs == 2) { + scale_idx_shfr = 1; + } else if (scale_gs == 4) { + scale_idx_shfr = 2; + } else if (scale_gs == 8) { + scale_idx_shfr = 3; + } else { + fprintf(stderr, "q_group_size=%d, kfactor=%d, g=%d\n", q_group_size, kfactor, g); + fprintf(stderr, "Unsupported scale group size over kfactor. Expected {1,2,4,8}, got %d.\n", scale_gs); + throw std::runtime_error(""); + } + + for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { + uint8_t * a = ((uint8_t *)A) + k_outer * bm * kfactor / ngroups_per_elem; + tmac_float_type * scales = one_scale ? (tmac_float_type *)Scales : + has_zero_point ? ((tmac_float_type *)Scales) + (k_outer >> scale_idx_shfr) * m * 2: + ((tmac_float_type *)Scales) + (k_outer >> scale_idx_shfr) * m; + int8_t * lut = ((int8_t *)LUT) + k_outer * kfactor * int(pow(2, g)); + tmac_float_type * lut_scales = ((tmac_float_type *)LUT_Scales) + (k_outer * kfactor * g / act_group_size); + tmac_float_type * lut_biases = ((tmac_float_type *)LUT_Biases) + (k_outer * kfactor * g / act_group_size); + + if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } + + else if (has_scale && kfactor == 8 && bits == 4 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 4 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 8 && bits == 4 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } else if (has_scale && kfactor == 16 && bits == 4 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + (int32_t)bm, CBits, lut, a, scales, lut_scales, lut_biases); + } + } + // if (!(((uint8_t *)A)[0] == 0 && ((uint8_t *)A)[1] == 0 && ((uint8_t *)A)[2] == 0 && ((uint8_t *)A)[3] == 0 + // && ((uint8_t *)A)[4] == 0 && ((uint8_t *)A)[5] == 0 && ((uint8_t *)A)[6] == 0 && ((uint8_t *)A)[7] == 0)) { + // printf("\n\n\n\nCBits:\n\n\n"); + // for (int i = 0; i < bm; i++) { + // printf("%f ", CBits[i]); + // } + // printf("\n"); + // } + + if (bits == 1) { + tbl_g4_int8_float_gather_bit1_impl(m, C_global, CBits, (tmac_float_type *)C); + } else if (bits == 2) { + tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, (tmac_float_type *)C); + } else if (bits == 3) { + tbl_g4_int8_float_gather_bit3_impl(m, C_global, CBits, (tmac_float_type *)C); + } else if (bits == 4) { + tbl_g4_int8_float_gather_bit4_impl(m, C_global, CBits, (tmac_float_type *)C); + } else { + throw std::runtime_error("Unsupported bits"); + } + + delete[] C_global; + delete[] CBits; +} + diff --git a/ggml/src/ggml-cpu/tmac/tbl.h b/ggml/src/ggml-cpu/tmac/tbl.h new file mode 100644 index 0000000000000..304914504c582 --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/tbl.h @@ -0,0 +1,63 @@ +#pragma once + +/* Please do not include this header file outside ggml-cpu/tmac */ + +#ifndef INTRINSIC_TYPES_H +#define INTRINSIC_TYPES_H + +#ifdef __ARM_NEON +#include +#elif defined __AVX2__ +#include +#endif + +#ifdef __ARM_NEON +typedef float16_t tmac_float_type; +#else +#include +#include +typedef float tmac_float_type; +#endif + +#endif + + +#ifndef TMAC_HALF_TYPEDEF_H +#define TMAC_HALF_TYPEDEF_H + +#ifndef __AVX2__ +typedef _Float16 half; +#endif +#endif + +#include "lut_ctor.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +int32_t tbl_int8_reset(int32_t m, int8_t* c); + +int32_t tbl_float_reset(int32_t m, void* c); + +int32_t tbl_int32_reset(int32_t m, int32_t* c); + +int32_t tbl_int16_reset(int32_t m, int16_t* c); + + +void qgemm_lut_int8_g4( + void* A, void* LUT, void* Scales, void* LUT_Scales, void* LUT_Biases, void* C, + int bm, int K, int N, const struct tmac_kernel_config * const kernel_config); + +#ifdef __cplusplus +} +#endif + + + + + + + + diff --git a/ggml/src/ggml-cpu/tmac/tmac.cpp b/ggml/src/ggml-cpu/tmac/tmac.cpp new file mode 100644 index 0000000000000..2cda9c7ee695a --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/tmac.cpp @@ -0,0 +1,165 @@ + +#include +#include + +#include "ggml-backend-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-traits.h" +#include "lut_mul_mat.h" +#include "tmac.h" + +#if defined(GGML_USE_TMAC) +namespace ggml::cpu::tmac { + +static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { + static tensor_traits traits; + return &traits; +} + + +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + // auto is_contiguous = [](const struct ggml_tensor * t) { + // return ggml_is_contiguous(t); + // }; + + if (// ggml_is_contiguous(src0) && // src0 must be contiguous + // ggml_is_contiguous(src1) && // src1 must be contiguous + // op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_tmac_buffer_type() && + ggml_tmac_can_mul_mat(op)) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { // src1 must be host buffer + return false; + } + return true; + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer && + op->src[0]->buffer->buft == ggml_backend_tmac_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + + return nullptr; + } +}; + +} // namespace ggml::cpu::tmac + +void ggml_tmac_init() { + tmac_init(); +} + +static void ggml_backend_tmac_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_aligned_free(buffer->context, buffer->size); +} + +static void * ggml_backend_tmac_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; + + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } + + return (void *)data; +} + +static enum ggml_status ggml_backend_tmac_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::tmac::get_tensor_traits(buffer, tensor); + + GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_tmac_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_tmac_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + if (is_type_supported(tensor->type)) { + GGML_LOG_DEBUG("%s: tmac repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type)); + ggml_backend_tmac_convert_weight(tensor, data, offset, size); + } else { + memcpy((char *) tensor->data + offset, data, size); + } + + GGML_UNUSED(buffer); +} + +static void ggml_backend_tmac_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + + +static ggml_backend_buffer_i ggml_backend_tmac_buffer_interface = { + /* .free_buffer = */ ggml_backend_tmac_buffer_free_buffer, // same as ggml_backend_cpu_buffer_free_buffer + /* .get_base = */ ggml_backend_tmac_buffer_get_base, // same as ggml_backend_cpu_buffer_get_base + /* .init_tensor = */ ggml_backend_tmac_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_tmac_buffer_memset_tensor, // same as ggml_backend_cpu_buffer_memset_tensor + /* .set_tensor = */ ggml_backend_tmac_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_tmac_buffer_clear, // same as ggml_backend_cpu_buffer_clear + /* .reset = */ nullptr, +}; + + +// T-MAC backend buffer type +static const char * ggml_backend_tmac_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "TMAC"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_tmac_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * data = ggml_aligned_malloc(size); + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_tmac_buffer_interface, data, size); +} + +static size_t ggml_backend_tmac_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_tmac_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + // T-MAC version of ggml_nbytes + return ggml_tmac_get_nbytes(tensor); + + GGML_UNUSED(buft); +} + +static bool ggml_backend_tmac_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_tmac_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_tmac = { + /* .iface = */ { + /* .get_name = */ ggml_backend_tmac_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_tmac_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_tmac_buffer_type_get_alignment, // same as ggml_backend_cpu_* + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_tmac_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_tmac_buffer_type_is_host, // same as ggml_backend_cpu_* + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::tmac::extra_buffer_type(), + }; + + return &ggml_backend_buffer_type_tmac; +} + +#endif // GGML_USE_TMAC \ No newline at end of file diff --git a/ggml/src/ggml-cpu/tmac/tmac.h b/ggml/src/ggml-cpu/tmac/tmac.h new file mode 100644 index 0000000000000..c2986909e91ff --- /dev/null +++ b/ggml/src/ggml-cpu/tmac/tmac.h @@ -0,0 +1,20 @@ +#pragma once + +#include "ggml-backend.h" + +// GGML internal header + +#if defined(GGML_USE_TMAC) + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_tmac_buffer_type(void); +void ggml_tmac_init(void); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index ac918a60d9ece..6180f1a50ddea 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5227,6 +5227,17 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I64: // nothing to validate break; + case GGML_TYPE_TMAC_BN_0: + case GGML_TYPE_TMAC_W2G64_0: + case GGML_TYPE_TMAC_W2G64_1: + case GGML_TYPE_TMAC_W2G128_0: + case GGML_TYPE_TMAC_W2G128_1: + case GGML_TYPE_TMAC_W4G64_0: + case GGML_TYPE_TMAC_W4G64_1: + case GGML_TYPE_TMAC_W4G128_0: + case GGML_TYPE_TMAC_W4G128_1: + // nothing to validate + break; default: { fprintf(stderr, "%s: invalid type %d\n", __func__, type); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 950772c75cb32..b85c921ed3464 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -570,6 +570,60 @@ static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp1 static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_TMAC_BN_0] = { + .type_name = "tmac_bn_0", + .blck_size = 64, + .type_size = 64 * 2 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W2G64_0] = { + .type_name = "tmac_w2g64_0", + .blck_size = 64, + .type_size = 2 + 64 * 2 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W2G64_1] = { + .type_name = "tmac_w2g64_1", + .blck_size = 64, + .type_size = 2 + 2 + 64 * 2 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W2G128_0] = { + .type_name = "tmac_w2g128_0", + .blck_size = 128, + .type_size = 2 + 128 * 2 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W2G128_1] = { + .type_name = "tmac_w2g128_1", + .blck_size = 128, + .type_size = 2 + 2 + 128 * 2 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W4G64_0] = { + .type_name = "tmac_w4g64_0", + .blck_size = 64, + .type_size = 2 + 64 * 4 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W4G64_1] = { + .type_name = "tmac_w4g64_1", + .blck_size = 64, + .type_size = 2 + 2 + 64 * 4 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W4G128_0] = { + .type_name = "tmac_w4g128_0", + .blck_size = 128, + .type_size = 2 + 128 * 4 / 8, + .is_quantized = false, + }, + [GGML_TYPE_TMAC_W4G128_1] = { + .type_name = "tmac_w4g128_1", + .blck_size = 128, + .type_size = 2 + 2 + 128 * 4 / 8, + .is_quantized = false, + }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -1170,6 +1224,12 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) { } } + if (tensor->type == GGML_TYPE_TMAC_BN_0) { + // One scale will not exceed one alignment boundary, so we can just add one alignment to the size. + nbytes += GGUF_DEFAULT_ALIGNMENT; + } + + return nbytes; } diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py index 243defc4c1ca4..fac14655cc20b 100644 --- a/gguf-py/gguf/__init__.py +++ b/gguf-py/gguf/__init__.py @@ -7,3 +7,4 @@ from .vocab import * from .utility import * from .metadata import * +from .tmac_utils import * diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 162070e6e193a..8c1f50217fd43 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1887,6 +1887,15 @@ class GGMLQuantizationType(IntEnum): BF16 = 30 TQ1_0 = 34 TQ2_0 = 35 + TMAC_BN_0 = 39 + TMAC_W2G64_0 = 40 + TMAC_W2G64_1 = 41 + TMAC_W2G128_0 = 42 + TMAC_W2G128_1 = 43 + TMAC_W4G64_0 = 44 + TMAC_W4G64_1 = 45 + TMAC_W4G128_0 = 46 + TMAC_W4G128_1 = 47 class ExpertGatingFuncType(IntEnum): @@ -1938,6 +1947,15 @@ class LlamaFileType(IntEnum): # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors + MOSTLY_TMAC_BN_0 = 38 # except 1d tensors + MOSTLY_TMAC_W2G64_0 = 39 # except 1d tensors + MOSTLY_TMAC_W2G64_1 = 40 # except 1d tensors + MOSTLY_TMAC_W2G128_0 = 41 # except 1d tensors + MOSTLY_TMAC_W2G128_1 = 42 # except 1d tensors + MOSTLY_TMAC_W4G64_0 = 43 # except 1d tensors + MOSTLY_TMAC_W4G64_1 = 44 # except 1d tensors + MOSTLY_TMAC_W4G128_0 = 45 # except 1d tensors + MOSTLY_TMAC_W4G128_1 = 46 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -2013,6 +2031,19 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.BF16: (1, 2), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ2_0: (256, 2 + 64), + # Currently, we use tricks here + # - Bitnet-style models have only one scale value for the whole tensor, + # - which is not compatible with the "blocking" philosophy of here. + # - During inference, the accurate nbytes info will be known through ggml_tmac_get_nbytes. + GGMLQuantizationType.TMAC_BN_0: (64, 64 * 2 // 8), + GGMLQuantizationType.TMAC_W2G64_0: (64, 2 + 64 * 2 // 8), + GGMLQuantizationType.TMAC_W2G64_1: (64, 2 + 2 + 64 * 2 // 8), + GGMLQuantizationType.TMAC_W2G128_0: (128, 2 + 128 * 2 // 8), + GGMLQuantizationType.TMAC_W2G128_1: (128, 2 + 2 + 128 * 2 // 8), + GGMLQuantizationType.TMAC_W4G64_0: (64, 2 + 64 * 4 // 8), + GGMLQuantizationType.TMAC_W4G64_1: (64, 2 + 2 + 64 * 4 // 8), + GGMLQuantizationType.TMAC_W4G128_0: (128, 2 + 128 * 4 // 8), + GGMLQuantizationType.TMAC_W4G128_1: (128, 2 + 2 + 128 * 4 // 8), } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 3c8ba82e19d3d..278d518a36762 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -54,12 +54,16 @@ class QuantError(Exception): ... def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + from gguf.tmac_utils import is_tmac_dtype if qtype == GGMLQuantizationType.F32: return data.astype(np.float32, copy=False) elif qtype == GGMLQuantizationType.F16: return data.astype(np.float16, copy=False) elif (q := _type_traits.get(qtype)) is not None: return q.quantize(data) + # Do nothing for I1/2/3/4, as they are already quantized + elif is_tmac_dtype(qtype): + return data else: raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented") diff --git a/gguf-py/gguf/tmac_utils.py b/gguf-py/gguf/tmac_utils.py new file mode 100644 index 0000000000000..42fd54e8934e3 --- /dev/null +++ b/gguf-py/gguf/tmac_utils.py @@ -0,0 +1,171 @@ +import json +import logging +import numpy as np +import os +from pathlib import Path +import sys +from typing import Optional, Tuple + +logger = logging.getLogger("tmac_utils") + + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + + +def is_tmac_w2_ftype(ftype: gguf.LlamaFileType): + return ftype == gguf.LlamaFileType.MOSTLY_TMAC_BN_0 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G64_0 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G64_1 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G128_0 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W2G128_1 + +def is_tmac_w4_ftype(ftype: gguf.LlamaFileType): + return ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G64_0 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G64_1 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G128_0 or \ + ftype == gguf.LlamaFileType.MOSTLY_TMAC_W4G128_1 + +def is_tmac_ftype(ftype: gguf.LlamaFileType): + return is_tmac_w2_ftype(ftype) or is_tmac_w4_ftype(ftype) + +def is_tmac_w2_dtype(dtype: gguf.GGMLQuantizationType): + return dtype == gguf.GGMLQuantizationType.TMAC_BN_0 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W2G64_0 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W2G64_1 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W2G128_0 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W2G128_1 + +def is_tmac_w4_dtype(dtype: gguf.GGMLQuantizationType): + return dtype == gguf.GGMLQuantizationType.TMAC_W4G64_0 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W4G64_1 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W4G128_0 or \ + dtype == gguf.GGMLQuantizationType.TMAC_W4G128_1 + +def is_tmac_dtype(dtype: gguf.GGMLQuantizationType): + return is_tmac_w2_dtype(dtype) or is_tmac_w4_dtype(dtype) + + +def parse_gptqv2(qweight: np.ndarray, scales: np.ndarray, qzeros: np.ndarray) -> Tuple: + bits = 32 // (scales.shape[1] // qzeros.shape[1]) + K = qweight.shape[0] * (32 // bits) + M = qweight.shape[1] + group_size = K // scales.shape[0] + + return K, M, bits, group_size + + +def unpack_gptqv2(qweight: np.ndarray, scales: np.ndarray, qzeros: np.ndarray, gptq_v2: bool = True): + """ + Unpack GPTQv2 + Return T-MAC biased uint8 weight [0, 2 ** bits), fp16 scales, biased fp16 zeros, bits, group_size + """ + assert qweight.dtype == "int32" + assert qzeros.dtype == "int32" + + K, M, bits, group_size = parse_gptqv2(qweight, scales, qzeros) + + # Unpack qweight + qweights = [(qweight >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)] + w = np.stack(qweights, axis=1).reshape(K, M).T.astype("uint8") + + scales = scales.T + + # Unpack qzeros + zeros = [(qzeros >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)] + zeros = np.stack(zeros, axis=-1).reshape(K // group_size, M).T.astype(scales.dtype) + if not gptq_v2: + # `zeros = zeros - 1` in AutoGPTQ + # Not in GPTQModel + zeros += 1 + zeros = (zeros - (2 ** (bits - 1))) * scales + + return w, scales, zeros, bits, group_size + + +def get_quantization_config(model_dir: str) -> dict: + try: + with open(model_dir / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + except FileNotFoundError: + logger.warning("config.json not found, using default empty quantization config") + hparams = {} + + # GPTQ + quantization_config = hparams.get("quantization_config", {}) + desc_act = quantization_config.get("desc_act", False) + assert not desc_act, "desc_act=True currently unsupported by T-MAC" + quantizer = quantization_config.get("meta", {}).get("quantizer", "") + group_size = quantization_config.get("group_size", 0) + bits = quantization_config.get("bits", 0) + sym = quantization_config.get("sym", False) + quant_method = quantization_config.get("quant_method", "") + # BitNet + weight_bits = hparams.get("weight_bits", 0) + + return { + "quantizer": quantizer, + "group_size": group_size, + "bits": bits, + "sym": sym, + "quant_method": quant_method, + "weight_bits": weight_bits, + } + + +def derive_ftype_from_quantization_config(quantization_config: dict) -> gguf.LlamaFileType | None: + # If bits > 0, the tensor is quantized by GPTQ + bits = quantization_config["bits"] + group_size = quantization_config["group_size"] + sym = quantization_config["sym"] + ftype = None + if quantization_config["quant_method"] in ["gptq", "bitdistiller"] and bits > 0: + if bits == 2 and group_size == -1: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_BN_0 + elif bits == 2 and group_size == 64 and sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W2G64_0 + elif bits == 2 and group_size == 64 and not sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W2G64_1 + elif bits == 2 and group_size == 128 and sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W2G128_0 + elif bits == 2 and group_size == 128 and not sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W2G128_1 + elif bits == 4 and group_size == 64 and sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W4G64_0 + elif bits == 4 and group_size == 64 and not sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W4G64_1 + elif bits == 4 and group_size == 128 and sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W4G128_0 + elif bits == 4 and group_size == 128 and not sym: + ftype = gguf.LlamaFileType.MOSTLY_TMAC_W4G128_1 + else: + raise ValueError(f"Unsupported number of (bits, group_size, sym): ({bits}, {group_size}, {sym})") + return ftype + + +def tighten_bit_array( + w: np.ndarray, + bits: int +) -> np.ndarray: + mask = (1 << bits) - 1 + tightened_array = w & mask + flattened_bits = np.unpackbits(tightened_array.astype(np.uint8)).reshape(-1, 8)[:, -bits:] + tightened_compact = np.packbits(flattened_bits) + return tightened_compact + + +def preprocess_for_t_mac( + w: np.ndarray, + scales: np.ndarray, + zeros: Optional[np.ndarray] = None, + bits: int = 2, + g: int = 4, +) -> np.ndarray: + + w_packed = tighten_bit_array(w, bits) + + if zeros is not None: + return np.concatenate([w_packed, scales.astype(np.float16).copy().view(np.uint8).flatten(), zeros.astype(np.float16).copy().view(np.uint8).flatten()]) + else: + return np.concatenate([w_packed, scales.astype(np.float16).copy().view(np.uint8).flatten()]) diff --git a/include/llama.h b/include/llama.h index 5657fbf0a703a..0535542fff43e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -185,6 +185,15 @@ extern "C" { //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TMAC_BN_0 = 38, + LLAMA_FTYPE_MOSTLY_TMAC_W2G64_0 = 39, + LLAMA_FTYPE_MOSTLY_TMAC_W2G64_1 = 40, + LLAMA_FTYPE_MOSTLY_TMAC_W2G128_0 = 41, + LLAMA_FTYPE_MOSTLY_TMAC_W2G128_1 = 42, + LLAMA_FTYPE_MOSTLY_TMAC_W4G64_0 = 43, + LLAMA_FTYPE_MOSTLY_TMAC_W4G64_1 = 44, + LLAMA_FTYPE_MOSTLY_TMAC_W4G128_0 = 45, + LLAMA_FTYPE_MOSTLY_TMAC_W4G128_1 = 46, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index ea73a8a7ba944..b6168d564e7df 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -2,6 +2,10 @@ #include "ggml.h" +#ifdef GGML_USE_TMAC + #include "tmac.h" +#endif + #include #include #include @@ -59,6 +63,15 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_BN_0: return "TMAC_BN_0"; + case LLAMA_FTYPE_MOSTLY_TMAC_W2G64_0: return "TMAC_W2G64_0 - 2.25 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W2G64_1: return "TMAC_W2G64_1 - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W2G128_0: return "TMAC_W2G128_0 - 2.125 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W2G128_1: return "TMAC_W2G128_1 - 2.25 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W4G64_0: return "TMAC_W4G64_0 - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W4G64_1: return "TMAC_W4G64_1 - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W4G128_0: return "TMAC_W4G128_0 - 4.125 bpw"; + case LLAMA_FTYPE_MOSTLY_TMAC_W4G128_1: return "TMAC_W4G128_1 - 4.25 bpw"; default: return "unknown, may not work"; } @@ -634,6 +647,15 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_TMAC_BN_0: ftype = LLAMA_FTYPE_MOSTLY_TMAC_BN_0; break; + case GGML_TYPE_TMAC_W2G64_0: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W2G64_0; break; + case GGML_TYPE_TMAC_W2G64_1: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W2G64_1; break; + case GGML_TYPE_TMAC_W2G128_0: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W2G128_0; break; + case GGML_TYPE_TMAC_W2G128_1: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W2G128_1; break; + case GGML_TYPE_TMAC_W4G64_0: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W4G64_0; break; + case GGML_TYPE_TMAC_W4G64_1: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W4G64_1; break; + case GGML_TYPE_TMAC_W4G128_0: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W4G128_0; break; + case GGML_TYPE_TMAC_W4G128_1: ftype = LLAMA_FTYPE_MOSTLY_TMAC_W4G128_1; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 7dc5422763118..4ee63288d17dc 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -813,7 +813,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } - + if (tensor->type == GGML_TYPE_TMAC_BN_0 || + tensor->type == GGML_TYPE_TMAC_W2G64_0 || + tensor->type == GGML_TYPE_TMAC_W2G64_1 || + tensor->type == GGML_TYPE_TMAC_W2G128_0 || + tensor->type == GGML_TYPE_TMAC_W2G128_1 || + tensor->type == GGML_TYPE_TMAC_W4G64_0 || + tensor->type == GGML_TYPE_TMAC_W4G64_1 || + tensor->type == GGML_TYPE_TMAC_W4G128_0 || + tensor->type == GGML_TYPE_TMAC_W4G128_1) { + // no need quantize for iN + new_type = tensor->type; + } + // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type;