Skip to content

Commit e26e5e8

Browse files
authored
refactor: remove duplication in grpc transport init (#808)
* refactor: remove duplication in grpc transport init * refactor: let base transport set final host, creds, scopes * refactor: remove scopes or self.AUTH_SCOPES in grpc * docs: explain self._prep_wrapped_messages
1 parent 3268ba7 commit e26e5e8

File tree

4 files changed

+90
-126
lines changed

4 files changed

+90
-126
lines changed

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class {{ service.name }}Transport(abc.ABC):
8080
host += ':443'
8181
self._host = host
8282

83+
# Save the scopes.
84+
self._scopes = scopes or self.AUTH_SCOPES
85+
8386
# If no credentials are provided, then determine the appropriate
8487
# defaults.
8588
if credentials and credentials_file:
@@ -88,19 +91,16 @@ class {{ service.name }}Transport(abc.ABC):
8891
if credentials_file is not None:
8992
credentials, _ = auth.load_credentials_from_file(
9093
credentials_file,
91-
scopes=scopes,
94+
scopes=self._scopes,
9295
quota_project_id=quota_project_id
9396
)
9497

9598
elif credentials is None:
96-
credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id)
99+
credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id)
97100

98101
# Save the credentials.
99102
self._credentials = credentials
100103

101-
# Lifted into its own function so it can be stubbed out during tests.
102-
self._prep_wrapped_messages(client_info)
103-
104104

105105
def _prep_wrapped_messages(self, client_info):
106106
# Precompute the wrapped methods.

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -101,91 +101,73 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
101101
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
102102
and ``credentials_file`` are passed.
103103
"""
104+
self._grpc_channel = None
104105
self._ssl_channel_credentials = ssl_channel_credentials
106+
self._stubs: Dict[str, Callable] = {}
107+
{%- if service.has_lro %}
108+
self._operations_client = None
109+
{%- endif %}
105110

106111
if api_mtls_endpoint:
107112
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
108113
if client_cert_source:
109114
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
110115

111116
if channel:
112-
# Sanity check: Ensure that channel and credentials are not both
113-
# provided.
117+
# Ignore credentials if a channel was passed.
114118
credentials = False
115-
116119
# If a channel was explicitly provided, set it.
117120
self._grpc_channel = channel
118121
self._ssl_channel_credentials = None
119-
elif api_mtls_endpoint:
120-
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
121-
122-
if credentials is None:
123-
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
124-
125-
# Create SSL credentials with client_cert_source or application
126-
# default SSL credentials.
127-
if client_cert_source:
128-
cert, key = client_cert_source()
129-
ssl_credentials = grpc.ssl_channel_credentials(
130-
certificate_chain=cert, private_key=key
131-
)
132-
else:
133-
ssl_credentials = SslCredentials().ssl_credentials
134-
135-
# create a new channel. The provided one is ignored.
136-
self._grpc_channel = type(self).create_channel(
137-
host,
138-
credentials=credentials,
139-
credentials_file=credentials_file,
140-
ssl_credentials=ssl_credentials,
141-
scopes=scopes or self.AUTH_SCOPES,
142-
quota_project_id=quota_project_id,
143-
options=[
144-
("grpc.max_send_message_length", -1),
145-
("grpc.max_receive_message_length", -1),
146-
],
147-
)
148-
self._ssl_channel_credentials = ssl_credentials
122+
149123
else:
150-
host = host if ":" in host else host + ":443"
151-
152-
if credentials is None:
153-
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
154-
155-
if client_cert_source_for_mtls and not ssl_channel_credentials:
156-
cert, key = client_cert_source_for_mtls()
157-
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
158-
certificate_chain=cert, private_key=key
159-
)
124+
if api_mtls_endpoint:
125+
host = api_mtls_endpoint
126+
127+
# Create SSL credentials with client_cert_source or application
128+
# default SSL credentials.
129+
if client_cert_source:
130+
cert, key = client_cert_source()
131+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
132+
certificate_chain=cert, private_key=key
133+
)
134+
else:
135+
self._ssl_channel_credentials = SslCredentials().ssl_credentials
160136

161-
# create a new channel. The provided one is ignored.
137+
else:
138+
if client_cert_source_for_mtls and not ssl_channel_credentials:
139+
cert, key = client_cert_source_for_mtls()
140+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
141+
certificate_chain=cert, private_key=key
142+
)
143+
144+
# The base transport sets the host, credentials and scopes
145+
super().__init__(
146+
host=host,
147+
credentials=credentials,
148+
credentials_file=credentials_file,
149+
scopes=scopes,
150+
quota_project_id=quota_project_id,
151+
client_info=client_info,
152+
)
153+
154+
if not self._grpc_channel:
162155
self._grpc_channel = type(self).create_channel(
163-
host,
164-
credentials=credentials,
156+
self._host,
157+
credentials=self._credentials,
165158
credentials_file=credentials_file,
159+
scopes=self._scopes,
166160
ssl_credentials=self._ssl_channel_credentials,
167-
scopes=scopes or self.AUTH_SCOPES,
168161
quota_project_id=quota_project_id,
169162
options=[
170163
("grpc.max_send_message_length", -1),
171164
("grpc.max_receive_message_length", -1),
172165
],
173166
)
174167

175-
self._stubs = {} # type: Dict[str, Callable]
176-
{%- if service.has_lro %}
177-
self._operations_client = None
178-
{%- endif %}
168+
# Wrap messages. This must be done after self._grpc_channel exists
169+
self._prep_wrapped_messages(client_info)
179170

180-
# Run the base constructor.
181-
super().__init__(
182-
host=host,
183-
credentials=credentials,
184-
credentials_file=credentials_file,
185-
scopes=scopes or self.AUTH_SCOPES,
186-
quota_project_id=quota_project_id,
187-
client_info=client_info,
188-
)
189171

190172
@classmethod
191173
def create_channel(cls,
@@ -197,7 +179,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
197179
**kwargs) -> grpc.Channel:
198180
"""Create and return a gRPC channel object.
199181
Args:
200-
address (Optional[str]): The host for the channel to use.
182+
host (Optional[str]): The host for the channel to use.
201183
credentials (Optional[~.Credentials]): The
202184
authorization credentials to attach to requests. These
203185
credentials identify this application to the service. If

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2

Lines changed: 41 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
5656
**kwargs) -> aio.Channel:
5757
"""Create and return a gRPC AsyncIO channel object.
5858
Args:
59-
address (Optional[str]): The host for the channel to use.
59+
host (Optional[str]): The host for the channel to use.
6060
credentials (Optional[~.Credentials]): The
6161
authorization credentials to attach to requests. These
6262
credentials identify this application to the service. If
@@ -145,91 +145,72 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
145145
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
146146
and ``credentials_file`` are passed.
147147
"""
148+
self._grpc_channel = None
148149
self._ssl_channel_credentials = ssl_channel_credentials
150+
self._stubs: Dict[str, Callable] = {}
151+
{%- if service.has_lro %}
152+
self._operations_client = None
153+
{%- endif %}
149154

150155
if api_mtls_endpoint:
151156
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
152157
if client_cert_source:
153158
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
154159

155160
if channel:
156-
# Sanity check: Ensure that channel and credentials are not both
157-
# provided.
161+
# Ignore credentials if a channel was passed.
158162
credentials = False
159-
160163
# If a channel was explicitly provided, set it.
161164
self._grpc_channel = channel
162165
self._ssl_channel_credentials = None
163-
elif api_mtls_endpoint:
164-
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
165-
166-
if credentials is None:
167-
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
168-
169-
# Create SSL credentials with client_cert_source or application
170-
# default SSL credentials.
171-
if client_cert_source:
172-
cert, key = client_cert_source()
173-
ssl_credentials = grpc.ssl_channel_credentials(
174-
certificate_chain=cert, private_key=key
175-
)
176-
else:
177-
ssl_credentials = SslCredentials().ssl_credentials
178166

179-
# create a new channel. The provided one is ignored.
180-
self._grpc_channel = type(self).create_channel(
181-
host,
182-
credentials=credentials,
183-
credentials_file=credentials_file,
184-
ssl_credentials=ssl_credentials,
185-
scopes=scopes or self.AUTH_SCOPES,
186-
quota_project_id=quota_project_id,
187-
options=[
188-
("grpc.max_send_message_length", -1),
189-
("grpc.max_receive_message_length", -1),
190-
],
191-
)
192-
self._ssl_channel_credentials = ssl_credentials
193167
else:
194-
host = host if ":" in host else host + ":443"
168+
if api_mtls_endpoint:
169+
host = api_mtls_endpoint
170+
171+
# Create SSL credentials with client_cert_source or application
172+
# default SSL credentials.
173+
if client_cert_source:
174+
cert, key = client_cert_source()
175+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
176+
certificate_chain=cert, private_key=key
177+
)
178+
else:
179+
self._ssl_channel_credentials = SslCredentials().ssl_credentials
180+
181+
else:
182+
if client_cert_source_for_mtls and not ssl_channel_credentials:
183+
cert, key = client_cert_source_for_mtls()
184+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
185+
certificate_chain=cert, private_key=key
186+
)
195187

196-
if credentials is None:
197-
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
188+
# The base transport sets the host, credentials and scopes
189+
super().__init__(
190+
host=host,
191+
credentials=credentials,
192+
credentials_file=credentials_file,
193+
scopes=scopes,
194+
quota_project_id=quota_project_id,
195+
client_info=client_info,
196+
)
198197

199-
if client_cert_source_for_mtls and not ssl_channel_credentials:
200-
cert, key = client_cert_source_for_mtls()
201-
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
202-
certificate_chain=cert, private_key=key
203-
)
204-
205-
# create a new channel. The provided one is ignored.
198+
if not self._grpc_channel:
206199
self._grpc_channel = type(self).create_channel(
207-
host,
208-
credentials=credentials,
200+
self._host,
201+
credentials=self._credentials,
209202
credentials_file=credentials_file,
203+
scopes=self._scopes,
210204
ssl_credentials=self._ssl_channel_credentials,
211-
scopes=scopes or self.AUTH_SCOPES,
212205
quota_project_id=quota_project_id,
213206
options=[
214207
("grpc.max_send_message_length", -1),
215208
("grpc.max_receive_message_length", -1),
216209
],
217210
)
218211

219-
# Run the base constructor.
220-
super().__init__(
221-
host=host,
222-
credentials=credentials,
223-
credentials_file=credentials_file,
224-
scopes=scopes or self.AUTH_SCOPES,
225-
quota_project_id=quota_project_id,
226-
client_info=client_info,
227-
)
228-
229-
self._stubs = {}
230-
{%- if service.has_lro %}
231-
self._operations_client = None
232-
{%- endif %}
212+
# Wrap messages. This must be done after self._grpc_channel exists
213+
self._prep_wrapped_messages(client_info)
233214

234215
@property
235216
def grpc_channel(self) -> aio.Channel:

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
9292
{%- endif %}
9393
if client_cert_source_for_mtls:
9494
self._session.configure_mtls_channel(client_cert_source_for_mtls)
95+
self._prep_wrapped_messages(client_info)
9596

9697
{%- if service.has_lro %}
9798

0 commit comments

Comments
 (0)