Skip to content

Commit bf5aa36

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
EdgeProgramManager passes (pytorch#16986)
Summary: Pull Request resolved: pytorch#16986 - Adds support to run to run passes on ExportedPrograms and EdgeProgramManager - EdgeProgramManager transform behaves basically like a pass manager Reviewed By: larryliu0820, ethansfng Differential Revision: D91725222
1 parent 1533e55 commit bf5aa36

6 files changed

Lines changed: 797 additions & 55 deletions

File tree

exir/BUCK

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,17 @@ fbcode_target(_kind = runtime.python_library,
259259
],
260260
)
261261

262+
fbcode_target(_kind = runtime.python_library,
263+
name = "edge_program_manager_pass_base",
264+
srcs = [
265+
"edge_program_manager_pass_base.py",
266+
],
267+
deps = [
268+
"//caffe2:torch",
269+
"//executorch/exir:pass_base",
270+
],
271+
)
272+
262273
fbcode_target(_kind = runtime.python_library,
263274
name = "pass_manager",
264275
srcs = [
@@ -267,6 +278,7 @@ fbcode_target(_kind = runtime.python_library,
267278
deps = [
268279
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
269280
":error",
281+
":pass_base",
270282
"//caffe2:torch",
271283
],
272284
)
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import copy
10+
from abc import ABC, abstractmethod
11+
from dataclasses import dataclass
12+
from typing import Callable, Dict, Optional, Sequence, TYPE_CHECKING, Union
13+
14+
import torch
15+
from torch.export import ExportedProgram
16+
from torch.fx.passes.infra.pass_base import PassResult
17+
18+
if TYPE_CHECKING:
19+
from executorch.exir.program._program import EdgeProgramManager
20+
21+
22+
@dataclass(frozen=True)
23+
class ExportedProgramPassResult:
24+
"""Result of running a pass on an ExportedProgram."""
25+
26+
exported_program: ExportedProgram
27+
modified: bool
28+
29+
30+
class ExportedProgramPassBase(ABC):
31+
"""
32+
Base interface for implementing passes that operate on ExportedProgram.
33+
"""
34+
35+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
36+
"""
37+
Runs the precondition check, the pass itself, and the postcondition check.
38+
"""
39+
40+
self.requires(exported_program)
41+
res = self.call(exported_program)
42+
self.ensures(exported_program)
43+
return res
44+
45+
@abstractmethod
46+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
47+
"""
48+
The pass that is run through the given exported program. To implement a
49+
pass, it is required to implement this function.
50+
51+
Args:
52+
exported_program: The exported program we will run a pass on
53+
"""
54+
55+
def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027
56+
"""
57+
This function will be called before the pass is run and will check that
58+
the given exported program contains the preconditions needed to run the
59+
pass. It is not required to implement this function.
60+
61+
Args:
62+
exported_program: The exported program we will run checks on
63+
"""
64+
65+
def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027
66+
"""
67+
This function will be called after the pass is run and will check that
68+
the given exported program contains the postconditions needed to run the
69+
pass. It is not required to implement this function.
70+
71+
Args:
72+
exported_program: The exported program we will run checks on
73+
"""
74+
75+
76+
@dataclass(frozen=True)
77+
class EdgeProgramManagerPassResult:
78+
"""Result of running a pass on an EdgeProgramManager."""
79+
80+
edge_program_manager: "EdgeProgramManager"
81+
modified: bool
82+
83+
84+
class EdgeProgramManagerPassBase(ABC):
85+
"""
86+
Base interface for implementing passes that operate on EdgeProgramManager.
87+
88+
This is the highest-level pass abstraction. Passes at this level can:
89+
- Transform individual ExportedPrograms within the manager
90+
- Modify constant methods
91+
- Split one program into multiple programs
92+
- Add or remove programs from the manager
93+
94+
Lower-level passes (ExportedProgramPassBase, GraphModule callables) can be
95+
lifted to this level using the provided wrapper classes.
96+
"""
97+
98+
def __call__(
99+
self, epm: "EdgeProgramManager"
100+
) -> EdgeProgramManagerPassResult:
101+
"""
102+
Runs the precondition check, the pass itself, and the postcondition check.
103+
"""
104+
self.requires(epm)
105+
res = self.call(epm)
106+
self.ensures(res.edge_program_manager)
107+
return res
108+
109+
@abstractmethod
110+
def call(
111+
self, epm: "EdgeProgramManager"
112+
) -> EdgeProgramManagerPassResult:
113+
"""
114+
The pass that is run on the given EdgeProgramManager. To implement a
115+
pass, it is required to implement this function.
116+
117+
Args:
118+
epm: The EdgeProgramManager to transform
119+
"""
120+
121+
def requires(self, epm: "EdgeProgramManager") -> None: # noqa: B027
122+
"""
123+
This function will be called before the pass is run and will check that
124+
the given EdgeProgramManager contains the preconditions needed to run the
125+
pass. It is not required to implement this function.
126+
127+
Args:
128+
epm: The EdgeProgramManager we will run checks on
129+
"""
130+
131+
def ensures(self, epm: "EdgeProgramManager") -> None: # noqa: B027
132+
"""
133+
This function will be called after the pass is run and will check that
134+
the given EdgeProgramManager contains the postconditions needed to run the
135+
pass. It is not required to implement this function.
136+
137+
Args:
138+
epm: The EdgeProgramManager we will run checks on
139+
"""
140+
141+
142+
class GraphModuleBackedExportedProgramPassWrapper(ExportedProgramPassBase):
143+
"""
144+
Wrapper that adapts a GraphModule pass to work as an ExportedProgramPassBase.
145+
146+
This wrapper takes a pass that operates on GraphModule and makes it compatible
147+
with ExportedProgramPassBase by extracting the graph module, running the pass,
148+
and updating the ExportedProgram in-place.
149+
"""
150+
151+
def __init__(
152+
self,
153+
graph_module_pass: Callable[[torch.fx.GraphModule], PassResult],
154+
) -> None:
155+
super().__init__()
156+
self._pass = graph_module_pass
157+
158+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
159+
from executorch.exir.program._program import (
160+
_get_updated_graph_signature,
161+
_get_updated_range_constraints,
162+
)
163+
164+
result = self._pass(exported_program.graph_module)
165+
166+
if result.modified:
167+
# Cannot use _update_exported_program_graph_module because it
168+
# runs verification, and it is not the responsibility of the
169+
# pass to run verification. EdgeProgram manager can
170+
# optionally run verification after a pass.
171+
result.graph_module.recompile()
172+
exported_program = copy.copy(exported_program) # bypasses __init__ and _validate()
173+
174+
exported_program._graph_module = result.graph_module
175+
exported_program._graph_signature = _get_updated_graph_signature(
176+
exported_program.graph_signature, result.graph_module
177+
)
178+
exported_program._range_constraints = _get_updated_range_constraints(
179+
result.graph_module
180+
)
181+
exported_program._module_call_graph = copy.deepcopy(
182+
exported_program._module_call_graph
183+
)
184+
exported_program._graph_module.meta.update(exported_program.graph_module.meta)
185+
186+
187+
return ExportedProgramPassResult(exported_program, result.modified)
188+
189+
190+
class ExportedProgramToEdgeProgramManagerPassWrapper(EdgeProgramManagerPassBase):
191+
"""
192+
Adapts an ExportedProgramPassBase to run on every method in an EdgeProgramManager.
193+
194+
This wrapper takes a pass that operates on a single ExportedProgram and applies it
195+
to every method in the EdgeProgramManager, collecting results into a new EPM.
196+
This is where the iteration over methods lives -- not in the pass manager, and not
197+
in EdgeProgramManager.transform().
198+
"""
199+
200+
def __init__(self, ep_pass: ExportedProgramPassBase) -> None:
201+
super().__init__()
202+
self._pass = ep_pass
203+
204+
def call(
205+
self, epm: "EdgeProgramManager"
206+
) -> EdgeProgramManagerPassResult:
207+
new_epm = copy.copy(epm)
208+
new_epm._edge_programs = dict(epm._edge_programs)
209+
210+
overall_modified = False
211+
for name, program in epm._edge_programs.items():
212+
result = self._pass(program)
213+
new_epm._edge_programs[name] = result.exported_program
214+
overall_modified = overall_modified or result.modified
215+
216+
new_epm._config_methods = epm._config_methods
217+
return EdgeProgramManagerPassResult(new_epm, overall_modified)
218+
219+
220+
PassType = Union[
221+
EdgeProgramManagerPassBase,
222+
ExportedProgramPassBase,
223+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
224+
]
225+
226+
# Passes that operate on a single method (ExportedProgram or GraphModule level).
227+
# Excludes EdgeProgramManagerPassBase, which operates on the whole EdgeProgramManager.
228+
# Use this for per-method pass specifications (e.g. Dict[str, Sequence[MethodPassType]]).
229+
MethodPassType = Union[
230+
ExportedProgramPassBase,
231+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
232+
]
233+
234+
235+
def _get_pass_name(fn: PassType) -> str:
236+
"""Unwraps wrapper chain to get the underlying pass name."""
237+
import inspect
238+
239+
if isinstance(fn, ExportedProgramToEdgeProgramManagerPassWrapper):
240+
return _get_pass_name(fn._pass)
241+
if isinstance(fn, GraphModuleBackedExportedProgramPassWrapper):
242+
return _get_pass_name(fn._pass)
243+
return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
244+
245+
246+
def wrap_passes(
247+
passes: Sequence[PassType],
248+
) -> list[EdgeProgramManagerPassBase]:
249+
"""
250+
Wraps a list of mixed-level passes up to the EdgeProgramManager level.
251+
252+
Accepts passes at three levels:
253+
- EdgeProgramManagerPassBase: used as-is
254+
- ExportedProgramPassBase: wrapped with ExportedProgramToEdgeProgramManagerPassWrapper
255+
- GraphModule callables: wrapped with GraphModuleBackedExportedProgramPassWrapper
256+
then ExportedProgramToEdgeProgramManagerPassWrapper
257+
258+
Args:
259+
passes: A sequence of passes at any level.
260+
261+
Returns:
262+
A list of EdgeProgramManagerPassBase passes.
263+
"""
264+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
265+
266+
wrapped: list[EdgeProgramManagerPassBase] = []
267+
for fn in passes:
268+
if isinstance(fn, EdgeProgramManagerPassBase):
269+
wrapped.append(fn)
270+
elif isinstance(fn, ExportedProgramPassBase):
271+
wrapped.append(
272+
ExportedProgramToEdgeProgramManagerPassWrapper(fn)
273+
)
274+
else:
275+
assert callable(fn)
276+
ep_pass = GraphModuleBackedExportedProgramPassWrapper(
277+
pass_result_wrapper(fn)
278+
)
279+
wrapped.append(
280+
ExportedProgramToEdgeProgramManagerPassWrapper(ep_pass)
281+
)
282+
return wrapped
283+
284+
285+
class MethodFilteredEdgeProgramManagerPass(EdgeProgramManagerPassBase):
286+
"""
287+
Applies different passes to different methods in an EdgeProgramManager.
288+
289+
Converts the Dict[str, Sequence[MethodPassType]] pattern (previously handled inline
290+
in EdgeProgramManager.transform) into a proper pass. Used by
291+
to_edge_transform_and_lower to handle the dict case.
292+
"""
293+
294+
def __init__(self, passes_dict: Dict[str, Sequence[MethodPassType]]) -> None:
295+
super().__init__()
296+
self._passes_dict = passes_dict
297+
298+
def call(
299+
self, epm: "EdgeProgramManager"
300+
) -> EdgeProgramManagerPassResult:
301+
from executorch.exir.program._program import _transform
302+
303+
new_epm = copy.copy(epm)
304+
new_epm._edge_programs = dict(epm._edge_programs)
305+
306+
overall_modified = False
307+
for name, program in epm._edge_programs.items():
308+
if name in self._passes_dict:
309+
new_program = _transform(program, *self._passes_dict[name])
310+
new_epm._edge_programs[name] = new_program
311+
overall_modified = True
312+
313+
return EdgeProgramManagerPassResult(new_epm, overall_modified)

exir/program/BUCK

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
22
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
43

54
oncall("executorch")
65

@@ -47,6 +46,7 @@ fbcode_target(_kind = runtime.python_library,
4746
"//executorch/exir/passes:spec_prop_pass",
4847
"//executorch/exir/passes:weights_to_outputs_pass",
4948
"//executorch/exir/passes:convert_constant_dim_order_pass",
49+
"//executorch/exir:edge_program_manager_pass_base",
5050
"//executorch/exir/verification:verifier",
5151
"//executorch/extension/flat_tensor/serialize:serialize",
5252
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])

0 commit comments

Comments
 (0)