diff --git a/.gitignore b/.gitignore index 1d00310..c19d4d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /target .envrc -test.lua +*.lua diff --git a/Cargo.lock b/Cargo.lock index 401708f..96fac98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,7 +704,6 @@ dependencies = [ "serde_json", "serde_repr", "tokio", - "tokio-stream", "tracing", "tracing-error", "tracing-subscriber", @@ -1376,17 +1375,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-stream" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.12" diff --git a/Cargo.toml b/Cargo.toml index 526226c..2c2b303 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ serde = { version = "1.0.209", features = ["derive"] } serde_json = "1.0.128" serde_repr = "0.1.19" tokio = { version = "1.40.0", features = ["rt", "sync"] } -tokio-stream = "0.1.16" tracing = "0.1.40" tracing-error = "0.2.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs index e0e1a4e..3e1a403 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,9 @@ mod worker { pub use server::worker as server; } +pub(crate) static APP_USER_AGENT: &str = + concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); + fn spawn(name: &'static str, task: F) -> Result>> where F: FnOnce() -> Result + Send + 'static, diff --git a/src/types.rs b/src/types.rs index 71e4ff5..57de440 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,42 @@ use std::collections::HashMap; +use mlua::IntoLua; use serde::{Deserialize, Serialize}; +/// Converts a `serde_json::Value` to a `mlua::Value`. +fn json_to_lua(value: serde_json::Value, lua: &mlua::Lua) -> mlua::Result> { + match value { + serde_json::Value::Null => Ok(mlua::Value::Nil), + serde_json::Value::Bool(value) => Ok(mlua::Value::Boolean(value)), + serde_json::Value::Number(value) => match value.as_f64() { + Some(number) => Ok(mlua::Value::Number(number)), + None => Err(mlua::Error::ToLuaConversionError { + from: "serde_json::Value::Number", + to: "Number", + message: Some("Number cannot be represented by a floating-point type".into()), + }), + }, + serde_json::Value::String(value) => lua.create_string(value).map(mlua::Value::String), + serde_json::Value::Array(value) => { + let tbl = lua.create_table_with_capacity(value.len(), 0)?; + for v in value { + let v = json_to_lua(v, lua)?; + tbl.push(v)?; + } + Ok(mlua::Value::Table(tbl)) + } + serde_json::Value::Object(value) => { + let tbl = lua.create_table_with_capacity(0, value.len())?; + for (k, v) in value { + let k = lua.create_string(k)?; + let v = json_to_lua(v, lua)?; + tbl.set(k, v)?; + } + Ok(mlua::Value::Table(tbl)) + } + } +} + #[derive(Clone, Debug, Serialize)] pub(crate) struct WebhookEvent { pub topic: String, @@ -13,9 +48,32 @@ pub(crate) struct WebhookEvent { pub(crate) struct ApiEvent { pub id: String, pub body: serde_json::Value, + pub headers: HashMap>, pub status: u16, } +impl<'lua> IntoLua<'lua> for ApiEvent { + fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result> { + let headers = lua.create_table_with_capacity(0, self.headers.len())?; + for (k, v) in self.headers { + let k = lua.create_string(k)?; + let v = lua.create_string(v)?; + headers.set(k, v)?; + } + + let body = json_to_lua(self.body, lua)?; + let tbl = lua.create_table_with_capacity(0, 4)?; + + lua.create_string(self.id) + .and_then(|id| tbl.set("id", id))?; + tbl.set("status", self.status)?; + tbl.set("headers", headers)?; + tbl.set("body", body)?; + + Ok(mlua::Value::Table(tbl)) + } +} + #[derive(Clone, Debug, Serialize)] pub(crate) struct ErrorEvent { pub id: String, @@ -63,8 +121,15 @@ pub(crate) struct NtfyAction { #[derive(Clone, Debug, Deserialize, Serialize)] pub(crate) struct NtfyMessage { + // Needs to be deserialized from Lua, but should not be sent to Ntfy + #[serde(skip_serializing)] + pub url: String, + // Needs to be deserialized from Lua, but should not be sent to Ntfy + #[serde(skip_serializing)] + pub token: String, pub topic: String, pub message: String, + #[serde(default)] pub title: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tags: Vec, @@ -89,16 +154,22 @@ pub(crate) enum Message { #[derive(Clone, Debug, Deserialize)] pub(crate) enum Method { + #[serde(alias = "GET", alias = "get")] Get, + #[serde(alias = "POST", alias = "post")] Post, } #[derive(Clone, Debug, Deserialize)] pub(crate) struct ApiTask { pub id: String, - pub interval: u64, + #[serde(default)] + pub delay: u64, pub method: Method, pub url: String, pub body: serde_json::Value, + #[serde(default)] pub query: HashMap, + #[serde(default)] + pub headers: HashMap, } diff --git a/src/worker/api.rs b/src/worker/api.rs index 2afdcfb..593b849 100644 --- a/src/worker/api.rs +++ b/src/worker/api.rs @@ -1,17 +1,25 @@ -use std::sync::mpsc::Sender; -use std::sync::Arc; +use std::ascii::AsciiExt; use std::time::Duration; +use std::{collections::HashMap, sync::mpsc::Sender}; use color_eyre::eyre::Context as _; use color_eyre::Result; -use reqwest::Client; +use reqwest::{header, Client}; use tokio::runtime; use tokio::sync::mpsc::UnboundedReceiver; -use tokio::sync::Mutex; -use tokio_stream::wrappers::IntervalStream; -use tokio_stream::{StreamExt as _, StreamMap}; use crate::types::{ApiEvent, ApiTask, Event, Method}; +use crate::APP_USER_AGENT; + +fn header_map_to_hashmap(headers: &header::HeaderMap) -> HashMap> { + let mut map = HashMap::with_capacity(headers.len()); + + for (k, v) in headers { + map.insert(k.to_string(), v.as_bytes().to_vec()); + } + + map +} async fn perform_request(client: &Client, task: &ApiTask) -> Result { let req = match task.method { @@ -22,70 +30,81 @@ async fn perform_request(client: &Client, task: &ApiTask) -> Result { client.post(&task.url).body(body) } }; - let res = req.query(&task.query).send().await?; - let status = res.status().as_u16(); - let body: serde_json::Value = res.json().await?; + let res = req + .query(&task.query) + .headers({ + let mut headers = header::HeaderMap::new(); + + for (k, v) in &task.headers { + // Non-ASCII characters aren't supported anyways for header name, so no need to map + // those. + let k = header::HeaderName::from_lowercase(k.to_ascii_lowercase().as_bytes()) + .wrap_err_with(|| format!("Invalid header name '{}'", k))?; + let v = header::HeaderValue::from_str(v) + .wrap_err_with(|| format!("Invalid header value '{}'", v))?; + + headers.insert(k, v); + } + + headers + }) + .send() + .await?; + let status = res.status(); + let headers = header_map_to_hashmap(res.headers()); + + let body: serde_json::Value = if let Ok(text) = res.text().await { + tracing::trace!("Response body: {}", text); + serde_json::from_str(&text).unwrap_or(serde_json::Value::String(text)) + } else { + tracing::trace!("Response body: NULL"); + serde_json::Value::Null + }; let event = ApiEvent { id: task.id.clone(), + headers, body, - status, + status: status.as_u16(), }; Ok(Event::Api(event)) } #[tracing::instrument(skip_all)] pub async fn run(mut api_rx: UnboundedReceiver, event_tx: Sender) -> Result<()> { - let tasks = Arc::new(Mutex::new(StreamMap::new())); + let (task_tx, mut task_rx) = tokio::sync::mpsc::channel(64); let client = Client::builder() + .user_agent(APP_USER_AGENT) .build() .wrap_err("Failed to build HTTP client")?; tokio::spawn({ - let tasks = tasks.clone(); async move { while let Some(task) = api_rx.recv().await { tracing::trace!("Received new API task: {:?}", task); - let id = task.id.clone(); - let interval = tokio::time::interval(Duration::from_secs(task.interval)); - - let task = Arc::new(task); - let mut tasks = tasks.lock().await; - tasks.insert(id, IntervalStream::new(interval).map(move |_| task.clone())); + let _ = task_tx + .send(async move { + tokio::time::sleep(Duration::from_secs(task.delay)).await; + task + }) + .await; } tracing::error!("API task channel closed"); } }); - loop { - // We need to guarantee that the lock on `tasks` is released as soon as possible. - // To ensure that it isn't held for the entirety of the `sleep`, the indirection - // via the extra `bool` value is used, so that the lock can be dropped immediately, - // and during the `sleep` new streams can be created. - let wait = { - let mut tasks = tasks.lock().await; - match tasks.next().await { - Some((_, task)) => { - let event = perform_request(&client, &task) - .await - .wrap_err("Failed to perform API request")?; - event_tx.send(event).wrap_err("Failed to send event")?; - - false - } - None => true, - } - }; - - if wait { - // Ideally we would be able to wait explicitly for the first stream - // to be registered. But as a workaround, we have to idle wait. - tokio::time::sleep(Duration::from_millis(500)).await; - } + while let Some(task) = task_rx.recv().await { + let task = task.await; + let event = perform_request(&client, &task) + .await + .wrap_err("Failed to perform API request")?; + event_tx.send(event).wrap_err("Failed to send event")?; } + + Ok(()) } #[tracing::instrument("api::worker", skip_all)] diff --git a/src/worker/lua.rs b/src/worker/lua.rs index 9911f3e..003327a 100644 --- a/src/worker/lua.rs +++ b/src/worker/lua.rs @@ -1,7 +1,7 @@ use std::sync::mpsc::Receiver; use color_eyre::Result; -use mlua::{Function, Lua, LuaSerdeExt}; +use mlua::{Function, IntoLua as _, Lua, LuaSerdeExt}; use tokio::sync::mpsc::UnboundedSender; use crate::types::{ApiTask, Event, Message}; @@ -70,7 +70,7 @@ pub fn worker( event_fn.call::<_, ()>(("webhook", data))? } Event::Api(data) => { - let data = lua.to_value(&data)?; + let data = data.into_lua(&lua)?; event_fn.call::<_, ()>(("api", data))? } Event::Error(data) => { diff --git a/src/worker/sender.rs b/src/worker/sender.rs index 79a18be..9288760 100644 --- a/src/worker/sender.rs +++ b/src/worker/sender.rs @@ -1,22 +1,38 @@ use std::sync::mpsc::Sender; -use color_eyre::eyre::{self, Context as _}; +use color_eyre::eyre; use color_eyre::Result; -use reqwest::{header, Client}; +use reqwest::header; +use reqwest::Client; use tokio::runtime; use tokio::sync::mpsc::UnboundedReceiver; use crate::types::{Event, Message, NtfyMessage}; - -static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); +use crate::APP_USER_AGENT; #[tracing::instrument] -async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result<()> { +async fn send_ntfy(client: &Client, data: &NtfyMessage) -> Result<()> { let body = serde_json::to_string(data)?; - tracing::trace!("JSON: {}", body); - let res = client.post(url).body(body).send().await?; + + let res = client + .post(&data.url) + .headers({ + let mut headers = header::HeaderMap::new(); + + let auth = format!("Bearer {}", data.token); + let mut auth = header::HeaderValue::from_str(&auth)?; + auth.set_sensitive(true); + + headers.insert(header::AUTHORIZATION, auth); + + headers + }) + .body(body) + .send() + .await?; + if res.status().as_u16() >= 400 { - let body = res.text().await?; + let body = res.text().await.unwrap_or_default(); eyre::bail!("Ntfy server returned error: {}", body); } @@ -25,30 +41,14 @@ async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result< #[tracing::instrument(skip_all)] pub async fn run(event_tx: Sender, mut ntfy_rx: UnboundedReceiver) -> Result<()> { - let ntfy_url = std::env::var("NTFY_URL").wrap_err("Missing env var 'NTFY_URL'")?; - let ntfy_token = std::env::var("NTFY_TOKEN").wrap_err("Missing env var 'NTFY_TOKEN'")?; - - let ntfy_client = Client::builder() - .default_headers({ - let mut headers = header::HeaderMap::new(); - - let auth = format!("Bearer {}", ntfy_token); - let mut auth = header::HeaderValue::from_str(&auth)?; - auth.set_sensitive(true); - - headers.insert(header::AUTHORIZATION, auth); - - headers - }) - .user_agent(APP_USER_AGENT) - .build()?; + let ntfy_client = Client::builder().user_agent(APP_USER_AGENT).build()?; while let Some(message) = ntfy_rx.recv().await { tracing::trace!("Received notification: {:?}", message); match message { Message::Ntfy(data) => { - if let Err(err) = send_ntfy(&ntfy_client, &ntfy_url, &data).await { + if let Err(err) = send_ntfy(&ntfy_client, &data).await { tracing::error!("Failed to send to Ntfy: {:?}", err); event_tx.send(Event::error(data.topic, err))?; }