287 lines
8.3 KiB
Rust
287 lines
8.3 KiB
Rust
#![warn(clippy::pedantic)]
|
|
#![allow(clippy::unused_async)]
|
|
|
|
mod tasks;
|
|
mod users;
|
|
|
|
use axum::extract::Extension;
|
|
use axum::extract::Query;
|
|
use axum::http::StatusCode;
|
|
use axum::response::Html;
|
|
use axum::response::IntoResponse;
|
|
use axum::response::Redirect;
|
|
use axum::response::Response;
|
|
use axum::routing::get_service;
|
|
use axum::{
|
|
routing::{get, post},
|
|
Router,
|
|
};
|
|
use axum_sessions::extractors::ReadableSession;
|
|
use axum_sessions::{async_session::MemoryStore, SessionLayer};
|
|
use serde::Deserialize;
|
|
use sqlx::postgres::PgPoolOptions;
|
|
use sqlx::{Pool, Postgres};
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tasks::RawTask;
|
|
use tasks::Task;
|
|
use tera::Tera;
|
|
use thiserror::Error as ThisError;
|
|
use tower::ServiceBuilder;
|
|
use tower_http::services::ServeDir;
|
|
use tower_http::trace::TraceLayer;
|
|
use tracing::info;
|
|
use tracing::warn;
|
|
use tracing_subscriber::EnvFilter;
|
|
use tracing_subscriber::FmtSubscriber;
|
|
|
|
#[derive(ThisError, Debug)]
|
|
pub enum Error {
|
|
Database(#[from] sqlx::Error),
|
|
Tera(#[from] tera::Error),
|
|
Tokio(#[from] tokio::task::JoinError),
|
|
Pbkdf2(pbkdf2::password_hash::Error),
|
|
Session(#[from] axum_sessions::async_session::serde_json::Error),
|
|
Io(#[from] std::io::Error),
|
|
Toml(#[from] toml::de::Error),
|
|
Other(String),
|
|
Redirect(Redirect),
|
|
StatusCode(StatusCode),
|
|
}
|
|
|
|
impl std::fmt::Display for Error {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
use std::fmt::Display;
|
|
#[allow(clippy::enum_glob_use)]
|
|
use Error::*;
|
|
match self {
|
|
Database(e) => Display::fmt(e, f),
|
|
Tera(e) => Display::fmt(e, f),
|
|
Tokio(e) => Display::fmt(e, f),
|
|
Pbkdf2(e) => Display::fmt(e, f),
|
|
Session(e) => Display::fmt(e, f),
|
|
Io(e) => Display::fmt(e, f),
|
|
Toml(e) => Display::fmt(e, f),
|
|
Other(e) => write!(f, "Error: {}", e),
|
|
Redirect(_) => write!(f, "[redirect]"),
|
|
StatusCode(c) => Display::fmt(c, f),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for Error {
|
|
fn into_response(self) -> Response {
|
|
match self {
|
|
Self::Redirect(r) => r.into_response(),
|
|
_ => self.to_string().into_response(),
|
|
}
|
|
}
|
|
}
|
|
|
|
macro_rules! manual_from {
|
|
($($f:ty => $t:ident),*) => {
|
|
$(
|
|
impl From<$f> for Error {
|
|
fn from(source: $f) -> Self {
|
|
Self::$t(source)
|
|
}
|
|
}
|
|
)*
|
|
};
|
|
}
|
|
|
|
// We need to implement these manually because they don't meet the bounds set by thiserror
|
|
manual_from! {
|
|
pbkdf2::password_hash::Error => Pbkdf2,
|
|
String => Other,
|
|
Redirect => Redirect,
|
|
StatusCode => StatusCode
|
|
}
|
|
|
|
#[derive(Deserialize, Debug, Clone)]
|
|
pub struct Config {
|
|
secret: String,
|
|
connection_string: String,
|
|
template_dir: Option<String>,
|
|
static_dir: Option<std::path::PathBuf>,
|
|
address: Option<SocketAddr>,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Error> {
|
|
let subscriber = FmtSubscriber::builder()
|
|
.with_env_filter(EnvFilter::from_env("TMTD_LOG"))
|
|
.with_ansi(true)
|
|
.with_file(true)
|
|
.with_line_number(true)
|
|
.finish();
|
|
|
|
tracing::subscriber::set_global_default(subscriber).unwrap();
|
|
|
|
let filename = std::env::var("TMTD_CONFIG").unwrap_or_else(|_| "config.toml".into());
|
|
info!("Loading file {}", filename);
|
|
let contents = std::fs::read_to_string(filename)?;
|
|
let config: Config = toml::from_str(&contents)?;
|
|
info!("Config loaded, connecting to database...");
|
|
let pool = Arc::new(
|
|
PgPoolOptions::new()
|
|
.connect(&config.connection_string)
|
|
.await?,
|
|
);
|
|
|
|
info!("Creating default tables (if needed)");
|
|
sqlx::query("create table if not exists users(id serial primary key, username text not null unique, password_hash text not null)").execute(&*pool).await?;
|
|
sqlx::query("create table if not exists tasks(id serial primary key, owner int not null, title text not null, description text not null, status int not null)").execute(&*pool).await?;
|
|
|
|
info!("Loading templates...");
|
|
let tera = Arc::new(Tera::new(&format!(
|
|
"{}/**/*.html",
|
|
config.template_dir.as_deref().unwrap_or("templates")
|
|
))?);
|
|
|
|
info!("Creating session layer");
|
|
let store = MemoryStore::new();
|
|
if config.secret.len() < 64 {
|
|
return Err("Secret must be at least 64 bytes!".to_string().into());
|
|
}
|
|
let session_layer =
|
|
SessionLayer::new(store, config.secret.as_bytes()).with_cookie_name("2m2d_session");
|
|
|
|
let task_routes = Router::new()
|
|
.route("/update/:id", get(tasks::update_form))
|
|
.route("/update/:id", post(tasks::update_backend))
|
|
.route("/create", get(tasks::create_form))
|
|
.route("/create", post(tasks::create_backend))
|
|
.route("/:id", get(tasks::task_detail))
|
|
.route("/delete/:id", get(tasks::delete_form))
|
|
.route("/delete/:id", post(tasks::delete_backend));
|
|
|
|
let serve_dir = get_service(ServeDir::new(
|
|
config.static_dir.unwrap_or_else(|| "static".into()),
|
|
))
|
|
.handle_error(|e: std::io::Error| async move {
|
|
warn!("Unhandled server error: {}", e);
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
format!("Unhandled server error: {}", e),
|
|
)
|
|
});
|
|
|
|
let app = Router::new()
|
|
.route("/", get(homepage))
|
|
.route("/register", post(users::create_user))
|
|
.route("/register", get(users::register_form))
|
|
.route("/login", get(users::login_form))
|
|
.route("/login", post(users::login_backend))
|
|
.route("/logout", get(users::logout_form))
|
|
.route("/logout", post(users::logout_backend))
|
|
.nest("/tasks", task_routes)
|
|
.nest("/static", serve_dir)
|
|
.layer(
|
|
ServiceBuilder::new()
|
|
.layer(Extension(pool))
|
|
.layer(Extension(tera))
|
|
.layer(session_layer)
|
|
.layer(TraceLayer::new_for_http()),
|
|
);
|
|
|
|
info!("Starting server...");
|
|
let address = config
|
|
.address
|
|
.unwrap_or_else(|| "0.0.0.0:3000".parse().unwrap());
|
|
info!("Server started on {}", address);
|
|
axum::Server::bind(&address)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! login_or_redirect {
|
|
($session:expr, $to:literal) => {
|
|
match $session.get::<String>("logged_in_as") {
|
|
Some(username) => username,
|
|
None => {
|
|
::tracing::warn!("User not logged in, redirecting to {}", $to);
|
|
return Err(::axum::response::Redirect::to($to).into());
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! ctx {
|
|
($($key:literal => $val:expr),*) => {{
|
|
#[allow(unused_mut)]
|
|
let mut ctx = ::tera::Context::new();
|
|
|
|
$(
|
|
ctx.insert($key, &$val);
|
|
)*
|
|
|
|
ctx
|
|
}};
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! assert_or {
|
|
($e:expr, $err:expr) => {
|
|
if $e {
|
|
::tracing::warn!("{}", $err);
|
|
return Err($err.into());
|
|
}
|
|
};
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct FilterParam {
|
|
filter: Option<u8>,
|
|
}
|
|
|
|
async fn homepage(
|
|
session: ReadableSession,
|
|
Extension(tera): Extension<Arc<Tera>>,
|
|
Extension(pool): Extension<Arc<Pool<Postgres>>>,
|
|
Query(FilterParam { filter }): Query<FilterParam>,
|
|
) -> Result<Html<String>, Error> {
|
|
let username = login_or_redirect!(session, "/login");
|
|
|
|
info!("Getting user ID...");
|
|
let (id,): (i32,) = sqlx::query_as("select id from users where username=$1")
|
|
.bind(&username)
|
|
.fetch_one(&*pool)
|
|
.await?;
|
|
|
|
info!("Getting tasks");
|
|
let mut tasks: Vec<RawTask> =
|
|
sqlx::query_as("select id,title,description,status from tasks where owner=$1")
|
|
.bind(id)
|
|
.fetch_all(&*pool)
|
|
.await?;
|
|
|
|
info!("Sorting and filtering tasks");
|
|
// TODO: refactor this when let_chains is finally stable
|
|
if let Some(filter) = filter {
|
|
if filter != 0 {
|
|
tasks.retain(|rt| rt.status == (filter as i32) - 1);
|
|
}
|
|
}
|
|
let mut tasks: Vec<Task> = tasks.into_iter().map(Task::from).collect();
|
|
tasks.sort_unstable();
|
|
tasks.reverse();
|
|
|
|
info!("rendering pagee");
|
|
let rendered = tera.render(
|
|
"home.html",
|
|
&ctx! {
|
|
"tasks" => tasks,
|
|
"filter" => filter.unwrap_or(255) // placeholder that won't match anything
|
|
},
|
|
)?;
|
|
|
|
info!("Rendering finished");
|
|
Ok(Html(rendered))
|
|
}
|