This commit is contained in:
2023-08-09 02:17:44 +02:00
parent eff6960f3e
commit 9b35b41ad1

View File

@@ -63,7 +63,32 @@ pub static CHAT_DONATION_NOTIFICATIONS_CACHE: Lazy<Cache<ChatId, ()>> = Lazy::ne
}); });
type Routes = Arc<RwLock<HashMap<String, UnboundedSender<Result<Update, std::convert::Infallible>>>>>; type Routes = Arc<RwLock<HashMap<String, (StopToken, ClosableSender<Result<Update, std::convert::Infallible>>)>>>;
struct ClosableSender<T> {
origin: std::sync::Arc<std::sync::RwLock<Option<mpsc::UnboundedSender<T>>>>,
}
impl<T> Clone for ClosableSender<T> {
fn clone(&self) -> Self {
Self { origin: self.origin.clone() }
}
}
impl<T> ClosableSender<T> {
fn new(sender: mpsc::UnboundedSender<T>) -> Self {
Self { origin: std::sync::Arc::new(std::sync::RwLock::new(Some(sender))) }
}
fn get(&self) -> Option<mpsc::UnboundedSender<T>> {
self.origin.read().unwrap().clone()
}
fn close(&mut self) {
self.origin.write().unwrap().take();
}
}
#[derive(Default, Clone)] #[derive(Default, Clone)]
@@ -74,42 +99,36 @@ struct ServerState {
pub struct BotsManager { pub struct BotsManager {
port: u16, port: u16,
stop_token: StopToken,
stop_flag: StopFlag,
state: ServerState state: ServerState
} }
impl BotsManager { impl BotsManager {
pub fn create() -> Self { pub fn create() -> Self {
let (stop_token, stop_flag) = mk_stop_token();
BotsManager { BotsManager {
port: 8000, port: 8000,
stop_token,
stop_flag,
state: ServerState { state: ServerState {
routers: Arc::new(RwLock::new(HashMap::new())) routers: Arc::new(RwLock::new(HashMap::new()))
} }
} }
} }
fn get_listener(&self) -> (UnboundedSender<Result<Update, std::convert::Infallible>>, impl UpdateListener<Err = Infallible>) { fn get_listener(&self) -> (StopToken, StopFlag, UnboundedSender<Result<Update, std::convert::Infallible>>, impl UpdateListener<Err = Infallible>) {
let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel(); let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel();
let (stop_token, stop_flag) = mk_stop_token();
let stream = UnboundedReceiverStream::new(rx); let stream = UnboundedReceiverStream::new(rx);
let listener = StatefulListener::new( let listener = StatefulListener::new(
(stream, self.stop_token.clone()), (stream, stop_token.clone()),
tuple_first_mut, tuple_first_mut,
|state: &mut (_, StopToken)| { |state: &mut (_, StopToken)| {
state.1.clone() state.1.clone()
}, },
); );
(tx, listener) (stop_token, stop_flag, tx, listener)
} }
async fn start_bot(&mut self, bot_data: &BotData) -> bool { async fn start_bot(&mut self, bot_data: &BotData) -> bool {
@@ -137,11 +156,11 @@ impl BotsManager {
.dependencies(dptree::deps![bot_data.cache]) .dependencies(dptree::deps![bot_data.cache])
.build(); .build();
let (tx, listener) = self.get_listener(); let (stop_token, _stop_flag, tx, listener) = self.get_listener();
{ {
let mut routers = self.state.routers.write().unwrap(); let mut routers = self.state.routers.write().unwrap();
routers.insert(token.to_string(), tx); routers.insert(token.to_string(), (stop_token, ClosableSender::new(tx)));
} }
let host = format!("{}:{}", &config::CONFIG.webhook_base_url, self.port); let host = format!("{}:{}", &config::CONFIG.webhook_base_url, self.port);
@@ -191,44 +210,38 @@ impl BotsManager {
async fn telegram_request( async fn telegram_request(
State(ServerState { routers }): State<ServerState>, State(ServerState { routers }): State<ServerState>,
Path(token): Path<String>, Path(token): Path<String>,
// secret_header: XTelegramBotApiSecretToken,
input: String, input: String,
) -> impl IntoResponse { ) -> impl IntoResponse {
// // FIXME: use constant time comparison here
// if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) {
// return StatusCode::UNAUTHORIZED;
// }
let routes = routers.read().unwrap(); let routes = routers.read().unwrap();
let tx = routes.get(&token); let tx = routes.get(&token);
let tx = match tx { let (stop_token, r_tx) = match tx {
Some(tx) => { Some(tx) => tx,
tx
// match tx.get() {
// None => return StatusCode::SERVICE_UNAVAILABLE,
// // Do not process updates after `.stop()` is called even if the server is still
// // running (useful for when you need to stop the bot but can't stop the server).
// // TODO
// // _ if flag.is_stopped() => {
// // tx.close();
// // return StatusCode::SERVICE_UNAVAILABLE;
// // }
// Some(tx) => tx,
// };
},
None => return StatusCode::NOT_FOUND, None => return StatusCode::NOT_FOUND,
}; };
let tx = match r_tx.get() {
Some(v) => v,
None => {
stop_token.stop();
routers.write().unwrap().remove(&token);
return StatusCode::SERVICE_UNAVAILABLE;
},
};
match serde_json::from_str::<Update>(&input) { match serde_json::from_str::<Update>(&input) {
Ok(mut update) => { Ok(mut update) => {
// See HACK comment in
// `teloxide_core::net::request::process_response::{closure#0}`
if let UpdateKind::Error(value) = &mut update.kind { if let UpdateKind::Error(value) = &mut update.kind {
*value = serde_json::from_str(&input).unwrap_or_default(); *value = serde_json::from_str(&input).unwrap_or_default();
} }
tx.send(Ok(update)).expect("Cannot send an incoming update from the webhook") if let Err(err) = tx.send(Ok(update)) {
log::error!("{:?}", err);
stop_token.stop();
routers.write().unwrap().remove(&token);
return StatusCode::SERVICE_UNAVAILABLE;
}
} }
Err(error) => { Err(error) => {
log::error!( log::error!(
@@ -244,8 +257,6 @@ impl BotsManager {
StatusCode::OK StatusCode::OK
} }
let stop_token = self.stop_token.clone();
let stop_flag = self.stop_flag.clone();
let port = self.port; let port = self.port;
let router = axum::Router::new() let router = axum::Router::new()
@@ -260,18 +271,23 @@ impl BotsManager {
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(router.into_make_service()) .serve(router.into_make_service())
.with_graceful_shutdown(stop_flag)
.await .await
.map_err(|err| {
stop_token.stop();
err
})
.expect("Axum server error"); .expect("Axum server error");
log::info!("Webserver shutdown..."); log::info!("Webserver shutdown...");
}); });
} }
pub async fn stop_all(self) {
let routers = self.state.routers.read().unwrap();
for (stop_token, _) in routers.values() {
stop_token.stop();
}
sleep(Duration::from_secs(5)).await;
}
pub async fn start(running: Arc<AtomicBool>) { pub async fn start(running: Arc<AtomicBool>) {
let mut manager = BotsManager::create(); let mut manager = BotsManager::create();
@@ -281,14 +297,10 @@ impl BotsManager {
loop { loop {
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
manager.stop_token.stop(); manager.stop_all().await;
return; return;
}; };
if manager.stop_flag.is_stopped() {
return;
}
if i == 0 { if i == 0 {
manager.check().await; manager.check().await;
} }