diff --git a/.gitignore b/.gitignore index 3509de0..a233cf0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target tmtd.toml .idea -.vs \ No newline at end of file +.vs +.env diff --git a/Cargo.lock b/Cargo.lock index a3e5a9b..51a3d0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1715,6 +1715,7 @@ dependencies = [ "anyhow", "async-sqlx-session", "chrono", + "dotenv", "serde", "sqlx", "tera", diff --git a/Cargo.toml b/Cargo.toml index a9a8e90..69e85ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ lto = true tokio = { version = "1", features = ["rt", "sync", "signal", "macros"] } toml = "0.5" serde = { version = "1.0", features = ["derive"] } -sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "postgres"] } +sqlx = { version = "0.5", default-features = false, features = ["runtime-tokio-rustls", "postgres", "macros"] } tracing = "0.1" tracing-subscriber = "0.3" warp = { version = "0.3", default-features = false } @@ -20,3 +20,9 @@ tera = "1.15" async-sqlx-session = { version = "0.4", default-features = false, features = ["pg"] } anyhow = "1.0" chrono = "0.4" +dotenv = "0.15" + +[build-dependencies] +toml = "0.5" +serde = { version = "1.0", features = ["derive"] } + diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..021f16d --- /dev/null +++ b/build.rs @@ -0,0 +1,20 @@ +use std::env; +use std::fs; +use serde::Deserialize; + +#[allow(dead_code)] +#[derive(Deserialize)] +struct Config { + listen_addr: String, + admin_pass: String, + connection_string: String, + log_level: Option, +} + +fn main() { + let config_var = env::var("TMTD_CONFIG"); + let config_path = config_var.as_deref().unwrap_or("tmtd.toml"); + let config = fs::read_to_string(config_path).unwrap(); + let cfg: Config = toml::from_str(&config).unwrap(); + fs::write(".env", format!("DATABASE_URL={}", cfg.connection_string)).unwrap(); +} diff --git a/src/database.rs b/src/database.rs index 3f9f71a..213196f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -53,19 +53,22 @@ impl Database { self.0.close().await; } - pub async fn create_user(&self, username: &str, passwod: &str) -> anyhow::Result<()> { + pub async fn create_user(&self, username: &str, passwod: &str) -> anyhow::Result { if username.len() < 17 { - let mut conn = self.0.acquire().await?; let mut hasher = DefaultHasher::new(); passwod.hash(&mut hasher); - let hash = hasher.finish(); - tracing::debug!("Password hash for user {}: {:x}", username, hash); - conn.execute(format!(" - INSERT INTO users(username, hash) - VALUES ({}, {:x}) - RETURNING id; - ", username, hash).as_str()).await?; - return Ok(()) + let hash = format!("{:x}", hasher.finish()); + tracing::debug!("Password hash for user {}: {}", username, hash); + let user = sqlx::query!( + "INSERT INTO users (username, hash) + VALUES ($1::Varchar, $2::Varchar) + RETURNING id", + username, + hash + ) + .fetch_one(&self.0) + .await?; + return Ok(user.id) } Err(anyhow::anyhow!("Username is longer then 16 characters")) } diff --git a/src/main.rs b/src/main.rs index f0900fa..f9ff2fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -54,6 +54,7 @@ async fn terminate_signal() { #[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { + dotenv::dotenv().ok(); let config_var = env::var("TMTD_CONFIG"); let config_path = config_var.as_deref().unwrap_or("tmtd.toml"); println!("Loading config from '{}'...", config_path); @@ -75,7 +76,6 @@ async fn main() -> anyhow::Result<()> { let session_store = PostgresSessionStore::from_client(database.pool()).with_table_name("sessions"); session_store.migrate().await?; - database.create_user("famfo", "1234").await?; let cleanup_task = spawn_session_cleanup_task(&session_store, Duration::from_secs(600), ctx.subscribe());