1
Fork 0
generated from lucas/rust-template

Compare commits

..

No commits in common. "18efbb095780b95922145b3fc626a0a1a7c6e035" and "4faa02d2af37eccac7d57a83998a602531115c49" have entirely different histories.

12 changed files with 13 additions and 2017 deletions

View file

@ -1 +0,0 @@
target/

2
.gitignore vendored
View file

@ -1,3 +1 @@
/target /target
.envrc
lua/

1301
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,18 +6,7 @@ edition = "2021"
license = "EUPL-1.2" license = "EUPL-1.2"
[dependencies] [dependencies]
axum = "0.7.5"
color-eyre = "0.6.3" color-eyre = "0.6.3"
mlua = { version = "0.9.9", features = ["luajit", "macros", "serialize", "vendored"] }
reqwest = { version = "0.12.7", default-features = false, features = ["json", "rustls-tls"] }
serde = { version = "1.0.209", features = ["derive"] }
serde_json = "1.0.128"
serde_repr = "0.1.19"
tokio = { version = "1.40.0", features = ["rt", "sync"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-error = "0.2.0" tracing-error = "0.2.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
[profile.release]
strip = true
lto = true

View file

@ -1,27 +0,0 @@
FROM rust:1.81.0-slim AS build
WORKDIR /src
RUN set -ex; \
apt-get update; \
apt-get install -y --no-install-recommends \
make \
libluajit-5.1-dev \
; \
rm -rf /var/lib/apt/lists/*;
COPY . /src
RUN --mount=type=cache,id=cargo-registry,target=/cargo/registry \
--mount=type=cache,id=cargo-target,target=/src/target \
cargo build --release --locked && cp /src/target/release/ntfy-collector /src/;
FROM gcr.io/distroless/cc-debian12 AS final
WORKDIR /ntfy-collector
ENV CONFIG_PATH=/ntfy-collector/lua/config.lua
ENV LUA_PATH=/ntfy-collector/lua/?.lua;/ntfy-collector/lua/?/init.lua
COPY --from=build /src/ntfy-collector /usr/bin/ntfy-collector
COPY ./lua ./lua
CMD ["/usr/bin/ntfy-collector"]

View file

@ -1,3 +1,3 @@
# Ntfy Collector # Rust Project Template
A daemon to collect notifications from various places and forward them to Ntfy. A simple boilerplate for Rust projects

View file

@ -1,124 +1,16 @@
use std::fs;
use std::sync::mpsc::Sender;
use std::thread::{self, JoinHandle};
use color_eyre::eyre::{bail, eyre, Context as _};
use color_eyre::Result; use color_eyre::Result;
use tracing_error::ErrorLayer; use tracing_error::ErrorLayer;
use tracing_subscriber::layer::SubscriberExt as _; use tracing_subscriber::{fmt, layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter};
use tracing_subscriber::util::SubscriberInitExt as _;
use tracing_subscriber::{fmt, EnvFilter};
pub(crate) mod types;
mod worker {
mod api;
mod lua;
mod sender;
mod server;
pub use api::worker as api;
pub use lua::worker as lua;
pub use sender::worker as sender;
pub use server::worker as server;
}
pub(crate) static APP_USER_AGENT: &str =
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
fn spawn<T, F>(
name: &'static str,
close_tx: Sender<&'static str>,
task: F,
) -> Result<JoinHandle<Result<T>>>
where
F: FnOnce() -> Result<T> + Send + 'static,
T: Send + 'static,
{
thread::Builder::new()
.name(name.to_string())
.spawn(move || {
let res = task();
if let Err(err) = &res {
tracing::error!("Thread '{}' errored: {:?}", name, err);
}
let _ = close_tx.send(name);
res
})
.wrap_err_with(|| format!("Failed to create thread '{}'", name))
}
fn main() -> Result<()> { fn main() -> Result<()> {
color_eyre::install()?; color_eyre::install()?;
tracing_subscriber::registry() tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into())) .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
.with(fmt::layer().compact()) .with(fmt::layer().pretty())
.with(ErrorLayer::new(fmt::format::Pretty::default())) .with(ErrorLayer::new(fmt::format::Pretty::default()))
.init(); .init();
let config_path = std::env::var("CONFIG_PATH").wrap_err("Missing variable 'CONFIG_PATH'")?; println!("Hello, world!");
// A channel send to each thread to signal that any of them finished. Ok(())
// Since all workers are supposed to be running indefinitely, waiting for
// events from the outside, they can only stop because of an error.
// But to be able
let (close_tx, close_rx) = std::sync::mpsc::channel();
let (sender_tx, sender_rx) = tokio::sync::mpsc::unbounded_channel();
let (event_tx, event_rx) = std::sync::mpsc::channel();
// A channel that lets other threads, mostly the Lua code, register
// a task with the API fetcher.
let (api_tx, api_rx) = tokio::sync::mpsc::unbounded_channel();
let lua_thread = {
let code = fs::read_to_string(&config_path)
.wrap_err_with(|| format!("Failed to read config from '{}'", config_path))?;
spawn("lua", close_tx.clone(), move || {
worker::lua(code, event_rx, api_tx, sender_tx)
})?
};
let api_thread = {
let event_tx = event_tx.clone();
spawn("api", close_tx.clone(), move || {
worker::api(api_rx, event_tx)
})?
};
let sender_thread = {
let event_tx = event_tx.clone();
spawn("sender", close_tx.clone(), move || {
worker::sender(event_tx, sender_rx)
})?
};
let server_thread = spawn("server", close_tx, move || worker::server(event_tx))?;
match close_rx.recv() {
Ok(name) => match name {
"api" => api_thread
.join()
.map_err(|err| eyre!("Thread 'api' panicked: {:?}", err))
.and_then(|res| res),
"lua" => lua_thread
.join()
.map_err(|err| eyre!("Thread 'lua' panicked: {:?}", err))
.and_then(|res| res),
"sender" => sender_thread
.join()
.map_err(|err| eyre!("Thread 'sender' panicked: {:?}", err))
.and_then(|res| res),
"server" => server_thread
.join()
.map_err(|err| eyre!("Thread 'server' panicked: {:?}", err))
.and_then(|res| res),
_ => bail!("Unknown thread '{}'", name),
},
Err(_) => unreachable!(
"Any thread given this channel will send a closing notification before dropping it"
),
}
} }

View file

@ -1,183 +0,0 @@
use std::collections::HashMap;
use mlua::IntoLua;
use serde::{Deserialize, Serialize};
/// Converts a `serde_json::Value` to a `mlua::Value`.
fn json_to_lua(value: serde_json::Value, lua: &mlua::Lua) -> mlua::Result<mlua::Value<'_>> {
match value {
serde_json::Value::Null => Ok(mlua::Value::Nil),
serde_json::Value::Bool(value) => Ok(mlua::Value::Boolean(value)),
serde_json::Value::Number(value) => match value.as_f64() {
Some(number) => Ok(mlua::Value::Number(number)),
None => Err(mlua::Error::ToLuaConversionError {
from: "serde_json::Value::Number",
to: "Number",
message: Some("Number cannot be represented by a floating-point type".into()),
}),
},
serde_json::Value::String(value) => lua.create_string(value).map(mlua::Value::String),
serde_json::Value::Array(value) => {
let tbl = lua.create_table_with_capacity(value.len(), 0)?;
for v in value {
let v = json_to_lua(v, lua)?;
tbl.push(v)?;
}
Ok(mlua::Value::Table(tbl))
}
serde_json::Value::Object(value) => {
let tbl = lua.create_table_with_capacity(0, value.len())?;
for (k, v) in value {
let k = lua.create_string(k)?;
let v = json_to_lua(v, lua)?;
tbl.set(k, v)?;
}
Ok(mlua::Value::Table(tbl))
}
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct WebhookEvent {
pub topic: String,
pub query: HashMap<String, String>,
pub body: serde_json::Value,
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct ApiEvent {
pub id: String,
pub body: serde_json::Value,
pub headers: HashMap<String, Vec<u8>>,
pub status: u16,
}
impl<'lua> IntoLua<'lua> for ApiEvent {
fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> {
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
for (k, v) in self.headers {
let k = lua.create_string(k)?;
let v = lua.create_string(v)?;
headers.set(k, v)?;
}
let body = json_to_lua(self.body, lua)?;
let tbl = lua.create_table_with_capacity(0, 4)?;
lua.create_string(self.id)
.and_then(|id| tbl.set("id", id))?;
tbl.set("status", self.status)?;
tbl.set("headers", headers)?;
tbl.set("body", body)?;
Ok(mlua::Value::Table(tbl))
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct ErrorEvent {
pub id: String,
pub message: String,
}
#[derive(Clone, Debug)]
pub(crate) enum Event {
Webhook(WebhookEvent),
Api(ApiEvent),
Error(ErrorEvent),
}
impl Event {
pub fn error(id: String, message: impl ToString) -> Self {
Self::Error(ErrorEvent {
id,
message: message.to_string(),
})
}
}
#[derive(Clone, Debug, serde_repr::Deserialize_repr, serde_repr::Serialize_repr)]
#[repr(u8)]
pub(crate) enum NtfyPriority {
Min = 1,
Low = 2,
Default = 3,
High = 4,
Max = 5,
}
impl Default for NtfyPriority {
fn default() -> Self {
Self::Default
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct NtfyAction {
pub action: String,
pub label: String,
pub url: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct NtfyMessage {
// Needs to be deserialized from Lua, but should not be sent to Ntfy
#[serde(skip_serializing)]
pub url: String,
// Needs to be deserialized from Lua, but should not be sent to Ntfy
#[serde(skip_serializing)]
pub token: String,
pub topic: String,
pub message: String,
#[serde(default)]
pub title: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(default)]
pub priority: NtfyPriority,
#[serde(default)]
pub click: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub actions: Vec<NtfyAction>,
#[serde(default)]
pub markdown: bool,
#[serde(default)]
pub icon: Option<String>,
#[serde(default)]
pub delay: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) enum Message {
Ntfy(NtfyMessage),
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) enum Method {
#[serde(alias = "GET", alias = "get")]
Get,
#[serde(alias = "POST", alias = "post")]
Post,
}
impl Default for Method {
fn default() -> Self {
Self::Get
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct ApiTask {
pub id: String,
#[serde(default)]
pub delay: u64,
#[serde(default)]
pub method: Method,
pub url: String,
#[serde(default)]
pub body: serde_json::Value,
#[serde(default)]
pub query: HashMap<String, String>,
#[serde(default)]
pub headers: HashMap<String, String>,
}

View file

@ -1,121 +0,0 @@
use std::collections::HashMap;
use std::sync::mpsc::Sender;
use std::time::Duration;
use color_eyre::eyre::Context as _;
use color_eyre::Result;
use reqwest::{header, Client};
use tokio::runtime;
use tokio::sync::mpsc::UnboundedReceiver;
use tracing::Instrument as _;
use crate::types::{ApiEvent, ApiTask, Event, Method};
use crate::APP_USER_AGENT;
fn header_map_to_hashmap(headers: &header::HeaderMap) -> HashMap<String, Vec<u8>> {
let mut map = HashMap::with_capacity(headers.len());
for (k, v) in headers {
map.insert(k.to_string(), v.as_bytes().to_vec());
}
map
}
#[tracing::instrument]
async fn perform_request(client: &Client, task: &ApiTask) -> Result<Event> {
let req = match task.method {
Method::Get => client.get(&task.url),
Method::Post => {
let body = serde_json::to_vec(&task.body)
.expect("Type `serde_json::Value` should always be serializable");
client.post(&task.url).body(body)
}
};
let res = req
.query(&task.query)
.headers({
let mut headers = header::HeaderMap::new();
for (k, v) in &task.headers {
// Non-ASCII characters aren't supported anyways for header name, so no need to map
// those.
let k = header::HeaderName::from_lowercase(k.to_ascii_lowercase().as_bytes())
.wrap_err_with(|| format!("Invalid header name '{}'", k))?;
let v = header::HeaderValue::from_str(v)
.wrap_err_with(|| format!("Invalid header value '{}'", v))?;
headers.insert(k, v);
}
headers
})
.send()
.await?;
let status = res.status();
let headers = header_map_to_hashmap(res.headers());
let body: serde_json::Value = if let Ok(text) = res.text().await {
serde_json::from_str(&text).unwrap_or(serde_json::Value::String(text))
} else {
serde_json::Value::Null
};
let event = ApiEvent {
id: task.id.clone(),
headers,
body,
status: status.as_u16(),
};
Ok(Event::Api(event))
}
#[tracing::instrument(skip_all)]
pub async fn run(mut api_rx: UnboundedReceiver<ApiTask>, event_tx: Sender<Event>) -> Result<()> {
let (task_tx, mut task_rx) = tokio::sync::mpsc::channel(64);
let client = Client::builder()
.user_agent(APP_USER_AGENT)
.build()
.wrap_err("Failed to build HTTP client")?;
tokio::spawn({
let span = tracing::info_span!("api receiver");
async move {
while let Some(task) = api_rx.recv().await {
tracing::trace!(id = task.id, "Received new API task");
let _ = task_tx
.send(async move {
tokio::time::sleep(Duration::from_secs(task.delay)).await;
task
})
.await;
}
tracing::error!("API task channel closed");
}
.instrument(span)
});
while let Some(task) = task_rx.recv().await {
let task = task.await;
let event = perform_request(&client, &task)
.await
.wrap_err("Failed to perform API request")?;
event_tx.send(event).wrap_err("Failed to send event")?;
}
Ok(())
}
#[tracing::instrument("api::worker", skip_all)]
pub fn worker(api_rx: UnboundedReceiver<ApiTask>, event_tx: Sender<Event>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("api")
.enable_io()
.enable_time()
.build()?;
rt.block_on(run(api_rx, event_tx))
}

View file

@ -1,118 +0,0 @@
use std::sync::mpsc::Receiver;
use color_eyre::Result;
use mlua::{Function, IntoLua as _, Lua, LuaSerdeExt};
use tokio::sync::mpsc::UnboundedSender;
use crate::types::{ApiTask, Event, Message, NtfyMessage};
#[tracing::instrument(skip_all)]
pub fn worker(
config: String,
event_rx: Receiver<Event>,
api_tx: UnboundedSender<ApiTask>,
ntfy_tx: UnboundedSender<Message>,
) -> Result<()> {
let lua = Lua::new();
let globals = lua.globals();
let config = lua.load(config).set_name("config");
lua.scope(|scope| {
let log_trace_fn = scope.create_function(|_, s: String| {
tracing::trace!(name: "lua", "{}", s);
Ok(())
})?;
let log_debug_fn = scope.create_function(|_, s: String| {
tracing::debug!(name: "lua", "{}", s);
Ok(())
})?;
let log_info_fn = scope.create_function(|_, s: String| {
tracing::info!(name: "lua", "{}", s);
Ok(())
})?;
let log_warn_fn = scope.create_function(|_, s: String| {
tracing::warn!(name: "lua", "{}", s);
Ok(())
})?;
let log_error_fn = scope.create_function(|_, s: String| {
tracing::error!(name: "lua", "{}", s);
Ok(())
})?;
let log_tbl = lua.create_table_with_capacity(0, 5)?;
log_tbl.set("trace", log_trace_fn)?;
log_tbl.set("debug", log_debug_fn)?;
log_tbl.set("info", log_info_fn)?;
log_tbl.set("warn", log_warn_fn)?;
log_tbl.set("error", log_error_fn)?;
let ntfy_fn = scope.create_function_mut(|_, data: mlua::Value| {
let data: NtfyMessage = lua.from_value(data)?;
tracing::trace!(topic = data.topic, title = data.title, msg = data.message,"Sending Ntfy message");
match ntfy_tx.send(Message::Ntfy(data)) {
Ok(_) => Ok((true, mlua::Value::Nil)),
Err(_) => {
let msg = lua.create_string("Failed to send message")?;
Ok((false, mlua::Value::String(msg)))
}
}
})?;
let set_api_task_fn = scope.create_function_mut(|_, data: mlua::Value| {
let task: ApiTask = lua.from_value(data)?;
tracing::trace!("Sending task request: id = {}, delay = {}", task.id, task.delay);
match api_tx.send(task) {
Ok(_) => Ok((true, mlua::Value::Nil)),
Err(_) => {
let msg = lua.create_string("Failed to trigger task")?;
Ok((false, mlua::Value::String(msg)))
}
}
})?;
globals.set("log", log_tbl)?;
globals.set("ntfy", ntfy_fn)?;
globals.set("api_task", set_api_task_fn)?;
config.exec()?;
let event_fn: Function = match globals.get("on_event") {
Ok(f) => f,
Err(err) => match err {
mlua::Error::FromLuaConversionError { from, to: _, message: _ } => {
let err = mlua::Error::runtime(format!("Global function 'on_event' not defined properly. Got value of type '{}'", from));
return Err(err);
}
err => return Err(err),
},
};
// Main blocking loop. As long as we can receive events, this scope will stay active.
while let Ok(event) = event_rx.recv() {
match event {
Event::Webhook(data) => {
tracing::trace!(id = data.topic, "Received webhook event");
let data = lua.to_value(&data)?;
event_fn.call::<_, ()>(("webhook", data))?
}
Event::Api(data) => {
tracing::trace!(id = data.id, status = data.status, "Received api event");
let data = data.into_lua(&lua)?;
event_fn.call::<_, ()>(("api", data))?
}
Event::Error(data) => {
tracing::trace!(id = data.id, message = data.message, "Received error event");
let data = lua.to_value(&data)?;
event_fn.call::<_, ()>(("error", data))?
}
}
}
Ok(())
})?;
Ok(())
}

View file

@ -1,70 +0,0 @@
use std::sync::mpsc::Sender;
use color_eyre::eyre;
use color_eyre::Result;
use reqwest::header;
use reqwest::Client;
use tokio::runtime;
use tokio::sync::mpsc::UnboundedReceiver;
use crate::types::{Event, Message, NtfyMessage};
use crate::APP_USER_AGENT;
#[tracing::instrument]
async fn send_ntfy(client: &Client, data: &NtfyMessage) -> Result<()> {
let body = serde_json::to_string(data)?;
let res = client
.post(&data.url)
.headers({
let mut headers = header::HeaderMap::new();
let auth = format!("Bearer {}", data.token);
let mut auth = header::HeaderValue::from_str(&auth)?;
auth.set_sensitive(true);
headers.insert(header::AUTHORIZATION, auth);
headers
})
.body(body)
.send()
.await?;
if res.status().as_u16() >= 400 {
let body = res.text().await.unwrap_or_default();
eyre::bail!("Ntfy server returned error: {}", body);
}
Ok(())
}
#[tracing::instrument(skip_all)]
pub async fn run(event_tx: Sender<Event>, mut ntfy_rx: UnboundedReceiver<Message>) -> Result<()> {
let ntfy_client = Client::builder().user_agent(APP_USER_AGENT).build()?;
while let Some(message) = ntfy_rx.recv().await {
match message {
Message::Ntfy(data) => {
tracing::trace!(topic = data.topic, "Sending Ntfy notification");
if let Err(err) = send_ntfy(&ntfy_client, &data).await {
tracing::error!("Failed to send to Ntfy: {:?}", err);
event_tx.send(Event::error(data.topic, err))?;
}
}
}
}
tracing::debug!("Stopped receiving messages");
Ok(())
}
#[tracing::instrument(skip_all)]
pub fn worker(event_tx: Sender<Event>, ntfy_rx: UnboundedReceiver<Message>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("sender")
.enable_io()
.build()?;
rt.block_on(run(event_tx, ntfy_rx))
}

View file

@ -1,76 +0,0 @@
use std::collections::HashMap;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{Json, Router};
use color_eyre::eyre::Context as _;
use color_eyre::Result;
use tokio::net::TcpListener;
use tokio::runtime;
use crate::types::{Event, WebhookEvent};
struct AppState {
event_tx: Sender<Event>,
}
pub fn worker(event_tx: Sender<Event>) -> Result<()> {
let rt = runtime::Builder::new_current_thread()
.thread_name("server")
.enable_io()
.build()?;
rt.block_on(task(event_tx))
}
async fn task(event_tx: Sender<Event>) -> Result<()> {
let state = AppState { event_tx };
let shared_state = Arc::new(state);
let app = Router::new()
.route(
"/webhook/:topic",
get(webhook_handler).post(webhook_handler),
)
.with_state(shared_state);
let listener = TcpListener::bind("0.0.0.0:3000")
.await
.wrap_err("Failed to bind to TCP socket")?;
tracing::info!("Listening on \"0.0.0.0:3000\"");
axum::serve(listener, app)
.await
.wrap_err("Failed to start server")
}
async fn webhook_handler(
State(state): State<Arc<AppState>>,
Path(topic): Path<String>,
Query(query): Query<HashMap<String, String>>,
body: Option<Json<serde_json::Value>>,
) -> impl IntoResponse {
let event = Event::Webhook(WebhookEvent {
topic,
query,
body: body
.map(|Json(body)| body)
.unwrap_or(serde_json::Value::Null),
});
tracing::debug!("Received webhook event: {:?}", &event);
if state.event_tx.send(event).is_err() {
tracing::error!("Failed to trigger webhook event");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to trigger webhook event",
)
} else {
(StatusCode::NO_CONTENT, "")
}
}