This commit is contained in:
Hannaeko 2022-03-04 13:08:03 +01:00
parent b7db84e9a8
commit 4a4362715c
16 changed files with 549 additions and 466 deletions

40
src/dns/class.rs Normal file
View file

@ -0,0 +1,40 @@
use serde::{Deserialize, Serialize};
use super::trust_dns_types;
#[derive(Deserialize, Serialize, Clone)]
pub enum DNSClass {
IN,
CH,
HS,
NONE,
ANY,
OPT(u16),
}
impl From<trust_dns_types::DNSClass> for DNSClass {
fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass {
match dns_class {
trust_dns_types::DNSClass::IN => DNSClass::IN,
trust_dns_types::DNSClass::CH => DNSClass::CH,
trust_dns_types::DNSClass::HS => DNSClass::HS,
trust_dns_types::DNSClass::NONE => DNSClass::NONE,
trust_dns_types::DNSClass::ANY => DNSClass::ANY,
trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v),
}
}
}
impl From<DNSClass> for trust_dns_types::DNSClass {
fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass {
match dns_class {
DNSClass::IN => trust_dns_types::DNSClass::IN,
DNSClass::CH => trust_dns_types::DNSClass::CH,
DNSClass::HS => trust_dns_types::DNSClass::HS,
DNSClass::NONE => trust_dns_types::DNSClass::NONE,
DNSClass::ANY => trust_dns_types::DNSClass::ANY,
DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v),
}
}
}

71
src/dns/client.rs Normal file
View file

@ -0,0 +1,71 @@
use std::{future::Future, pin::Pin, task::{Context, Poll}};
use std::ops::{Deref, DerefMut};
use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}};
use tokio::{net::TcpStream as TokioTcpStream, task};
use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream};
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use crate::config::Config;
use super::message::DnsMessage;
pub struct DnsClient(AsyncClient);
impl Deref for DnsClient {
type Target = AsyncClient;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsClient {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl DnsMessage for AsyncClient {}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<State<Config>>().await);
let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(config.dns.server);
let client = AsyncClient::with_timeout(
stream,
handle,
std::time::Duration::from_secs(5),
None);
let (client, bg) = match client.await {
Err(e) => {
println!("Failed to connect to DNS server {:#?}", e);
return Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => c
};
task::spawn(bg);
Outcome::Success(DnsClient(client))
}
}
// Reimplement this type here as ClientReponse in trust-dns crate have private fields
pub struct ClientResponse<R>(pub(crate) R)
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
impl<R> Future for ClientResponse<R>
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
{
type Output = Result<DnsResponse, ClientError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// This is from the future_utils crate, we simply reuse the reexport from Rocket
rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from)
}
}

67
src/dns/message.rs Normal file
View file

@ -0,0 +1,67 @@
use trust_dns_proto::DnsHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query};
use trust_dns_proto::error::ProtoError;
use super::trust_dns_types::{Name, Record};
use super::client::{ClientResponse};
pub enum MessageError {
RecordNotInZone {
zone: Name,
class: DNSClass,
mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>,
}
}
pub trait DnsMessage: DnsHandle<Error = ProtoError> + Send {
fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<ClientResponse<Self::Response>, MessageError>
{
let mut mismatched_class = Vec::new();
let mut mismatched_zone = Vec::new();
for record in new_records.iter() {
if !zone.zone_of(record.name()) {
mismatched_zone.push(record.clone());
}
if record.dns_class() != class {
mismatched_class.push(record.clone());
}
}
if mismatched_class.len() > 0 || mismatched_zone.len() > 0 {
return Err(MessageError::RecordNotInZone {
zone,
class,
mismatched_zone,
mismatched_class
})
}
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(new_records);
{
let edns = message.edns_mut();
edns.set_max_payload(1232);
edns.set_version(0);
}
return Ok(ClientResponse(self.send(message)));
}
}

16
src/dns/mod.rs Normal file
View file

@ -0,0 +1,16 @@
pub mod class;
pub mod name;
pub mod rdata;
pub mod record;
pub mod client;
pub mod message;
pub mod trust_dns_types {
pub use trust_dns_client::rr::rdata::{
DNSSECRData, caa, sshfp, mx, null, soa, srv, txt
};
pub use trust_dns_client::rr::{
RData, DNSClass, Record
};
pub use trust_dns_proto::rr::Name;
}

72
src/dns/name.rs Normal file
View file

@ -0,0 +1,72 @@
use std::ops::Deref;
use rocket::request::FromParam;
use serde::{Deserialize, Serialize, Deserializer, Serializer};
use trust_dns_proto::error::ProtoError;
use super::trust_dns_types::Name;
#[derive(Debug, Clone)]
pub struct SerdeName(pub(crate)Name);
impl Deref for SerdeName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for SerdeName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|string|
Name::from_utf8(&string)
.map_err(|e| Error::custom(e.to_string()))
).map( SerdeName)
}
}
impl Serialize for SerdeName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer
{
self.0.to_utf8().serialize(serializer)
}
}
impl SerdeName {
pub fn into_inner(self) -> Name {
self.0
}
}
#[derive(Debug, Deserialize)]
pub struct AbsoluteName(SerdeName);
impl<'r> FromParam<'r> for AbsoluteName {
type Error = ProtoError;
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
let mut name = Name::from_utf8(&param)?;
if !name.is_fqdn() {
name.set_fqdn(true);
}
Ok(AbsoluteName(SerdeName(name)))
}
}
impl Deref for AbsoluteName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0.0
}
}

View file

@ -1,21 +1,14 @@
use std::{convert::{TryFrom, TryInto}, future::Future, net::{Ipv6Addr, Ipv4Addr}, pin::Pin, task::{Context, Poll}};
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::convert::TryFrom;
use std::net::{Ipv6Addr, Ipv4Addr};
use serde::{Deserialize, Serialize};
use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}}; use trust_dns_client::serialize::binary::BinEncoder;
use trust_dns_proto::error::ProtoError;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::trust_dns_types;
use super::name::SerdeName;
use tokio::{net::TcpStream as TokioTcpStream, task};
use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, serialize::binary::BinEncoder, tcp::TcpClientStream};
use trust_dns_proto::error::{ProtoError};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use super::trust_dns_types::{self, Name};
use crate::config::Config;
#[derive(Deserialize, Serialize, Clone)] #[derive(Deserialize, Serialize, Clone)]
@ -103,6 +96,7 @@ pub enum RData {
// ZERO, // ZERO,
// TODO: DS // TODO: DS
// TODO: TLSA
} }
impl From<trust_dns_types::RData> for RData { impl From<trust_dns_types::RData> for RData {
@ -286,195 +280,3 @@ impl<'a> fmt::Display for CAAValue<'a> {
Ok(()) Ok(())
} }
} }
#[derive(Deserialize, Serialize, Clone)]
pub enum DNSClass {
IN,
CH,
HS,
NONE,
ANY,
OPT(u16),
}
impl From<trust_dns_types::DNSClass> for DNSClass {
fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass {
match dns_class {
trust_dns_types::DNSClass::IN => DNSClass::IN,
trust_dns_types::DNSClass::CH => DNSClass::CH,
trust_dns_types::DNSClass::HS => DNSClass::HS,
trust_dns_types::DNSClass::NONE => DNSClass::NONE,
trust_dns_types::DNSClass::ANY => DNSClass::ANY,
trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v),
}
}
}
impl From<DNSClass> for trust_dns_types::DNSClass {
fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass {
match dns_class {
DNSClass::IN => trust_dns_types::DNSClass::IN,
DNSClass::CH => trust_dns_types::DNSClass::CH,
DNSClass::HS => trust_dns_types::DNSClass::HS,
DNSClass::NONE => trust_dns_types::DNSClass::NONE,
DNSClass::ANY => trust_dns_types::DNSClass::ANY,
DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v),
}
}
}
// Reimplement this type here as ClientReponse in trust-dns crate have private fields
pub struct ClientResponse<R>(pub(crate) R)
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
impl<R> Future for ClientResponse<R>
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
{
type Output = Result<DnsResponse, ClientError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// This is from the future_utils crate, we simply reuse the reexport from Rocket
rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from)
}
}
#[derive(Deserialize, Serialize, Clone)]
pub struct Record {
#[serde(rename = "Name")]
pub name: SerdeName,
// TODO: Make class optional, default to IN
#[serde(rename = "Class")]
pub dns_class: DNSClass,
#[serde(rename = "TTL")]
pub ttl: u32,
#[serde(flatten)]
pub rdata: RData,
}
impl From<trust_dns_types::Record> for Record {
fn from(record: trust_dns_types::Record) -> Record {
Record {
name: SerdeName(record.name().clone()),
dns_class: record.dns_class().into(),
ttl: record.ttl(),
rdata: record.into_data().into(),
}
}
}
impl TryFrom<Record> for trust_dns_types::Record {
type Error = ProtoError;
fn try_from(record: Record) -> Result<Self, Self::Error> {
let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?);
trust_dns_record.set_dns_class(record.dns_class.into());
Ok(trust_dns_record)
}
}
#[derive(Debug, Clone)]
pub struct SerdeName(Name);
impl<'de> Deserialize<'de> for SerdeName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|string|
Name::from_utf8(&string)
.map_err(|e| Error::custom(e.to_string()))
).map( SerdeName)
}
}
impl Deref for SerdeName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Serialize for SerdeName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer
{
self.0.to_utf8().serialize(serializer)
}
}
impl SerdeName {
fn into_inner(self) -> Name {
self.0
}
}
#[derive(Debug, Deserialize)]
pub struct AbsoluteName(SerdeName);
impl<'r> FromParam<'r> for AbsoluteName {
type Error = ProtoError;
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
let mut name = Name::from_utf8(&param)?;
if !name.is_fqdn() {
name.set_fqdn(true);
}
Ok(AbsoluteName(SerdeName(name)))
}
}
impl Deref for AbsoluteName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0.0
}
}
pub struct DnsClient(AsyncClient);
impl Deref for DnsClient {
type Target = AsyncClient;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsClient {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<State<Config>>().await);
let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(config.dns.server);
let client = AsyncClient::with_timeout(
stream,
handle,
std::time::Duration::from_secs(5),
None);
let (client, bg) = match client.await {
Err(e) => {
println!("Failed to connect to DNS server {:#?}", e);
return Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => c
};
task::spawn(bg);
Outcome::Success(DnsClient(client))
}
}

43
src/dns/record.rs Normal file
View file

@ -0,0 +1,43 @@
use std::convert::{TryFrom, TryInto};
use serde::{Deserialize, Serialize};
use trust_dns_proto::error::ProtoError;
use super::trust_dns_types;
use super::name::SerdeName;
use super::class::DNSClass;
use super::rdata::RData;
#[derive(Deserialize, Serialize, Clone)]
pub struct Record {
#[serde(rename = "Name")]
pub name: SerdeName,
// TODO: Make class optional, default to IN
#[serde(rename = "Class")]
pub dns_class: DNSClass,
#[serde(rename = "TTL")]
pub ttl: u32,
#[serde(flatten)]
pub rdata: RData,
}
impl From<trust_dns_types::Record> for Record {
fn from(record: trust_dns_types::Record) -> Record {
Record {
name: SerdeName(record.name().clone()),
dns_class: record.dns_class().into(),
ttl: record.ttl(),
rdata: record.into_data().into(),
}
}
}
impl TryFrom<Record> for trust_dns_types::Record {
type Error = ProtoError;
fn try_from(record: Record) -> Result<Self, Self::Error> {
let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?);
trust_dns_record.set_dns_class(record.dns_class.into());
Ok(trust_dns_record)
}
}

View file

@ -8,6 +8,7 @@ mod models;
mod config; mod config;
mod schema; mod schema;
mod routes; mod routes;
mod dns;
use routes::users::*; use routes::users::*;
use routes::zones::*; use routes::zones::*;

63
src/models/auth.rs Normal file
View file

@ -0,0 +1,63 @@
use uuid::Uuid;
use serde::{Serialize, Deserialize};
use chrono::serde::ts_seconds;
use chrono::prelude::{DateTime, Utc};
use chrono::Duration;
use jsonwebtoken::{
encode, decode,
Header, Validation,
Algorithm as JwtAlgorithm, EncodingKey, DecodingKey,
errors::Result as JwtResult
};
use crate::models::user::UserInfo;
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthClaims {
pub jti: String,
pub sub: String,
#[serde(with = "ts_seconds")]
pub exp: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub iat: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
pub struct AuthTokenResponse {
pub token: String
}
#[derive(Debug, Deserialize)]
pub struct AuthTokenRequest {
pub username: String,
pub password: String,
}
impl AuthClaims {
pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims {
let jti = Uuid::new_v4().to_simple().to_string();
let iat = Utc::now();
let exp = iat + token_duration;
AuthClaims {
jti,
sub: user_info.id.clone(),
exp,
iat,
}
}
pub fn decode(token: &str, secret: &str) -> JwtResult<AuthClaims> {
decode::<AuthClaims>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::new(JwtAlgorithm::HS256)
).map(|data| data.claims)
}
pub fn encode(self, secret: &str) -> JwtResult<String> {
encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref()))
}
}

View file

@ -3,8 +3,36 @@ use rocket::http::Status;
use rocket::request::{Request, Outcome}; use rocket::request::{Request, Outcome};
use rocket::response::{self, Response, Responder}; use rocket::response::{self, Response, Responder};
use rocket_contrib::json::Json; use rocket_contrib::json::Json;
use crate::models::users::UserError;
use serde_json::Value; use serde_json::Value;
use djangohashers::{HasherError};
use diesel::result::Error as DieselError;
#[derive(Debug)]
pub enum UserError {
ZoneNotFound,
NotFound,
UserConflict,
BadCreds,
BadToken,
ExpiredToken,
MalformedHeader,
PermissionDenied,
DbError(DieselError),
PasswordError(HasherError),
}
impl From<HasherError> for UserError {
fn from(e: HasherError) -> Self {
UserError::PasswordError(e)
}
}
impl From<DieselError> for UserError {
fn from(e: DieselError) -> Self {
UserError::DbError(e)
}
}
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]

View file

@ -1,13 +1,5 @@
pub mod dns; //pub mod dns;
pub mod errors; pub mod errors;
pub mod users; pub mod user;
pub mod zone;
pub mod trust_dns_types { pub mod auth;
pub use trust_dns_client::rr::rdata::{
DNSSECRData, caa, sshfp, mx, null, soa, srv, txt
};
pub use trust_dns_client::rr::{
RData, DNSClass, Record
};
pub use trust_dns_proto::rr::Name;
}

View file

@ -3,25 +3,20 @@ use diesel::prelude::*;
use diesel::result::Error as DieselError; use diesel::result::Error as DieselError;
use diesel_derive_enum::DbEnum; use diesel_derive_enum::DbEnum;
use rocket::{State, request::{FromRequest, Request, Outcome}}; use rocket::{State, request::{FromRequest, Request, Outcome}};
use serde::{Serialize, Deserialize}; use serde::{Deserialize};
use chrono::serde::ts_seconds;
use chrono::prelude::{DateTime, Utc};
use chrono::Duration;
// TODO: Maybe just use argon2 crate directly // TODO: Maybe just use argon2 crate directly
use djangohashers::{make_password_with_algorithm, check_password, HasherError, Algorithm}; use djangohashers::{make_password_with_algorithm, check_password, Algorithm};
use jsonwebtoken::{ use jsonwebtoken::{
encode, decode,
Header, Validation,
Algorithm as JwtAlgorithm, EncodingKey, DecodingKey,
errors::Result as JwtResult,
errors::ErrorKind as JwtErrorKind errors::ErrorKind as JwtErrorKind
}; };
use crate::schema::*; use crate::schema::*;
use crate::DbConn; use crate::DbConn;
use crate::config::Config; use crate::config::Config;
use crate::models::errors::{ErrorResponse, make_500}; use crate::models::errors::{UserError, ErrorResponse, make_500};
use crate::models::dns::AbsoluteName; use crate::models::zone::Zone;
use crate::models::auth::AuthClaims;
const BEARER: &str = "Bearer "; const BEARER: &str = "Bearer ";
const AUTH_HEADER: &str = "Authorization"; const AUTH_HEADER: &str = "Authorization";
@ -61,14 +56,6 @@ pub struct UserZone {
pub zone_id: String, pub zone_id: String,
} }
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CreateUserRequest { pub struct CreateUserRequest {
pub username: String, pub username: String,
@ -77,42 +64,6 @@ pub struct CreateUserRequest {
pub role: Option<Role> pub role: Option<Role>
} }
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateZoneRequest {
pub name: AbsoluteName,
}
// pub struct LdapUserAssociation {
// user_id: Uuid,
// ldap_id: String
// }
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthClaims {
pub jti: String,
pub sub: String,
#[serde(with = "ts_seconds")]
pub exp: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub iat: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
pub struct AuthTokenResponse {
pub token: String
}
#[derive(Debug, Deserialize)]
pub struct AuthTokenRequest {
pub username: String,
pub password: String,
}
#[derive(Debug)] #[derive(Debug)]
pub struct UserInfo { pub struct UserInfo {
pub id: String, pub id: String,
@ -205,32 +156,6 @@ impl<'r> FromRequest<'r> for UserInfo {
} }
} }
#[derive(Debug)]
pub enum UserError {
ZoneNotFound,
NotFound,
UserConflict,
BadCreds,
BadToken,
ExpiredToken,
MalformedHeader,
PermissionDenied,
DbError(DieselError),
PasswordError(HasherError),
}
impl From<HasherError> for UserError {
fn from(e: HasherError) -> Self {
UserError::PasswordError(e)
}
}
impl From<DieselError> for UserError {
fn from(e: DieselError) -> Self {
UserError::DbError(e)
}
}
impl LocalUser { impl LocalUser {
pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result<UserInfo, UserError> { pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result<UserInfo, UserError> {
use crate::schema::localuser::dsl::*; use crate::schema::localuser::dsl::*;
@ -320,95 +245,4 @@ impl LocalUser {
username: client_localuser.username, username: client_localuser.username,
}) })
} }
}
}
impl AuthClaims {
pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims {
let jti = Uuid::new_v4().to_simple().to_string();
let iat = Utc::now();
let exp = iat + token_duration;
AuthClaims {
jti,
sub: user_info.id.clone(),
exp,
iat,
}
}
pub fn decode(token: &str, secret: &str) -> JwtResult<AuthClaims> {
decode::<AuthClaims>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::new(JwtAlgorithm::HS256)
).map(|data| data.claims)
}
pub fn encode(self, secret: &str) -> JwtResult<String> {
encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref()))
}
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
let new_zone = Zone {
id: Uuid::new_v4().to_simple().to_string(),
name: zone_request.name.to_utf8(),
};
diesel::insert_into(zone)
.values(&new_zone)
.execute(conn)
.map_err(|e| match e {
DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict,
other => UserError::DbError(other)
})?;
Ok(new_zone)
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
}

93
src/models/zone.rs Normal file
View file

@ -0,0 +1,93 @@
use crate::models::user::UserInfo;
use uuid::Uuid;
use diesel::prelude::*;
use diesel::result::Error as DieselError;
use serde::{Serialize, Deserialize};
use crate::schema::*;
use crate::dns::name::AbsoluteName;
use crate::models::user::UserZone;
use crate::models::errors::UserError;
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateZoneRequest {
pub name: AbsoluteName,
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
let new_zone = Zone {
id: Uuid::new_v4().to_simple().to_string(),
name: zone_request.name.to_utf8(),
};
diesel::insert_into(zone)
.values(&new_zone)
.execute(conn)
.map_err(|e| match e {
DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict,
other => UserError::DbError(other)
})?;
Ok(new_zone)
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
}

View file

@ -5,13 +5,8 @@ use rocket::http::Status;
use crate::config::Config; use crate::config::Config;
use crate::DbConn; use crate::DbConn;
use crate::models::errors::{ErrorResponse, make_500}; use crate::models::errors::{ErrorResponse, make_500};
use crate::models::users::{ use crate::models::user::{LocalUser, CreateUserRequest};
LocalUser, use crate::models::auth::{AuthClaims, AuthTokenRequest, AuthTokenResponse};
CreateUserRequest,
AuthClaims,
AuthTokenRequest,
AuthTokenResponse
};
#[post("/users/me/token", data = "<auth_request>")] #[post("/users/me/token", data = "<auth_request>")]

View file

@ -5,27 +5,30 @@ use rocket::http::Status;
use rocket_contrib::json::Json; use rocket_contrib::json::Json;
use trust_dns_client::{client::ClientHandle, op::UpdateMessage}; use trust_dns_client::client::ClientHandle;
use trust_dns_client::op::ResponseCode; use trust_dns_client::op::ResponseCode;
use trust_dns_client::rr::{DNSClass, RecordType}; use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_proto::DnsHandle;
pub use trust_dns_client::op::Message; pub use trust_dns_client::op::Message;
pub use trust_dns_client::op::OpCode; pub use trust_dns_client::op::OpCode;
pub use trust_dns_client::op::Query; pub use trust_dns_client::op::Query;
pub use trust_dns_client::op::MessageType; pub use trust_dns_client::op::MessageType;
use crate::{DbConn, models::{dns, trust_dns_types}}; use crate::{dns::{self, trust_dns_types}, DbConn};
use crate::models::errors::{ErrorResponse, make_500}; use crate::models::errors::{ErrorResponse, make_500};
use crate::models::users::{LocalUser, UserInfo, Zone, AddZoneMemberRequest, CreateZoneRequest}; use crate::models::user::{LocalUser, UserInfo};
use crate::models::zone::{Zone, AddZoneMemberRequest, CreateZoneRequest};
use crate::dns::message::DnsMessage;
use crate::dns::message::MessageError;
#[get("/zones/<zone>/records")] #[get("/zones/<zone>/records")]
pub async fn get_zone_records( pub async fn get_zone_records(
mut client: dns::DnsClient, mut client: dns::client::DnsClient,
conn: DbConn, conn: DbConn,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<UserInfo, ErrorResponse>,
zone: dns::AbsoluteName zone: dns::name::AbsoluteName
) -> Result<Json<Vec<dns::Record>>, ErrorResponse> { ) -> Result<Json<Vec<dns::record::Record>>, ErrorResponse> {
let user_info = user_info?; let user_info = user_info?;
let zone_name = zone.to_string(); let zone_name = zone.to_string();
@ -55,8 +58,8 @@ pub async fn get_zone_records(
let answers = response.answers(); let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter() let mut records: Vec<_> = answers.to_vec().into_iter()
.map(dns::Record::from) .map(dns::record::Record::from)
.filter(|record| !matches!(record.rdata, dns::RData::NULL { .. } | dns::RData::DNSSEC(_))) .filter(|record| !matches!(record.rdata, dns::rdata::RData::NULL { .. } | dns::rdata::RData::DNSSEC(_)))
.collect(); .collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response. // AXFR response ends with SOA, we remove it so it is not doubled in the response.
@ -67,11 +70,11 @@ pub async fn get_zone_records(
#[post("/zones/<zone>/records", data = "<new_records>")] #[post("/zones/<zone>/records", data = "<new_records>")]
pub async fn create_zone_records( pub async fn create_zone_records(
mut client: dns::DnsClient, mut client: dns::client::DnsClient,
conn: DbConn, conn: DbConn,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<UserInfo, ErrorResponse>,
zone: dns::AbsoluteName, zone: dns::name::AbsoluteName,
new_records: Json<Vec<dns::Record>> new_records: Json<Vec<dns::record::Record>>
) -> Result<Json<()>, ErrorResponse> { ) -> Result<Json<()>, ErrorResponse> {
let user_info = user_info?; let user_info = user_info?;
@ -98,10 +101,6 @@ pub async fn create_zone_records(
} }
} }
let bad_zone_records: Vec<_> = records.iter().filter(|record| !zone.zone_of(record.name())).collect();
// TODO: Get zone class from somewhere instead of always assuming IN
let bad_class_records: Vec<_> = records.iter().filter(|record| record.dns_class() != DNSClass::IN).collect();
if !bad_records.is_empty() { if !bad_records.is_empty() {
return ErrorResponse::new( return ErrorResponse::new(
Status::BadRequest, Status::BadRequest,
@ -114,54 +113,21 @@ pub async fn create_zone_records(
).err(); ).err();
} }
if !bad_zone_records.is_empty() { let response = match client.add_records(zone.clone(), DNSClass::IN, records) {
return ErrorResponse::new( Ok(query) => query.await.map_err(make_500)?,
Status::BadRequest, Err(MessageError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone}) => {
"Record list contains records whose name does not belong to the zone".into() return ErrorResponse::new(
).with_details( Status::BadRequest,
json!({ "Record list contains records that do not belong to the zone".into()
"zone_name": zone.to_utf8(), ).with_details(
"records": bad_zone_records.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::Record>>() json!({
}) "zone_name": zone.to_utf8(),
).err(); "class": dns::class::DNSClass::from(class),
} "mismatched_class": mismatched_class.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::record::Record>>(),
"mismatched_zone": mismatched_zone.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::record::Record>>(),
if !bad_class_records.is_empty() { })
return ErrorResponse::new( ).err();
Status::BadRequest, }
"Record list contains records whose class differs from the zone class `IN`".into()
).with_details(
json!({
"zone_name": zone.to_utf8(),
"records": bad_class_records.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::Record>>()
})
).err();
}
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(DNSClass::IN)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(records);
{
let edns = message.edns_mut();
edns.set_max_payload(1232);
edns.set_version(0);
}
let response = {
let query = dns::ClientResponse(client.send(message));
query.await.map_err(make_500)?
}; };
// TODO: better error handling // TODO: better error handling
@ -199,7 +165,7 @@ pub async fn get_zones(
#[post("/zones", data = "<zone_request>")] #[post("/zones", data = "<zone_request>")]
pub async fn create_zone( pub async fn create_zone(
conn: DbConn, conn: DbConn,
mut client: dns::DnsClient, mut client: dns::client::DnsClient,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<UserInfo, ErrorResponse>,
zone_request: Json<CreateZoneRequest>, zone_request: Json<CreateZoneRequest>,
) -> Result<Json<Zone>, ErrorResponse> { ) -> Result<Json<Zone>, ErrorResponse> {
@ -230,7 +196,7 @@ pub async fn create_zone(
#[post("/zones/<zone>/members", data = "<zone_member_request>")] #[post("/zones/<zone>/members", data = "<zone_member_request>")]
pub async fn add_member_to_zone<'r>( pub async fn add_member_to_zone<'r>(
conn: DbConn, conn: DbConn,
zone: dns::AbsoluteName, zone: dns::name::AbsoluteName,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<UserInfo, ErrorResponse>,
zone_member_request: Json<AddZoneMemberRequest> zone_member_request: Json<AddZoneMemberRequest>
) -> Result<Response<'r>, ErrorResponse> { ) -> Result<Response<'r>, ErrorResponse> {

View file

@ -1,6 +1,6 @@
table! { table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::models::users::*; use crate::models::user::*;
localuser (user_id) { localuser (user_id) {
user_id -> Text, user_id -> Text,