From 4438a114370acd93576f26851aa706f45c53eefe Mon Sep 17 00:00:00 2001 From: Lucas Schwiderski Date: Thu, 13 Apr 2023 16:55:59 +0200 Subject: [PATCH] feat: Improve daemon setup Properly clean up on termination and improve thread setup. --- Cargo.lock | 103 +++++++++++++++++++++++++++++++ Cargo.toml | 2 + src/client.rs | 33 ++++++++++ src/daemon/listener.rs | 70 +++++++++++++++++++++ src/daemon/mod.rs | 66 ++++++++++++++++++++ src/daemon/worker.rs | 137 +++++++++++++++++++++++++++++++++++++++++ src/kakoune.rs | 4 ++ src/main.rs | 91 ++++++--------------------- 8 files changed, 435 insertions(+), 71 deletions(-) create mode 100644 src/client.rs create mode 100644 src/daemon/listener.rs create mode 100644 src/daemon/mod.rs create mode 100644 src/daemon/worker.rs create mode 100644 src/kakoune.rs diff --git a/Cargo.lock b/Cargo.lock index b11684d..b7b3ade 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -183,6 +183,73 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +dependencies = [ + "cfg-if", +] + [[package]] name = "errno" version = "0.3.1" @@ -283,7 +350,9 @@ version = "0.1.0" dependencies = [ "clap", "color-eyre", + "crossbeam", "serde", + "signal-hook", "toml", "tracing", "tracing-error", @@ -332,6 +401,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memoffset" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +dependencies = [ + "autocfg", +] + [[package]] name = "miniz_oxide" version = "0.6.2" @@ -446,6 +524,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "serde" version = "1.0.160" @@ -484,6 +568,25 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "smallvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 5e87427..fdf3696 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,9 @@ edition = "2021" [dependencies] clap = { version = "4.2.1", features = ["cargo", "color", "unicode", "std", "derive"] } color-eyre = "0.6.2" +crossbeam = "0.8.2" serde = { version = "1.0.160", features = ["derive"] } +signal-hook = "0.3.15" toml = "0.7.3" tracing = "0.1.37" tracing-error = "0.2.0" diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..1be1a10 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,33 @@ +use std::io::{stdin, Read, Write}; +use std::os::unix::net::UnixStream; +use std::path::PathBuf; + +use color_eyre::eyre::Context; +use color_eyre::Result; + +use crate::Request; + +pub fn handle(runtime_dir: PathBuf) -> Result<()> { + let mut buf = String::new(); + stdin() + .read_to_string(&mut buf) + .wrap_err("Failed to read stdin")?; + + let req: Request = toml::from_str(&buf).wrap_err("Failed to parse request")?; + + tracing::trace!("Received request: {:?}", req); + + let path = runtime_dir.join(&req.session).with_extension("s"); + tracing::debug!(path = %path.display()); + + let mut socket = UnixStream::connect(&path) + .wrap_err_with(|| format!("Failed to connect to daemon socket {}", path.display()))?; + + socket + .write_all(buf.as_bytes()) + .wrap_err("Failed to send request")?; + + tracing::info!("Sent request to {}", path.display()); + + Ok(()) +} diff --git a/src/daemon/listener.rs b/src/daemon/listener.rs new file mode 100644 index 0000000..36e98e7 --- /dev/null +++ b/src/daemon/listener.rs @@ -0,0 +1,70 @@ +use std::fs; +use std::io::ErrorKind; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::{Path, PathBuf}; +use std::thread; + +use color_eyre::eyre::Context; +use color_eyre::Result; +use crossbeam::channel::{unbounded, Receiver}; + +type Task = Result; + +pub struct Listener { + path: PathBuf, + receiver: Receiver, +} + +impl Listener { + pub fn new(runtime_dir: impl AsRef, session: String) -> Result { + let runtime_dir = runtime_dir.as_ref(); + + if let Err(err) = fs::metadata(runtime_dir) { + if err.kind() == ErrorKind::NotFound { + fs::create_dir_all(runtime_dir).wrap_err_with(|| { + format!( + "Failed to create runtime directory {}", + runtime_dir.display() + ) + })?; + } + } + + let path = runtime_dir.join(session).with_extension("s"); + let listener = UnixListener::bind(&path) + .wrap_err_with(|| format!("Failed to bind listener to {}", path.display()))?; + + tracing::info!("Listening on {}", path.display()); + + let (tx, rx) = unbounded(); + + thread::Builder::new() + .name("socker-listener".into()) + .spawn(move || { + for stream in listener.incoming() { + tx.send(stream.map_err(From::from)) + .expect("Failed to send stream"); + } + }) + .wrap_err("Failed to spawn listener thread")?; + + Ok(Self { path, receiver: rx }) + } + + pub fn receiver(&self) -> &Receiver { + &self.receiver + } + + pub fn close(&mut self) -> Result<()> { + tracing::trace!("Removing socket '{}'", self.path.display()); + + fs::remove_file(&self.path) + .wrap_err_with(|| format!("Failed to remove socket file '{}'", self.path.display())) + } +} + +impl Drop for Listener { + fn drop(&mut self) { + self.close().expect("Failed to close Listener") + } +} diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs new file mode 100644 index 0000000..9ffd5e1 --- /dev/null +++ b/src/daemon/mod.rs @@ -0,0 +1,66 @@ +use std::path::PathBuf; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::thread; + +use color_eyre::eyre::Context; +use color_eyre::Result; +use crossbeam::channel::bounded; +use crossbeam::select; +use signal_hook::consts::{SIGINT, TERM_SIGNALS}; +use signal_hook::flag; +use signal_hook::iterator::Signals; + +use crate::daemon::worker::TaskScheduler; + +use self::listener::Listener; + +mod listener; +mod worker; + +#[tracing::instrument] +pub fn handle(runtime_dir: PathBuf, session: String, workers: u8) -> Result<()> { + // Ensure that sending a termination signal twice will immediately kill + { + let term_now = Arc::new(AtomicBool::new(false)); + for sig in TERM_SIGNALS { + flag::register_conditional_shutdown(*sig, 1, Arc::clone(&term_now))?; + flag::register(*sig, Arc::clone(&term_now))?; + } + } + + let rx_signals = { + let mut signals = Signals::new([SIGINT])?; + let (tx, rx) = bounded(1); + + thread::Builder::new() + .name("signal-handler".into()) + .spawn(move || { + for sig in &mut signals { + tx.send(sig).expect("Failed to send signal"); + } + }) + .wrap_err("Failed to start signal handler")?; + + rx + }; + + let mut scheduler = TaskScheduler::new(workers).wrap_err("Failed to create task scheduler")?; + let listener = Listener::new(runtime_dir, session).wrap_err("Failed to create listener")?; + + loop { + select! { + recv(listener.receiver()) -> task => scheduler.schedule(task?), + recv(rx_signals) -> _ => { + tracing::info!( + "Received shutdown signal, waiting for workers to finish. \ + Send termination again to force quit."); + break; + }, + } + } + + scheduler.terminate(); + + Ok(()) +} diff --git a/src/daemon/worker.rs b/src/daemon/worker.rs new file mode 100644 index 0000000..6eff457 --- /dev/null +++ b/src/daemon/worker.rs @@ -0,0 +1,137 @@ +use std::io::Read; +use std::os::unix::net::UnixStream; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::Duration; +use std::{fs, iter}; + +use color_eyre::eyre::{self, Context}; +use color_eyre::Result; +use crossbeam::deque::{Injector, Stealer, Worker}; + +use crate::kakoune::editor_quote; +use crate::Request; + +type Task = Result; + +pub struct TaskScheduler { + injector: Arc>, + terminate: Arc, + threads: Vec>, +} + +impl TaskScheduler { + pub fn new(workers: u8) -> Result { + let terminate = Arc::new(AtomicBool::new(false)); + + let injector = Arc::new(Injector::new()); + let workers: Vec<_> = (0..workers).map(|_| Worker::new_fifo()).collect(); + let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect(); + let stealers = Arc::new(stealers); + + let threads = workers + .into_iter() + .enumerate() + .map(|(i, worker)| { + let injector = injector.clone(); + let stealers = stealers.clone(); + let terminate = terminate.clone(); + + thread::Builder::new() + .name(format!("worker-{}", i)) + .spawn(|| thread_handler(worker, injector, stealers, terminate)) + .map_err(From::from) + }) + .collect::>>() + .wrap_err("Failed to spawn worker threads")?; + + tracing::info!("Started {} worker threads", threads.len()); + + Ok(Self { + injector, + terminate, + threads, + }) + } + + pub fn schedule(&mut self, task: Task) { + self.injector.push(task) + } + + pub fn terminate(self) { + self.terminate.store(true, Ordering::Relaxed); + + for handle in self.threads { + if let Err(err) = handle.join() { + tracing::error!("Worker thread panicked: {:?}", err); + } + } + } +} + +#[tracing::instrument] +fn find_task(local: &Worker, global: &Injector, stealers: &[Stealer]) -> Option { + local.pop().or_else(|| { + iter::repeat_with(|| { + global + .steal_batch_and_pop(local) + .or_else(|| stealers.iter().map(|s| s.steal()).collect()) + }) + .find(|s| !s.is_retry()) + .and_then(|s| s.success()) + }) +} + +#[tracing::instrument] +fn thread_handler( + worker: Worker, + injector: Arc>, + stealers: Arc>>, + terminate: Arc, +) { + loop { + let task = 'find_task: loop { + if terminate.load(Ordering::Relaxed) { + return; + } + + if let Some(task) = find_task(&worker, &injector, &stealers) { + break 'find_task task; + } + + thread::sleep(Duration::from_millis(50)); + }; + + if let Err(err) = handle_connection(task) { + tracing::error!("{:?}", err); + } + } +} + +#[tracing::instrument(skip_all)] +fn handle_connection(task: Task) -> Result<()> { + let mut stream = task.wrap_err("Failed to receive client connection")?; + + let mut buf = String::new(); + stream + .read_to_string(&mut buf) + .wrap_err("Failed to read from connection")?; + let req: Request = toml::from_str(&buf).wrap_err("Failed to parse request")?; + + tracing::info!("Received request: {:?}", req); + + let response = process_request(&req) + .unwrap_or_else(|err| format!("fail {}", editor_quote(format!("{}", err)))); + + tracing::debug!("Sending response:\n{}", response); + + fs::write(&req.fifo, response.as_bytes()).wrap_err("Failed to write to command fifo")?; + + Ok(()) +} + +#[tracing::instrument] +fn process_request(req: &Request) -> Result { + eyre::bail!("Not implemented") +} diff --git a/src/kakoune.rs b/src/kakoune.rs new file mode 100644 index 0000000..1b8dae2 --- /dev/null +++ b/src/kakoune.rs @@ -0,0 +1,4 @@ +pub fn editor_quote(s: impl AsRef) -> String { + // TODO + format!("'{}'", s.as_ref()) +} diff --git a/src/main.rs b/src/main.rs index 0b8bf3e..cad939e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,6 @@ -use std::fs::{self, File}; -use std::io::{stderr, stdin, ErrorKind, Read, Write}; -use std::os::unix::net::{UnixListener, UnixStream}; +use std::fs::File; +use std::io::stderr; use std::path::PathBuf; -use std::thread; use clap::{Parser, Subcommand}; use color_eyre::eyre::Context; @@ -13,6 +11,10 @@ use tracing_error::ErrorLayer; use tracing_subscriber::fmt; use tracing_subscriber::prelude::*; +mod client; +mod daemon; +mod kakoune; + #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Cli { @@ -35,12 +37,18 @@ enum Command { Daemon { /// The Kakoune session this daemon belongs to session: String, + /// The number of worker threads to spawn + #[arg(short = 'n', long, default_value = "2", action = clap::ArgAction::Count)] + workers: u8, }, } #[derive(Clone, Debug, Deserialize)] struct Request { + /// The Kakoune session this request came from session: String, + /// The command FIFO provided by Kakoune + fifo: String, } #[tracing::instrument] @@ -54,7 +62,12 @@ fn main() -> Result<()> { 2 => LevelFilter::DEBUG, _ => LevelFilter::TRACE, }; - let stderr_layer = fmt::layer().pretty().with_writer(stderr); + + let stderr_layer = if cfg!(debug_assertions) { + fmt::layer().pretty().with_writer(stderr).boxed() + } else { + fmt::layer().compact().with_writer(stderr).boxed() + }; let file_layer = if let Some(path) = cli.log { let f = File::create(&path) .wrap_err_with(|| format!("Failed to create log file '{}'", path.display()))?; @@ -78,71 +91,7 @@ fn main() -> Result<()> { .join("kak-highlight"); match cli.command { - Command::Daemon { session } => { - if let Err(err) = fs::metadata(&runtime_dir) { - if err.kind() == ErrorKind::NotFound { - fs::create_dir_all(&runtime_dir).wrap_err_with(|| { - format!( - "Failed to create runtime directory {}", - runtime_dir.display() - ) - })?; - } - } - - let path = runtime_dir.join(session).with_extension("s"); - let listener = UnixListener::bind(&path) - .wrap_err_with(|| format!("Failed to bind listener to {}", path.display()))?; - - tracing::info!("Listening on {}", path.display()); - - for stream in listener.incoming() { - let stream = stream.wrap_err("Failed to accept incoming connection")?; - thread::spawn(|| { - if let Err(err) = handle_connection(stream) { - tracing::error!("{:?}", err); - } - }); - } - - Ok(()) - } - Command::Request => { - let mut buf = String::new(); - stdin() - .read_to_string(&mut buf) - .wrap_err("Failed to read stdin")?; - - let req: Request = toml::from_str(&buf).wrap_err("Failed to parse request")?; - - tracing::trace!("Received request: {:?}", req); - - let path = runtime_dir.join(&req.session).with_extension("s"); - tracing::debug!(path = %path.display()); - - let mut socket = UnixStream::connect(&path).wrap_err_with(|| { - format!("Failed to connect to daemon socket {}", path.display()) - })?; - - socket - .write_all(buf.as_bytes()) - .wrap_err("Failed to send response")?; - - tracing::info!("Sent request to {}", path.display()); - - Ok(()) - } + Command::Daemon { session, workers } => daemon::handle(runtime_dir, session, workers), + Command::Request => client::handle(runtime_dir), } } - -#[tracing::instrument(skip_all)] -fn handle_connection(mut read: impl Read) -> Result<()> { - let mut buf = String::new(); - read.read_to_string(&mut buf) - .wrap_err("Failed to read stdin")?; - let req: Request = toml::from_str(&buf).wrap_err("Failed to parse request")?; - - tracing::info!("Received request: {:?}", req); - - todo!(); -}