Skip to content

Commit bcbcd23

Browse files
committed
refactor: apply backpressure without blocking the actor loop
1 parent 7c100fa commit bcbcd23

File tree

2 files changed

+177
-39
lines changed

2 files changed

+177
-39
lines changed

iroh/src/magicsock.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,15 +2104,17 @@ impl RelayDatagramSendChannelReceiver {
21042104
#[derive(Debug)]
21052105
struct RelayDatagramRecvQueue {
21062106
queue: ConcurrentQueue<RelayRecvDatagram>,
2107-
waker: AtomicWaker,
2107+
recv_waker: AtomicWaker,
2108+
send_waker: AtomicWaker,
21082109
}
21092110

21102111
impl RelayDatagramRecvQueue {
21112112
/// Creates a new, empty queue with a fixed size bound of 512 items.
21122113
fn new() -> Self {
21132114
Self {
21142115
queue: ConcurrentQueue::bounded(512),
2115-
waker: AtomicWaker::new(),
2116+
recv_waker: AtomicWaker::new(),
2117+
send_waker: AtomicWaker::new(),
21162118
}
21172119
}
21182120

@@ -2125,10 +2127,21 @@ impl RelayDatagramRecvQueue {
21252127
item: RelayRecvDatagram,
21262128
) -> Result<(), concurrent_queue::PushError<RelayRecvDatagram>> {
21272129
self.queue.push(item).inspect(|_| {
2128-
self.waker.wake();
2130+
self.recv_waker.wake();
21292131
})
21302132
}
21312133

2134+
fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
2135+
if self.queue.is_closed() {
2136+
Poll::Ready(Err(anyhow!("Queue closed")))
2137+
} else if !self.queue.is_full() {
2138+
Poll::Ready(Ok(()))
2139+
} else {
2140+
self.send_waker.register(cx.waker());
2141+
Poll::Pending
2142+
}
2143+
}
2144+
21322145
/// Polls for new items in the queue.
21332146
///
21342147
/// Although this method is available from `&self`, it must not be
@@ -2143,23 +2156,31 @@ impl RelayDatagramRecvQueue {
21432156
/// to be able to poll from `&self`.
21442157
fn poll_recv(&self, cx: &mut Context) -> Poll<Result<RelayRecvDatagram>> {
21452158
match self.queue.pop() {
2146-
Ok(value) => Poll::Ready(Ok(value)),
2159+
Ok(value) => {
2160+
self.send_waker.wake();
2161+
Poll::Ready(Ok(value))
2162+
}
21472163
Err(concurrent_queue::PopError::Empty) => {
2148-
self.waker.register(cx.waker());
2164+
self.recv_waker.register(cx.waker());
21492165

21502166
match self.queue.pop() {
21512167
Ok(value) => {
2152-
self.waker.take();
2168+
self.send_waker.wake();
2169+
self.recv_waker.take();
21532170
Poll::Ready(Ok(value))
21542171
}
21552172
Err(concurrent_queue::PopError::Empty) => Poll::Pending,
21562173
Err(concurrent_queue::PopError::Closed) => {
2157-
self.waker.take();
2174+
self.recv_waker.take();
2175+
self.send_waker.wake();
21582176
Poll::Ready(Err(anyhow!("Queue closed")))
21592177
}
21602178
}
21612179
}
2162-
Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(anyhow!("Queue closed"))),
2180+
Err(concurrent_queue::PopError::Closed) => {
2181+
self.send_waker.wake();
2182+
Poll::Ready(Err(anyhow!("Queue closed")))
2183+
}
21632184
}
21642185
}
21652186
}

iroh/src/magicsock/relay_actor.rs

Lines changed: 148 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ use n0_future::{
5151
time::{self, Duration, Instant, MissedTickBehavior},
5252
FuturesUnorderedBounded, SinkExt, StreamExt,
5353
};
54-
use tokio::sync::{mpsc, oneshot};
54+
use tokio::sync::{
55+
mpsc::{self, OwnedPermit},
56+
oneshot,
57+
};
5558
use tokio_util::sync::CancellationToken;
5659
use tracing::{debug, error, event, info_span, instrument, trace, warn, Instrument, Level};
5760
use url::Url;
@@ -159,6 +162,20 @@ struct ActiveRelayActor {
159162
/// Token indicating the [`ActiveRelayActor`] should stop.
160163
stop_token: CancellationToken,
161164
metrics: Arc<MagicsockMetrics>,
165+
/// Received relay packets that could not yet be forwarded to the magicsocket.
166+
pending_received: Option<PendingRecv>,
167+
}
168+
169+
#[derive(Debug)]
170+
struct PendingRecv {
171+
packet_iter: PacketSplitIter,
172+
blocked_on: RecvPath,
173+
}
174+
175+
#[derive(Debug)]
176+
enum RecvPath {
177+
Data,
178+
Disco,
162179
}
163180

164181
#[derive(Debug)]
@@ -263,6 +280,7 @@ impl ActiveRelayActor {
263280
inactive_timeout: Box::pin(time::sleep(RELAY_INACTIVE_CLEANUP_TIME)),
264281
stop_token,
265282
metrics,
283+
pending_received: None,
266284
}
267285
}
268286

@@ -612,7 +630,8 @@ impl ActiveRelayActor {
612630
let fut = client_sink.send_all(&mut packet_stream);
613631
self.run_sending(fut, &mut state, &mut client_stream).await?;
614632
}
615-
msg = client_stream.next() => {
633+
_ = forward_pending(&mut self.pending_received, &self.relay_datagrams_recv, &mut self.relay_disco_recv), if self.pending_received.is_some() => {}
634+
msg = client_stream.next(), if self.pending_received.is_none() => {
616635
let Some(msg) = msg else {
617636
break Err(anyhow!("Stream closed by server."));
618637
};
@@ -659,33 +678,14 @@ impl ActiveRelayActor {
659678
state.last_packet_src = Some(remote_node_id);
660679
state.nodes_present.insert(remote_node_id);
661680
}
662-
for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, data) {
663-
let Ok(datagram) = datagram else {
664-
warn!("Invalid packet split");
665-
break;
666-
};
667-
match crate::disco::source_and_box_bytes(&datagram.buf) {
668-
Some((source, sealed_box)) => {
669-
if remote_node_id != source {
670-
// TODO: return here?
671-
warn!("Received relay disco message from connection for {}, but with message from {}", remote_node_id.fmt_short(), source.fmt_short());
672-
}
673-
let message = RelayDiscoMessage {
674-
source,
675-
sealed_box,
676-
relay_url: datagram.url.clone(),
677-
relay_remote_node_id: datagram.src,
678-
};
679-
if let Err(err) = self.relay_disco_recv.try_send(message) {
680-
warn!("Dropping received relay disco packet: {err:#}");
681-
}
682-
}
683-
None => {
684-
if let Err(err) = self.relay_datagrams_recv.try_send(datagram) {
685-
warn!("Dropping received relay data packet: {err:#}");
686-
}
687-
}
688-
}
681+
let packet_iter = PacketSplitIter::new(self.url.clone(), remote_node_id, data);
682+
if let Some(pending) = handle_received_packet_iter(
683+
packet_iter,
684+
None,
685+
&self.relay_datagrams_recv,
686+
&mut self.relay_disco_recv,
687+
) {
688+
self.pending_received = Some(pending);
689689
}
690690
}
691691
ReceivedMessage::NodeGone(node_id) => {
@@ -769,7 +769,8 @@ impl ActiveRelayActor {
769769
break Err(anyhow!("Ping timeout"));
770770
}
771771
// No need to read the inbox or datagrams to send.
772-
msg = client_stream.next() => {
772+
_ = forward_pending(&mut self.pending_received, &self.relay_datagrams_recv, &mut self.relay_disco_recv), if self.pending_received.is_some() => {}
773+
msg = client_stream.next(), if self.pending_received.is_none() => {
773774
let Some(msg) = msg else {
774775
break Err(anyhow!("Stream closed by server."));
775776
};
@@ -788,6 +789,105 @@ impl ActiveRelayActor {
788789
}
789790
}
790791

792+
/// Forward pending received packets to their queues.
793+
///
794+
/// If `maybe_pending` is not empty, this will wait for the path the last received item
795+
/// is blocked on (via [`PendingRecv::blocked_on`]) to become unblocked. It will then forward
796+
/// the pending items, until a queue is blocked again. In that case, the remaining items will
797+
/// be put back into `maybe_pending`. If all items could be sent, `maybe_pending` will be set
798+
/// to `None`.
799+
///
800+
/// This function is cancellation-safe: If the future is dropped at any point, all items are guaranteed
801+
/// to either be sent into their respective queues, or are still in `maybe_pending`.
802+
async fn forward_pending(
803+
maybe_pending: &mut Option<PendingRecv>,
804+
relay_datagrams_recv: &RelayDatagramRecvQueue,
805+
relay_disco_recv: &mut mpsc::Sender<RelayDiscoMessage>,
806+
) {
807+
// We take a mutable reference onto the inner value.
808+
// we're not `take`ing it here, because this would make the function not cancellation safe.
809+
let Some(ref mut pending) = maybe_pending else {
810+
return;
811+
};
812+
let disco_permit = match pending.blocked_on {
813+
RecvPath::Data => {
814+
std::future::poll_fn(|cx| relay_datagrams_recv.poll_send_ready(cx))
815+
.await
816+
.ok();
817+
None
818+
}
819+
RecvPath::Disco => {
820+
let Ok(permit) = relay_disco_recv.clone().reserve_owned().await else {
821+
return;
822+
};
823+
Some(permit)
824+
}
825+
};
826+
// We now take the inner value by value. it is cancellation safe here because
827+
// no further `await`s occur after here.
828+
// The unwrap is guaranteed to be safe because we checked above that it is not none.
829+
#[allow(clippy::unwrap_used, reason = "checked above")]
830+
let pending = maybe_pending.take().unwrap();
831+
if let Some(pending) = handle_received_packet_iter(
832+
pending.packet_iter,
833+
disco_permit,
834+
relay_datagrams_recv,
835+
relay_disco_recv,
836+
) {
837+
*maybe_pending = Some(pending);
838+
}
839+
}
840+
841+
fn handle_received_packet_iter(
842+
mut packet_iter: PacketSplitIter,
843+
mut disco_permit: Option<OwnedPermit<RelayDiscoMessage>>,
844+
relay_datagrams_recv: &RelayDatagramRecvQueue,
845+
relay_disco_recv: &mut mpsc::Sender<RelayDiscoMessage>,
846+
) -> Option<PendingRecv> {
847+
let remote_node_id = packet_iter.remote_node_id();
848+
for datagram in &mut packet_iter {
849+
let Ok(datagram) = datagram else {
850+
warn!("Invalid packet split");
851+
return None;
852+
};
853+
match crate::disco::source_and_box_bytes(&datagram.buf) {
854+
Some((source, sealed_box)) => {
855+
if remote_node_id != source {
856+
// TODO: return here?
857+
warn!("Received relay disco message from connection for {}, but with message from {}", remote_node_id.fmt_short(), source.fmt_short());
858+
}
859+
let message = RelayDiscoMessage {
860+
source,
861+
sealed_box,
862+
relay_url: datagram.url.clone(),
863+
relay_remote_node_id: datagram.src,
864+
};
865+
if let Some(permit) = disco_permit.take() {
866+
permit.send(message);
867+
} else if let Err(err) = relay_disco_recv.try_send(message) {
868+
warn!("Dropping received relay disco packet: {err:#}");
869+
packet_iter.push_front(datagram);
870+
return Some(PendingRecv {
871+
packet_iter,
872+
blocked_on: RecvPath::Disco,
873+
});
874+
}
875+
}
876+
None => {
877+
if let Err(err) = relay_datagrams_recv.try_send(datagram) {
878+
warn!("Dropping received relay data packet: {err:#}");
879+
packet_iter.push_front(err.into_inner());
880+
return Some(PendingRecv {
881+
packet_iter,
882+
blocked_on: RecvPath::Data,
883+
});
884+
}
885+
}
886+
}
887+
}
888+
None
889+
}
890+
791891
/// Shared state when the [`ActiveRelayActor`] is connected to a relay server.
792892
///
793893
/// Common state between [`ActiveRelayActor::run_connected`] and
@@ -1270,12 +1370,22 @@ struct PacketSplitIter {
12701370
url: RelayUrl,
12711371
src: NodeId,
12721372
bytes: Bytes,
1373+
next: Option<RelayRecvDatagram>,
12731374
}
12741375

12751376
impl PacketSplitIter {
12761377
/// Create a new PacketSplitIter from a packet.
12771378
fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self {
1278-
Self { url, src, bytes }
1379+
Self {
1380+
url,
1381+
src,
1382+
bytes,
1383+
next: None,
1384+
}
1385+
}
1386+
1387+
fn remote_node_id(&self) -> NodeId {
1388+
self.src
12791389
}
12801390

12811391
fn fail(&mut self) -> Option<std::io::Result<RelayRecvDatagram>> {
@@ -1285,13 +1395,20 @@ impl PacketSplitIter {
12851395
"",
12861396
)))
12871397
}
1398+
1399+
fn push_front(&mut self, item: RelayRecvDatagram) {
1400+
self.next = Some(item);
1401+
}
12881402
}
12891403

12901404
impl Iterator for PacketSplitIter {
12911405
type Item = std::io::Result<RelayRecvDatagram>;
12921406

12931407
fn next(&mut self) -> Option<Self::Item> {
12941408
use bytes::Buf;
1409+
if let Some(item) = self.next.take() {
1410+
return Some(Ok(item));
1411+
}
12951412
if self.bytes.has_remaining() {
12961413
if self.bytes.remaining() < 2 {
12971414
return self.fail();

0 commit comments

Comments
 (0)