diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index adcf94f8f21..9d50aea42bf 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -14,12 +14,6 @@ QCOM_DTYPE, QCOM_ENCODING, QCOM_QUANT_ATTRS, - QCOM_QUANT_MAX, - QCOM_QUANT_MIN, - QCOM_SCALE, - QCOM_SCALES, - QCOM_ZERO_POINT, - QCOM_ZERO_POINTS, ) from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING, QNN_QUANT_TYPE_MAP @@ -40,6 +34,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + op_wrapper_list = [] weight_node = self.get_node(node.args[0]) is_pcq_embedding = QCOM_QUANT_ATTRS in weight_node.meta and weight_node.meta[ QCOM_QUANT_ATTRS @@ -63,44 +58,53 @@ def define_node( nodes_to_wrappers, ) - gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper] - - output_tensor = self.get_tensor(node, node) - node_name = node.name + gather_input_tensors = [] if is_pcq_embedding: - node_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy() - intermediate_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy() - # Based on QNN HTP quantization constraints, - # we should set the scale to max of scales and per-tensor quantization for embedding op - intermediate_quant_attrs[QCOM_SCALE] = ( - weight_node.meta[QCOM_QUANT_ATTRS][QCOM_SCALES].max().item() + act_quant_encoding, act_quant_configs = self.get_quant_encoding_conf( + node, node + ) + act_dtype = ( + torch.uint16 + if act_quant_configs[QCOM_DTYPE] == torch.int32 + else act_quant_configs[QCOM_DTYPE] + ) + convert_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_convert", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_QUANT_TYPE_MAP[act_dtype], + quant_encoding=act_quant_encoding, + quant_configs=act_quant_configs, + dims=weight_tensor.size(), + tensor=weight_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, ) - intermediate_quant_attrs[QCOM_ZERO_POINT] = ( - weight_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINTS].max().item() + convert_op = PyQnnManager.PyQnnOpWrapper( + node.name + "_convert", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpConvert.op_name, ) - intermediate_quant_attrs[QCOM_DTYPE] = weight_node.meta[QCOM_QUANT_ATTRS][ - QCOM_DTYPE - ] - intermediate_quant_attrs[QCOM_QUANT_MAX] = weight_node.meta[ - QCOM_QUANT_ATTRS - ][QCOM_QUANT_MAX] - intermediate_quant_attrs[QCOM_QUANT_MIN] = weight_node.meta[ - QCOM_QUANT_ATTRS - ][QCOM_QUANT_MIN] - node.meta[QCOM_QUANT_ATTRS] = intermediate_quant_attrs - node_name += "_intermediate" + convert_op.AddInputTensors([weight_tensor_wrapper]) + convert_op.AddOutputTensors([convert_tensor_wrapper]) + op_wrapper_list.append(convert_op) + gather_input_tensors.append(convert_tensor_wrapper) + else: + gather_input_tensors.append(weight_tensor_wrapper) + gather_input_tensors.append(indices_tensor_wrapper) + + output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, output_tensor, PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, - node_name=node_name, + node_name=node.name, ) gather_output_tensors = [output_tensor_wrapper] gather_op = PyQnnManager.PyQnnOpWrapper( - node_name, + node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGather.op_name, ) @@ -113,37 +117,6 @@ def define_node( PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, {QCOM_DATA: np.int32(0)}, ) - - op_wrapper_list = [gather_op] - - if is_pcq_embedding: - node.meta[QCOM_QUANT_ATTRS] = node_quant_attrs - act_quant_encoding, act_quant_configs = self.get_quant_encoding_conf( - node, node - ) - act_dtype = ( - torch.uint16 - if act_quant_configs[QCOM_DTYPE] == torch.int32 - else act_quant_configs[QCOM_DTYPE] - ) - convert_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name, - tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_QUANT_TYPE_MAP[act_dtype], - quant_encoding=act_quant_encoding, - quant_configs=act_quant_configs, - dims=output_tensor.size(), - tensor=output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - convert_op = PyQnnManager.PyQnnOpWrapper( - node.name + "_convert", - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpConvert.op_name, - ) - convert_op.AddInputTensors(gather_output_tensors) - convert_op.AddOutputTensors([convert_tensor_wrapper]) - op_wrapper_list.append(convert_op) + op_wrapper_list.append(gather_op) return op_wrapper_list diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index fcb365292ee..2b64d78c9e1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3860,18 +3860,20 @@ def test_qnn_backend_embedding(self): ) self.lower_module_and_test_output(modules[i], sample_input) - # TODO: Once the accuracy issue is fixed, enable this test. - @unittest.skip("Bad accuracy for HTP") + @unittest.skipIf(is_qnn_sdk_version_less_than("2.48"), "UT pass after QNN 2.48") def test_qnn_backend_embedding_per_channel(self): module = Embedding() # noqa: F405 sample_input = (torch.Tensor([1, 2, 4, 5]).to(torch.int32),) - qdq_module = self.get_qdq_module( - module, - sample_input, - quant_dtype=QuantDtype.use_16a8w, - is_embedding_per_channel=True, - ) - self.lower_module_and_test_output(qdq_module, sample_input) + quant_dtype = [QuantDtype.use_16a8w, QuantDtype.use_16a4w] + for i, qdtype in enumerate(quant_dtype): + with self.subTest(i=i): + qdq_module = self.get_qdq_module( + module, + sample_input, + quant_dtype=qdtype, + is_embedding_per_channel=True, + ) + self.lower_module_and_test_output(qdq_module, sample_input) def test_qnn_backend_equal(self): test_comb = [ diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 9bab682eac8..1478994011b 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -13,7 +13,7 @@ import types from functools import partial -from typing import Dict, List +from typing import Any, Dict, List import torch @@ -783,10 +783,17 @@ def activation_override(quantized_node, unquantized_node): activation_override(quantized_user, unquantized_user) def parameter_override(quantized_node, unquantized_node): + # Some parameters need to be iterated over to retrieve attributes such as static_llama.tok_embedding.weight + def _get_attr(graph_module: torch.fx.GraphModule, target: str) -> Any: + attr: Any = graph_module + for target_atom in target.split("."): + attr = getattr(attr, target_atom) + return attr + setattr( unquantized_model, unquantized_node.target, - getattr(quantized_model, quantized_node.target), + _get_attr(quantized_model, quantized_node.target), ) # scale / zero point are part of op's attributes if list(quantized_node.users)[0].target in ptq_target: