Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/ndb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def context(
legacy_data (bool): Set to ``True`` (the default) to write data in
a way that can be read by the legacy version of NDB.
"""
context = context_module.get_context(False)
if context is not None:
raise RuntimeError("Context is already created for this thread.")

context = context_module.Context(
self,
cache_policy=cache_policy,
Expand Down
14 changes: 11 additions & 3 deletions google/cloud/ndb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,32 @@ def __init__(self):
_state = _LocalState()


def get_context():
def get_context(raise_context_error=True):
"""Get the current context.

This function should be called within a context established by
:meth:`google.cloud.ndb.client.Client.context`.

Args:
raise_context_error (bool): If set to :data:`True`, will raise an
exception if called outside of a context. Set this to :data:`False`
in order to have it just return :data:`None` if called outside of a
context. Default: :data:`True`

Returns:
Context: The current context.

Raises:
.ContextError: If called outside of a context
established by :meth:`google.cloud.ndb.client.Client.context`.
established by :meth:`google.cloud.ndb.client.Client.context` and
``raise_context_error`` is :data:`True`.
"""
context = _state.context
if context:
return context

raise exceptions.ContextError()
if raise_context_error:
raise exceptions.ContextError()


def _default_policy(attr_name, value_type):
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,19 @@ def test__http():
client._http

@staticmethod
def test__context():
def test_context():
with patch_credentials("testing"):
client = client_module.Client()

with client.context():
context = context_module.get_context()
assert context.client is client

@staticmethod
def test_context_double_jeopardy():
with patch_credentials("testing"):
client = client_module.Client()

with client.context():
with pytest.raises(RuntimeError):
client.context().__enter__()