diff --git a/backend/Cargo.toml b/backend/Cargo.toml index f4de65394d..1e0d8b106b 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -12,7 +12,6 @@ actix-http = "2.2.1" actix-web-actors = "3" actix-codec = "0.3" - futures = "0.3.15" bytes = "0.5" toml = "0.5.8" @@ -21,7 +20,10 @@ log = "0.4.14" serde_json = "1.0" serde = { version = "1.0", features = ["derive"] } serde_repr = "0.1" +derive_more = {version = "0.99", features = ["display"]} +protobuf = {version = "2.20.0"} flowy-log = { path = "../rust-lib/flowy-log" } +flowy-user = { path = "../rust-lib/flowy-user" } [dependencies.sqlx] version = "0.5.2" diff --git a/backend/src/config/config.rs b/backend/src/config/config.rs index 430cc05de2..f820a19c06 100644 --- a/backend/src/config/config.rs +++ b/backend/src/config/config.rs @@ -1,11 +1,18 @@ -use std::convert::TryFrom; +use crate::config::DatabaseConfig; +use std::{convert::TryFrom, sync::Arc}; pub struct Config { pub http_port: u16, + pub database: Arc, } impl Config { - pub fn new() -> Self { Config { http_port: 3030 } } + pub fn new() -> Self { + Config { + http_port: 3030, + database: Arc::new(DatabaseConfig::default()), + } + } pub fn server_addr(&self) -> String { format!("0.0.0.0:{}", self.http_port) } } diff --git a/backend/src/config/const_define.rs b/backend/src/config/const_define.rs index b0eb7de9c4..60cadfdf0e 100644 --- a/backend/src/config/const_define.rs +++ b/backend/src/config/const_define.rs @@ -2,3 +2,4 @@ use std::time::Duration; pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8); pub const PING_TIMEOUT: Duration = Duration::from_secs(60); +pub const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k diff --git a/backend/src/config/database/config.toml b/backend/src/config/database/config.toml new file mode 100644 index 0000000000..2c3db17d38 --- /dev/null +++ b/backend/src/config/database/config.toml @@ -0,0 +1,5 @@ +host = "localhost" +port = 5433 +username = "postgres" +password = "password" +database_name = "flowy" \ No newline at end of file diff --git a/backend/src/config/database/database.rs b/backend/src/config/database/database.rs new file mode 100644 index 0000000000..1013f26f16 --- /dev/null +++ b/backend/src/config/database/database.rs @@ -0,0 +1,33 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct DatabaseConfig { + username: String, + password: String, + port: u16, + host: String, + database_name: String, +} + +impl DatabaseConfig { + pub fn connect_url(&self) -> String { + format!( + "postgres://{}:{}@{}:{}/{}", + self.username, self.password, self.host, self.port, self.database_name + ) + } + + pub fn set_env_db_url(&self) { + let url = self.connect_url(); + std::env::set_var("DATABASE_URL", url); + } +} + +impl std::default::Default for DatabaseConfig { + fn default() -> DatabaseConfig { + let toml_str: &str = include_str!("config.toml"); + let config: DatabaseConfig = toml::from_str(toml_str).unwrap(); + config.set_env_db_url(); + config + } +} diff --git a/backend/src/config/database/mod.rs b/backend/src/config/database/mod.rs new file mode 100644 index 0000000000..d505655ade --- /dev/null +++ b/backend/src/config/database/mod.rs @@ -0,0 +1,3 @@ +mod database; + +pub use database::*; diff --git a/backend/src/config/mod.rs b/backend/src/config/mod.rs index 8d0e0324bf..bb3ca0c2c3 100644 --- a/backend/src/config/mod.rs +++ b/backend/src/config/mod.rs @@ -1,5 +1,7 @@ mod config; mod const_define; +mod database; pub use config::*; pub use const_define::*; +pub use database::*; diff --git a/backend/src/context.rs b/backend/src/context.rs index b53df14de7..9db0fd9f3b 100644 --- a/backend/src/context.rs +++ b/backend/src/context.rs @@ -1,19 +1,28 @@ -use crate::{config::Config, ws_service::WSServer}; +use crate::{config::Config, user_service::Auth, ws_service::WSServer}; use actix::Addr; + +use sqlx::PgPool; use std::sync::Arc; pub struct AppContext { pub config: Arc, - pub server: Addr, + pub ws_server: Addr, + pub db_pool: Arc, + pub auth: Arc, } impl AppContext { - pub fn new(server: Addr) -> Self { + pub fn new( + config: Arc, + ws_server: Addr, + db_pool: Arc, + auth: Arc, + ) -> Self { AppContext { - config: Arc::new(Config::new()), - server, + config, + ws_server, + db_pool, + auth, } } - - pub fn ws_server(&self) -> Addr { self.server.clone() } } diff --git a/backend/src/entities/mod.rs b/backend/src/entities/mod.rs new file mode 100644 index 0000000000..9dbbe4cc26 --- /dev/null +++ b/backend/src/entities/mod.rs @@ -0,0 +1,7 @@ +mod response; +mod response_serde; +mod server_code; + +pub use response::*; +pub use response_serde::*; +pub use server_code::*; diff --git a/backend/src/entities/response.rs b/backend/src/entities/response.rs new file mode 100644 index 0000000000..6a020b101f --- /dev/null +++ b/backend/src/entities/response.rs @@ -0,0 +1,45 @@ +use crate::{entities::ServerCode, errors::ServerError}; +use actix_web::{body::Body, HttpResponse, ResponseError}; + +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct ServerResponse { + pub msg: String, + pub data: Option, + pub code: ServerCode, +} + +impl ServerResponse { + pub fn new(data: Option, msg: &str, code: ServerCode) -> Self { + ServerResponse { + msg: msg.to_owned(), + data, + code, + } + } + + pub fn from_data(data: T, msg: &str, code: ServerCode) -> Self { + Self::new(Some(data), msg, code) + } +} + +impl ServerResponse { + pub fn success() -> Self { Self::from_msg("", ServerCode::Success) } + + pub fn from_msg(msg: &str, code: ServerCode) -> Self { + Self::new(Some("".to_owned()), msg, code) + } +} + +impl std::convert::Into for ServerResponse { + fn into(self) -> HttpResponse { + match serde_json::to_string(&self) { + Ok(body) => HttpResponse::Ok().body(Body::from(body)), + Err(e) => { + let msg = format!("Serial error: {:?}", e); + ServerError::InternalError(msg).error_response() + }, + } + } +} diff --git a/backend/src/entities/response_serde.rs b/backend/src/entities/response_serde.rs new file mode 100644 index 0000000000..7149fa0980 --- /dev/null +++ b/backend/src/entities/response_serde.rs @@ -0,0 +1,128 @@ +use crate::entities::{ServerCode, ServerResponse}; +use serde::{ + de::{self, MapAccess, Visitor}, + Deserialize, + Deserializer, + Serialize, +}; +use std::{fmt, marker::PhantomData, str::FromStr}; + +pub trait ServerData<'a>: Serialize + Deserialize<'a> + FromStr {} +impl<'de, T: ServerData<'de>> Deserialize<'de> for ServerResponse { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ServerResponseVisitor(PhantomData T>); + impl<'de, T> Visitor<'de> for ServerResponseVisitor + where + T: ServerData<'de>, + { + type Value = ServerResponse; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Duration") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut msg = None; + let mut data: Option = None; + let mut code: Option = None; + while let Some(key) = map.next_key()? { + match key { + "msg" => { + if msg.is_some() { + return Err(de::Error::duplicate_field("msg")); + } + msg = Some(map.next_value()?); + }, + "code" => { + if code.is_some() { + return Err(de::Error::duplicate_field("code")); + } + code = Some(map.next_value()?); + }, + "data" => { + if data.is_some() { + return Err(de::Error::duplicate_field("data")); + } + data = match MapAccess::next_value::>(&mut map) { + Ok(wrapper) => wrapper.value, + Err(err) => return Err(err), + }; + }, + _ => panic!(), + } + } + let msg = msg.ok_or_else(|| de::Error::missing_field("msg"))?; + let code = code.ok_or_else(|| de::Error::missing_field("code"))?; + Ok(Self::Value::new(data, msg, code)) + } + } + const FIELDS: &'static [&'static str] = &["msg", "code", "data"]; + deserializer.deserialize_struct( + "ServerResponse", + FIELDS, + ServerResponseVisitor(PhantomData), + ) + } +} + +struct DeserializeWith<'de, T: ServerData<'de>> { + value: Option, + phantom: PhantomData<&'de ()>, +} + +impl<'de, T: ServerData<'de>> Deserialize<'de> for DeserializeWith<'de, T> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(DeserializeWith { + value: match string_or_data(deserializer) { + Ok(val) => val, + Err(e) => return Err(e), + }, + phantom: PhantomData, + }) + } +} + +fn string_or_data<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, + T: ServerData<'de>, +{ + struct StringOrData(PhantomData T>); + impl<'de, T: ServerData<'de>> Visitor<'de> for StringOrData { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string or struct impl deserialize") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match FromStr::from_str(value) { + Ok(val) => Ok(Some(val)), + Err(_e) => Ok(None), + } + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + match Deserialize::deserialize(de::value::MapAccessDeserializer::new(map)) { + Ok(val) => Ok(Some(val)), + Err(e) => Err(e), + } + } + } + deserializer.deserialize_any(StringOrData(PhantomData)) +} diff --git a/backend/src/entities/server_code.rs b/backend/src/entities/server_code.rs new file mode 100644 index 0000000000..f0790f0a71 --- /dev/null +++ b/backend/src/entities/server_code.rs @@ -0,0 +1,12 @@ +use serde_repr::*; + +#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] +#[repr(u16)] +pub enum ServerCode { + Success = 0, + InvalidToken = 1, + InternalError = 2, + Unauthorized = 3, + PayloadOverflow = 4, + PayloadSerdeFail = 5, +} diff --git a/backend/src/errors.rs b/backend/src/errors.rs index acbf82d2ed..5da2cccc23 100644 --- a/backend/src/errors.rs +++ b/backend/src/errors.rs @@ -1,3 +1,48 @@ -pub struct ServerError {} +use crate::entities::{ServerCode, ServerResponse}; +use actix_web::{error::ResponseError, HttpResponse}; +use protobuf::ProtobufError; +use std::fmt::Formatter; -// pub enum ErrorCode {} +#[derive(Debug)] +pub enum ServerError { + InternalError(String), + BadRequest(ServerResponse), + Unauthorized, +} + +impl std::fmt::Display for ServerError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ServerError::InternalError(_) => f.write_str("Internal Server Error"), + ServerError::BadRequest(request) => { + let msg = format!("Bad Request: {:?}", request); + f.write_str(&msg) + }, + ServerError::Unauthorized => f.write_str("Unauthorized"), + } + } +} + +impl ResponseError for ServerError { + fn error_response(&self) -> HttpResponse { + match self { + ServerError::InternalError(msg) => { + let msg = format!("Internal Server Error. {}", msg); + let resp = ServerResponse::from_msg(&msg, ServerCode::InternalError); + HttpResponse::InternalServerError().json(resp) + }, + ServerError::BadRequest(ref resp) => HttpResponse::BadRequest().json(resp), + ServerError::Unauthorized => { + let resp = ServerResponse::from_msg("Unauthorized", ServerCode::Unauthorized); + HttpResponse::Unauthorized().json(resp) + }, + } + } +} + +impl std::convert::From for ServerError { + fn from(err: ProtobufError) -> Self { + let msg = format!("{:?}", err); + ServerError::InternalError(msg) + } +} diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 1ff4a08f8d..4333f4ea66 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,6 +1,8 @@ mod config; mod context; +mod entities; mod errors; mod routers; pub mod startup; +pub mod user_service; pub mod ws_service; diff --git a/backend/src/routers/helper.rs b/backend/src/routers/helper.rs new file mode 100644 index 0000000000..b7d7fb5de9 --- /dev/null +++ b/backend/src/routers/helper.rs @@ -0,0 +1,34 @@ +use crate::{ + config::MAX_PAYLOAD_SIZE, + entities::{ServerCode, ServerResponse}, + errors::ServerError, +}; +use actix_web::web; +use futures::StreamExt; +use protobuf::{Message, ProtobufResult}; + +pub async fn parse_from_payload(payload: web::Payload) -> Result { + let bytes = poll_payload(payload).await?; + parse_from_bytes(&bytes) +} + +pub fn parse_from_bytes(bytes: &[u8]) -> Result { + let result: ProtobufResult = Message::parse_from_bytes(&bytes); + match result { + Ok(data) => Ok(data), + Err(e) => Err(e.into()), + } +} + +pub async fn poll_payload(mut payload: web::Payload) -> Result { + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk.map_err(|e| ServerError::InternalError(format!("{:?}", e)))?; + if (body.len() + chunk.len()) > MAX_PAYLOAD_SIZE { + let resp = ServerResponse::from_msg("Payload overflow", ServerCode::PayloadOverflow); + return Err(ServerError::BadRequest(resp)); + } + body.extend_from_slice(&chunk); + } + Ok(body) +} diff --git a/backend/src/routers/mod.rs b/backend/src/routers/mod.rs index 15edb0affa..41bdf89771 100644 --- a/backend/src/routers/mod.rs +++ b/backend/src/routers/mod.rs @@ -1,3 +1,6 @@ +mod helper; +mod user; pub(crate) mod ws; +pub use user::*; pub use ws::*; diff --git a/backend/src/routers/user.rs b/backend/src/routers/user.rs new file mode 100644 index 0000000000..bc482c1a9e --- /dev/null +++ b/backend/src/routers/user.rs @@ -0,0 +1,23 @@ +use crate::user_service::Auth; +use actix_web::{ + web::{Data, Payload}, + Error, + HttpRequest, + HttpResponse, +}; +use flowy_user::protobuf::SignUpRequest; + +use crate::{entities::ServerResponse, routers::helper::parse_from_payload}; + +use std::sync::Arc; + +pub async fn user_register( + request: HttpRequest, + payload: Payload, + auth: Data>, +) -> Result { + let request: SignUpRequest = parse_from_payload(payload).await?; + // ProtobufError + let resp = ServerResponse::success(); + Ok(resp.into()) +} diff --git a/backend/src/routers/ws.rs b/backend/src/routers/ws.rs index 3ce3584c9f..199d8d3d25 100644 --- a/backend/src/routers/ws.rs +++ b/backend/src/routers/ws.rs @@ -1,6 +1,6 @@ use crate::ws_service::{entities::SessionId, WSClient, WSServer}; use actix::Addr; -use actix_http::{body::Body, Response}; + use actix_web::{ get, web::{Data, Path, Payload}, diff --git a/backend/src/startup.rs b/backend/src/startup.rs index c53ff6b0b8..4f0d3d6b78 100644 --- a/backend/src/startup.rs +++ b/backend/src/startup.rs @@ -1,6 +1,13 @@ -use crate::{context::AppContext, routers::*, ws_service::WSServer}; +use crate::{ + config::Config, + context::AppContext, + routers::*, + user_service::Auth, + ws_service::WSServer, +}; use actix::Actor; use actix_web::{dev::Server, middleware, web, App, HttpServer, Scope}; +use sqlx::PgPool; use std::{net::TcpListener, sync::Arc}; pub fn run(app_ctx: Arc, listener: TcpListener) -> Result { @@ -9,7 +16,9 @@ pub fn run(app_ctx: Arc, listener: TcpListener) -> Result Scope { web::scope("/ws").service(ws::start_connection) } pub async fn init_app_context() -> Arc { let _ = flowy_log::Builder::new("flowy").env_filter("Debug").build(); + let config = Arc::new(Config::new()); + + // TODO: what happened when PgPool connect fail? + let db_pool = Arc::new( + PgPool::connect(&config.database.connect_url()) + .await + .expect("Failed to connect to Postgres."), + ); let ws_server = WSServer::new().start(); - let ctx = AppContext::new(ws_server); + let auth = Arc::new(Auth::new(db_pool.clone())); + + let ctx = AppContext::new(config, ws_server, db_pool, auth); Arc::new(ctx) } diff --git a/backend/src/user_service/auth.rs b/backend/src/user_service/auth.rs new file mode 100644 index 0000000000..6ee457a57e --- /dev/null +++ b/backend/src/user_service/auth.rs @@ -0,0 +1,14 @@ +use crate::errors::ServerError; +use flowy_user::protobuf::SignUpRequest; +use sqlx::PgPool; +use std::sync::Arc; + +pub struct Auth { + db_pool: Arc, +} + +impl Auth { + pub fn new(db_pool: Arc) -> Self { Self { db_pool } } + + pub fn handle_sign_up(&self, request: SignUpRequest) -> Result<(), ServerError> { Ok(()) } +} diff --git a/backend/src/user_service/mod.rs b/backend/src/user_service/mod.rs new file mode 100644 index 0000000000..ca51de07ab --- /dev/null +++ b/backend/src/user_service/mod.rs @@ -0,0 +1,3 @@ +mod auth; + +pub use auth::*; diff --git a/backend/src/ws_service/ws_client.rs b/backend/src/ws_service/ws_client.rs index 5a473e43e4..8c0ac23577 100644 --- a/backend/src/ws_service/ws_client.rs +++ b/backend/src/ws_service/ws_client.rs @@ -16,7 +16,6 @@ use actix::{ AsyncContext, ContextFutureSpawner, Handler, - Recipient, Running, StreamHandler, WrapFuture, diff --git a/backend/src/ws_service/ws_server.rs b/backend/src/ws_service/ws_server.rs index 120d483bc8..4a254da986 100644 --- a/backend/src/ws_service/ws_server.rs +++ b/backend/src/ws_service/ws_server.rs @@ -3,7 +3,6 @@ use crate::{ ws_service::{ entities::{Connect, Disconnect, Session, SessionId}, ClientMessage, - WSClient, }, }; use actix::{Actor, Context, Handler}; diff --git a/rust-lib/flowy-log/src/lib.rs b/rust-lib/flowy-log/src/lib.rs index 4ccf33e7a8..402a3c7e5d 100644 --- a/rust-lib/flowy-log/src/lib.rs +++ b/rust-lib/flowy-log/src/lib.rs @@ -24,7 +24,7 @@ impl Builder { self } - pub fn local(mut self, directory: impl AsRef) -> Self { + pub fn local(self, directory: impl AsRef) -> Self { let directory = directory.as_ref().to_str().unwrap().to_owned(); let local_file_name = format!("{}.log", &self.name); let file_appender = tracing_appender::rolling::daily(directory, local_file_name); diff --git a/rust-lib/flowy-user/src/lib.rs b/rust-lib/flowy-user/src/lib.rs index 55099e214f..f9078b797e 100644 --- a/rust-lib/flowy-user/src/lib.rs +++ b/rust-lib/flowy-user/src/lib.rs @@ -3,7 +3,7 @@ pub mod errors; pub mod event; mod handlers; pub mod module; -mod protobuf; +pub mod protobuf; mod services; pub mod sql_tables;