From 7df4792ec50303351da7b942021292259e9547b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20Berthaud-M=C3=BCller?= Date: Fri, 2 Apr 2021 17:12:29 -0400 Subject: [PATCH] improve error type --- src/models/errors.rs | 32 +++++++++++++++++++++----------- src/models/users.rs | 16 ++++++++-------- src/routes/users.rs | 4 ++-- src/routes/zones.rs | 5 +++-- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/models/errors.rs b/src/models/errors.rs index ddc47c3..c7973a3 100644 --- a/src/models/errors.rs +++ b/src/models/errors.rs @@ -1,18 +1,20 @@ use serde::Serialize; use rocket::http::Status; -use rocket::request::Request; +use rocket::request::{Request, Outcome}; use rocket::response::{self, Response, Responder}; use rocket_contrib::json::Json; use crate::models::users::UserError; +use serde_json::Value; + #[derive(Serialize, Debug)] -pub struct ErrorResponse { +pub struct ErrorResponse { #[serde(with = "StatusDef")] #[serde(flatten)] pub status: Status, pub message: String, #[serde(skip_serializing_if = "Option::is_none")] - pub details: Option + pub details: Option } #[derive(Serialize)] @@ -23,8 +25,8 @@ struct StatusDef { reason: &'static str, } -impl ErrorResponse { - pub fn new(status: Status, message: String) -> ErrorResponse { +impl ErrorResponse { + pub fn new(status: Status, message: String) -> ErrorResponse { ErrorResponse { status, message, @@ -32,26 +34,26 @@ impl ErrorResponse { } } - pub fn with_details(self, details: T) -> ErrorResponse { + pub fn with_details (self, details: T) -> ErrorResponse { ErrorResponse { - details: Some(details), + details: serde_json::to_value(details).ok(), ..self } } - pub fn err(self) -> Result> { + pub fn err(self) -> Result { Err(self) } } -impl<'r, T: Serialize> Responder<'r, 'static> for ErrorResponse { +impl<'r> Responder<'r, 'static> for ErrorResponse { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { let status = self.status; Response::build_from(Json(self).respond_to(req)?).status(status).ok() } } -impl From for ErrorResponse<()> { +impl From for ErrorResponse { fn from(e: UserError) -> Self { match e { UserError::NotFound => ErrorResponse::new(Status::Unauthorized, "Provided credentials or token do not match any existing user".into()), @@ -66,7 +68,15 @@ impl From for ErrorResponse<()> { } } -pub fn make_500(e: E) -> ErrorResponse<()> { + +impl Into> for ErrorResponse { + fn into(self) -> Outcome { + Outcome::Failure((self.status.clone(), self)) + } +} + + +pub fn make_500(e: E) -> ErrorResponse { println!("{:?}", e); ErrorResponse::new(Status::InternalServerError, "An unexpected error occured.".into()) } diff --git a/src/models/users.rs b/src/models/users.rs index c9b5fe0..59b3645 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -4,7 +4,6 @@ use diesel::result::Error as DieselError; use diesel_derive_enum::DbEnum; use rocket::{State, request::{FromRequest, Request, Outcome}}; use serde::{Serialize, Deserialize}; -use rocket::http::Status; use chrono::serde::ts_seconds; use chrono::prelude::{DateTime, Utc}; use chrono::Duration; @@ -21,6 +20,7 @@ use jsonwebtoken::{ use crate::schema::*; use crate::DbConn; use crate::config::Config; +use crate::models::errors::ErrorResponse; const BEARER: &'static str = "Bearer "; @@ -95,7 +95,7 @@ pub struct UserInfo { #[rocket::async_trait] impl<'r> FromRequest<'r> for UserInfo { - type Error = UserError; + type Error = ErrorResponse; async fn from_request(request: &'r Request<'_>) -> Outcome { let auth_header = match request.headers().get_one(AUTH_HEADER) { @@ -106,7 +106,7 @@ impl<'r> FromRequest<'r> for UserInfo { let token = if auth_header.starts_with(BEARER) { auth_header.trim_start_matches(BEARER) } else { - return Outcome::Failure((Status::BadRequest, UserError::MalformedHeader)) + return ErrorResponse::from(UserError::MalformedHeader).into() }; // TODO: Better error handling @@ -116,12 +116,12 @@ impl<'r> FromRequest<'r> for UserInfo { let token_data = AuthClaims::decode( token, &config.web_app.secret ).map_err(|e| match e.into_kind() { - JwtErrorKind::ExpiredSignature => (Status::Unauthorized, UserError::ExpiredToken), - _ => (Status::BadRequest, UserError::BadToken), + JwtErrorKind::ExpiredSignature => UserError::ExpiredToken, + _ => UserError::BadToken, }); let token_data = match token_data { - Err(e) => return Outcome::Failure(e), + Err(e) => return ErrorResponse::from(e).into(), Ok(data) => data }; @@ -129,8 +129,8 @@ impl<'r> FromRequest<'r> for UserInfo { conn.run(|c| { match LocalUser::get_user_by_uuid(c, user_id) { - Err(UserError::NotFound) => Outcome::Failure((Status::NotFound, UserError::NotFound)), - Err(e) => Outcome::Failure((Status::InternalServerError, e)), + Err(UserError::NotFound) => ErrorResponse::from(UserError::NotFound).into(), + Err(e) => ErrorResponse::from(e).into(), Ok(d) => Outcome::Success(d), } }).await diff --git a/src/routes/users.rs b/src/routes/users.rs index 17fefa2..0937901 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -13,7 +13,7 @@ pub async fn create_auth_token( conn: DbConn, config: State<'_, Config>, auth_request: Json -) -> Result, ErrorResponse<()>> { +) -> Result, ErrorResponse> { let user_info = conn.run(move |c| { LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password) @@ -27,7 +27,7 @@ pub async fn create_auth_token( } #[post("/users", data = "")] -pub async fn create_user<'r>(conn: DbConn, user_request: Json) -> Result, ErrorResponse<()>>{ +pub async fn create_user<'r>(conn: DbConn, user_request: Json) -> Result, ErrorResponse>{ // TODO: Check current user if any to check if user has permission to create users (with or without role) let _user_info = conn.run(|c| { LocalUser::create_user(&c, user_request.into_inner()) diff --git a/src/routes/zones.rs b/src/routes/zones.rs index 1ccbe73..bc72497 100644 --- a/src/routes/zones.rs +++ b/src/routes/zones.rs @@ -16,9 +16,10 @@ use crate::DnsClient; #[get("/zones//records")] pub fn get_zone_records( client: State, - _user_info: UserInfo, + user_info: Result, zone: String -) -> Result>, ErrorResponse<()>> { +) -> Result>, ErrorResponse> { + user_info?; // TODO: Implement FromParam for Name let name = Name::from_utf8(&zone).unwrap();