diff --git a/src/bots/leek.rs b/src/bots/leek.rs index 411077f..27206e4 100644 --- a/src/bots/leek.rs +++ b/src/bots/leek.rs @@ -1,6 +1,9 @@ use arrayvec::{ArrayString, CapacityError}; use rand::Rng; -use std::{error::Error, fmt::{Debug, Display}}; +use std::{ + error::Error, + fmt::{Debug, Display}, +}; #[derive(Debug)] pub struct LeekCapacityError(CapacityError); @@ -19,7 +22,6 @@ impl From> for LeekCapacityError { } } - type LeekResult = Result, LeekCapacityError>; pub fn mock(input: &str) -> LeekResult { @@ -99,4 +101,3 @@ pub fn owoify(input: &str) -> LeekResult { builder.try_push_str("~~")?; Ok(builder) } - diff --git a/src/bots/misc.rs b/src/bots/misc.rs index d4eb697..0fbdab5 100644 --- a/src/bots/misc.rs +++ b/src/bots/misc.rs @@ -1,9 +1,16 @@ -use std::collections::HashMap; use arrayvec::ArrayString; use meval::Context; use serde_json::Value; +use std::collections::HashMap; use std::fmt::Write; +#[derive(Debug)] +pub enum LeekCommand { + Owo, + Leet, + Mock, +} + pub async fn get_waifu_pic(category: &str) -> anyhow::Result> { let api_resp = reqwest::get(format!("https://api.waifu.pics/sfw/{}", category)) .await? @@ -15,16 +22,47 @@ pub async fn get_waifu_pic(category: &str) -> anyhow::Result> { Ok(url) } -pub fn mathbot(author: String, expr: Option<&str>, last_evals: &mut HashMap) -> anyhow::Result> { +pub fn mathbot( + author: String, + expr: Option<&str>, + last_evals: &mut HashMap, +) -> anyhow::Result> { if let Some(expr) = expr { let last_eval = last_evals.entry(author).or_insert(0.0); let mut meval_ctx = Context::new(); let mut result = ArrayString::new(); let value = meval::eval_str_with_context(expr, meval_ctx.var("x", *last_eval))?; *last_eval = value; + tracing::debug!("{} = {}", expr, value); write!(result, "{} = {}", expr, value)?; Ok(result) } else { Ok(ArrayString::from("No expression to evaluate")?) } } + +pub async fn execute_leek( + state: &mut crate::AppState, + cmd: LeekCommand, + channel: &str, + nick: &str, +) -> anyhow::Result<()> { + match state.last_msgs.get(nick) { + Some(msg) => { + tracing::debug!("Executing {:?} on {:?}", cmd, msg); + let output = match cmd { + LeekCommand::Owo => super::leek::owoify(msg)?, + LeekCommand::Leet => super::leek::leetify(msg)?, + LeekCommand::Mock => super::leek::mock(msg)?, + }; + state.client.privmsg(channel, &output).await?; + } + None => { + state + .client + .privmsg(channel, "No last messages found.") + .await?; + } + } + Ok(()) +} diff --git a/src/bots/mod.rs b/src/bots/mod.rs index 81f6a61..99c5cd8 100644 --- a/src/bots/mod.rs +++ b/src/bots/mod.rs @@ -1,4 +1,4 @@ pub mod leek; -pub mod title; pub mod misc; pub mod sed; +pub mod title; diff --git a/src/bots/sed.rs b/src/bots/sed.rs index ef8da5c..8c39c83 100644 --- a/src/bots/sed.rs +++ b/src/bots/sed.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, error::Error}; +use std::{error::Error, fmt::Display}; use arrayvec::{ArrayString, CapacityError}; use fancy_regex::Regex; @@ -9,7 +9,7 @@ use sedregex::find_and_replace; pub enum SedError { Capacity(CapacityError), Regex(fancy_regex::Error), - SedRegex(sedregex::ErrorKind) + SedRegex(sedregex::ErrorKind), } impl Display for SedError { @@ -56,8 +56,8 @@ pub fn resolve(prev_msg: &str, cmd: &str) -> SedResult { Ok(Some(ArrayString::from(&formatted)?)) } else { Ok(None) - } + }; } Ok(None) -} \ No newline at end of file +} diff --git a/src/bots/title.rs b/src/bots/title.rs index ca54cf2..45543c7 100644 --- a/src/bots/title.rs +++ b/src/bots/title.rs @@ -23,7 +23,8 @@ async fn resolve_spotify( // } tracing::debug!( "Resolving Spotify resource '{}' with id '{}'", - resource_type, resource_id + resource_type, + resource_id ); match resource_type { "track" => { @@ -89,6 +90,7 @@ impl Titlebot { r"(?:https?|spotify):(?://open\.spotify\.com/)?(track|artist|album|playlist)[/:]([a-zA-Z0-9]*)", )?; let mut spotify = ClientCredsSpotify::new(spotify_creds); + spotify.request_token().await?; Ok(Self { url_regex, @@ -103,6 +105,7 @@ impl Titlebot { tracing::debug!("{}", message); let tp_group = m.get(1).unwrap(); let id_group = m.get(2).unwrap(); + return Ok(Some( resolve_spotify( &mut self.spotify, @@ -114,6 +117,7 @@ impl Titlebot { } else if let Some(m) = self.url_regex.find(&message)? { let url = &message[m.start()..m.end()]; tracing::debug!("url: {}", url); + let response = reqwest::get(url).await?; if let Some(header) = response.headers().get("Content-Type") { tracing::debug!("response header: {}", header.to_str()?); @@ -121,6 +125,7 @@ impl Titlebot { return Ok(None); } } + let body = response.text().await?; if let Some(tm) = self.title_regex.find(&body)? { let title_match = &body[tm.start()..tm.end()]; diff --git a/src/database.rs b/src/database.rs index 1113653..b85331e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,5 +1,8 @@ -use rusqlite::{OptionalExtension, params}; -use tokio::sync::{mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, oneshot}; +use rusqlite::{params, OptionalExtension}; +use tokio::sync::{ + mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + oneshot, +}; #[derive(Debug)] enum Task { @@ -17,8 +20,11 @@ impl DbExecutor { pub fn create(dbpath: &str) -> rusqlite::Result<(Self, ExecutorConnection)> { let (tx, rx) = unbounded_channel(); let db = rusqlite::Connection::open(dbpath)?; - db.execute("create table if not exists quotes(id integer primary key,\ - username text not null, quote text not null)", [])?; + db.execute( + "create table if not exists quotes(id integer primary key,\ + username text not null, quote text not null)", + [], + )?; tracing::debug!("Database connected ({})", dbpath); Ok((Self { rx, db }, ExecutorConnection { tx })) } @@ -28,7 +34,9 @@ impl DbExecutor { match task { Task::AddQuote(tx, quote, author) => { if let Err(e) = self.db.execute( - "insert into quotes(quote,username) values(?,?)", params![quote,author]) { + "insert into quotes(quote,username) values(?,?)", + params![quote, author], + ) { tracing::error!("A database error has occurred: {}", e); tx.send(false).unwrap(); } else { @@ -57,7 +65,9 @@ pub struct ExecutorConnection { impl Clone for ExecutorConnection { fn clone(&self) -> Self { - Self { tx: self.tx.clone() } + Self { + tx: self.tx.clone(), + } } } diff --git a/src/main.rs b/src/main.rs index 3a14388..380a94e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,20 @@ mod bots; mod database; +use crate::database::{DbExecutor, ExecutorConnection}; use arrayvec::ArrayString; use async_circe::{commands::Command, Client, Config}; use bots::title::Titlebot; -use bots::{leek, misc, sed}; +use bots::{misc, misc::LeekCommand, sed}; use rspotify::Credentials; use serde::Deserialize; +use std::fmt::Write; use std::fs::File; use std::io::Read; +use std::thread; use std::{collections::HashMap, env}; use tokio::select; use tracing_subscriber::EnvFilter; -use std::fmt::Write; -use std::thread; -use crate::database::{DbExecutor, ExecutorConnection}; // this will be displayed when the help command is used const HELP: &[&str] = &[ @@ -45,13 +45,13 @@ async fn terminate_signal() { let _ = ctrlc.recv().await; } -struct AppState { +pub struct AppState { prefix: String, client: Client, last_msgs: HashMap, last_eval: HashMap, titlebot: Titlebot, - db: ExecutorConnection + db: ExecutorConnection, } #[derive(Deserialize)] @@ -65,7 +65,7 @@ struct ClientConf { spotify_client_id: String, spotify_client_secret: String, prefix: String, - db_path: Option + db_path: Option, } #[tokio::main(flavor = "current_thread")] @@ -74,13 +74,15 @@ async fn main() -> anyhow::Result<()> { .with_env_filter(EnvFilter::from_env("UBERBOT_LOG")) .init(); - let mut file = File::open(env::var("UBERBOT_CONFIG").unwrap_or_else(|_| "uberbot.toml".to_string()))?; + let mut file = + File::open(env::var("UBERBOT_CONFIG").unwrap_or_else(|_| "uberbot.toml".to_string()))?; let mut client_conf = String::new(); file.read_to_string(&mut client_conf)?; let client_config: ClientConf = toml::from_str(&client_conf)?; - let (db_exec, db_conn) = DbExecutor::create(client_config.db_path.as_deref().unwrap_or("uberbot.db3"))?; + let (db_exec, db_conn) = + DbExecutor::create(client_config.db_path.as_deref().unwrap_or("uberbot.db3"))?; let exec_thread = thread::spawn(move || { db_exec.run(); tracing::info!("Database executor has been shut down"); @@ -99,6 +101,7 @@ async fn main() -> anyhow::Result<()> { client_config.port, client_config.username, ); + let mut client = Client::new(config).await?; client.identify().await?; @@ -108,14 +111,16 @@ async fn main() -> anyhow::Result<()> { last_msgs: HashMap::new(), last_eval: HashMap::new(), titlebot: Titlebot::create(spotify_creds).await?, - db: db_conn + db: db_conn, }; if let Err(e) = executor(state).await { tracing::error!("Error in message loop: {}", e); } - exec_thread.join(); + if let Err(e) = exec_thread.join() { + tracing::error!("Error while shutting down the database: {:?}", e); + } tracing::info!("Shutting down"); Ok(()) @@ -146,28 +151,6 @@ async fn message_loop(state: &mut AppState) -> anyhow::Result<()> { Ok(()) } -#[derive(Debug)] -enum LeekCommand { - Owo, Leet, Mock -} -async fn execute_leek(state: &mut AppState, cmd: LeekCommand, channel: &str, nick: &str) -> anyhow::Result<()> { - match state.last_msgs.get(nick) { - Some(msg) => { - tracing::debug!("Executing {:?} on {:?}", cmd, msg); - let output = match cmd { - LeekCommand::Owo => leek::owoify(msg)?, - LeekCommand::Leet => leek::leetify(msg)?, - LeekCommand::Mock => leek::mock(msg)? - }; - state.client.privmsg(channel, &output).await?; - } - None => { - state.client.privmsg(channel, "No last messages found.").await?; - } - } - Ok(()) -} - fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) { if let Some(o) = str.find(' ') { (&str[prefix_len..o], Some(&str[o + 1..])) @@ -219,13 +202,26 @@ async fn handle_privmsg( state.client.privmsg(&channel, response).await?; } "mock" => { - execute_leek(state, LeekCommand::Mock, channel, remainder.unwrap_or(&nick)).await?; + misc::execute_leek( + state, + LeekCommand::Mock, + channel, + remainder.unwrap_or(&nick), + ) + .await?; } "leet" => { - execute_leek(state, LeekCommand::Leet, channel, remainder.unwrap_or(&nick)).await?; + misc::execute_leek( + state, + LeekCommand::Leet, + channel, + remainder.unwrap_or(&nick), + ) + .await?; } "owo" => { - execute_leek(state, LeekCommand::Owo, channel, remainder.unwrap_or(&nick)).await?; + misc::execute_leek(state, LeekCommand::Owo, channel, remainder.unwrap_or(&nick)) + .await?; } "ev" => { let result = misc::mathbot(nick, remainder, &mut state.last_eval)?; @@ -234,20 +230,32 @@ async fn handle_privmsg( "grab" => { if let Some(target) = remainder { if target == nick { - state.client.privmsg(&channel, "You can't grab yourself").await?; - return Ok(()) + state + .client + .privmsg(&channel, "You can't grab yourself") + .await?; + return Ok(()); } if let Some(prev_msg) = state.last_msgs.get(target) { if state.db.add_quote(prev_msg.clone(), target.into()).await { state.client.privmsg(&channel, "Quote added").await?; } else { - state.client.privmsg(&channel, "A database error has occurred").await?; + state + .client + .privmsg(&channel, "A database error has occurred") + .await?; } } else { - state.client.privmsg(&channel, "No previous messages to grab").await?; + state + .client + .privmsg(&channel, "No previous messages to grab") + .await?; } } else { - state.client.privmsg(&channel, "No nickname to grab").await?; + state + .client + .privmsg(&channel, "No nickname to grab") + .await?; } } "quot" => {