Compare commits
6 commits
Author | SHA1 | Date | |
---|---|---|---|
lemonsh | bc83fce101 | ||
lemon-sh | 5aebcc1fce | ||
lemon-sh | 21d8edbef1 | ||
lemon-sh | 608a2e8797 | ||
lemon-sh | c58db54696 | ||
lemon-sh | d7e68e827e |
18
Cargo.lock
generated
18
Cargo.lock
generated
|
@ -488,9 +488,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
|||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.112"
|
||||
version = "0.2.113"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125"
|
||||
checksum = "eef78b64d87775463c549fbd80e19249ef436ea3bf1de2a1eb7e717ec7fab1e9"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
|
@ -884,18 +884,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.133"
|
||||
version = "1.0.134"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97565067517b60e2d1ea8b268e59ce036de907ac523ad83a0475da04e818989a"
|
||||
checksum = "96b3c34c1690edf8174f5b289a336ab03f568a4460d8c6df75f2f3a692b3bc6a"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.133"
|
||||
version = "1.0.134"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed201699328568d8d08208fdd080e3ff594e6c422e438b6705905da01005d537"
|
||||
checksum = "784ed1fbfa13fe191077537b0d70ec8ad1e903cfe04831da608aa36457cb653d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -961,9 +961,9 @@ checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
|
|||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.4.2"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516"
|
||||
checksum = "0f82496b90c36d70af5fcd482edaa2e0bd16fade569de1330405fecbbdac736b"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
|
@ -1391,7 +1391,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "xuproxy"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"hex",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "xuproxy"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
|
||||
[profile.release]
|
||||
|
@ -18,6 +18,6 @@ toml = "0.5"
|
|||
serde = { version = "1.0", features = ["derive"] }
|
||||
hmac-sha256 = "1.1"
|
||||
hex = "0.4"
|
||||
reqwest = { version = "0.11", features = ["multipart", "json"] }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["multipart", "json", "rustls-tls"] }
|
||||
once_cell = "1.9"
|
||||
serde_json = "1.0"
|
||||
|
|
|
@ -7,8 +7,11 @@ use tokio::sync::{
|
|||
|
||||
#[derive(Debug)]
|
||||
enum Task {
|
||||
GetFile(oneshot::Sender<Option<String>>, String),
|
||||
AddFile(oneshot::Sender<bool>, String, String),
|
||||
// syntax: TaskName(oneshot::Sender<ReturnType>, /* param_name */ ParamType),
|
||||
|
||||
GetFile(oneshot::Sender<Option<String>>, /* path */ String),
|
||||
AddFile(oneshot::Sender<bool>, /* path */ String, /* url */ String, /* mid */ u64),
|
||||
Cleanup(oneshot::Sender<Option<Vec<u64>>>, /* older_than */ u64),
|
||||
}
|
||||
|
||||
pub struct DbExecutor {
|
||||
|
@ -21,8 +24,8 @@ impl DbExecutor {
|
|||
let (tx, rx) = unbounded_channel();
|
||||
let db = rusqlite::Connection::open(dbpath)?;
|
||||
db.execute(
|
||||
"create table if not exists files(\
|
||||
id integer primary key, path text not null, url text not null, \
|
||||
"create table if not exists files(id integer primary key,\
|
||||
path text not null, url text not null, mid integer not null, \
|
||||
timestamp integer default (strftime('%s','now')))",
|
||||
[],
|
||||
)?;
|
||||
|
@ -33,12 +36,12 @@ impl DbExecutor {
|
|||
pub fn run(&mut self) {
|
||||
while let Some(task) = self.rx.blocking_recv() {
|
||||
match task {
|
||||
Task::GetFile(tx, p) => {
|
||||
Task::GetFile(tx, path) => {
|
||||
let paste = self
|
||||
.db
|
||||
.query_row(
|
||||
"select url from files where path=? limit 1",
|
||||
params![p],
|
||||
params![path],
|
||||
|r| r.get(0),
|
||||
)
|
||||
.optional()
|
||||
|
@ -46,18 +49,39 @@ impl DbExecutor {
|
|||
error!("A database error has occurred: {}", v);
|
||||
None
|
||||
});
|
||||
tx.send(paste).unwrap();
|
||||
let _ = tx.send(paste);
|
||||
}
|
||||
Task::AddFile(tx, p, u) => {
|
||||
if let Err(e) = self
|
||||
Task::AddFile(tx, path, url, mid) => {
|
||||
let _ = if let Err(e) = self
|
||||
.db
|
||||
.execute("insert into files(path,url) values(?,?)", params![p, u])
|
||||
.execute("insert into files(path,url,mid) values(?,?,?)", params![path, url, mid])
|
||||
{
|
||||
error!("A database error has occurred: {}", e);
|
||||
tx.send(false).unwrap();
|
||||
tx.send(false)
|
||||
} else {
|
||||
tx.send(true).unwrap();
|
||||
}
|
||||
tx.send(true)
|
||||
};
|
||||
}
|
||||
Task::Cleanup(tx, older_than) => {
|
||||
let _ = tx.send(match self.db
|
||||
.prepare("delete from files where timestamp < strftime('%s','now')-? returning mid")
|
||||
.and_then(|mut v| v.query(params![older_than])
|
||||
.and_then(|mut v| {
|
||||
let mut mids: Vec<u64> = Vec::new();
|
||||
while let Some(row) = v.next()? {
|
||||
mids.push(row.get(0)?);
|
||||
}
|
||||
Ok(mids)
|
||||
}))
|
||||
{
|
||||
Ok(o) => {
|
||||
Some(o)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("A database error has occurred: {}", e);
|
||||
None
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,15 +100,18 @@ impl Clone for ExecutorConnection {
|
|||
}
|
||||
}
|
||||
|
||||
impl ExecutorConnection {
|
||||
pub async fn add_file(&self, path: String, url: String) -> bool {
|
||||
let (otx, orx) = oneshot::channel();
|
||||
self.tx.send(Task::AddFile(otx, path, url)).unwrap();
|
||||
orx.await.unwrap()
|
||||
}
|
||||
pub async fn get_file(&self, path: String) -> Option<String> {
|
||||
let (otx, orx) = oneshot::channel();
|
||||
self.tx.send(Task::GetFile(otx, path)).unwrap();
|
||||
orx.await.unwrap()
|
||||
macro_rules! executor_wrapper {
|
||||
($name:ident, $task:expr, $ret:ty, $($arg:ident: $ty:ty),*) => {
|
||||
pub async fn $name(&self, $($arg: $ty),*) -> $ret {
|
||||
let (otx, orx) = oneshot::channel();
|
||||
self.tx.send($task(otx, $($arg),*)).unwrap();
|
||||
orx.await.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutorConnection {
|
||||
executor_wrapper!(add_file, Task::AddFile, bool, path: String, url: String, mid: u64);
|
||||
executor_wrapper!(cleanup, Task::Cleanup, Option<Vec<u64>>, older_than: u64);
|
||||
executor_wrapper!(get_file, Task::GetFile, Option<String>, path: String);
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
use crate::Bytes;
|
||||
use anyhow::anyhow;
|
||||
use log::debug;
|
||||
use log::{debug, trace, warn};
|
||||
use once_cell::sync::OnceCell;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::multipart::{Form, Part};
|
||||
use reqwest::{Body, Client, IntoUrl};
|
||||
use reqwest::{Body, Client, IntoUrl, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::fmt::Display;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
// note: only delete has rate-limit handling
|
||||
|
||||
static CLIENT: OnceCell<Client> = OnceCell::new();
|
||||
|
||||
|
@ -14,21 +18,22 @@ pub async fn upload_webhook<T: Into<Body>>(
|
|||
webhook: &str,
|
||||
file: T,
|
||||
filename: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
) -> anyhow::Result<(String, u64)> {
|
||||
debug!("Uploading '{}' to Discord", filename);
|
||||
let client = CLIENT.get_or_init(Client::new);
|
||||
let form = Form::new().part("file", Part::stream(file).file_name(filename.to_string()));
|
||||
let req = client
|
||||
let resp = client
|
||||
.post(webhook)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?
|
||||
.json::<Value>()
|
||||
.await?;
|
||||
if let Some(u) = req["attachments"][0]["url"].as_str() {
|
||||
Ok(u.into())
|
||||
trace!("Received JSON from Discord: {}", resp);
|
||||
if let (Some(u), Some(i)) = (resp["attachments"][0]["url"].as_str(), resp["id"].as_str().and_then(|f| f.parse::<u64>().ok())) {
|
||||
Ok((u.into(), i))
|
||||
} else {
|
||||
Err(anyhow!("Discord response didn't include the URL"))
|
||||
Err(anyhow!("Discord response didn't include the URL or message ID"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,3 +72,33 @@ pub async fn get<U: IntoUrl + Display>(url: U) -> anyhow::Result<(u64, String, B
|
|||
Err(anyhow!("Discord response didn't include the URL"))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete(webhook: &str, mid: u64) -> anyhow::Result<()> {
|
||||
debug!("Deleting message with ID {}", mid);
|
||||
let client = CLIENT.get_or_init(Client::new);
|
||||
let resp = client.delete(format!("{}/messages/{}", webhook, mid)).send().await?;
|
||||
if resp.status() != StatusCode::NO_CONTENT {
|
||||
Err(anyhow!(resp.text().await?))
|
||||
} else {
|
||||
let rt_header = resp.headers()
|
||||
.get("X-RateLimit-Remaining")
|
||||
.and_then(|v| v.to_str().ok()?.parse::<u64>().ok())
|
||||
.and_then(|v| {
|
||||
if v == 0 {
|
||||
resp.headers()
|
||||
.get("X-RateLimit-Reset-After")
|
||||
.and_then(|v| v.to_str().ok()?.parse::<f64>().ok())
|
||||
} else {
|
||||
Some(0.0)
|
||||
}
|
||||
});
|
||||
if let Some(rt) = rt_header {
|
||||
if rt > 0.0 {
|
||||
sleep(Duration::from_secs_f64(rt)).await;
|
||||
}
|
||||
} else {
|
||||
warn!("Couldn't await the rate-limit, because there was a problem with the rate-limit header in Discord's response")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
72
src/main.rs
72
src/main.rs
|
@ -6,11 +6,14 @@ use std::{env, fs, thread};
|
|||
use crate::database::{DbExecutor, ExecutorConnection};
|
||||
use log::{debug, error, info, warn};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use warp::http::{Response, StatusCode};
|
||||
use warp::hyper::body::Bytes;
|
||||
use warp::path::FullPath;
|
||||
use warp::{any, body, header, path, query, Filter, Reply};
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::Duration;
|
||||
|
||||
mod database;
|
||||
mod discord;
|
||||
|
@ -41,6 +44,7 @@ struct Config {
|
|||
webhook: String,
|
||||
secret: String,
|
||||
dbpath: String,
|
||||
cleanup: Option<u64>
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -56,7 +60,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let config_path = config_var.as_deref().unwrap_or("xuproxy.toml");
|
||||
info!("Loading config '{}'", config_path);
|
||||
let config_str = fs::read_to_string(config_path)?;
|
||||
let config: Config = toml::from_str(&config_str)?;
|
||||
let config: Arc<Config> = Arc::new(toml::from_str(&config_str)?);
|
||||
info!("Initializing database...");
|
||||
let (mut db_exec, db_conn) = DbExecutor::create(&config.dbpath)?;
|
||||
let executor_thread = thread::spawn(move || {
|
||||
|
@ -64,14 +68,17 @@ async fn main() -> anyhow::Result<()> {
|
|||
log::info!("Database executor shutting down");
|
||||
});
|
||||
|
||||
let (ctx, crx) = oneshot::channel();
|
||||
let server_task = tokio::spawn(run_server(config, db_conn, crx));
|
||||
let (ctx, _) = broadcast::channel(1);
|
||||
let server_task = tokio::spawn(run_server(config.clone(), db_conn.clone(), ctx.subscribe()));
|
||||
let cleanup_task = config.cleanup.map(|_| tokio::spawn(cleanup_task(db_conn, config, ctx.subscribe())));
|
||||
terminate_signal().await;
|
||||
info!("Shutdown signal received, powering down");
|
||||
let _ = ctx.send(());
|
||||
server_task
|
||||
.await
|
||||
server_task.await
|
||||
.unwrap_or_else(|e| error!("Couldn't await the server task: {}", e));
|
||||
if let Some(t) = cleanup_task {
|
||||
t.await.unwrap_or_else(|e| error!("Couldn't await the cleanup task: {}", e));
|
||||
}
|
||||
executor_thread.join().unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
@ -109,13 +116,14 @@ async fn handle_put(
|
|||
debug!("Token '{}' doesn't match HMAC secret", query.v);
|
||||
return StatusCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
match discord::upload_webhook(&config.webhook, body, filename_str).await {
|
||||
Err(e) => {
|
||||
warn!("Could not upload file to Discord: {}", e);
|
||||
StatusCode::FORBIDDEN
|
||||
}
|
||||
Ok(url) => {
|
||||
if db.add_file(filename_str.to_string(), url).await {
|
||||
Ok(o) => {
|
||||
if db.add_file(filename_str.to_string(), o.0, o.1).await {
|
||||
StatusCode::CREATED
|
||||
} else {
|
||||
StatusCode::FORBIDDEN
|
||||
|
@ -124,6 +132,27 @@ async fn handle_put(
|
|||
}
|
||||
}
|
||||
|
||||
async fn cleanup_task(db: ExecutorConnection, conf: Arc<Config>, mut cancel: Receiver<()>) {
|
||||
let older_than = conf.cleanup.unwrap()*3600;
|
||||
loop {
|
||||
debug!("Starting daily cleanup...");
|
||||
if let Some(mids) = db.cleanup(older_than).await {
|
||||
let midslen = mids.len();
|
||||
for mid in mids {
|
||||
if let Err(e) = discord::delete(&conf.webhook, mid).await {
|
||||
warn!("Couldn't delete message {}: {}", mid, e);
|
||||
}
|
||||
}
|
||||
info!("Daily cleanup complete! Removed {} entries", midslen);
|
||||
}
|
||||
tokio::select! {
|
||||
_ = sleep(Duration::from_secs(86400)) => {},
|
||||
_ = cancel.recv() => break
|
||||
}
|
||||
}
|
||||
info!("Daily cleanup task has been shut down");
|
||||
}
|
||||
|
||||
async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: bool) -> impl Reply {
|
||||
let filename_str = &filename.as_str()[1..];
|
||||
debug!("Received GET request, name({})", filename_str);
|
||||
|
@ -139,9 +168,10 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
|
|||
return Response::builder()
|
||||
.header("Content-Length", o.0)
|
||||
.header("Content-Type", o.1)
|
||||
.body("").into_response()
|
||||
.body("")
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => e
|
||||
Err(e) => e,
|
||||
}
|
||||
} else {
|
||||
match discord::get(&url).await {
|
||||
|
@ -153,7 +183,7 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
|
|||
.unwrap()
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => e
|
||||
Err(e) => e,
|
||||
}
|
||||
};
|
||||
warn!("Could not retrieve '{}' from Discord: {}", url, err);
|
||||
|
@ -163,9 +193,7 @@ async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: boo
|
|||
}
|
||||
}
|
||||
|
||||
async fn run_server(conf: Config, db: ExecutorConnection, cancel: oneshot::Receiver<()>) {
|
||||
let conf = Arc::new(conf);
|
||||
|
||||
async fn run_server(conf: Arc<Config>, db: ExecutorConnection, mut cancel: Receiver<()>) {
|
||||
let put_route = warp::put()
|
||||
.and(path::full())
|
||||
.and(header::<u64>("content-length"))
|
||||
|
@ -206,18 +234,14 @@ async fn run_server(conf: Config, db: ExecutorConnection, cancel: oneshot::Recei
|
|||
.tls()
|
||||
.cert_path(&tls.cert)
|
||||
.key_path(&tls.key)
|
||||
.bind_with_graceful_shutdown(conf.address, async {
|
||||
let _ = cancel.await;
|
||||
})
|
||||
.1
|
||||
.await;
|
||||
.bind_with_graceful_shutdown(conf.address, async move {
|
||||
let _ = cancel.recv().await;
|
||||
}).1.await;
|
||||
} else {
|
||||
warp::serve(routes)
|
||||
.bind_with_graceful_shutdown(conf.address, async {
|
||||
let _ = cancel.await;
|
||||
})
|
||||
.1
|
||||
.await;
|
||||
.bind_with_graceful_shutdown(conf.address, async move {
|
||||
let _ = cancel.recv().await;
|
||||
}).1.await;
|
||||
};
|
||||
info!("Webserver shutting down");
|
||||
}
|
||||
|
|
Reference in a new issue