248 lines
7.8 KiB
Rust
248 lines
7.8 KiB
Rust
use hmac_sha256::HMAC;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use std::{env, fs, thread};
|
|
|
|
use crate::database::{DbExecutor, ExecutorConnection};
|
|
use log::{debug, error, info, warn};
|
|
use serde::Deserialize;
|
|
use tokio::sync::broadcast;
|
|
use tokio::sync::broadcast::Receiver;
|
|
use warp::http::{Response, StatusCode};
|
|
use warp::hyper::body::Bytes;
|
|
use warp::path::FullPath;
|
|
use warp::{any, body, header, path, query, Filter, Reply};
|
|
use tokio::time::sleep;
|
|
use tokio::time::Duration;
|
|
|
|
mod database;
|
|
mod discord;
|
|
|
|
#[cfg(unix)]
|
|
async fn terminate_signal() {
|
|
use tokio::signal::unix::{signal, SignalKind};
|
|
let mut sigterm = signal(SignalKind::terminate()).unwrap();
|
|
let mut sigint = signal(SignalKind::interrupt()).unwrap();
|
|
debug!("Installed ctrl+c handler");
|
|
tokio::select! {
|
|
_ = sigterm.recv() => {},
|
|
_ = sigint.recv() => {}
|
|
}
|
|
}
|
|
#[cfg(windows)]
|
|
async fn terminate_signal() {
|
|
use tokio::signal::windows::ctrl_c;
|
|
let mut ctrlc = ctrl_c().unwrap();
|
|
debug!("Installed ctrl+c handler");
|
|
let _ = ctrlc.recv().await;
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct Config {
|
|
tls: Option<TlsConfig>,
|
|
address: SocketAddr,
|
|
webhook: String,
|
|
secret: String,
|
|
dbpath: String,
|
|
cleanup: Option<u64>
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TlsConfig {
|
|
cert: String,
|
|
key: String,
|
|
}
|
|
|
|
#[tokio::main(flavor = "current_thread")]
|
|
async fn main() -> anyhow::Result<()> {
|
|
pretty_env_logger::init_custom_env("XUPROXY_LOG");
|
|
let config_var = env::var("XUPROXY_CONFIG");
|
|
let config_path = config_var.as_deref().unwrap_or("xuproxy.toml");
|
|
info!("Loading config '{}'", config_path);
|
|
let config_str = fs::read_to_string(config_path)?;
|
|
let config: Arc<Config> = Arc::new(toml::from_str(&config_str)?);
|
|
info!("Initializing database...");
|
|
let (mut db_exec, db_conn) = DbExecutor::create(&config.dbpath)?;
|
|
let executor_thread = thread::spawn(move || {
|
|
db_exec.run();
|
|
log::info!("Database executor shutting down");
|
|
});
|
|
|
|
let (ctx, _) = broadcast::channel(1);
|
|
let server_task = tokio::spawn(run_server(config.clone(), db_conn.clone(), ctx.subscribe()));
|
|
let cleanup_task = config.cleanup.map(|_| tokio::spawn(cleanup_task(db_conn, config, ctx.subscribe())));
|
|
terminate_signal().await;
|
|
info!("Shutdown signal received, powering down");
|
|
let _ = ctx.send(());
|
|
server_task.await
|
|
.unwrap_or_else(|e| error!("Couldn't await the server task: {}", e));
|
|
if let Some(t) = cleanup_task {
|
|
t.await.unwrap_or_else(|e| error!("Couldn't await the cleanup task: {}", e));
|
|
}
|
|
executor_thread.join().unwrap();
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct PutQueryString {
|
|
v: String,
|
|
}
|
|
|
|
async fn handle_put(
|
|
filename: FullPath,
|
|
length: u64,
|
|
query: PutQueryString,
|
|
body: Bytes,
|
|
config: Arc<Config>,
|
|
db: ExecutorConnection,
|
|
) -> impl Reply {
|
|
let filename_str = &filename.as_str()[1..];
|
|
debug!(
|
|
"Received PUT request, name({}) length({}) token({})",
|
|
filename_str, length, query.v
|
|
);
|
|
if filename_str.is_empty() {
|
|
debug!("Empty file path submitted");
|
|
return StatusCode::FORBIDDEN;
|
|
}
|
|
let mut supplied_token = [0_u8; 32];
|
|
if let Err(e) = hex::decode_to_slice(&query.v, &mut supplied_token) {
|
|
debug!("Failed to parse hex string '{}': {}", query.v, e);
|
|
return StatusCode::FORBIDDEN;
|
|
}
|
|
let hmac_input = format!("{} {}", filename_str, length);
|
|
let calculated_token = HMAC::mac(hmac_input.as_bytes(), config.secret.as_bytes());
|
|
if supplied_token != calculated_token {
|
|
debug!("Token '{}' doesn't match HMAC secret", query.v);
|
|
return StatusCode::FORBIDDEN;
|
|
}
|
|
|
|
match discord::upload_webhook(&config.webhook, body, filename_str).await {
|
|
Err(e) => {
|
|
warn!("Could not upload file to Discord: {}", e);
|
|
StatusCode::FORBIDDEN
|
|
}
|
|
Ok(o) => {
|
|
if db.add_file(filename_str.to_string(), o.0, o.1).await {
|
|
StatusCode::CREATED
|
|
} else {
|
|
StatusCode::FORBIDDEN
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn cleanup_task(db: ExecutorConnection, conf: Arc<Config>, mut cancel: Receiver<()>) {
|
|
let older_than = conf.cleanup.unwrap()*3600;
|
|
loop {
|
|
debug!("Starting daily cleanup...");
|
|
if let Some(mids) = db.cleanup(older_than).await {
|
|
let midslen = mids.len();
|
|
for mid in mids {
|
|
if let Err(e) = discord::delete(&conf.webhook, mid).await {
|
|
warn!("Couldn't delete message {}: {}", mid, e);
|
|
}
|
|
}
|
|
info!("Daily cleanup complete! Removed {} entries", midslen);
|
|
}
|
|
tokio::select! {
|
|
_ = sleep(Duration::from_secs(86400)) => {},
|
|
_ = cancel.recv() => break
|
|
}
|
|
}
|
|
info!("Daily cleanup task has been shut down");
|
|
}
|
|
|
|
async fn handle_get_n_head(filename: FullPath, db: ExecutorConnection, head: bool) -> impl Reply {
|
|
let filename_str = &filename.as_str()[1..];
|
|
debug!("Received GET request, name({})", filename_str);
|
|
if filename_str.is_empty() {
|
|
debug!("Empty file path submitted");
|
|
return StatusCode::FORBIDDEN.into_response();
|
|
}
|
|
match db.get_file(filename_str.to_string()).await {
|
|
Some(url) => {
|
|
let err = if head {
|
|
match discord::head(&url).await {
|
|
Ok(o) => {
|
|
return Response::builder()
|
|
.header("Content-Length", o.0)
|
|
.header("Content-Type", o.1)
|
|
.body("")
|
|
.into_response()
|
|
}
|
|
Err(e) => e,
|
|
}
|
|
} else {
|
|
match discord::get(&url).await {
|
|
Ok(o) => {
|
|
return Response::builder()
|
|
.header("Content-Length", o.0)
|
|
.header("Content-Type", o.1)
|
|
.body(o.2)
|
|
.unwrap()
|
|
.into_response();
|
|
}
|
|
Err(e) => e,
|
|
}
|
|
};
|
|
warn!("Could not retrieve '{}' from Discord: {}", url, err);
|
|
StatusCode::FORBIDDEN.into_response()
|
|
}
|
|
None => StatusCode::NOT_FOUND.into_response(),
|
|
}
|
|
}
|
|
|
|
async fn run_server(conf: Arc<Config>, db: ExecutorConnection, mut cancel: Receiver<()>) {
|
|
let put_route = warp::put()
|
|
.and(path::full())
|
|
.and(header::<u64>("content-length"))
|
|
.and(query::<PutQueryString>())
|
|
.and(body::bytes())
|
|
.and(any().map({
|
|
let conf = conf.clone();
|
|
move || conf.clone()
|
|
}))
|
|
.and(any().map({
|
|
let db = db.clone();
|
|
move || db.clone()
|
|
}))
|
|
.then(handle_put);
|
|
|
|
let get_route = warp::get()
|
|
.and(path::full())
|
|
.and(any().map({
|
|
let db = db.clone();
|
|
move || db.clone()
|
|
}))
|
|
.and(any().map(|| false)) // head parameter
|
|
.then(handle_get_n_head);
|
|
|
|
let head_route = warp::head()
|
|
.and(path::full())
|
|
.and(any().map({
|
|
let db = db.clone();
|
|
move || db.clone()
|
|
}))
|
|
.and(any().map(|| true)) // head parameter
|
|
.then(handle_get_n_head);
|
|
|
|
let routes = put_route.or(get_route).or(head_route);
|
|
|
|
if let Some(tls) = &conf.tls {
|
|
warp::serve(routes)
|
|
.tls()
|
|
.cert_path(&tls.cert)
|
|
.key_path(&tls.key)
|
|
.bind_with_graceful_shutdown(conf.address, async move {
|
|
let _ = cancel.recv().await;
|
|
}).1.await;
|
|
} else {
|
|
warp::serve(routes)
|
|
.bind_with_graceful_shutdown(conf.address, async move {
|
|
let _ = cancel.recv().await;
|
|
}).1.await;
|
|
};
|
|
info!("Webserver shutting down");
|
|
}
|