Compare commits

...

6 commits

5 changed files with 151 additions and 65 deletions

18
Cargo.lock generated
View file

@ -488,9 +488,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.112"
version = "0.2.113"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125"
checksum = "eef78b64d87775463c549fbd80e19249ef436ea3bf1de2a1eb7e717ec7fab1e9"
[[package]]
name = "libsqlite3-sys"
@ -884,18 +884,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.133"
version = "1.0.134"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97565067517b60e2d1ea8b268e59ce036de907ac523ad83a0475da04e818989a"
checksum = "96b3c34c1690edf8174f5b289a336ab03f568a4460d8c6df75f2f3a692b3bc6a"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.133"
version = "1.0.134"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed201699328568d8d08208fdd080e3ff594e6c422e438b6705905da01005d537"
checksum = "784ed1fbfa13fe191077537b0d70ec8ad1e903cfe04831da608aa36457cb653d"
dependencies = [
"proc-macro2",
"quote",
@ -961,9 +961,9 @@ checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]]
name = "socket2"
version = "0.4.2"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516"
checksum = "0f82496b90c36d70af5fcd482edaa2e0bd16fade569de1330405fecbbdac736b"
dependencies = [
"libc",
"winapi",
@ -1391,7 +1391,7 @@ dependencies = [
[[package]]
name = "xuproxy"
version = "0.1.0"
version = "0.2.0"
dependencies = [
"anyhow",
"hex",

View file

@ -1,6 +1,6 @@
[package]
name = "xuproxy"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
[profile.release]
@ -18,6 +18,6 @@ toml = "0.5"
serde = { version = "1.0", features = ["derive"] }
hmac-sha256 = "1.1"
hex = "0.4"
reqwest = { version = "0.11", features = ["multipart", "json"] }
reqwest = { version = "0.11", default-features = false, features = ["multipart", "json", "rustls-tls"] }
once_cell = "1.9"
serde_json = "1.0"

View file

@ -7,8 +7,11 @@ use tokio::sync::{
#[derive(Debug)]
enum Task {
GetFile(oneshot::Sender<Option<String>>, String),
AddFile(oneshot::Sender<bool>, String, String),
// syntax: TaskName(oneshot::Sender<ReturnType>, /* param_name */ ParamType),
GetFile(oneshot::Sender<Option<String>>, /* path */ String),
AddFile(oneshot::Sender<bool>, /* path */ String, /* url */ String, /* mid */ u64),
Cleanup(oneshot::Sender<Option<Vec<u64>>>, /* older_than */ u64),
}
pub struct DbExecutor {
@ -21,8 +24,8 @@ impl DbExecutor {
let (tx, rx) = unbounded_channel();
let db = rusqlite::Connection::open(dbpath)?;
db.execute(
"create table if not exists files(\
id integer primary key, path text not null, url text not null, \
"create table if not exists files(id integer primary key,\
path text not null, url text not null, mid integer not null, \
timestamp integer default (strftime('%s','now')))",
[],
)?;
@ -33,12 +36,12 @@ impl DbExecutor {
pub fn run(&mut self) {
while let Some(task) = self.rx.blocking_recv() {
match task {
Task::GetFile(tx, p) => {
Task::GetFile(tx, path) => {
let paste = self
.db
.query_row(
"select url from files where path=? limit 1",
params![p],
params![path],
|r| r.get(0),
)
.optional()
@ -46,18 +49,39 @@ impl DbExecutor {
error!("A database error has occurred: {}", v);
None
});
tx.send(paste).unwrap();
let _ = tx.send(paste);
}
Task::AddFile(tx, p, u) => {
if let Err(e) = self
Task::AddFile(tx, path, url, mid) => {
let _ = if let Err(e) = self
.db
.execute("insert into files(path,url) values(?,?)", params![p, u])
.execute("insert into files(path,url,mid) values(?,?,?)", params![path, url, mid])
{
error!("A database error has occurred: {}", e);
tx.send(false).unwrap();
tx.send(false)
} else {
tx.send(true).unwrap();
}
tx.send(true)
};
}
Task::Cleanup(tx, older_than) => {
let _ = tx.send(match self.db
.prepare("delete from files where timestamp < strftime('%s','now')-? returning mid")
.and_then(|mut v| v.query(params![older_than])
.and_then(|mut v| {
let mut mids: Vec<u64> = Vec::new();
while let Some(row) = v.next()? {
mids.push(row.get(0)?);
}
Ok(mids)
}))
{
Ok(o) => {
Some(o)
}
Err(e) => {
error!("A database error has occurred: {}", e);
None
}
});
}
}
}
@ -76,15 +100,18 @@ impl Clone for ExecutorConnection {
}
}
impl ExecutorConnection {
pub async fn add_file(&self, path: String, url: String) -> bool {
let (otx, orx) = oneshot::channel();
self.tx.send(Task::AddFile(otx, path, url)).unwrap();
orx.await.unwrap()
}
pub async fn get_file(&self, path: String) -> Option<String> {
let (otx, orx) = oneshot::channel();
self.tx.send(Task::GetFile(otx, path)).unwrap();
orx.await.unwrap()
macro_rules! executor_wrapper {
($name:ident, $task:expr, $ret:ty, $($arg:ident: $ty:ty),*) => {
pub async fn $name(&self, $($arg: $ty),*) -> $ret {
let (otx, orx) = oneshot::channel();
self.tx.send($task(otx, $($arg),*)).unwrap();
orx.await.unwrap()
}
}
}
impl ExecutorConnection {
executor_wrapper!(add_file, Task::AddFile, bool, path: String, url: String, mid: u64);
executor_wrapper!(cleanup, Task::Cleanup, Option<Vec<u64>>, older_than: u64);
executor_wrapper!(get_file, Task::GetFile, Option<String>, path: String);
}

View file

@ -1,12 +1,16 @@
use crate::Bytes;
use anyhow::anyhow;
use log::debug;
use log::{debug, trace, warn};
use once_cell::sync::OnceCell;
use reqwest::header::HeaderMap;
use reqwest::multipart::{Form, Part};
use reqwest::{Body, Client, IntoUrl};
use reqwest::{Body, Client, IntoUrl, StatusCode};
use serde_json::Value;
use std::fmt::Display;
use std::time::Duration;
use tokio::time::sleep;
// note: only delete has rate-limit handling
static CLIENT: OnceCell<Client> = OnceCell::new();
@ -14,21 +18,22 @@ pub async fn upload_webhook<T: Into<Body>>(
webhook: &str,
file: T,
filename: &str,
) -> anyhow::Result<String> {
) -> anyhow::Result<(String, u64)> {
debug!("Uploading '{}' to Discord", filename);
let client = CLIENT.get_or_init(Client::new);
let form = Form::new().part("file", Part::stream(file).file_name(filename.to_string()));
let req = client
let resp = client
.post(webhook)
.multipart(form)
.send()
.await?
.json::<Value>()
.await?;
if let Some(u) = req["attachments"][0]["url"].as_str() {
Ok(u.into())
trace!("Received JSON from Discord: {}", resp);
if let (Some(u), Some(i)) = (resp["attachments"][0]["url"].as_str(), resp["id"].as_str().and_then(|f| f.parse::<u64>().ok())) {
Ok((u.into(), i))
} else {
Err(anyhow!("Discord response didn't include the URL"))
Err(anyhow!("Discord response didn't include the URL or message ID"))
}
}
@ -67,3 +72,33 @@ pub async fn get<U: IntoUrl + Display>(url: U) -> anyhow::Result<(u64, String, B
Err(anyhow!("Discord response didn't include the URL"))
}
}
pub async fn delete(webhook: &str, mid: u64) -> anyhow::Result<()> {
debug!("Deleting message with ID {}", mid);
let client = CLIENT.get_or_init(Client::new);
let resp = client.delete(format!("{}/messages/{}", webhook, mid)).send().await?;
if resp.status() != StatusCode::NO_CONTENT {
Err(anyhow!(resp.text().await?))
} else {
let rt_header = resp.headers()
.get("X-RateLimit-Remaining")
.and_then(|v| v.to_str().ok()?.parse::<u64>().ok())
.and_then(|v| {
if v == 0 {
resp.headers()
.get("X-RateLimit-Reset-After")
.and_then(|v| v.to_str().ok()?.parse::<f64>().ok())
} else {
Some(0.0)
}
});
if let Some(rt) = rt_header {
if rt > 0.0 {
sleep(Duration::from_secs_f64(rt)).await;
}
} else {
warn!("Couldn't await the rate-limit, because there was a problem with the rate-limit header in Discord's response")
}
Ok(())
}
}

View file

@ -6,11 +6,14 @@ use std::{env, fs, thread};
use crate::database::{DbExecutor, ExecutorConnection};
use log::{debug, error, info, warn};
use serde::Deserialize;
use tokio::sync::oneshot;
use tokio::sync::broadcast;
use tokio::sync::broadcast::Receiver;
use warp::http::{Response, StatusCode};
use warp::hyper::body::Bytes;
use warp::path::FullPath;
use warp::{any, body, header, path, query, Filter, Reply};
use tokio::time::sleep;
use tokio::time::Duration;
mod database;
mod discord;
@ -41,6 +44,7 @@ struct Config {
webhook: String,
secret: String,
dbpath: String,
cleanup: Option<u64>
}
#[derive(Debug, Deserialize)]
@ -56,7 +60,7 @@ async fn main() -> anyhow::Result<()> {
let config_path = config_var.as_deref().unwrap_or("xuproxy.toml");
info!("Loading config '{}'", config_path);
let config_str = fs::read_to_string(config_path)?;
let config: Config = toml::from_str(&config_str)?;
let config: Arc<Config> = Arc::new(toml::from_str(&config_str)?);
info!("Initializing database...");
let (mut db_exec, db_conn) = DbExecutor::create(&config.dbpath)?;
let executor_thread = thread::spawn(move || {
@ -64,14 +68,17 @@ async fn main() -> anyhow::Result<()> {
log::info!("Database executor shutting down");
});
let (ctx, crx) = oneshot::channel();
let server_task = tokio::spawn(run_server(config, db_conn, crx));
let (ctx, _) = broadcast::channel(1);
let server_task = tokio::spawn(run_server(config.clone(), db_conn.clone(), ctx.subscribe()));
let cleanup_task = config.cleanup.map(|_| tokio::spawn(cleanup_task(db_conn, config, ctx.subscribe())));
terminate_signal().await;
info!("Shutdown signal received, powering down");
let _ = ctx.send(());
server_task
.await
server_task.await
.unwrap_or_else(|e| error!("Couldn't await the server task: {}", e));
if let Some(t) = cleanup_task {
t.await.unwrap_or_else(|e| error!("Couldn't await the cleanup task: {}", e));
}
executor_thread.join().unwrap();
Ok(())
}
@ -109,13 +116,14 @@ async fn handle_put(
debug!("Token '{}' doesn't match HMAC secret", query.v);
return StatusCode::FORBIDDEN;
}
match discord::upload_webhook(&config.webhook, body, filename_str).await {
Err(e) => {
warn!("Could not upload file to Discord: {}", e);
StatusCode::FORBIDDEN
}
Ok(url) => {
if db.add_file(filename_str.to_string(), url).await {
Ok(o) => {
if db.add_file(filename_str.to_string(), o.0, o.1).await {
StatusCode::CREATED
} else {
StatusCode::FORBIDDEN
@ -124,6 +132,27 @@ async fn handle_put(
}
}
async fn cleanup_task(db: ExecutorConnection, conf: Arc<Config>, mut cancel: Receiver<()>) {
let older_than = conf.cleanup.unwrap()*3600;
loop {
debug!("Starting daily cleanup...");
if let Some(mids) = db.cleanup(older_than).await {
let midslen = mids.len();
for mid in mids {
if let Err(e) = discord::delete(&conf.webhook, mid).await {
warn!("Couldn't delete message {}: {}", mid, e);
}
}
info!("Daily cleanup complete! Removed {} entries", midslen);
}
tokio::select! {
_ = sleep(Duration::from_secs(86400)) => {},
_ = cancel.recv() => break
}
}
info!("Daily cleanup task has been shut down");
}
async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: bool) -> impl Reply {
let filename_str = &filename.as_str()[1..];
debug!("Received GET request, name({})", filename_str);
@ -139,9 +168,10 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
return Response::builder()
.header("Content-Length", o.0)
.header("Content-Type", o.1)
.body("").into_response()
.body("")
.into_response()
}
Err(e) => e
Err(e) => e,
}
} else {
match discord::get(&url).await {
@ -153,7 +183,7 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
.unwrap()
.into_response();
}
Err(e) => e
Err(e) => e,
}
};
warn!("Could not retrieve '{}' from Discord: {}", url, err);
@ -163,9 +193,7 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
}
}
async fn run_server(conf: Config, db: ExecutorConnection, cancel: oneshot::Receiver<()>) {
let conf = Arc::new(conf);
async fn run_server(conf: Arc<Config>, db: ExecutorConnection, mut cancel: Receiver<()>) {
let put_route = warp::put()
.and(path::full())
.and(header::<u64>("content-length"))
@ -206,18 +234,14 @@ async fn run_server(conf: Config, db: ExecutorConnection, cancel: oneshot::Recei
.tls()
.cert_path(&tls.cert)
.key_path(&tls.key)
.bind_with_graceful_shutdown(conf.address, async {
let _ = cancel.await;
})
.1
.await;
.bind_with_graceful_shutdown(conf.address, async move {
let _ = cancel.recv().await;
}).1.await;
} else {
warp::serve(routes)
.bind_with_graceful_shutdown(conf.address, async {
let _ = cancel.await;
})
.1
.await;
.bind_with_graceful_shutdown(conf.address, async move {
let _ = cancel.recv().await;
}).1.await;
};
info!("Webserver shutting down");
}