diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 82d2ff1dbe0..7f56a6e8dbe 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -176,8 +176,8 @@ def f(x: torch.Tensor) -> torch.Tensor: ) return (f(effective_codes).to(dtype=torch.int8), 0) + @staticmethod def generate_16_bit_table_values( - self, torch_op: Callable[[torch.Tensor], torch.Tensor], in_quantargs: QuantArgs, out_quantargs: QuantArgs, @@ -224,6 +224,15 @@ def f(x: torch.Tensor) -> torch.Tensor: # but due to signedness this is a negative number! So we need to shift it one more bit. # Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7. rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16 + # When the table values use fewer than 16 bits (e.g. a sigmoid output + # quantized with a small scale, so the max table value is well below + # 2**15), the formula above yields a negative rshift. The values already + # fit in signed int16, and a negative right-shift is undefined (on host it + # masks the shift count and zeroes the table, giving a degenerate + # step-function LUT on device). Clamp to 0 so no shift is applied; this is + # the documented int16 case (rshift == 0, rescale_lshift == -7) and keeps + # rescale_lshift consistent with the shift actually performed below. + rshift = max(rshift, 0) # The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do. rescale_lshift = rshift - 7 lut_values = lut_values >> rshift diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 29fc90b67fc..707f1bae06f 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -175,3 +177,43 @@ def test_sigmoid_u85_INT_add_sigmoid(test_data): ) configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() + + +def test_int16_table_small_output_range_is_not_degenerate(): + """Regression for the int16 TABLE negative-rshift bug. + + A TABLE op whose output uses fewer than 16 bits -- e.g. a sigmoid output in + [0, 1] quantized with a small scale, so the max table value (here 4096) is + well below 2**15 -- yields ``rshift < 0``. The generator then did + ``lut_values >> rshift``, an undefined negative right-shift that zeroed the + whole table on the host, turning the on-device activation into a constant. + The table must remain a non-degenerate, monotonic ramp. + + """ + # qparams for a small-output-range sigmoid (input spans ~+-22.4, + # output in [0, 1] quantized at 1/4096 -> max table value 4096 = 13 bits). + in_quantargs = QuantArgs( + scale=0.0006833122461102903, zp=0, qmin=-32767, qmax=32767, dtype=torch.int16 + ) + out_quantargs = QuantArgs( + scale=1.0 / 4096, zp=0, qmin=-32767, qmax=32767, dtype=torch.int16 + ) + + # generate_16_bit_table_values is a @staticmethod; call it on the class. + lut, rescale_lshift = InsertTableOpsPass.generate_16_bit_table_values( + torch.sigmoid, in_quantargs, out_quantargs + ) + + assert ( + torch.unique(lut).numel() > 1 + ), "int16 sigmoid table collapsed to a constant (negative-rshift bug)" + assert bool( + (lut[1:] >= lut[:-1]).all() + ), "sigmoid table must be monotonically non-decreasing" + assert int(lut.min()) == 0 and int(lut.max()) == 4096, ( + f"unexpected table range [{int(lut.min())}, {int(lut.max())}], " + "expected a full [0, 4096] sigmoid ramp" + ) + # Values already fit in int16, so no shift is applied: this is the documented + # int16 case (rshift == 0 -> rescale_lshift == -7), not the buggy -10. + assert rescale_lshift == -7, f"expected rescale_lshift == -7, got {rescale_lshift}"