diff --git a/rust-lib/flowy-user/src/services/user/user_session.rs b/rust-lib/flowy-user/src/services/user/user_session.rs index 350cb3ab69..ef5fd115b7 100644 --- a/rust-lib/flowy-user/src/services/user/user_session.rs +++ b/rust-lib/flowy-user/src/services/user/user_session.rs @@ -18,7 +18,7 @@ use flowy_database::{ }; use flowy_infra::kv::KV; use flowy_sqlite::ConnectionPool; -use flowy_ws::{WsController, WsMessage, WsMessageHandler}; +use flowy_ws::{connect::Retry, WsController, WsMessage, WsMessageHandler}; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::{sync::Arc, time::Duration}; @@ -47,7 +47,7 @@ pub struct UserSession { #[allow(dead_code)] server: Server, session: RwLock>, - ws_controller: RwLock, + ws_controller: Arc>, status_callback: SessionStatusCallback, } @@ -55,7 +55,7 @@ impl UserSession { pub fn new(config: UserSessionConfig, status_callback: SessionStatusCallback) -> Self { let db = UserDB::new(&config.root_dir); let server = construct_user_server(); - let ws_controller = RwLock::new(WsController::new()); + let ws_controller = Arc::new(RwLock::new(WsController::new())); let user_session = Self { database: db, config, @@ -278,7 +278,12 @@ impl UserSession { fn start_ws_connection(&self, token: &str) -> Result<(), UserError> { let addr = format!("{}/{}", flowy_net::config::WS_ADDR.as_str(), token); - let _ = self.ws_controller.write().connect(addr); + let ws_controller = self.ws_controller.clone(); + let retry = Retry::new(&addr, move |addr| { + ws_controller.write().connect(addr.to_owned()); + }); + + let _ = self.ws_controller.write().connect_with_retry(addr, retry); Ok(()) } } diff --git a/rust-lib/flowy-ws/src/connect.rs b/rust-lib/flowy-ws/src/connect.rs index 065c17a889..1a01d70ec3 100644 --- a/rust-lib/flowy-ws/src/connect.rs +++ b/rust-lib/flowy-ws/src/connect.rs @@ -56,8 +56,11 @@ impl Future for WsConnection { loop { return match ready!(self.as_mut().project().fut.poll(cx)) { Ok((stream, _)) => { - log::debug!("🐴 ws connect success: {:?}", error); - let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); + log::debug!("🐴 ws connect success"); + let (msg_tx, ws_rx) = ( + self.msg_tx.take().expect("WsConnection should be call once "), + self.ws_rx.take().expect("WsConnection should be call once "), + ); Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream))) }, Err(error) => { @@ -72,6 +75,7 @@ impl Future for WsConnection { type Fut = BoxFuture<'static, Result<(), WsError>>; #[pin_project] pub struct WsStream { + msg_tx: MsgSender, #[pin] inner: Option<(Fut, Fut)>, } @@ -80,6 +84,7 @@ impl WsStream { pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream>) -> Self { let (ws_write, ws_read) = stream.split(); Self { + msg_tx: msg_tx.clone(), inner: Some(( Box::pin(async move { let _ = ws_read.for_each(|message| async { post_message(msg_tx.clone(), message) }).await; @@ -102,15 +107,15 @@ impl Future for WsStream { type Output = Result<(), WsError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let (mut left, mut right) = self.inner.take().unwrap(); - match left.poll_unpin(cx) { + let (mut ws_read, mut ws_write) = self.inner.take().unwrap(); + match ws_read.poll_unpin(cx) { Poll::Ready(l) => Poll::Ready(l), Poll::Pending => { // - match right.poll_unpin(cx) { + match ws_write.poll_unpin(cx) { Poll::Ready(r) => Poll::Ready(r), Poll::Pending => { - self.inner = Some((left, right)); + self.inner = Some((ws_read, ws_write)); Poll::Pending }, } @@ -144,3 +149,35 @@ fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> Serv error } + +pub struct Retry { + f: F, + retry_time: usize, + addr: String, +} + +impl Retry +where + F: Fn(&str), +{ + pub fn new(addr: &str, f: F) -> Self { + Self { + f, + retry_time: 3, + addr: addr.to_owned(), + } + } +} + +impl Future for Retry +where + F: Fn(&str), +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + (self.f)(&self.addr); + + Poll::Ready(()) + } +} diff --git a/rust-lib/flowy-ws/src/lib.rs b/rust-lib/flowy-ws/src/lib.rs index 307c5ff600..65e5213152 100644 --- a/rust-lib/flowy-ws/src/lib.rs +++ b/rust-lib/flowy-ws/src/lib.rs @@ -1,4 +1,4 @@ -mod connect; +pub mod connect; pub mod errors; mod msg; pub mod protobuf; diff --git a/rust-lib/flowy-ws/src/ws.rs b/rust-lib/flowy-ws/src/ws.rs index 59edac514d..bbed3e95b0 100644 --- a/rust-lib/flowy-ws/src/ws.rs +++ b/rust-lib/flowy-ws/src/ws.rs @@ -3,10 +3,13 @@ use flowy_net::errors::ServerError; use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures_core::{ready, Stream}; +use crate::connect::Retry; +use futures_core::future::BoxFuture; use pin_project::pin_project; use std::{ collections::HashMap, future::Future, + marker::PhantomData, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -24,6 +27,7 @@ pub trait WsMessageHandler: Sync + Send + 'static { pub struct WsController { sender: Option>, handlers: HashMap>, + addr: Option, } impl WsController { @@ -31,6 +35,7 @@ impl WsController { let controller = Self { sender: None, handlers: HashMap::new(), + addr: None, }; controller } @@ -44,25 +49,41 @@ impl WsController { Ok(()) } - pub fn connect(&mut self, addr: String) -> Result, ServerError> { + pub fn connect(&mut self, addr: String) -> Result, ServerError> { self._connect(addr.clone(), None) } + + pub fn connect_with_retry(&mut self, addr: String, retry: Retry) -> Result, ServerError> + where + F: Fn(&str) + Send + Sync + 'static, + { + self._connect(addr, Some(Box::pin(async { retry.await }))) + } + + fn _connect(&mut self, addr: String, retry: Option>) -> Result, ServerError> { log::debug!("🐴 ws connect: {}", &addr); - let (connection, handlers) = self.make_connect(addr); - Ok(tokio::spawn(async { - tokio::select! { - result = connection => { - match result { - Ok(stream) => { - tokio::spawn(stream).await; - // stream.start().await; + let (connection, handlers) = self.make_connect(addr.clone()); + Ok(tokio::spawn(async move { + match connection.await { + Ok(stream) => { + tokio::select! { + result = stream => { + match result { + Ok(_) => {}, + Err(e) => { + // TODO: retry? + log::error!("ws stream error {:?}", e); + } + } }, - Err(e) => { - // TODO: retry? - log::error!("ws connect failed {:?}", e); - } - } + result = handlers => log::debug!("handlers completed {:?}", result), + }; }, - result = handlers => log::debug!("handlers completed {:?}", result), - }; + Err(e) => match retry { + None => log::error!("ws connect {} failed {:?}", addr, e), + Some(retry) => { + tokio::spawn(retry); + }, + }, + } })) } @@ -81,6 +102,7 @@ impl WsController { let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded(); let handlers = self.handlers.clone(); self.sender = Some(Arc::new(WsSender::new(ws_tx))); + self.addr = Some(addr.clone()); (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx)) } @@ -109,10 +131,10 @@ impl Future for WsHandlers { loop { match ready!(self.as_mut().project().msg_rx.poll_next(cx)) { None => { - // log::debug!("🐴 ws handler done"); - return Poll::Pending; + return Poll::Ready(()); }, Some(message) => { + log::debug!("🐴 ws handler receive message"); let message = WsMessage::from(message); match self.handlers.get(&message.source) { None => log::error!("Can't find any handler for message: {:?}", message),