From c3b12e9e72b8ef9a34e762c0bd0889d8a6a2122e Mon Sep 17 00:00:00 2001 From: lemonsh Date: Sat, 1 Jan 2022 22:35:27 +0100 Subject: [PATCH] Port quotquot --- .gitignore | 1 + Cargo.toml | 5 ++-- src/bots/misc.rs | 1 - src/database.rs | 75 ++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 57 ++++++++++++++++++++++++++++++++---- 5 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 src/database.rs diff --git a/.gitignore b/.gitignore index 8372106..1bdef75 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ uberbot_*.toml uberbot.toml /Cargo.lock .idea +*.db3 diff --git a/Cargo.toml b/Cargo.toml index 431a050..521aee7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,5 +22,6 @@ arrayvec = "0.7" rand = "0.8" meval = "0.2" async-circe = { git = "https://git.karx.xyz/circe/async-circe" } -lazy_static = "1.4.0" -sedregex = "0.2.5" +lazy_static = "1.4" +sedregex = "0.2" +rusqlite = { version = "0.26", features = ["bundled"] } diff --git a/src/bots/misc.rs b/src/bots/misc.rs index ddaf08d..d4eb697 100644 --- a/src/bots/misc.rs +++ b/src/bots/misc.rs @@ -10,7 +10,6 @@ pub async fn get_waifu_pic(category: &str) -> anyhow::Result> { .text() .await?; let api_resp = api_resp.trim(); - tracing::debug!("API response: {}", api_resp); let value: Value = serde_json::from_str(&api_resp)?; let url = value["url"].as_str().map(|v| v.to_string()); Ok(url) diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..1113653 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,75 @@ +use rusqlite::{OptionalExtension, params}; +use tokio::sync::{mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, oneshot}; + +#[derive(Debug)] +enum Task { + AddQuote(oneshot::Sender, String, String), + GetQuote(oneshot::Sender>, Option), + // implement search WITH PAGINATION +} + +pub struct DbExecutor { + rx: UnboundedReceiver, + db: rusqlite::Connection, +} + +impl DbExecutor { + pub fn create(dbpath: &str) -> rusqlite::Result<(Self, ExecutorConnection)> { + let (tx, rx) = unbounded_channel(); + let db = rusqlite::Connection::open(dbpath)?; + db.execute("create table if not exists quotes(id integer primary key,\ + username text not null, quote text not null)", [])?; + tracing::debug!("Database connected ({})", dbpath); + Ok((Self { rx, db }, ExecutorConnection { tx })) + } + + pub fn run(mut self) { + while let Some(task) = self.rx.blocking_recv() { + match task { + Task::AddQuote(tx, quote, author) => { + if let Err(e) = self.db.execute( + "insert into quotes(quote,username) values(?,?)", params![quote,author]) { + tracing::error!("A database error has occurred: {}", e); + tx.send(false).unwrap(); + } else { + tx.send(true).unwrap(); + } + } + Task::GetQuote(tx, author) => { + let quote = if let Some(ref author) = author { + self.db.query_row("select quote,username from quotes where username=? order by random() limit 1", params![author], |v| Ok((v.get(0)?, v.get(1)?))) + } else { + self.db.query_row("select quote,username from quotes order by random() limit 1", params![], |v| Ok((v.get(0)?, v.get(1)?))) + }.optional().unwrap_or_else(|e| { + tracing::error!("A database error has occurred: {}", e); + None + }); + tx.send(quote).unwrap(); + } + } + } + } +} + +pub struct ExecutorConnection { + tx: UnboundedSender, +} + +impl Clone for ExecutorConnection { + fn clone(&self) -> Self { + Self { tx: self.tx.clone() } + } +} + +impl ExecutorConnection { + pub async fn add_quote(&self, quote: String, author: String) -> bool { + let (otx, orx) = oneshot::channel(); + self.tx.send(Task::AddQuote(otx, quote, author)).unwrap(); + orx.await.unwrap() + } + pub async fn get_quote(&self, author: Option) -> Option<(String, String)> { + let (otx, orx) = oneshot::channel(); + self.tx.send(Task::GetQuote(otx, author)).unwrap(); + orx.await.unwrap() + } +} diff --git a/src/main.rs b/src/main.rs index a8fd6e9..aab10c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod bots; +mod database; use arrayvec::ArrayString; use async_circe::{commands::Command, Client, Config}; @@ -12,6 +13,8 @@ use std::{collections::HashMap, env}; use tokio::select; use tracing_subscriber::EnvFilter; use std::fmt::Write; +use std::thread; +use crate::database::{DbExecutor, ExecutorConnection}; // this will be displayed when the help command is used const HELP: &[&str] = &[ @@ -48,6 +51,7 @@ struct AppState { last_msgs: HashMap, last_eval: HashMap, titlebot: Titlebot, + db: ExecutorConnection } #[derive(Deserialize)] @@ -61,6 +65,7 @@ struct ClientConf { spotify_client_id: String, spotify_client_secret: String, prefix: String, + db_path: Option } #[tokio::main(flavor = "current_thread")] @@ -75,6 +80,12 @@ async fn main() -> anyhow::Result<()> { let client_config: ClientConf = toml::from_str(&client_conf)?; + let (db_exec, db_conn) = DbExecutor::create(client_config.db_path.as_deref().unwrap_or("uberbot.db3"))?; + thread::spawn(move || { + db_exec.run(); + tracing::info!("Database executor has been shut down"); + }); + let spotify_creds = Credentials::new( &client_config.spotify_client_id, &client_config.spotify_client_secret, @@ -97,6 +108,7 @@ async fn main() -> anyhow::Result<()> { last_msgs: HashMap::new(), last_eval: HashMap::new(), titlebot: Titlebot::create(spotify_creds).await?, + db: db_conn }; if let Err(e) = executor(state).await { @@ -133,12 +145,14 @@ async fn message_loop(state: &mut AppState) -> anyhow::Result<()> { Ok(()) } +#[derive(Debug)] enum LeekCommand { Owo, Leet, Mock } async fn execute_leek(state: &mut AppState, cmd: LeekCommand, channel: &str, nick: &str) -> anyhow::Result<()> { match state.last_msgs.get(nick) { Some(msg) => { + tracing::debug!("Executing {:?} on {:?}", cmd, msg); let output = match cmd { LeekCommand::Owo => leek::owoify(msg)?, LeekCommand::Leet => leek::leetify(msg)?, @@ -153,6 +167,14 @@ async fn execute_leek(state: &mut AppState, cmd: LeekCommand, channel: &str, nic Ok(()) } +fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) { + if let Some(o) = str.find(' ') { + (&str[prefix_len..o], Some(&str[o + 1..])) + } else { + (&str[prefix_len..], None) + } +} + async fn handle_privmsg( state: &mut AppState, nick: String, @@ -177,12 +199,7 @@ async fn handle_privmsg( state.last_msgs.insert(nick, message); return Ok(()); } - let space_index = message.find(' '); - let (command, remainder) = if let Some(o) = space_index { - (&message[state.prefix.len()..o], Some(&message[o + 1..])) - } else { - (&message[state.prefix.len()..], None) - }; + let (command, remainder) = separate_to_space(&message, state.prefix.len()); tracing::debug!("Command received ({:?}; {:?})", command, remainder); match command { @@ -213,6 +230,34 @@ async fn handle_privmsg( let result = misc::mathbot(nick, remainder, &mut state.last_eval)?; state.client.privmsg(&channel, &result).await?; } + "grab" => { + if let Some(target) = remainder { + if target == nick { + state.client.privmsg(&channel, "You can't grab yourself").await?; + return Ok(()) + } + if let Some(prev_msg) = state.last_msgs.get(target) { + if state.db.add_quote(prev_msg.clone(), target.into()).await { + state.client.privmsg(&channel, "Quote added").await?; + } else { + state.client.privmsg(&channel, "A database error has occurred").await?; + } + } else { + state.client.privmsg(&channel, "No previous messages to grab").await?; + } + } else { + state.client.privmsg(&channel, "No nickname to grab").await?; + } + } + "quot" => { + if let Some(quote) = state.db.get_quote(remainder.map(|v| v.to_string())).await { + let mut resp = ArrayString::<512>::new(); + write!(resp, "\"{}\" ~{}", quote.0, quote.1)?; + state.client.privmsg(&channel, &resp).await?; + } else { + state.client.privmsg(&channel, "No quotes found").await?; + } + } _ => { state.client.privmsg(&channel, "Unknown command").await?; }