improve error type

main
Hannaeko 2021-04-02 17:12:29 -04:00
parent b758c87521
commit 7df4792ec5
4 changed files with 34 additions and 23 deletions

View File

@ -1,18 +1,20 @@
use serde::Serialize; use serde::Serialize;
use rocket::http::Status; use rocket::http::Status;
use rocket::request::Request; use rocket::request::{Request, Outcome};
use rocket::response::{self, Response, Responder}; use rocket::response::{self, Response, Responder};
use rocket_contrib::json::Json; use rocket_contrib::json::Json;
use crate::models::users::UserError; use crate::models::users::UserError;
use serde_json::Value;
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
pub struct ErrorResponse<T> { pub struct ErrorResponse {
#[serde(with = "StatusDef")] #[serde(with = "StatusDef")]
#[serde(flatten)] #[serde(flatten)]
pub status: Status, pub status: Status,
pub message: String, pub message: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<T> pub details: Option<Value>
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -23,8 +25,8 @@ struct StatusDef {
reason: &'static str, reason: &'static str,
} }
impl<T> ErrorResponse<T> { impl ErrorResponse {
pub fn new(status: Status, message: String) -> ErrorResponse<T> { pub fn new(status: Status, message: String) -> ErrorResponse {
ErrorResponse { ErrorResponse {
status, status,
message, message,
@ -32,26 +34,26 @@ impl<T> ErrorResponse<T> {
} }
} }
pub fn with_details(self, details: T) -> ErrorResponse<T> { pub fn with_details<T: Serialize> (self, details: T) -> ErrorResponse {
ErrorResponse { ErrorResponse {
details: Some(details), details: serde_json::to_value(details).ok(),
..self ..self
} }
} }
pub fn err<R>(self) -> Result<R, ErrorResponse<T>> { pub fn err<R>(self) -> Result<R, ErrorResponse> {
Err(self) Err(self)
} }
} }
impl<'r, T: Serialize> Responder<'r, 'static> for ErrorResponse<T> { impl<'r> Responder<'r, 'static> for ErrorResponse {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
let status = self.status; let status = self.status;
Response::build_from(Json(self).respond_to(req)?).status(status).ok() Response::build_from(Json(self).respond_to(req)?).status(status).ok()
} }
} }
impl From<UserError> for ErrorResponse<()> { impl From<UserError> for ErrorResponse {
fn from(e: UserError) -> Self { fn from(e: UserError) -> Self {
match e { match e {
UserError::NotFound => ErrorResponse::new(Status::Unauthorized, "Provided credentials or token do not match any existing user".into()), UserError::NotFound => ErrorResponse::new(Status::Unauthorized, "Provided credentials or token do not match any existing user".into()),
@ -66,7 +68,15 @@ impl From<UserError> for ErrorResponse<()> {
} }
} }
pub fn make_500<E: std::fmt::Debug>(e: E) -> ErrorResponse<()> {
impl<S> Into<Outcome<S, ErrorResponse>> for ErrorResponse {
fn into(self) -> Outcome<S, ErrorResponse> {
Outcome::Failure((self.status.clone(), self))
}
}
pub fn make_500<E: std::fmt::Debug>(e: E) -> ErrorResponse {
println!("{:?}", e); println!("{:?}", e);
ErrorResponse::new(Status::InternalServerError, "An unexpected error occured.".into()) ErrorResponse::new(Status::InternalServerError, "An unexpected error occured.".into())
} }

View File

@ -4,7 +4,6 @@ use diesel::result::Error as DieselError;
use diesel_derive_enum::DbEnum; use diesel_derive_enum::DbEnum;
use rocket::{State, request::{FromRequest, Request, Outcome}}; use rocket::{State, request::{FromRequest, Request, Outcome}};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use rocket::http::Status;
use chrono::serde::ts_seconds; use chrono::serde::ts_seconds;
use chrono::prelude::{DateTime, Utc}; use chrono::prelude::{DateTime, Utc};
use chrono::Duration; use chrono::Duration;
@ -21,6 +20,7 @@ use jsonwebtoken::{
use crate::schema::*; use crate::schema::*;
use crate::DbConn; use crate::DbConn;
use crate::config::Config; use crate::config::Config;
use crate::models::errors::ErrorResponse;
const BEARER: &'static str = "Bearer "; const BEARER: &'static str = "Bearer ";
@ -95,7 +95,7 @@ pub struct UserInfo {
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for UserInfo { impl<'r> FromRequest<'r> for UserInfo {
type Error = UserError; type Error = ErrorResponse;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let auth_header = match request.headers().get_one(AUTH_HEADER) { let auth_header = match request.headers().get_one(AUTH_HEADER) {
@ -106,7 +106,7 @@ impl<'r> FromRequest<'r> for UserInfo {
let token = if auth_header.starts_with(BEARER) { let token = if auth_header.starts_with(BEARER) {
auth_header.trim_start_matches(BEARER) auth_header.trim_start_matches(BEARER)
} else { } else {
return Outcome::Failure((Status::BadRequest, UserError::MalformedHeader)) return ErrorResponse::from(UserError::MalformedHeader).into()
}; };
// TODO: Better error handling // TODO: Better error handling
@ -116,12 +116,12 @@ impl<'r> FromRequest<'r> for UserInfo {
let token_data = AuthClaims::decode( let token_data = AuthClaims::decode(
token, &config.web_app.secret token, &config.web_app.secret
).map_err(|e| match e.into_kind() { ).map_err(|e| match e.into_kind() {
JwtErrorKind::ExpiredSignature => (Status::Unauthorized, UserError::ExpiredToken), JwtErrorKind::ExpiredSignature => UserError::ExpiredToken,
_ => (Status::BadRequest, UserError::BadToken), _ => UserError::BadToken,
}); });
let token_data = match token_data { let token_data = match token_data {
Err(e) => return Outcome::Failure(e), Err(e) => return ErrorResponse::from(e).into(),
Ok(data) => data Ok(data) => data
}; };
@ -129,8 +129,8 @@ impl<'r> FromRequest<'r> for UserInfo {
conn.run(|c| { conn.run(|c| {
match LocalUser::get_user_by_uuid(c, user_id) { match LocalUser::get_user_by_uuid(c, user_id) {
Err(UserError::NotFound) => Outcome::Failure((Status::NotFound, UserError::NotFound)), Err(UserError::NotFound) => ErrorResponse::from(UserError::NotFound).into(),
Err(e) => Outcome::Failure((Status::InternalServerError, e)), Err(e) => ErrorResponse::from(e).into(),
Ok(d) => Outcome::Success(d), Ok(d) => Outcome::Success(d),
} }
}).await }).await

View File

@ -13,7 +13,7 @@ pub async fn create_auth_token(
conn: DbConn, conn: DbConn,
config: State<'_, Config>, config: State<'_, Config>,
auth_request: Json<AuthTokenRequest> auth_request: Json<AuthTokenRequest>
) -> Result<Json<AuthTokenResponse>, ErrorResponse<()>> { ) -> Result<Json<AuthTokenResponse>, ErrorResponse> {
let user_info = conn.run(move |c| { let user_info = conn.run(move |c| {
LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password) LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password)
@ -27,7 +27,7 @@ pub async fn create_auth_token(
} }
#[post("/users", data = "<user_request>")] #[post("/users", data = "<user_request>")]
pub async fn create_user<'r>(conn: DbConn, user_request: Json<CreateUserRequest>) -> Result<Response<'r>, ErrorResponse<()>>{ pub async fn create_user<'r>(conn: DbConn, user_request: Json<CreateUserRequest>) -> Result<Response<'r>, ErrorResponse>{
// TODO: Check current user if any to check if user has permission to create users (with or without role) // TODO: Check current user if any to check if user has permission to create users (with or without role)
let _user_info = conn.run(|c| { let _user_info = conn.run(|c| {
LocalUser::create_user(&c, user_request.into_inner()) LocalUser::create_user(&c, user_request.into_inner())

View File

@ -16,9 +16,10 @@ use crate::DnsClient;
#[get("/zones/<zone>/records")] #[get("/zones/<zone>/records")]
pub fn get_zone_records( pub fn get_zone_records(
client: State<DnsClient>, client: State<DnsClient>,
_user_info: UserInfo, user_info: Result<UserInfo, ErrorResponse>,
zone: String zone: String
) -> Result<Json<Vec<dns::Record>>, ErrorResponse<()>> { ) -> Result<Json<Vec<dns::Record>>, ErrorResponse> {
user_info?;
// TODO: Implement FromParam for Name // TODO: Implement FromParam for Name
let name = Name::from_utf8(&zone).unwrap(); let name = Name::from_utf8(&zone).unwrap();