Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,46 @@ def forward(self, x):
self.assertTrue(expected.shape == et_result.shape)
self.assertTrue(torch.allclose(expected, et_result))

def test_emit_sym_ite(self) -> None:
class SymIteModel(nn.Module):
def forward(self, x):
n = x.shape[0]
m = x.shape[1]
cond = n > 5
val = torch.sym_ite(cond, n, m)
return torch.zeros(val, dtype=x.dtype, device=x.device)

model = SymIteModel()
model.eval()
test_inputs = [
torch.randn(3, 6), # n<=5: ite(False,3,6)=6
torch.randn(8, 4), # n>5: ite(True,8,4)=8
]
reference_outputs = []
with torch.no_grad():
for inp in test_inputs:
reference_outputs.append(model(inp))

batch_dim = Dim("batch", min=1, max=20)
feat_dim = Dim("feat", min=1, max=20)
dynamic_shapes = {"x": {0: batch_dim, 1: feat_dim}}
exported_program = torch.export.export(
model, (test_inputs[0],), dynamic_shapes=dynamic_shapes
)

edge_program = to_edge(
exported_program,
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
et_program = edge_program.to_executorch()
program_buffer = et_program.buffer
et_module = _load_for_executorch_from_buffer(program_buffer)
for inp, expected in zip(test_inputs, reference_outputs):
et_output = et_module.forward([inp])
et_result = et_output[0]
self.assertTrue(expected.shape == et_result.shape)
self.assertTrue(torch.allclose(expected, et_result))

def test_emit_channels_last_constant(self) -> None:
"""Test that channels-last constant tensors are emitted correctly.

Expand Down
8 changes: 8 additions & 0 deletions exir/passes/executorch_prim_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def sym_not(a: _SymScalar) -> bool:
return not a # pyre-ignore


@bind_pattern_to_op(
executorch_prims_lib, "sym_ite.Scalar(Scalar b, Scalar t, Scalar f) -> Scalar"
)
def sym_ite(b: _SymScalar, t: _SymScalar, f: _SymScalar) -> _SymScalar:
return t if b else f # pyre-ignore


_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = {
builtins.round: ops.backend.executorch_prim.round.Scalar,
math.ceil: ops.backend.executorch_prim.ceil.Scalar,
Expand All @@ -155,6 +162,7 @@ def sym_not(a: _SymScalar) -> bool:
torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar,
torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar,
torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar,
torch.sym_ite: ops.backend.executorch_prim.sym_ite.Scalar,
}


Expand Down
42 changes: 42 additions & 0 deletions kernels/prim_ops/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,48 @@ static Kernel prim_ops[] = {
}),
#endif

#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \
defined(INCLUDE_EXECUTORCH_PRIM_SYM_ITE_SCALAR)
// executorch_prim::sym_ite.Scalar(bool b, Scalar t, Scalar f) -> Scalar
Kernel(
"executorch_prim::sym_ite.Scalar",
[](KernelRuntimeContext& context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(
context,
stack.size() == 4,
InvalidProgram,
/* void */,
"Expected %zu args, got %zu",
(size_t)4,
stack.size());
EValue& b = *stack[0];
EValue& out = *stack[3];
ET_KERNEL_CHECK_MSG(
context,
b.isBool(),
InvalidType,
/* void */,
"sym_ite condition must be bool, got %zu",
(size_t)b.tag);
EValue& selected = b.toBool() ? *stack[1] : *stack[2];
if (selected.isInt()) {
out = EValue(selected.toInt());
} else if (selected.isDouble()) {
out = EValue(selected.toDouble());
} else if (selected.isBool()) {
out = EValue(selected.toBool());
} else {
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"sym_ite value must be int, double, or bool, got %zu",
(size_t)selected.tag);
}
}),
#endif

#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \
defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_INT)
// executorch_prim::floordiv.int(int, int) -> int
Expand Down
41 changes: 41 additions & 0 deletions kernels/prim_ops/test/prim_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) {
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_ite.Scalar"));
}

TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
Expand Down Expand Up @@ -241,6 +242,46 @@ TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) {
EXPECT_EQ(stack[1]->toInt(), 0);
}

TEST_F(RegisterPrimOpsTest, SymIteReturnsCorrectValue) {
EValue values[4];
EValue* stack[4];
for (size_t i = 0; i < 4; i++) {
stack[i] = &values[i];
}

// true branch selects t (int)
values[0] = EValue(true);
values[1] = EValue((int64_t)42);
values[2] = EValue((int64_t)99);
values[3] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[3]->toInt(), 42);

// false branch selects f (int)
values[0] = EValue(false);
values[1] = EValue((int64_t)42);
values[2] = EValue((int64_t)99);
values[3] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[3]->toInt(), 99);

// true branch selects t (double)
values[0] = EValue(true);
values[1] = EValue(3.14);
values[2] = EValue(2.72);
values[3] = EValue(0.0);
getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span<EValue*>(stack));
EXPECT_FLOAT_EQ(stack[3]->toDouble(), 3.14);

// false branch selects f (bool)
values[0] = EValue(false);
values[1] = EValue(true);
values[2] = EValue(false);
values[3] = EValue(false);
getOpsFn("executorch_prim::sym_ite.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[3]->toBool(), false);
}

TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
EValue values[3];
int64_t a = 3;
Expand Down
Loading