2m2d/src/main.rs

287 lines
8.3 KiB
Rust

#![warn(clippy::pedantic)]
#![allow(clippy::unused_async)]
mod tasks;
mod users;
use axum::extract::Extension;
use axum::extract::Query;
use axum::http::StatusCode;
use axum::response::Html;
use axum::response::IntoResponse;
use axum::response::Redirect;
use axum::response::Response;
use axum::routing::get_service;
use axum::{
routing::{get, post},
Router,
};
use axum_sessions::extractors::ReadableSession;
use axum_sessions::{async_session::MemoryStore, SessionLayer};
use serde::Deserialize;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
use std::net::SocketAddr;
use std::sync::Arc;
use tasks::RawTask;
use tasks::Task;
use tera::Tera;
use thiserror::Error as ThisError;
use tower::ServiceBuilder;
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing::warn;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::FmtSubscriber;
#[derive(ThisError, Debug)]
pub enum Error {
Database(#[from] sqlx::Error),
Tera(#[from] tera::Error),
Tokio(#[from] tokio::task::JoinError),
Pbkdf2(pbkdf2::password_hash::Error),
Session(#[from] axum_sessions::async_session::serde_json::Error),
Io(#[from] std::io::Error),
Toml(#[from] toml::de::Error),
Other(String),
Redirect(Redirect),
StatusCode(StatusCode),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use std::fmt::Display;
#[allow(clippy::enum_glob_use)]
use Error::*;
match self {
Database(e) => Display::fmt(e, f),
Tera(e) => Display::fmt(e, f),
Tokio(e) => Display::fmt(e, f),
Pbkdf2(e) => Display::fmt(e, f),
Session(e) => Display::fmt(e, f),
Io(e) => Display::fmt(e, f),
Toml(e) => Display::fmt(e, f),
Other(e) => write!(f, "Error: {}", e),
Redirect(_) => write!(f, "[redirect]"),
StatusCode(c) => Display::fmt(c, f),
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
match self {
Self::Redirect(r) => r.into_response(),
_ => self.to_string().into_response(),
}
}
}
macro_rules! manual_from {
($($f:ty => $t:ident),*) => {
$(
impl From<$f> for Error {
fn from(source: $f) -> Self {
Self::$t(source)
}
}
)*
};
}
// We need to implement these manually because they don't meet the bounds set by thiserror
manual_from! {
pbkdf2::password_hash::Error => Pbkdf2,
String => Other,
Redirect => Redirect,
StatusCode => StatusCode
}
#[derive(Deserialize, Debug, Clone)]
pub struct Config {
secret: String,
connection_string: String,
template_dir: Option<String>,
static_dir: Option<std::path::PathBuf>,
address: Option<SocketAddr>,
}
#[tokio::main]
async fn main() -> Result<(), Error> {
let subscriber = FmtSubscriber::builder()
.with_env_filter(EnvFilter::from_env("TMTD_LOG"))
.with_ansi(true)
.with_file(true)
.with_line_number(true)
.finish();
tracing::subscriber::set_global_default(subscriber).unwrap();
let filename = std::env::var("TMTD_CONFIG").unwrap_or_else(|_| "config.toml".into());
info!("Loading file {}", filename);
let contents = std::fs::read_to_string(filename)?;
let config: Config = toml::from_str(&contents)?;
info!("Config loaded, connecting to database...");
let pool = Arc::new(
PgPoolOptions::new()
.connect(&config.connection_string)
.await?,
);
info!("Creating default tables (if needed)");
sqlx::query("create table if not exists users(id serial primary key, username text not null unique, password_hash text not null)").execute(&*pool).await?;
sqlx::query("create table if not exists tasks(id serial primary key, owner int not null, title text not null, description text not null, status int not null)").execute(&*pool).await?;
info!("Loading templates...");
let tera = Arc::new(Tera::new(&format!(
"{}/**/*.html",
config.template_dir.as_deref().unwrap_or("templates")
))?);
info!("Creating session layer");
let store = MemoryStore::new();
if config.secret.len() < 64 {
return Err("Secret must be at least 64 bytes!".to_string().into());
}
let session_layer =
SessionLayer::new(store, config.secret.as_bytes()).with_cookie_name("2m2d_session");
let task_routes = Router::new()
.route("/update/:id", get(tasks::update_form))
.route("/update/:id", post(tasks::update_backend))
.route("/create", get(tasks::create_form))
.route("/create", post(tasks::create_backend))
.route("/:id", get(tasks::task_detail))
.route("/delete/:id", get(tasks::delete_form))
.route("/delete/:id", post(tasks::delete_backend));
let serve_dir = get_service(ServeDir::new(
config.static_dir.unwrap_or_else(|| "static".into()),
))
.handle_error(|e: std::io::Error| async move {
warn!("Unhandled server error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled server error: {}", e),
)
});
let app = Router::new()
.route("/", get(homepage))
.route("/register", post(users::create_user))
.route("/register", get(users::register_form))
.route("/login", get(users::login_form))
.route("/login", post(users::login_backend))
.route("/logout", get(users::logout_form))
.route("/logout", post(users::logout_backend))
.nest("/tasks", task_routes)
.nest("/static", serve_dir)
.layer(
ServiceBuilder::new()
.layer(Extension(pool))
.layer(Extension(tera))
.layer(session_layer)
.layer(TraceLayer::new_for_http()),
);
info!("Starting server...");
let address = config
.address
.unwrap_or_else(|| "0.0.0.0:3000".parse().unwrap());
info!("Server started on {}", address);
axum::Server::bind(&address)
.serve(app.into_make_service())
.await
.unwrap();
Ok(())
}
#[macro_export]
macro_rules! login_or_redirect {
($session:expr, $to:literal) => {
match $session.get::<String>("logged_in_as") {
Some(username) => username,
None => {
::tracing::warn!("User not logged in, redirecting to {}", $to);
return Err(::axum::response::Redirect::to($to).into());
}
}
};
}
#[macro_export]
macro_rules! ctx {
($($key:literal => $val:expr),*) => {{
#[allow(unused_mut)]
let mut ctx = ::tera::Context::new();
$(
ctx.insert($key, &$val);
)*
ctx
}};
}
#[macro_export]
macro_rules! assert_or {
($e:expr, $err:expr) => {
if $e {
::tracing::warn!("{}", $err);
return Err($err.into());
}
};
}
#[derive(Deserialize)]
struct FilterParam {
filter: Option<u8>,
}
async fn homepage(
session: ReadableSession,
Extension(tera): Extension<Arc<Tera>>,
Extension(pool): Extension<Arc<Pool<Postgres>>>,
Query(FilterParam { filter }): Query<FilterParam>,
) -> Result<Html<String>, Error> {
let username = login_or_redirect!(session, "/login");
info!("Getting user ID...");
let (id,): (i32,) = sqlx::query_as("select id from users where username=$1")
.bind(&username)
.fetch_one(&*pool)
.await?;
info!("Getting tasks");
let mut tasks: Vec<RawTask> =
sqlx::query_as("select id,title,description,status from tasks where owner=$1")
.bind(id)
.fetch_all(&*pool)
.await?;
info!("Sorting and filtering tasks");
// TODO: refactor this when let_chains is finally stable
if let Some(filter) = filter {
if filter != 0 {
tasks.retain(|rt| rt.status == (filter as i32) - 1);
}
}
let mut tasks: Vec<Task> = tasks.into_iter().map(Task::from).collect();
tasks.sort_unstable();
tasks.reverse();
info!("rendering pagee");
let rendered = tera.render(
"home.html",
&ctx! {
"tasks" => tasks,
"filter" => filter.unwrap_or(255) // placeholder that won't match anything
},
)?;
info!("Rendering finished");
Ok(Html(rendered))
}