#![warn(clippy::pedantic)] #![allow( clippy::module_name_repetitions, clippy::similar_names, clippy::let_underscore_drop )] mod db; mod download; mod index; mod new_crate; mod owners; mod search; mod yank; use async_trait::async_trait; use axum::{ body::Body, extract::{FromRequest, RequestParts}, http::{header::AUTHORIZATION, StatusCode}, response::{IntoResponse, Response}, routing::{delete, get, put}, Extension, Json, Router, }; use db::DbUser; use download::download; use owners::{add_owners, list_owners, remove_owners}; use search::search; use serde::{ser::SerializeStruct, Deserialize, Serialize}; use sqlx::{error::BoxDynError, postgres::PgPoolOptions, query_as, Pool, Postgres}; use std::{fmt::Display, net::SocketAddr, path::PathBuf, sync::Arc}; use tokio::{fs, sync::Mutex}; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; use tracing_subscriber::filter::LevelFilter; use yank::{unyank, yank}; #[derive(Serialize)] pub struct Errors { errors: Vec, } impl Errors { #[allow(clippy::needless_pass_by_value)] // I'd fix this but frankly I'm too lazy fn new(v: impl ToString) -> Self { Self { errors: vec![SingleError { detail: v.to_string(), }], } } fn new_many(v: impl IntoIterator) -> Errors { Self { errors: v .into_iter() .map(|v| SingleError { detail: v.to_string(), }) .collect(), } } } impl From for Errors { fn from(v: T) -> Self { Self::new(v) } } impl IntoResponse for Errors { fn into_response(self) -> Response { Json(self).into_response() } } impl From for Response { fn from(v: Errors) -> Self { v.into_response() } } pub type RespResult = std::result::Result; #[derive(Serialize)] struct SingleError { detail: String, } fn get_crate_prefix(s: &str) -> Result { if s.is_empty() { return Err("Crate name must be non-empty"); } else if !s.is_ascii() { return Err("Crate name must be ASCII"); } let mut buf = PathBuf::new(); if s.len() == 1 { buf.push("1"); } else if s.len() == 2 { buf.push("2"); } else if s.len() == 3 { buf.push("3"); let c = s.chars().next().unwrap().to_ascii_lowercase(); let mut b = [0]; buf.push(c.encode_utf8(&mut b)); } else { let mut b: [u8; 4] = s.as_bytes()[0..4].try_into().unwrap(); let s = std::str::from_utf8_mut(&mut b).unwrap(); s.make_ascii_lowercase(); buf.push(&s[0..2]); buf.push(&s[2..4]); } Ok(buf) } pub struct State { index_dir: PathBuf, crate_dir: PathBuf, db: Pool, } static INDEX_LOCK: Mutex<()> = Mutex::const_new(()); #[derive(Deserialize)] pub struct Config { working_dir: PathBuf, postgres_uri: String, listen_uri: SocketAddr, } #[tokio::main] async fn main() -> Result<(), BoxDynError> { let config = toml::from_str::( &fs::read_to_string("/etc/warehouse/warehouse.toml") .await .expect("Config file missing"), ) .expect("Invalid config file"); let mut index_dir = config.working_dir.clone(); index_dir.push("index"); let mut crate_dir = config.working_dir; crate_dir.push("crates"); let state = Arc::new(State { index_dir, crate_dir, db: PgPoolOptions::new().connect(&config.postgres_uri).await?, }); tracing_subscriber::fmt() .with_max_level(LevelFilter::TRACE) .init(); db::init(&state.db).await?; let app = Router::new() .route("/api/v1/crates/new", put(new_crate::new_crate)) .route("/api/v1/crates/:crate_name/:version/yank", delete(yank)) .route("/api/v1/crates/:crate_name/:version/unyank", put(unyank)) .route( "/api/v1/crates/:crate_name/owners", get(list_owners).put(add_owners).delete(remove_owners), ) .route( "/api/v1/crates/:crate_name/:version/download", get(download), ) .route("/api/v1/crates", get(search)) .layer( ServiceBuilder::new() .layer(Extension(state.clone())) .layer(TraceLayer::new_for_http()), ); fs::create_dir_all(&state.index_dir).await?; fs::create_dir_all(&state.crate_dir).await?; axum::Server::bind(&config.listen_uri) .serve(app.into_make_service()) .await .unwrap(); Ok(()) } pub struct Auth(DbUser); #[async_trait] impl FromRequest for Auth { type Rejection = Response; async fn from_request(req: &mut RequestParts) -> Result { let Extension(state): Extension> = Extension::from_request(req) .await .map_err(IntoResponse::into_response)?; if let Some(auth) = req .headers() .get(AUTHORIZATION) .and_then(|v| v.to_str().ok()) { if let Some(db_user) = sqlx::query_as::<_, DbUser>("SELECT * FROM users WHERE credential = $1 LIMIT 1") .bind(auth) .fetch_optional(&state.db) .await .map_err(db_error) .map_err(IntoResponse::into_response)? { Ok(Self(db_user)) } else { Err(StatusCode::FORBIDDEN.into_response()) } } else { Err(StatusCode::FORBIDDEN.into_response()) } } } /// Checks if the user has permissions to modify the crate. /// /// # Errors /// Returns a FORBIDDEN status code if the user does not have the required permissions. pub async fn auth(crate_name: &str, auth_user: &DbUser, state: &State) -> Result<(), Response> { let (is_authenticated,): (bool,) = query_as("SELECT $1 = ANY (crates.owners) FROM crates WHERE name = $2") .bind(auth_user.id) .bind(&crate_name) .fetch_one(&state.db) .await .map_err(db_error) .map_err(IntoResponse::into_response)?; if !is_authenticated { return Err(StatusCode::FORBIDDEN.into_response()); } Ok(()) } pub struct Success; impl Serialize for Success { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { let mut serialize_struct = serializer.serialize_struct("Success", 1)?; serialize_struct.serialize_field("ok", &true)?; serialize_struct.end() } } impl IntoResponse for Success { fn into_response(self) -> Response { Json(self).into_response() } } fn internal_error(e: E) -> Errors { Errors::new(format_args!("Internal server error: {e}")) } fn db_error(e: E) -> Errors { Errors::new(format_args!("Database error: {e}")) }