55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8-
9- from typing import Callable , List , Optional , Union
8+ import inspect
9+ import logging
10+ from typing import Callable , List , Optional , TYPE_CHECKING , TypeAlias , Union
1011
1112import torch
1213import torch .fx .passes .infra .pass_manager as fx
1314import torch .utils ._pytree as pytree
1415from executorch .exir .error import ExportError , ExportErrorType
16+
1517from torch .fx .passes .infra .pass_base import PassResult
16- from typing_extensions import TypeAlias
1718
18- PassType : TypeAlias = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
19+ if TYPE_CHECKING :
20+ from executorch .exir .program ._program import EdgeProgramManager
21+ from executorch .exir .program .edge_program_manager_pass_base import (
22+ EdgeProgramManagerPassBase ,
23+ EdgeProgramManagerPassResult ,
24+ ExportedProgramPassBase ,
25+ )
26+
27+ logger = logging .getLogger (__name__ )
28+ logger .setLevel (logging .WARNING )
29+
30+ PassType : TypeAlias = Union [
31+ "EdgeProgramManagerPassBase" ,
32+ "ExportedProgramPassBase" ,
33+ Callable [[torch .fx .GraphModule ], Optional [PassResult ]],
34+ ]
35+
36+
37+ def _get_pass_name (fn : PassType ) -> str :
38+ from executorch .exir .program .edge_program_manager_pass_base import (
39+ ExportedProgramToEdgeProgramManagerPassWrapper ,
40+ GraphModuleBackedExportedProgramPassWrapper ,
41+ GraphModuleToEdgeProgramManagerPassWrapper ,
42+ )
43+
44+ if isinstance (fn , ExportedProgramToEdgeProgramManagerPassWrapper ):
45+ return _get_pass_name (fn ._pass )
46+ if isinstance (fn , GraphModuleToEdgeProgramManagerPassWrapper ):
47+ return _get_pass_name (fn ._inner ._pass )
48+ if isinstance (fn , GraphModuleBackedExportedProgramPassWrapper ):
49+ return _get_pass_name (fn ._pass )
50+ return fn .__name__ if inspect .isfunction (fn ) else type (fn ).__name__
1951
2052
2153class PassManager (fx .PassManager ):
@@ -27,6 +59,8 @@ class PassManager(fx.PassManager):
2759 * **passes**: A list of callable passes
2860 * **params**: An instance of PassManagerParams containing the result of the
2961 flags set in the constructor.
62+
63+ Note: This class is deprecated. Please use EdgeProgramManagerPassManager instead.
3064 """
3165
3266 def __init__ (
@@ -41,6 +75,9 @@ def __init__(
4175 enable_debug_pass: set to true to enable the debug passes
4276 run_checks_after_each_pass: whether to run checks and linting after each pass
4377 """
78+ logger .warning (
79+ "PassManager is deprecated. Please use EdgeProgramManagerPassManager instead."
80+ )
4481
4582 # Flatten the passes to a list of callables
4683 passes = passes if passes else []
@@ -76,3 +113,138 @@ def check(self, module: torch.nn.Module) -> None:
76113 ExportErrorType .NOT_SUPPORTED ,
77114 f"call_method `{ node } ` is not supported except for backend delegate." ,
78115 )
116+
117+
118+ class EdgeProgramManagerPassManager :
119+ """
120+ Runs multiple passes on an EdgeProgramManager.
121+
122+ This PassManager accepts passes at three levels of abstraction:
123+ - EdgeProgramManagerPassBase: operates on the full EdgeProgramManager
124+ - ExportedProgramPassBase: operates on individual ExportedPrograms
125+ - Callable[[GraphModule], PassResult]: operates on GraphModules
126+
127+ Lower-level passes are automatically wrapped to operate at the EPM level.
128+ The iteration over methods within an EPM is handled by the wrapper passes,
129+ not by this pass manager.
130+ """
131+
132+ def __init__ (
133+ self ,
134+ passes : Optional [Union [List [PassType ], List [List [PassType ]]]] = None ,
135+ constraints : Optional [List [Callable [[Callable , Callable ], bool ]]] = None ,
136+ run_checks_after_each_pass : bool = False ,
137+ suppress_exported_program_pre_verification : bool = True ,
138+ steps : int = 1 ,
139+ ) -> None :
140+ from executorch .exir .program .edge_program_manager_pass_base import (
141+ EdgeProgramManagerPassBase ,
142+ ExportedProgramPassBase ,
143+ ExportedProgramToEdgeProgramManagerPassWrapper ,
144+ GraphModuleToEdgeProgramManagerPassWrapper ,
145+ )
146+
147+ wrapped_passes : List [EdgeProgramManagerPassBase ] = []
148+ for fn in pytree .tree_flatten (passes )[0 ] if passes else []:
149+ if isinstance (fn , EdgeProgramManagerPassBase ):
150+ wrapped_passes .append (fn )
151+ elif isinstance (fn , ExportedProgramPassBase ):
152+ wrapped_passes .append (
153+ ExportedProgramToEdgeProgramManagerPassWrapper (fn )
154+ )
155+ else :
156+ wrapped_passes .append (
157+ GraphModuleToEdgeProgramManagerPassWrapper (fn )
158+ )
159+
160+ if suppress_exported_program_pre_verification :
161+ logger .warning (
162+ "Pre-verification of exported program is suppressed. This means that the exported program may pass validation prior to running the pass manager."
163+ )
164+
165+ self .passes : List ["EdgeProgramManagerPassBase" ] = wrapped_passes
166+ self .constraints = constraints
167+ self .run_checks_after_each_pass = run_checks_after_each_pass
168+ self .suppress_check_failures = suppress_exported_program_pre_verification
169+ self .steps = steps
170+ self ._validated = False
171+
172+ def solve_constraints (self ) -> None :
173+ """Placeholder for constraint solving. Marks the pass manager as validated."""
174+ self ._validated = True
175+
176+ def check (self , epm : "EdgeProgramManager" ) -> None :
177+ """
178+ Runs exported program validation on each method in the EdgeProgramManager.
179+ """
180+ for name , program in epm ._edge_programs .items (): # noqa: B007
181+ if not self .suppress_check_failures :
182+ program .validate ()
183+
184+ module = program .graph_module
185+ module .recompile ()
186+ module .graph .lint ()
187+
188+ for node in module .graph .nodes :
189+ if node .op == "call_method" :
190+ raise ExportError (
191+ ExportErrorType .NOT_SUPPORTED ,
192+ f"call_method `{ node } ` is not supported except for backend delegate." ,
193+ )
194+
195+ def __call__ (
196+ self , epm : "EdgeProgramManager"
197+ ) -> "EdgeProgramManagerPassResult" :
198+ """
199+ Runs passes on an EdgeProgramManager.
200+
201+ Args:
202+ epm: The EdgeProgramManager to transform.
203+
204+ Returns:
205+ EdgeProgramManagerPassResult containing the transformed EPM and whether
206+ or not it was modified.
207+ """
208+ from executorch .exir .program .edge_program_manager_pass_base import (
209+ EdgeProgramManagerPassResult ,
210+ )
211+
212+ if not self ._validated :
213+ self .solve_constraints ()
214+
215+ self .check (epm )
216+
217+ overall_modified = False
218+
219+ for _ in range (self .steps ):
220+ step_modified = False
221+
222+ for i , fn in enumerate (self .passes ):
223+ try :
224+ result = fn (epm )
225+ if result .modified :
226+ logger .debug (
227+ "EPM after pass '%s'" ,
228+ _get_pass_name (fn ),
229+ )
230+
231+ epm = result .edge_program_manager
232+ step_modified = step_modified or result .modified
233+
234+ if self .run_checks_after_each_pass :
235+ self .check (epm )
236+
237+ except Exception as e :
238+ prev_names = [_get_pass_name (p ) for p in self .passes [:i ]]
239+ msg = (
240+ f"An error occurred when running the '{ _get_pass_name (fn )} ' pass "
241+ f"after the following passes: { prev_names } \n "
242+ f"Original error: { e } "
243+ )
244+ raise type (e )(msg ) from e
245+
246+ overall_modified = overall_modified or step_modified
247+ if not step_modified :
248+ break
249+
250+ return EdgeProgramManagerPassResult (epm , overall_modified )
0 commit comments