warehouse/src/main.rs
2022-09-08 13:14:43 -05:00

269 lines
7 KiB
Rust

#![warn(clippy::pedantic)]
#![allow(
clippy::module_name_repetitions,
clippy::similar_names,
clippy::let_underscore_drop
)]
mod db;
mod download;
mod index;
mod new_crate;
mod owners;
mod search;
mod yank;
use async_trait::async_trait;
use axum::{
body::Body,
extract::{FromRequest, RequestParts},
http::{header::AUTHORIZATION, StatusCode},
response::{IntoResponse, Response},
routing::{delete, get, put},
Extension, Json, Router,
};
use db::DbUser;
use download::download;
use owners::{add_owners, list_owners, remove_owners};
use search::search;
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use sqlx::{error::BoxDynError, postgres::PgPoolOptions, query_as, Pool, Postgres};
use std::{fmt::Display, net::SocketAddr, path::PathBuf, sync::Arc};
use tokio::{fs, sync::Mutex};
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tracing_subscriber::filter::LevelFilter;
use yank::{unyank, yank};
#[derive(Serialize)]
pub struct Errors {
errors: Vec<SingleError>,
}
impl Errors {
#[allow(clippy::needless_pass_by_value)] // I'd fix this but frankly I'm too lazy
fn new(v: impl ToString) -> Self {
Self {
errors: vec![SingleError {
detail: v.to_string(),
}],
}
}
fn new_many(v: impl IntoIterator<Item = impl ToString>) -> Errors {
Self {
errors: v
.into_iter()
.map(|v| SingleError {
detail: v.to_string(),
})
.collect(),
}
}
}
impl<T: ToString> From<T> for Errors {
fn from(v: T) -> Self {
Self::new(v)
}
}
impl IntoResponse for Errors {
fn into_response(self) -> Response {
Json(self).into_response()
}
}
impl From<Errors> for Response {
fn from(v: Errors) -> Self {
v.into_response()
}
}
pub type RespResult<T = Success> = std::result::Result<T, Response>;
#[derive(Serialize)]
struct SingleError {
detail: String,
}
fn get_crate_prefix(s: &str) -> Result<PathBuf, &'static str> {
if s.is_empty() {
return Err("Crate name must be non-empty");
} else if !s.is_ascii() {
return Err("Crate name must be ASCII");
}
let mut buf = PathBuf::new();
if s.len() == 1 {
buf.push("1");
} else if s.len() == 2 {
buf.push("2");
} else if s.len() == 3 {
buf.push("3");
let c = s.chars().next().unwrap().to_ascii_lowercase();
let mut b = [0];
buf.push(c.encode_utf8(&mut b));
} else {
let mut b: [u8; 4] = s.as_bytes()[0..4].try_into().unwrap();
let s = std::str::from_utf8_mut(&mut b).unwrap();
s.make_ascii_lowercase();
buf.push(&s[0..2]);
buf.push(&s[2..4]);
}
Ok(buf)
}
pub struct State {
index_dir: PathBuf,
crate_dir: PathBuf,
db: Pool<Postgres>,
}
static INDEX_LOCK: Mutex<()> = Mutex::const_new(());
#[derive(Deserialize)]
pub struct Config {
working_dir: PathBuf,
postgres_uri: String,
listen_uri: SocketAddr,
}
#[tokio::main]
async fn main() -> Result<(), BoxDynError> {
let config = toml::from_str::<Config>(
&fs::read_to_string("/etc/warehouse/warehouse.toml")
.await
.expect("Config file missing"),
)
.expect("Invalid config file");
let mut index_dir = config.working_dir.clone();
index_dir.push("index");
let mut crate_dir = config.working_dir;
crate_dir.push("crates");
let state = Arc::new(State {
index_dir,
crate_dir,
db: PgPoolOptions::new().connect(&config.postgres_uri).await?,
});
tracing_subscriber::fmt()
.with_max_level(LevelFilter::TRACE)
.init();
db::init(&state.db).await?;
let app = Router::new()
.route("/api/v1/crates/new", put(new_crate::new_crate))
.route("/api/v1/crates/:crate_name/:version/yank", delete(yank))
.route("/api/v1/crates/:crate_name/:version/unyank", put(unyank))
.route(
"/api/v1/crates/:crate_name/owners",
get(list_owners).put(add_owners).delete(remove_owners),
)
.route(
"/api/v1/crates/:crate_name/:version/download",
get(download),
)
.route("/api/v1/crates", get(search))
.layer(
ServiceBuilder::new()
.layer(Extension(state.clone()))
.layer(TraceLayer::new_for_http()),
);
fs::create_dir_all(&state.index_dir).await?;
fs::create_dir_all(&state.crate_dir).await?;
axum::Server::bind(&config.listen_uri)
.serve(app.into_make_service())
.await
.unwrap();
Ok(())
}
pub struct Auth(DbUser);
#[async_trait]
impl FromRequest<Body> for Auth {
type Rejection = Response;
async fn from_request(req: &mut RequestParts<Body>) -> Result<Self, Self::Rejection> {
let Extension(state): Extension<Arc<State>> = Extension::from_request(req)
.await
.map_err(IntoResponse::into_response)?;
if let Some(auth) = req
.headers()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
if let Some(db_user) =
sqlx::query_as::<_, DbUser>("SELECT * FROM users WHERE credential = $1 LIMIT 1")
.bind(auth)
.fetch_optional(&state.db)
.await
.map_err(db_error)
.map_err(IntoResponse::into_response)?
{
Ok(Self(db_user))
} else {
Err(StatusCode::FORBIDDEN.into_response())
}
} else {
Err(StatusCode::FORBIDDEN.into_response())
}
}
}
/// Checks if the user has permissions to modify the crate.
///
/// # Errors
/// Returns a FORBIDDEN status code if the user does not have the required permissions.
pub async fn auth(crate_name: &str, auth_user: &DbUser, state: &State) -> Result<(), Response> {
let (is_authenticated,): (bool,) =
query_as("SELECT $1 = ANY (crates.owners) FROM crates WHERE name = $2")
.bind(auth_user.id)
.bind(&crate_name)
.fetch_one(&state.db)
.await
.map_err(db_error)
.map_err(IntoResponse::into_response)?;
if !is_authenticated {
return Err(StatusCode::FORBIDDEN.into_response());
}
Ok(())
}
pub struct Success;
impl Serialize for Success {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut serialize_struct = serializer.serialize_struct("Success", 1)?;
serialize_struct.serialize_field("ok", &true)?;
serialize_struct.end()
}
}
impl IntoResponse for Success {
fn into_response(self) -> Response {
Json(self).into_response()
}
}
fn internal_error<E: Display>(e: E) -> Errors {
Errors::new(format_args!("Internal server error: {e}"))
}
fn db_error<E: Display>(e: E) -> Errors {
Errors::new(format_args!("Database error: {e}"))
}