diff --git a/src/bots_manager/mod.rs b/src/bots_manager/mod.rs index 016109f..d5e8087 100644 --- a/src/bots_manager/mod.rs +++ b/src/bots_manager/mod.rs @@ -63,7 +63,32 @@ pub static CHAT_DONATION_NOTIFICATIONS_CACHE: Lazy> = Lazy::ne }); -type Routes = Arc>>>>; +type Routes = Arc>)>>>; + + +struct ClosableSender { + origin: std::sync::Arc>>>, +} + +impl Clone for ClosableSender { + fn clone(&self) -> Self { + Self { origin: self.origin.clone() } + } +} + +impl ClosableSender { + fn new(sender: mpsc::UnboundedSender) -> Self { + Self { origin: std::sync::Arc::new(std::sync::RwLock::new(Some(sender))) } + } + + fn get(&self) -> Option> { + self.origin.read().unwrap().clone() + } + + fn close(&mut self) { + self.origin.write().unwrap().take(); + } +} #[derive(Default, Clone)] @@ -74,42 +99,36 @@ struct ServerState { pub struct BotsManager { port: u16, - stop_token: StopToken, - stop_flag: StopFlag, - state: ServerState } impl BotsManager { pub fn create() -> Self { - let (stop_token, stop_flag) = mk_stop_token(); - BotsManager { port: 8000, - stop_token, - stop_flag, - state: ServerState { routers: Arc::new(RwLock::new(HashMap::new())) } } } - fn get_listener(&self) -> (UnboundedSender>, impl UpdateListener) { + fn get_listener(&self) -> (StopToken, StopFlag, UnboundedSender>, impl UpdateListener) { let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel(); + let (stop_token, stop_flag) = mk_stop_token(); + let stream = UnboundedReceiverStream::new(rx); let listener = StatefulListener::new( - (stream, self.stop_token.clone()), + (stream, stop_token.clone()), tuple_first_mut, |state: &mut (_, StopToken)| { state.1.clone() }, ); - (tx, listener) + (stop_token, stop_flag, tx, listener) } async fn start_bot(&mut self, bot_data: &BotData) -> bool { @@ -137,11 +156,11 @@ impl BotsManager { .dependencies(dptree::deps![bot_data.cache]) .build(); - let (tx, listener) = self.get_listener(); + let (stop_token, _stop_flag, tx, listener) = self.get_listener(); { let mut routers = self.state.routers.write().unwrap(); - routers.insert(token.to_string(), tx); + routers.insert(token.to_string(), (stop_token, ClosableSender::new(tx))); } let host = format!("{}:{}", &config::CONFIG.webhook_base_url, self.port); @@ -191,44 +210,38 @@ impl BotsManager { async fn telegram_request( State(ServerState { routers }): State, Path(token): Path, - // secret_header: XTelegramBotApiSecretToken, input: String, ) -> impl IntoResponse { - // // FIXME: use constant time comparison here - // if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) { - // return StatusCode::UNAUTHORIZED; - // } let routes = routers.read().unwrap(); let tx = routes.get(&token); - let tx = match tx { - Some(tx) => { - tx - // match tx.get() { - // None => return StatusCode::SERVICE_UNAVAILABLE, - // // Do not process updates after `.stop()` is called even if the server is still - // // running (useful for when you need to stop the bot but can't stop the server). - // // TODO - // // _ if flag.is_stopped() => { - // // tx.close(); - // // return StatusCode::SERVICE_UNAVAILABLE; - // // } - // Some(tx) => tx, - // }; - }, + let (stop_token, r_tx) = match tx { + Some(tx) => tx, None => return StatusCode::NOT_FOUND, }; + let tx = match r_tx.get() { + Some(v) => v, + None => { + stop_token.stop(); + routers.write().unwrap().remove(&token); + return StatusCode::SERVICE_UNAVAILABLE; + }, + }; + match serde_json::from_str::(&input) { Ok(mut update) => { - // See HACK comment in - // `teloxide_core::net::request::process_response::{closure#0}` if let UpdateKind::Error(value) = &mut update.kind { *value = serde_json::from_str(&input).unwrap_or_default(); } - tx.send(Ok(update)).expect("Cannot send an incoming update from the webhook") + if let Err(err) = tx.send(Ok(update)) { + log::error!("{:?}", err); + stop_token.stop(); + routers.write().unwrap().remove(&token); + return StatusCode::SERVICE_UNAVAILABLE; + } } Err(error) => { log::error!( @@ -244,8 +257,6 @@ impl BotsManager { StatusCode::OK } - let stop_token = self.stop_token.clone(); - let stop_flag = self.stop_flag.clone(); let port = self.port; let router = axum::Router::new() @@ -260,18 +271,23 @@ impl BotsManager { axum::Server::bind(&addr) .serve(router.into_make_service()) - .with_graceful_shutdown(stop_flag) .await - .map_err(|err| { - stop_token.stop(); - err - }) .expect("Axum server error"); log::info!("Webserver shutdown..."); }); } + pub async fn stop_all(self) { + let routers = self.state.routers.read().unwrap(); + + for (stop_token, _) in routers.values() { + stop_token.stop(); + } + + sleep(Duration::from_secs(5)).await; + } + pub async fn start(running: Arc) { let mut manager = BotsManager::create(); @@ -281,14 +297,10 @@ impl BotsManager { loop { if !running.load(Ordering::SeqCst) { - manager.stop_token.stop(); + manager.stop_all().await; return; }; - if manager.stop_flag.is_stopped() { - return; - } - if i == 0 { manager.check().await; }