Skip to content

Commit 2ac210d

Browse files
author
Chris Rossi
authored
fix: fix a connection leak in RedisCache (#556)
1 parent eba1f19 commit 2ac210d

6 files changed

Lines changed: 144 additions & 5 deletions

File tree

packages/google-cloud-ndb/google/cloud/ndb/_cache.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,37 @@ def make_call(self):
305305

306306
def future_info(self, key):
307307
"""Generate info string for Future."""
308-
return "GlobalWatch.delete({})".format(key)
308+
return "GlobalCache.watch({})".format(key)
309+
310+
311+
def global_unwatch(key):
312+
"""End optimistic transaction with global cache.
313+
314+
Indicates that value for the key wasn't found in the database, so there will not be
315+
a future call to :func:`global_compare_and_swap`, and we no longer need to watch
316+
this key.
317+
318+
Args:
319+
key (bytes): The key to unwatch.
320+
321+
Returns:
322+
tasklets.Future: Eventual result will be ``None``.
323+
"""
324+
batch = _batch.get_batch(_GlobalCacheUnwatchBatch)
325+
return batch.add(key)
326+
327+
328+
class _GlobalCacheUnwatchBatch(_GlobalCacheWatchBatch):
329+
"""Batch for global cache unwatch requests. """
330+
331+
def make_call(self):
332+
"""Call :method:`GlobalCache.unwatch`."""
333+
cache = context_module.get_context().global_cache
334+
return cache.unwatch(self.keys)
335+
336+
def future_info(self, key):
337+
"""Generate info string for Future."""
338+
return "GlobalCache.unwatch({})".format(key)
309339

310340

311341
def global_compare_and_swap(key, value, expires=None):

packages/google-cloud-ndb/google/cloud/ndb/_datastore_api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,15 @@ def lookup(key, options):
154154
entity_pb = yield batch.add(key)
155155

156156
# Do not cache misses
157-
if use_global_cache and not key_locked and entity_pb is not _NOT_FOUND:
158-
expires = context._global_cache_timeout(key, options)
159-
serialized = entity_pb.SerializeToString()
160-
yield _cache.global_compare_and_swap(cache_key, serialized, expires=expires)
157+
if use_global_cache and not key_locked:
158+
if entity_pb is not _NOT_FOUND:
159+
expires = context._global_cache_timeout(key, options)
160+
serialized = entity_pb.SerializeToString()
161+
yield _cache.global_compare_and_swap(
162+
cache_key, serialized, expires=expires
163+
)
164+
else:
165+
yield _cache.global_unwatch(cache_key)
161166

162167
raise tasklets.Return(entity_pb)
163168

packages/google-cloud-ndb/google/cloud/ndb/global_cache.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,19 @@ def watch(self, keys):
9292
"""
9393
raise NotImplementedError
9494

95+
@abc.abstractmethod
96+
def unwatch(self, keys):
97+
"""End an optimistic transaction for the given keys.
98+
99+
Indicates that value for the key wasn't found in the database, so there will not
100+
be a future call to :meth:`compare_and_swap`, and we no longer need to watch
101+
this key.
102+
103+
Arguments:
104+
keys (List[bytes]): The keys to watch.
105+
"""
106+
raise NotImplementedError
107+
95108
@abc.abstractmethod
96109
def compare_and_swap(self, items, expires=None):
97110
"""Like :meth:`set` but using an optimistic transaction.
@@ -160,6 +173,11 @@ def watch(self, keys):
160173
for key in keys:
161174
self._watch_keys[key] = self.cache.get(key)
162175

176+
def unwatch(self, keys):
177+
"""Implements :meth:`GlobalCache.unwatch`."""
178+
for key in keys:
179+
self._watch_keys.pop(key, None)
180+
163181
def compare_and_swap(self, items, expires=None):
164182
"""Implements :meth:`GlobalCache.compare_and_swap`."""
165183
if expires:
@@ -239,6 +257,13 @@ def watch(self, keys):
239257
for key in keys:
240258
self.pipes[key] = holder
241259

260+
def unwatch(self, keys):
261+
"""Implements :meth:`GlobalCache.watch`."""
262+
for key in keys:
263+
holder = self.pipes.pop(key, None)
264+
if holder:
265+
holder.pipe.reset()
266+
242267
def compare_and_swap(self, items, expires=None):
243268
"""Implements :meth:`GlobalCache.compare_and_swap`."""
244269
pipes = {}
@@ -391,6 +416,13 @@ def watch(self, keys):
391416
for key, (value, caskey) in self.client.gets_many(keys).items():
392417
caskeys[key] = caskey
393418

419+
def unwatch(self, keys):
420+
"""Implements :meth:`GlobalCache.unwatch`."""
421+
keys = [self._key(key) for key in keys]
422+
caskeys = self.caskeys
423+
for key in keys:
424+
caskeys.pop(key, None)
425+
394426
def compare_and_swap(self, items, expires=None):
395427
"""Implements :meth:`GlobalCache.compare_and_swap`."""
396428
caskeys = self.caskeys

packages/google-cloud-ndb/tests/unit/test__cache.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,31 @@ def test_add_and_idle_and_done_callbacks(in_context):
284284
assert future2.result() is None
285285

286286

287+
@mock.patch("google.cloud.ndb._cache._batch")
288+
def test_global_unwatch(_batch):
289+
batch = _batch.get_batch.return_value
290+
assert _cache.global_unwatch(b"key") is batch.add.return_value
291+
_batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch)
292+
batch.add.assert_called_once_with(b"key")
293+
294+
295+
class Test_GlobalCacheUnwatchBatch:
296+
@staticmethod
297+
def test_add_and_idle_and_done_callbacks(in_context):
298+
cache = mock.Mock()
299+
300+
batch = _cache._GlobalCacheUnwatchBatch({})
301+
future1 = batch.add(b"foo")
302+
future2 = batch.add(b"bar")
303+
304+
with in_context.new(global_cache=cache).use():
305+
batch.idle_callback()
306+
307+
cache.unwatch.assert_called_once_with([b"foo", b"bar"])
308+
assert future1.result() is None
309+
assert future2.result() is None
310+
311+
287312
class Test_global_compare_and_swap:
288313
@staticmethod
289314
@mock.patch("google.cloud.ndb._cache._batch")

packages/google-cloud-ndb/tests/unit/test__datastore_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ class SomeKind(model.Model):
346346
assert future.result() is _api._NOT_FOUND
347347

348348
assert global_cache.get([cache_key]) == [_cache._LOCKED]
349+
assert len(global_cache._watch_keys) == 0
349350

350351

351352
class Test_LookupBatch:

packages/google-cloud-ndb/tests/unit/test_global_cache.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def delete(self, keys):
3838
def watch(self, keys):
3939
return super(MockImpl, self).watch(keys)
4040

41+
def unwatch(self, keys):
42+
return super(MockImpl, self).unwatch(keys)
43+
4144
def compare_and_swap(self, items, expires=None):
4245
return super(MockImpl, self).compare_and_swap(items, expires=expires)
4346

@@ -63,6 +66,11 @@ def test_watch(self):
6366
with pytest.raises(NotImplementedError):
6467
cache.watch(b"foo")
6568

69+
def test_unwatch(self):
70+
cache = self.make_one()
71+
with pytest.raises(NotImplementedError):
72+
cache.unwatch(b"foo")
73+
6674
def test_compare_and_swap(self):
6775
cache = self.make_one()
6876
with pytest.raises(NotImplementedError):
@@ -147,6 +155,16 @@ def test_watch_compare_and_swap_with_expires(time):
147155
result = cache.get([b"one", b"two", b"three"])
148156
assert result == [None, b"hamburgers", None]
149157

158+
@staticmethod
159+
def test_watch_unwatch():
160+
cache = global_cache._InProcessGlobalCache()
161+
result = cache.watch([b"one", b"two", b"three"])
162+
assert result is None
163+
164+
result = cache.unwatch([b"one", b"two", b"three"])
165+
assert result is None
166+
assert cache._watch_keys == {}
167+
150168

151169
class TestRedisCache:
152170
@staticmethod
@@ -225,6 +243,23 @@ def test_watch(uuid):
225243
"bar": global_cache._Pipeline(pipe, "abc123"),
226244
}
227245

246+
@staticmethod
247+
def test_unwatch():
248+
redis = mock.Mock(spec=())
249+
cache = global_cache.RedisCache(redis)
250+
pipe1 = mock.Mock(spec=("reset",))
251+
pipe2 = mock.Mock(spec=("reset",))
252+
cache._pipes.pipes = {
253+
"ay": global_cache._Pipeline(pipe1, "abc123"),
254+
"be": global_cache._Pipeline(pipe1, "abc123"),
255+
"see": global_cache._Pipeline(pipe2, "def456"),
256+
"dee": global_cache._Pipeline(pipe2, "def456"),
257+
"whatevs": global_cache._Pipeline(None, "himom!"),
258+
}
259+
260+
cache.unwatch(["ay", "be", "see", "dee", "nuffin"])
261+
assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")}
262+
228263
@staticmethod
229264
def test_compare_and_swap():
230265
redis = mock.Mock(spec=())
@@ -450,6 +485,17 @@ def test_watch():
450485
key2: b"1",
451486
}
452487

488+
@staticmethod
489+
def test_unwatch():
490+
client = mock.Mock(spec=())
491+
cache = global_cache.MemcacheCache(client)
492+
key2 = cache._key(b"two")
493+
cache.caskeys[key2] = b"5"
494+
cache.caskeys["whatevs"] = b"6"
495+
cache.unwatch([b"one", b"two"])
496+
497+
assert cache.caskeys == {"whatevs": b"6"}
498+
453499
@staticmethod
454500
def test_compare_and_swap():
455501
client = mock.Mock(spec=("cas",))

0 commit comments

Comments
 (0)