Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def f(x: torch.Tensor) -> torch.Tensor:
)
return (f(effective_codes).to(dtype=torch.int8), 0)

@staticmethod
def generate_16_bit_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
Expand Down Expand Up @@ -224,6 +224,15 @@ def f(x: torch.Tensor) -> torch.Tensor:
# but due to signedness this is a negative number! So we need to shift it one more bit.
# Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16
# When the table values use fewer than 16 bits (e.g. a sigmoid output
# quantized with a small scale, so the max table value is well below
# 2**15), the formula above yields a negative rshift. The values already
# fit in signed int16, and a negative right-shift is undefined (on host it
# masks the shift count and zeroes the table, giving a degenerate
# step-function LUT on device). Clamp to 0 so no shift is applied; this is
# the documented int16 case (rshift == 0, rescale_lshift == -7) and keeps
# rescale_lshift consistent with the shift actually performed below.
rshift = max(rshift, 0)
# The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
rescale_lshift = rshift - 7
lut_values = lut_values >> rshift
Expand Down
42 changes: 42 additions & 0 deletions backends/arm/test/ops/test_sigmoid_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down Expand Up @@ -175,3 +177,43 @@ def test_sigmoid_u85_INT_add_sigmoid(test_data):
)
configure_32bit_sigmoid_quantizer(pipeline)
pipeline.run()


def test_int16_table_small_output_range_is_not_degenerate():
"""Regression for the int16 TABLE negative-rshift bug.

A TABLE op whose output uses fewer than 16 bits -- e.g. a sigmoid output in
[0, 1] quantized with a small scale, so the max table value (here 4096) is
well below 2**15 -- yields ``rshift < 0``. The generator then did
``lut_values >> rshift``, an undefined negative right-shift that zeroed the
whole table on the host, turning the on-device activation into a constant.
The table must remain a non-degenerate, monotonic ramp.

"""
# qparams for a small-output-range sigmoid (input spans ~+-22.4,
# output in [0, 1] quantized at 1/4096 -> max table value 4096 = 13 bits).
in_quantargs = QuantArgs(
scale=0.0006833122461102903, zp=0, qmin=-32767, qmax=32767, dtype=torch.int16
)
out_quantargs = QuantArgs(
scale=1.0 / 4096, zp=0, qmin=-32767, qmax=32767, dtype=torch.int16
)

# generate_16_bit_table_values is a @staticmethod; call it on the class.
lut, rescale_lshift = InsertTableOpsPass.generate_16_bit_table_values(
torch.sigmoid, in_quantargs, out_quantargs
)

assert (
torch.unique(lut).numel() > 1
), "int16 sigmoid table collapsed to a constant (negative-rshift bug)"
assert bool(
(lut[1:] >= lut[:-1]).all()
), "sigmoid table must be monotonically non-decreasing"
assert int(lut.min()) == 0 and int(lut.max()) == 4096, (
f"unexpected table range [{int(lut.min())}, {int(lut.max())}], "
"expected a full [0, 4096] sigmoid ramp"
)
# Values already fit in int16, so no shift is applied: this is the documented
# int16 case (rshift == 0 -> rescale_lshift == -7), not the buggy -10.
assert rescale_lshift == -7, f"expected rescale_lshift == -7, got {rescale_lshift}"
Loading