Skip to content

Commit 96c7800

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
EdgeProgramManager passes (pytorch#16986)
Summary: - Adds support to run to run passes on ExportedPrograms and EdgeProgramManager - Creates an EdgeProgramManagerPassManager Reviewed By: larryliu0820 Differential Revision: D91725222
1 parent 994f1fe commit 96c7800

10 files changed

Lines changed: 933 additions & 78 deletions

File tree

backends/nxp/tests/ir/edge_passes/test_edge_passes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def test_remove_additional_quantize_dequantize_nodes_pass(self):
317317
example_input,
318318
)
319319

320-
edge_program_manager = edge_program_manager.transform(NeutronEdgePassManager())
320+
edge_program_manager = edge_program_manager.transform(
321+
NeutronEdgePassManager().passes
322+
)
321323

322324
compile_spec = generate_neutron_compile_spec(target)
323325
partitioner = NeutronPartitioner(

examples/nxp/aot_neutron_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def get_model_and_inputs_from_name(model_name: str, use_random_dataset: bool):
351351
)
352352

353353
edge_program_manager = edge_program_manager.transform(
354-
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
354+
[RemoveAdditionalQDQClustersPass()]
355355
)
356356

357357
logging.debug(f"Lowered graph:\n{edge_program_manager.exported_program().graph}")

exir/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ fbcode_target(_kind = runtime.python_library,
267267
deps = [
268268
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
269269
":error",
270+
":pass_base",
270271
"//caffe2:torch",
272+
"//executorch/exir/program:edge_program_manager_pass_base",
271273
],
272274
)
273275

exir/pass_manager.py

Lines changed: 176 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,49 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
9-
from typing import Callable, List, Optional, Union
8+
import inspect
9+
import logging
10+
from typing import Callable, List, Optional, TYPE_CHECKING, TypeAlias, Union
1011

1112
import torch
1213
import torch.fx.passes.infra.pass_manager as fx
1314
import torch.utils._pytree as pytree
1415
from executorch.exir.error import ExportError, ExportErrorType
16+
1517
from torch.fx.passes.infra.pass_base import PassResult
16-
from typing_extensions import TypeAlias
1718

18-
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
19+
if TYPE_CHECKING:
20+
from executorch.exir.program._program import EdgeProgramManager
21+
from executorch.exir.program.edge_program_manager_pass_base import (
22+
EdgeProgramManagerPassBase,
23+
EdgeProgramManagerPassResult,
24+
ExportedProgramPassBase,
25+
)
26+
27+
logger = logging.getLogger(__name__)
28+
logger.setLevel(logging.WARNING)
29+
30+
PassType: TypeAlias = Union[
31+
"EdgeProgramManagerPassBase",
32+
"ExportedProgramPassBase",
33+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
34+
]
35+
36+
37+
def _get_pass_name(fn: PassType) -> str:
38+
from executorch.exir.program.edge_program_manager_pass_base import (
39+
ExportedProgramToEdgeProgramManagerPassWrapper,
40+
GraphModuleBackedExportedProgramPassWrapper,
41+
GraphModuleToEdgeProgramManagerPassWrapper,
42+
)
43+
44+
if isinstance(fn, ExportedProgramToEdgeProgramManagerPassWrapper):
45+
return _get_pass_name(fn._pass)
46+
if isinstance(fn, GraphModuleToEdgeProgramManagerPassWrapper):
47+
return _get_pass_name(fn._inner._pass)
48+
if isinstance(fn, GraphModuleBackedExportedProgramPassWrapper):
49+
return _get_pass_name(fn._pass)
50+
return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
1951

2052

2153
class PassManager(fx.PassManager):
@@ -27,6 +59,8 @@ class PassManager(fx.PassManager):
2759
* **passes**: A list of callable passes
2860
* **params**: An instance of PassManagerParams containing the result of the
2961
flags set in the constructor.
62+
63+
Note: This class is deprecated. Please use EdgeProgramManagerPassManager instead.
3064
"""
3165

3266
def __init__(
@@ -41,6 +75,9 @@ def __init__(
4175
enable_debug_pass: set to true to enable the debug passes
4276
run_checks_after_each_pass: whether to run checks and linting after each pass
4377
"""
78+
logger.warning(
79+
"PassManager is deprecated. Please use EdgeProgramManagerPassManager instead."
80+
)
4481

4582
# Flatten the passes to a list of callables
4683
passes = passes if passes else []
@@ -76,3 +113,138 @@ def check(self, module: torch.nn.Module) -> None:
76113
ExportErrorType.NOT_SUPPORTED,
77114
f"call_method `{node}` is not supported except for backend delegate.",
78115
)
116+
117+
118+
class EdgeProgramManagerPassManager:
119+
"""
120+
Runs multiple passes on an EdgeProgramManager.
121+
122+
This PassManager accepts passes at three levels of abstraction:
123+
- EdgeProgramManagerPassBase: operates on the full EdgeProgramManager
124+
- ExportedProgramPassBase: operates on individual ExportedPrograms
125+
- Callable[[GraphModule], PassResult]: operates on GraphModules
126+
127+
Lower-level passes are automatically wrapped to operate at the EPM level.
128+
The iteration over methods within an EPM is handled by the wrapper passes,
129+
not by this pass manager.
130+
"""
131+
132+
def __init__(
133+
self,
134+
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
135+
constraints: Optional[List[Callable[[Callable, Callable], bool]]] = None,
136+
run_checks_after_each_pass: bool = False,
137+
suppress_exported_program_pre_verification: bool = True,
138+
steps: int = 1,
139+
) -> None:
140+
from executorch.exir.program.edge_program_manager_pass_base import (
141+
EdgeProgramManagerPassBase,
142+
ExportedProgramPassBase,
143+
ExportedProgramToEdgeProgramManagerPassWrapper,
144+
GraphModuleToEdgeProgramManagerPassWrapper,
145+
)
146+
147+
wrapped_passes: List[EdgeProgramManagerPassBase] = []
148+
for fn in pytree.tree_flatten(passes)[0] if passes else []:
149+
if isinstance(fn, EdgeProgramManagerPassBase):
150+
wrapped_passes.append(fn)
151+
elif isinstance(fn, ExportedProgramPassBase):
152+
wrapped_passes.append(
153+
ExportedProgramToEdgeProgramManagerPassWrapper(fn)
154+
)
155+
else:
156+
wrapped_passes.append(
157+
GraphModuleToEdgeProgramManagerPassWrapper(fn)
158+
)
159+
160+
if suppress_exported_program_pre_verification:
161+
logger.warning(
162+
"Pre-verification of exported program is suppressed. This means that the exported program may pass validation prior to running the pass manager."
163+
)
164+
165+
self.passes: List["EdgeProgramManagerPassBase"] = wrapped_passes
166+
self.constraints = constraints
167+
self.run_checks_after_each_pass = run_checks_after_each_pass
168+
self.suppress_check_failures = suppress_exported_program_pre_verification
169+
self.steps = steps
170+
self._validated = False
171+
172+
def solve_constraints(self) -> None:
173+
"""Placeholder for constraint solving. Marks the pass manager as validated."""
174+
self._validated = True
175+
176+
def check(self, epm: "EdgeProgramManager") -> None:
177+
"""
178+
Runs exported program validation on each method in the EdgeProgramManager.
179+
"""
180+
for name, program in epm._edge_programs.items(): # noqa: B007
181+
if not self.suppress_check_failures:
182+
program.validate()
183+
184+
module = program.graph_module
185+
module.recompile()
186+
module.graph.lint()
187+
188+
for node in module.graph.nodes:
189+
if node.op == "call_method":
190+
raise ExportError(
191+
ExportErrorType.NOT_SUPPORTED,
192+
f"call_method `{node}` is not supported except for backend delegate.",
193+
)
194+
195+
def __call__(
196+
self, epm: "EdgeProgramManager"
197+
) -> "EdgeProgramManagerPassResult":
198+
"""
199+
Runs passes on an EdgeProgramManager.
200+
201+
Args:
202+
epm: The EdgeProgramManager to transform.
203+
204+
Returns:
205+
EdgeProgramManagerPassResult containing the transformed EPM and whether
206+
or not it was modified.
207+
"""
208+
from executorch.exir.program.edge_program_manager_pass_base import (
209+
EdgeProgramManagerPassResult,
210+
)
211+
212+
if not self._validated:
213+
self.solve_constraints()
214+
215+
self.check(epm)
216+
217+
overall_modified = False
218+
219+
for _ in range(self.steps):
220+
step_modified = False
221+
222+
for i, fn in enumerate(self.passes):
223+
try:
224+
result = fn(epm)
225+
if result.modified:
226+
logger.debug(
227+
"EPM after pass '%s'",
228+
_get_pass_name(fn),
229+
)
230+
231+
epm = result.edge_program_manager
232+
step_modified = step_modified or result.modified
233+
234+
if self.run_checks_after_each_pass:
235+
self.check(epm)
236+
237+
except Exception as e:
238+
prev_names = [_get_pass_name(p) for p in self.passes[:i]]
239+
msg = (
240+
f"An error occurred when running the '{_get_pass_name(fn)}' pass "
241+
f"after the following passes: {prev_names}\n"
242+
f"Original error: {e}"
243+
)
244+
raise type(e)(msg) from e
245+
246+
overall_modified = overall_modified or step_modified
247+
if not step_modified:
248+
break
249+
250+
return EdgeProgramManagerPassResult(epm, overall_modified)

exir/program/BUCK

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
22
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
43

54
oncall("executorch")
65

@@ -15,6 +14,17 @@ fbcode_target(_kind = runtime.python_library,
1514
],
1615
)
1716

17+
fbcode_target(_kind = runtime.python_library,
18+
name = "edge_program_manager_pass_base",
19+
srcs = [
20+
"edge_program_manager_pass_base.py",
21+
],
22+
deps = [
23+
"//caffe2:torch",
24+
"//executorch/exir:pass_base",
25+
],
26+
)
27+
1828
fbcode_target(_kind = runtime.python_library,
1929
name = "program",
2030
srcs = [
@@ -47,6 +57,7 @@ fbcode_target(_kind = runtime.python_library,
4757
"//executorch/exir/passes:spec_prop_pass",
4858
"//executorch/exir/passes:weights_to_outputs_pass",
4959
"//executorch/exir/passes:convert_constant_dim_order_pass",
60+
"//executorch/exir/program:edge_program_manager_pass_base",
5061
"//executorch/exir/verification:verifier",
5162
"//executorch/extension/flat_tensor/serialize:serialize",
5263
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])

0 commit comments

Comments
 (0)