3131import dataclasses
3232import re
3333from itertools import chain
34- from typing import (cast , Dict , FrozenSet , List , Mapping , Optional ,
34+ from typing import (cast , Dict , FrozenSet , Iterable , List , Mapping , Optional ,
3535 Sequence , Set , Union )
3636
3737from google .api import annotations_pb2 # type: ignore
@@ -225,7 +225,6 @@ def __hash__(self):
225225
226226 @utils .cached_property
227227 def field_types (self ) -> Sequence [Union ['MessageType' , 'EnumType' ]]:
228- """Return all composite fields used in this proto's messages."""
229228 answer = tuple (
230229 field .type
231230 for field in self .fields .values ()
@@ -234,6 +233,23 @@ def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
234233
235234 return answer
236235
236+ @utils .cached_property
237+ def recursive_field_types (self ) -> Sequence [
238+ Union ['MessageType' , 'EnumType' ]
239+ ]:
240+ """Return all composite fields used in this proto's messages."""
241+ types : List [Union ['MessageType' , 'EnumType' ]] = []
242+ stack = [iter (self .fields .values ())]
243+ while stack :
244+ fields_iter = stack .pop ()
245+ for field in fields_iter :
246+ if field .message and field .type not in types :
247+ stack .append (iter (field .message .fields .values ()))
248+ if not field .is_primitive :
249+ types .append (field .type )
250+
251+ return tuple (types )
252+
237253 @property
238254 def map (self ) -> bool :
239255 """Return True if the given message is a map, False otherwise."""
@@ -654,19 +670,30 @@ def paged_result_field(self) -> Optional[Field]:
654670
655671 @utils .cached_property
656672 def ref_types (self ) -> Sequence [Union [MessageType , EnumType ]]:
673+ return self ._ref_types (True )
674+
675+ @utils .cached_property
676+ def flat_ref_types (self ) -> Sequence [Union [MessageType , EnumType ]]:
677+ return self ._ref_types (False )
678+
679+ def _ref_types (self , recursive : bool ) -> Sequence [Union [MessageType , EnumType ]]:
657680 """Return types referenced by this method."""
658681 # Begin with the input (request) and output (response) messages.
659- answer = [self .input ]
682+ answer : List [Union [MessageType , EnumType ]] = [self .input ]
683+ types : Iterable [Union [MessageType , EnumType ]] = (
684+ self .input .recursive_field_types if recursive
685+ else (
686+ f .type
687+ for f in self .flattened_fields .values ()
688+ if f .message or f .enum
689+ )
690+ )
691+ answer .extend (types )
692+
660693 if not self .void :
661694 answer .append (self .client_output )
662695 answer .extend (self .client_output .field_types )
663696
664- answer .extend (
665- field .type
666- for field in self .flattened_fields .values ()
667- if field .message or field .enum
668- )
669-
670697 # If this method has LRO, it is possible (albeit unlikely) that
671698 # the LRO messages reside in a different module.
672699 if self .lro :
0 commit comments