From d2cb39f9a24cf3b25743619084652f60f83e13ff Mon Sep 17 00:00:00 2001 From: Lucas Schwiderski Date: Wed, 18 Sep 2024 11:31:10 +0200 Subject: [PATCH] Rework API tasks GitHub expects a 'Last-Modified' header, and honoring an 'X-Poll-Interval' header for their notifications endpoint. Other services might also have certain limitations that require customizing every API query individually. Since that's not possible if API tasks are configured once and run off of an interval, this reworks them so that the config needs to trigger every query individually. A `delay` parameter allows re-creating the same intervals that were possible before. This also moves the configuration for Ntfy to the Lua file. --- .gitignore | 2 +- Cargo.lock | 12 ----- Cargo.toml | 1 - src/main.rs | 3 ++ src/types.rs | 73 +++++++++++++++++++++++++++++- src/worker/api.rs | 105 +++++++++++++++++++++++++------------------ src/worker/lua.rs | 4 +- src/worker/sender.rs | 52 ++++++++++----------- 8 files changed, 166 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index 1d00310..c19d4d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /target .envrc -test.lua +*.lua diff --git a/Cargo.lock b/Cargo.lock index 401708f..96fac98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,7 +704,6 @@ dependencies = [ "serde_json", "serde_repr", "tokio", - "tokio-stream", "tracing", "tracing-error", "tracing-subscriber", @@ -1376,17 +1375,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-stream" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.12" diff --git a/Cargo.toml b/Cargo.toml index 526226c..2c2b303 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ serde = { version = "1.0.209", features = ["derive"] } serde_json = "1.0.128" serde_repr = "0.1.19" tokio = { version = "1.40.0", features = ["rt", "sync"] } -tokio-stream = "0.1.16" tracing = "0.1.40" tracing-error = "0.2.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs index e0e1a4e..3e1a403 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,9 @@ mod worker { pub use server::worker as server; } +pub(crate) static APP_USER_AGENT: &str = + concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); + fn spawn(name: &'static str, task: F) -> Result>> where F: FnOnce() -> Result + Send + 'static, diff --git a/src/types.rs b/src/types.rs index 71e4ff5..57de440 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,42 @@ use std::collections::HashMap; +use mlua::IntoLua; use serde::{Deserialize, Serialize}; +/// Converts a `serde_json::Value` to a `mlua::Value`. +fn json_to_lua(value: serde_json::Value, lua: &mlua::Lua) -> mlua::Result> { + match value { + serde_json::Value::Null => Ok(mlua::Value::Nil), + serde_json::Value::Bool(value) => Ok(mlua::Value::Boolean(value)), + serde_json::Value::Number(value) => match value.as_f64() { + Some(number) => Ok(mlua::Value::Number(number)), + None => Err(mlua::Error::ToLuaConversionError { + from: "serde_json::Value::Number", + to: "Number", + message: Some("Number cannot be represented by a floating-point type".into()), + }), + }, + serde_json::Value::String(value) => lua.create_string(value).map(mlua::Value::String), + serde_json::Value::Array(value) => { + let tbl = lua.create_table_with_capacity(value.len(), 0)?; + for v in value { + let v = json_to_lua(v, lua)?; + tbl.push(v)?; + } + Ok(mlua::Value::Table(tbl)) + } + serde_json::Value::Object(value) => { + let tbl = lua.create_table_with_capacity(0, value.len())?; + for (k, v) in value { + let k = lua.create_string(k)?; + let v = json_to_lua(v, lua)?; + tbl.set(k, v)?; + } + Ok(mlua::Value::Table(tbl)) + } + } +} + #[derive(Clone, Debug, Serialize)] pub(crate) struct WebhookEvent { pub topic: String, @@ -13,9 +48,32 @@ pub(crate) struct WebhookEvent { pub(crate) struct ApiEvent { pub id: String, pub body: serde_json::Value, + pub headers: HashMap>, pub status: u16, } +impl<'lua> IntoLua<'lua> for ApiEvent { + fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result> { + let headers = lua.create_table_with_capacity(0, self.headers.len())?; + for (k, v) in self.headers { + let k = lua.create_string(k)?; + let v = lua.create_string(v)?; + headers.set(k, v)?; + } + + let body = json_to_lua(self.body, lua)?; + let tbl = lua.create_table_with_capacity(0, 4)?; + + lua.create_string(self.id) + .and_then(|id| tbl.set("id", id))?; + tbl.set("status", self.status)?; + tbl.set("headers", headers)?; + tbl.set("body", body)?; + + Ok(mlua::Value::Table(tbl)) + } +} + #[derive(Clone, Debug, Serialize)] pub(crate) struct ErrorEvent { pub id: String, @@ -63,8 +121,15 @@ pub(crate) struct NtfyAction { #[derive(Clone, Debug, Deserialize, Serialize)] pub(crate) struct NtfyMessage { + // Needs to be deserialized from Lua, but should not be sent to Ntfy + #[serde(skip_serializing)] + pub url: String, + // Needs to be deserialized from Lua, but should not be sent to Ntfy + #[serde(skip_serializing)] + pub token: String, pub topic: String, pub message: String, + #[serde(default)] pub title: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tags: Vec, @@ -89,16 +154,22 @@ pub(crate) enum Message { #[derive(Clone, Debug, Deserialize)] pub(crate) enum Method { + #[serde(alias = "GET", alias = "get")] Get, + #[serde(alias = "POST", alias = "post")] Post, } #[derive(Clone, Debug, Deserialize)] pub(crate) struct ApiTask { pub id: String, - pub interval: u64, + #[serde(default)] + pub delay: u64, pub method: Method, pub url: String, pub body: serde_json::Value, + #[serde(default)] pub query: HashMap, + #[serde(default)] + pub headers: HashMap, } diff --git a/src/worker/api.rs b/src/worker/api.rs index 2afdcfb..593b849 100644 --- a/src/worker/api.rs +++ b/src/worker/api.rs @@ -1,17 +1,25 @@ -use std::sync::mpsc::Sender; -use std::sync::Arc; +use std::ascii::AsciiExt; use std::time::Duration; +use std::{collections::HashMap, sync::mpsc::Sender}; use color_eyre::eyre::Context as _; use color_eyre::Result; -use reqwest::Client; +use reqwest::{header, Client}; use tokio::runtime; use tokio::sync::mpsc::UnboundedReceiver; -use tokio::sync::Mutex; -use tokio_stream::wrappers::IntervalStream; -use tokio_stream::{StreamExt as _, StreamMap}; use crate::types::{ApiEvent, ApiTask, Event, Method}; +use crate::APP_USER_AGENT; + +fn header_map_to_hashmap(headers: &header::HeaderMap) -> HashMap> { + let mut map = HashMap::with_capacity(headers.len()); + + for (k, v) in headers { + map.insert(k.to_string(), v.as_bytes().to_vec()); + } + + map +} async fn perform_request(client: &Client, task: &ApiTask) -> Result { let req = match task.method { @@ -22,70 +30,81 @@ async fn perform_request(client: &Client, task: &ApiTask) -> Result { client.post(&task.url).body(body) } }; - let res = req.query(&task.query).send().await?; - let status = res.status().as_u16(); - let body: serde_json::Value = res.json().await?; + let res = req + .query(&task.query) + .headers({ + let mut headers = header::HeaderMap::new(); + + for (k, v) in &task.headers { + // Non-ASCII characters aren't supported anyways for header name, so no need to map + // those. + let k = header::HeaderName::from_lowercase(k.to_ascii_lowercase().as_bytes()) + .wrap_err_with(|| format!("Invalid header name '{}'", k))?; + let v = header::HeaderValue::from_str(v) + .wrap_err_with(|| format!("Invalid header value '{}'", v))?; + + headers.insert(k, v); + } + + headers + }) + .send() + .await?; + let status = res.status(); + let headers = header_map_to_hashmap(res.headers()); + + let body: serde_json::Value = if let Ok(text) = res.text().await { + tracing::trace!("Response body: {}", text); + serde_json::from_str(&text).unwrap_or(serde_json::Value::String(text)) + } else { + tracing::trace!("Response body: NULL"); + serde_json::Value::Null + }; let event = ApiEvent { id: task.id.clone(), + headers, body, - status, + status: status.as_u16(), }; Ok(Event::Api(event)) } #[tracing::instrument(skip_all)] pub async fn run(mut api_rx: UnboundedReceiver, event_tx: Sender) -> Result<()> { - let tasks = Arc::new(Mutex::new(StreamMap::new())); + let (task_tx, mut task_rx) = tokio::sync::mpsc::channel(64); let client = Client::builder() + .user_agent(APP_USER_AGENT) .build() .wrap_err("Failed to build HTTP client")?; tokio::spawn({ - let tasks = tasks.clone(); async move { while let Some(task) = api_rx.recv().await { tracing::trace!("Received new API task: {:?}", task); - let id = task.id.clone(); - let interval = tokio::time::interval(Duration::from_secs(task.interval)); - - let task = Arc::new(task); - let mut tasks = tasks.lock().await; - tasks.insert(id, IntervalStream::new(interval).map(move |_| task.clone())); + let _ = task_tx + .send(async move { + tokio::time::sleep(Duration::from_secs(task.delay)).await; + task + }) + .await; } tracing::error!("API task channel closed"); } }); - loop { - // We need to guarantee that the lock on `tasks` is released as soon as possible. - // To ensure that it isn't held for the entirety of the `sleep`, the indirection - // via the extra `bool` value is used, so that the lock can be dropped immediately, - // and during the `sleep` new streams can be created. - let wait = { - let mut tasks = tasks.lock().await; - match tasks.next().await { - Some((_, task)) => { - let event = perform_request(&client, &task) - .await - .wrap_err("Failed to perform API request")?; - event_tx.send(event).wrap_err("Failed to send event")?; - - false - } - None => true, - } - }; - - if wait { - // Ideally we would be able to wait explicitly for the first stream - // to be registered. But as a workaround, we have to idle wait. - tokio::time::sleep(Duration::from_millis(500)).await; - } + while let Some(task) = task_rx.recv().await { + let task = task.await; + let event = perform_request(&client, &task) + .await + .wrap_err("Failed to perform API request")?; + event_tx.send(event).wrap_err("Failed to send event")?; } + + Ok(()) } #[tracing::instrument("api::worker", skip_all)] diff --git a/src/worker/lua.rs b/src/worker/lua.rs index 9911f3e..003327a 100644 --- a/src/worker/lua.rs +++ b/src/worker/lua.rs @@ -1,7 +1,7 @@ use std::sync::mpsc::Receiver; use color_eyre::Result; -use mlua::{Function, Lua, LuaSerdeExt}; +use mlua::{Function, IntoLua as _, Lua, LuaSerdeExt}; use tokio::sync::mpsc::UnboundedSender; use crate::types::{ApiTask, Event, Message}; @@ -70,7 +70,7 @@ pub fn worker( event_fn.call::<_, ()>(("webhook", data))? } Event::Api(data) => { - let data = lua.to_value(&data)?; + let data = data.into_lua(&lua)?; event_fn.call::<_, ()>(("api", data))? } Event::Error(data) => { diff --git a/src/worker/sender.rs b/src/worker/sender.rs index 79a18be..9288760 100644 --- a/src/worker/sender.rs +++ b/src/worker/sender.rs @@ -1,22 +1,38 @@ use std::sync::mpsc::Sender; -use color_eyre::eyre::{self, Context as _}; +use color_eyre::eyre; use color_eyre::Result; -use reqwest::{header, Client}; +use reqwest::header; +use reqwest::Client; use tokio::runtime; use tokio::sync::mpsc::UnboundedReceiver; use crate::types::{Event, Message, NtfyMessage}; - -static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); +use crate::APP_USER_AGENT; #[tracing::instrument] -async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result<()> { +async fn send_ntfy(client: &Client, data: &NtfyMessage) -> Result<()> { let body = serde_json::to_string(data)?; - tracing::trace!("JSON: {}", body); - let res = client.post(url).body(body).send().await?; + + let res = client + .post(&data.url) + .headers({ + let mut headers = header::HeaderMap::new(); + + let auth = format!("Bearer {}", data.token); + let mut auth = header::HeaderValue::from_str(&auth)?; + auth.set_sensitive(true); + + headers.insert(header::AUTHORIZATION, auth); + + headers + }) + .body(body) + .send() + .await?; + if res.status().as_u16() >= 400 { - let body = res.text().await?; + let body = res.text().await.unwrap_or_default(); eyre::bail!("Ntfy server returned error: {}", body); } @@ -25,30 +41,14 @@ async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result< #[tracing::instrument(skip_all)] pub async fn run(event_tx: Sender, mut ntfy_rx: UnboundedReceiver) -> Result<()> { - let ntfy_url = std::env::var("NTFY_URL").wrap_err("Missing env var 'NTFY_URL'")?; - let ntfy_token = std::env::var("NTFY_TOKEN").wrap_err("Missing env var 'NTFY_TOKEN'")?; - - let ntfy_client = Client::builder() - .default_headers({ - let mut headers = header::HeaderMap::new(); - - let auth = format!("Bearer {}", ntfy_token); - let mut auth = header::HeaderValue::from_str(&auth)?; - auth.set_sensitive(true); - - headers.insert(header::AUTHORIZATION, auth); - - headers - }) - .user_agent(APP_USER_AGENT) - .build()?; + let ntfy_client = Client::builder().user_agent(APP_USER_AGENT).build()?; while let Some(message) = ntfy_rx.recv().await { tracing::trace!("Received notification: {:?}", message); match message { Message::Ntfy(data) => { - if let Err(err) = send_ntfy(&ntfy_client, &ntfy_url, &data).await { + if let Err(err) = send_ntfy(&ntfy_client, &data).await { tracing::error!("Failed to send to Ntfy: {:?}", err); event_tx.send(Event::error(data.topic, err))?; }