diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index b2b473cc8c9..03cee2caff0 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2490,6 +2490,46 @@ def forward(self, x): self.assertTrue(expected.shape == et_result.shape) self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_sym_ite(self) -> None: + class SymIteModel(nn.Module): + def forward(self, x): + n = x.shape[0] + m = x.shape[1] + cond = n > 5 + val = torch.sym_ite(cond, n, m) + return torch.zeros(val, dtype=x.dtype, device=x.device) + + model = SymIteModel() + model.eval() + test_inputs = [ + torch.randn(3, 6), # n<=5: ite(False,3,6)=6 + torch.randn(8, 4), # n>5: ite(True,8,4)=8 + ] + reference_outputs = [] + with torch.no_grad(): + for inp in test_inputs: + reference_outputs.append(model(inp)) + + batch_dim = Dim("batch", min=1, max=20) + feat_dim = Dim("feat", min=1, max=20) + dynamic_shapes = {"x": {0: batch_dim, 1: feat_dim}} + exported_program = torch.export.export( + model, (test_inputs[0],), dynamic_shapes=dynamic_shapes + ) + + edge_program = to_edge( + exported_program, + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et_program = edge_program.to_executorch() + program_buffer = et_program.buffer + et_module = _load_for_executorch_from_buffer(program_buffer) + for inp, expected in zip(test_inputs, reference_outputs): + et_output = et_module.forward([inp]) + et_result = et_output[0] + self.assertTrue(expected.shape == et_result.shape) + self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_channels_last_constant(self) -> None: """Test that channels-last constant tensors are emitted correctly. diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index c2235ae34ad..f382627db0e 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -134,6 +134,13 @@ def sym_not(a: _SymScalar) -> bool: return not a # pyre-ignore +@bind_pattern_to_op( + executorch_prims_lib, "sym_ite.Scalar(Scalar b, Scalar t, Scalar f) -> Scalar" +) +def sym_ite(b: _SymScalar, t: _SymScalar, f: _SymScalar) -> _SymScalar: + return t if b else f # pyre-ignore + + _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { builtins.round: ops.backend.executorch_prim.round.Scalar, math.ceil: ops.backend.executorch_prim.ceil.Scalar, @@ -155,6 +162,7 @@ def sym_not(a: _SymScalar) -> bool: torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar, torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar, torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar, + torch.sym_ite: ops.backend.executorch_prim.sym_ite.Scalar, } diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 468411a706b..483b7c82873 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -541,6 +541,48 @@ static Kernel prim_ops[] = { }), #endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_ITE_SCALAR) + // executorch_prim::sym_ite.Scalar(bool b, Scalar t, Scalar f) -> Scalar + Kernel( + "executorch_prim::sym_ite.Scalar", + [](KernelRuntimeContext& context, Span stack) { + ET_KERNEL_CHECK_MSG( + context, + stack.size() == 4, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + (size_t)4, + stack.size()); + EValue& b = *stack[0]; + EValue& out = *stack[3]; + ET_KERNEL_CHECK_MSG( + context, + b.isBool(), + InvalidType, + /* void */, + "sym_ite condition must be bool, got %zu", + (size_t)b.tag); + EValue& selected = b.toBool() ? *stack[1] : *stack[2]; + if (selected.isInt()) { + out = EValue(selected.toInt()); + } else if (selected.isDouble()) { + out = EValue(selected.toDouble()); + } else if (selected.isBool()) { + out = EValue(selected.toBool()); + } else { + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "sym_ite value must be int, double, or bool, got %zu", + (size_t)selected.tag); + } + }), +#endif + #if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_INT) // executorch_prim::floordiv.int(int, int) -> int diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index c2ad20da7ec..fac62411d6c 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -42,6 +42,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) { EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar")); + EXPECT_TRUE(hasOpsFn("executorch_prim::sym_ite.Scalar")); } TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) { @@ -241,6 +242,46 @@ TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) { EXPECT_EQ(stack[1]->toInt(), 0); } +TEST_F(RegisterPrimOpsTest, SymIteReturnsCorrectValue) { + EValue values[4]; + EValue* stack[4]; + for (size_t i = 0; i < 4; i++) { + stack[i] = &values[i]; + } + + // true branch selects t (int) + values[0] = EValue(true); + values[1] = EValue((int64_t)42); + values[2] = EValue((int64_t)99); + values[3] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toInt(), 42); + + // false branch selects f (int) + values[0] = EValue(false); + values[1] = EValue((int64_t)42); + values[2] = EValue((int64_t)99); + values[3] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toInt(), 99); + + // true branch selects t (double) + values[0] = EValue(true); + values[1] = EValue(3.14); + values[2] = EValue(2.72); + values[3] = EValue(0.0); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_FLOAT_EQ(stack[3]->toDouble(), 3.14); + + // false branch selects f (bool) + values[0] = EValue(false); + values[1] = EValue(true); + values[2] = EValue(false); + values[3] = EValue(false); + getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[3]->toBool(), false); +} + TEST_F(RegisterPrimOpsTest, TestAlgebraOps) { EValue values[3]; int64_t a = 3;