Add command traits

This commit is contained in:
lemonsh 2022-07-16 12:21:23 +02:00
parent 3299807172
commit fa039d9070
4 changed files with 109 additions and 173 deletions

View file

@ -1,6 +1,7 @@
[irc] [irc]
host = "karx.xyz" host = "karx.xyz"
port = 6697 port = 6697
tls = true
username = "uberbot" username = "uberbot"
channels = ["#main", "#no-normies"] channels = ["#main", "#no-normies"]
mode = "+B" mode = "+B"

60
src/bot.rs Normal file
View file

@ -0,0 +1,60 @@
use std::collections::HashMap;
use fancy_regex::Regex;
use crate::ExecutorConnection;
use async_trait::async_trait;
fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) {
if let Some(o) = str.find(' ') {
(&str[prefix_len..o], Some(&str[o + 1..]))
} else {
(&str[prefix_len..], None)
}
}
pub trait RegexCommand {
fn execute(&mut self, message: String) -> anyhow::Result<String>;
}
#[async_trait]
pub trait NormalCommand {
async fn execute(&mut self, last_msg: &HashMap<String, String>, message: String) -> anyhow::Result<String>;
}
#[derive(Default)]
struct Commands {
regex: Vec<(Regex, Box<dyn RegexCommand + Send>)>,
normal: HashMap<String, Box<dyn NormalCommand + Send>>,
}
pub struct Bot<SF: FnMut(String, String) -> anyhow::Result<()>> {
last_msg: HashMap<String, String>,
prefix: String,
db: ExecutorConnection,
commands: Commands,
sendmsg: SF
}
impl<SF: FnMut(String, String) -> anyhow::Result<()>> Bot<SF> {
pub fn new(prefix: String, db: ExecutorConnection, sendmsg: SF) -> Self {
Bot {
last_msg: HashMap::new(),
prefix,
db,
commands: Commands::default(),
sendmsg
}
}
pub fn add_command<C: NormalCommand + Send + 'static>(&mut self, name: String, cmd: C) {
self.commands.normal.insert(name, Box::new(cmd));
}
pub fn add_regex_command<C: RegexCommand + Send + 'static>(&mut self, regex: Regex, cmd: C) {
self.commands.regex.push((regex, Box::new(cmd)));
}
pub async fn handle_message(&mut self, origin: &str, author: &str, content: &str) -> anyhow::Result<()> {
(self.sendmsg)(origin.into(), content.into()).unwrap();
Ok(())
}
}

View file

@ -3,16 +3,23 @@ use meval::Context;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Write; use std::fmt::Write;
use crate::bot::NormalCommand;
use async_trait::async_trait;
pub async fn get_waifu_pic(category: &str) -> anyhow::Result<Option<String>> { pub struct Waifu;
let api_resp = reqwest::get(format!("https://api.waifu.pics/sfw/{}", category))
.await? #[async_trait]
.text() impl NormalCommand for Waifu {
.await?; async fn execute(&mut self, _last_msg: &HashMap<String, String>, message: String) -> anyhow::Result<String> {
let api_resp = api_resp.trim(); let api_resp = reqwest::get(format!("https://api.waifu.pics/sfw/{}", message))
let value: Value = serde_json::from_str(api_resp)?; .await?
let url = value["url"].as_str().map(ToString::to_string); .text()
Ok(url) .await?;
let api_resp = api_resp.trim();
let value: Value = serde_json::from_str(api_resp)?;
let url = value["url"].as_str().unwrap_or("Invalid API Response.").to_string();
Ok(url)
}
} }
pub fn mathbot( pub fn mathbot(

View file

@ -1,14 +1,12 @@
#![allow(clippy::match_wildcard_for_single_variants)] #![allow(clippy::match_wildcard_for_single_variants)]
use std::fmt::Write;
use std::fs::File; use std::fs::File;
use std::io::Read; use std::io::Read;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::{collections::HashMap, env}; use std::env;
use std::fmt::Display;
use arrayvec::ArrayString;
use futures_util::stream::StreamExt; use futures_util::stream::StreamExt;
use irc::client::prelude::Config; use irc::client::prelude::Config;
use irc::client::{Client, ClientStream}; use irc::client::{Client, ClientStream};
@ -19,9 +17,11 @@ use tokio::select;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::unbounded_channel;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use crate::bot::Bot;
use crate::bots::misc::Waifu;
use crate::bots::{leek, misc, sed, title}; use crate::config::UberConfig;
use crate::database::{DbExecutor, ExecutorConnection, Quote}; use crate::database::{DbExecutor, ExecutorConnection};
mod bots; mod bots;
mod database; mod database;
@ -57,14 +57,10 @@ async fn terminate_signal() {
let _ = ctrlc.recv().await; let _ = ctrlc.recv().await;
} }
pub struct AppState { pub struct AppState<SF: FnMut(String, String) -> anyhow::Result<()>> {
prefix: String,
client: Arc<Client>, client: Arc<Client>,
stream: ClientStream, stream: ClientStream,
last_msgs: HashMap<String, String>, bot: Bot<SF>
last_eval: HashMap<String, f64>,
titlebot: title::Titlebot,
db: ExecutorConnection,
} }
#[tokio::main] #[tokio::main]
@ -78,38 +74,30 @@ async fn main() -> anyhow::Result<()> {
let mut client_conf = String::new(); let mut client_conf = String::new();
file.read_to_string(&mut client_conf)?; file.read_to_string(&mut client_conf)?;
let client_config: ClientConf = toml::from_str(&client_conf)?; let cfg: UberConfig = toml::from_str(&client_conf)?;
let (db_exec, db_conn) = let (db_exec, db_conn) =
DbExecutor::create(client_config.db_path.as_deref().unwrap_or("uberbot.db3"))?; DbExecutor::create(cfg.db_path.as_deref().unwrap_or("uberbot.db3"))?;
let exec_thread = thread::spawn(move || { let exec_thread = thread::spawn(move || {
db_exec.run(); db_exec.run();
tracing::info!("Database executor has been shut down"); tracing::info!("Database executor has been shut down");
}); });
let spotify_creds = Credentials::new(
&client_config.spotify_client_id,
&client_config.spotify_client_secret,
);
let http_listen = client_config
.http_listen
.unwrap_or_else(|| SocketAddr::from(([127, 0, 0, 1], 5000)));
let uber_ver = concat!("Überbot ", env!("CARGO_PKG_VERSION")); let uber_ver = concat!("Überbot ", env!("CARGO_PKG_VERSION"));
let irc_config = Config { let irc_config = Config {
nickname: Some( nickname: Some(
client_config cfg
.irc
.nickname .nickname
.unwrap_or_else(|| client_config.username.clone()), .unwrap_or_else(|| cfg.irc.username.clone()),
), ),
username: Some(client_config.username.clone()), username: Some(cfg.irc.username.clone()),
realname: Some(client_config.username), realname: Some(cfg.irc.username),
server: Some(client_config.host), server: Some(cfg.irc.host),
port: Some(client_config.port), port: Some(cfg.irc.port),
use_tls: Some(client_config.tls), use_tls: Some(cfg.irc.tls),
channels: client_config.channels, channels: cfg.irc.channels,
umodes: client_config.mode, umodes: cfg.irc.mode,
user_info: Some(uber_ver.into()), user_info: Some(uber_ver.into()),
version: Some(uber_ver.into()), version: Some(uber_ver.into()),
..Config::default() ..Config::default()
@ -122,14 +110,17 @@ async fn main() -> anyhow::Result<()> {
let (ctx, _) = broadcast::channel(1); let (ctx, _) = broadcast::channel(1);
let (etx, mut erx) = unbounded_channel(); let (etx, mut erx) = unbounded_channel();
let mut bot = Bot::new(cfg.irc.prefix, db_conn, {
let client = client.clone();
move |target, msg| Ok(client.send_privmsg(target, msg)?)
});
bot.add_command("waifu".into(), Waifu);
let state = AppState { let state = AppState {
prefix: client_config.prefix,
client: client.clone(), client: client.clone(),
stream, stream,
last_msgs: HashMap::new(), bot
last_eval: HashMap::new(),
titlebot: title::Titlebot::create(spotify_creds).await?,
db: db_conn,
}; };
let message_loop_task = tokio::spawn(async move { let message_loop_task = tokio::spawn(async move {
if let Err(e) = message_loop(state).await { if let Err(e) = message_loop(state).await {
@ -156,15 +147,17 @@ async fn main() -> anyhow::Result<()> {
message_loop_task message_loop_task
.await .await
.unwrap_or_else(|e| tracing::warn!("Couldn't join the web service: {:?}", e)); .unwrap_or_else(|e| tracing::warn!("Couldn't join the web service: {:?}", e));
tracing::info!("Message loop finished");
exec_thread exec_thread
.join() .join()
.unwrap_or_else(|e| tracing::warn!("Couldn't join the database: {:?}", e)); .unwrap_or_else(|e| tracing::warn!("Couldn't join the database: {:?}", e));
tracing::info!("Executor thread finished");
tracing::info!("Shutdown complete!"); tracing::info!("Shutdown complete!");
Ok(()) Ok(())
} }
async fn message_loop(mut state: AppState) -> anyhow::Result<()> { async fn message_loop<SF: FnMut(String, String) -> anyhow::Result<()>>(mut state: AppState<SF>) -> anyhow::Result<()> {
while let Some(message) = state.stream.next().await.transpose()? { while let Some(message) = state.stream.next().await.transpose()? {
if let Command::PRIVMSG(ref origin, content) = message.command { if let Command::PRIVMSG(ref origin, content) = message.command {
if origin.is_channel_name() { if origin.is_channel_name() {
@ -172,7 +165,7 @@ async fn message_loop(mut state: AppState) -> anyhow::Result<()> {
Prefix::Nickname(name, _, _) => Some(&name[..]), Prefix::Nickname(name, _, _) => Some(&name[..]),
_ => None, _ => None,
}) { }) {
if let Err(e) = handle_privmsg(&mut state, author, origin, content).await { if let Err(e) = state.bot.handle_message(origin, author, &content).await {
state state
.client .client
.send_privmsg(origin, &format!("Error: {}", e))?; .send_privmsg(origin, &format!("Error: {}", e))?;
@ -183,132 +176,7 @@ async fn message_loop(mut state: AppState) -> anyhow::Result<()> {
} }
} }
} }
tracing::info!("Message loop finished");
Ok(()) Ok(())
} }
fn separate_to_space(str: &str, prefix_len: usize) -> (&str, Option<&str>) {
if let Some(o) = str.find(' ') {
(&str[prefix_len..o], Some(&str[o + 1..]))
} else {
(&str[prefix_len..], None)
}
}
#[allow(clippy::too_many_lines)]
async fn handle_privmsg(
state: &mut AppState,
author: &str,
origin: &str,
content: String,
) -> anyhow::Result<()> {
if !content.starts_with(state.prefix.as_str()) {
if let Some(titlebot_msg) = state.titlebot.resolve(&content).await? {
state.client.send_privmsg(origin, &titlebot_msg)?;
}
if let Some(prev_msg) = state.last_msgs.get(author) {
if let Some(formatted) = sed::resolve(prev_msg, &content)? {
let mut result = ArrayString::<512>::new();
write!(result, "<{}> {}", author, formatted)?;
state.client.send_privmsg(origin, &result)?;
state.last_msgs.insert(author.into(), formatted.to_string());
return Ok(());
}
}
state.last_msgs.insert(author.into(), content);
return Ok(());
}
let (command, remainder) = separate_to_space(&content, state.prefix.len());
tracing::debug!("Command received ({:?}; {:?})", command, remainder);
match command {
"help" => {
for help_line in HELP {
state.client.send_privmsg(origin, help_line)?;
}
}
"waifu" => {
let category = remainder.unwrap_or("waifu");
let url = misc::get_waifu_pic(category).await?;
let response = url
.as_deref()
.unwrap_or("Invalid category. Valid categories: https://waifu.pics/docs");
state.client.send_privmsg(origin, response)?;
}
"mock" => {
leek::execute(
state,
leek::Command::Mock,
origin,
remainder.unwrap_or(author),
)?;
}
"leet" => {
leek::execute(
state,
leek::Command::Leet,
origin,
remainder.unwrap_or(author),
)?;
}
"owo" => {
leek::execute(
state,
leek::Command::Owo,
origin,
remainder.unwrap_or(author),
)?;
}
"ev" => {
let result = misc::mathbot(author.into(), remainder, &mut state.last_eval)?;
state.client.send_privmsg(origin, &result)?;
}
"grab" => {
if let Some(target) = remainder {
if target == author {
state
.client
.send_privmsg(origin, "You can't grab yourself")?;
return Ok(());
}
if let Some(prev_msg) = state.last_msgs.get(target) {
if state
.db
.add_quote(Quote {
quote: prev_msg.clone(),
author: target.into(),
})
.await
{
state.client.send_privmsg(origin, "Quote added")?;
} else {
state
.client
.send_privmsg(origin, "A database error has occurred")?;
}
} else {
state
.client
.send_privmsg(origin, "No previous messages to grab")?;
}
} else {
state.client.send_privmsg(origin, "No nickname to grab")?;
}
}
"quot" => {
if let Some(quote) = state.db.get_quote(remainder.map(ToString::to_string)).await {
let mut resp = ArrayString::<512>::new();
write!(resp, "\"{}\" ~{}", quote.quote, quote.author)?;
state.client.send_privmsg(origin, &resp)?;
} else {
state.client.send_privmsg(origin, "No quotes found")?;
}
}
_ => {
state.client.send_privmsg(origin, "Unknown command")?;
}
}
Ok(())
}