Skip to content

Commit e5a316a

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
ExportedProgram passes (#16986)
Summary: - Adds support to _ExportPassBase to run passes on ExportedPrograms. Ensures that we can only run either call or call_exported_program, not both - Updates exir.PassManager to add a new pass manager which operates on exported programs. This is done to ensure backwards compatibility, while allowing _program transformations to use either pass manager Differential Revision: D91725222
1 parent 50c170c commit e5a316a

File tree

4 files changed

+510
-47
lines changed

4 files changed

+510
-47
lines changed

exir/pass_base.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
9+
from abc import ABC, abstractmethod
1010
import operator
1111
import traceback
1212
from contextlib import nullcontext
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -27,16 +28,15 @@
2728

2829
import torch
2930
from executorch.exir import memory
30-
3131
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
32-
3332
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3433
from executorch.exir.error import ExportError, ExportErrorType
3534
from torch import fx
3635
from torch._dispatch.python import enable_python_dispatcher
3736
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
3837
from torch._subclasses.fake_tensor import FakeTensor
3938
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39+
from torch.export import ExportedProgram
4040
from torch.fx import traceback as fx_traceback
4141
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
4242
from torch.fx.graph import CodeGen
@@ -157,6 +157,56 @@ class ExportPassBaseError(RuntimeError):
157157
pass
158158

159159

160+
@dataclass(frozen=True)
161+
class ExportedProgramPassResult:
162+
exported_program: ExportedProgram
163+
modified: bool
164+
165+
class ExportedProgramPassBase(ABC):
166+
"""
167+
Base interface for implementing passes that operate on ExportedProgram.
168+
"""
169+
170+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
171+
"""
172+
Runs the precondition check, the pass itself, and the postcondition check.
173+
"""
174+
175+
self.requires(exported_program)
176+
res = self.call(exported_program)
177+
self.ensures(exported_program)
178+
return res
179+
180+
@abstractmethod
181+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
182+
"""
183+
The pass that is run through the given exported program. To implement a
184+
pass, it is required to implement this function.
185+
186+
Args:
187+
exported_program: The exported program we will run a pass on
188+
"""
189+
190+
def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027
191+
"""
192+
This function will be called before the pass is run and will check that
193+
the given exported program contains the preconditions needed to run the
194+
pass. It is not required to implement this function.
195+
196+
Args:
197+
exported_program: The exported program we will run checks on
198+
"""
199+
200+
def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027
201+
"""
202+
This function will be called after the pass is run and will check that
203+
the given exported program contains the postconditions needed to run the
204+
pass. It is not required to implement this function.
205+
206+
Args:
207+
exported_program: The exported program we will run checks on
208+
"""
209+
160210
class _ExportPassBase(PassBase):
161211
"""
162212
Interpreter-based pass class to help users maintain the IR spec while writing

exir/pass_manager.py

Lines changed: 181 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,60 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
9-
from typing import Callable, List, Optional, Union
8+
import copy
9+
import inspect
10+
import logging
11+
from typing import Callable, List, Optional, TypeAlias, Union
1012

1113
import torch
1214
import torch.fx.passes.infra.pass_manager as fx
1315
import torch.utils._pytree as pytree
1416
from executorch.exir.error import ExportError, ExportErrorType
17+
from executorch.exir.pass_base import (
18+
ExportedProgramPassResult,
19+
ExportedProgramPassBase,
20+
)
21+
from torch.export import ExportedProgram
1522
from torch.fx.passes.infra.pass_base import PassResult
16-
from typing_extensions import TypeAlias
23+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.WARNING)
27+
28+
PassType: TypeAlias = Union[
29+
ExportedProgramPassBase, Callable[[torch.fx.GraphModule], Optional[PassResult]]
30+
]
31+
1732

18-
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
33+
def _get_pass_name(fn: PassType) -> str:
34+
"""Returns a human-readable name for a pass."""
35+
return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
1936

2037

38+
def _is_graph_module_pass(fn: PassType) -> bool:
39+
"""Returns True if the pass operates on GraphModule (not ExportedProgram)."""
40+
return not isinstance(fn, ExportedProgramPassBase)
41+
2142
class PassManager(fx.PassManager):
2243
"""
23-
Class to run multiple passes on a given graph module. The PassManager is
24-
callable so to run it, we can just call the PassManager instance.
44+
Runs multiple passes on a GraphModule.
45+
46+
This is the legacy PassManager that extends torch.fx.passes.infra.pass_manager.PassManager.
47+
Use this when you need to run passes on a GraphModule directly.
2548
26-
Private Attributes:
27-
* **passes**: A list of callable passes
28-
* **params**: An instance of PassManagerParams containing the result of the
29-
flags set in the constructor.
49+
For running passes on ExportedProgram, use ExportedProgramPassManager instead.
3050
"""
3151

3252
def __init__(
3353
self,
3454
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
3555
run_checks_after_each_pass: bool = False,
3656
suppress_check_failures: bool = False,
57+
steps: int = 1,
3758
) -> None:
38-
r"""
39-
Args:
40-
passes: A list of passes
41-
enable_debug_pass: set to true to enable the debug passes
42-
run_checks_after_each_pass: whether to run checks and linting after each pass
43-
"""
44-
59+
logger.warning(
60+
"PassManager is deprecated. Please use ExportedProgramPassManager instead."
61+
)
4562
# Flatten the passes to a list of callables
4663
passes = passes if passes else []
4764
flattened_passes = [
@@ -52,6 +69,7 @@ def __init__(
5269
flattened_passes,
5370
run_checks_after_each_pass=run_checks_after_each_pass,
5471
suppress_check_failures=suppress_check_failures,
72+
steps=steps,
5573
)
5674

5775
def check(self, module: torch.nn.Module) -> None:
@@ -65,14 +83,158 @@ def check(self, module: torch.nn.Module) -> None:
6583
node's spec field is a tuple)
6684
- Ensure that the graph module has type torch.fx.GraphModule
6785
"""
68-
assert isinstance(module, fx.GraphModule)
86+
assert isinstance(module, torch.fx.GraphModule)
87+
module.recompile()
88+
module.graph.lint()
89+
90+
for node in module.graph.nodes:
91+
if node.op == "call_method":
92+
raise ExportError(
93+
ExportErrorType.NOT_SUPPORTED,
94+
f"call_method `{node}` is not supported except for backend delegate.",
95+
)
96+
97+
class ExportedProgramPassManager(fx.PassManager):
98+
"""
99+
Runs multiple passes on an ExportedProgram.
100+
101+
This PassManager is specifically designed for ExportedProgram and supports
102+
both GraphModule-only passes and ExportedProgram-aware passes.
103+
104+
For running passes on GraphModule directly, use PassManager instead.
105+
"""
106+
107+
def __init__(
108+
self,
109+
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
110+
constraints: Optional[List[Callable[[Callable, Callable], bool]]] = None,
111+
run_checks_after_each_pass: bool = False,
112+
suppress_check_failures: bool = False,
113+
steps: int = 1,
114+
) -> None:
115+
wrapped_passes = (
116+
[
117+
fn if isinstance(fn, ExportedProgramPassBase) else pass_result_wrapper(fn)
118+
for fn in pytree.tree_flatten(passes)[0]
119+
]
120+
if passes
121+
else []
122+
)
123+
124+
super().__init__(
125+
wrapped_passes,
126+
constraints=constraints,
127+
run_checks_after_each_pass=run_checks_after_each_pass,
128+
suppress_check_failures=suppress_check_failures,
129+
steps=steps,
130+
)
131+
132+
def check(self, module: torch.fx.GraphModule) -> None:
133+
"""Validates graph module invariants."""
69134
module.recompile()
70135
module.graph.lint()
71-
# TODO(qihan): use verifier.check_is_exir
72136

73137
for node in module.graph.nodes:
74138
if node.op == "call_method":
75139
raise ExportError(
76140
ExportErrorType.NOT_SUPPORTED,
77141
f"call_method `{node}` is not supported except for backend delegate.",
78142
)
143+
144+
def _run_graph_module_pass(
145+
self,
146+
fn: PassType,
147+
graph_module: torch.fx.GraphModule,
148+
) -> PassResult:
149+
"""Runs a pass that operates on GraphModule."""
150+
res = fn(graph_module)
151+
152+
if res is None:
153+
raise TypeError(
154+
f"The result of pass {_get_pass_name(fn)} should be type PassResult. "
155+
"Please wrap it with pass_result_wrapper()"
156+
)
157+
158+
if res.modified:
159+
logger.debug(
160+
"Graph after pass '%s': %s", _get_pass_name(fn), res.graph_module.graph
161+
)
162+
res.graph_module.recompile()
163+
164+
return res
165+
166+
def _run_exported_program_pass(
167+
self,
168+
fn: ExportedProgramPassBase,
169+
exported_program: ExportedProgram,
170+
) -> ExportedProgramPassResult:
171+
"""Runs a pass that operates on ExportedProgram."""
172+
res = fn(exported_program)
173+
174+
if res.modified:
175+
logger.debug(
176+
"Graph after pass '%s': %s",
177+
_get_pass_name(fn),
178+
res.exported_program.graph_module.graph,
179+
)
180+
res.exported_program.graph_module.recompile()
181+
182+
return res
183+
184+
# pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram
185+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
186+
"""
187+
Runs passes on an ExportedProgram.
188+
189+
Handles both GraphModule-only passes and ExportedProgram-aware passes.
190+
191+
Args:
192+
exported_program: The exported program to transform.
193+
194+
Returns:
195+
ExportedProgramPassResult containing the transformed program.
196+
"""
197+
if not self._validated:
198+
self.solve_constraints()
199+
200+
exported_program = copy.copy(exported_program)
201+
202+
# Check graph invariants before running passes
203+
self.check(exported_program.graph_module)
204+
205+
overall_modified = False
206+
207+
for _ in range(self.steps):
208+
step_modified = False
209+
210+
for i, fn in enumerate(self.passes):
211+
try:
212+
if _is_graph_module_pass(fn):
213+
result = self._run_graph_module_pass(
214+
fn, exported_program.graph_module
215+
)
216+
exported_program._graph_module = result.graph_module
217+
step_modified = step_modified or result.modified
218+
219+
if self.run_checks_after_each_pass:
220+
self.check(result.graph_module)
221+
else:
222+
assert isinstance(fn, ExportedProgramPassBase)
223+
result = self._run_exported_program_pass(fn, exported_program)
224+
exported_program = result.exported_program
225+
step_modified = step_modified or result.modified
226+
227+
if self.run_checks_after_each_pass:
228+
exported_program.validate()
229+
self.check(exported_program.graph_module)
230+
231+
except Exception as e:
232+
prev_names = [_get_pass_name(p) for p in self.passes[:i]]
233+
msg = f"An error occurred when running the '{_get_pass_name(fn)}' pass after the following passes: {prev_names}"
234+
raise Exception(msg) from e # noqa: TRY002
235+
236+
overall_modified = overall_modified or step_modified
237+
if not step_modified:
238+
break
239+
240+
return ExportedProgramPassResult(exported_program, overall_modified)

0 commit comments

Comments
 (0)