@@ -51,7 +51,10 @@ use n0_future::{
51
51
time:: { self , Duration , Instant , MissedTickBehavior } ,
52
52
FuturesUnorderedBounded , SinkExt , StreamExt ,
53
53
} ;
54
- use tokio:: sync:: { mpsc, oneshot} ;
54
+ use tokio:: sync:: {
55
+ mpsc:: { self , OwnedPermit } ,
56
+ oneshot,
57
+ } ;
55
58
use tokio_util:: sync:: CancellationToken ;
56
59
use tracing:: { debug, error, event, info_span, instrument, trace, warn, Instrument , Level } ;
57
60
use url:: Url ;
@@ -159,6 +162,20 @@ struct ActiveRelayActor {
159
162
/// Token indicating the [`ActiveRelayActor`] should stop.
160
163
stop_token : CancellationToken ,
161
164
metrics : Arc < MagicsockMetrics > ,
165
+ /// Received relay packets that could not yet be forwarded to the magicsocket.
166
+ pending_received : Option < PendingRecv > ,
167
+ }
168
+
169
+ #[ derive( Debug ) ]
170
+ struct PendingRecv {
171
+ packet_iter : PacketSplitIter ,
172
+ blocked_on : RecvPath ,
173
+ }
174
+
175
+ #[ derive( Debug ) ]
176
+ enum RecvPath {
177
+ Data ,
178
+ Disco ,
162
179
}
163
180
164
181
#[ derive( Debug ) ]
@@ -263,6 +280,7 @@ impl ActiveRelayActor {
263
280
inactive_timeout : Box :: pin ( time:: sleep ( RELAY_INACTIVE_CLEANUP_TIME ) ) ,
264
281
stop_token,
265
282
metrics,
283
+ pending_received : None ,
266
284
}
267
285
}
268
286
@@ -612,7 +630,8 @@ impl ActiveRelayActor {
612
630
let fut = client_sink. send_all( & mut packet_stream) ;
613
631
self . run_sending( fut, & mut state, & mut client_stream) . await ?;
614
632
}
615
- msg = client_stream. next( ) => {
633
+ _ = forward_pending( & mut self . pending_received, & self . relay_datagrams_recv, & mut self . relay_disco_recv) , if self . pending_received. is_some( ) => { }
634
+ msg = client_stream. next( ) , if self . pending_received. is_none( ) => {
616
635
let Some ( msg) = msg else {
617
636
break Err ( anyhow!( "Stream closed by server." ) ) ;
618
637
} ;
@@ -659,33 +678,14 @@ impl ActiveRelayActor {
659
678
state. last_packet_src = Some ( remote_node_id) ;
660
679
state. nodes_present . insert ( remote_node_id) ;
661
680
}
662
- for datagram in PacketSplitIter :: new ( self . url . clone ( ) , remote_node_id, data) {
663
- let Ok ( datagram) = datagram else {
664
- warn ! ( "Invalid packet split" ) ;
665
- break ;
666
- } ;
667
- match crate :: disco:: source_and_box_bytes ( & datagram. buf ) {
668
- Some ( ( source, sealed_box) ) => {
669
- if remote_node_id != source {
670
- // TODO: return here?
671
- warn ! ( "Received relay disco message from connection for {}, but with message from {}" , remote_node_id. fmt_short( ) , source. fmt_short( ) ) ;
672
- }
673
- let message = RelayDiscoMessage {
674
- source,
675
- sealed_box,
676
- relay_url : datagram. url . clone ( ) ,
677
- relay_remote_node_id : datagram. src ,
678
- } ;
679
- if let Err ( err) = self . relay_disco_recv . try_send ( message) {
680
- warn ! ( "Dropping received relay disco packet: {err:#}" ) ;
681
- }
682
- }
683
- None => {
684
- if let Err ( err) = self . relay_datagrams_recv . try_send ( datagram) {
685
- warn ! ( "Dropping received relay data packet: {err:#}" ) ;
686
- }
687
- }
688
- }
681
+ let packet_iter = PacketSplitIter :: new ( self . url . clone ( ) , remote_node_id, data) ;
682
+ if let Some ( pending) = handle_received_packet_iter (
683
+ packet_iter,
684
+ None ,
685
+ & self . relay_datagrams_recv ,
686
+ & mut self . relay_disco_recv ,
687
+ ) {
688
+ self . pending_received = Some ( pending) ;
689
689
}
690
690
}
691
691
ReceivedMessage :: NodeGone ( node_id) => {
@@ -769,7 +769,8 @@ impl ActiveRelayActor {
769
769
break Err ( anyhow!( "Ping timeout" ) ) ;
770
770
}
771
771
// No need to read the inbox or datagrams to send.
772
- msg = client_stream. next( ) => {
772
+ _ = forward_pending( & mut self . pending_received, & self . relay_datagrams_recv, & mut self . relay_disco_recv) , if self . pending_received. is_some( ) => { }
773
+ msg = client_stream. next( ) , if self . pending_received. is_none( ) => {
773
774
let Some ( msg) = msg else {
774
775
break Err ( anyhow!( "Stream closed by server." ) ) ;
775
776
} ;
@@ -788,6 +789,105 @@ impl ActiveRelayActor {
788
789
}
789
790
}
790
791
792
+ /// Forward pending received packets to their queues.
793
+ ///
794
+ /// If `maybe_pending` is not empty, this will wait for the path the last received item
795
+ /// is blocked on (via [`PendingRecv::blocked_on`]) to become unblocked. It will then forward
796
+ /// the pending items, until a queue is blocked again. In that case, the remaining items will
797
+ /// be put back into `maybe_pending`. If all items could be sent, `maybe_pending` will be set
798
+ /// to `None`.
799
+ ///
800
+ /// This function is cancellation-safe: If the future is dropped at any point, all items are guaranteed
801
+ /// to either be sent into their respective queues, or are still in `maybe_pending`.
802
+ async fn forward_pending (
803
+ maybe_pending : & mut Option < PendingRecv > ,
804
+ relay_datagrams_recv : & RelayDatagramRecvQueue ,
805
+ relay_disco_recv : & mut mpsc:: Sender < RelayDiscoMessage > ,
806
+ ) {
807
+ // We take a mutable reference onto the inner value.
808
+ // we're not `take`ing it here, because this would make the function not cancellation safe.
809
+ let Some ( ref mut pending) = maybe_pending else {
810
+ return ;
811
+ } ;
812
+ let disco_permit = match pending. blocked_on {
813
+ RecvPath :: Data => {
814
+ std:: future:: poll_fn ( |cx| relay_datagrams_recv. poll_send_ready ( cx) )
815
+ . await
816
+ . ok ( ) ;
817
+ None
818
+ }
819
+ RecvPath :: Disco => {
820
+ let Ok ( permit) = relay_disco_recv. clone ( ) . reserve_owned ( ) . await else {
821
+ return ;
822
+ } ;
823
+ Some ( permit)
824
+ }
825
+ } ;
826
+ // We now take the inner value by value. it is cancellation safe here because
827
+ // no further `await`s occur after here.
828
+ // The unwrap is guaranteed to be safe because we checked above that it is not none.
829
+ #[ allow( clippy:: unwrap_used, reason = "checked above" ) ]
830
+ let pending = maybe_pending. take ( ) . unwrap ( ) ;
831
+ if let Some ( pending) = handle_received_packet_iter (
832
+ pending. packet_iter ,
833
+ disco_permit,
834
+ relay_datagrams_recv,
835
+ relay_disco_recv,
836
+ ) {
837
+ * maybe_pending = Some ( pending) ;
838
+ }
839
+ }
840
+
841
+ fn handle_received_packet_iter (
842
+ mut packet_iter : PacketSplitIter ,
843
+ mut disco_permit : Option < OwnedPermit < RelayDiscoMessage > > ,
844
+ relay_datagrams_recv : & RelayDatagramRecvQueue ,
845
+ relay_disco_recv : & mut mpsc:: Sender < RelayDiscoMessage > ,
846
+ ) -> Option < PendingRecv > {
847
+ let remote_node_id = packet_iter. remote_node_id ( ) ;
848
+ for datagram in & mut packet_iter {
849
+ let Ok ( datagram) = datagram else {
850
+ warn ! ( "Invalid packet split" ) ;
851
+ return None ;
852
+ } ;
853
+ match crate :: disco:: source_and_box_bytes ( & datagram. buf ) {
854
+ Some ( ( source, sealed_box) ) => {
855
+ if remote_node_id != source {
856
+ // TODO: return here?
857
+ warn ! ( "Received relay disco message from connection for {}, but with message from {}" , remote_node_id. fmt_short( ) , source. fmt_short( ) ) ;
858
+ }
859
+ let message = RelayDiscoMessage {
860
+ source,
861
+ sealed_box,
862
+ relay_url : datagram. url . clone ( ) ,
863
+ relay_remote_node_id : datagram. src ,
864
+ } ;
865
+ if let Some ( permit) = disco_permit. take ( ) {
866
+ permit. send ( message) ;
867
+ } else if let Err ( err) = relay_disco_recv. try_send ( message) {
868
+ warn ! ( "Dropping received relay disco packet: {err:#}" ) ;
869
+ packet_iter. push_front ( datagram) ;
870
+ return Some ( PendingRecv {
871
+ packet_iter,
872
+ blocked_on : RecvPath :: Disco ,
873
+ } ) ;
874
+ }
875
+ }
876
+ None => {
877
+ if let Err ( err) = relay_datagrams_recv. try_send ( datagram) {
878
+ warn ! ( "Dropping received relay data packet: {err:#}" ) ;
879
+ packet_iter. push_front ( err. into_inner ( ) ) ;
880
+ return Some ( PendingRecv {
881
+ packet_iter,
882
+ blocked_on : RecvPath :: Data ,
883
+ } ) ;
884
+ }
885
+ }
886
+ }
887
+ }
888
+ None
889
+ }
890
+
791
891
/// Shared state when the [`ActiveRelayActor`] is connected to a relay server.
792
892
///
793
893
/// Common state between [`ActiveRelayActor::run_connected`] and
@@ -1270,12 +1370,22 @@ struct PacketSplitIter {
1270
1370
url : RelayUrl ,
1271
1371
src : NodeId ,
1272
1372
bytes : Bytes ,
1373
+ next : Option < RelayRecvDatagram > ,
1273
1374
}
1274
1375
1275
1376
impl PacketSplitIter {
1276
1377
/// Create a new PacketSplitIter from a packet.
1277
1378
fn new ( url : RelayUrl , src : NodeId , bytes : Bytes ) -> Self {
1278
- Self { url, src, bytes }
1379
+ Self {
1380
+ url,
1381
+ src,
1382
+ bytes,
1383
+ next : None ,
1384
+ }
1385
+ }
1386
+
1387
+ fn remote_node_id ( & self ) -> NodeId {
1388
+ self . src
1279
1389
}
1280
1390
1281
1391
fn fail ( & mut self ) -> Option < std:: io:: Result < RelayRecvDatagram > > {
@@ -1285,13 +1395,20 @@ impl PacketSplitIter {
1285
1395
"" ,
1286
1396
) ) )
1287
1397
}
1398
+
1399
+ fn push_front ( & mut self , item : RelayRecvDatagram ) {
1400
+ self . next = Some ( item) ;
1401
+ }
1288
1402
}
1289
1403
1290
1404
impl Iterator for PacketSplitIter {
1291
1405
type Item = std:: io:: Result < RelayRecvDatagram > ;
1292
1406
1293
1407
fn next ( & mut self ) -> Option < Self :: Item > {
1294
1408
use bytes:: Buf ;
1409
+ if let Some ( item) = self . next . take ( ) {
1410
+ return Some ( Ok ( item) ) ;
1411
+ }
1295
1412
if self . bytes . has_remaining ( ) {
1296
1413
if self . bytes . remaining ( ) < 2 {
1297
1414
return self . fail ( ) ;
0 commit comments