diff --git a/.gitignore b/.gitignore index 2960b25..1728cbd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target /xuproxy.toml -.idea/ \ No newline at end of file +.idea/ +/xuproxy.db \ No newline at end of file diff --git a/src/database.rs b/src/database.rs index 0a06e4a..ded21e1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -36,9 +36,11 @@ impl DbExecutor { Task::GetFile(tx, p) => { let paste = self .db - .query_row("select url from files where path=? limit 1", params![p], |r| { - r.get(0) - }) + .query_row( + "select url from files where path=? limit 1", + params![p], + |r| r.get(0), + ) .optional() .unwrap_or_else(|v| { error!("A database error has occurred: {}", v); diff --git a/src/main.rs b/src/main.rs index 8087e0e..1bd463b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,19 @@ use hmac_sha256::HMAC; use std::net::SocketAddr; use std::sync::Arc; -use std::{env, fs}; +use std::{env, fs, thread}; -use log::{debug, info}; +use crate::database::{DbExecutor, ExecutorConnection}; +use log::{debug, error, info}; use serde::Deserialize; -use tokio::select; +use tokio::sync::oneshot; use warp::http::StatusCode; use warp::hyper::body::Bytes; -use warp::{any, body, header, path, query, Filter, Reply}; use warp::path::FullPath; +use warp::{any, body, header, path, query, Filter, Reply}; -mod discord; mod database; +mod discord; #[cfg(unix)] async fn terminate_signal() { @@ -39,6 +40,7 @@ struct Config { address: SocketAddr, webhook: String, secret: String, + dbpath: String, } #[derive(Debug, Deserialize)] @@ -55,12 +57,20 @@ async fn main() -> anyhow::Result<()> { info!("Loading config '{}'", config_path); let config_str = fs::read_to_string(config_path)?; let config: Config = toml::from_str(&config_str)?; + info!("Initializing database..."); + let (mut db_exec, db_conn) = DbExecutor::create(&config.dbpath)?; + let executor_thread = thread::spawn(move || { + db_exec.run(); + log::info!("Database executor shutting down"); + }); - select! { - r = run_server(config) => r?, - _ = terminate_signal() => {} - } - info!("Shutting down..."); + let (ctx, crx) = oneshot::channel(); + let server_task = tokio::spawn(run_server(config, db_conn, crx)); + terminate_signal().await; + info!("Shutdown signal received, powering down"); + let _ = ctx.send(()); + server_task.await.unwrap_or_else(|e| error!("Couldn't await the server task: {}", e)); + executor_thread.join().unwrap(); Ok(()) } @@ -75,6 +85,7 @@ async fn handle_put( query: PutQueryString, body: Bytes, config: Arc, + db: ExecutorConnection ) -> impl Reply { let filename_str = &filename.as_str()[1..]; debug!( @@ -92,15 +103,26 @@ async fn handle_put( debug!("Token '{}' doesn't match HMAC secret", query.v); return StatusCode::FORBIDDEN; } - if let Err(e) = discord::upload_webhook(&config.webhook, body, filename_str).await { - debug!("Could not upload file to Discord: {}", e); - return StatusCode::FORBIDDEN; + match discord::upload_webhook(&config.webhook, body, filename_str).await { + Err(e) => { + debug!("Could not upload file to Discord: {}", e); + StatusCode::FORBIDDEN + } + Ok(url) => { + if !db.add_file(filename_str.to_string(), url).await { + StatusCode::FORBIDDEN + } else { + StatusCode::CREATED + } + } } - - StatusCode::CREATED } -async fn run_server(conf: Config) -> anyhow::Result<()> { +async fn run_server( + conf: Config, + db: ExecutorConnection, + cancel: oneshot::Receiver<()> +) { let conf = Arc::new(conf); let put_route = warp::put() @@ -112,10 +134,12 @@ async fn run_server(conf: Config) -> anyhow::Result<()> { let conf = conf.clone(); move || conf.clone() })) + .and(any().map({ + let db = db.clone(); + move || db.clone() + })) .then(handle_put); - - let routes = put_route; if let Some(tls) = &conf.tls { @@ -123,10 +147,13 @@ async fn run_server(conf: Config) -> anyhow::Result<()> { .tls() .cert_path(&tls.cert) .key_path(&tls.key) - .run(conf.address) - .await + .bind_with_graceful_shutdown(conf.address, async { + let _ = cancel.await; + }).1.await; } else { - warp::serve(routes).run(conf.address).await - } - Ok(()) + warp::serve(routes).bind_with_graceful_shutdown(conf.address, async { + let _ = cancel.await; + }).1.await; + }; + info!("Webserver shutting down"); }