From d92a7d3ad48b7184b62e62783d9d29658e216045 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 30 Jun 2026 13:44:52 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../source_transformation/pre_quantization.py | 150 +++++++++++++++++- 1 file changed, 143 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/source_transformation/pre_quantization.py b/examples/models/llama/source_transformation/pre_quantization.py index bb905579875..8f6bcab8046 100644 --- a/examples/models/llama/source_transformation/pre_quantization.py +++ b/examples/models/llama/source_transformation/pre_quantization.py @@ -11,6 +11,7 @@ from typing import Any, Optional import torch +import torch.nn.functional as F from torch import nn from torchao.quantization.linear_quant_modules import ( @@ -18,10 +19,129 @@ Int8DynActInt4WeightLinear, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.quant_primitives import dequantize_affine from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding +class WeightOnlyInt4Linear(torch.nn.Module): + """Weight-only int4 per-group linear that reuses the 8da4w checkpoint layout. + + Stores the SAME buffers as ``Int8DynActInt4WeightLinear`` (int8 ``weight`` holding + int4 values, per-group ``scales``/``zeros``) so a pre-quantized QAT/PTQ checkpoint + loads unchanged. The difference is forward: the activation is left in floating + point (no per-token dynamic quant), so the traced graph is + ``dequantize_affine(weight) -> F.linear`` with no activation quant. This lowers to + the ET-VK weight-only ``et_vk.linear_q4gsw`` path instead of the dynamic-activation + ``et_vk.linear_dq8ca_q4gsw`` path. + """ + + __constants__ = ["in_features", "out_features"] + + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | None = None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + self.in_features = in_features + self.out_features = out_features + self.groupsize = groupsize + self.precision = precision + self.register_buffer( + "weight", + torch.zeros((out_features, in_features), dtype=torch.int8, device=device), + ) + self.register_buffer( + "scales", + torch.zeros( + (out_features, in_features // groupsize), + dtype=scales_precision, + device=device, + ), + ) + self.register_buffer( + "zeros", + torch.zeros( + (out_features, in_features // groupsize), + dtype=scales_precision, + device=device, + ), + ) + if bias: + self.register_buffer( + "bias", torch.zeros(out_features, dtype=precision, device=device) + ) + else: + self.bias = None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(self.precision) + n_bit = 4 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + block_size = (1, self.groupsize) + w_dq = dequantize_affine( + self.weight, + block_size, + self.scales, + self.zeros, + torch.int8, + quant_min, + quant_max, + output_dtype=self.precision, + ) + return F.linear(input, w_dq, self.bias) + + +def _replace_linear_with_linear_int4_weight_only_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + precision: torch.dtype, + scales_precision: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + scales_key = f"{cur_fqn}.scales" + if isinstance(child, nn.Linear) and scales_key in checkpoint: + assert _check_linear_int4_k(child.in_features, group_size) + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == scales_precision + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = WeightOnlyInt4Linear( + # pyre-fixme[6] + child.in_features, + # pyre-fixme[6] + child.out_features, + bias=child.bias is not None, + device=child.weight.device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + # Symmetric int4: zero point is 0 (matches the 8da4w pre-quant path, which + # zeros this buffer rather than trusting the checkpoint's zeros key). + new_linear.zeros = torch.zeros_like(new_linear.zeros) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + def _replace_linear_with_linear_8da4w_for_pre_quantization( module: torch.nn.Module, checkpoint: Any, @@ -65,24 +185,40 @@ def transform_linear_for_pre_quantization( checkpoint: Any, group_size: int, dtype: torch.dtype, + weight_only: bool = False, ) -> torch.nn.Module: """ Transform the model to be able to load pre-quantized checkpoints that are quantized with the given group size and quantization mode for linear layers. + + When ``weight_only`` is True, linears are swapped for ``WeightOnlyInt4Linear`` + (float activation x dequantized int4 weight) instead of the default + ``Int8DynActInt4WeightLinear`` (dynamic per-token int8 activation). The + checkpoint buffers (int4 weight, per-group scales/zeros) are identical, so the + same pre-quantized checkpoint loads either way. """ if group_size not in [32, 64, 128, 256]: raise ValueError( f"Group size {group_size} is not supported for pre-quantized checkpoint." ) - _replace_linear_with_linear_8da4w_for_pre_quantization( - module, - checkpoint, - group_size, - dtype, - dtype, - ) + if weight_only: + _replace_linear_with_linear_int4_weight_only_for_pre_quantization( + module, + checkpoint, + group_size, + dtype, + dtype, + ) + else: + _replace_linear_with_linear_8da4w_for_pre_quantization( + module, + checkpoint, + group_size, + dtype, + dtype, + ) return module