From fa039d90702f08c23f9f697b37b318ee59cf471b Mon Sep 17 00:00:00 2001 From: lemonsh Date: Sat, 16 Jul 2022 12:21:23 +0200 Subject: [PATCH] Add command traits --- sample_uberbot.toml | 1 + src/bot.rs | 60 ++++++++++++++ src/bots/misc.rs | 25 ++++-- src/main.rs | 196 ++++++++------------------------------------ 4 files changed, 109 insertions(+), 173 deletions(-) create mode 100644 src/bot.rs diff --git a/sample_uberbot.toml b/sample_uberbot.toml index 95a847a..de8b056 100644 --- a/sample_uberbot.toml +++ b/sample_uberbot.toml @@ -1,6 +1,7 @@ [irc] host = "karx.xyz" port = 6697 +tls = true username = "uberbot" channels = ["#main", "#no-normies"] mode = "+B" diff --git a/src/bot.rs b/src/bot.rs new file mode 100644 index 0000000..a00b776 --- /dev/null +++ b/src/bot.rs @@ -0,0 +1,60 @@ +use std::collections::HashMap; +use fancy_regex::Regex; +use crate::ExecutorConnection; +use async_trait::async_trait; + +fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) { + if let Some(o) = str.find(' ') { + (&str[prefix_len..o], Some(&str[o + 1..])) + } else { + (&str[prefix_len..], None) + } +} + +pub trait RegexCommand { + fn execute(&mut self, message: String) -> anyhow::Result; +} + +#[async_trait] +pub trait NormalCommand { + async fn execute(&mut self, last_msg: &HashMap, message: String) -> anyhow::Result; +} + +#[derive(Default)] +struct Commands { + regex: Vec<(Regex, Box)>, + normal: HashMap>, +} + +pub struct Bot anyhow::Result<()>> { + last_msg: HashMap, + prefix: String, + db: ExecutorConnection, + commands: Commands, + sendmsg: SF +} + +impl anyhow::Result<()>> Bot { + pub fn new(prefix: String, db: ExecutorConnection, sendmsg: SF) -> Self { + Bot { + last_msg: HashMap::new(), + prefix, + db, + commands: Commands::default(), + sendmsg + } + } + + pub fn add_command(&mut self, name: String, cmd: C) { + self.commands.normal.insert(name, Box::new(cmd)); + } + + pub fn add_regex_command(&mut self, regex: Regex, cmd: C) { + self.commands.regex.push((regex, Box::new(cmd))); + } + + pub async fn handle_message(&mut self, origin: &str, author: &str, content: &str) -> anyhow::Result<()> { + (self.sendmsg)(origin.into(), content.into()).unwrap(); + Ok(()) + } +} \ No newline at end of file diff --git a/src/bots/misc.rs b/src/bots/misc.rs index 893ed7d..ecb1f61 100644 --- a/src/bots/misc.rs +++ b/src/bots/misc.rs @@ -3,16 +3,23 @@ use meval::Context; use serde_json::Value; use std::collections::HashMap; use std::fmt::Write; +use crate::bot::NormalCommand; +use async_trait::async_trait; -pub async fn get_waifu_pic(category: &str) -> anyhow::Result> { - let api_resp = reqwest::get(format!("https://api.waifu.pics/sfw/{}", category)) - .await? - .text() - .await?; - let api_resp = api_resp.trim(); - let value: Value = serde_json::from_str(api_resp)?; - let url = value["url"].as_str().map(ToString::to_string); - Ok(url) +pub struct Waifu; + +#[async_trait] +impl NormalCommand for Waifu { + async fn execute(&mut self, _last_msg: &HashMap, message: String) -> anyhow::Result { + let api_resp = reqwest::get(format!("https://api.waifu.pics/sfw/{}", message)) + .await? + .text() + .await?; + let api_resp = api_resp.trim(); + let value: Value = serde_json::from_str(api_resp)?; + let url = value["url"].as_str().unwrap_or("Invalid API Response.").to_string(); + Ok(url) + } } pub fn mathbot( diff --git a/src/main.rs b/src/main.rs index 010a900..a9eb273 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,12 @@ #![allow(clippy::match_wildcard_for_single_variants)] -use std::fmt::Write; use std::fs::File; use std::io::Read; -use std::net::SocketAddr; use std::sync::Arc; use std::thread; -use std::{collections::HashMap, env}; +use std::env; +use std::fmt::Display; -use arrayvec::ArrayString; use futures_util::stream::StreamExt; use irc::client::prelude::Config; use irc::client::{Client, ClientStream}; @@ -19,9 +17,11 @@ use tokio::select; use tokio::sync::broadcast; use tokio::sync::mpsc::unbounded_channel; use tracing_subscriber::EnvFilter; +use crate::bot::Bot; +use crate::bots::misc::Waifu; -use crate::bots::{leek, misc, sed, title}; -use crate::database::{DbExecutor, ExecutorConnection, Quote}; +use crate::config::UberConfig; +use crate::database::{DbExecutor, ExecutorConnection}; mod bots; mod database; @@ -57,14 +57,10 @@ async fn terminate_signal() { let _ = ctrlc.recv().await; } -pub struct AppState { - prefix: String, +pub struct AppState anyhow::Result<()>> { client: Arc, stream: ClientStream, - last_msgs: HashMap, - last_eval: HashMap, - titlebot: title::Titlebot, - db: ExecutorConnection, + bot: Bot } #[tokio::main] @@ -78,38 +74,30 @@ async fn main() -> anyhow::Result<()> { let mut client_conf = String::new(); file.read_to_string(&mut client_conf)?; - let client_config: ClientConf = toml::from_str(&client_conf)?; + let cfg: UberConfig = toml::from_str(&client_conf)?; let (db_exec, db_conn) = - DbExecutor::create(client_config.db_path.as_deref().unwrap_or("uberbot.db3"))?; + DbExecutor::create(cfg.db_path.as_deref().unwrap_or("uberbot.db3"))?; let exec_thread = thread::spawn(move || { db_exec.run(); tracing::info!("Database executor has been shut down"); }); - let spotify_creds = Credentials::new( - &client_config.spotify_client_id, - &client_config.spotify_client_secret, - ); - - let http_listen = client_config - .http_listen - .unwrap_or_else(|| SocketAddr::from(([127, 0, 0, 1], 5000))); - let uber_ver = concat!("Überbot ", env!("CARGO_PKG_VERSION")); let irc_config = Config { nickname: Some( - client_config + cfg + .irc .nickname - .unwrap_or_else(|| client_config.username.clone()), + .unwrap_or_else(|| cfg.irc.username.clone()), ), - username: Some(client_config.username.clone()), - realname: Some(client_config.username), - server: Some(client_config.host), - port: Some(client_config.port), - use_tls: Some(client_config.tls), - channels: client_config.channels, - umodes: client_config.mode, + username: Some(cfg.irc.username.clone()), + realname: Some(cfg.irc.username), + server: Some(cfg.irc.host), + port: Some(cfg.irc.port), + use_tls: Some(cfg.irc.tls), + channels: cfg.irc.channels, + umodes: cfg.irc.mode, user_info: Some(uber_ver.into()), version: Some(uber_ver.into()), ..Config::default() @@ -122,14 +110,17 @@ async fn main() -> anyhow::Result<()> { let (ctx, _) = broadcast::channel(1); let (etx, mut erx) = unbounded_channel(); + let mut bot = Bot::new(cfg.irc.prefix, db_conn, { + let client = client.clone(); + move |target, msg| Ok(client.send_privmsg(target, msg)?) + }); + + bot.add_command("waifu".into(), Waifu); + let state = AppState { - prefix: client_config.prefix, client: client.clone(), stream, - last_msgs: HashMap::new(), - last_eval: HashMap::new(), - titlebot: title::Titlebot::create(spotify_creds).await?, - db: db_conn, + bot }; let message_loop_task = tokio::spawn(async move { if let Err(e) = message_loop(state).await { @@ -156,15 +147,17 @@ async fn main() -> anyhow::Result<()> { message_loop_task .await .unwrap_or_else(|e| tracing::warn!("Couldn't join the web service: {:?}", e)); + tracing::info!("Message loop finished"); exec_thread .join() .unwrap_or_else(|e| tracing::warn!("Couldn't join the database: {:?}", e)); + tracing::info!("Executor thread finished"); tracing::info!("Shutdown complete!"); Ok(()) } -async fn message_loop(mut state: AppState) -> anyhow::Result<()> { +async fn message_loop anyhow::Result<()>>(mut state: AppState) -> anyhow::Result<()> { while let Some(message) = state.stream.next().await.transpose()? { if let Command::PRIVMSG(ref origin, content) = message.command { if origin.is_channel_name() { @@ -172,7 +165,7 @@ async fn message_loop(mut state: AppState) -> anyhow::Result<()> { Prefix::Nickname(name, _, _) => Some(&name[..]), _ => None, }) { - if let Err(e) = handle_privmsg(&mut state, author, origin, content).await { + if let Err(e) = state.bot.handle_message(origin, author, &content).await { state .client .send_privmsg(origin, &format!("Error: {}", e))?; @@ -183,132 +176,7 @@ async fn message_loop(mut state: AppState) -> anyhow::Result<()> { } } } - tracing::info!("Message loop finished"); Ok(()) } -fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) { - if let Some(o) = str.find(' ') { - (&str[prefix_len..o], Some(&str[o + 1..])) - } else { - (&str[prefix_len..], None) - } -} -#[allow(clippy::too_many_lines)] -async fn handle_privmsg( - state: &mut AppState, - author: &str, - origin: &str, - content: String, -) -> anyhow::Result<()> { - if !content.starts_with(state.prefix.as_str()) { - if let Some(titlebot_msg) = state.titlebot.resolve(&content).await? { - state.client.send_privmsg(origin, &titlebot_msg)?; - } - - if let Some(prev_msg) = state.last_msgs.get(author) { - if let Some(formatted) = sed::resolve(prev_msg, &content)? { - let mut result = ArrayString::<512>::new(); - write!(result, "<{}> {}", author, formatted)?; - state.client.send_privmsg(origin, &result)?; - state.last_msgs.insert(author.into(), formatted.to_string()); - return Ok(()); - } - } - - state.last_msgs.insert(author.into(), content); - return Ok(()); - } - let (command, remainder) = separate_to_space(&content, state.prefix.len()); - tracing::debug!("Command received ({:?}; {:?})", command, remainder); - - match command { - "help" => { - for help_line in HELP { - state.client.send_privmsg(origin, help_line)?; - } - } - "waifu" => { - let category = remainder.unwrap_or("waifu"); - let url = misc::get_waifu_pic(category).await?; - let response = url - .as_deref() - .unwrap_or("Invalid category. Valid categories: https://waifu.pics/docs"); - state.client.send_privmsg(origin, response)?; - } - "mock" => { - leek::execute( - state, - leek::Command::Mock, - origin, - remainder.unwrap_or(author), - )?; - } - "leet" => { - leek::execute( - state, - leek::Command::Leet, - origin, - remainder.unwrap_or(author), - )?; - } - "owo" => { - leek::execute( - state, - leek::Command::Owo, - origin, - remainder.unwrap_or(author), - )?; - } - "ev" => { - let result = misc::mathbot(author.into(), remainder, &mut state.last_eval)?; - state.client.send_privmsg(origin, &result)?; - } - "grab" => { - if let Some(target) = remainder { - if target == author { - state - .client - .send_privmsg(origin, "You can't grab yourself")?; - return Ok(()); - } - if let Some(prev_msg) = state.last_msgs.get(target) { - if state - .db - .add_quote(Quote { - quote: prev_msg.clone(), - author: target.into(), - }) - .await - { - state.client.send_privmsg(origin, "Quote added")?; - } else { - state - .client - .send_privmsg(origin, "A database error has occurred")?; - } - } else { - state - .client - .send_privmsg(origin, "No previous messages to grab")?; - } - } else { - state.client.send_privmsg(origin, "No nickname to grab")?; - } - } - "quot" => { - if let Some(quote) = state.db.get_quote(remainder.map(ToString::to_string)).await { - let mut resp = ArrayString::<512>::new(); - write!(resp, "\"{}\" ~{}", quote.quote, quote.author)?; - state.client.send_privmsg(origin, &resp)?; - } else { - state.client.send_privmsg(origin, "No quotes found")?; - } - } - _ => { - state.client.send_privmsg(origin, "Unknown command")?; - } - } - Ok(()) -}