Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import os
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -79,6 +80,31 @@
from torchgen.selective_build.selector import SelectiveBuilder


DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME = "register_all_kernels"
MANUAL_REGISTRATION_LIB_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def get_manual_registration_function_name(
*,
manual_registration: bool,
manual_registration_lib_name: str | None,
) -> str:
if not manual_registration_lib_name:
return DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME

if not manual_registration:
raise ValueError(
"--manual-registration-lib-name requires --manual-registration"
)

if not MANUAL_REGISTRATION_LIB_NAME_PATTERN.fullmatch(manual_registration_lib_name):
raise ValueError(
"--manual-registration-lib-name must be a valid C++ identifier"
)

return f"register_{manual_registration_lib_name}_kernels"


def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
"""
A wrapper function to basically get `sig.decl(include_context=True)`.
Expand Down Expand Up @@ -329,6 +355,9 @@ def gen_unboxing(
use_aten_lib: bool,
kernel_index: ETKernelIndex,
manual_registration: bool,
manual_registration_function_name: str = (
DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME
),
add_exception_boundary: bool = False,
) -> None:
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
Expand Down Expand Up @@ -364,6 +393,9 @@ def key_func(
), # Only write header once
},
num_shards=1,
base_env={
"manual_registration_function_name": manual_registration_function_name,
},
sharded_keys={"unboxed_kernels", "fn_header"},
)

Expand Down Expand Up @@ -484,6 +516,9 @@ def gen_headers(
kernel_index: ETKernelIndex,
cpu_fm: FileManager,
use_aten_lib: bool,
manual_registration_function_name: str = (
DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME
),
) -> None:
"""Generate headers.

Expand Down Expand Up @@ -534,6 +569,7 @@ def gen_headers(
"RegisterKernels.h",
lambda: {
"generated_comment": "@" + "generated by torchgen/gen_executorch.py",
"manual_registration_function_name": manual_registration_function_name,
},
)
headers = {
Expand Down Expand Up @@ -958,7 +994,15 @@ def main() -> None:
"--manual-registration",
action="store_true",
help="a boolean flag to indicate whether we want to manually call"
"register_kernels() or rely on static init. ",
"register_all_kernels() or rely on static init. ",
)
parser.add_argument(
"--manual-registration-lib-name",
"--manual_registration_lib_name",
"--lib-name",
"--lib_name",
help="library name to include in the manual registration API name. "
"Requires --manual-registration.",
)
parser.add_argument(
"--generate",
Expand All @@ -977,6 +1021,13 @@ def main() -> None:
)
options = parser.parse_args()
assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
try:
manual_registration_function_name = get_manual_registration_function_name(
manual_registration=options.manual_registration,
manual_registration_lib_name=options.manual_registration_lib_name,
)
except ValueError as exc:
parser.error(str(exc))

selector = get_custom_build_selector(
options.op_registration_whitelist,
Expand Down Expand Up @@ -1011,6 +1062,7 @@ def main() -> None:
kernel_index=kernel_index,
cpu_fm=cpu_fm,
use_aten_lib=options.use_aten_lib,
manual_registration_function_name=manual_registration_function_name,
)

if "sources" in options.generate:
Expand All @@ -1021,6 +1073,7 @@ def main() -> None:
use_aten_lib=options.use_aten_lib,
kernel_index=kernel_index,
manual_registration=options.manual_registration,
manual_registration_function_name=manual_registration_function_name,
add_exception_boundary=options.add_exception_boundary,
)
if custom_ops_native_functions:
Expand Down
6 changes: 3 additions & 3 deletions codegen/templates/RegisterKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

// ${generated_comment}
// This implements register_all_kernels() API that is declared in
// This implements ${manual_registration_function_name}() API that is declared in
// RegisterKernels.h
#include "RegisterKernels.h"
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
Expand All @@ -16,14 +16,14 @@
namespace torch {
namespace executor {

Error register_all_kernels() {
Error ${manual_registration_function_name}() {
Kernel kernels_to_register[] = {
${unboxed_kernels} // Generated kernels
};
Error success_with_kernel_reg =
::executorch::runtime::register_kernels({kernels_to_register});
if (success_with_kernel_reg != Error::Ok) {
ET_LOG(Error, "Failed register all kernels");
ET_LOG(Error, "Failed to register kernels");
return success_with_kernel_reg;
}
return Error::Ok;
Expand Down
4 changes: 2 additions & 2 deletions codegen/templates/RegisterKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

// ${generated_comment}
// Exposing an API for registering all kernels at once.
// Exposing an API for registering generated kernels at once.
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/operator_registry.h>
Expand All @@ -16,7 +16,7 @@
namespace torch {
namespace executor {

Error register_all_kernels();
Error ${manual_registration_function_name}();

} // namespace executor
} // namespace torch
103 changes: 103 additions & 0 deletions codegen/test/test_executorch_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from executorch.codegen.gen import (
ComputeCodegenUnboxedKernels,
gen_functions_declarations,
gen_headers,
gen_unboxing,
get_manual_registration_function_name,
parse_yaml_files,
translate_native_yaml,
)
Expand All @@ -29,6 +32,7 @@
OperatorName,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager


TEST_YAML = """
Expand Down Expand Up @@ -318,6 +322,105 @@ def tearDown(self) -> None:
pass


class TestManualRegistrationFunctionName(unittest.TestCase):
def test_default_function_name(self) -> None:
self.assertEqual(
get_manual_registration_function_name(
manual_registration=True,
manual_registration_lib_name=None,
),
"register_all_kernels",
)

def test_named_function_name(self) -> None:
self.assertEqual(
get_manual_registration_function_name(
manual_registration=True,
manual_registration_lib_name="portable_ops_lib",
),
"register_portable_ops_lib_kernels",
)

def test_named_function_requires_manual_registration(self) -> None:
with self.assertRaisesRegex(
ValueError, "--manual-registration-lib-name requires"
):
get_manual_registration_function_name(
manual_registration=False,
manual_registration_lib_name="portable_ops_lib",
)

def test_named_function_requires_cpp_identifier(self) -> None:
with self.assertRaisesRegex(ValueError, "valid C\\+\\+ identifier"):
get_manual_registration_function_name(
manual_registration=True,
manual_registration_lib_name="portable-ops-lib",
)


class TestManualRegistrationTemplates(unittest.TestCase):
def setUp(self) -> None:
self.template_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "templates"
)
self.function_name = "register_portable_ops_lib_kernels"

def test_register_kernels_header_uses_named_function(self) -> None:
with tempfile.TemporaryDirectory() as tempdir:
gen_headers(
native_functions=[],
gen_custom_ops_header=False,
custom_ops_native_functions=[],
selector=SelectiveBuilder.get_nop_selector(),
kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type]
cpu_fm=FileManager(tempdir, self.template_dir, False),
use_aten_lib=False,
manual_registration_function_name=self.function_name,
)

with open(os.path.join(tempdir, "RegisterKernels.h")) as f:
header = f.read()

self.assertIn(f"Error {self.function_name}();", header)
self.assertNotIn("Error register_all_kernels();", header)

def test_register_kernels_cpp_uses_named_function(self) -> None:
with tempfile.TemporaryDirectory() as tempdir:
native_function, backend_index = NativeFunction.from_yaml(
{
"func": "custom_1::op_1() -> bool",
"dispatch": {"CPU": "kernel_1"},
},
loc=Location(__file__, 1),
valid_tags=set(),
)
backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}
BackendIndex.grow_index(backend_indices, backend_index)

gen_unboxing(
native_functions=[native_function],
cpu_fm=FileManager(tempdir, self.template_dir, False),
selector=SelectiveBuilder.from_yaml_dict(
{"include_all_operators": True}
),
use_aten_lib=False,
kernel_index=ETKernelIndex.from_backend_indices(backend_indices),
manual_registration=True,
manual_registration_function_name=self.function_name,
)

with open(os.path.join(tempdir, "RegisterKernelsEverything.cpp")) as f:
source = f.read()

self.assertIn(f"Error {self.function_name}() {{", source)
self.assertNotIn("Error register_all_kernels() {", source)


class TestGenFunctionsDeclarations(unittest.TestCase):
def setUp(self) -> None:
(
Expand Down
10 changes: 10 additions & 0 deletions shim_et/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def _prepare_genrule_and_lib(
custom_ops_yaml_path = None,
custom_ops_requires_runtime_registration = True,
manual_registration = False,
manual_registration_lib_name = None,
aten_mode = False,
support_exceptions = True):
"""
Expand Down Expand Up @@ -350,6 +351,10 @@ def _prepare_genrule_and_lib(
genrule_cmd = genrule_cmd + [
"--manual_registration",
]
if manual_registration_lib_name:
genrule_cmd = genrule_cmd + [
"--manual-registration-lib-name={}".format(manual_registration_lib_name),
]
if custom_ops_yaml_path:
genrule_cmd = genrule_cmd + [
"--custom_ops_yaml_path=" + custom_ops_yaml_path,
Expand Down Expand Up @@ -828,6 +833,7 @@ def executorch_generated_lib(
visibility = [],
aten_mode = False,
manual_registration = False,
manual_registration_lib_name = None,
use_default_aten_ops_lib = True,
deps = [],
xplat_deps = [],
Expand Down Expand Up @@ -888,6 +894,9 @@ def executorch_generated_lib(
xplat_deps: Additional xplat deps, can be used to provide custom operator library.
fbcode_deps: Additional fbcode deps, can be used to provide custom operator library.
compiler_flags: compiler_flags args to runtime.cxx_library
manual_registration_lib_name: Optional C++ identifier to use when
generating a named manual registration API. If omitted, manual
registration keeps using `register_all_kernels`.
dtype_selective_build: In additional to operator selection, dtype selective build
further selects the dtypes for each operator. Can be used with model or dict
selective build APIs, where dtypes can be specified.
Expand Down Expand Up @@ -999,6 +1008,7 @@ def executorch_generated_lib(
custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration,
aten_mode = aten_mode,
manual_registration = manual_registration,
manual_registration_lib_name = manual_registration_lib_name,
support_exceptions = support_exceptions,
)

Expand Down
Loading
Loading