From 2bf3bb2f2033624e8cd0ab77ca3cbfcd4e44eec2 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 14:45:32 -0700 Subject: [PATCH 1/3] Add sym_ite prim op for symbolic if-then-else Implements the ternary sym_ite(condition, true_val, false_val) op, needed by torch.export with Dim.AUTO when models contain conditional symbolic shape logic. --- exir/emit/test/test_emit.py | 39 +++++++++++++++++++ exir/passes/executorch_prim_ops_registry.py | 8 ++++ kernels/prim_ops/register_prim_ops.cpp | 42 +++++++++++++++++++++ kernels/prim_ops/test/prim_ops_test.cpp | 41 ++++++++++++++++++++ 4 files changed, 130 insertions(+) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index b2b473cc8c9..89834c36ea2 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2466,6 +2466,45 @@ def forward(self, x): torch.randn(3, 4), torch.randn(8, 4), ] + ] + reference_outputs = [] + with torch.no_grad(): + for inp in test_inputs: + reference_outputs.append(model(inp)) + + batch_dim = Dim("batch", min=1, max=20) + dynamic_shapes = {"x": {0: batch_dim}} + exported_program = torch.export.export( + model, (test_inputs[0],), dynamic_shapes=dynamic_shapes + ) + + edge_program = to_edge( + exported_program, + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et_program = edge_program.to_executorch() + program_buffer = et_program.buffer + et_module = _load_for_executorch_from_buffer(program_buffer) + for inp, expected in zip(test_inputs, reference_outputs): + et_output = et_module.forward([inp]) + et_result = et_output[0] + self.assertTrue(expected.shape == et_result.shape) + self.assertTrue(torch.allclose(expected, et_result)) + + def test_emit_sym_ite(self) -> None: + class SymIteModel(nn.Module): + def forward(self, x): + n = x.shape[0] + cond = n > 5 + val = torch.sym_ite(cond, n, 6) + return torch.zeros(val, dtype=x.dtype, device=x.device) + + model = SymIteModel() + model.eval() + test_inputs = [ + torch.randn(3, 4), # n<=5: ite(False,3,6)=6 + torch.randn(8, 4), # n>5: ite(True,8,6)=8 + ] reference_outputs = [] with torch.no_grad(): for inp in test_inputs: diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index c2235ae34ad..f382627db0e 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -134,6 +134,13 @@ def sym_not(a: _SymScalar) -> bool: return not a # pyre-ignore +@bind_pattern_to_op( + executorch_prims_lib, "sym_ite.Scalar(Scalar b, Scalar t, Scalar f) -> Scalar" +) +def sym_ite(b: _SymScalar, t: _SymScalar, f: _SymScalar) -> _SymScalar: + return t if b else f # pyre-ignore + + _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { builtins.round: ops.backend.executorch_prim.round.Scalar, math.ceil: ops.backend.executorch_prim.ceil.Scalar, @@ -155,6 +162,7 @@ def sym_not(a: _SymScalar) -> bool: torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar, torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar, torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar, + torch.sym_ite: ops.backend.executorch_prim.sym_ite.Scalar, } diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 468411a706b..483b7c82873 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -541,6 +541,48 @@ static Kernel prim_ops[] = { }), #endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_ITE_SCALAR) + // executorch_prim::sym_ite.Scalar(bool b, Scalar t, Scalar f) -> Scalar + Kernel( + "executorch_prim::sym_ite.Scalar", + [](KernelRuntimeContext& context, Span stack) { + ET_KERNEL_CHECK_MSG( + context, + stack.size() == 4, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + (size_t)4, + stack.size()); + EValue& b = *stack[0]; + EValue& out = *stack[3]; + ET_KERNEL_CHECK_MSG( + context, + b.isBool(), + InvalidType, + /* void */, + "sym_ite condition must be bool, got %zu", + (size_t)b.tag); + EValue& selected = b.toBool() ? *stack[1] : *stack[2]; + if (selected.isInt()) { + out = EValue(selected.toInt()); + } else if (selected.isDouble()) { + out = EValue(selected.toDouble()); + } else if (selected.isBool()) { + out = EValue(selected.toBool()); + } else { + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "sym_ite value must be int, double, or bool, got %zu", + (size_t)selected.tag); + } + }), +#endif + #if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_INT) // executorch_prim::floordiv.int(int, int) -> int diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index c2ad20da7ec..fac62411d6c 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -42,6 +42,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) { EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar")); + EXPECT_TRUE(hasOpsFn("executorch_prim::sym_ite.Scalar")); } TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) { @@ -241,6 +242,46 @@ TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) { EXPECT_EQ(stack[1]->toInt(), 0); } +TEST_F(RegisterPrimOpsTest, SymIteReturnsCorrectValue) { + EValue values[4]; + EValue* stack[4]; + for (size_t i = 0; i < 4; i++) { + stack[i] = &values[i]; + } + + // true branch selects t (int) + values[0] = EValue(true); + values[1] = EValue((int64_t)42); + values[2] = EValue((int64_t)99); + values[3] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toInt(), 42); + + // false branch selects f (int) + values[0] = EValue(false); + values[1] = EValue((int64_t)42); + values[2] = EValue((int64_t)99); + values[3] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toInt(), 99); + + // true branch selects t (double) + values[0] = EValue(true); + values[1] = EValue(3.14); + values[2] = EValue(2.72); + values[3] = EValue(0.0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_FLOAT_EQ(stack[3]->toDouble(), 3.14); + + // false branch selects f (bool) + values[0] = EValue(false); + values[1] = EValue(true); + values[2] = EValue(false); + values[3] = EValue(false); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toBool(), false); +} + TEST_F(RegisterPrimOpsTest, TestAlgebraOps) { EValue values[3]; int64_t a = 3; From 93d5aefd7dc3dfc85396bdb40e5e6ce106acb1a0 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 19:09:38 -0700 Subject: [PATCH 2/3] Fix sym_ite test: both branches must have matching SymInt types torch.sym_ite enforces type(t) == type(f). The original test passed a SymInt (n) and a plain int (6), which fails the type check. Use a second dynamic shape dimension (m) so both branches are SymInt. --- exir/emit/test/test_emit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 89834c36ea2..3da64307d89 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2495,15 +2495,16 @@ def test_emit_sym_ite(self) -> None: class SymIteModel(nn.Module): def forward(self, x): n = x.shape[0] + m = x.shape[1] cond = n > 5 - val = torch.sym_ite(cond, n, 6) + val = torch.sym_ite(cond, n, m) return torch.zeros(val, dtype=x.dtype, device=x.device) model = SymIteModel() model.eval() test_inputs = [ - torch.randn(3, 4), # n<=5: ite(False,3,6)=6 - torch.randn(8, 4), # n>5: ite(True,8,6)=8 + torch.randn(3, 6), # n<=5: ite(False,3,6)=6 + torch.randn(8, 4), # n>5: ite(True,8,4)=8 ] reference_outputs = [] with torch.no_grad(): @@ -2511,7 +2512,8 @@ def forward(self, x): reference_outputs.append(model(inp)) batch_dim = Dim("batch", min=1, max=20) - dynamic_shapes = {"x": {0: batch_dim}} + feat_dim = Dim("feat", min=1, max=20) + dynamic_shapes = {"x": {0: batch_dim, 1: feat_dim}} exported_program = torch.export.export( model, (test_inputs[0],), dynamic_shapes=dynamic_shapes ) From bba09315e1a8027e1dfe24c4f6023be73af57281 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 21:47:01 -0700 Subject: [PATCH 3/3] Remove stray bracket from rebase conflict resolution --- exir/emit/test/test_emit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3da64307d89..03cee2caff0 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2466,7 +2466,6 @@ def forward(self, x): torch.randn(3, 4), torch.randn(8, 4), ] - ] reference_outputs = [] with torch.no_grad(): for inp in test_inputs: