kak-highlight/src/daemon/worker.rs

246 lines
7.4 KiB
Rust

use std::collections::{HashMap, VecDeque};
use std::io::{Read, Write};
use std::iter;
use std::os::unix::net::UnixStream;
use std::process::{Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use color_eyre::eyre::{self, Context};
use color_eyre::Result;
use crossbeam::deque::{Injector, Stealer, Worker};
use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter};
use crate::kakoune::{self, editor_quote};
use crate::Request;
type Task = Result<UnixStream>;
pub struct TaskScheduler {
injector: Arc<Injector<Task>>,
terminate: Arc<AtomicBool>,
threads: Vec<JoinHandle<()>>,
}
pub struct TaskContext {
worker: Worker<Task>,
injector: Arc<Injector<Task>>,
stealers: Arc<Vec<Stealer<Task>>>,
terminate: Arc<AtomicBool>,
highlighter: Highlighter,
highlight_config: Arc<HighlightConfiguration>,
tokens: Arc<HashMap<String, String>>,
}
impl TaskScheduler {
pub fn new(workers: u8, tokens: HashMap<String, String>) -> Result<Self> {
let terminate = Arc::new(AtomicBool::new(false));
let injector = Arc::new(Injector::new());
let workers: Vec<_> = (0..workers).map(|_| Worker::new_fifo()).collect();
let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
let stealers = Arc::new(stealers);
let mut highlight_config = HighlightConfiguration::new(
tree_sitter_rust::language(),
tree_sitter_rust::HIGHLIGHT_QUERY,
"",
"",
)
.wrap_err("Invalid highlighter config")?;
let names: Vec<_> = tokens.keys().collect();
tracing::debug!("Highlighter tokens: {:?}", names);
highlight_config.configure(&names);
let highlight_config = Arc::new(highlight_config);
let tokens = Arc::new(tokens);
let threads = workers
.into_iter()
.enumerate()
.map(|(i, worker)| {
let ctx = TaskContext {
worker,
injector: injector.clone(),
stealers: stealers.clone(),
terminate: terminate.clone(),
highlighter: Highlighter::new(),
highlight_config: highlight_config.clone(),
tokens: tokens.clone(),
};
thread::Builder::new()
.name(format!("worker-{}", i))
.spawn(|| thread_handler(ctx))
.map_err(From::from)
})
.collect::<Result<Vec<_>>>()
.wrap_err("Failed to spawn worker threads")?;
tracing::info!("Started {} worker threads", threads.len());
Ok(Self {
injector,
terminate,
threads,
})
}
pub fn schedule(&mut self, task: Task) {
self.injector.push(task)
}
pub fn terminate(self) {
self.terminate.store(true, Ordering::Relaxed);
for handle in self.threads {
if let Err(err) = handle.join() {
tracing::error!("Worker thread panicked: {:?}", err);
}
}
}
}
#[tracing::instrument]
fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
local.pop().or_else(|| {
iter::repeat_with(|| {
global
.steal_batch_and_pop(local)
.or_else(|| stealers.iter().map(|s| s.steal()).collect())
})
.find(|s| !s.is_retry())
.and_then(|s| s.success())
})
}
#[tracing::instrument(skip_all)]
fn thread_handler(mut ctx: TaskContext) {
loop {
let task = 'find_task: loop {
if ctx.terminate.load(Ordering::Relaxed) {
return;
}
if let Some(task) = find_task(&ctx.worker, &ctx.injector, &ctx.stealers) {
break 'find_task task;
}
thread::sleep(Duration::from_millis(50));
};
if let Err(err) = handle_connection(&mut ctx, task) {
tracing::error!("{:?}", err);
}
}
}
#[tracing::instrument(skip_all)]
fn handle_connection(ctx: &mut TaskContext, task: Task) -> Result<()> {
let mut stream = task.wrap_err("Failed to receive client connection")?;
let mut buf = String::new();
stream
.read_to_string(&mut buf)
.wrap_err("Failed to read from connection")?;
let req: Request = toml::from_str(&buf).wrap_err("Failed to parse request")?;
tracing::info!("Received request");
tracing::debug!(?req);
let response = process_request(ctx, &req)
.unwrap_or_else(|err| format!("fail {}", editor_quote(format!("{}", err))));
let mut child = Command::new("kak")
.args(["-p", &req.session])
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.wrap_err("Failed to spawn Kakoune command")?;
if let Some(stdin) = child.stdin.as_mut() {
let command = format!(
"evaluate-commands -client {} -verbatim -- {}",
req.client, response
);
tracing::info!("Writing response");
tracing::debug!(command);
stdin
.write_all(command.as_bytes())
.wrap_err("Failed to write to Kakoune stdin")?;
} else {
eyre::bail!("Failed to get stdin for Kakoune command");
}
Ok(())
}
#[tracing::instrument(skip(ctx, req), fields(
session = req.session,
client = req.client,
content_len = req.content.len(),
timestamp = req.timestamp,
))]
fn process_request(ctx: &mut TaskContext, req: &Request) -> Result<String> {
let names: Vec<_> = ctx.tokens.keys().collect();
let highlights = ctx
.highlighter
.highlight(&ctx.highlight_config, req.content.as_bytes(), None, |_| {
None
})
.wrap_err("Failed to highlight content")?;
let mut stack = VecDeque::new();
let mut range_spec = String::new();
for res in highlights {
match res? {
HighlightEvent::Source { start, end } => {
if let Some(index) = stack.back() {
// Tree-sitter actually returns the byte position after the token
// as `end` here.
let end = end.saturating_sub(1);
let range =
kakoune::range_from_byte_offsets(req.content.as_bytes(), start, end);
tracing::trace!(start, end, ?range, index);
let spec = format!(
"{}.{},{}.{}|{}",
range.start_point.row,
range.start_point.column,
range.end_point.row,
range.end_point.column,
ctx.tokens[names[*index]]
);
range_spec.push(' ');
range_spec.push_str(&spec);
}
}
HighlightEvent::HighlightStart(index) => {
stack.push_back(index.0);
}
HighlightEvent::HighlightEnd => {
// Tree-sitter shouldn't call this when there is nothing on the stack,
// but it wouldn't matter anyways.
let _ = stack.pop_back();
}
}
}
let response = format!(
"set-option buffer kak_highlight_ranges {}{range_spec}",
req.timestamp
);
Ok(response)
}