diff --git a/src/bots/leek.rs b/src/bots/leek.rs index 11f479c..e4c1d28 100644 --- a/src/bots/leek.rs +++ b/src/bots/leek.rs @@ -1,4 +1,4 @@ -use arrayvec::{ArrayString}; +use arrayvec::ArrayString; use rand::Rng; use std::{ error::Error, diff --git a/src/database.rs b/src/database.rs index bd8af2f..f118756 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,15 +1,16 @@ use rusqlite::{params, OptionalExtension}; +use serde::Serialize; use tokio::sync::{ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, oneshot, }; -use serde::Serialize; #[derive(Debug)] enum Task { AddQuote(oneshot::Sender, Quote), GetQuote(oneshot::Sender>, Option), - // implement search WITH PAGINATION + Search(oneshot::Sender>>, String), + Random20(oneshot::Sender>>) } pub struct DbExecutor { @@ -20,7 +21,7 @@ pub struct DbExecutor { #[derive(Serialize, Debug)] pub struct Quote { pub author: String, - pub quote: String + pub quote: String, } impl DbExecutor { @@ -61,6 +62,33 @@ impl DbExecutor { }); tx.send(quote).unwrap(); } + Task::Search(tx, query) => { + tx.send(match self.db + .prepare("select quote,username from quotes where quote like '%'||?1||'%' order by quote asc limit 50") + .and_then(|mut v| v.query(params![query]) + .and_then(|mut v| { + let mut quotes: Vec = Vec::with_capacity(50); + while let Some(row) = v.next()? { + quotes.push(Quote { + quote: row.get(0)?, + author: row.get(1)?, + }); + } + Ok(quotes) + })) + { + Ok(o) => { + Some(o) + } + Err(e) => { + tracing::error!("A database error has occurred: {}", e); + None + } + }).unwrap(); + } + Task::Random20(tx) => { + tx.send(None).unwrap(); + } } } } @@ -78,15 +106,24 @@ impl Clone for ExecutorConnection { } } -impl ExecutorConnection { - pub async fn add_quote(&self, quote: Quote) -> bool { - let (otx, orx) = oneshot::channel(); - self.tx.send(Task::AddQuote(otx, quote)).unwrap(); - orx.await.unwrap() - } - pub async fn get_quote(&self, author: Option) -> Option { - let (otx, orx) = oneshot::channel(); - self.tx.send(Task::GetQuote(otx, author)).unwrap(); - orx.await.unwrap() +macro_rules! executor_wrapper { + ($name:ident, $task:expr, $ret:ty, $($arg:ident: $ty:ty),*) => { + pub async fn $name(&self, $($arg: $ty),*) -> $ret { + let (otx, orx) = oneshot::channel(); + self.tx.send($task(otx, $($arg),*)).unwrap(); + orx.await.unwrap() } + } +} + +impl ExecutorConnection { + // WARNING: these methods are NOT cancel-safe + executor_wrapper!(add_quote, Task::AddQuote, bool, quote: Quote); + executor_wrapper!( + get_quote, + Task::GetQuote, + Option, + author: Option + ); + executor_wrapper!(search, Task::Search, Option>, query: String); } diff --git a/src/main.rs b/src/main.rs index 9aade81..13f8023 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use irc::proto::{ChannelExt, Command, Prefix}; use rspotify::Credentials; use serde::Deserialize; use tokio::sync::broadcast; -use tokio::sync::mpsc::{unbounded_channel}; +use tokio::sync::mpsc::unbounded_channel; use tracing_subscriber::EnvFilter; use crate::bots::{leek, misc, sed, title}; @@ -142,7 +142,7 @@ async fn main() -> anyhow::Result<()> { client.clone(), client_config.git_channel, http_listen, - ctx.subscribe() + ctx.subscribe(), )); let state = AppState { @@ -258,7 +258,8 @@ async fn handle_privmsg( "waifu" => { let category = remainder.unwrap_or("waifu"); let url = misc::get_waifu_pic(category).await?; - let response = url.as_deref() + let response = url + .as_deref() .unwrap_or("Invalid category. Valid categories: https://waifu.pics/docs"); state.client.send_privmsg(origin, response)?; } @@ -299,7 +300,14 @@ async fn handle_privmsg( return Ok(()); } if let Some(prev_msg) = state.last_msgs.get(target) { - if state.db.add_quote(Quote{quote:prev_msg.clone(), author:target.into()}).await { + if state + .db + .add_quote(Quote { + quote: prev_msg.clone(), + author: target.into(), + }) + .await + { state.client.send_privmsg(target, "Quote added")?; } else { state diff --git a/src/res/quote_tmpl.hbs b/src/res/quote_tmpl.hbs index 95b8636..85d6be3 100644 --- a/src/res/quote_tmpl.hbs +++ b/src/res/quote_tmpl.hbs @@ -16,13 +16,16 @@
-
+
+ {{#if flash}} +

{{flash}}

+ {{/if}} {{#if quotes}} diff --git a/src/web_service.rs b/src/web_service.rs index 0305659..97df282 100644 --- a/src/web_service.rs +++ b/src/web_service.rs @@ -1,20 +1,21 @@ +use crate::database::Quote; use crate::ExecutorConnection; +use handlebars::Handlebars; use irc::client::Client; +use lazy_static::lazy_static; use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; use serde_json::Value::Null; use std::net::SocketAddr; use std::sync::Arc; -use handlebars::Handlebars; -use lazy_static::lazy_static; use tokio::sync::broadcast::Receiver; use warp::{reply, Filter, Reply}; -use serde::Serialize; -use crate::database::Quote; lazy_static! { static ref HANDLEBARS: Handlebars<'static> = { let mut reg = Handlebars::new(); - reg.register_template_string("quotes", include_str!("res/quote_tmpl.hbs")).unwrap(); + reg.register_template_string("quotes", include_str!("res/quote_tmpl.hbs")) + .unwrap(); reg }; } @@ -24,12 +25,13 @@ pub async fn run( wh_irc: Arc, wh_channel: String, listen: SocketAddr, - mut cancel: Receiver<()> + mut cancel: Receiver<()>, ) { - let quote_get = warp::path("quotes") - .and(warp::get()) + let quote_get = warp::get() + .and(warp::path("quotes")) + .and(warp::query::()) .and(warp::any().map(move || db.clone())) - .map(handle_get_quote); + .then(handle_get_quote); let webhook_post = warp::path("webhook") .and(warp::post()) @@ -38,31 +40,58 @@ pub async fn run( .and(warp::any().map(move || wh_channel.clone())) .map(handle_webhook); - let filter = quote_get.or(webhook_post); - warp::serve(filter).bind_with_graceful_shutdown(listen, async move { - let _ = cancel.recv().await; - }).1.await; + let routes = webhook_post.or(quote_get); + warp::serve(routes) + .bind_with_graceful_shutdown(listen, async move { + let _ = cancel.recv().await; + }) + .1 + .await; tracing::info!("Web service finished"); } #[derive(Serialize)] struct QuotesTemplate { - quotes: Option> + quotes: Option>, + flash: Option, } -fn handle_get_quote(_: ExecutorConnection) -> impl Reply { - match HANDLEBARS.render("quotes", &QuotesTemplate{quotes: Some(vec![ - Quote{quote:"something".into(),author:"by someone".into()}, - Quote{quote:"something different".into(),author:"by someone else".into()}, - Quote{quote:"something even more different".into(),author:"by nobody".into()} - ])}) { +#[derive(Deserialize)] +struct QuotesQuery { + query: Option, +} + +async fn handle_get_quote(query: QuotesQuery, db: ExecutorConnection) -> impl Reply { + let template = if let Some(query) = query.query { + if let Some(quotes) = db.search(query.clone()).await { + let quotes_count = quotes.len(); + QuotesTemplate { + quotes: Some(quotes), + flash: Some(format!("Displaying {}/50 results for query \"{}\"", quotes_count, query)), + } + } else { + QuotesTemplate { + quotes: None, + flash: Some("A database error has occurred".into()), + } + } + } else { + QuotesTemplate { + quotes: None, + flash: None, + } + }; + match HANDLEBARS.render("quotes", &template) { Ok(o) => reply::html(o).into_response(), Err(e) => { tracing::warn!("Error while rendering template: {}", e); - reply::with_status("Failed to render template", StatusCode::INTERNAL_SERVER_ERROR).into_response() + reply::with_status( + "Failed to render template", + StatusCode::INTERNAL_SERVER_ERROR, + ) + .into_response() } } - } #[allow(clippy::needless_pass_by_value)]