@@ -513,26 +513,26 @@ def _broadcast(
513513 mandatory = False ,
514514 ).result ()
515515
516- def _transaction_begin (self , ** kwargs ):
517- """Enter transaction mode.
518- :param **kwargs: Further parameters for the transport layer.
516+ def _transaction_begin (
517+ self , transaction_id : int , * , subscription_id : Optional [int ] = None , ** kwargs
518+ ) -> None :
519+ """Start a new transaction.
520+ :param transaction_id: ID for this transaction in the transport layer.
521+ :param subscription_id: Tie the transaction to a specific channel containing this subscription.
519522 """
520- raise NotImplementedError ()
521- # self._channel.tx_select()
523+ self ._pika_thread .tx_select (transaction_id , subscription_id )
522524
523- def _transaction_abort (self , ** kwargs ):
525+ def _transaction_abort (self , transaction_id : int , ** kwargs ) -> None :
524526 """Abort a transaction and roll back all operations.
525- :param **kwargs: Further parameters for the transport layer .
527+ :param transaction_id: ID of transaction to be aborted .
526528 """
527- raise NotImplementedError ()
528- # self._channel.tx_rollback()
529+ self ._pika_thread .tx_rollback (transaction_id )
529530
530- def _transaction_commit (self , ** kwargs ):
531+ def _transaction_commit (self , transaction_id : int , ** kwargs ) -> None :
531532 """Commit a transaction.
532- :param **kwargs: Further parameters for the transport layer .
533+ :param transaction_id: ID of transaction to be committed .
533534 """
534- raise NotImplementedError ()
535- # self._channel.tx_commit()
535+ self ._pika_thread .tx_commit (transaction_id )
536536
537537 def _ack (
538538 self , message_id , subscription_id : int , * , multiple : bool = False , ** _kwargs
@@ -694,9 +694,12 @@ def __init__(
694694 self ._subscriptions : Dict [int , _PikaSubscription ] = {}
695695 # The pika connection object
696696 self ._connection : Optional [pika .BlockingConnection ] = None
697- # Per -subscription channels. May be pointing to the shared channel
697+ # Index of per -subscription channels.
698698 self ._pika_channels : Dict [int , BlockingChannel ] = {}
699- # A common, shared channel, used for non-QoS subscriptions
699+ # Bidirectional index of all ongoing transactions. May include the shared channel
700+ self ._transactions_by_id : Dict [int , BlockingChannel ] = {}
701+ self ._transactions_by_channel : Dict [BlockingChannel , int ] = {}
702+ # A common, shared channel, used for sending messages outside of transactions.
700703 self ._pika_shared_channel : Optional [BlockingChannel ]
701704 # Are we allowed to reconnect. Can only be turned off, never on
702705 self ._reconnection_allowed : bool = True
@@ -907,6 +910,11 @@ def _unsubscribe():
907910 logger .debug ("Closing channel that is now unused" )
908911 channel .close ()
909912
913+ # Forget about any ongoing transactions on the channel
914+ if channel in self ._transactions_by_channel :
915+ transaction_id = self ._transactions_by_channel .pop (channel )
916+ self ._transactions_by_id .pop (transaction_id )
917+
910918 result .set_result (None )
911919 except BaseException as e :
912920 result .set_exception (e )
@@ -974,6 +982,100 @@ def nack(
974982 lambda : channel .basic_nack (delivery_tag , multiple = multiple , requeue = requeue )
975983 )
976984
985+ def tx_select (
986+ self , transaction_id : int , subscription_id : Optional [int ]
987+ ) -> Future [None ]:
988+ """Set a channel to transaction mode. Thread-safe.
989+ :param transaction_id: ID for this transaction in the transport layer.
990+ :param subscription_id: Tie the transaction to a specific channel containing this subscription.
991+ """
992+
993+ if not self ._connection :
994+ raise RuntimeError ("Cannot transact on unstarted connection" )
995+
996+ future : Future [None ] = Future ()
997+
998+ def _tx_select ():
999+ if future .set_running_or_notify_cancel ():
1000+ try :
1001+ if subscription_id :
1002+ if subscription_id not in self ._pika_channels :
1003+ raise KeyError (
1004+ f"Could not find subscription { subscription_id } to begin transaction"
1005+ )
1006+ channel = self ._pika_channels [subscription_id ]
1007+ else :
1008+ channel = self ._get_shared_channel ()
1009+ if channel in self ._transactions_by_channel :
1010+ raise KeyError (
1011+ f"Channel { channel } is already running transaction { self ._transactions_by_channel [channel ]} , so can't start transaction { transaction_id } "
1012+ )
1013+ channel .tx_select ()
1014+ self ._transactions_by_channel [channel ] = transaction_id
1015+ self ._transactions_by_id [transaction_id ] = channel
1016+
1017+ future .set_result (None )
1018+ except BaseException as e :
1019+ future .set_exception (e )
1020+ raise
1021+
1022+ self ._connection .add_callback_threadsafe (_tx_select )
1023+ return future
1024+
1025+ def tx_rollback (self , transaction_id : int ) -> Future [None ]:
1026+ """Abort a transaction and roll back all operations. Thread-safe.
1027+ :param transaction_id: ID of transaction to be aborted.
1028+ """
1029+ if not self ._connection :
1030+ raise RuntimeError ("Cannot transact on unstarted connection" )
1031+
1032+ future : Future [None ] = Future ()
1033+
1034+ def _tx_rollback ():
1035+ if future .set_running_or_notify_cancel ():
1036+ try :
1037+ channel = self ._transactions_by_id .pop (transaction_id , None )
1038+ if not channel :
1039+ raise KeyError (
1040+ f"Could not find transaction { transaction_id } to roll back"
1041+ )
1042+ self ._transactions_by_channel .pop (channel )
1043+ channel .tx_rollback ()
1044+ future .set_result (None )
1045+ except BaseException as e :
1046+ future .set_exception (e )
1047+ raise
1048+
1049+ self ._connection .add_callback_threadsafe (_tx_rollback )
1050+ return future
1051+
1052+ def tx_commit (self , transaction_id : int ) -> Future [None ]:
1053+ """Commit a transaction.
1054+ :param transaction_id: ID of transaction to be committed. Thread-safe..
1055+ """
1056+ if not self ._connection :
1057+ raise RuntimeError ("Cannot transact on unstarted connection" )
1058+
1059+ future : Future [None ] = Future ()
1060+
1061+ def _tx_commit ():
1062+ if future .set_running_or_notify_cancel ():
1063+ try :
1064+ channel = self ._transactions_by_id .pop (transaction_id , None )
1065+ if not channel :
1066+ raise KeyError (
1067+ f"Could not find transaction { transaction_id } to commit"
1068+ )
1069+ self ._transactions_by_channel .pop (channel )
1070+ channel .tx_commit ()
1071+ future .set_result (None )
1072+ except BaseException as e :
1073+ future .set_exception (e )
1074+ raise
1075+
1076+ self ._connection .add_callback_threadsafe (_tx_commit )
1077+ return future
1078+
9771079 @property
9781080 def connection_alive (self ) -> bool :
9791081 """
@@ -989,7 +1091,7 @@ def connection_alive(self) -> bool:
9891091 )
9901092
9911093 # NOTE: With reconnection lifecycle this probably doesn't make sense
992- # on it's own. It might make sense to add this returning a
1094+ # on its own. It might make sense to add this returning a
9931095 # connection-specific 'token' - presumably the user might want
9941096 # to ensure that a connection is still the same connection
9951097 # and thus adhering to various within-connection guarantees.
@@ -1017,7 +1119,7 @@ def _get_shared_channel(self) -> BlockingChannel:
10171119
10181120 if not self ._pika_shared_channel :
10191121 self ._pika_shared_channel = self ._connection .channel ()
1020- self ._pika_shared_channel .confirm_delivery ()
1122+ ##### self._pika_shared_channel.confirm_delivery()
10211123 return self ._pika_shared_channel
10221124
10231125 def _recreate_subscriptions (self ):
@@ -1050,22 +1152,18 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio
10501152 f"Subscription { subscription_id } to '{ subscription .destination } ' is not reconnectable. Turning reconnection off."
10511153 )
10521154
1053- # Either open a channel (if prefetch) or use the shared one
1054- if subscription .prefetch_count == 0 :
1055- channel = self ._get_shared_channel ()
1056- else :
1057- channel = self ._connection .channel ()
1058- channel .confirm_delivery ()
1059- channel .basic_qos (prefetch_count = subscription .prefetch_count )
1155+ # Open a dedicated channel for this subscription
1156+ channel = self ._connection .channel ()
1157+ channel .basic_qos (prefetch_count = subscription .prefetch_count )
10601158
1061- if subscription .kind == _PikaSubscriptionKind .FANOUT :
1159+ if subscription .kind is _PikaSubscriptionKind .FANOUT :
10621160 # If a FANOUT subscription, then we need to create and bind
10631161 # a temporary queue to receive messages from the exchange
10641162 queue = channel .queue_declare ("" , exclusive = True ).method .queue
10651163 assert queue is not None
10661164 channel .queue_bind (queue , subscription .destination )
10671165 subscription .queue = queue
1068- elif subscription .kind == _PikaSubscriptionKind .DIRECT :
1166+ elif subscription .kind is _PikaSubscriptionKind .DIRECT :
10691167 subscription .queue = subscription .destination
10701168 else :
10711169 raise NotImplementedError (f"Unknown subscription kind: { subscription .kind } " )
0 commit comments