generated from lucas/rust-template
Rework API tasks
GitHub expects a 'Last-Modified' header, and honoring an 'X-Poll-Interval' header for their notifications endpoint. Other services might also have certain limitations that require customizing every API query individually. Since that's not possible if API tasks are configured once and run off of an interval, this reworks them so that the config needs to trigger every query individually. A `delay` parameter allows re-creating the same intervals that were possible before. This also moves the configuration for Ntfy to the Lua file.
This commit is contained in:
parent
a02bf39f27
commit
d2cb39f9a2
8 changed files with 166 additions and 86 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,3 +1,3 @@
|
|||
/target
|
||||
.envrc
|
||||
test.lua
|
||||
*.lua
|
||||
|
|
12
Cargo.lock
generated
12
Cargo.lock
generated
|
@ -704,7 +704,6 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_repr",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-error",
|
||||
"tracing-subscriber",
|
||||
|
@ -1376,17 +1375,6 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.12"
|
||||
|
|
|
@ -14,7 +14,6 @@ serde = { version = "1.0.209", features = ["derive"] }
|
|||
serde_json = "1.0.128"
|
||||
serde_repr = "0.1.19"
|
||||
tokio = { version = "1.40.0", features = ["rt", "sync"] }
|
||||
tokio-stream = "0.1.16"
|
||||
tracing = "0.1.40"
|
||||
tracing-error = "0.2.0"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
|
|
|
@ -21,6 +21,9 @@ mod worker {
|
|||
pub use server::worker as server;
|
||||
}
|
||||
|
||||
pub(crate) static APP_USER_AGENT: &str =
|
||||
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
|
||||
|
||||
fn spawn<T, F>(name: &'static str, task: F) -> Result<JoinHandle<Result<T>>>
|
||||
where
|
||||
F: FnOnce() -> Result<T> + Send + 'static,
|
||||
|
|
73
src/types.rs
73
src/types.rs
|
@ -1,7 +1,42 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use mlua::IntoLua;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Converts a `serde_json::Value` to a `mlua::Value`.
|
||||
fn json_to_lua(value: serde_json::Value, lua: &mlua::Lua) -> mlua::Result<mlua::Value<'_>> {
|
||||
match value {
|
||||
serde_json::Value::Null => Ok(mlua::Value::Nil),
|
||||
serde_json::Value::Bool(value) => Ok(mlua::Value::Boolean(value)),
|
||||
serde_json::Value::Number(value) => match value.as_f64() {
|
||||
Some(number) => Ok(mlua::Value::Number(number)),
|
||||
None => Err(mlua::Error::ToLuaConversionError {
|
||||
from: "serde_json::Value::Number",
|
||||
to: "Number",
|
||||
message: Some("Number cannot be represented by a floating-point type".into()),
|
||||
}),
|
||||
},
|
||||
serde_json::Value::String(value) => lua.create_string(value).map(mlua::Value::String),
|
||||
serde_json::Value::Array(value) => {
|
||||
let tbl = lua.create_table_with_capacity(value.len(), 0)?;
|
||||
for v in value {
|
||||
let v = json_to_lua(v, lua)?;
|
||||
tbl.push(v)?;
|
||||
}
|
||||
Ok(mlua::Value::Table(tbl))
|
||||
}
|
||||
serde_json::Value::Object(value) => {
|
||||
let tbl = lua.create_table_with_capacity(0, value.len())?;
|
||||
for (k, v) in value {
|
||||
let k = lua.create_string(k)?;
|
||||
let v = json_to_lua(v, lua)?;
|
||||
tbl.set(k, v)?;
|
||||
}
|
||||
Ok(mlua::Value::Table(tbl))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub(crate) struct WebhookEvent {
|
||||
pub topic: String,
|
||||
|
@ -13,9 +48,32 @@ pub(crate) struct WebhookEvent {
|
|||
pub(crate) struct ApiEvent {
|
||||
pub id: String,
|
||||
pub body: serde_json::Value,
|
||||
pub headers: HashMap<String, Vec<u8>>,
|
||||
pub status: u16,
|
||||
}
|
||||
|
||||
impl<'lua> IntoLua<'lua> for ApiEvent {
|
||||
fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> {
|
||||
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
|
||||
for (k, v) in self.headers {
|
||||
let k = lua.create_string(k)?;
|
||||
let v = lua.create_string(v)?;
|
||||
headers.set(k, v)?;
|
||||
}
|
||||
|
||||
let body = json_to_lua(self.body, lua)?;
|
||||
let tbl = lua.create_table_with_capacity(0, 4)?;
|
||||
|
||||
lua.create_string(self.id)
|
||||
.and_then(|id| tbl.set("id", id))?;
|
||||
tbl.set("status", self.status)?;
|
||||
tbl.set("headers", headers)?;
|
||||
tbl.set("body", body)?;
|
||||
|
||||
Ok(mlua::Value::Table(tbl))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub(crate) struct ErrorEvent {
|
||||
pub id: String,
|
||||
|
@ -63,8 +121,15 @@ pub(crate) struct NtfyAction {
|
|||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub(crate) struct NtfyMessage {
|
||||
// Needs to be deserialized from Lua, but should not be sent to Ntfy
|
||||
#[serde(skip_serializing)]
|
||||
pub url: String,
|
||||
// Needs to be deserialized from Lua, but should not be sent to Ntfy
|
||||
#[serde(skip_serializing)]
|
||||
pub token: String,
|
||||
pub topic: String,
|
||||
pub message: String,
|
||||
#[serde(default)]
|
||||
pub title: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tags: Vec<String>,
|
||||
|
@ -89,16 +154,22 @@ pub(crate) enum Message {
|
|||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) enum Method {
|
||||
#[serde(alias = "GET", alias = "get")]
|
||||
Get,
|
||||
#[serde(alias = "POST", alias = "post")]
|
||||
Post,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct ApiTask {
|
||||
pub id: String,
|
||||
pub interval: u64,
|
||||
#[serde(default)]
|
||||
pub delay: u64,
|
||||
pub method: Method,
|
||||
pub url: String,
|
||||
pub body: serde_json::Value,
|
||||
#[serde(default)]
|
||||
pub query: HashMap<String, String>,
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
use std::sync::mpsc::Sender;
|
||||
use std::sync::Arc;
|
||||
use std::ascii::AsciiExt;
|
||||
use std::time::Duration;
|
||||
use std::{collections::HashMap, sync::mpsc::Sender};
|
||||
|
||||
use color_eyre::eyre::Context as _;
|
||||
use color_eyre::Result;
|
||||
use reqwest::Client;
|
||||
use reqwest::{header, Client};
|
||||
use tokio::runtime;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
use tokio_stream::{StreamExt as _, StreamMap};
|
||||
|
||||
use crate::types::{ApiEvent, ApiTask, Event, Method};
|
||||
use crate::APP_USER_AGENT;
|
||||
|
||||
fn header_map_to_hashmap(headers: &header::HeaderMap) -> HashMap<String, Vec<u8>> {
|
||||
let mut map = HashMap::with_capacity(headers.len());
|
||||
|
||||
for (k, v) in headers {
|
||||
map.insert(k.to_string(), v.as_bytes().to_vec());
|
||||
}
|
||||
|
||||
map
|
||||
}
|
||||
|
||||
async fn perform_request(client: &Client, task: &ApiTask) -> Result<Event> {
|
||||
let req = match task.method {
|
||||
|
@ -22,70 +30,81 @@ async fn perform_request(client: &Client, task: &ApiTask) -> Result<Event> {
|
|||
client.post(&task.url).body(body)
|
||||
}
|
||||
};
|
||||
let res = req.query(&task.query).send().await?;
|
||||
let status = res.status().as_u16();
|
||||
|
||||
let body: serde_json::Value = res.json().await?;
|
||||
let res = req
|
||||
.query(&task.query)
|
||||
.headers({
|
||||
let mut headers = header::HeaderMap::new();
|
||||
|
||||
for (k, v) in &task.headers {
|
||||
// Non-ASCII characters aren't supported anyways for header name, so no need to map
|
||||
// those.
|
||||
let k = header::HeaderName::from_lowercase(k.to_ascii_lowercase().as_bytes())
|
||||
.wrap_err_with(|| format!("Invalid header name '{}'", k))?;
|
||||
let v = header::HeaderValue::from_str(v)
|
||||
.wrap_err_with(|| format!("Invalid header value '{}'", v))?;
|
||||
|
||||
headers.insert(k, v);
|
||||
}
|
||||
|
||||
headers
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
let status = res.status();
|
||||
let headers = header_map_to_hashmap(res.headers());
|
||||
|
||||
let body: serde_json::Value = if let Ok(text) = res.text().await {
|
||||
tracing::trace!("Response body: {}", text);
|
||||
serde_json::from_str(&text).unwrap_or(serde_json::Value::String(text))
|
||||
} else {
|
||||
tracing::trace!("Response body: NULL");
|
||||
serde_json::Value::Null
|
||||
};
|
||||
|
||||
let event = ApiEvent {
|
||||
id: task.id.clone(),
|
||||
headers,
|
||||
body,
|
||||
status,
|
||||
status: status.as_u16(),
|
||||
};
|
||||
Ok(Event::Api(event))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn run(mut api_rx: UnboundedReceiver<ApiTask>, event_tx: Sender<Event>) -> Result<()> {
|
||||
let tasks = Arc::new(Mutex::new(StreamMap::new()));
|
||||
let (task_tx, mut task_rx) = tokio::sync::mpsc::channel(64);
|
||||
|
||||
let client = Client::builder()
|
||||
.user_agent(APP_USER_AGENT)
|
||||
.build()
|
||||
.wrap_err("Failed to build HTTP client")?;
|
||||
|
||||
tokio::spawn({
|
||||
let tasks = tasks.clone();
|
||||
async move {
|
||||
while let Some(task) = api_rx.recv().await {
|
||||
tracing::trace!("Received new API task: {:?}", task);
|
||||
|
||||
let id = task.id.clone();
|
||||
let interval = tokio::time::interval(Duration::from_secs(task.interval));
|
||||
|
||||
let task = Arc::new(task);
|
||||
let mut tasks = tasks.lock().await;
|
||||
tasks.insert(id, IntervalStream::new(interval).map(move |_| task.clone()));
|
||||
let _ = task_tx
|
||||
.send(async move {
|
||||
tokio::time::sleep(Duration::from_secs(task.delay)).await;
|
||||
task
|
||||
})
|
||||
.await;
|
||||
}
|
||||
tracing::error!("API task channel closed");
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
// We need to guarantee that the lock on `tasks` is released as soon as possible.
|
||||
// To ensure that it isn't held for the entirety of the `sleep`, the indirection
|
||||
// via the extra `bool` value is used, so that the lock can be dropped immediately,
|
||||
// and during the `sleep` new streams can be created.
|
||||
let wait = {
|
||||
let mut tasks = tasks.lock().await;
|
||||
match tasks.next().await {
|
||||
Some((_, task)) => {
|
||||
while let Some(task) = task_rx.recv().await {
|
||||
let task = task.await;
|
||||
let event = perform_request(&client, &task)
|
||||
.await
|
||||
.wrap_err("Failed to perform API request")?;
|
||||
event_tx.send(event).wrap_err("Failed to send event")?;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
};
|
||||
|
||||
if wait {
|
||||
// Ideally we would be able to wait explicitly for the first stream
|
||||
// to be registered. But as a workaround, we have to idle wait.
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument("api::worker", skip_all)]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::mpsc::Receiver;
|
||||
|
||||
use color_eyre::Result;
|
||||
use mlua::{Function, Lua, LuaSerdeExt};
|
||||
use mlua::{Function, IntoLua as _, Lua, LuaSerdeExt};
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
|
||||
use crate::types::{ApiTask, Event, Message};
|
||||
|
@ -70,7 +70,7 @@ pub fn worker(
|
|||
event_fn.call::<_, ()>(("webhook", data))?
|
||||
}
|
||||
Event::Api(data) => {
|
||||
let data = lua.to_value(&data)?;
|
||||
let data = data.into_lua(&lua)?;
|
||||
event_fn.call::<_, ()>(("api", data))?
|
||||
}
|
||||
Event::Error(data) => {
|
||||
|
|
|
@ -1,22 +1,38 @@
|
|||
use std::sync::mpsc::Sender;
|
||||
|
||||
use color_eyre::eyre::{self, Context as _};
|
||||
use color_eyre::eyre;
|
||||
use color_eyre::Result;
|
||||
use reqwest::{header, Client};
|
||||
use reqwest::header;
|
||||
use reqwest::Client;
|
||||
use tokio::runtime;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
use crate::types::{Event, Message, NtfyMessage};
|
||||
|
||||
static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
|
||||
use crate::APP_USER_AGENT;
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result<()> {
|
||||
async fn send_ntfy(client: &Client, data: &NtfyMessage) -> Result<()> {
|
||||
let body = serde_json::to_string(data)?;
|
||||
tracing::trace!("JSON: {}", body);
|
||||
let res = client.post(url).body(body).send().await?;
|
||||
|
||||
let res = client
|
||||
.post(&data.url)
|
||||
.headers({
|
||||
let mut headers = header::HeaderMap::new();
|
||||
|
||||
let auth = format!("Bearer {}", data.token);
|
||||
let mut auth = header::HeaderValue::from_str(&auth)?;
|
||||
auth.set_sensitive(true);
|
||||
|
||||
headers.insert(header::AUTHORIZATION, auth);
|
||||
|
||||
headers
|
||||
})
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if res.status().as_u16() >= 400 {
|
||||
let body = res.text().await?;
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
eyre::bail!("Ntfy server returned error: {}", body);
|
||||
}
|
||||
|
||||
|
@ -25,30 +41,14 @@ async fn send_ntfy(client: &Client, url: &String, data: &NtfyMessage) -> Result<
|
|||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn run(event_tx: Sender<Event>, mut ntfy_rx: UnboundedReceiver<Message>) -> Result<()> {
|
||||
let ntfy_url = std::env::var("NTFY_URL").wrap_err("Missing env var 'NTFY_URL'")?;
|
||||
let ntfy_token = std::env::var("NTFY_TOKEN").wrap_err("Missing env var 'NTFY_TOKEN'")?;
|
||||
|
||||
let ntfy_client = Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = header::HeaderMap::new();
|
||||
|
||||
let auth = format!("Bearer {}", ntfy_token);
|
||||
let mut auth = header::HeaderValue::from_str(&auth)?;
|
||||
auth.set_sensitive(true);
|
||||
|
||||
headers.insert(header::AUTHORIZATION, auth);
|
||||
|
||||
headers
|
||||
})
|
||||
.user_agent(APP_USER_AGENT)
|
||||
.build()?;
|
||||
let ntfy_client = Client::builder().user_agent(APP_USER_AGENT).build()?;
|
||||
|
||||
while let Some(message) = ntfy_rx.recv().await {
|
||||
tracing::trace!("Received notification: {:?}", message);
|
||||
|
||||
match message {
|
||||
Message::Ntfy(data) => {
|
||||
if let Err(err) = send_ntfy(&ntfy_client, &ntfy_url, &data).await {
|
||||
if let Err(err) = send_ntfy(&ntfy_client, &data).await {
|
||||
tracing::error!("Failed to send to Ntfy: {:?}", err);
|
||||
event_tx.send(Event::error(data.topic, err))?;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue