diff --git a/src/models/mod.rs b/src/models/mod.rs index d04ea25..f200898 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,15 +1,14 @@ -//pub mod dns; -pub mod auth; pub mod class; pub mod errors; pub mod name; pub mod rdata; pub mod record; +pub mod session; pub mod user; pub mod zone; // Reexport types for convenience -pub use auth::{AuthTokenRequest, Session}; +pub use session::{AuthTokenRequest, Session}; pub use class::DNSClass; pub use errors::{UserError, ErrorResponse, make_500}; pub use name::{AbsoluteName, SerdeName}; diff --git a/src/models/auth.rs b/src/models/session.rs similarity index 55% rename from src/models/auth.rs rename to src/models/session.rs index b614624..6e850da 100644 --- a/src/models/auth.rs +++ b/src/models/session.rs @@ -7,10 +7,17 @@ use diesel::prelude::*; use rand::Rng; use rand::rngs::OsRng; use rand::distributions::Alphanumeric; +use rocket::request::{FromRequest, Request, Outcome}; +use rocket::outcome::try_outcome; -use crate::models::user::UserInfo; use crate::schema::*; -use crate::models::errors::UserError; +use crate::DbConn; +use crate::models::user::UserInfo; +use crate::models::errors::{UserError, ErrorResponse, make_500}; + +const BEARER: &str = "Bearer "; +const AUTH_HEADER: &str = "Authorization"; +pub const COOKIE_NAME: &str = "session_id"; #[derive(Debug, Deserialize)] @@ -75,4 +82,49 @@ impl Session { Ok(user_session) } + + fn get_token_from_header<'r>(request: &'r Request<'_>) -> Outcome { + let auth_header = match request.headers().get_one(AUTH_HEADER) { + None => return Outcome::Forward(()), + Some(auth_header) => auth_header, + }; + + let token = if auth_header.starts_with(BEARER) { + auth_header.trim_start_matches(BEARER).to_string() + } else { + return ErrorResponse::from(UserError::MalformedHeader).into(); + }; + + Outcome::Success(token) + } + + fn get_token_from_cookie<'r>(request: &'r Request<'_>) -> Outcome { + match request.cookies().get(COOKIE_NAME) { + None => Outcome::Forward(()), + Some(session_cookie) => Outcome::Success(session_cookie.value().to_string()), + } + } +} + + + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Session { + type Error = ErrorResponse; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let token = try_outcome!( + Session::get_token_from_header(request) + .forward_then(|_| Session::get_token_from_cookie(request)) + ); + + let conn = try_outcome!(request.guard::().await.map_failure(make_500)); + + conn.run(move |c| { + match Session::from_session_id(c, &token) { + Err(e) => ErrorResponse::from(e).into(), + Ok(s) => Outcome::Success(s), + } + }).await + } } diff --git a/src/models/user.rs b/src/models/user.rs index 2802f25..df4beba 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -16,12 +16,7 @@ use crate::DbConn; use crate::models::errors::{UserError, ErrorResponse, make_500}; use crate::models::zone::Zone; -use crate::models::auth::Session; - - -const BEARER: &str = "Bearer "; -const AUTH_HEADER: &str = "Authorization"; -pub const COOKIE_NAME: &str = "session_id"; +use crate::models::session::Session; #[derive(Debug, DbEnum, Deserialize, Clone)] @@ -115,49 +110,14 @@ impl UserInfo { } } -fn get_token_from_header<'r>(request: &'r Request<'_>) -> Outcome { - let auth_header = match request.headers().get_one(AUTH_HEADER) { - None => return Outcome::Forward(()), - Some(auth_header) => auth_header, - }; - - let token = if auth_header.starts_with(BEARER) { - auth_header.trim_start_matches(BEARER).to_string() - } else { - return ErrorResponse::from(UserError::MalformedHeader).into(); - }; - - Outcome::Success(token) -} - -fn get_token_from_cookie<'r>(request: &'r Request<'_>) -> Outcome { - match request.cookies().get(COOKIE_NAME) { - None => Outcome::Forward(()), - Some(session_cookie) => Outcome::Success(session_cookie.value().to_string()), - } -} - #[rocket::async_trait] impl<'r> FromRequest<'r> for UserInfo { type Error = ErrorResponse; async fn from_request(request: &'r Request<'_>) -> Outcome { - let token = try_outcome!( - get_token_from_header(request) - .forward_then(|_| get_token_from_cookie(request)) - ); - + let session = try_outcome!(request.guard::().await.map_failure(make_500)); let conn = try_outcome!(request.guard::().await.map_failure(make_500)); - let session_res = conn.run(move |c| { - Session::from_session_id(c, &token) - }).await; - - let session = match session_res { - Err(e) => return ErrorResponse::from(e).into(), - Ok(s) => s, - }; - conn.run(move |c| { match LocalUser::get_user_by_uuid(c, &session.user_id) { Err(e) => ErrorResponse::from(e).into(), diff --git a/src/routes/users.rs b/src/routes/users.rs index 4b97be2..cfe6738 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -33,7 +33,7 @@ pub async fn create_auth_token( // About unwrap: I guess too bad if session time is over year 9999 (current max time if time-rs) let expires = time::OffsetDateTime::from_unix_timestamp(session.expires_at.timestamp()).unwrap(); - let session_cookie = Cookie::build(models::user::COOKIE_NAME, session.session_id.clone()) + let session_cookie = Cookie::build(models::session::COOKIE_NAME, session.session_id.clone()) .same_site(SameSite::Strict) .secure(true) .http_only(true)