55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8-
9- from typing import Callable , List , Optional , Union
8+ import copy
9+ import inspect
10+ import logging
11+ from typing import Callable , List , Optional , TypeAlias , Union
1012
1113import torch
1214import torch .fx .passes .infra .pass_manager as fx
1315import torch .utils ._pytree as pytree
1416from executorch .exir .error import ExportError , ExportErrorType
17+ from executorch .exir .pass_base import (
18+ ExportedProgramPassResult ,
19+ ExportedProgramPassBase ,
20+ )
21+ from torch .export import ExportedProgram
1522from torch .fx .passes .infra .pass_base import PassResult
16- from typing_extensions import TypeAlias
23+ from torch .fx .passes .infra .pass_manager import pass_result_wrapper
24+
25+ logger = logging .getLogger (__name__ )
26+ logger .setLevel (logging .WARNING )
27+
28+ PassType : TypeAlias = Union [
29+ ExportedProgramPassBase , Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
30+ ]
31+
1732
18- PassType : TypeAlias = Callable [[torch .fx .GraphModule ], Optional [PassResult ]]
33+ def _get_pass_name (fn : PassType ) -> str :
34+ """Returns a human-readable name for a pass."""
35+ return fn .__name__ if inspect .isfunction (fn ) else type (fn ).__name__
1936
2037
38+ def _is_graph_module_pass (fn : PassType ) -> bool :
39+ """Returns True if the pass operates on GraphModule (not ExportedProgram)."""
40+ return not isinstance (fn , ExportedProgramPassBase )
41+
2142class PassManager (fx .PassManager ):
2243 """
23- Class to run multiple passes on a given graph module. The PassManager is
24- callable so to run it, we can just call the PassManager instance.
44+ Runs multiple passes on a GraphModule.
45+
46+ This is the legacy PassManager that extends torch.fx.passes.infra.pass_manager.PassManager.
47+ Use this when you need to run passes on a GraphModule directly.
2548
26- Private Attributes:
27- * **passes**: A list of callable passes
28- * **params**: An instance of PassManagerParams containing the result of the
29- flags set in the constructor.
49+ For running passes on ExportedProgram, use ExportedProgramPassManager instead.
3050 """
3151
3252 def __init__ (
3353 self ,
3454 passes : Optional [Union [List [PassType ], List [List [PassType ]]]] = None ,
3555 run_checks_after_each_pass : bool = False ,
3656 suppress_check_failures : bool = False ,
57+ steps : int = 1 ,
3758 ) -> None :
38- r"""
39- Args:
40- passes: A list of passes
41- enable_debug_pass: set to true to enable the debug passes
42- run_checks_after_each_pass: whether to run checks and linting after each pass
43- """
44-
59+ logger .warning (
60+ "PassManager is deprecated. Please use ExportedProgramPassManager instead."
61+ )
4562 # Flatten the passes to a list of callables
4663 passes = passes if passes else []
4764 flattened_passes = [
@@ -52,6 +69,7 @@ def __init__(
5269 flattened_passes ,
5370 run_checks_after_each_pass = run_checks_after_each_pass ,
5471 suppress_check_failures = suppress_check_failures ,
72+ steps = steps ,
5573 )
5674
5775 def check (self , module : torch .nn .Module ) -> None :
@@ -65,14 +83,158 @@ def check(self, module: torch.nn.Module) -> None:
6583 node's spec field is a tuple)
6684 - Ensure that the graph module has type torch.fx.GraphModule
6785 """
68- assert isinstance (module , fx .GraphModule )
86+ assert isinstance (module , torch .fx .GraphModule )
87+ module .recompile ()
88+ module .graph .lint ()
89+
90+ for node in module .graph .nodes :
91+ if node .op == "call_method" :
92+ raise ExportError (
93+ ExportErrorType .NOT_SUPPORTED ,
94+ f"call_method `{ node } ` is not supported except for backend delegate." ,
95+ )
96+
97+ class ExportedProgramPassManager (fx .PassManager ):
98+ """
99+ Runs multiple passes on an ExportedProgram.
100+
101+ This PassManager is specifically designed for ExportedProgram and supports
102+ both GraphModule-only passes and ExportedProgram-aware passes.
103+
104+ For running passes on GraphModule directly, use PassManager instead.
105+ """
106+
107+ def __init__ (
108+ self ,
109+ passes : Optional [Union [List [PassType ], List [List [PassType ]]]] = None ,
110+ constraints : Optional [List [Callable [[Callable , Callable ], bool ]]] = None ,
111+ run_checks_after_each_pass : bool = False ,
112+ suppress_check_failures : bool = False ,
113+ steps : int = 1 ,
114+ ) -> None :
115+ wrapped_passes = (
116+ [
117+ fn if isinstance (fn , ExportedProgramPassBase ) else pass_result_wrapper (fn )
118+ for fn in pytree .tree_flatten (passes )[0 ]
119+ ]
120+ if passes
121+ else []
122+ )
123+
124+ super ().__init__ (
125+ wrapped_passes ,
126+ constraints = constraints ,
127+ run_checks_after_each_pass = run_checks_after_each_pass ,
128+ suppress_check_failures = suppress_check_failures ,
129+ steps = steps ,
130+ )
131+
132+ def check (self , module : torch .fx .GraphModule ) -> None :
133+ """Validates graph module invariants."""
69134 module .recompile ()
70135 module .graph .lint ()
71- # TODO(qihan): use verifier.check_is_exir
72136
73137 for node in module .graph .nodes :
74138 if node .op == "call_method" :
75139 raise ExportError (
76140 ExportErrorType .NOT_SUPPORTED ,
77141 f"call_method `{ node } ` is not supported except for backend delegate." ,
78142 )
143+
144+ def _run_graph_module_pass (
145+ self ,
146+ fn : PassType ,
147+ graph_module : torch .fx .GraphModule ,
148+ ) -> PassResult :
149+ """Runs a pass that operates on GraphModule."""
150+ res = fn (graph_module )
151+
152+ if res is None :
153+ raise TypeError (
154+ f"The result of pass { _get_pass_name (fn )} should be type PassResult. "
155+ "Please wrap it with pass_result_wrapper()"
156+ )
157+
158+ if res .modified :
159+ logger .debug (
160+ "Graph after pass '%s': %s" , _get_pass_name (fn ), res .graph_module .graph
161+ )
162+ res .graph_module .recompile ()
163+
164+ return res
165+
166+ def _run_exported_program_pass (
167+ self ,
168+ fn : ExportedProgramPassBase ,
169+ exported_program : ExportedProgram ,
170+ ) -> ExportedProgramPassResult :
171+ """Runs a pass that operates on ExportedProgram."""
172+ res = fn (exported_program )
173+
174+ if res .modified :
175+ logger .debug (
176+ "Graph after pass '%s': %s" ,
177+ _get_pass_name (fn ),
178+ res .exported_program .graph_module .graph ,
179+ )
180+ res .exported_program .graph_module .recompile ()
181+
182+ return res
183+
184+ # pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram
185+ def __call__ (self , exported_program : ExportedProgram ) -> ExportedProgramPassResult :
186+ """
187+ Runs passes on an ExportedProgram.
188+
189+ Handles both GraphModule-only passes and ExportedProgram-aware passes.
190+
191+ Args:
192+ exported_program: The exported program to transform.
193+
194+ Returns:
195+ ExportedProgramPassResult containing the transformed program.
196+ """
197+ if not self ._validated :
198+ self .solve_constraints ()
199+
200+ exported_program = copy .copy (exported_program )
201+
202+ # Check graph invariants before running passes
203+ self .check (exported_program .graph_module )
204+
205+ overall_modified = False
206+
207+ for _ in range (self .steps ):
208+ step_modified = False
209+
210+ for i , fn in enumerate (self .passes ):
211+ try :
212+ if _is_graph_module_pass (fn ):
213+ result = self ._run_graph_module_pass (
214+ fn , exported_program .graph_module
215+ )
216+ exported_program ._graph_module = result .graph_module
217+ step_modified = step_modified or result .modified
218+
219+ if self .run_checks_after_each_pass :
220+ self .check (result .graph_module )
221+ else :
222+ assert isinstance (fn , ExportedProgramPassBase )
223+ result = self ._run_exported_program_pass (fn , exported_program )
224+ exported_program = result .exported_program
225+ step_modified = step_modified or result .modified
226+
227+ if self .run_checks_after_each_pass :
228+ exported_program .validate ()
229+ self .check (exported_program .graph_module )
230+
231+ except Exception as e :
232+ prev_names = [_get_pass_name (p ) for p in self .passes [:i ]]
233+ msg = f"An error occurred when running the '{ _get_pass_name (fn )} ' pass after the following passes: { prev_names } "
234+ raise Exception (msg ) from e # noqa: TRY002
235+
236+ overall_modified = overall_modified or step_modified
237+ if not step_modified :
238+ break
239+
240+ return ExportedProgramPassResult (exported_program , overall_modified )
0 commit comments