269 lines
7 KiB
Rust
269 lines
7 KiB
Rust
#![warn(clippy::pedantic)]
|
|
#![allow(
|
|
clippy::module_name_repetitions,
|
|
clippy::similar_names,
|
|
clippy::let_underscore_drop
|
|
)]
|
|
|
|
mod db;
|
|
mod download;
|
|
mod index;
|
|
mod new_crate;
|
|
mod owners;
|
|
mod search;
|
|
mod yank;
|
|
|
|
use async_trait::async_trait;
|
|
use axum::{
|
|
body::Body,
|
|
extract::{FromRequest, RequestParts},
|
|
http::{header::AUTHORIZATION, StatusCode},
|
|
response::{IntoResponse, Response},
|
|
routing::{delete, get, put},
|
|
Extension, Json, Router,
|
|
};
|
|
use db::DbUser;
|
|
use download::download;
|
|
use owners::{add_owners, list_owners, remove_owners};
|
|
use search::search;
|
|
use serde::{ser::SerializeStruct, Deserialize, Serialize};
|
|
use sqlx::{error::BoxDynError, postgres::PgPoolOptions, query_as, Pool, Postgres};
|
|
use std::{fmt::Display, net::SocketAddr, path::PathBuf, sync::Arc};
|
|
use tokio::{fs, sync::Mutex};
|
|
use tower::ServiceBuilder;
|
|
use tower_http::trace::TraceLayer;
|
|
use tracing_subscriber::filter::LevelFilter;
|
|
use yank::{unyank, yank};
|
|
|
|
#[derive(Serialize)]
|
|
pub struct Errors {
|
|
errors: Vec<SingleError>,
|
|
}
|
|
|
|
impl Errors {
|
|
#[allow(clippy::needless_pass_by_value)] // I'd fix this but frankly I'm too lazy
|
|
fn new(v: impl ToString) -> Self {
|
|
Self {
|
|
errors: vec![SingleError {
|
|
detail: v.to_string(),
|
|
}],
|
|
}
|
|
}
|
|
|
|
fn new_many(v: impl IntoIterator<Item = impl ToString>) -> Errors {
|
|
Self {
|
|
errors: v
|
|
.into_iter()
|
|
.map(|v| SingleError {
|
|
detail: v.to_string(),
|
|
})
|
|
.collect(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: ToString> From<T> for Errors {
|
|
fn from(v: T) -> Self {
|
|
Self::new(v)
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for Errors {
|
|
fn into_response(self) -> Response {
|
|
Json(self).into_response()
|
|
}
|
|
}
|
|
|
|
impl From<Errors> for Response {
|
|
fn from(v: Errors) -> Self {
|
|
v.into_response()
|
|
}
|
|
}
|
|
|
|
pub type RespResult<T = Success> = std::result::Result<T, Response>;
|
|
|
|
#[derive(Serialize)]
|
|
struct SingleError {
|
|
detail: String,
|
|
}
|
|
|
|
fn get_crate_prefix(s: &str) -> Result<PathBuf, &'static str> {
|
|
if s.is_empty() {
|
|
return Err("Crate name must be non-empty");
|
|
} else if !s.is_ascii() {
|
|
return Err("Crate name must be ASCII");
|
|
}
|
|
|
|
let mut buf = PathBuf::new();
|
|
if s.len() == 1 {
|
|
buf.push("1");
|
|
} else if s.len() == 2 {
|
|
buf.push("2");
|
|
} else if s.len() == 3 {
|
|
buf.push("3");
|
|
let c = s.chars().next().unwrap().to_ascii_lowercase();
|
|
let mut b = [0];
|
|
buf.push(c.encode_utf8(&mut b));
|
|
} else {
|
|
let mut b: [u8; 4] = s.as_bytes()[0..4].try_into().unwrap();
|
|
let s = std::str::from_utf8_mut(&mut b).unwrap();
|
|
s.make_ascii_lowercase();
|
|
buf.push(&s[0..2]);
|
|
buf.push(&s[2..4]);
|
|
}
|
|
|
|
Ok(buf)
|
|
}
|
|
|
|
pub struct State {
|
|
index_dir: PathBuf,
|
|
crate_dir: PathBuf,
|
|
db: Pool<Postgres>,
|
|
}
|
|
|
|
static INDEX_LOCK: Mutex<()> = Mutex::const_new(());
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct Config {
|
|
working_dir: PathBuf,
|
|
postgres_uri: String,
|
|
listen_uri: SocketAddr,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), BoxDynError> {
|
|
let config = toml::from_str::<Config>(
|
|
&fs::read_to_string("/etc/warehouse/warehouse.toml")
|
|
.await
|
|
.expect("Config file missing"),
|
|
)
|
|
.expect("Invalid config file");
|
|
|
|
let mut index_dir = config.working_dir.clone();
|
|
index_dir.push("index");
|
|
let mut crate_dir = config.working_dir;
|
|
crate_dir.push("crates");
|
|
let state = Arc::new(State {
|
|
index_dir,
|
|
crate_dir,
|
|
db: PgPoolOptions::new().connect(&config.postgres_uri).await?,
|
|
});
|
|
|
|
tracing_subscriber::fmt()
|
|
.with_max_level(LevelFilter::TRACE)
|
|
.init();
|
|
|
|
db::init(&state.db).await?;
|
|
|
|
let app = Router::new()
|
|
.route("/api/v1/crates/new", put(new_crate::new_crate))
|
|
.route("/api/v1/crates/:crate_name/:version/yank", delete(yank))
|
|
.route("/api/v1/crates/:crate_name/:version/unyank", put(unyank))
|
|
.route(
|
|
"/api/v1/crates/:crate_name/owners",
|
|
get(list_owners).put(add_owners).delete(remove_owners),
|
|
)
|
|
.route(
|
|
"/api/v1/crates/:crate_name/:version/download",
|
|
get(download),
|
|
)
|
|
.route("/api/v1/crates", get(search))
|
|
.layer(
|
|
ServiceBuilder::new()
|
|
.layer(Extension(state.clone()))
|
|
.layer(TraceLayer::new_for_http()),
|
|
);
|
|
|
|
fs::create_dir_all(&state.index_dir).await?;
|
|
fs::create_dir_all(&state.crate_dir).await?;
|
|
|
|
axum::Server::bind(&config.listen_uri)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub struct Auth(DbUser);
|
|
|
|
#[async_trait]
|
|
impl FromRequest<Body> for Auth {
|
|
type Rejection = Response;
|
|
|
|
async fn from_request(req: &mut RequestParts<Body>) -> Result<Self, Self::Rejection> {
|
|
let Extension(state): Extension<Arc<State>> = Extension::from_request(req)
|
|
.await
|
|
.map_err(IntoResponse::into_response)?;
|
|
|
|
if let Some(auth) = req
|
|
.headers()
|
|
.get(AUTHORIZATION)
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
if let Some(db_user) =
|
|
sqlx::query_as::<_, DbUser>("SELECT * FROM users WHERE credential = $1 LIMIT 1")
|
|
.bind(auth)
|
|
.fetch_optional(&state.db)
|
|
.await
|
|
.map_err(db_error)
|
|
.map_err(IntoResponse::into_response)?
|
|
{
|
|
Ok(Self(db_user))
|
|
} else {
|
|
Err(StatusCode::FORBIDDEN.into_response())
|
|
}
|
|
} else {
|
|
Err(StatusCode::FORBIDDEN.into_response())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Checks if the user has permissions to modify the crate.
|
|
///
|
|
/// # Errors
|
|
/// Returns a FORBIDDEN status code if the user does not have the required permissions.
|
|
pub async fn auth(crate_name: &str, auth_user: &DbUser, state: &State) -> Result<(), Response> {
|
|
let (is_authenticated,): (bool,) =
|
|
query_as("SELECT $1 = ANY (crates.owners) FROM crates WHERE name = $2")
|
|
.bind(auth_user.id)
|
|
.bind(&crate_name)
|
|
.fetch_one(&state.db)
|
|
.await
|
|
.map_err(db_error)
|
|
.map_err(IntoResponse::into_response)?;
|
|
|
|
if !is_authenticated {
|
|
return Err(StatusCode::FORBIDDEN.into_response());
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub struct Success;
|
|
|
|
impl Serialize for Success {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
let mut serialize_struct = serializer.serialize_struct("Success", 1)?;
|
|
serialize_struct.serialize_field("ok", &true)?;
|
|
serialize_struct.end()
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for Success {
|
|
fn into_response(self) -> Response {
|
|
Json(self).into_response()
|
|
}
|
|
}
|
|
|
|
fn internal_error<E: Display>(e: E) -> Errors {
|
|
Errors::new(format_args!("Internal server error: {e}"))
|
|
}
|
|
|
|
fn db_error<E: Display>(e: E) -> Errors {
|
|
Errors::new(format_args!("Database error: {e}"))
|
|
}
|