2424//! The call site should, thus, look something like this:
2525//! ```
2626//! use tokio::sync::mpsc;
27- //! use tokio ::net::TcpStream;
27+ //! use std ::net::TcpStream;
2828//! use bitcoin::secp256k1::key::PublicKey;
2929//! use lightning::util::events::EventsProvider;
3030//! use std::net::SocketAddr;
@@ -86,6 +86,7 @@ use lightning::util::logger::Logger;
8686
8787use std:: { task, thread} ;
8888use std:: net:: SocketAddr ;
89+ use std:: net:: TcpStream as StdTcpStream ;
8990use std:: sync:: { Arc , Mutex , MutexGuard } ;
9091use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
9192use std:: time:: Duration ;
@@ -156,7 +157,7 @@ impl Connection {
156157 // In this case, we do need to call peer_manager.socket_disconnected() to inform
157158 // Rust-Lightning that the socket is gone.
158159 PeerDisconnected
159- } ;
160+ }
160161 let disconnect_type = loop {
161162 macro_rules! shutdown_socket {
162163 ( $err: expr, $need_disconnect: expr) => { {
@@ -218,7 +219,7 @@ impl Connection {
218219 }
219220 }
220221
221- fn new ( event_notify : mpsc:: Sender < ( ) > , stream : TcpStream ) -> ( io:: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
222+ fn new ( event_notify : mpsc:: Sender < ( ) > , stream : StdTcpStream ) -> ( io:: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
222223 // We only ever need a channel of depth 1 here: if we returned a non-full write to the
223224 // PeerManager, we will eventually get notified that there is room in the socket to write
224225 // new bytes, which will generate an event. That event will be popped off the queue before
@@ -229,7 +230,8 @@ impl Connection {
229230 // we shove a value into the channel which comes after we've reset the read_paused bool to
230231 // false.
231232 let ( read_waker, read_receiver) = mpsc:: channel ( 1 ) ;
232- let ( reader, writer) = io:: split ( stream) ;
233+ stream. set_nonblocking ( true ) . unwrap ( ) ;
234+ let ( reader, writer) = io:: split ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
233235
234236 ( reader, write_receiver, read_receiver,
235237 Arc :: new ( Mutex :: new ( Self {
@@ -248,7 +250,7 @@ impl Connection {
248250/// not need to poll the provided future in order to make progress.
249251///
250252/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
251- pub fn setup_inbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , stream : TcpStream ) -> impl std:: future:: Future < Output =( ) > where
253+ pub fn setup_inbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , stream : StdTcpStream ) -> impl std:: future:: Future < Output =( ) > where
252254 CMH : ChannelMessageHandler + ' static ,
253255 RMH : RoutingMessageHandler + ' static ,
254256 L : Logger + ' static + ?Sized {
@@ -290,7 +292,7 @@ pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<So
290292/// not need to poll the provided future in order to make progress.
291293///
292294/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
293- pub fn setup_outbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , their_node_id : PublicKey , stream : TcpStream ) -> impl std:: future:: Future < Output =( ) > where
295+ pub fn setup_outbound < CMH , RMH , L > ( peer_manager : Arc < peer_handler:: PeerManager < SocketDescriptor , Arc < CMH > , Arc < RMH > , Arc < L > > > , event_notify : mpsc:: Sender < ( ) > , their_node_id : PublicKey , stream : StdTcpStream ) -> impl std:: future:: Future < Output =( ) > where
294296 CMH : ChannelMessageHandler + ' static ,
295297 RMH : RoutingMessageHandler + ' static ,
296298 L : Logger + ' static + ?Sized {
@@ -366,7 +368,7 @@ pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerM
366368 CMH : ChannelMessageHandler + ' static ,
367369 RMH : RoutingMessageHandler + ' static ,
368370 L : Logger + ' static + ?Sized {
369- if let Ok ( Ok ( stream) ) = time:: timeout ( Duration :: from_secs ( 10 ) , TcpStream :: connect ( & addr) ) . await {
371+ if let Ok ( Ok ( stream) ) = time:: timeout ( Duration :: from_secs ( 10 ) , async { TcpStream :: connect ( & addr) . await . map ( |s| s . into_std ( ) . unwrap ( ) ) } ) . await {
370372 Some ( setup_outbound ( peer_manager, event_notify, their_node_id, stream) )
371373 } else { None }
372374}
@@ -388,7 +390,7 @@ fn wake_socket_waker(orig_ptr: *const ()) {
388390}
389391fn wake_socket_waker_by_ref ( orig_ptr : * const ( ) ) {
390392 let sender_ptr = orig_ptr as * const mpsc:: Sender < ( ) > ;
391- let mut sender = unsafe { ( * sender_ptr) . clone ( ) } ;
393+ let sender = unsafe { ( * sender_ptr) . clone ( ) } ;
392394 let _ = sender. try_send ( ( ) ) ;
393395}
394396fn drop_socket_waker ( orig_ptr : * const ( ) ) {
@@ -512,6 +514,7 @@ mod tests {
512514 use tokio:: sync:: mpsc;
513515
514516 use std:: mem;
517+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
515518 use std:: sync:: { Arc , Mutex } ;
516519 use std:: time:: Duration ;
517520
@@ -526,6 +529,7 @@ mod tests {
526529 expected_pubkey : PublicKey ,
527530 pubkey_connected : mpsc:: Sender < ( ) > ,
528531 pubkey_disconnected : mpsc:: Sender < ( ) > ,
532+ disconnected_flag : AtomicBool ,
529533 msg_events : Mutex < Vec < MessageSendEvent > > ,
530534 }
531535 impl RoutingMessageHandler for MsgHandler {
@@ -559,6 +563,7 @@ mod tests {
559563 fn handle_announcement_signatures ( & self , _their_node_id : & PublicKey , _msg : & AnnouncementSignatures ) { }
560564 fn peer_disconnected ( & self , their_node_id : & PublicKey , _no_connection_possible : bool ) {
561565 if * their_node_id == self . expected_pubkey {
566+ self . disconnected_flag . store ( true , Ordering :: SeqCst ) ;
562567 self . pubkey_disconnected . clone ( ) . try_send ( ( ) ) . unwrap ( ) ;
563568 }
564569 }
@@ -591,6 +596,7 @@ mod tests {
591596 expected_pubkey : b_pub,
592597 pubkey_connected : a_connected_sender,
593598 pubkey_disconnected : a_disconnected_sender,
599+ disconnected_flag : AtomicBool :: new ( false ) ,
594600 msg_events : Mutex :: new ( Vec :: new ( ) ) ,
595601 } ) ;
596602 let a_manager = Arc :: new ( PeerManager :: new ( MessageHandler {
@@ -604,6 +610,7 @@ mod tests {
604610 expected_pubkey : a_pub,
605611 pubkey_connected : b_connected_sender,
606612 pubkey_disconnected : b_disconnected_sender,
613+ disconnected_flag : AtomicBool :: new ( false ) ,
607614 msg_events : Mutex :: new ( Vec :: new ( ) ) ,
608615 } ) ;
609616 let b_manager = Arc :: new ( PeerManager :: new ( MessageHandler {
@@ -624,27 +631,29 @@ mod tests {
624631 } else { panic ! ( "Failed to bind to v4 localhost on common ports" ) ; } ;
625632
626633 let ( sender, _receiver) = mpsc:: channel ( 2 ) ;
627- let fut_a = super :: setup_outbound ( Arc :: clone ( & a_manager) , sender. clone ( ) , b_pub, tokio :: net :: TcpStream :: from_std ( conn_a) . unwrap ( ) ) ;
628- let fut_b = super :: setup_inbound ( b_manager, sender, tokio :: net :: TcpStream :: from_std ( conn_b) . unwrap ( ) ) ;
634+ let fut_a = super :: setup_outbound ( Arc :: clone ( & a_manager) , sender. clone ( ) , b_pub, conn_a) ;
635+ let fut_b = super :: setup_inbound ( b_manager, sender, conn_b) ;
629636
630637 tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , a_connected. recv ( ) ) . await . unwrap ( ) ;
631638 tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , b_connected. recv ( ) ) . await . unwrap ( ) ;
632639
633640 a_handler. msg_events . lock ( ) . unwrap ( ) . push ( MessageSendEvent :: HandleError {
634641 node_id : b_pub, action : ErrorAction :: DisconnectPeer { msg : None }
635642 } ) ;
636- assert ! ( a_disconnected . try_recv ( ) . is_err ( ) ) ;
637- assert ! ( b_disconnected . try_recv ( ) . is_err ( ) ) ;
643+ assert ! ( !a_handler . disconnected_flag . load ( Ordering :: SeqCst ) ) ;
644+ assert ! ( !b_handler . disconnected_flag . load ( Ordering :: SeqCst ) ) ;
638645
639646 a_manager. process_events ( ) ;
640647 tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , a_disconnected. recv ( ) ) . await . unwrap ( ) ;
641648 tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , b_disconnected. recv ( ) ) . await . unwrap ( ) ;
649+ assert ! ( a_handler. disconnected_flag. load( Ordering :: SeqCst ) ) ;
650+ assert ! ( b_handler. disconnected_flag. load( Ordering :: SeqCst ) ) ;
642651
643652 fut_a. await ;
644653 fut_b. await ;
645654 }
646655
647- #[ tokio:: test( threaded_scheduler ) ]
656+ #[ tokio:: test( flavor = "multi_thread" ) ]
648657 async fn basic_threaded_connection_test ( ) {
649658 do_basic_connection_test ( ) . await ;
650659 }
0 commit comments