From 334358d3804bc79ac89c4183787e12640efd8cd1 Mon Sep 17 00:00:00 2001 From: lemon-sh Date: Thu, 20 Jan 2022 15:47:30 +0100 Subject: [PATCH] Add GET handler --- src/discord.rs | 45 +++++++++++++++++++++++++++++++++++---------- src/main.rs | 50 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 15 deletions(-) diff --git a/src/discord.rs b/src/discord.rs index e2af1a5..269fe6d 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -1,8 +1,13 @@ -use anyhow::anyhow; +// warning: trash error handling here + +use std::fmt::Display; +use log::debug; use once_cell::sync::OnceCell; use reqwest::multipart::{Form, Part}; -use reqwest::{Body, Client}; +use reqwest::{Body, Client, IntoUrl}; +use reqwest::header::HeaderMap; use serde_json::Value; +use crate::Bytes; static CLIENT: OnceCell = OnceCell::new(); @@ -10,19 +15,39 @@ pub async fn upload_webhook>( webhook: &str, file: T, filename: &str, -) -> anyhow::Result { +) -> Option { + debug!("Uploading '{}' to Discord", filename); let client = CLIENT.get_or_init(Client::new); let form = Form::new().part("file", Part::stream(file).file_name(filename.to_string())); let req = client .post(webhook) .multipart(form) .send() - .await? + .await.ok()? .json::() - .await?; - if let Some(url) = req["attachments"][0]["url"].as_str() { - Ok(url.into()) - } else { - Err(anyhow!("Discord didn't include the URL in the response")) - } + .await.ok()?; + req["attachments"][0]["url"].as_str().map(|v| v.to_string()) +} + +fn extract_headers(h: &HeaderMap) -> Option<(u64, String)> { + let content_length = h.get("content-length").and_then(|v| v.to_str().ok()).and_then(|v| v.parse::().ok())?; + let mime = h.get("content-type").and_then(|v| v.to_str().ok()).map(|v| v.to_string())?; + Some((content_length, mime)) +} + +pub async fn head(url: U) -> Option<(u64, String)> { + debug!("Downloading headers of '{}' from Discord", url); + let client = CLIENT.get_or_init(Client::new); + let resp = client.head(url).send().await.ok()?; + let headers = resp.headers(); + extract_headers(headers) +} + +pub async fn get(url: U) -> Option<(u64, String, Bytes)> { + debug!("Downloading '{}' from Discord", url); + let client = CLIENT.get_or_init(Client::new); + let resp = client.get(url).send().await.ok()?; + let headers = extract_headers(resp.headers())?; + let bytes = resp.bytes().await.ok()?; + Some((headers.0, headers.1, bytes)) } diff --git a/src/main.rs b/src/main.rs index 1bd463b..e06651f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use crate::database::{DbExecutor, ExecutorConnection}; use log::{debug, error, info}; use serde::Deserialize; use tokio::sync::oneshot; -use warp::http::StatusCode; +use warp::http::{Response, StatusCode}; use warp::hyper::body::Bytes; use warp::path::FullPath; use warp::{any, body, header, path, query, Filter, Reply}; @@ -92,6 +92,10 @@ async fn handle_put( "Received PUT request, name({}) length({}) token({})", filename_str, length, query.v ); + if filename_str.is_empty() { + debug!("Empty file path submitted"); + return StatusCode::FORBIDDEN; + } let mut supplied_token = [0_u8; 32]; if let Err(e) = hex::decode_to_slice(&query.v, &mut supplied_token) { debug!("Failed to parse hex string '{}': {}", query.v, e); @@ -104,11 +108,11 @@ async fn handle_put( return StatusCode::FORBIDDEN; } match discord::upload_webhook(&config.webhook, body, filename_str).await { - Err(e) => { - debug!("Could not upload file to Discord: {}", e); + None => { + debug!("Could not upload '{}' to Discord", filename_str); StatusCode::FORBIDDEN } - Ok(url) => { + Some(url) => { if !db.add_file(filename_str.to_string(), url).await { StatusCode::FORBIDDEN } else { @@ -118,6 +122,34 @@ async fn handle_put( } } +async fn handle_get(filename: FullPath, db: ExecutorConnection) -> impl Reply { + let filename_str = &filename.as_str()[1..]; + debug!("Received GET request, name({})", filename_str); + if filename_str.is_empty() { + debug!("Empty file path submitted"); + return StatusCode::FORBIDDEN.into_response(); + } + match db.get_file(filename_str.to_string()).await { + Some(url) => { + match discord::get(&url).await { + Some(o) => { + Response::builder() + .header("Content-length", o.0) + .header("Content-type", o.1) + .body(o.2).unwrap().into_response() + } + None => { + debug!("Could not download '{}' from Discord", url); + StatusCode::FORBIDDEN.into_response() + } + } + } + None => { + StatusCode::NOT_FOUND.into_response() + } + } +} + async fn run_server( conf: Config, db: ExecutorConnection, @@ -140,7 +172,15 @@ async fn run_server( })) .then(handle_put); - let routes = put_route; + let get_route = warp::get() + .and(path::full()) + .and(any().map({ + let db = db.clone(); + move || db.clone() + })) + .then(handle_get); + + let routes = put_route.or(get_route); if let Some(tls) = &conf.tls { warp::serve(routes)