Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 36 additions & 63 deletions backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_SCALE,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)

from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING, QNN_QUANT_TYPE_MAP
Expand All @@ -40,6 +34,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
op_wrapper_list = []
weight_node = self.get_node(node.args[0])
is_pcq_embedding = QCOM_QUANT_ATTRS in weight_node.meta and weight_node.meta[
QCOM_QUANT_ATTRS
Expand All @@ -63,44 +58,53 @@ def define_node(
nodes_to_wrappers,
)

gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
node_name = node.name
gather_input_tensors = []
if is_pcq_embedding:
node_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy()
intermediate_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy()
# Based on QNN HTP quantization constraints,
# we should set the scale to max of scales and per-tensor quantization for embedding op
intermediate_quant_attrs[QCOM_SCALE] = (
weight_node.meta[QCOM_QUANT_ATTRS][QCOM_SCALES].max().item()
act_quant_encoding, act_quant_configs = self.get_quant_encoding_conf(
node, node
)
act_dtype = (
torch.uint16
if act_quant_configs[QCOM_DTYPE] == torch.int32
else act_quant_configs[QCOM_DTYPE]
)
convert_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name + "_convert",
tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
dtype=QNN_QUANT_TYPE_MAP[act_dtype],
quant_encoding=act_quant_encoding,
quant_configs=act_quant_configs,
dims=weight_tensor.size(),
tensor=weight_tensor,
is_fake_tensor=True,
nodes_to_wrappers=nodes_to_wrappers,
)
intermediate_quant_attrs[QCOM_ZERO_POINT] = (
weight_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINTS].max().item()
convert_op = PyQnnManager.PyQnnOpWrapper(
node.name + "_convert",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpConvert.op_name,
)
intermediate_quant_attrs[QCOM_DTYPE] = weight_node.meta[QCOM_QUANT_ATTRS][
QCOM_DTYPE
]
intermediate_quant_attrs[QCOM_QUANT_MAX] = weight_node.meta[
QCOM_QUANT_ATTRS
][QCOM_QUANT_MAX]
intermediate_quant_attrs[QCOM_QUANT_MIN] = weight_node.meta[
QCOM_QUANT_ATTRS
][QCOM_QUANT_MIN]
node.meta[QCOM_QUANT_ATTRS] = intermediate_quant_attrs
node_name += "_intermediate"
convert_op.AddInputTensors([weight_tensor_wrapper])
convert_op.AddOutputTensors([convert_tensor_wrapper])
op_wrapper_list.append(convert_op)
gather_input_tensors.append(convert_tensor_wrapper)
else:
gather_input_tensors.append(weight_tensor_wrapper)
gather_input_tensors.append(indices_tensor_wrapper)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
node_name=node_name,
node_name=node.name,
)
gather_output_tensors = [output_tensor_wrapper]

gather_op = PyQnnManager.PyQnnOpWrapper(
node_name,
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGather.op_name,
)
Expand All @@ -113,37 +117,6 @@ def define_node(
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32,
{QCOM_DATA: np.int32(0)},
)

op_wrapper_list = [gather_op]

if is_pcq_embedding:
node.meta[QCOM_QUANT_ATTRS] = node_quant_attrs
act_quant_encoding, act_quant_configs = self.get_quant_encoding_conf(
node, node
)
act_dtype = (
torch.uint16
if act_quant_configs[QCOM_DTYPE] == torch.int32
else act_quant_configs[QCOM_DTYPE]
)
convert_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name,
tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
dtype=QNN_QUANT_TYPE_MAP[act_dtype],
quant_encoding=act_quant_encoding,
quant_configs=act_quant_configs,
dims=output_tensor.size(),
tensor=output_tensor,
is_fake_tensor=True,
nodes_to_wrappers=nodes_to_wrappers,
)
convert_op = PyQnnManager.PyQnnOpWrapper(
node.name + "_convert",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpConvert.op_name,
)
convert_op.AddInputTensors(gather_output_tensors)
convert_op.AddOutputTensors([convert_tensor_wrapper])
op_wrapper_list.append(convert_op)
op_wrapper_list.append(gather_op)

return op_wrapper_list
20 changes: 11 additions & 9 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3860,18 +3860,20 @@ def test_qnn_backend_embedding(self):
)
self.lower_module_and_test_output(modules[i], sample_input)

# TODO: Once the accuracy issue is fixed, enable this test.
@unittest.skip("Bad accuracy for HTP")
@unittest.skipIf(is_qnn_sdk_version_less_than("2.48"), "UT pass after QNN 2.48")
def test_qnn_backend_embedding_per_channel(self):
module = Embedding() # noqa: F405
sample_input = (torch.Tensor([1, 2, 4, 5]).to(torch.int32),)
qdq_module = self.get_qdq_module(
module,
sample_input,
quant_dtype=QuantDtype.use_16a8w,
is_embedding_per_channel=True,
)
self.lower_module_and_test_output(qdq_module, sample_input)
quant_dtype = [QuantDtype.use_16a8w, QuantDtype.use_16a4w]
for i, qdtype in enumerate(quant_dtype):
with self.subTest(i=i):
qdq_module = self.get_qdq_module(
module,
sample_input,
quant_dtype=qdtype,
is_embedding_per_channel=True,
)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_equal(self):
test_comb = [
Expand Down
11 changes: 9 additions & 2 deletions examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import types

from functools import partial
from typing import Dict, List
from typing import Any, Dict, List

import torch

Expand Down Expand Up @@ -783,10 +783,17 @@ def activation_override(quantized_node, unquantized_node):
activation_override(quantized_user, unquantized_user)

def parameter_override(quantized_node, unquantized_node):
# Some parameters need to be iterated over to retrieve attributes such as static_llama.tok_embedding.weight
def _get_attr(graph_module: torch.fx.GraphModule, target: str) -> Any:
attr: Any = graph_module
for target_atom in target.split("."):
attr = getattr(attr, target_atom)
return attr

setattr(
unquantized_model,
unquantized_node.target,
getattr(quantized_model, quantized_node.target),
_get_attr(quantized_model, quantized_node.target),
)
# scale / zero point are part of op's attributes
if list(quantized_node.users)[0].target in ptq_target:
Expand Down
Loading