Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,40 @@

logger = logging.getLogger(__name__)

# For Torch 2.12 and later overloads.
_QDQ_TORCH_OVERLOADS = (
("quantize_per_tensor", ("tensor", "tensor2", "default")),
("dequantize_per_tensor", ("tensor", "tensor2", "default")),
("quantize_per_channel", ("default",)),
("dequantize_per_channel", ("default",)),
)

# For backward compatibility with Torch versions older than 2.12.
_QDQ_BACKWARD_COMPAT_OVERLOADS = (
("quantize_per_tensor", ("out",)),
("dequantize_per_tensor", ("out",)),
("quantize_per_channel", ("out",)),
("dequantize_per_channel", ("out",)),
)


def _get_qdq_memory_format_ops() -> tuple[object, ...]:
qdq_ops = []
backward_compat = dict(_QDQ_BACKWARD_COMPAT_OVERLOADS)
ns = torch.ops.quantized_decomposed
for op_name, overload_names in _QDQ_TORCH_OVERLOADS:
op_packet = getattr(ns, op_name, None)
if op_packet is None:
continue
for overload_name in overload_names + backward_compat[op_name]:
if hasattr(op_packet, overload_name):
qdq_ops.append(getattr(op_packet, overload_name))

return tuple(qdq_ops)


_QDQ_MEMORY_FORMAT_OPS = _get_qdq_memory_format_ops()

# Copied from PyTorch.
# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
# To avoid a dependency on _internal stuff.
Expand Down Expand Up @@ -358,12 +392,7 @@ def __torch_function__(self, func, types, args=..., kwargs=None):

# This is a hack since Q/DQ ops does not handle channels last input correctly: the simplest and most robust
# workaround is to simply run them in channels first format and then convert back to channels last.
if func in (
torch.ops.quantized_decomposed.quantize_per_tensor.out,
torch.ops.quantized_decomposed.dequantize_per_tensor.out,
torch.ops.quantized_decomposed.quantize_per_channel.out,
torch.ops.quantized_decomposed.dequantize_per_channel.out,
):
if func in _QDQ_MEMORY_FORMAT_OPS:

input_dim_order = args[0].dim_order()
if input_dim_order in (NHWC_ORDER, NNHWC_ORDER):
Expand Down
Loading