From a973f0001f7e2834ce66516367f4d578ae5d4b8f Mon Sep 17 00:00:00 2001 From: Colin Marc Date: Sat, 22 Feb 2025 15:31:01 +0100 Subject: [PATCH] feat: add a higher-level async Client --- Cargo.toml | 18 +- README.md | 8 +- examples/playback.rs | 8 +- examples/playback_async.rs | 131 ++++++ examples/record.rs | 5 +- examples/record_async.rs | 147 +++++++ src/client.rs | 783 ++++++++++++++++++++++++++++++++++ src/client/playback_source.rs | 109 +++++ src/client/playback_stream.rs | 184 ++++++++ src/client/reactor.rs | 575 +++++++++++++++++++++++++ src/client/record_sink.rs | 280 ++++++++++++ src/client/record_stream.rs | 161 +++++++ src/lib.rs | 2 + src/protocol.rs | 11 +- 14 files changed, 2406 insertions(+), 16 deletions(-) create mode 100644 examples/playback_async.rs create mode 100644 examples/record_async.rs create mode 100644 src/client.rs create mode 100644 src/client/playback_source.rs create mode 100644 src/client/playback_stream.rs create mode 100644 src/client/reactor.rs create mode 100644 src/client/record_sink.rs create mode 100644 src/client/record_stream.rs diff --git a/Cargo.toml b/Cargo.toml index 591f2fe..ea1fbde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,20 +14,26 @@ license = "MIT" members = ["patrace"] [dependencies] -bitflags = "2.4.1" -byteorder = "1.5.0" -enum-primitive-derive = "0.3.0" -num-traits = "0.2.17" -thiserror = "1.0.51" +bitflags = "2" +byteorder = "1" +enum-primitive-derive = "0.3" +futures = "0.3" +log = "0.4" +mio = { version = "1", features = ["os-ext", "os-poll", "net"] } +num-traits = "0.2" +thiserror = "1" [dev-dependencies] anyhow = "1.0.76" assert_matches = "1.5.0" hound = "3.5.1" indicatif = "0.17.7" -mio = { version = "1", features = ["os-ext", "os-poll", "net"] } mio-timerfd = "0.2.0" pretty_assertions = "1.4.0" +rand = "0.9" +test-log = "0.2" +tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7", features = ["io", "compat"] } [features] _integration-tests = [] diff --git a/README.md b/README.md index 954145f..b62432e 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,16 @@ This is a native rust implementation of the [PulseAudio](https://www.freedesktop Currently implemented: - Low-level serialization and deserialization of the wire format (called "tagstructs") + - A higher level `async`-friendly API Not yet implemented (but contributions welcome!) - - A higher level `async`-friendly API - `memfd`/`shm` shenanigans for zero-copy streaming Examples: - [Listing sinks](examples/list-sinks.rs) - [Subscribing to server events](examples/subscribe.rs) - - [Playing an audio file](examples/playback.rs) - - [Recording audio](examples/record.rs) - - [Acting as a sound server](examples/server.rs) \ No newline at end of file + - [Playing an audio file](examples/playback.rs) and the [async version](examples/playback_async.rs) + - [Recording audio](examples/record.rs) and the [async version](examples/record_async.rs) + - [Acting as a sound server](examples/server.rs) diff --git a/examples/playback.rs b/examples/playback.rs index 3d07774..65ef629 100644 --- a/examples/playback.rs +++ b/examples/playback.rs @@ -1,5 +1,7 @@ -// To run this example, run the following command: -// cargo run --example playback -- testfiles/victory.wav +//! A simple example that plays a WAV file to the server. +//! +//! Run with: +//! cargo run --example playback -- testfiles/victory.wav use std::{ ffi::CString, @@ -65,7 +67,7 @@ fn main() -> anyhow::Result<()> { }, channel_map, cvolume: Some(protocol::ChannelVolume::norm(2)), - sink_name: Some(CString::new("@DEFAULT_SINK@")?), + sink_name: Some(protocol::DEFAULT_SINK.to_owned()), ..Default::default() }), protocol_version, diff --git a/examples/playback_async.rs b/examples/playback_async.rs new file mode 100644 index 0000000..9b29ed3 --- /dev/null +++ b/examples/playback_async.rs @@ -0,0 +1,131 @@ +//! An example of using the higher-level [pulseaudio::Client] API to play audio +//! with an async runtime. +//! +//! Run with: +//! cargo run --example playback -- testfiles/victory.wav + +use std::{fs::File, io, path::Path, time}; + +use anyhow::{bail, Context as _}; +use pulseaudio::{protocol, AsPlaybackSource, Client, PlaybackStream}; + +// We're using tokio as a runtime here, but tokio is not a dependency of the +// crate, and it should be compatible with any executor. +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + println!("Usage: {} ", args[0]); + return Ok(()); + } + + // Load the audio file, and choose parameters for the playback stream based + // on the format of the audio. We only support 16bit integer PCM in this + // example. + let file = File::open(Path::new(&args[1]))?; + let mut wav_reader = hound::WavReader::new(file)?; + let spec = wav_reader.spec(); + + let format = match (spec.bits_per_sample, spec.sample_format) { + (16, hound::SampleFormat::Int) => protocol::SampleFormat::S16Le, + _ => bail!( + "unsupported sample format: {}bit {:?}", + spec.bits_per_sample, + spec.sample_format, + ), + }; + + let channel_map = match spec.channels { + 1 => protocol::ChannelMap::mono(), + 2 => protocol::ChannelMap::stereo(), + _ => bail!("unsupported channel count: {}", spec.channels), + }; + + // Set up a progress bar for displaying during playback. + let file_duration = + time::Duration::from_secs(wav_reader.duration() as u64 / spec.sample_rate as u64); + let file_bytes = + wav_reader.duration() as u64 * (spec.channels * spec.bits_per_sample / 8) as u64; + let pb = indicatif::ProgressBar::new(file_bytes) + .with_style(indicatif::ProgressStyle::with_template(&format!( + "[{{elapsed_precise}} / {}] {{bar}} {{msg}}", + indicatif::FormattedDuration(file_duration) + ))?) + .with_finish(indicatif::ProgressFinish::AndLeave); + + let params = protocol::PlaybackStreamParams { + sample_spec: protocol::SampleSpec { + format, + channels: spec.channels as u8, + sample_rate: spec.sample_rate, + }, + channel_map, + cvolume: Some(protocol::ChannelVolume::norm(2)), + sink_name: Some(protocol::DEFAULT_SINK.to_owned()), + ..Default::default() + }; + + // First, establish a connection to the PulseAudio server. + let client = Client::from_env(c"test-playback-rs").context("Failed to create client")?; + + // Create a callback function, which is called by the client to write data + // to the stream. + let callback = move |data: &mut [u8]| copy_chunk(&mut wav_reader, data); + + let stream = client + .create_playback_stream(params, callback.as_playback_source()) + .await + .context("Failed to create playback stream")?; + + // Update our progress bar in a loop while waiting for the stream to finish. + tokio::select! { + res = stream.play_all() => res.context("Failed to play stream")?, + _ = async { + loop { + if let Err(err) = update_progress(stream.clone(), pb.clone()).await { + eprintln!("Failed to update progress: {}", err); + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + } => (), + } + + Ok(()) +} + +async fn update_progress( + stream: PlaybackStream, + pb: indicatif::ProgressBar, +) -> Result<(), pulseaudio::ClientError> { + let timing_info = stream.timing_info().await?; + + // Use the information from the server to display the current playback latency. + let latency = time::Duration::from_micros(timing_info.sink_usec + timing_info.source_usec); + + pb.set_message(format!("{}ms latency", latency.as_millis())); + + // The playback position is the server's offset into the buffer. + // We'll use that to update the progress bar. + pb.set_position(timing_info.read_offset as u64); + Ok(()) +} + +fn copy_chunk(wav_reader: &mut hound::WavReader, buf: &mut [u8]) -> usize { + use byteorder::WriteBytesExt; + let len = buf.len(); + assert!(len % 2 == 0); + + let mut cursor = std::io::Cursor::new(buf); + for sample in wav_reader.samples::().filter_map(Result::ok) { + if cursor.write_i16::(sample).is_err() { + break; + } + + if cursor.position() == len as u64 { + break; + } + } + + cursor.position() as usize +} diff --git a/examples/record.rs b/examples/record.rs index 7162c45..a3ccf36 100644 --- a/examples/record.rs +++ b/examples/record.rs @@ -1,4 +1,5 @@ -//! A simple example that records audio from the default input. +//! A simple example that records audio from the default input and saves it as +//! a WAV file. //! //! Run with: //! cargo run --example record /tmp/recording.wav @@ -29,7 +30,7 @@ pub fn main() -> anyhow::Result<()> { sock.get_mut(), 10, &protocol::Command::GetSourceInfo(protocol::GetSourceInfo { - name: Some(CString::new("@DEFAULT_SOURCE@")?), + name: Some(protocol::DEFAULT_SOURCE.to_owned()), ..Default::default() }), protocol_version, diff --git a/examples/record_async.rs b/examples/record_async.rs new file mode 100644 index 0000000..d45c722 --- /dev/null +++ b/examples/record_async.rs @@ -0,0 +1,147 @@ +//! An example using the higher-level [pulseaudio::Client] API with an async +//! runtime to record audio. +//! +//! Run with: +//! cargo run --example record_async /tmp/recording.wav + +use std::{ + fs::File, + io::{self, BufWriter, Read}, + path::Path, + time, +}; + +use anyhow::{bail, Context as _}; +use futures::StreamExt as _; +use pulseaudio::{protocol, Client}; +use tokio::sync::oneshot; +use tokio_util::{compat::FuturesAsyncReadCompatExt as _, io::ReaderStream}; + +// We're using tokio as a runtime here, but tokio is not a dependency of the +// crate, and it should be compatible with any executor. +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + println!("Usage: {} ", args[0]); + return Ok(()); + } + + // First, establish a connection to the PulseAudio server. + let client = Client::from_env(c"test-record-rs").context("Failed to create client")?; + + // Determine the default stream format. + let source_info = client + .source_info_by_name(protocol::DEFAULT_SOURCE.to_owned()) + .await?; + + // Create a record stream on the server. This will negotiate the actual + // format. + let params = protocol::RecordStreamParams { + source_index: Some(source_info.index), + sample_spec: protocol::SampleSpec { + format: source_info.sample_spec.format, + channels: source_info.channel_map.num_channels(), + sample_rate: source_info.sample_spec.sample_rate, + }, + channel_map: source_info.channel_map, + cvolume: Some(protocol::ChannelVolume::norm(2)), + ..Default::default() + }; + + // Create a buffer that implements AsyncRead. + let buffer = pulseaudio::RecordBuffer::new(1024 * 1024 * 1024); + let stream = client + .create_record_stream(params, buffer.as_record_sink()) + .await?; + + // Create the output file. + let sample_spec = stream.sample_spec().clone(); + let (bits_per_sample, sample_format) = match sample_spec.format { + protocol::SampleFormat::S16Le => (16, hound::SampleFormat::Int), + protocol::SampleFormat::Float32Le => (32, hound::SampleFormat::Float), + protocol::SampleFormat::S32Le => (32, hound::SampleFormat::Int), + _ => bail!("unsupported sample format: {:?}", sample_spec.format), + }; + + let spec = hound::WavSpec { + channels: stream.channel_map().num_channels() as u16, + sample_rate: sample_spec.sample_rate, + bits_per_sample, + sample_format, + }; + + let file = BufWriter::new(File::create(Path::new(&args[1]))?); + let mut wav_writer = hound::WavWriter::new(file, spec)?; + + let mut bytes = ReaderStream::new(buffer.compat()); + tokio::spawn(async move { + while let Some(chunk) = bytes.next().await { + write_chunk(&mut wav_writer, sample_spec.format, &chunk?)?; + } + + Ok::<(), anyhow::Error>(()) + }); + + // Wait for the stream to start. + stream.started().await?; + eprintln!("Recording... [press enter to finish]"); + + // Wait for the user to press enter. + read_stdin().await?; + + // If we quit now, we'll miss out on anything still in the server's buffer. + // Instead, we can measure the stream latency and wait that long before + // deleting the stream. + // + // To calculate the latency, we measure the difference between the + // read/write offset on the buffer, and add the source's inherent latency. + let timing_info = stream.timing_info().await?; + let offset = (timing_info.write_offset as u64) + .checked_sub(timing_info.read_offset as u64) + .unwrap_or(0); + let latency = time::Duration::from_micros(timing_info.source_usec) + + sample_spec.bytes_to_duration(offset as usize); + tokio::time::sleep(latency).await; + + stream.delete().await?; + eprintln!("Saved recording to {}", args[1]); + + Ok(()) +} + +async fn read_stdin() -> io::Result<()> { + let (done_tx, done_rx) = oneshot::channel(); + std::thread::spawn(|| { + let mut buf = [0; 1]; + let _ = done_tx.send(std::io::stdin().read(&mut buf).map(|_| ())); + }); + + done_rx.await.unwrap() +} + +fn write_chunk( + wav_writer: &mut hound::WavWriter>, + format: protocol::SampleFormat, + chunk: &[u8], +) -> anyhow::Result<()> { + use byteorder::ReadBytesExt as _; + + let mut cursor = io::Cursor::new(chunk); + while cursor.position() < cursor.get_ref().len() as u64 { + match format { + protocol::SampleFormat::S16Le => { + wav_writer.write_sample(cursor.read_i16::()?)? + } + protocol::SampleFormat::Float32Le => { + wav_writer.write_sample(cursor.read_f32::()?)? + } + protocol::SampleFormat::S32Le => { + wav_writer.write_sample(cursor.read_i32::()?)? + } + _ => unreachable!(), + }; + } + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..433b22e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,783 @@ +use std::{ + ffi::{CStr, CString}, + io::{BufReader, Read, Write}, +}; + +use mio::net::UnixStream; + +use super::protocol; + +mod playback_source; +mod playback_stream; +mod reactor; +mod record_sink; +mod record_stream; + +pub use playback_source::*; +pub use playback_stream::*; +pub use record_sink::*; +pub use record_stream::*; + +/// An error encountered by a [Client]. +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + /// The PulseAudio server socket couldn't be located.. + #[error("PulseAudio server unavailable")] + ServerUnavailable, + /// The server sent an invalid sequence number in reply to a command. + #[error("Unexpected sequence number")] + UnexpectedSequenceNumber, + /// A protocol-level error, like an invalid message. + #[error("Protocol error")] + Protocol(#[from] protocol::ProtocolError), + /// An error message sent by the server in response to a command. + #[error("Server error: {0}")] + ServerError(protocol::PulseError), + /// An error occurred reading or writing to the socket, or communicating + /// with the worker thread. + #[error("I/O error")] + Io(#[from] std::io::Error), + /// The client has disconnected, usually because an error occurred. + #[error("Client disconnected")] + Disconnected, +} + +/// The result of a [Client] operation. +pub type Result = std::result::Result; + +/// A PulseAudio client. +/// +/// The client object can be freely cloned and shared between threads. +#[derive(Clone)] +pub struct Client { + desc: String, + handle: reactor::ReactorHandle, +} + +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Client").field(&self.desc).finish() + } +} + +impl Client { + /// Creates a new client, using the environment to find the socket and cookie file. + /// + /// See the documentation for [socket_path_from_env](super::socket_path_from_env) and + /// [cookie_path_from_env](super::cookie_path_from_env) for an explanation + /// of how the socket path and cookie are determined. + pub fn from_env(client_name: impl AsRef) -> Result { + let socket_path = super::socket_path_from_env().ok_or(ClientError::ServerUnavailable)?; + let cookie = super::cookie_path_from_env().and_then(|p| std::fs::read(p).ok()); + + log::info!( + "connecting to PulseAudio server at {}", + socket_path.display() + ); + let socket = std::os::unix::net::UnixStream::connect(socket_path)?; + Self::new_unix(client_name, socket, cookie) + } + + /// Creates a new client, using the given connected unix domain socket to + /// communicate with the PulseAudio server. + pub fn new_unix( + client_name: impl AsRef, + mut socket: std::os::unix::net::UnixStream, + cookie: Option>, + ) -> std::result::Result { + let desc = if let Some(path) = socket.peer_addr()?.as_pathname() { + format!("unix:{}", path.display()) + } else { + "".into() + }; + + // Perform the handshake. + let protocol_version; + { + let mut reader = BufReader::new(&mut socket); + let cookie = cookie.as_ref().map(AsRef::as_ref).unwrap_or(&[]).to_owned(); + let auth = protocol::AuthParams { + version: protocol::MAX_VERSION, + supports_shm: false, + supports_memfd: false, + cookie, + }; + + let auth_reply: protocol::AuthReply = roundtrip_blocking( + &mut reader, + protocol::Command::Auth(auth), + 0, + protocol::MAX_VERSION, + )?; + + protocol_version = std::cmp::min(protocol::MAX_VERSION, auth_reply.version); + + let mut props = protocol::Props::new(); + props.set(protocol::Prop::ApplicationName, client_name.as_ref()); + + let _: protocol::SetClientNameReply = roundtrip_blocking( + &mut reader, + protocol::Command::SetClientName(props), + 1, + protocol_version, + )?; + } + + // Set up the reactor. + socket.set_nonblocking(true)?; + let socket = UnixStream::from_std(socket); + let handle = reactor::Reactor::spawn(socket, protocol_version)?; + + Ok(Self { desc, handle }) + } + + /// Fetches basic information on the server. + pub async fn server_info(&self) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetServerInfo) + .await + } + + /// Fetches all clients connected to the server. + pub async fn list_clients(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetClientInfoList) + .await + } + + /// Fetches a connected client by its index. + pub async fn client_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetClientInfo(index)) + .await + } + + /// Fetches all sinks available on the server. + pub async fn list_sinks(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetSinkInfoList) + .await + } + + /// Fetches all sources available on the server. + pub async fn list_sources(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetSourceInfoList) + .await + } + + /// Fetches a specific sink by its index. + pub async fn sink_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetSinkInfo(protocol::GetSinkInfo { + index: Some(index), + name: None, + })) + .await + } + + /// Fetches a specific sink by name. + pub async fn sink_info_by_name(&self, name: CString) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetSinkInfo(protocol::GetSinkInfo { + index: None, + name: Some(name), + })) + .await + } + + /// Fetches a specific source by its index. + pub async fn source_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetSourceInfo(protocol::GetSourceInfo { + index: Some(index), + name: None, + })) + .await + } + + /// Fetches a specific source by name. + pub async fn source_info_by_name(&self, name: CString) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetSourceInfo(protocol::GetSourceInfo { + index: None, + name: Some(name), + })) + .await + } + + /// Looks up a sink by its index. + pub async fn lookup_sink(&self, index: u32) -> Result { + let cmd = protocol::Command::LookupSink(CString::new(index.to_string()).unwrap()); + let reply = self + .handle + .roundtrip_reply::(cmd) + .await?; + Ok(reply.0) + } + + /// Looks up a sink by its name. + pub async fn lookup_sink_by_name(&self, name: CString) -> Result { + let cmd = protocol::Command::LookupSink(name); + let reply = self + .handle + .roundtrip_reply::(cmd) + .await?; + Ok(reply.0) + } + + /// Looks up a source by its index. + pub async fn lookup_source(&self, index: u32) -> Result { + let cmd = protocol::Command::LookupSource(CString::new(index.to_string()).unwrap()); + let reply = self + .handle + .roundtrip_reply::(cmd) + .await?; + Ok(reply.0) + } + + /// Looks up a source by its name. + pub async fn lookup_source_by_name(&self, name: CString) -> Result { + let cmd = protocol::Command::LookupSource(name); + let reply = self + .handle + .roundtrip_reply::(cmd) + .await?; + Ok(reply.0) + } + + /// Fetches a specific card by its index. + pub async fn card_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetCardInfo(protocol::GetCardInfo { + index: Some(index), + name: None, + })) + .await + } + + /// Fetches a specific card by its name. + pub async fn card_info_by_name(&self, name: CString) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetCardInfo(protocol::GetCardInfo { + index: None, + name: Some(name), + })) + .await + } + + /// Fetches all cards available on the server. + pub async fn list_cards(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetCardInfoList) + .await + } + + /// Fetches a specific module. + pub async fn module_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetModuleInfo(index)) + .await + } + + /// Fetches all modules. + pub async fn list_modules(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetModuleInfoList) + .await + } + + /// Fetches memory usage information from the server. + pub async fn stat(&self) -> Result { + self.handle.roundtrip_reply(protocol::Command::Stat).await + } + + /// Fetches a specific sample. + pub async fn sample_info(&self, index: u32) -> Result { + self.handle + .roundtrip_reply(protocol::Command::GetSampleInfo(index)) + .await + } + + /// Fetches all samples available on the server. + pub async fn list_samples(&self) -> Result> { + self.handle + .roundtrip_reply(protocol::Command::GetSampleInfoList) + .await + } + + /// Sets the default sink. + pub async fn set_default_sink(&self, name: CString) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SetDefaultSink(name)) + .await + } + + /// Sets the default source. + pub async fn set_default_source(&self, name: CString) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SetDefaultSource(name)) + .await + } + + /// Kills a client. + pub async fn kill_client(&self, index: u32) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::KillClient(index)) + .await + } + + /// Kills a sink input. + pub async fn kill_sink_input(&self, index: u32) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::KillSinkInput(index)) + .await + } + + /// Kills a source output. + pub async fn kill_source_output(&self, index: u32) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::KillSourceOutput(index)) + .await + } + + /// Suspends a sink by its index. + pub async fn suspend_sink(&self, index: u32, suspend: bool) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SuspendSink(protocol::SuspendParams { + device_index: Some(index), + device_name: None, + suspend, + })) + .await + } + + /// Suspends a sink by its name. + pub async fn suspend_sink_by_name(&self, name: CString, suspend: bool) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SuspendSink(protocol::SuspendParams { + device_index: None, + device_name: Some(name), + suspend, + })) + .await + } + + /// Suspends a source by its index. + pub async fn suspend_source(&self, index: u32, suspend: bool) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SuspendSource(protocol::SuspendParams { + device_index: Some(index), + device_name: None, + suspend, + })) + .await + } + + /// Suspends a source by its name. + pub async fn suspend_source_by_name(&self, name: CString, suspend: bool) -> Result<()> { + self.handle + .roundtrip_ack(protocol::Command::SuspendSource(protocol::SuspendParams { + device_index: None, + device_name: Some(name), + suspend, + })) + .await + } + + /// Creates a new playback stream. The given callback will be called when the + /// server requests data for the stream. + pub async fn create_playback_stream( + &self, + params: protocol::PlaybackStreamParams, + source: impl PlaybackSource, + ) -> Result { + PlaybackStream::new(self.handle.clone(), params, source).await + } + + /// Creates a new record stream. The returned handle implements + /// [AsyncRead](futures::io::AsyncRead) for extracting the raw audio data. + pub async fn create_record_stream( + &self, + params: protocol::RecordStreamParams, + sink: impl RecordSink, + ) -> Result { + RecordStream::new(self.handle.clone(), params, sink).await + } +} + +fn roundtrip_blocking( + socket: &mut BufReader, + cmd: protocol::Command, + seq: u32, + protocol_version: u16, +) -> Result { + log::debug!("CLIENT [{seq}]: {cmd:?}"); + protocol::write_command_message(socket.get_mut(), seq, &cmd, protocol_version)?; + + let (seq, reply) = protocol::read_reply_message(socket, protocol_version)?; + if seq != seq { + return Err(ClientError::UnexpectedSequenceNumber); + } + + Ok(reply) +} +#[cfg(all(test, feature = "_integration-tests"))] +mod tests { + use std::time; + + use super::*; + use anyhow::anyhow; + use anyhow::Context as _; + use futures::executor::block_on; + use rand::Rng; + + fn random_client_name() -> CString { + CString::new(format!( + "pulseaudio-rs-test-{}", + rand::rng().random_range(0..10000) + )) + .unwrap() + } + + #[test_log::test] + fn server_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let server_info = block_on(client.server_info())?; + assert!(server_info.server_name.is_some()); + + Ok(()) + } + + #[test_log::test] + fn list_clients() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let client_list = block_on(client.list_clients())?; + assert!(!client_list.is_empty()); + + Ok(()) + } + + #[test_log::test] + fn client_info() -> anyhow::Result<()> { + let client_name = random_client_name(); + let client = + Client::from_env(client_name.clone()).context("connecting to PulseAudio server")?; + + let client_list = block_on(client.list_clients())?; + assert!(!client_list.is_empty()); + + let expected = &client_list + .iter() + .find(|client| client.name == client_name) + .ok_or(anyhow!("no client with matching name"))?; + let client_info = block_on(client.client_info(expected.index))?; + + assert_eq!(**expected, client_info); + + Ok(()) + } + + #[test_log::test] + fn list_sinks() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let info_list = block_on(client.list_sinks())?; + assert!(!info_list.is_empty()); + + Ok(()) + } + + #[test_log::test] + fn list_sources() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let info_list = block_on(client.list_sources())?; + assert!(!info_list.is_empty()); + + Ok(()) + } + + #[test_log::test] + fn sink_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let sink_list = block_on(client.list_sinks())?; + assert!(!sink_list.is_empty()); + + let mut expected = sink_list[0].clone(); + let mut sink_info = block_on(client.sink_info(expected.index))?; + + expected.actual_latency = 0; + sink_info.actual_latency = 0; + assert_eq!(expected, sink_info); + + Ok(()) + } + + #[test_log::test] + fn sink_info_by_name() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let sink_list = block_on(client.list_sinks())?; + assert!(!sink_list.is_empty()); + + let mut expected = sink_list[0].clone(); + let mut sink_info = block_on(client.sink_info_by_name(expected.name.clone()))?; + + expected.actual_latency = 0; + sink_info.actual_latency = 0; + assert_eq!(expected, sink_info); + + Ok(()) + } + + #[test_log::test] + fn source_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let source_list = block_on(client.list_sources())?; + assert!(!source_list.is_empty()); + + let expected = &source_list[0]; + let source_info = block_on(client.source_info(expected.index))?; + + assert_eq!(expected, &source_info); + + Ok(()) + } + + #[test_log::test] + fn source_info_by_name() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let source_list = block_on(client.list_sources())?; + assert!(!source_list.is_empty()); + + let expected = &source_list[0]; + let source_info = block_on(client.source_info_by_name(expected.name.clone()))?; + + assert_eq!(expected, &source_info); + + Ok(()) + } + + #[test_log::test] + fn lookup_sink() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let sink_list = block_on(client.list_sinks())?; + assert!(!sink_list.is_empty()); + + let expected = &sink_list[0]; + let sink_index = block_on(client.lookup_sink(expected.index))?; + + assert_eq!(expected.index, sink_index); + + Ok(()) + } + + #[test_log::test] + fn lookup_sink_by_name() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let sink_list = block_on(client.list_sinks())?; + assert!(!sink_list.is_empty()); + + let expected = &sink_list[0]; + let sink_index = block_on(client.lookup_sink_by_name(expected.name.clone()))?; + + assert_eq!(expected.index, sink_index); + + Ok(()) + } + + #[test_log::test] + fn lookup_source() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let source_list = block_on(client.list_sources())?; + assert!(!source_list.is_empty()); + + let expected = &source_list[0]; + let source_index = block_on(client.lookup_source(expected.index))?; + + assert_eq!(expected.index, source_index); + + Ok(()) + } + + #[test_log::test] + fn lookup_source_by_name() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let source_list = block_on(client.list_sources())?; + assert!(!source_list.is_empty()); + + let expected = &source_list[0]; + let source_index = block_on(client.lookup_source_by_name(expected.name.clone()))?; + + assert_eq!(expected.index, source_index); + + Ok(()) + } + + #[test_log::test] + fn card_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let card_list = block_on(client.list_cards())?; + + if !card_list.is_empty() { + let expected = &card_list[0]; + let card_info = block_on(client.card_info(expected.index))?; + + assert_eq!(expected, &card_info); + } + + Ok(()) + } + + #[test_log::test] + fn card_info_by_name() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let card_list = block_on(client.list_cards())?; + + if !card_list.is_empty() { + let expected = &card_list[0]; + let card_info = block_on(client.card_info_by_name(expected.name.clone()))?; + + assert_eq!(expected, &card_info); + } + + Ok(()) + } + + #[test_log::test] + fn list_cards() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let _card_list = block_on(client.list_cards())?; + Ok(()) + } + + #[test_log::test] + fn module_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let module_list = block_on(client.list_modules())?; + assert!(!module_list.is_empty()); + + let expected = &module_list[0]; + let module_info = block_on(client.module_info(expected.index))?; + + assert_eq!(expected, &module_info); + + Ok(()) + } + + #[test_log::test] + fn list_modules() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let module_list = block_on(client.list_modules())?; + assert!(!module_list.is_empty()); + + Ok(()) + } + + #[test_log::test] + fn stat() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let stat_info = block_on(client.stat())?; + assert!(stat_info.memblock_total > 0); + + Ok(()) + } + + #[test_log::test] + fn sample_info() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let sample_list = block_on(client.list_samples())?; + if sample_list.is_empty() { + return Ok(()); + } + + let expected = &sample_list[0]; + let sample_info = block_on(client.sample_info(expected.index))?; + + assert_eq!(expected, &sample_info); + + Ok(()) + } + + #[test_log::test] + fn list_samples() -> anyhow::Result<()> { + let client = + Client::from_env(random_client_name()).context("connecting to PulseAudio server")?; + + let _sample_list = block_on(client.list_samples())?; + Ok(()) + } + + #[test_log::test] + fn kill_client() -> anyhow::Result<()> { + let client_name1 = random_client_name(); + let client1 = Client::from_env(client_name1.clone())?; + let client2 = Client::from_env(random_client_name())?; + + let client_list = block_on(client2.list_clients())?; + assert!(!client_list.is_empty()); + + let client1_info = client_list + .iter() + .find(|client| client.name == client_name1) + .ok_or(anyhow!("no client1 with matching name"))?; + + block_on(client2.kill_client(client1_info.index))?; + + // Listing things should eventually fail with client1. + let start = time::Instant::now(); + loop { + match block_on(client1.server_info()).err() { + Some(ClientError::Disconnected) => break, + _ if start.elapsed() < time::Duration::from_secs(1) => { + std::thread::sleep(time::Duration::from_millis(10)) + } + _ => panic!("client still connected"), + } + } + + let client_list = block_on(client2.list_clients())?; + assert!(client_list + .iter() + .find(|client| client.name == client1_info.name) + .is_none()); + + Ok(()) + } +} diff --git a/src/client/playback_source.rs b/src/client/playback_source.rs new file mode 100644 index 0000000..f8914e4 --- /dev/null +++ b/src/client/playback_source.rs @@ -0,0 +1,109 @@ +use std::pin::Pin; + +/// An audio source for a playback stream. At its core, this is just a callback +/// that is driven by the server to generate samples. +/// +/// # Example: using a callback +/// +/// A callback can be used as a [PlaybackSource] via [AsPlaybackSource]: +/// +/// ```no_run +/// # use pulseaudio::*; +/// # let client = Client::from_env(c"client").unwrap(); +/// # let params = protocol::PlaybackStreamParams::default(); +/// let callback = move |buf: &mut [u8]| { +/// // Here, we're just returning silence. +/// buf.fill(0); +/// // We have to return the number of bytes writen, which can be less than +/// // the buffer size. However, if we return 0 bytes, that's considered an +/// // EOF, and the callback won't be called anymore. +/// buf.len() +/// }; +/// +/// # let _ = +/// client.create_playback_stream(params, callback.as_playback_source()); +/// ``` +/// +/// # Example: using a type that implements AsyncRead +/// +/// Types that implement [futures::io::AsyncRead] can also used as a source. In +/// this case, any error will be considered EOF. +/// +/// ```no_run +/// # use pulseaudio::*; +/// # use futures::TryStreamExt; +/// # let client = Client::from_env(c"client").unwrap(); +/// # let params = protocol::PlaybackStreamParams::default(); +/// // Here we'll create an arbitrary stream, but this could just as easily be +/// // a PCM file or network stream or something else. +/// let stream = futures::stream::iter([ +/// Ok(vec![0, 0]), +/// Ok(vec![0, 0]), +/// Ok(vec![0, 0]), +/// Ok(vec![0, 0]), +/// ]); +/// +/// # let _ = +/// client.create_playback_stream(params, stream.into_async_read()); +/// ``` +pub trait PlaybackSource: Send + 'static { + #[allow(missing_docs)] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut futures::task::Context<'_>, + buf: &mut [u8], + ) -> futures::task::Poll; +} + +/// A trait for converting a callback into an [AudioSource]. +pub trait AsPlaybackSource { + /// Converts the callback into an [AudioSource]. + fn as_playback_source(self) -> impl PlaybackSource; +} + +struct CallbackWrapper usize + Send + 'static>(T); + +impl PlaybackSource for CallbackWrapper +where + T: FnMut(&mut [u8]) -> usize + Send + 'static, +{ + fn poll_read( + self: Pin<&mut CallbackWrapper>, + _cx: &mut futures::task::Context<'_>, + buf: &mut [u8], + ) -> futures::task::Poll { + let len = unsafe { + let pinned_closure = Pin::get_unchecked_mut(self); + pinned_closure.0(buf) + }; + + // We don't need to worry about waking up the reactor, because the + // closure always returns Ok(n) or Ok(0). + futures::task::Poll::Ready(len) + } +} + +impl AsPlaybackSource for T +where + T: FnMut(&mut [u8]) -> usize + Send + 'static, +{ + fn as_playback_source(self) -> impl PlaybackSource { + CallbackWrapper(self) + } +} + +impl PlaybackSource for T +where + T: futures::AsyncRead + Send + 'static, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut futures::task::Context<'_>, + buf: &mut [u8], + ) -> futures::task::Poll { + futures::AsyncRead::poll_read(self, cx, buf).map(|result| match result { + Ok(n) => n, + Err(_) => 0, + }) + } +} diff --git a/src/client/playback_stream.rs b/src/client/playback_stream.rs new file mode 100644 index 0000000..fb92a56 --- /dev/null +++ b/src/client/playback_stream.rs @@ -0,0 +1,184 @@ +use std::ffi::CString; +use std::sync::Arc; +use std::time; + +use futures::channel::oneshot; +use futures::FutureExt as _; + +use super::reactor::ReactorHandle; +use super::{ClientError, PlaybackSource, Result as ClientResult}; +use crate::protocol; + +/// A stream of audio data sent from the client to the server for playback in +/// a sink. +/// +/// The stream handle can be freely cloned and shared between threads. +#[derive(Clone)] +pub struct PlaybackStream(Arc); + +struct InnerPlaybackStream { + handle: ReactorHandle, + info: protocol::CreatePlaybackStreamReply, + eof_notify: futures::future::Shared>, +} + +impl std::fmt::Debug for PlaybackStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PlaybackStream") + .field(&self.0.info.channel) + .finish() + } +} + +impl PlaybackStream { + pub(super) async fn new( + handle: ReactorHandle, + params: protocol::PlaybackStreamParams, + source: impl PlaybackSource, + ) -> Result { + let (eof_tx, eof_rx) = oneshot::channel(); + let info = handle + .insert_playback_stream(params, source, Some(eof_tx)) + .await?; + + Ok(Self(Arc::new(InnerPlaybackStream { + handle, + info, + eof_notify: eof_rx.shared(), + }))) + } + + /// The ID of the stream. + pub fn channel(&self) -> u32 { + self.0.info.channel + } + + /// The attributes of the server-side buffer. + pub fn buffer_attr(&self) -> &protocol::stream::BufferAttr { + &self.0.info.buffer_attr + } + + /// The sample specification for the stream. Can differ from the client's + /// requested sample spec. + pub fn sample_spec(&self) -> &protocol::SampleSpec { + &self.0.info.sample_spec + } + + /// The channel map for the stream. + pub fn channel_map(&self) -> &protocol::ChannelMap { + &self.0.info.channel_map + } + + /// The sink the stream is connected to. + pub fn sink(&self) -> u32 { + self.0.info.sink_index + } + + /// Sets the name of the playback stream. + pub async fn set_name(&self, name: CString) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::SetPlaybackStreamName( + protocol::SetStreamNameParams { + index: self.0.info.stream_index, + name, + }, + )) + .await + } + + /// Fetches playback timing information for the playback stream. + pub async fn timing_info(&self) -> ClientResult { + self.0 + .handle + .roundtrip_reply(protocol::Command::GetPlaybackLatency( + protocol::LatencyParams { + channel: self.0.info.channel, + now: time::SystemTime::now(), + }, + )) + .await + } + + /// Corks the playback stream (temporarily pausing playback). + pub async fn cork(&self) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::CorkPlaybackStream( + protocol::CorkStreamParams { + channel: self.0.info.channel, + cork: true, + }, + )) + .await + } + + /// Uncorks the playback stream. + pub async fn uncork(&self) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::CorkPlaybackStream( + protocol::CorkStreamParams { + channel: self.0.info.channel, + cork: false, + }, + )) + .await + } + + /// Returns a future that resolves when the stream's [AudioSource] has reached the end. + pub async fn source_eof(&self) -> ClientResult<()> { + self.0 + .eof_notify + .clone() + .await + .map_err(|_| ClientError::Disconnected) + } + + /// Waits until the given [AudioSource] has reached the end (and returns 0 in [AudioSource::poll_read]), + /// and then instructs the server to drain the buffer before ending the stream. + pub async fn play_all(&self) -> ClientResult<()> { + self.source_eof().await?; + self.drain().await?; + Ok(()) + } + + /// Instructs the server to play any remaining data in the buffer, then end + /// the stream. This method returns once the stream has finished. + pub async fn drain(&self) -> ClientResult<()> { + self.0 + .handle + .mark_playback_stream_draining(self.0.info.channel); + self.0 + .handle + .roundtrip_ack(protocol::Command::DrainPlaybackStream(self.0.info.channel)) + .await + } + + /// Instructs the server to discard any buffered data. + pub async fn flush(&self) -> super::Result<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::FlushPlaybackStream(self.0.info.channel)) + .await + } + + /// Deletes the stream from the server. + pub async fn delete(self) -> ClientResult<()> { + self.0 + .handle + .delete_playback_stream(self.0.info.channel) + .await + } +} + +impl Drop for InnerPlaybackStream { + fn drop(&mut self) { + // Sends the delete command to the server, but doesn't wait for the + // response. + let _ = self + .handle + .delete_playback_stream(self.info.channel) + .now_or_never(); + } +} diff --git a/src/client/reactor.rs b/src/client/reactor.rs new file mode 100644 index 0000000..4da2849 --- /dev/null +++ b/src/client/reactor.rs @@ -0,0 +1,575 @@ +use std::{ + collections::BTreeMap, + io::{self}, + pin::Pin, + sync::{ + atomic::{self, AtomicU32}, + mpsc::{Receiver, Sender, TryRecvError}, + Arc, Mutex, Weak, + }, + task::{Context, Poll}, + thread::JoinHandle, +}; + +use futures::channel::oneshot; +use mio::net::UnixStream; + +use crate::protocol::{self, DescriptorFlags}; + +use super::{ClientError, PlaybackSource, RecordSink}; + +type ReplyResult<'a> = + Result<(&'a mut ReactorState, &'a mut dyn io::BufRead), protocol::PulseError>; +type ReplyHandler = Box) + Send + 'static>; + +struct PlaybackStreamState { + stream_info: protocol::CreatePlaybackStreamReply, + source: Pin>, + + requested_bytes: usize, + done: bool, + eof_notify: Option>, +} + +pub(super) struct RecordStreamState { + sink: Box, + start_notify: Option>, +} + +#[derive(Default)] +struct ReactorState { + handlers: BTreeMap, + playback_streams: BTreeMap, + record_streams: BTreeMap, +} + +struct SharedState { + protocol_version: u16, + next_seq: AtomicU32, + _thread_handle: JoinHandle>, +} + +// We need to wrap this to implement futures::task::ArcWake. +struct Waker(mio::Waker); + +impl futures::task::ArcWake for Waker { + fn wake_by_ref(arc_self: &Arc) { + let _ = arc_self.0.wake(); + } +} + +#[derive(Clone)] +pub(super) struct ReactorHandle { + state: Weak>, + shared: Arc, + outgoing: Sender<(u32, protocol::Command)>, + waker: Arc, +} + +impl ReactorHandle { + pub(super) async fn roundtrip_reply( + &self, + cmd: protocol::Command, + ) -> Result { + let seq = self.next_seq(); + + // Install a handler for the sequence number. + let (tx, rx) = oneshot::channel(); + let protocol_version = self.shared.protocol_version; + self.install_handler(seq, move |res: ReplyResult<'_>| { + let _ = match res { + Ok((_, buf)) => tx.send(read_tagstruct(buf, protocol_version)), + Err(err) => tx.send(Err(ClientError::ServerError(err))), + }; + })?; + + // Send the message. + self.write_command(seq, cmd)?; + + // Wait for the response. + rx.await.map_err(|_| ClientError::Disconnected)? + } + + pub(super) async fn roundtrip_ack(&self, cmd: protocol::Command) -> Result<(), ClientError> { + let seq = self.next_seq(); + + // Install a handler for the sequence number. + let (tx, rx) = oneshot::channel(); + self.install_handler(seq, move |res: ReplyResult<'_>| { + let _ = match res { + Ok(_) => tx.send(Ok(())), + Err(err) => tx.send(Err(ClientError::ServerError(err))), + }; + })?; + + // Send the message. + self.write_command(seq, cmd)?; + + // Wait for the response. + rx.await.map_err(|_| ClientError::Disconnected)? + } + + pub(super) async fn insert_playback_stream( + &self, + params: protocol::PlaybackStreamParams, + source: impl PlaybackSource, + eof_notify: Option>, + ) -> Result { + // This is the seq for the CreatePlaybackStream command. + let seq = self.next_seq(); + + let protocol_version = self.shared.protocol_version; + let handler = move |res: ReplyResult<'_>| { + let (state, buf) = res.map_err(ClientError::ServerError)?; + let stream_info: protocol::CreatePlaybackStreamReply = + read_tagstruct(buf, protocol_version)?; + + let requested_bytes = stream_info.requested_bytes as usize; + state.playback_streams.insert( + stream_info.channel, + PlaybackStreamState { + stream_info: stream_info.clone(), + source: Box::pin(source), + + requested_bytes, + done: false, + eof_notify, + }, + ); + + Ok(stream_info) + }; + + let (tx, rx) = oneshot::channel(); + self.install_handler(seq, move |res: ReplyResult<'_>| { + let _ = tx.send(handler(res)); + })?; + + // Send the message. + self.write_command(seq, protocol::Command::CreatePlaybackStream(params))?; + + // Wait for the response. + rx.await.map_err(|_| ClientError::Disconnected)? + } + + pub(super) async fn delete_playback_stream(&self, channel: u32) -> Result<(), ClientError> { + let seq = self.next_seq(); + + let (tx, rx) = oneshot::channel(); + self.install_handler(seq, move |res| { + if let Ok((state, _ack)) = res { + state.playback_streams.remove(&channel); + } + + let _ = tx.send(()); + })?; + + self.write_command(seq, protocol::Command::DeletePlaybackStream(channel))?; + rx.await.map_err(|_| ClientError::Disconnected) + } + + pub(super) fn mark_playback_stream_draining(&self, channel: u32) { + if let Some(state) = self.state.upgrade() { + if let Some(stream) = state.lock().unwrap().playback_streams.get_mut(&channel) { + stream.done = true; + } + } + } + + pub(super) async fn insert_record_stream( + &self, + params: protocol::RecordStreamParams, + sink: impl RecordSink, + start_notify: Option>, + ) -> Result { + let seq = self.next_seq(); + + let protocol_version = self.shared.protocol_version; + let handler = move |res: ReplyResult<'_>| { + let (state, buf) = res.map_err(ClientError::ServerError)?; + let stream_info: protocol::CreateRecordStreamReply = + read_tagstruct(buf, protocol_version)?; + + state.record_streams.insert( + stream_info.channel, + RecordStreamState { + sink: Box::new(sink), + start_notify, + }, + ); + + Ok(stream_info) + }; + + let (tx, rx) = oneshot::channel(); + self.install_handler(seq, move |res: ReplyResult<'_>| { + let _ = tx.send(handler(res)); + })?; + + // Send the message. + self.write_command(seq, protocol::Command::CreateRecordStream(params))?; + + // Wait for the response. + rx.await.map_err(|_| ClientError::Disconnected)? + } + + pub(super) async fn delete_record_stream(&self, channel: u32) -> Result<(), ClientError> { + let seq = self.next_seq(); + + let (tx, rx) = oneshot::channel(); + self.install_handler(seq, move |res| { + if let Ok((state, _ack)) = res { + state.record_streams.remove(&channel); + } + + let _ = tx.send(()); + })?; + + self.write_command(seq, protocol::Command::DeleteRecordStream(channel))?; + rx.await.map_err(|_| ClientError::Disconnected) + } + + fn write_command(&self, seq: u32, cmd: protocol::Command) -> Result<(), ClientError> { + self.outgoing + .send((seq, cmd)) + .map_err(|_| ClientError::Disconnected)?; + self.waker.0.wake()?; + + Ok(()) + } + + fn install_handler(&self, seq: u32, handler: F) -> Result<(), ClientError> + where + F: FnOnce(ReplyResult<'_>) + Send + 'static, + { + self.state + .upgrade() + .ok_or(ClientError::Disconnected)? + .lock() + .unwrap() + .handlers + .insert(seq, Box::new(handler)); + + Ok(()) + } + + fn next_seq(&self) -> u32 { + self.shared.next_seq.fetch_add(1, atomic::Ordering::Relaxed) + } +} + +pub(super) const WAKER: mio::Token = mio::Token(0); +pub(super) const SOCKET: mio::Token = mio::Token(1); + +pub(super) struct Reactor { + socket: UnixStream, + poll: mio::Poll, + waker: Arc, + state: Arc>, + outgoing: Receiver<(u32, protocol::Command)>, + protocol_version: u16, + + write_buf: Vec, + read_buf: Vec, + in_progress_read: Option, +} + +impl Reactor { + pub(super) fn spawn( + mut socket: UnixStream, + protocol_version: u16, + ) -> Result { + let poll = mio::Poll::new()?; + let waker = Arc::new(Waker(mio::Waker::new(poll.registry(), WAKER)?)); + poll.registry().register( + &mut socket, + SOCKET, + mio::Interest::READABLE | mio::Interest::WRITABLE, + )?; + + let state = Arc::new(Mutex::new(ReactorState::default())); + + let (cmd_tx, cmd_rx) = std::sync::mpsc::channel(); + let mut reactor = Self { + socket, + poll, + waker: waker.clone(), + state: state.clone(), + outgoing: cmd_rx, + protocol_version, + + write_buf: Vec::new(), + read_buf: Vec::new(), + in_progress_read: None, + }; + + let reactor_thread = std::thread::spawn(move || match reactor.run() { + Ok(_) => Ok(()), + Err(err) => { + log::error!("Reactor error: {}", err); + Err(err) + } + }); + + Ok(ReactorHandle { + state: Arc::downgrade(&state), + outgoing: cmd_tx, + waker, + shared: Arc::new(SharedState { + protocol_version, + next_seq: AtomicU32::new(1024), + _thread_handle: reactor_thread, + }), + }) + } + + pub(super) fn run(&mut self) -> Result<(), ClientError> { + let mut events = mio::Events::with_capacity(1024); + + loop { + self.poll.poll(&mut events, None)?; + self.recv()?; + + // Handle any requested writes. + self.write_streams()?; + self.write_commands()?; + } + } + + fn recv(&mut self) -> Result<(), ClientError> { + use io::Read; + + 'read: loop { + let off = self.read_buf.len(); + self.read_buf.resize(off + 1024 * 1024, 0); + + match self.socket.read(&mut self.read_buf[off..]) { + Ok(0) => return Err(ClientError::Disconnected), + Ok(n) => self.read_buf.truncate(off + n), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + self.read_buf.truncate(off); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + + // Decode messages (there may be multiple). + while !self.read_buf.is_empty() { + // Continue the previous read, if it was unfinished. + let desc = if let Some(desc) = self.in_progress_read.take() { + desc + } else if self.read_buf.len() >= protocol::DESCRIPTOR_SIZE { + protocol::read_descriptor(&mut io::Cursor::new(&self.read_buf))? + } else { + log::trace!("very short read ({} bytes)", self.read_buf.len()); + continue 'read; + }; + + // If we don't have all the message, poll until we do. + let len = desc.length as usize + protocol::DESCRIPTOR_SIZE; + if self.read_buf.len() < len { + self.in_progress_read = Some(desc); + log::trace!("partial read ({}/{} bytes)", self.read_buf.len(), len); + continue 'read; + } + + if desc.channel == u32::MAX { + self.handle_command(len); + } else { + // Stream data for a record stream. + let mut guard = self.state.lock().unwrap(); + if let Some(RecordStreamState { sink, start_notify }) = + guard.record_streams.get_mut(&desc.channel) + { + log::trace!("reading {len} bytes from stream {}", desc.channel,); + if let Some(start_notify) = start_notify.take() { + let _ = start_notify.send(()); + } + + sink.write(&self.read_buf[protocol::DESCRIPTOR_SIZE..len]) + } else { + log::warn!("Received data for unknown record stream {}", desc.channel); + } + } + + self.read_buf.drain(..len); + } + } + } + + fn handle_command(&mut self, len: usize) { + let mut cursor = io::Cursor::new(&self.read_buf[protocol::DESCRIPTOR_SIZE..len]); + let (seq, cmd) = + match protocol::Command::read_tag_prefixed(&mut cursor, self.protocol_version) { + Ok((seq, cmd)) => (seq, cmd), + Err(err) => { + log::error!("failed to read command message: {}", err); + return; + } + }; + + let mut state = self.state.lock().unwrap(); + + log::debug!("SERVER [{}]: {cmd:?}", seq as i32); + if matches!(cmd, protocol::Command::Reply | protocol::Command::Error(_)) { + let Some(handler) = state.handlers.remove(&seq) else { + log::warn!("no reply handler found for sequence {}", seq); + return; + }; + + match cmd { + protocol::Command::Reply => handler(Ok((&mut state, &mut cursor))), + protocol::Command::Error(err) => handler(Err(err)), + _ => unreachable!(), + } + return; + } + + match cmd { + protocol::Command::Started(channel) => { + if state.playback_streams.contains_key(&channel) { + log::debug!("stream started: {}", channel); + } else { + log::error!("unknown stream: {}", channel); + } + } + protocol::Command::Request(protocol::Request { channel, length }) => { + if let Some(stream) = state.playback_streams.get_mut(&channel) { + stream.requested_bytes += length as usize; + } else { + log::error!("unknown stream: {}", channel); + } + } + _ => log::debug!("ignoring unexpected command: {:?}", cmd), + } + } + + fn write_commands(&mut self) -> Result<(), ClientError> { + loop { + // Drain the write buffer... + if !drain_buf(&mut self.write_buf, &mut self.socket)? { + return Ok(()); + } + + // ...and encode new command messages into it. + match self.outgoing.try_recv() { + Ok((seq, cmd)) => { + log::debug!("CLIENT [{seq}]: {cmd:?}"); + protocol::encode_command_message( + &mut self.write_buf, + seq, + &cmd, + self.protocol_version, + )?; + } + Err(TryRecvError::Empty) => return Ok(()), + Err(TryRecvError::Disconnected) => return Err(ClientError::Disconnected), + }; + } + } + + fn write_streams(&mut self) -> Result<(), ClientError> { + if !drain_buf(&mut self.write_buf, &mut self.socket)? { + return Ok(()); + } + + let mut state = self.state.lock().unwrap(); + for stream in state.playback_streams.values_mut() { + if stream.done { + continue; + } + + while stream.requested_bytes > 0 { + let requested = stream.requested_bytes; + + self.write_buf + .resize(protocol::DESCRIPTOR_SIZE + requested, 0); + + let waker = futures::task::waker(self.waker.clone()); + let mut cx = Context::from_waker(&waker); + let mut buf = &mut self.write_buf[protocol::DESCRIPTOR_SIZE..]; + let len = match PlaybackSource::poll_read(stream.source.as_mut(), &mut cx, &mut buf) + { + Poll::Ready(0) => { + log::debug!( + "source for stream {} reached EOF", + stream.stream_info.channel + ); + + stream.done = true; + stream.eof_notify.take().map(|done| done.send(())); + self.write_buf.clear(); + break; + } + Poll::Pending => { + self.write_buf.clear(); + break; + } + Poll::Ready(n) => n, + }; + + let len = len.min(requested); + if len == 0 { + log::debug!( + "callback for stream {} returned no data", + stream.stream_info.channel + ); + + self.write_buf.clear(); + break; + } + + log::trace!( + "writing {len} bytes to stream {} (requested {})", + stream.stream_info.channel, + stream.requested_bytes + ); + + self.write_buf.truncate(protocol::DESCRIPTOR_SIZE + len); + stream.requested_bytes -= len; + + let desc = protocol::Descriptor { + length: len as u32, + channel: stream.stream_info.channel, + offset: 0, + flags: DescriptorFlags::empty(), + }; + + protocol::encode_descriptor( + (&mut self.write_buf[..protocol::DESCRIPTOR_SIZE]) + .try_into() + .unwrap(), + &desc, + ); + + if !drain_buf(&mut self.write_buf, &mut self.socket)? { + return Ok(()); + } + } + } + + Ok(()) + } +} + +fn drain_buf(buf: &mut Vec, w: &mut impl io::Write) -> Result { + while !buf.is_empty() { + match w.write(&buf) { + Ok(0) => return Ok(false), + Ok(n) => buf.drain(..n), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(false), + Err(err) => return Err(err), + }; + } + + Ok(true) +} + +fn read_tagstruct( + buf: &mut dyn io::BufRead, + protocol_version: u16, +) -> Result { + protocol::TagStructReader::new(buf, protocol_version) + .read() + .map_err(Into::into) +} diff --git a/src/client/record_sink.rs b/src/client/record_sink.rs new file mode 100644 index 0000000..1f2f867 --- /dev/null +++ b/src/client/record_sink.rs @@ -0,0 +1,280 @@ +use std::{ + collections::VecDeque, + io, + sync::{Arc, Mutex}, + task::{Poll, Waker}, +}; + +use futures::AsyncRead; + +/// An audio sink for a record stream. At its core, this is just a callback +/// that is called whenever the server sends samples for the stream. +/// +/// # Example: using a callback +/// +/// A callback can be used directly as a [RecordSink]. +/// +/// ```no_run +/// # use pulseaudio::*; +/// # let client = Client::from_env(c"client").unwrap(); +/// # let params = protocol::RecordStreamParams::default(); +/// let callback = move |buf: &[u8]| { +/// // Process the audio data somehow. +/// }; +/// +/// # let _ = +/// client.create_record_stream(params, callback); +/// ``` +/// +/// # Example: using RecordBuffer +/// +/// You can use a [RecordBuffer] to integrate with the async ecosystem, as it +/// implements [futures::AsyncRead]. +/// +/// Because of the inversion of control, data must be first written to the +/// buffer as it arrives from the server, and can then be read. This entails +/// an extra copy. +/// +/// ```no_run +/// # use pulseaudio::*; +/// # let client = Client::from_env(c"client").unwrap(); +/// # let params = protocol::RecordStreamParams::default(); +/// // The size we pass determines the maximum amount that will be buffered. +/// let mut buffer = RecordBuffer::new(usize::MAX); +/// +/// # let _ = +/// client.create_record_stream(params, buffer.as_record_sink()); +/// +/// // Now we can read from the buffer. +/// # let mut dst = Vec::new(); +/// # async { +/// use futures::io::AsyncReadExt; +/// buffer.read(&mut dst).await?; +/// # Ok::<(), std::io::Error>(()) +/// # }; +/// ``` +pub trait RecordSink: Send + 'static { + #[allow(missing_docs)] + fn write(&mut self, data: &[u8]); +} + +impl RecordSink for T +where + T: FnMut(&[u8]) + Send + 'static, +{ + fn write(&mut self, data: &[u8]) { + self(data); + } +} + +/// A buffer for adapting a record stream in situations where an implementation +/// of [AsyncRead](futures::io::AsyncRead) is required. +pub struct RecordBuffer { + inner: Arc>, + capacity: usize, +} + +struct InnerRecordBuffer { + buf: VecDeque, + waker: Option, + eof: bool, +} + +impl RecordBuffer { + /// Create a new record buffer with the given capacity. If you created the + /// record stream with a specific set of + /// [Buffer Attributes](protocol::BufferAttr), the capacity should be at + /// least equal to the `max_length` parameter. Alternatively, just pick + /// something reasonably large. + pub fn new(capacity: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(InnerRecordBuffer { + buf: VecDeque::with_capacity(capacity), + waker: None, + eof: false, + })), + capacity, + } + } +} + +impl std::fmt::Debug for RecordBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecordBuffer") + .field("capacity", &self.capacity) + .finish() + } +} + +impl AsyncRead for RecordBuffer { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut inner = self.inner.lock().unwrap(); + if inner.eof { + return Poll::Ready(Ok(0)); + } + + let (ref mut front, _) = inner.buf.as_slices(); + if front.is_empty() { + inner.waker = match inner.waker.take() { + Some(w) if w.will_wake(cx.waker()) => Some(w), + _ => Some(cx.waker().clone()), + }; + + return Poll::Pending; + } + + let n = io::Read::read(front, buf)?; + inner.buf.drain(..n); + Poll::Ready(Ok(n)) + } +} + +/// A newtype for the Drop implementation, which sets +/// an EOF flag for the reader. +struct RecordBufferSink(Arc>); + +impl Drop for RecordBufferSink { + fn drop(&mut self) { + let mut inner = self.0.lock().unwrap(); + inner.eof = true; + if let Some(w) = inner.waker.take() { + w.wake(); + } + } +} + +impl RecordSink for RecordBufferSink { + fn write(&mut self, data: &[u8]) { + if data.len() == 0 { + return; + } + + let mut inner = self.0.lock().unwrap(); + + let len = inner.buf.len(); + let to_write = data.len(); + let capacity = inner.buf.capacity(); + + if to_write > capacity { + inner.buf.clear(); + inner.buf.extend(&data[..capacity]); + } else if to_write + len > capacity { + inner.buf.drain(..to_write.min(len)); + inner.buf.extend(data); + } else { + inner.buf.extend(data); + } + + if let Some(waker) = inner.waker.take() { + waker.wake(); + } + } +} + +impl RecordBuffer { + /// Creates a type suitable for use as a [RecordSink] when creating a new + /// [RecordStream]. + pub fn as_record_sink(&self) -> impl RecordSink { + RecordBufferSink(self.inner.clone()) + } +} +#[cfg(test)] +mod tests { + use super::*; + use std::{ + pin::Pin, + sync::{Arc, Mutex}, + }; + + #[test] + fn record_buffer_asyncread() { + let mut buffer = RecordBuffer::new(10); + let mut sink = buffer.as_record_sink(); + + sink.write(&[1, 2, 3, 4, 5]); + + let mut read_buf = [0; 3]; + + let waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + + match Pin::new(&mut buffer).poll_read(&mut cx, &mut read_buf) { + Poll::Ready(Ok(n)) => { + assert_eq!(n, 3); + assert_eq!(&read_buf[..n], &[1, 2, 3]); + } + _ => panic!("expected ready"), + } + + match Pin::new(&mut buffer).poll_read(&mut cx, &mut read_buf) { + Poll::Ready(Ok(n)) => { + assert_eq!(n, 2); + assert_eq!(&read_buf[..n], &[4, 5]); + } + _ => panic!("expected ready"), + } + + match Pin::new(&mut buffer).poll_read(&mut cx, &mut read_buf) { + Poll::Pending => (), + _ => panic!("expected pending"), + } + + drop(sink); + + match Pin::new(&mut buffer).poll_read(&mut cx, &mut read_buf) { + Poll::Ready(Ok(n)) => { + assert_eq!(n, 0); + } + _ => panic!("expected ready"), + } + } + + #[test] + fn record_buffer_write() { + let buffer = RecordBuffer { + inner: Arc::new(Mutex::new(InnerRecordBuffer { + buf: VecDeque::with_capacity(10), + waker: None, + eof: false, + })), + capacity: 10, + }; + + let mut sink = buffer.as_record_sink(); + + sink.write(&[1, 2, 3, 4, 5]); + + { + let inner = buffer.inner.lock().unwrap(); + assert_eq!(inner.buf.len(), 5); + } + + sink.write(&[6, 7, 8, 9, 10]); + + { + let inner = buffer.inner.lock().unwrap(); + assert_eq!(inner.buf.len(), 10); + assert_eq!( + inner.buf.as_slices(), + (&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10][..], &[][..]) + ); + } + + sink.write(&[11, 12, 13]); + + { + let mut inner = buffer.inner.lock().unwrap(); + assert_eq!(inner.buf.len(), 10); + + inner.buf.make_contiguous(); + assert_eq!( + inner.buf.as_slices(), + (&[4, 5, 6, 7, 8, 9, 10, 11, 12, 13][..], &[][..]) + ); + } + } +} diff --git a/src/client/record_stream.rs b/src/client/record_stream.rs new file mode 100644 index 0000000..dac8c54 --- /dev/null +++ b/src/client/record_stream.rs @@ -0,0 +1,161 @@ +use std::{ffi::CString, sync::Arc, time}; + +use futures::{channel::oneshot, FutureExt as _}; + +use super::{reactor::ReactorHandle, ClientError, RecordSink, Result as ClientResult}; +use crate::protocol; + +/// A stream of audio data sent from the server to the client, originating from +/// a source. +/// +/// The stream handle can be freely cloned and shared between threads. +#[derive(Clone)] +pub struct RecordStream(Arc); + +struct InnerRecordStream { + handle: ReactorHandle, + info: protocol::CreateRecordStreamReply, + start_notify: futures::future::Shared>, +} + +impl std::fmt::Debug for RecordStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("RecordStream") + .field(&self.0.info.channel) + .finish() + } +} + +impl RecordStream { + pub(super) async fn new( + handle: ReactorHandle, + params: protocol::RecordStreamParams, + sink: impl RecordSink, + ) -> Result { + let (start_tx, start_rx) = oneshot::channel(); + let info = handle + .insert_record_stream(params, sink, Some(start_tx)) + .await?; + + Ok(Self(Arc::new(InnerRecordStream { + handle, + info, + start_notify: start_rx.shared(), + }))) + } + + /// The ID of the stream. + pub fn channel(&self) -> u32 { + self.0.info.channel + } + + /// The attributes of the server-side buffer. + pub fn buffer_attr(&self) -> &protocol::stream::BufferAttr { + &self.0.info.buffer_attr + } + + /// The sample specification for the stream. Can differ from the client's + /// requested sample spec. + pub fn sample_spec(&self) -> &protocol::SampleSpec { + &self.0.info.sample_spec + } + + /// The channel map for the stream. + pub fn channel_map(&self) -> &protocol::ChannelMap { + &self.0.info.channel_map + } + + /// The sink the stream is connected to. + pub fn sink(&self) -> u32 { + self.0.info.sink_index + } + + /// Sets the name of the record stream. + pub async fn set_name(&self, name: CString) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::SetRecordStreamName( + protocol::SetStreamNameParams { + index: self.0.info.stream_index, + name, + }, + )) + .await + } + + /// Fetches record timing information for the record stream. + pub async fn timing_info(&self) -> ClientResult { + self.0 + .handle + .roundtrip_reply(protocol::Command::GetRecordLatency( + protocol::LatencyParams { + channel: self.0.info.channel, + now: time::SystemTime::now(), + }, + )) + .await + } + + /// Corks the record stream (temporarily pausing recording). + pub async fn cork(&self) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::CorkRecordStream( + protocol::CorkStreamParams { + channel: self.0.info.channel, + cork: true, + }, + )) + .await + } + + /// Uncorks the record stream. + pub async fn uncork(&self) -> ClientResult<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::CorkRecordStream( + protocol::CorkStreamParams { + channel: self.0.info.channel, + cork: false, + }, + )) + .await + } + + /// Returns a future that resolves when the first bytes are written to + /// the stream by the server. + pub async fn started(&self) -> ClientResult<()> { + self.0 + .start_notify + .clone() + .await + .map_err(|_| ClientError::Disconnected) + } + + /// Instructs the server to discard any buffered data. + pub async fn flush(&self) -> super::Result<()> { + self.0 + .handle + .roundtrip_ack(protocol::Command::FlushRecordStream(self.0.info.channel)) + .await + } + + /// Deletes the stream from the server. + pub async fn delete(self) -> ClientResult<()> { + self.0 + .handle + .delete_record_stream(self.0.info.channel) + .await + } +} + +impl Drop for InnerRecordStream { + fn drop(&mut self) { + // Sends the delete command to the server, but doesn't wait for the + // response. + let _ = self + .handle + .delete_record_stream(self.info.channel) + .now_or_never(); + } +} diff --git a/src/lib.rs b/src/lib.rs index b3100af..0533030 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,9 @@ use std::path::PathBuf; +mod client; pub mod protocol; +pub use client::*; /// Attempts to determine the socket path from the runtime environment, checking /// the following locations in order: diff --git a/src/protocol.rs b/src/protocol.rs index dd86519..de8e89a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -5,7 +5,10 @@ mod serde; mod error; -use std::io::{BufRead, Cursor, Read, Seek, SeekFrom, Write}; +use std::{ + ffi::CStr, + io::{BufRead, Cursor, Read, Seek, SeekFrom, Write}, +}; use bitflags::bitflags; use byteorder::NetworkEndian; @@ -29,6 +32,12 @@ pub const DESCRIPTOR_SIZE: usize = 5 * 4; /// for stream data, as well as the maximum buffer size, in bytes. pub const MAX_MEMBLOCKQ_LENGTH: usize = 4 * 1024 * 1024; +/// The protocol uses this sink name to indicate the default sink. +pub const DEFAULT_SINK: &CStr = c"@DEFAULT_SINK@"; + +/// The protocol uses this source name to indicate the default source. +pub const DEFAULT_SOURCE: &CStr = c"@DEFAULT_SOURCE@"; + bitflags! { /// Special message types. #[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]