tokio::spawn() instead of a thicc select!

This commit is contained in:
lemon-sh 2022-01-26 19:17:07 +01:00
parent 5577475b6c
commit ac5f14a21e

View file

@ -2,19 +2,19 @@ use std::fmt::Write;
use std::fs::File;
use std::io::Read;
use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use std::{collections::HashMap, env};
use arrayvec::ArrayString;
use futures_util::stream::StreamExt;
use irc::client::prelude::Config;
use irc::client::Client;
use irc::client::{Client, ClientStream};
use irc::proto::{ChannelExt, Command, Prefix};
use rspotify::Credentials;
use serde::Deserialize;
use tokio::select;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tracing_log::LogTracer;
use tokio::sync::broadcast;
use tokio::sync::mpsc::{unbounded_channel};
use tracing_subscriber::EnvFilter;
use crate::bots::{leek, misc, sed, title};
@ -55,12 +55,12 @@ async fn terminate_signal() {
pub struct AppState {
prefix: String,
client: Client,
client: Arc<Client>,
stream: ClientStream,
last_msgs: HashMap<String, String>,
last_eval: HashMap<String, f64>,
titlebot: title::Titlebot,
db: ExecutorConnection,
git_channel: String,
}
#[derive(Deserialize)]
@ -80,9 +80,8 @@ struct ClientConf {
git_channel: String,
}
#[tokio::main(flavor = "current_thread")]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
LogTracer::init()?;
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_env("UBERBOT_LOG"))
.init();
@ -112,7 +111,11 @@ async fn main() -> anyhow::Result<()> {
let uber_ver = concat!("Überbot ", env!("CARGO_PKG_VERSION"));
let irc_config = Config {
nickname: client_config.nickname,
nickname: Some(
client_config
.nickname
.unwrap_or_else(|| client_config.username.clone()),
),
username: Some(client_config.username.clone()),
realname: Some(client_config.username),
server: Some(client_config.host),
@ -124,66 +127,76 @@ async fn main() -> anyhow::Result<()> {
version: Some(uber_ver.into()),
..Config::default()
};
let client = Client::from_config(irc_config).await?;
let mut client = Client::from_config(irc_config).await?;
let stream = client.stream()?;
client.identify()?;
let client = Arc::new(client);
let (ctx, _) = broadcast::channel(1);
let (etx, mut erx) = unbounded_channel();
let web_task = tokio::spawn(web_service::run(
db_conn.clone(),
client.clone(),
client_config.git_channel,
http_listen,
ctx.subscribe()
));
let state = AppState {
prefix: client_config.prefix,
client,
client: client.clone(),
stream,
last_msgs: HashMap::new(),
last_eval: HashMap::new(),
titlebot: title::Titlebot::create(spotify_creds).await?,
db: db_conn,
git_channel: client_config.git_channel,
};
let message_loop_task = tokio::spawn(async move {
if let Err(e) = message_loop(state).await {
let _ = etx.send(e);
}
});
let (git_tx, git_recv) = channel(512);
if let Err(e) = executor(state, git_tx, git_recv, http_listen).await {
tracing::error!("Error in message loop: {}", e);
}
if let Err(e) = exec_thread.join() {
tracing::error!("Error while shutting down the database: {:?}", e);
}
tracing::info!("Shutting down");
Ok(())
}
async fn executor(
mut state: AppState,
git_tx: Sender<String>,
mut git_recv: Receiver<String>,
http_listen: SocketAddr,
) -> anyhow::Result<()> {
let web_db = state.db.clone();
select! {
r = web_service::run(web_db, git_tx, http_listen) => r?,
r = message_loop(&mut state) => r?,
r = git_recv.recv() => {
if let Some(message) = r {
state.client.send_privmsg(&state.git_channel, &message)?;
tokio::select! {
_ = terminate_signal() => {
tracing::info!("Received shutdown signal, sending QUIT message");
client.send_quit("überbot shutting down")?;
}
e = erx.recv() => {
if let Some(e) = e {
tracing::error!("An error has occurred, shutting down: {}", e);
} else {
tracing::error!("Error channel has been dropped due to an unknown error, shutting down");
}
}
_ = terminate_signal() => {
tracing::info!("Sending QUIT message");
state.client.send_quit("überbot shutting down")?;
}
}
tracing::info!("Closing services...");
let _ = ctx.send(());
web_task
.await
.unwrap_or_else(|e| tracing::warn!("Couldn't join the web service: {:?}", e));
message_loop_task
.await
.unwrap_or_else(|e| tracing::warn!("Couldn't join the web service: {:?}", e));
exec_thread
.join()
.unwrap_or_else(|e| tracing::warn!("Couldn't join the database: {:?}", e));
tracing::info!("Shutdown complete!");
Ok(())
}
async fn message_loop(state: &mut AppState) -> anyhow::Result<()> {
let mut stream = state.client.stream()?;
while let Some(message) = stream.next().await.transpose()? {
async fn message_loop(mut state: AppState) -> anyhow::Result<()> {
while let Some(message) = state.stream.next().await.transpose()? {
if let Command::PRIVMSG(ref origin, content) = message.command {
if origin.is_channel_name() {
if let Some(author) = message.prefix.as_ref().and_then(|p| match p {
Prefix::Nickname(name, _, _) => Some(&name[..]),
_ => None,
}) {
if let Err(e) = handle_privmsg(state, author, origin, content).await {
if let Err(e) = handle_privmsg(&mut state, author, origin, content).await {
state
.client
.send_privmsg(origin, &format!("Error: {}", e))?;
@ -194,6 +207,7 @@ async fn message_loop(state: &mut AppState) -> anyhow::Result<()> {
}
}
}
tracing::info!("Message loop finished");
Ok(())
}