Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,77 @@ def forward(self, x, y):
return z5


class SimpleLLMDecoder(torch.nn.Module):
"""
Minimal transformer decoder mirroring how QNN LLM decoders are built:
a token embedding feeds a stack of decoder blocks whose linear projections
are expressed as 1x1 conv2d (see static_llama.py), grouped under a
``layers.N`` ModuleList. Takes token ids and an additive attention mask.
"""

class ConvAttention(torch.nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim**-0.5
self.wq_conv = torch.nn.Conv2d(dim, dim, 1, bias=False)
self.wk_conv = torch.nn.Conv2d(dim, dim, 1, bias=False)
self.wv_conv = torch.nn.Conv2d(dim, dim, 1, bias=False)
self.wo_conv = torch.nn.Conv2d(dim, dim, 1, bias=False)

def forward(self, x, atten_mask): # x: (b, dim, 1, seq)
b, dim, _, seq = x.shape
q = self.wq_conv(x).view(b, self.n_heads, self.head_dim, seq)
k = self.wk_conv(x).view(b, self.n_heads, self.head_dim, seq)
v = self.wv_conv(x).view(b, self.n_heads, self.head_dim, seq)
attn = torch.matmul(q.transpose(-2, -1), k) * self.scale
attn = torch.softmax(attn + atten_mask, dim=-1)
ctx = torch.matmul(v, attn.transpose(-2, -1))
ctx = ctx.reshape(b, dim, 1, seq)
return self.wo_conv(ctx)

class ConvFeedForward(torch.nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1_conv = torch.nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.w2_conv = torch.nn.Conv2d(hidden_dim, dim, 1, bias=False)
self.w3_conv = torch.nn.Conv2d(dim, hidden_dim, 1, bias=False)
self.act_fn = torch.nn.SiLU()

def forward(self, x):
return self.w2_conv(self.act_fn(self.w1_conv(x)) * self.w3_conv(x))

class DecoderLayer(torch.nn.Module):
def __init__(self, dim, hidden_dim, n_heads):
super().__init__()
self.attention = SimpleLLMDecoder.ConvAttention(dim, n_heads)
self.feed_forward = SimpleLLMDecoder.ConvFeedForward(dim, hidden_dim)

def forward(self, x, atten_mask):
x = x + self.attention(x, atten_mask)
x = x + self.feed_forward(x)
return x

def __init__(self, vocab_size=128, dim=32, hidden_dim=64, n_heads=4, n_layers=1):
super().__init__()
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
self.layers = torch.nn.ModuleList(
[self.DecoderLayer(dim, hidden_dim, n_heads) for _ in range(n_layers)]
)
self.output_conv = torch.nn.Conv2d(dim, dim, 1, bias=False)
self.eval()

def forward(self, input_ids, atten_mask): # input_ids: (b, seq)
x = self.tok_embeddings(input_ids) # (b, seq, dim)
b, seq, dim = x.shape
x = x.reshape(b, seq, 1, dim).transpose(1, 3) # (b, dim, 1, seq)
for layer in self.layers:
x = layer(x, atten_mask)
x = self.output_conv(x)
return x.transpose(1, 3).reshape(b, seq, dim)


class SkipBackToBack(torch.nn.Module):

def __init__(self):
Expand Down
42 changes: 33 additions & 9 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10812,19 +10812,40 @@ def test_analyzer_to_file_generation(self):
save_suggest_recipes,
)

module = SimpleModel() # noqa: F405
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
torch.manual_seed(8)
n_layers = 20
vocab_size, seq_len, n_heads = 128, 8, 4
module = SimpleLLMDecoder( # noqa: F405
vocab_size=vocab_size, n_heads=n_heads, n_layers=n_layers
)
input_ids = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.int32)
atten_mask = torch.triu(
torch.full((1, 1, seq_len, seq_len), float("-inf")), diagonal=1
)
sample_input = (input_ids, atten_mask)
fp32_gm = torch.export.export(module, sample_input, strict=True).module()
qdq_gm = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_8a4w
)

class DecoderInference:
def get_inputs(self, input_ids, attn_mask):
return (input_ids, attn_mask)

text_dataloader = [
{
"input_ids": input_ids,
"attention_mask": atten_mask,
}
]

num_sharding = 5
report = PerLayerSqnrAnalyzer(
model_name="simple_conv",
num_layers=4,
model_name="simple_llm_decoder",
num_layers=n_layers,
fp32_gm=fp32_gm,
qdq_gm=qdq_gm,
).analyze([sample_input], num_sharding=4)
).analyze(DecoderInference(), text_dataloader, num_sharding=num_sharding)

overrides = report.suggest_recipe_overrides(sqnr_threshold=22.0)

Expand All @@ -10833,10 +10854,13 @@ def test_analyzer_to_file_generation(self):
save_suggest_recipes(report, overrides, output_dir=tmp_dir)

# --- save_analysis_summary csv file ---
with open(f"{tmp_dir}/simple_conv_quantization_error.csv") as f:
with open(f"{tmp_dir}/simple_llm_decoder_quantization_error.csv") as f:
csv_content = f.read()
rows = list(csv.reader(csv_content.splitlines()))
self.assertEqual(len(rows), 5) # 1 header + 4 group rows
# 1 header + per-shard conv groups (7 projections each: wq/wk/wv/wo,
# w1/w2/w3) + the model-level output_conv. Layers are bucketed into
# num_sharding contiguous shards (n_layers >= num_sharding).
self.assertEqual(len(rows), 1 + num_sharding * 7 + 1)
self.assertEqual(
rows[0],
[
Expand All @@ -10852,11 +10876,11 @@ def test_analyzer_to_file_generation(self):

# --- save_suggest_recipes .py file (only written when sensitive layers exist) ---
if overrides:
with open(f"{tmp_dir}/simple_conv_suggest_recipe.py") as f:
with open(f"{tmp_dir}/simple_llm_decoder_suggest_recipe.py") as f:
py_content = f.read()
# generated file must be valid Python
try:
compile(py_content, "simple_conv_suggest_recipe.py", "exec")
compile(py_content, "simple_llm_decoder_suggest_recipe.py", "exec")
except SyntaxError as e:
self.fail(
f"Generated recipe file has syntax error: {e}\n{py_content}"
Expand Down
Loading