Skip to content

Commit 8d883e7

Browse files
authored
PikaTransport: Enable Transactions (#92)
Make every subscription run on a separate channel so that transactions can be used on channels. So if a message is to be send as part of a transaction then it can be sent on the relevant channel, and if a message is sent outside of a transaction it can be sent via the default channel. This disables confirm mode on all channels for now. The confirm mode currently does not have any effect as we are not observing the callback. However, confirm and transaction modes are mutually exclusive, so transactions can't be used if we use confirm mode on subscription channels. For the default channel we may eventually want to have one in normal mode and one in confirm mode, but this is another issue for another day. Add optional subscription ID to transaction calls so that a transaction can be associated with the correct channel where needed.
1 parent 35aa025 commit 8d883e7

File tree

6 files changed

+148
-39
lines changed

6 files changed

+148
-39
lines changed

src/workflows/services/sample_transaction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ class SampleTxn(CommonService):
1717
def initializing(self):
1818
"""Subscribe to a channel. Received messages must be acknowledged."""
1919
self.subid = self._transport.subscribe(
20-
"transient.transaction", self.receive_message, acknowledgement=True
20+
"transient.transaction",
21+
self.receive_message,
22+
acknowledgement=True,
23+
prefetch_count=1000,
2124
)
2225

2326
@staticmethod

src/workflows/transport/common_transport.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,24 @@ def nack(self, message, subscription_id: Optional[int] = None, **kwargs):
288288
)
289289
self._nack(message_id, subscription_id=subscription_id, **kwargs)
290290

291-
def transaction_begin(self, **kwargs) -> int:
291+
def transaction_begin(self, subscription_id: Optional[int] = None, **kwargs) -> int:
292292
"""Start a new transaction.
293-
:param **kwargs: Further parameters for the transport layer. For example
293+
:param **kwargs: Further parameters for the transport layer.
294294
:return: A transaction ID that can be passed to other functions.
295295
"""
296296
self.__transaction_id += 1
297297
self.__transactions.add(self.__transaction_id)
298-
self.log.debug("Starting transaction with ID %d", self.__subscription_id)
299-
self._transaction_begin(self.__transaction_id, **kwargs)
298+
if subscription_id:
299+
self.log.debug(
300+
"Starting transaction with ID %d on subscription %d",
301+
self.__transaction_id,
302+
subscription_id,
303+
)
304+
else:
305+
self.log.debug("Starting transaction with ID %d", self.__transaction_id)
306+
self._transaction_begin(
307+
self.__transaction_id, subscription_id=subscription_id, **kwargs
308+
)
300309
return self.__transaction_id
301310

302311
def transaction_abort(self, transaction_id: int, **kwargs):
@@ -405,21 +414,23 @@ def _nack(self, message_id, subscription_id, **kwargs):
405414
"""
406415
raise NotImplementedError("Transport interface not implemented")
407416

408-
def _transaction_begin(self, transaction_id, **kwargs):
417+
def _transaction_begin(
418+
self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs
419+
) -> None:
409420
"""Start a new transaction.
410421
:param transaction_id: ID for this transaction in the transport layer.
411422
:param **kwargs: Further parameters for the transport layer.
412423
"""
413424
raise NotImplementedError("Transport interface not implemented")
414425

415-
def _transaction_abort(self, transaction_id, **kwargs):
426+
def _transaction_abort(self, transaction_id: int, **kwargs) -> None:
416427
"""Abort a transaction and roll back all operations.
417428
:param transaction_id: ID of transaction to be aborted.
418429
:param **kwargs: Further parameters for the transport layer.
419430
"""
420431
raise NotImplementedError("Transport interface not implemented")
421432

422-
def _transaction_commit(self, transaction_id, **kwargs):
433+
def _transaction_commit(self, transaction_id: int, **kwargs) -> None:
423434
"""Commit a transaction.
424435
:param transaction_id: ID of transaction to be committed.
425436
:param **kwargs: Further parameters for the transport layer.

src/workflows/transport/pika_transport.py

Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -513,26 +513,26 @@ def _broadcast(
513513
mandatory=False,
514514
).result()
515515

516-
def _transaction_begin(self, **kwargs):
517-
"""Enter transaction mode.
518-
:param **kwargs: Further parameters for the transport layer.
516+
def _transaction_begin(
517+
self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs
518+
) -> None:
519+
"""Start a new transaction.
520+
:param transaction_id: ID for this transaction in the transport layer.
521+
:param subscription_id: Tie the transaction to a specific channel containing this subscription.
519522
"""
520-
raise NotImplementedError()
521-
# self._channel.tx_select()
523+
self._pika_thread.tx_select(transaction_id, subscription_id)
522524

523-
def _transaction_abort(self, **kwargs):
525+
def _transaction_abort(self, transaction_id: int, **kwargs) -> None:
524526
"""Abort a transaction and roll back all operations.
525-
:param **kwargs: Further parameters for the transport layer.
527+
:param transaction_id: ID of transaction to be aborted.
526528
"""
527-
raise NotImplementedError()
528-
# self._channel.tx_rollback()
529+
self._pika_thread.tx_rollback(transaction_id)
529530

530-
def _transaction_commit(self, **kwargs):
531+
def _transaction_commit(self, transaction_id: int, **kwargs) -> None:
531532
"""Commit a transaction.
532-
:param **kwargs: Further parameters for the transport layer.
533+
:param transaction_id: ID of transaction to be committed.
533534
"""
534-
raise NotImplementedError()
535-
# self._channel.tx_commit()
535+
self._pika_thread.tx_commit(transaction_id)
536536

537537
def _ack(
538538
self, message_id, subscription_id: int, *, multiple: bool = False, **_kwargs
@@ -694,9 +694,12 @@ def __init__(
694694
self._subscriptions: Dict[int, _PikaSubscription] = {}
695695
# The pika connection object
696696
self._connection: Optional[pika.BlockingConnection] = None
697-
# Per-subscription channels. May be pointing to the shared channel
697+
# Index of per-subscription channels.
698698
self._pika_channels: Dict[int, BlockingChannel] = {}
699-
# A common, shared channel, used for non-QoS subscriptions
699+
# Bidirectional index of all ongoing transactions. May include the shared channel
700+
self._transactions_by_id: Dict[int, BlockingChannel] = {}
701+
self._transactions_by_channel: Dict[BlockingChannel, int] = {}
702+
# A common, shared channel, used for sending messages outside of transactions.
700703
self._pika_shared_channel: Optional[BlockingChannel]
701704
# Are we allowed to reconnect. Can only be turned off, never on
702705
self._reconnection_allowed: bool = True
@@ -907,6 +910,11 @@ def _unsubscribe():
907910
logger.debug("Closing channel that is now unused")
908911
channel.close()
909912

913+
# Forget about any ongoing transactions on the channel
914+
if channel in self._transactions_by_channel:
915+
transaction_id = self._transactions_by_channel.pop(channel)
916+
self._transactions_by_id.pop(transaction_id)
917+
910918
result.set_result(None)
911919
except BaseException as e:
912920
result.set_exception(e)
@@ -974,6 +982,100 @@ def nack(
974982
lambda: channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue)
975983
)
976984

985+
def tx_select(
986+
self, transaction_id: int, subscription_id: Optional[int]
987+
) -> Future[None]:
988+
"""Set a channel to transaction mode. Thread-safe.
989+
:param transaction_id: ID for this transaction in the transport layer.
990+
:param subscription_id: Tie the transaction to a specific channel containing this subscription.
991+
"""
992+
993+
if not self._connection:
994+
raise RuntimeError("Cannot transact on unstarted connection")
995+
996+
future: Future[None] = Future()
997+
998+
def _tx_select():
999+
if future.set_running_or_notify_cancel():
1000+
try:
1001+
if subscription_id:
1002+
if subscription_id not in self._pika_channels:
1003+
raise KeyError(
1004+
f"Could not find subscription {subscription_id} to begin transaction"
1005+
)
1006+
channel = self._pika_channels[subscription_id]
1007+
else:
1008+
channel = self._get_shared_channel()
1009+
if channel in self._transactions_by_channel:
1010+
raise KeyError(
1011+
f"Channel {channel} is already running transaction {self._transactions_by_channel[channel]}, so can't start transaction {transaction_id}"
1012+
)
1013+
channel.tx_select()
1014+
self._transactions_by_channel[channel] = transaction_id
1015+
self._transactions_by_id[transaction_id] = channel
1016+
1017+
future.set_result(None)
1018+
except BaseException as e:
1019+
future.set_exception(e)
1020+
raise
1021+
1022+
self._connection.add_callback_threadsafe(_tx_select)
1023+
return future
1024+
1025+
def tx_rollback(self, transaction_id: int) -> Future[None]:
1026+
"""Abort a transaction and roll back all operations. Thread-safe.
1027+
:param transaction_id: ID of transaction to be aborted.
1028+
"""
1029+
if not self._connection:
1030+
raise RuntimeError("Cannot transact on unstarted connection")
1031+
1032+
future: Future[None] = Future()
1033+
1034+
def _tx_rollback():
1035+
if future.set_running_or_notify_cancel():
1036+
try:
1037+
channel = self._transactions_by_id.pop(transaction_id, None)
1038+
if not channel:
1039+
raise KeyError(
1040+
f"Could not find transaction {transaction_id} to roll back"
1041+
)
1042+
self._transactions_by_channel.pop(channel)
1043+
channel.tx_rollback()
1044+
future.set_result(None)
1045+
except BaseException as e:
1046+
future.set_exception(e)
1047+
raise
1048+
1049+
self._connection.add_callback_threadsafe(_tx_rollback)
1050+
return future
1051+
1052+
def tx_commit(self, transaction_id: int) -> Future[None]:
1053+
"""Commit a transaction.
1054+
:param transaction_id: ID of transaction to be committed. Thread-safe..
1055+
"""
1056+
if not self._connection:
1057+
raise RuntimeError("Cannot transact on unstarted connection")
1058+
1059+
future: Future[None] = Future()
1060+
1061+
def _tx_commit():
1062+
if future.set_running_or_notify_cancel():
1063+
try:
1064+
channel = self._transactions_by_id.pop(transaction_id, None)
1065+
if not channel:
1066+
raise KeyError(
1067+
f"Could not find transaction {transaction_id} to commit"
1068+
)
1069+
self._transactions_by_channel.pop(channel)
1070+
channel.tx_commit()
1071+
future.set_result(None)
1072+
except BaseException as e:
1073+
future.set_exception(e)
1074+
raise
1075+
1076+
self._connection.add_callback_threadsafe(_tx_commit)
1077+
return future
1078+
9771079
@property
9781080
def connection_alive(self) -> bool:
9791081
"""
@@ -989,7 +1091,7 @@ def connection_alive(self) -> bool:
9891091
)
9901092

9911093
# NOTE: With reconnection lifecycle this probably doesn't make sense
992-
# on it's own. It might make sense to add this returning a
1094+
# on its own. It might make sense to add this returning a
9931095
# connection-specific 'token' - presumably the user might want
9941096
# to ensure that a connection is still the same connection
9951097
# and thus adhering to various within-connection guarantees.
@@ -1017,7 +1119,7 @@ def _get_shared_channel(self) -> BlockingChannel:
10171119

10181120
if not self._pika_shared_channel:
10191121
self._pika_shared_channel = self._connection.channel()
1020-
self._pika_shared_channel.confirm_delivery()
1122+
##### self._pika_shared_channel.confirm_delivery()
10211123
return self._pika_shared_channel
10221124

10231125
def _recreate_subscriptions(self):
@@ -1050,22 +1152,18 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio
10501152
f"Subscription {subscription_id} to '{subscription.destination}' is not reconnectable. Turning reconnection off."
10511153
)
10521154

1053-
# Either open a channel (if prefetch) or use the shared one
1054-
if subscription.prefetch_count == 0:
1055-
channel = self._get_shared_channel()
1056-
else:
1057-
channel = self._connection.channel()
1058-
channel.confirm_delivery()
1059-
channel.basic_qos(prefetch_count=subscription.prefetch_count)
1155+
# Open a dedicated channel for this subscription
1156+
channel = self._connection.channel()
1157+
channel.basic_qos(prefetch_count=subscription.prefetch_count)
10601158

1061-
if subscription.kind == _PikaSubscriptionKind.FANOUT:
1159+
if subscription.kind is _PikaSubscriptionKind.FANOUT:
10621160
# If a FANOUT subscription, then we need to create and bind
10631161
# a temporary queue to receive messages from the exchange
10641162
queue = channel.queue_declare("", exclusive=True).method.queue
10651163
assert queue is not None
10661164
channel.queue_bind(queue, subscription.destination)
10671165
subscription.queue = queue
1068-
elif subscription.kind == _PikaSubscriptionKind.DIRECT:
1166+
elif subscription.kind is _PikaSubscriptionKind.DIRECT:
10691167
subscription.queue = subscription.destination
10701168
else:
10711169
raise NotImplementedError(f"Unknown subscription kind: {subscription.kind}")

src/workflows/transport/stomp_transport.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,21 +410,18 @@ def _broadcast(
410410
def _transaction_begin(self, transaction_id, **kwargs):
411411
"""Start a new transaction.
412412
:param transaction_id: ID for this transaction in the transport layer.
413-
:param **kwargs: Further parameters for the transport layer.
414413
"""
415414
self._conn.begin(transaction=transaction_id)
416415

417416
def _transaction_abort(self, transaction_id, **kwargs):
418417
"""Abort a transaction and roll back all operations.
419418
:param transaction_id: ID of transaction to be aborted.
420-
:param **kwargs: Further parameters for the transport layer.
421419
"""
422420
self._conn.abort(transaction_id)
423421

424422
def _transaction_commit(self, transaction_id, **kwargs):
425423
"""Commit a transaction.
426424
:param transaction_id: ID of transaction to be committed.
427-
:param **kwargs: Further parameters for the transport layer.
428425
"""
429426
self._conn.commit(transaction_id)
430427

tests/services/test_sample_transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_txnservice_subscribes_to_channel():
5252
p.initializing()
5353

5454
mock_transport.subscribe.assert_called_once_with(
55-
mock.ANY, p.receive_message, acknowledgement=True
55+
mock.ANY, p.receive_message, acknowledgement=True, prefetch_count=1000
5656
)
5757

5858

tests/transport/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def test_create_and_destroy_transactions():
248248
t = ct.transaction_begin()
249249

250250
assert t
251-
ct._transaction_begin.assert_called_once_with(t)
251+
ct._transaction_begin.assert_called_once_with(t, subscription_id=None)
252252

253253
ct.transaction_abort(t)
254254
with pytest.raises(workflows.Error):

0 commit comments

Comments
 (0)