From edbb2933a17d94158c18ed2c6a7f505f2b2ad72c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivan=20Schre=CC=81ter?= Date: Mon, 14 Apr 2025 17:41:45 +0200 Subject: [PATCH 1/2] Refactor: Reading log is an immutable operation Change the receiver to `&self` for methods reading the log and adjust callers. --- openraft/src/storage/helper.rs | 4 ++-- openraft/src/storage/v2/raft_log_reader.rs | 4 ++-- stores/memstore/src/lib.rs | 2 +- tests/tests/life_cycle/t10_initialization.rs | 2 +- tests/tests/membership/t11_add_learner.rs | 2 +- tests/tests/membership/t20_change_membership.rs | 2 +- tests/tests/snapshot_building/t10_build_snapshot.rs | 2 +- .../tests/snapshot_streaming/t30_purge_in_snapshot_logs.rs | 6 +++--- .../t34_replication_does_not_block_purge.rs | 2 +- 9 files changed, 13 insertions(+), 13 deletions(-) diff --git a/openraft/src/storage/helper.rs b/openraft/src/storage/helper.rs index 95eb4b00a..31fd1aa78 100644 --- a/openraft/src/storage/helper.rs +++ b/openraft/src/storage/helper.rs @@ -186,7 +186,7 @@ where chunk_size, ); - let mut log_reader = self.log_store.get_log_reader().await; + let log_reader = self.log_store.get_log_reader().await; while start < end { let chunk_end = std::cmp::min(end, start + chunk_size); @@ -290,7 +290,7 @@ where let step = 64; let mut res = vec![]; - let mut log_reader = self.log_store.get_log_reader().await; + let log_reader = self.log_store.get_log_reader().await; while start < end { let step_start = std::cmp::max(start, end.saturating_sub(step)); diff --git a/openraft/src/storage/v2/raft_log_reader.rs b/openraft/src/storage/v2/raft_log_reader.rs index 91f501f03..000ff4340 100644 --- a/openraft/src/storage/v2/raft_log_reader.rs +++ b/openraft/src/storage/v2/raft_log_reader.rs @@ -48,7 +48,7 @@ where C: RaftTypeConfig /// - The read operation must be transactional. That is, it should not reflect any state changes /// that occur after the read operation has commenced. async fn try_get_log_entries + Clone + Debug + OptionalSend>( - &mut self, + &self, range: RB, ) -> Result, StorageError>; @@ -70,7 +70,7 @@ where C: RaftTypeConfig /// /// The default implementation just returns the full range of log entries. #[since(version = "0.10.0")] - async fn limited_get_log_entries(&mut self, start: u64, end: u64) -> Result, StorageError> { + async fn limited_get_log_entries(&self, start: u64, end: u64) -> Result, StorageError> { self.try_get_log_entries(start..end).await } diff --git a/stores/memstore/src/lib.rs b/stores/memstore/src/lib.rs index 18cc353a0..e4758b260 100644 --- a/stores/memstore/src/lib.rs +++ b/stores/memstore/src/lib.rs @@ -239,7 +239,7 @@ pub fn new_mem_store() -> (Arc, Arc) { impl RaftLogReader for Arc { async fn try_get_log_entries + Clone + Debug + OptionalSend>( - &mut self, + &self, range: RB, ) -> Result>, StorageError> { let mut entries = vec![]; diff --git a/tests/tests/life_cycle/t10_initialization.rs b/tests/tests/life_cycle/t10_initialization.rs index 824644eb5..350b1ba0f 100644 --- a/tests/tests/life_cycle/t10_initialization.rs +++ b/tests/tests/life_cycle/t10_initialization.rs @@ -123,7 +123,7 @@ async fn initialization() -> anyhow::Result<()> { } for i in [0, 1, 2] { - let (mut sto, mut sm) = router.get_storage_handle(&1)?; + let (sto, mut sm) = router.get_storage_handle(&1)?; let first = sto.try_get_log_entries(0..2).await?.into_iter().next(); tracing::info!( diff --git a/tests/tests/membership/t11_add_learner.rs b/tests/tests/membership/t11_add_learner.rs index 02d877e6b..cfa97f9f2 100644 --- a/tests/tests/membership/t11_add_learner.rs +++ b/tests/tests/membership/t11_add_learner.rs @@ -68,7 +68,7 @@ async fn add_learner_basic() -> Result<()> { tracing::info!(log_index, "--- add_learner blocks until the replication catches up"); { - let (mut sto1, _sm1) = router.get_storage_handle(&1)?; + let (sto1, _sm1) = router.get_storage_handle(&1)?; let logs = sto1.try_get_log_entries(..).await?; diff --git a/tests/tests/membership/t20_change_membership.rs b/tests/tests/membership/t20_change_membership.rs index 36792f0fd..130774650 100644 --- a/tests/tests/membership/t20_change_membership.rs +++ b/tests/tests/membership/t20_change_membership.rs @@ -91,7 +91,7 @@ async fn change_with_new_learner_blocking() -> anyhow::Result<()> { tracing::info!(log_index, "--- change_membership blocks until success: {:?}", res); for node_id in 0..2 { - let (mut sto, _sm) = router.get_storage_handle(&node_id)?; + let (sto, _sm) = router.get_storage_handle(&node_id)?; let logs = sto.try_get_log_entries(..).await?; assert_eq!(log_index, logs[logs.len() - 1].log_id.index(), "node: {}", node_id); // 0-th log diff --git a/tests/tests/snapshot_building/t10_build_snapshot.rs b/tests/tests/snapshot_building/t10_build_snapshot.rs index 2829b2b73..caab32ee2 100644 --- a/tests/tests/snapshot_building/t10_build_snapshot.rs +++ b/tests/tests/snapshot_building/t10_build_snapshot.rs @@ -98,7 +98,7 @@ async fn build_snapshot() -> Result<()> { "--- logs should be deleted after installing snapshot; left only the last one" ); { - let (mut sto, _sm) = router.get_storage_handle(&1)?; + let (sto, _sm) = router.get_storage_handle(&1)?; let logs = sto.try_get_log_entries(..).await?; assert_eq!(2, logs.len()); assert_eq!(log_id(1, 0, log_index - 1), logs[0].log_id) diff --git a/tests/tests/snapshot_streaming/t30_purge_in_snapshot_logs.rs b/tests/tests/snapshot_streaming/t30_purge_in_snapshot_logs.rs index 208ce8ad5..e5e6374b0 100644 --- a/tests/tests/snapshot_streaming/t30_purge_in_snapshot_logs.rs +++ b/tests/tests/snapshot_streaming/t30_purge_in_snapshot_logs.rs @@ -35,14 +35,14 @@ async fn purge_in_snapshot_logs() -> Result<()> { let leader = router.get_raft_handle(&0)?; let learner = router.get_raft_handle(&1)?; - let (mut sto0, mut _sm0) = router.get_storage_handle(&0)?; + let (sto0, mut _sm0) = router.get_storage_handle(&0)?; tracing::info!(log_index, "--- build snapshot on leader, check purged log"); { log_index += router.client_request_many(0, "0", 10).await?; leader.trigger().snapshot().await?; leader.wait(timeout()).snapshot(log_id(1, 0, log_index), "building 1st snapshot").await?; - let (mut sto0, mut _sm0) = router.get_storage_handle(&0)?; + let (sto0, mut _sm0) = router.get_storage_handle(&0)?; // Wait for purge to complete. sleep(Duration::from_millis(500)).await; @@ -77,7 +77,7 @@ async fn purge_in_snapshot_logs() -> Result<()> { learner.wait(timeout()).snapshot(log_id(1, 0, log_index), "learner install snapshot").await?; - let (mut sto1, mut _sm) = router.get_storage_handle(&1)?; + let (sto1, mut _sm) = router.get_storage_handle(&1)?; let logs = sto1.try_get_log_entries(..).await?; assert_eq!(0, logs.len()); } diff --git a/tests/tests/snapshot_streaming/t34_replication_does_not_block_purge.rs b/tests/tests/snapshot_streaming/t34_replication_does_not_block_purge.rs index 28c09f460..4ecc5628e 100644 --- a/tests/tests/snapshot_streaming/t34_replication_does_not_block_purge.rs +++ b/tests/tests/snapshot_streaming/t34_replication_does_not_block_purge.rs @@ -51,7 +51,7 @@ async fn replication_does_not_block_purge() -> Result<()> { sleep(Duration::from_millis(500)).await; - let (mut sto0, mut _sm0) = router.get_storage_handle(&0)?; + let (sto0, mut _sm0) = router.get_storage_handle(&0)?; let logs = sto0.try_get_log_entries(..).await?; assert_eq!(max_keep as usize, logs.len(), "leader's local logs are purged"); } From 99976ee7286992751ff14c3ca0a75090abacab2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivan=20Schre=CC=81ter?= Date: Mon, 14 Apr 2025 17:42:26 +0200 Subject: [PATCH 2/2] Refactor: Move the SM worker loop to a SM method, so SM::apply() can be implemented more efficiently with overlapped I/O and similar. This is quick&dirty change only, with no documentation on newly-added structs, but it should be self-evident. --- openraft/src/core/raft_core.rs | 2 +- openraft/src/core/sm/handle.rs | 21 +- openraft/src/core/sm/worker.rs | 84 +---- openraft/src/storage/mod.rs | 4 + openraft/src/storage/v2/mod.rs | 4 + openraft/src/storage/v2/raft_state_machine.rs | 341 ++++++++++++++++++ 6 files changed, 376 insertions(+), 80 deletions(-) diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 67f90894a..9626804f7 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -776,7 +776,7 @@ where network, snapshot_network, self.log_store.get_log_reader().await, - self.sm_handle.new_snapshot_reader(), + self.sm_handle.new_snapshot_reader(self.tx_notification.clone()), self.tx_notification.clone(), tracing::span!(parent: &self.span, Level::DEBUG, "replication", id=display(&self.id), target=display(&target)), ) diff --git a/openraft/src/core/sm/handle.rs b/openraft/src/core/sm/handle.rs index 76b51dc85..7840b044a 100644 --- a/openraft/src/core/sm/handle.rs +++ b/openraft/src/core/sm/handle.rs @@ -3,7 +3,9 @@ use crate::async_runtime::MpscUnboundedSender; use crate::async_runtime::MpscUnboundedWeakSender; use crate::async_runtime::SendError; +use crate::core::notification::Notification; use crate::core::sm; +use crate::storage::RaftStateMachineCommand; use crate::storage::Snapshot; use crate::type_config::alias::JoinHandleOf; use crate::type_config::alias::MpscUnboundedSenderOf; @@ -15,7 +17,9 @@ use crate::RaftTypeConfig; pub(crate) struct Handle where C: RaftTypeConfig { - pub(in crate::core::sm) cmd_tx: MpscUnboundedSenderOf>, + pub(in crate::core::sm) cmd_tx: MpscUnboundedSenderOf>, + + pub(in crate::core::sm) tx_notification: MpscUnboundedSenderOf>, #[allow(dead_code)] pub(in crate::core::sm) join_handle: JoinHandleOf, @@ -24,15 +28,20 @@ where C: RaftTypeConfig impl Handle where C: RaftTypeConfig { - pub(crate) fn send(&mut self, cmd: sm::Command) -> Result<(), SendError>> { + pub(crate) fn send(&mut self, cmd: sm::Command) -> Result<(), SendError>> { tracing::debug!("sending command to state machine worker: {:?}", cmd); + let cmd = RaftStateMachineCommand::new(cmd, &self.tx_notification); self.cmd_tx.send(cmd) } /// Create a [`SnapshotReader`] to get the current snapshot from the state machine. - pub(crate) fn new_snapshot_reader(&self) -> SnapshotReader { + pub(crate) fn new_snapshot_reader( + &self, + tx_notification: MpscUnboundedSenderOf>, + ) -> SnapshotReader { SnapshotReader { cmd_tx: self.cmd_tx.downgrade(), + tx_notification, } } } @@ -46,7 +55,9 @@ where C: RaftTypeConfig /// It is weak because the [`Worker`] watches the close event of this channel for shutdown. /// /// [`Worker`]: sm::worker::Worker - cmd_tx: MpscUnboundedWeakSenderOf>, + cmd_tx: MpscUnboundedWeakSenderOf>, + + tx_notification: MpscUnboundedSenderOf>, } impl SnapshotReader @@ -68,7 +79,7 @@ where C: RaftTypeConfig }; // If fail to send command, cmd is dropped and tx will be dropped. - let _ = cmd_tx.send(cmd); + let _ = cmd_tx.send(RaftStateMachineCommand::new(cmd, &self.tx_notification)); let got = match rx.await { Ok(x) => x, diff --git a/openraft/src/core/sm/worker.rs b/openraft/src/core/sm/worker.rs index d55beef0b..89f241a2f 100644 --- a/openraft/src/core/sm/worker.rs +++ b/openraft/src/core/sm/worker.rs @@ -3,14 +3,11 @@ use std::collections::BTreeMap; use anyerror::AnyError; use tracing_futures::Instrument; -use crate::async_runtime::MpscUnboundedReceiver; use crate::async_runtime::MpscUnboundedSender; use crate::async_runtime::OneshotSender; -use crate::base::BoxAsyncOnceMut; use crate::core::notification::Notification; use crate::core::raft_msg::ResultSender; use crate::core::sm::handle::Handle; -use crate::core::sm::Command; use crate::core::sm::CommandResult; use crate::core::sm::Response; use crate::core::ApplyResult; @@ -23,6 +20,7 @@ use crate::raft::ClientWriteResponse; #[cfg(doc)] use crate::storage::RaftLogStorage; use crate::storage::RaftStateMachine; +use crate::storage::RaftStateMachineCommand; use crate::storage::Snapshot; use crate::type_config::alias::JoinHandleOf; use crate::type_config::alias::LogIdOf; @@ -48,7 +46,7 @@ where log_reader: LR, /// Read command from RaftCore to execute. - cmd_rx: MpscUnboundedReceiverOf>, + cmd_rx: MpscUnboundedReceiverOf>, /// Send back the result of the command to RaftCore. resp_tx: MpscUnboundedSenderOf>, @@ -68,6 +66,7 @@ where span: tracing::Span, ) -> Handle { let (cmd_tx, cmd_rx) = C::mpsc_unbounded(); + let tx_notification = resp_tx.clone(); let worker = Worker { state_machine, @@ -78,7 +77,11 @@ where let join_handle = worker.do_spawn(span); - Handle { cmd_tx, join_handle } + Handle { + cmd_tx, + tx_notification, + join_handle, + } } fn do_spawn(mut self, span: tracing::Span) -> JoinHandleOf { @@ -98,76 +101,9 @@ where #[tracing::instrument(level = "debug", skip_all)] async fn worker_loop(&mut self) -> Result<(), StorageError> { - loop { - let cmd = self.cmd_rx.recv().await; - let cmd = match cmd { - None => { - tracing::info!("{}: rx closed, state machine worker quit", func_name!()); - return Ok(()); - } - Some(x) => x, - }; - - tracing::debug!("{}: received command: {:?}", func_name!(), cmd); - - match cmd { - Command::BuildSnapshot => { - tracing::info!("{}: build snapshot", func_name!()); - - // It is a read operation and is spawned, and it responds in another task - self.build_snapshot(self.resp_tx.clone()).await; - } - Command::GetSnapshot { tx } => { - tracing::info!("{}: get snapshot", func_name!()); - - self.get_snapshot(tx).await?; - // GetSnapshot does not respond to RaftCore - } - Command::InstallFullSnapshot { io_id, snapshot } => { - tracing::info!("{}: install complete snapshot", func_name!()); - - let meta = snapshot.meta.clone(); - self.state_machine.install_snapshot(&meta, snapshot.snapshot).await?; - - tracing::info!("Done install complete snapshot, meta: {}", meta); - - let res = CommandResult::new(Ok(Response::InstallSnapshot((io_id, Some(meta))))); - let _ = self.resp_tx.send(Notification::sm(res)); - } - Command::BeginReceivingSnapshot { tx } => { - tracing::info!("{}: BeginReceivingSnapshot", func_name!()); - - let snapshot_data = self.state_machine.begin_receiving_snapshot().await?; - - let _ = tx.send(Ok(snapshot_data)); - // No response to RaftCore - } - Command::Apply { - first, - last, - mut client_resp_channels, - } => { - let resp = self.apply(first, last, &mut client_resp_channels).await?; - let res = CommandResult::new(Ok(Response::Apply(resp))); - let _ = self.resp_tx.send(Notification::sm(res)); - } - Command::Func { func, input_sm_type } => { - tracing::debug!("{}: run user defined Func", func_name!()); - - let res: Result>, _> = func.downcast(); - if let Ok(f) = res { - f(&mut self.state_machine).await; - } else { - tracing::warn!( - "User-defined SM function uses incorrect state machine type, expected: {}, got: {}", - std::any::type_name::(), - input_sm_type - ); - }; - } - }; - } + self.state_machine.worker(&mut self.cmd_rx, &self.log_reader).await } + #[tracing::instrument(level = "debug", skip_all)] async fn apply( &mut self, diff --git a/openraft/src/storage/mod.rs b/openraft/src/storage/mod.rs index 731c03e07..015189bc0 100644 --- a/openraft/src/storage/mod.rs +++ b/openraft/src/storage/mod.rs @@ -19,8 +19,12 @@ pub use self::log_state::LogState; pub use self::snapshot::Snapshot; pub use self::snapshot_meta::SnapshotMeta; pub use self::snapshot_signature::SnapshotSignature; +pub use self::v2::ApplyResultSender; +pub use self::v2::BuildSnapshotResultSender; +pub use self::v2::InstallSnapshotResultSender; pub use self::v2::RaftLogReader; pub use self::v2::RaftLogStorage; pub use self::v2::RaftLogStorageExt; pub use self::v2::RaftSnapshotBuilder; pub use self::v2::RaftStateMachine; +pub use self::v2::RaftStateMachineCommand; diff --git a/openraft/src/storage/v2/mod.rs b/openraft/src/storage/v2/mod.rs index 3564cf260..d977bd64a 100644 --- a/openraft/src/storage/v2/mod.rs +++ b/openraft/src/storage/v2/mod.rs @@ -13,4 +13,8 @@ pub use self::raft_log_reader::RaftLogReader; pub use self::raft_log_storage::RaftLogStorage; pub use self::raft_log_storage_ext::RaftLogStorageExt; pub use self::raft_snapshot_builder::RaftSnapshotBuilder; +pub use self::raft_state_machine::ApplyResultSender; +pub use self::raft_state_machine::BuildSnapshotResultSender; +pub use self::raft_state_machine::InstallSnapshotResultSender; pub use self::raft_state_machine::RaftStateMachine; +pub use self::raft_state_machine::RaftStateMachineCommand; diff --git a/openraft/src/storage/v2/raft_state_machine.rs b/openraft/src/storage/v2/raft_state_machine.rs index c2886976c..c1e00ad1d 100644 --- a/openraft/src/storage/v2/raft_state_machine.rs +++ b/openraft/src/storage/v2/raft_state_machine.rs @@ -1,9 +1,39 @@ +use std::collections::BTreeMap; +use std::fmt::Debug; + +use anyerror::AnyError; use openraft_macros::add_async_trait; use openraft_macros::since; +use super::RaftLogReader; +use crate::alias::AsyncRuntimeOf; +use crate::alias::MpscUnboundedReceiverOf; +use crate::alias::MpscUnboundedSenderOf; +use crate::alias::ResponderOf; +use crate::alias::SnapshotDataOf; +use crate::async_runtime::MpscUnboundedReceiver; +use crate::async_runtime::MpscUnboundedSender; +use crate::base::BoxAny; +use crate::base::BoxAsyncOnceMut; +use crate::core::notification::Notification; +use crate::core::raft_msg::ResultSender; +use crate::core::sm::Command; +use crate::core::sm::CommandResult; +use crate::core::sm::Response; +use crate::core::ApplyResult; +use crate::display_ext::DisplayOptionExt; +use crate::display_ext::DisplaySliceExt; +use crate::entry::RaftEntry; +use crate::entry::RaftPayload; +use crate::error::Infallible; +use crate::raft::responder::Responder; +use crate::raft::ClientWriteResponse; +use crate::raft_state::IOId; use crate::storage::Snapshot; use crate::storage::SnapshotMeta; use crate::type_config::alias::LogIdOf; +use crate::type_config::OneshotSender; +use crate::AsyncRuntime; use crate::OptionalSend; use crate::OptionalSync; use crate::RaftSnapshotBuilder; @@ -11,6 +41,167 @@ use crate::RaftTypeConfig; use crate::StorageError; use crate::StoredMembership; +pub struct BuildSnapshotResultSender { + resp_tx: MpscUnboundedSenderOf>, +} + +impl BuildSnapshotResultSender { + pub(crate) fn new(resp_tx: MpscUnboundedSenderOf>) -> Self { + Self { resp_tx } + } + + pub fn send(self, result: Result, StorageError>) { + let result = result.map(|snap| Response::BuildSnapshot(snap.meta)); + let cmd_res = CommandResult::new(result); + let _ = self.resp_tx.send(Notification::sm(cmd_res)); + } +} + +pub struct InstallSnapshotResultSender { + io_id: IOId, + resp_tx: MpscUnboundedSenderOf>, +} + +impl InstallSnapshotResultSender { + pub(crate) fn new(io_id: IOId, resp_tx: MpscUnboundedSenderOf>) -> Self { + Self { io_id, resp_tx } + } + + pub fn send(self, meta: SnapshotMeta) { + let res = CommandResult::new(Ok(Response::InstallSnapshot((self.io_id, Some(meta))))); + let _ = self.resp_tx.send(Notification::sm(res)); + } +} + +pub struct ApplyResultSender { + since: u64, + end: u64, + resp_tx: MpscUnboundedSenderOf>, +} + +impl ApplyResultSender { + pub(crate) fn new(since: u64, end: u64, resp_tx: MpscUnboundedSenderOf>) -> Self { + Self { since, end, resp_tx } + } + + pub fn send(self, result: Result, StorageError>) { + let res = CommandResult::new(result.map(|last_applied| { + Response::Apply(ApplyResult { + since: self.since, + end: self.end, + last_applied, + }) + })); + let _ = self.resp_tx.send(Notification::sm(res)); + } +} + +/// The payload of a state machine command. +pub enum RaftStateMachineCommand +where C: RaftTypeConfig +{ + /// Instruct the state machine to create a snapshot based on its most recent view. + BuildSnapshot { tx: BuildSnapshotResultSender }, + + /// Get the latest built snapshot. + GetSnapshot { tx: ResultSender>> }, + + BeginReceivingSnapshot { + tx: ResultSender, Infallible>, + }, + + InstallFullSnapshot { + /// The IO id used to update IO progress. + /// + /// Installing a snapshot is considered as an IO of AppendEntries `[0, + /// snapshot.last_log_id]` + snapshot: Snapshot, + tx: InstallSnapshotResultSender, + }, + + /// Apply the log entries to the state machine. + Apply { + /// The first log id to apply, inclusive. + first: LogIdOf, + + /// The last log id to apply, inclusive. + last: LogIdOf, + + client_resp_channels: BTreeMap>, + + tx: ApplyResultSender, + }, + + /// Apply a custom function to the state machine. + /// + /// To erase the type parameter `SM`, it is a + /// `Box Box> + Send + 'static>` + /// wrapped in a `Box` + Func { + func: BoxAny, + /// The SM type user specified, for debug purpose. + input_sm_type: &'static str, + }, +} + +impl Debug for RaftStateMachineCommand +where C: RaftTypeConfig +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::BuildSnapshot { .. } => write!(f, "BuildSnapshot"), + Self::GetSnapshot { .. } => write!(f, "GetSnapshot"), + Self::BeginReceivingSnapshot { .. } => write!(f, "BeginReceivingSnapshot"), + Self::InstallFullSnapshot { snapshot, tx, .. } => { + write!( + f, + "InstallFullSnapshot: meta: {:?}, io_id: {:?}", + snapshot.meta, tx.io_id + ) + } + Self::Apply { first, last, .. } => { + write!(f, "Apply: [{first},{last}]") + } + Self::Func { input_sm_type, .. } => { + write!(f, "Func({})", input_sm_type) + } + } + } +} + +impl RaftStateMachineCommand +where C: RaftTypeConfig +{ + pub(crate) fn new(cmd: Command, resp_tx: &MpscUnboundedSenderOf>) -> Self { + match cmd { + Command::BuildSnapshot => Self::BuildSnapshot { + tx: BuildSnapshotResultSender::new(resp_tx.clone()), + }, + Command::GetSnapshot { tx } => Self::GetSnapshot { tx }, + Command::BeginReceivingSnapshot { tx } => Self::BeginReceivingSnapshot { tx }, + Command::InstallFullSnapshot { io_id, snapshot } => Self::InstallFullSnapshot { + snapshot, + tx: InstallSnapshotResultSender::new(io_id, resp_tx.clone()), + }, + Command::Apply { + first, + last, + client_resp_channels, + } => { + let since = first.index; + let end = last.index + 1; + Self::Apply { + first, + last, + client_resp_channels, + tx: ApplyResultSender::new(since, end, resp_tx.clone()), + } + } + Command::Func { func, input_sm_type } => Self::Func { func, input_sm_type }, + } + } +} + /// API for state machine and snapshot. /// /// Snapshot is part of the state machine, because usually a snapshot is the persisted state of the @@ -24,6 +215,85 @@ where C: RaftTypeConfig /// Snapshot builder type. type SnapshotBuilder: RaftSnapshotBuilder; + /// Run state machine worker on this state machine. + async fn worker>( + &mut self, + cmd_rx: &mut MpscUnboundedReceiverOf>, + log_reader: &LR, + ) -> Result<(), StorageError> { + loop { + let cmd = cmd_rx.recv().await; + let cmd = match cmd { + None => { + tracing::info!("{}: rx closed, state machine worker quit", func_name!()); + return Ok(()); + } + Some(x) => x, + }; + + tracing::debug!("{}: received command: {:?}", func_name!(), cmd); + + match cmd { + RaftStateMachineCommand::BuildSnapshot { tx } => { + tracing::info!("{}: build snapshot", func_name!()); + // TODO: does this need to be abortable? + // use futures::future::abortable; + // let (fu, abort_handle) = abortable(async move { builder.build_snapshot().await }); + let mut builder = self.get_snapshot_builder().await; + // run the snapshot in a concurrent task + let _handle = AsyncRuntimeOf::::spawn(async move { + let res = builder.build_snapshot().await; + tx.send(res); + }); + tracing::info!("{} returning; spawned snapshot building task", func_name!()); + } + RaftStateMachineCommand::GetSnapshot { tx } => { + tracing::info!("{}: get snapshot", func_name!()); + let snapshot = self.get_current_snapshot().await?; + tracing::info!( + "sending back snapshot: meta: {}", + snapshot.as_ref().map(|s| &s.meta).display() + ); + let _ = tx.send(Ok(snapshot)); + } + RaftStateMachineCommand::InstallFullSnapshot { snapshot, tx } => { + tracing::info!("{}: install complete snapshot", func_name!()); + let meta = snapshot.meta.clone(); + self.install_snapshot(&meta, snapshot.snapshot).await?; + tracing::info!("Done install complete snapshot, meta: {}", meta); + tx.send(meta); + } + RaftStateMachineCommand::BeginReceivingSnapshot { tx } => { + tracing::info!("{}: BeginReceivingSnapshot", func_name!()); + let snapshot_data = self.begin_receiving_snapshot().await?; + let _ = tx.send(Ok(snapshot_data)); + } + RaftStateMachineCommand::Apply { + first, + last, + mut client_resp_channels, + tx, + } => { + self.apply_from_log(first, last, log_reader, &mut client_resp_channels, tx).await?; + } + RaftStateMachineCommand::Func { func, input_sm_type } => { + tracing::debug!("{}: run user defined Func", func_name!()); + + let res: Result>, _> = func.downcast(); + if let Ok(f) = res { + f(self).await; + } else { + tracing::warn!( + "User-defined SM function uses incorrect state machine type, expected: {}, got: {}", + std::any::type_name::(), + input_sm_type + ); + }; + } + }; + } + } + // TODO: This can be made into sync, provided all state machines will use atomic read or the // like. // --- @@ -69,6 +339,77 @@ where C: RaftTypeConfig I: IntoIterator + OptionalSend, I::IntoIter: OptionalSend; + #[tracing::instrument(level = "debug", skip_all)] + async fn apply_from_log>( + &mut self, + first: LogIdOf, + last: LogIdOf, + log_reader: &LR, + client_resp_channels: &mut BTreeMap>, + final_resp_tx: ApplyResultSender, + ) -> Result<(), StorageError> { + // TODO: prepare response before apply, + // so that an Entry does not need to be Clone, + // and no references will be used by apply + + let since = first.index(); + let end = last.index() + 1; + + let entries = log_reader.try_get_log_entries(since..end).await?; + if entries.len() != (end - since) as usize { + return Err(StorageError::read_logs(AnyError::error(format!( + "returned log entries count({}) does not match the input([{}, {}]))", + entries.len(), + since, + end + )))); + } + tracing::debug!(entries = display(entries.display()), "about to apply"); + + let last_applied = last; + + // Fake complain: avoid using `collect()` when not needed + #[allow(clippy::needless_collect)] + let applying_entries = entries.iter().map(|e| (e.log_id(), e.get_membership())).collect::>(); + + let n_entries = end - since; + + let apply_results = self.apply(entries).await?; + + let n_replies = apply_results.len() as u64; + + debug_assert_eq!( + n_entries, n_replies, + "n_entries: {} should equal n_replies: {}", + n_entries, n_replies + ); + + let mut results = apply_results.into_iter(); + let mut applying_entries = applying_entries.into_iter(); + for log_index in since..end { + let (log_id, membership) = applying_entries.next().unwrap(); + let resp = results.next().unwrap(); + let tx = client_resp_channels.remove(&log_index); + tracing::debug!( + log_id = debug(&log_id), + membership = debug(&membership), + "send_response" + ); + + if let Some(tx) = tx { + let res = Ok(ClientWriteResponse { + log_id, + data: resp, + membership, + }); + tx.send(res); + } + } + + final_resp_tx.send(Ok(last_applied)); + Ok(()) + } + /// Get the snapshot builder for the state machine. /// /// Usually it returns a snapshot view of the state machine(i.e., subsequent changes to the