Skip to content

Commit 9fcb647

Browse files
aseeringolavloite
andauthored
feat(spanner): add Client Context support to options (#1499)
Re-opening #1495 due to permissions issues. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-spanner/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕 --------- Co-authored-by: Knut Olav Løite <koloite@gmail.com>
1 parent 12773d7 commit 9fcb647

21 files changed

+785
-76
lines changed

google/cloud/spanner_v1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .types.spanner import BatchWriteRequest
3939
from .types.spanner import BatchWriteResponse
4040
from .types.spanner import BeginTransactionRequest
41+
from .types.spanner import ClientContext
4142
from .types.spanner import CommitRequest
4243
from .types.spanner import CreateSessionRequest
4344
from .types.spanner import DeleteSessionRequest
@@ -110,6 +111,7 @@
110111
"BatchWriteRequest",
111112
"BatchWriteResponse",
112113
"BeginTransactionRequest",
114+
"ClientContext",
113115
"CommitRequest",
114116
"CommitResponse",
115117
"CreateSessionRequest",

google/cloud/spanner_v1/_helpers.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from google.cloud._helpers import _date_from_iso8601_date
3535
from google.cloud.spanner_v1.types import ExecuteSqlRequest
3636
from google.cloud.spanner_v1.types import TransactionOptions
37+
from google.cloud.spanner_v1.types import ClientContext
38+
from google.cloud.spanner_v1.types import RequestOptions
3739
from google.cloud.spanner_v1.data_types import JsonObject, Interval
3840
from google.cloud.spanner_v1.request_id_header import (
3941
with_request_id,
@@ -172,15 +174,15 @@ def _merge_query_options(base, merge):
172174
If the resultant object only has empty fields, returns None.
173175
"""
174176
combined = base or ExecuteSqlRequest.QueryOptions()
175-
if type(combined) is dict:
177+
if isinstance(combined, dict):
176178
combined = ExecuteSqlRequest.QueryOptions(
177179
optimizer_version=combined.get("optimizer_version", ""),
178180
optimizer_statistics_package=combined.get(
179181
"optimizer_statistics_package", ""
180182
),
181183
)
182184
merge = merge or ExecuteSqlRequest.QueryOptions()
183-
if type(merge) is dict:
185+
if isinstance(merge, dict):
184186
merge = ExecuteSqlRequest.QueryOptions(
185187
optimizer_version=merge.get("optimizer_version", ""),
186188
optimizer_statistics_package=merge.get("optimizer_statistics_package", ""),
@@ -191,6 +193,95 @@ def _merge_query_options(base, merge):
191193
return combined
192194

193195

196+
def _merge_client_context(base, merge):
197+
"""Merge higher precedence ClientContext with current ClientContext.
198+
199+
:type base: :class:`~google.cloud.spanner_v1.types.ClientContext`
200+
or :class:`dict` or None
201+
:param base: The current ClientContext that is intended for use.
202+
203+
:type merge: :class:`~google.cloud.spanner_v1.types.ClientContext`
204+
or :class:`dict` or None
205+
:param merge:
206+
The ClientContext that has a higher priority than base. These options
207+
should overwrite the fields in base.
208+
209+
:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
210+
or None
211+
:returns:
212+
ClientContext object formed by merging the two given ClientContexts.
213+
"""
214+
if base is None and merge is None:
215+
return None
216+
217+
# Avoid in-place modification of base
218+
combined_pb = ClientContext()._pb
219+
if base:
220+
base_pb = ClientContext(base)._pb if isinstance(base, dict) else base._pb
221+
combined_pb.MergeFrom(base_pb)
222+
if merge:
223+
merge_pb = ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb
224+
combined_pb.MergeFrom(merge_pb)
225+
226+
combined = ClientContext(combined_pb)
227+
228+
if not combined.secure_context:
229+
return None
230+
return combined
231+
232+
233+
def _validate_client_context(client_context):
234+
"""Validate and convert client_context.
235+
236+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
237+
or :class:`dict`
238+
:param client_context: (Optional) Client context to use.
239+
240+
:rtype: :class:`~google.cloud.spanner_v1.types.ClientContext`
241+
:returns: Validated ClientContext object or None.
242+
:raises TypeError: if client_context is not a ClientContext or a dict.
243+
"""
244+
if client_context is not None:
245+
if isinstance(client_context, dict):
246+
client_context = ClientContext(client_context)
247+
elif not isinstance(client_context, ClientContext):
248+
raise TypeError("client_context must be a ClientContext or a dict")
249+
return client_context
250+
251+
252+
def _merge_request_options(request_options, client_context):
253+
"""Merge RequestOptions and ClientContext.
254+
255+
:type request_options: :class:`~google.cloud.spanner_v1.types.RequestOptions`
256+
or :class:`dict` or None
257+
:param request_options: The current RequestOptions that is intended for use.
258+
259+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
260+
or :class:`dict` or None
261+
:param client_context:
262+
The ClientContext to merge into request_options.
263+
264+
:rtype: :class:`~google.cloud.spanner_v1.types.RequestOptions`
265+
or None
266+
:returns:
267+
RequestOptions object formed by merging the given ClientContext.
268+
"""
269+
if request_options is None and client_context is None:
270+
return None
271+
272+
if request_options is None:
273+
request_options = RequestOptions()
274+
elif isinstance(request_options, dict):
275+
request_options = RequestOptions(request_options)
276+
277+
if client_context:
278+
request_options.client_context = _merge_client_context(
279+
client_context, request_options.client_context
280+
)
281+
282+
return request_options
283+
284+
194285
def _assert_numeric_precision_and_scale(value):
195286
"""
196287
Asserts that input numeric field is within Spanner supported range.

google/cloud/spanner_v1/batch.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
_metadata_with_prefix,
2929
_metadata_with_leader_aware_routing,
3030
_merge_Transaction_Options,
31+
_merge_client_context,
32+
_merge_request_options,
33+
_validate_client_context,
3134
AtomicCounter,
3235
)
3336
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
@@ -37,6 +40,7 @@
3740
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3841
from google.api_core.exceptions import InternalServerError
3942
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
43+
from google.cloud.spanner_v1.types import ClientContext
4044
import time
4145

4246
DEFAULT_RETRY_TIMEOUT_SECS = 30
@@ -47,9 +51,14 @@ class _BatchBase(_SessionWrapper):
4751
4852
:type session: :class:`~google.cloud.spanner_v1.session.Session`
4953
:param session: the session used to perform the commit
54+
55+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
56+
or :class:`dict`
57+
:param client_context: (Optional) Client context to use for all requests made
58+
by this batch.
5059
"""
5160

52-
def __init__(self, session):
61+
def __init__(self, session, client_context=None):
5362
super(_BatchBase, self).__init__(session)
5463

5564
self._mutations: List[Mutation] = []
@@ -58,6 +67,7 @@ def __init__(self, session):
5867
self.committed = None
5968
"""Timestamp at which the batch was successfully committed."""
6069
self.commit_stats: Optional[CommitResponse.CommitStats] = None
70+
self._client_context = _validate_client_context(client_context)
6171

6272
def insert(self, table, columns, values):
6373
"""Insert one or more new table rows.
@@ -227,10 +237,14 @@ def commit(
227237
txn_options,
228238
)
229239

240+
client_context = _merge_client_context(
241+
database._instance._client._client_context, self._client_context
242+
)
243+
request_options = _merge_request_options(request_options, client_context)
244+
230245
if request_options is None:
231246
request_options = RequestOptions()
232-
elif type(request_options) is dict:
233-
request_options = RequestOptions(request_options)
247+
234248
request_options.transaction_tag = self.transaction_tag
235249

236250
# Request tags are not supported for commit requests.
@@ -317,13 +331,25 @@ class MutationGroups(_SessionWrapper):
317331
318332
:type session: :class:`~google.cloud.spanner_v1.session.Session`
319333
:param session: the session used to perform the commit
334+
335+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
336+
or :class:`dict`
337+
:param client_context: (Optional) Client context to use for all requests made
338+
by this mutation group.
320339
"""
321340

322-
def __init__(self, session):
341+
def __init__(self, session, client_context=None):
323342
super(MutationGroups, self).__init__(session)
324343
self._mutation_groups: List[MutationGroup] = []
325344
self.committed: bool = False
326345

346+
if client_context is not None:
347+
if isinstance(client_context, dict):
348+
client_context = ClientContext(client_context)
349+
elif not isinstance(client_context, ClientContext):
350+
raise TypeError("client_context must be a ClientContext or a dict")
351+
self._client_context = client_context
352+
327353
def group(self):
328354
"""Returns a new `MutationGroup` to which mutations can be added."""
329355
mutation_group = BatchWriteRequest.MutationGroup()
@@ -365,10 +391,13 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
365391
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
366392
)
367393

394+
client_context = _merge_client_context(
395+
database._instance._client._client_context, self._client_context
396+
)
397+
request_options = _merge_request_options(request_options, client_context)
398+
368399
if request_options is None:
369400
request_options = RequestOptions()
370-
elif type(request_options) is dict:
371-
request_options = RequestOptions(request_options)
372401

373402
with trace_call(
374403
name="CloudSpanner.batch_write",

google/cloud/spanner_v1/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_merge_query_options,
5656
)
5757
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
58+
from google.cloud.spanner_v1._helpers import _validate_client_context
5859
from google.cloud.spanner_v1.instance import Instance
5960
from google.cloud.spanner_v1.metrics.constants import (
6061
METRIC_EXPORT_INTERVAL_MS,
@@ -228,6 +229,10 @@ class Client(ClientWithProject):
228229
:param disable_builtin_metrics: (Optional) Default False. Set to True to disable
229230
the Spanner built-in metrics collection and exporting.
230231
232+
:type client_context: :class:`~google.cloud.spanner_v1.types.RequestOptions.ClientContext`
233+
or :class:`dict`
234+
:param client_context: (Optional) Client context to use for all requests made by this client.
235+
231236
:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
232237
and ``admin`` are :data:`True`
233238
@@ -278,6 +283,7 @@ def __init__(
278283
default_transaction_options: Optional[DefaultTransactionOptions] = None,
279284
experimental_host=None,
280285
disable_builtin_metrics=False,
286+
client_context=None,
281287
use_plain_text=False,
282288
ca_certificate=None,
283289
client_certificate=None,
@@ -324,6 +330,7 @@ def __init__(
324330

325331
# Environment flag config has higher precedence than application config.
326332
self._query_options = _merge_query_options(query_options, env_query_options)
333+
self._client_context = _validate_client_context(client_context)
327334

328335
if self._emulator_host is not None and (
329336
"http://" in self._emulator_host or "https://" in self._emulator_host

0 commit comments

Comments
 (0)