1
Fork 0
generated from lucas/rust-template
Code Issues 1 Pull requests 2 Activity
Var Dump:
dumpVar: only available in dev mode
Mailing List

Compare commits

...

7 commits

Author SHA1 Message Date
18efbb0957
Catch failing worker threads
Previously if one of the threads other than the server failed,
it would log the error but continue to run in a broken state.
So instead this makes sure than if any thread finishes the whole
application stops.
Since all worker threads should always wait indefinitely for more work,
this should only happen if one of them bails on an error.
2024-09-24 11:48:15 +02:00
c105ac80cd
Add Dockerfile 2024-09-20 10:31:19 +02:00
f5c64b788f
Use Rustls for TLS
Slim Docker images like Alpine or Distroless don't ship OpenSSL by
default, and rather than installing that, Rustls can be linked
statically.
2024-09-20 10:28:08 +02:00
e9796333bb
Reduce logging noise 2024-09-20 10:27:50 +02:00
23d27389d3
Expose logging functions to Lua 2024-09-19 13:52:36 +02:00
d2cb39f9a2
Rework API tasks
GitHub expects a 'Last-Modified' header, and honoring an
'X-Poll-Interval' header for their notifications endpoint.
Other services might also have certain limitations that require
customizing every API query individually.

Since that's not possible if API tasks are configured once and run off
of an interval, this reworks them so that the config needs to trigger
every query individually. A `delay` parameter allows re-creating the
same intervals that were possible before.

This also moves the configuration for Ntfy to the Lua file.
2024-09-18 11:31:10 +02:00
a02bf39f27
Implement initial boilerplate
The general architecture is a collection of four
threads.
Three threads that run their dedicated Tokio runtime and handle
the axum server, send outgoing notifications and run scheduled API
requests respectively. The fourth thread runs the Lua VM.

Channels exist between the threads to send messages and allow the Lua
script to orchestrate and configure everything.
2024-09-17 13:49:27 +02:00
12 changed files with 2017 additions and 13 deletions

1
.dockerignore Normal file
View file

@ -0,0 +1 @@
target/

2
.gitignore vendored
View file

@ -1 +1,3 @@
/target
.envrc
lua/

1301
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,18 @@ edition = "2021"
license = "EUPL-1.2"
[dependencies]
axum = "0.7.5"
color-eyre = "0.6.3"
mlua = { version = "0.9.9", features = ["luajit", "macros", "serialize", "vendored"] }
reqwest = { version = "0.12.7", default-features = false, features = ["json", "rustls-tls"] }
serde = { version = "1.0.209", features = ["derive"] }
serde_json = "1.0.128"
serde_repr = "0.1.19"
tokio = { version = "1.40.0", features = ["rt", "sync"] }
tracing = "0.1.40"
tracing-error = "0.2.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
[profile.release]
strip = true
lto = true

27
Dockerfile Normal file
View file

@ -0,0 +1,27 @@
FROM rust:1.81.0-slim AS build
WORKDIR /src
RUN set -ex; \
apt-get update; \
apt-get install -y --no-install-recommends \
make \
libluajit-5.1-dev \
; \
rm -rf /var/lib/apt/lists/*;
COPY . /src
RUN --mount=type=cache,id=cargo-registry,target=/cargo/registry \
--mount=type=cache,id=cargo-target,target=/src/target \
cargo build --release --locked && cp /src/target/release/ntfy-collector /src/;
FROM gcr.io/distroless/cc-debian12 AS final
WORKDIR /ntfy-collector
ENV CONFIG_PATH=/ntfy-collector/lua/config.lua
ENV LUA_PATH=/ntfy-collector/lua/?.lua;/ntfy-collector/lua/?/init.lua
COPY --from=build /src/ntfy-collector /usr/bin/ntfy-collector
COPY ./lua ./lua
CMD ["/usr/bin/ntfy-collector"]

View file

@ -1,3 +1,3 @@
# Rust Project Template
# Ntfy Collector
A simple boilerplate for Rust projects
A daemon to collect notifications from various places and forward them to Ntfy.

View file

@ -1,16 +1,124 @@
use std::fs;
use std::sync::mpsc::Sender;
use std::thread::{self, JoinHandle};
use color_eyre::eyre::{bail, eyre, Context as _};
use color_eyre::Result;
use tracing_error::ErrorLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter};
use tracing_subscriber::layer::SubscriberExt as _;
use tracing_subscriber::util::SubscriberInitExt as _;
use tracing_subscriber::{fmt, EnvFilter};
pub(crate) mod types;
mod worker {
mod api;
mod lua;
mod sender;
mod server;
pub use api::worker as api;
pub use lua::worker as lua;
pub use sender::worker as sender;
pub use server::worker as server;
}
pub(crate) static APP_USER_AGENT: &str =
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
fn spawn<T, F>(
name: &'static str,
close_tx: Sender<&'static str>,
task: F,
) -> Result<JoinHandle<Result<T>>>
where
F: FnOnce() -> Result<T> + Send + 'static,
T: Send + 'static,
{
thread::Builder::new()
.name(name.to_string())
.spawn(move || {
let res = task();
if let Err(err) = &res {
tracing::error!("Thread '{}' errored: {:?}", name, err);
}
let _ = close_tx.send(name);
res
})
.wrap_err_with(|| format!("Failed to create thread '{}'", name))
}
fn main() -> Result<()> {
color_eyre::install()?;
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
.with(fmt::layer().pretty())
.with(fmt::layer().compact())
.with(ErrorLayer::new(fmt::format::Pretty::default()))
.init();
println!("Hello, world!");
let config_path = std::env::var("CONFIG_PATH").wrap_err("Missing variable 'CONFIG_PATH'")?;
Ok(())
// A channel send to each thread to signal that any of them finished.
// Since all workers are supposed to be running indefinitely, waiting for
// events from the outside, they can only stop because of an error.
// But to be able
let (close_tx, close_rx) = std::sync::mpsc::channel();
let (sender_tx, sender_rx) = tokio::sync::mpsc::unbounded_channel();
let (event_tx, event_rx) = std::sync::mpsc::channel();
// A channel that lets other threads, mostly the Lua code, register
// a task with the API fetcher.
let (api_tx, api_rx) = tokio::sync::mpsc::unbounded_channel();
let lua_thread = {
let code = fs::read_to_string(&config_path)
.wrap_err_with(|| format!("Failed to read config from '{}'", config_path))?;
spawn("lua", close_tx.clone(), move || {
worker::lua(code, event_rx, api_tx, sender_tx)
})?
};
let api_thread = {
let event_tx = event_tx.clone();
spawn("api", close_tx.clone(), move || {
worker::api(api_rx, event_tx)
})?
};
let sender_thread = {
let event_tx = event_tx.clone();
spawn("sender", close_tx.clone(), move || {
worker::sender(event_tx, sender_rx)
})?
};
let server_thread = spawn("server", close_tx, move || worker::server(event_tx))?;
match close_rx.recv() {
Ok(name) => match name {
"api" => api_thread
.join()
.map_err(|err| eyre!("Thread 'api' panicked: {:?}", err))
.and_then(|res| res),
"lua" => lua_thread
.join()
.map_err(|err| eyre!("Thread 'lua' panicked: {:?}", err))
.and_then(|res| res),
"sender" => sender_thread
.join()
.map_err(|err| eyre!("Thread 'sender' panicked: {:?}", err))
.and_then(|res| res),
"server" => server_thread
.join()
.map_err(|err| eyre!("Thread 'server' panicked: {:?}", err))
.and_then(|res| res),
_ => bail!("Unknown thread '{}'", name),
},
Err(_) => unreachable!(
"Any thread given this channel will send a closing notification before dropping it"
),
}
}

183
src/types.rs Normal file
View file

@ -0,0 +1,183 @@
use std::collections::HashMap;
use mlua::IntoLua;
use serde::{Deserialize, Serialize};
/// Converts a `serde_json::Value` to a `mlua::Value`.
fn json_to_lua(value: serde_json::Value, lua: &mlua::Lua) -> mlua::Result<mlua::Value<'_>> {
match value {
serde_json::Value::Null => Ok(mlua::Value::Nil),
serde_json::Value::Bool(value) => Ok(mlua::Value::Boolean(value)),
serde_json::Value::Number(value) => match value.as_f64() {
Some(number) => Ok(mlua::Value::Number(number)),
None => Err(mlua::Error::ToLuaConversionError {
from: "serde_json::Value::Number",
to: "Number",
message: Some("Number cannot be represented by a floating-point type".into()),
}),
},
serde_json::Value::String(value) => lua.create_string(value).map(mlua::Value::String),
serde_json::Value::Array(value) => {
let tbl = lua.create_table_with_capacity(value.len(), 0)?;
for v in value {
let v = json_to_lua(v, lua)?;
tbl.push(v)?;
}
Ok(mlua::Value::Table(tbl))
}
serde_json::Value::Object(value) => {
let tbl = lua.create_table_with_capacity(0, value.len())?;
for (k, v) in value {
let k = lua.create_string(k)?;
let v = json_to_lua(v, lua)?;
tbl.set(k, v)?;
}
Ok(mlua::Value::Table(tbl))
}
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct WebhookEvent {
pub topic: String,
pub query: HashMap<String, String>,
pub body: serde_json::Value,
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct ApiEvent {
pub id: String,
pub body: serde_json::Value,
pub headers: HashMap<String, Vec<u8>>,
pub status: u16,
}
impl<'lua> IntoLua<'lua> for ApiEvent {
fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> {
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
for (k, v) in self.headers {
let k = lua.create_string(k)?;
let v = lua.create_string(v)?;
headers.set(k, v)?;
}
let body = json_to_lua(self.body, lua)?;
let tbl = lua.create_table_with_capacity(0, 4)?;
lua.create_string(self.id)
.and_then(|id| tbl.set("id", id))?;
tbl.set("status", self.status)?;
tbl.set("headers", headers)?;
tbl.set("body", body)?;
Ok(mlua::Value::Table(tbl))
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct ErrorEvent {
pub id: String,
pub message: String,
}
#[derive(Clone, Debug)]
pub(crate) enum Event {
Webhook(WebhookEvent),
Api(ApiEvent),
Error(ErrorEvent),
}
impl Event {
pub fn error(id: String, message: impl ToString) -> Self {
Self::Error(ErrorEvent {
id,
message: message.to_string(),
})
}
}
#[derive(Clone, Debug, serde_repr::Deserialize_repr, serde_repr::Serialize_repr)]
#[repr(u8)]
pub(crate) enum NtfyPriority {
Min = 1,
Low = 2,
Default = 3,
High = 4,
Max = 5,
}
impl Default for NtfyPriority {
fn default() -> Self {
Self::Default
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct NtfyAction {
pub action: String,
pub label: String,
pub url: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct NtfyMessage {
// Needs to be deserialized from Lua, but should not be sent to Ntfy
#[serde(skip_serializing)]
pub url: String,
// Needs to be deserialized from Lua, but should not be sent to Ntfy
#[serde(skip_serializing)]
pub token: String,
pub topic: String,
pub message: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(default)]
pub priority: NtfyPriority,
#[serde(default)]
pub click: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub actions: Vec<NtfyAction>,
#[serde(default)]
pub markdown: bool,
#[serde(default)]
pub icon: Option<String>,
#[serde(default)]
pub delay: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) enum Message {
Ntfy(NtfyMessage),
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) enum Method {
#[serde(alias = "GET", alias = "get")]
Get,
#[serde(alias = "POST", alias = "post")]
Post,
}
impl Default for Method {
fn default() -> Self {
Self::Get
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct ApiTask {
pub id: String,
#[serde(default)]
pub delay: u64,
#[serde(default)]
pub method: Method,
pub url: String,
#[serde(default)]
pub body: serde_json::Value,
#[serde(default)]
pub query: HashMap<String, String>,
#[serde(default)]
pub headers: HashMap<String, String>,
}

121
src/worker/api.rs Normal file
View file

@ -0,0 +1,121 @@
use std::collections::HashMap;
use std::sync::mpsc::Sender;
use std::time::Duration;
use color_eyre::eyre::Context as _;
use color_eyre::Result;
use reqwest::{header, Client};
use tokio::runtime;
use tokio::sync::mpsc::UnboundedReceiver;
use tracing::Instrument as _;
use crate::types::{ApiEvent, ApiTask, Event, Method};
use crate::APP_USER_AGENT;
fn header_map_to_hashmap(headers: &header::HeaderMap) -> HashMap<String, Vec<u8>> {
let mut map = HashMap::with_capacity(headers.len());
for (k, v) in headers {
map.insert(k.to_string(), v.as_bytes().to_vec());
}
map
}
#[tracing::instrument]
async fn perform_request(client: &Client, task: &ApiTask) -> Result<Event> {
let req = match task.method {
Method::Get => client.get(&task.url),
Method::Post => {
let body = serde_json::to_vec(&task.body)
.expect("Type `serde_json::Value` should always be serializable");
client.post(&task.url).body(body)
}
};
let res = req
.query(&task.query)
.headers({
let mut headers = header::HeaderMap::new();
for (k, v) in &task.headers {
// Non-ASCII characters aren't supported anyways for header name, so no need to map
// those.
let k = header::HeaderName::from_lowercase(k.to_ascii_lowercase().as_bytes())
.wrap_err_with(|| format!("Invalid header name '{}'", k))?;
let v = header::HeaderValue::from_str(v)
.wrap_err_with(|| format!("Invalid header value '{}'", v))?;
headers.insert(k, v);
}
headers
})
.send()
.await?;
let status = res.status();
let headers = header_map_to_hashmap(res.headers());
let body: serde_json::Value = if let Ok(text) = res.text().await {
serde_json::from_str(&text).unwrap_or(serde_json::Value::String(text))
} else {
serde_json::Value::Null
};
let event = ApiEvent {
id: task.id.clone(),
headers,
body,
status: status.as_u16(),
};
Ok(Event::Api(event))
}
#[tracing::instrument(skip_all)]
pub async fn run(mut api_rx: UnboundedReceiver<ApiTask>, event_tx: Sender<Event>) -> Result<()> {
let (task_tx, mut task_rx) = tokio::sync::mpsc::channel(64);
let client = Client::builder()
.user_agent(APP_USER_AGENT)
.build()
.wrap_err("Failed to build HTTP client")?;
tokio::spawn({
let span = tracing::info_span!("api receiver");
async move {
while let Some(task) = api_rx.recv().await {
tracing::trace!(id = task.id, "Received new API task");
let _ = task_tx
.send(async move {
tokio::time::sleep(Duration::from_secs(task.delay)).await;
task
})
.await;
}
tracing::error!("API task channel closed");
}
.instrument(span)
});
while let Some(task) = task_rx.recv().await {
let task = task.await;
let event = perform_request(&client, &task)
.await
.wrap_err("Failed to perform API request")?;
event_tx.send(event).wrap_err("Failed to send event")?;
}
Ok(())
}
#[tracing::instrument("api::worker", skip_all)]
pub fn worker(api_rx: UnboundedReceiver<ApiTask>, event_tx: Sender<Event>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("api")
.enable_io()
.enable_time()
.build()?;
rt.block_on(run(api_rx, event_tx))
}

118
src/worker/lua.rs Normal file
View file

@ -0,0 +1,118 @@
use std::sync::mpsc::Receiver;
use color_eyre::Result;
use mlua::{Function, IntoLua as _, Lua, LuaSerdeExt};
use tokio::sync::mpsc::UnboundedSender;
use crate::types::{ApiTask, Event, Message, NtfyMessage};
#[tracing::instrument(skip_all)]
pub fn worker(
config: String,
event_rx: Receiver<Event>,
api_tx: UnboundedSender<ApiTask>,
ntfy_tx: UnboundedSender<Message>,
) -> Result<()> {
let lua = Lua::new();
let globals = lua.globals();
let config = lua.load(config).set_name("config");
lua.scope(|scope| {
let log_trace_fn = scope.create_function(|_, s: String| {
tracing::trace!(name: "lua", "{}", s);
Ok(())
})?;
let log_debug_fn = scope.create_function(|_, s: String| {
tracing::debug!(name: "lua", "{}", s);
Ok(())
})?;
let log_info_fn = scope.create_function(|_, s: String| {
tracing::info!(name: "lua", "{}", s);
Ok(())
})?;
let log_warn_fn = scope.create_function(|_, s: String| {
tracing::warn!(name: "lua", "{}", s);
Ok(())
})?;
let log_error_fn = scope.create_function(|_, s: String| {
tracing::error!(name: "lua", "{}", s);
Ok(())
})?;
let log_tbl = lua.create_table_with_capacity(0, 5)?;
log_tbl.set("trace", log_trace_fn)?;
log_tbl.set("debug", log_debug_fn)?;
log_tbl.set("info", log_info_fn)?;
log_tbl.set("warn", log_warn_fn)?;
log_tbl.set("error", log_error_fn)?;
let ntfy_fn = scope.create_function_mut(|_, data: mlua::Value| {
let data: NtfyMessage = lua.from_value(data)?;
tracing::trace!(topic = data.topic, title = data.title, msg = data.message,"Sending Ntfy message");
match ntfy_tx.send(Message::Ntfy(data)) {
Ok(_) => Ok((true, mlua::Value::Nil)),
Err(_) => {
let msg = lua.create_string("Failed to send message")?;
Ok((false, mlua::Value::String(msg)))
}
}
})?;
let set_api_task_fn = scope.create_function_mut(|_, data: mlua::Value| {
let task: ApiTask = lua.from_value(data)?;
tracing::trace!("Sending task request: id = {}, delay = {}", task.id, task.delay);
match api_tx.send(task) {
Ok(_) => Ok((true, mlua::Value::Nil)),
Err(_) => {
let msg = lua.create_string("Failed to trigger task")?;
Ok((false, mlua::Value::String(msg)))
}
}
})?;
globals.set("log", log_tbl)?;
globals.set("ntfy", ntfy_fn)?;
globals.set("api_task", set_api_task_fn)?;
config.exec()?;
let event_fn: Function = match globals.get("on_event") {
Ok(f) => f,
Err(err) => match err {
mlua::Error::FromLuaConversionError { from, to: _, message: _ } => {
let err = mlua::Error::runtime(format!("Global function 'on_event' not defined properly. Got value of type '{}'", from));
return Err(err);
}
err => return Err(err),
},
};
// Main blocking loop. As long as we can receive events, this scope will stay active.
while let Ok(event) = event_rx.recv() {
match event {
Event::Webhook(data) => {
tracing::trace!(id = data.topic, "Received webhook event");
let data = lua.to_value(&data)?;
event_fn.call::<_, ()>(("webhook", data))?
}
Event::Api(data) => {
tracing::trace!(id = data.id, status = data.status, "Received api event");
let data = data.into_lua(&lua)?;
event_fn.call::<_, ()>(("api", data))?
}
Event::Error(data) => {
tracing::trace!(id = data.id, message = data.message, "Received error event");
let data = lua.to_value(&data)?;
event_fn.call::<_, ()>(("error", data))?
}
}
}
Ok(())
})?;
Ok(())
}

70
src/worker/sender.rs Normal file
View file

@ -0,0 +1,70 @@
use std::sync::mpsc::Sender;
use color_eyre::eyre;
use color_eyre::Result;
use reqwest::header;
use reqwest::Client;
use tokio::runtime;
use tokio::sync::mpsc::UnboundedReceiver;
use crate::types::{Event, Message, NtfyMessage};
use crate::APP_USER_AGENT;
#[tracing::instrument]
async fn send_ntfy(client: &Client, data: &NtfyMessage) -> Result<()> {
let body = serde_json::to_string(data)?;
let res = client
.post(&data.url)
.headers({
let mut headers = header::HeaderMap::new();
let auth = format!("Bearer {}", data.token);
let mut auth = header::HeaderValue::from_str(&auth)?;
auth.set_sensitive(true);
headers.insert(header::AUTHORIZATION, auth);
headers
})
.body(body)
.send()
.await?;
if res.status().as_u16() >= 400 {
let body = res.text().await.unwrap_or_default();
eyre::bail!("Ntfy server returned error: {}", body);
}
Ok(())
}
#[tracing::instrument(skip_all)]
pub async fn run(event_tx: Sender<Event>, mut ntfy_rx: UnboundedReceiver<Message>) -> Result<()> {
let ntfy_client = Client::builder().user_agent(APP_USER_AGENT).build()?;
while let Some(message) = ntfy_rx.recv().await {
match message {
Message::Ntfy(data) => {
tracing::trace!(topic = data.topic, "Sending Ntfy notification");
if let Err(err) = send_ntfy(&ntfy_client, &data).await {
tracing::error!("Failed to send to Ntfy: {:?}", err);
event_tx.send(Event::error(data.topic, err))?;
}
}
}
}
tracing::debug!("Stopped receiving messages");
Ok(())
}
#[tracing::instrument(skip_all)]
pub fn worker(event_tx: Sender<Event>, ntfy_rx: UnboundedReceiver<Message>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("sender")
.enable_io()
.build()?;
rt.block_on(run(event_tx, ntfy_rx))
}

76
src/worker/server.rs Normal file
View file

@ -0,0 +1,76 @@
use std::collections::HashMap;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{Json, Router};
use color_eyre::eyre::Context as _;
use color_eyre::Result;
use tokio::net::TcpListener;
use tokio::runtime;
use crate::types::{Event, WebhookEvent};
struct AppState {
event_tx: Sender<Event>,
}
pub fn worker(event_tx: Sender<Event>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("server")
.enable_io()
.build()?;
rt.block_on(task(event_tx))
}
async fn task(event_tx: Sender<Event>) -> Result<()> {
let state = AppState { event_tx };
let shared_state = Arc::new(state);
let app = Router::new()
.route(
"/webhook/:topic",
get(webhook_handler).post(webhook_handler),
)
.with_state(shared_state);
let listener = TcpListener::bind("0.0.0.0:3000")
.await
.wrap_err("Failed to bind to TCP socket")?;
tracing::info!("Listening on \"0.0.0.0:3000\"");
axum::serve(listener, app)
.await
.wrap_err("Failed to start server")
}
async fn webhook_handler(
State(state): State<Arc<AppState>>,
Path(topic): Path<String>,
Query(query): Query<HashMap<String, String>>,
body: Option<Json<serde_json::Value>>,
) -> impl IntoResponse {
let event = Event::Webhook(WebhookEvent {
topic,
query,
body: body
.map(|Json(body)| body)
.unwrap_or(serde_json::Value::Null),
});
tracing::debug!("Received webhook event: {:?}", &event);
if state.event_tx.send(event).is_err() {
tracing::error!("Failed to trigger webhook event");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to trigger webhook event",
)
} else {
(StatusCode::NO_CONTENT, "")
}
}