diff --git a/src/dns/class.rs b/src/dns/class.rs new file mode 100644 index 0000000..e95eab8 --- /dev/null +++ b/src/dns/class.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +use super::trust_dns_types; + + +#[derive(Deserialize, Serialize, Clone)] +pub enum DNSClass { + IN, + CH, + HS, + NONE, + ANY, + OPT(u16), +} + +impl From for DNSClass { + fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass { + match dns_class { + trust_dns_types::DNSClass::IN => DNSClass::IN, + trust_dns_types::DNSClass::CH => DNSClass::CH, + trust_dns_types::DNSClass::HS => DNSClass::HS, + trust_dns_types::DNSClass::NONE => DNSClass::NONE, + trust_dns_types::DNSClass::ANY => DNSClass::ANY, + trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v), + } + } +} + +impl From for trust_dns_types::DNSClass { + fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass { + match dns_class { + DNSClass::IN => trust_dns_types::DNSClass::IN, + DNSClass::CH => trust_dns_types::DNSClass::CH, + DNSClass::HS => trust_dns_types::DNSClass::HS, + DNSClass::NONE => trust_dns_types::DNSClass::NONE, + DNSClass::ANY => trust_dns_types::DNSClass::ANY, + DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v), + } + } +} \ No newline at end of file diff --git a/src/dns/client.rs b/src/dns/client.rs new file mode 100644 index 0000000..da942fe --- /dev/null +++ b/src/dns/client.rs @@ -0,0 +1,71 @@ +use std::{future::Future, pin::Pin, task::{Context, Poll}}; +use std::ops::{Deref, DerefMut}; + +use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}}; +use tokio::{net::TcpStream as TokioTcpStream, task}; +use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream}; +use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; + +use crate::config::Config; +use super::message::DnsMessage; + + +pub struct DnsClient(AsyncClient); + +impl Deref for DnsClient { + type Target = AsyncClient; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DnsClient { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl DnsMessage for AsyncClient {} + + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DnsClient { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::>().await); + let (stream, handle) = TcpClientStream::>::new(config.dns.server); + let client = AsyncClient::with_timeout( + stream, + handle, + std::time::Duration::from_secs(5), + None); + let (client, bg) = match client.await { + Err(e) => { + println!("Failed to connect to DNS server {:#?}", e); + return Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => c + }; + task::spawn(bg); + Outcome::Success(DnsClient(client)) + } +} + +// Reimplement this type here as ClientReponse in trust-dns crate have private fields +pub struct ClientResponse(pub(crate) R) +where + R: Future> + Send + Unpin + 'static; + +impl Future for ClientResponse +where + R: Future> + Send + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // This is from the future_utils crate, we simply reuse the reexport from Rocket + rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) + } +} \ No newline at end of file diff --git a/src/dns/message.rs b/src/dns/message.rs new file mode 100644 index 0000000..1e1e62f --- /dev/null +++ b/src/dns/message.rs @@ -0,0 +1,67 @@ +use trust_dns_proto::DnsHandle; +use trust_dns_client::rr::{DNSClass, RecordType}; +use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query}; +use trust_dns_proto::error::ProtoError; + +use super::trust_dns_types::{Name, Record}; +use super::client::{ClientResponse}; + + +pub enum MessageError { + RecordNotInZone { + zone: Name, + class: DNSClass, + mismatched_class: Vec, + mismatched_zone: Vec, + } +} + + +pub trait DnsMessage: DnsHandle + Send { + fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec) -> Result, MessageError> + { + let mut mismatched_class = Vec::new(); + let mut mismatched_zone = Vec::new(); + + for record in new_records.iter() { + if !zone.zone_of(record.name()) { + mismatched_zone.push(record.clone()); + } + if record.dns_class() != class { + mismatched_class.push(record.clone()); + } + } + + if mismatched_class.len() > 0 || mismatched_zone.len() > 0 { + return Err(MessageError::RecordNotInZone { + zone, + class, + mismatched_zone, + mismatched_class + }) + } + + let mut zone_query = Query::new(); + zone_query.set_name(zone.clone()) + .set_query_class(class) + .set_query_type(RecordType::SOA); + let mut message = Message::new(); + + // TODO: set random / time based id + message + .set_id(0) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Update) + .set_recursion_desired(false); + message.add_zone(zone_query); + message.add_updates(new_records); + + { + let edns = message.edns_mut(); + edns.set_max_payload(1232); + edns.set_version(0); + } + + return Ok(ClientResponse(self.send(message))); + } +} \ No newline at end of file diff --git a/src/dns/mod.rs b/src/dns/mod.rs new file mode 100644 index 0000000..8143b60 --- /dev/null +++ b/src/dns/mod.rs @@ -0,0 +1,16 @@ +pub mod class; +pub mod name; +pub mod rdata; +pub mod record; +pub mod client; +pub mod message; + +pub mod trust_dns_types { + pub use trust_dns_client::rr::rdata::{ + DNSSECRData, caa, sshfp, mx, null, soa, srv, txt + }; + pub use trust_dns_client::rr::{ + RData, DNSClass, Record + }; + pub use trust_dns_proto::rr::Name; +} \ No newline at end of file diff --git a/src/dns/name.rs b/src/dns/name.rs new file mode 100644 index 0000000..4dca8e8 --- /dev/null +++ b/src/dns/name.rs @@ -0,0 +1,72 @@ +use std::ops::Deref; + + +use rocket::request::FromParam; +use serde::{Deserialize, Serialize, Deserializer, Serializer}; +use trust_dns_proto::error::ProtoError; + +use super::trust_dns_types::Name; + + +#[derive(Debug, Clone)] +pub struct SerdeName(pub(crate)Name); + +impl Deref for SerdeName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'de> Deserialize<'de> for SerdeName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de> + { + use serde::de::Error; + + String::deserialize(deserializer) + .and_then(|string| + Name::from_utf8(&string) + .map_err(|e| Error::custom(e.to_string())) + ).map( SerdeName) + } +} + +impl Serialize for SerdeName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer + { + self.0.to_utf8().serialize(serializer) + } +} + +impl SerdeName { + pub fn into_inner(self) -> Name { + self.0 + } +} + + +#[derive(Debug, Deserialize)] +pub struct AbsoluteName(SerdeName); + +impl<'r> FromParam<'r> for AbsoluteName { + type Error = ProtoError; + + fn from_param(param: &'r str) -> Result { + let mut name = Name::from_utf8(¶m)?; + if !name.is_fqdn() { + name.set_fqdn(true); + } + Ok(AbsoluteName(SerdeName(name))) + } +} + +impl Deref for AbsoluteName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0.0 + } +} \ No newline at end of file diff --git a/src/models/dns.rs b/src/dns/rdata.rs similarity index 63% rename from src/models/dns.rs rename to src/dns/rdata.rs index c9d13b8..99f53af 100644 --- a/src/models/dns.rs +++ b/src/dns/rdata.rs @@ -1,21 +1,14 @@ -use std::{convert::{TryFrom, TryInto}, future::Future, net::{Ipv6Addr, Ipv4Addr}, pin::Pin, task::{Context, Poll}}; use std::fmt; -use std::ops::{Deref, DerefMut}; +use std::convert::TryFrom; +use std::net::{Ipv6Addr, Ipv4Addr}; +use serde::{Deserialize, Serialize}; -use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}}; +use trust_dns_client::serialize::binary::BinEncoder; +use trust_dns_proto::error::ProtoError; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -use tokio::{net::TcpStream as TokioTcpStream, task}; - -use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, serialize::binary::BinEncoder, tcp::TcpClientStream}; -use trust_dns_proto::error::{ProtoError}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; - - -use super::trust_dns_types::{self, Name}; -use crate::config::Config; +use super::trust_dns_types; +use super::name::SerdeName; #[derive(Deserialize, Serialize, Clone)] @@ -103,6 +96,7 @@ pub enum RData { // ZERO, // TODO: DS + // TODO: TLSA } impl From for RData { @@ -286,195 +280,3 @@ impl<'a> fmt::Display for CAAValue<'a> { Ok(()) } } - -#[derive(Deserialize, Serialize, Clone)] -pub enum DNSClass { - IN, - CH, - HS, - NONE, - ANY, - OPT(u16), -} - -impl From for DNSClass { - fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass { - match dns_class { - trust_dns_types::DNSClass::IN => DNSClass::IN, - trust_dns_types::DNSClass::CH => DNSClass::CH, - trust_dns_types::DNSClass::HS => DNSClass::HS, - trust_dns_types::DNSClass::NONE => DNSClass::NONE, - trust_dns_types::DNSClass::ANY => DNSClass::ANY, - trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v), - } - } -} - -impl From for trust_dns_types::DNSClass { - fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass { - match dns_class { - DNSClass::IN => trust_dns_types::DNSClass::IN, - DNSClass::CH => trust_dns_types::DNSClass::CH, - DNSClass::HS => trust_dns_types::DNSClass::HS, - DNSClass::NONE => trust_dns_types::DNSClass::NONE, - DNSClass::ANY => trust_dns_types::DNSClass::ANY, - DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v), - } - } -} - - -// Reimplement this type here as ClientReponse in trust-dns crate have private fields -pub struct ClientResponse(pub(crate) R) -where - R: Future> + Send + Unpin + 'static; - -impl Future for ClientResponse -where - R: Future> + Send + Unpin + 'static, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // This is from the future_utils crate, we simply reuse the reexport from Rocket - rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) - } -} - - - -#[derive(Deserialize, Serialize, Clone)] -pub struct Record { - #[serde(rename = "Name")] - pub name: SerdeName, - // TODO: Make class optional, default to IN - #[serde(rename = "Class")] - pub dns_class: DNSClass, - #[serde(rename = "TTL")] - pub ttl: u32, - #[serde(flatten)] - pub rdata: RData, -} - -impl From for Record { - fn from(record: trust_dns_types::Record) -> Record { - Record { - name: SerdeName(record.name().clone()), - dns_class: record.dns_class().into(), - ttl: record.ttl(), - rdata: record.into_data().into(), - } - } -} - -impl TryFrom for trust_dns_types::Record { - type Error = ProtoError; - - fn try_from(record: Record) -> Result { - let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?); - trust_dns_record.set_dns_class(record.dns_class.into()); - Ok(trust_dns_record) - } -} - -#[derive(Debug, Clone)] -pub struct SerdeName(Name); - -impl<'de> Deserialize<'de> for SerdeName { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de> - { - use serde::de::Error; - - String::deserialize(deserializer) - .and_then(|string| - Name::from_utf8(&string) - .map_err(|e| Error::custom(e.to_string())) - ).map( SerdeName) - } -} - -impl Deref for SerdeName { - type Target = Name; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Serialize for SerdeName { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer - { - self.0.to_utf8().serialize(serializer) - } -} - -impl SerdeName { - fn into_inner(self) -> Name { - self.0 - } -} - - -#[derive(Debug, Deserialize)] -pub struct AbsoluteName(SerdeName); - -impl<'r> FromParam<'r> for AbsoluteName { - type Error = ProtoError; - - fn from_param(param: &'r str) -> Result { - let mut name = Name::from_utf8(¶m)?; - if !name.is_fqdn() { - name.set_fqdn(true); - } - Ok(AbsoluteName(SerdeName(name))) - } -} - -impl Deref for AbsoluteName { - type Target = Name; - fn deref(&self) -> &Self::Target { - &self.0.0 - } -} -pub struct DnsClient(AsyncClient); - -impl Deref for DnsClient { - type Target = AsyncClient; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for DnsClient { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for DnsClient { - type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { - let config = try_outcome!(request.guard::>().await); - let (stream, handle) = TcpClientStream::>::new(config.dns.server); - let client = AsyncClient::with_timeout( - stream, - handle, - std::time::Duration::from_secs(5), - None); - let (client, bg) = match client.await { - Err(e) => { - println!("Failed to connect to DNS server {:#?}", e); - return Outcome::Failure((Status::InternalServerError, ())) - }, - Ok(c) => c - }; - task::spawn(bg); - Outcome::Success(DnsClient(client)) - } -} diff --git a/src/dns/record.rs b/src/dns/record.rs new file mode 100644 index 0000000..dc1debb --- /dev/null +++ b/src/dns/record.rs @@ -0,0 +1,43 @@ +use std::convert::{TryFrom, TryInto}; +use serde::{Deserialize, Serialize}; +use trust_dns_proto::error::ProtoError; + +use super::trust_dns_types; +use super::name::SerdeName; +use super::class::DNSClass; +use super::rdata::RData; + + +#[derive(Deserialize, Serialize, Clone)] +pub struct Record { + #[serde(rename = "Name")] + pub name: SerdeName, + // TODO: Make class optional, default to IN + #[serde(rename = "Class")] + pub dns_class: DNSClass, + #[serde(rename = "TTL")] + pub ttl: u32, + #[serde(flatten)] + pub rdata: RData, +} + +impl From for Record { + fn from(record: trust_dns_types::Record) -> Record { + Record { + name: SerdeName(record.name().clone()), + dns_class: record.dns_class().into(), + ttl: record.ttl(), + rdata: record.into_data().into(), + } + } +} + +impl TryFrom for trust_dns_types::Record { + type Error = ProtoError; + + fn try_from(record: Record) -> Result { + let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?); + trust_dns_record.set_dns_class(record.dns_class.into()); + Ok(trust_dns_record) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e4ed571..03aea0b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ mod models; mod config; mod schema; mod routes; +mod dns; use routes::users::*; use routes::zones::*; diff --git a/src/models/auth.rs b/src/models/auth.rs new file mode 100644 index 0000000..05fca5e --- /dev/null +++ b/src/models/auth.rs @@ -0,0 +1,63 @@ +use uuid::Uuid; +use serde::{Serialize, Deserialize}; +use chrono::serde::ts_seconds; +use chrono::prelude::{DateTime, Utc}; +use chrono::Duration; +use jsonwebtoken::{ + encode, decode, + Header, Validation, + Algorithm as JwtAlgorithm, EncodingKey, DecodingKey, + errors::Result as JwtResult +}; + +use crate::models::user::UserInfo; + + + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthClaims { + pub jti: String, + pub sub: String, + #[serde(with = "ts_seconds")] + pub exp: DateTime, + #[serde(with = "ts_seconds")] + pub iat: DateTime, +} + +#[derive(Debug, Serialize)] +pub struct AuthTokenResponse { + pub token: String +} + +#[derive(Debug, Deserialize)] +pub struct AuthTokenRequest { + pub username: String, + pub password: String, +} + +impl AuthClaims { + pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims { + let jti = Uuid::new_v4().to_simple().to_string(); + let iat = Utc::now(); + let exp = iat + token_duration; + + AuthClaims { + jti, + sub: user_info.id.clone(), + exp, + iat, + } + } + + pub fn decode(token: &str, secret: &str) -> JwtResult { + decode::( + token, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::new(JwtAlgorithm::HS256) + ).map(|data| data.claims) + } + + pub fn encode(self, secret: &str) -> JwtResult { + encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref())) + } +} \ No newline at end of file diff --git a/src/models/errors.rs b/src/models/errors.rs index 589d10d..c7edd76 100644 --- a/src/models/errors.rs +++ b/src/models/errors.rs @@ -3,8 +3,36 @@ use rocket::http::Status; use rocket::request::{Request, Outcome}; use rocket::response::{self, Response, Responder}; use rocket_contrib::json::Json; -use crate::models::users::UserError; use serde_json::Value; +use djangohashers::{HasherError}; +use diesel::result::Error as DieselError; + + +#[derive(Debug)] +pub enum UserError { + ZoneNotFound, + NotFound, + UserConflict, + BadCreds, + BadToken, + ExpiredToken, + MalformedHeader, + PermissionDenied, + DbError(DieselError), + PasswordError(HasherError), +} + +impl From for UserError { + fn from(e: HasherError) -> Self { + UserError::PasswordError(e) + } +} + +impl From for UserError { + fn from(e: DieselError) -> Self { + UserError::DbError(e) + } +} #[derive(Serialize, Debug)] diff --git a/src/models/mod.rs b/src/models/mod.rs index a43b162..707a5da 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,13 +1,5 @@ -pub mod dns; +//pub mod dns; pub mod errors; -pub mod users; - -pub mod trust_dns_types { - pub use trust_dns_client::rr::rdata::{ - DNSSECRData, caa, sshfp, mx, null, soa, srv, txt - }; - pub use trust_dns_client::rr::{ - RData, DNSClass, Record - }; - pub use trust_dns_proto::rr::Name; -} +pub mod user; +pub mod zone; +pub mod auth; diff --git a/src/models/users.rs b/src/models/user.rs similarity index 61% rename from src/models/users.rs rename to src/models/user.rs index d19142c..95db8e8 100644 --- a/src/models/users.rs +++ b/src/models/user.rs @@ -3,25 +3,20 @@ use diesel::prelude::*; use diesel::result::Error as DieselError; use diesel_derive_enum::DbEnum; use rocket::{State, request::{FromRequest, Request, Outcome}}; -use serde::{Serialize, Deserialize}; -use chrono::serde::ts_seconds; -use chrono::prelude::{DateTime, Utc}; -use chrono::Duration; +use serde::{Deserialize}; // TODO: Maybe just use argon2 crate directly -use djangohashers::{make_password_with_algorithm, check_password, HasherError, Algorithm}; +use djangohashers::{make_password_with_algorithm, check_password, Algorithm}; use jsonwebtoken::{ - encode, decode, - Header, Validation, - Algorithm as JwtAlgorithm, EncodingKey, DecodingKey, - errors::Result as JwtResult, errors::ErrorKind as JwtErrorKind }; use crate::schema::*; use crate::DbConn; use crate::config::Config; -use crate::models::errors::{ErrorResponse, make_500}; -use crate::models::dns::AbsoluteName; +use crate::models::errors::{UserError, ErrorResponse, make_500}; +use crate::models::zone::Zone; +use crate::models::auth::AuthClaims; + const BEARER: &str = "Bearer "; const AUTH_HEADER: &str = "Authorization"; @@ -61,14 +56,6 @@ pub struct UserZone { pub zone_id: String, } -#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)] -#[table_name = "zone"] -pub struct Zone { - #[serde(skip)] - pub id: String, - pub name: String, -} - #[derive(Debug, Deserialize)] pub struct CreateUserRequest { pub username: String, @@ -77,42 +64,6 @@ pub struct CreateUserRequest { pub role: Option } -#[derive(Debug, Deserialize)] -pub struct AddZoneMemberRequest { - pub id: String, -} - -#[derive(Debug, Deserialize)] -pub struct CreateZoneRequest { - pub name: AbsoluteName, -} - -// pub struct LdapUserAssociation { -// user_id: Uuid, -// ldap_id: String -// } - -#[derive(Debug, Serialize, Deserialize)] -pub struct AuthClaims { - pub jti: String, - pub sub: String, - #[serde(with = "ts_seconds")] - pub exp: DateTime, - #[serde(with = "ts_seconds")] - pub iat: DateTime, -} - -#[derive(Debug, Serialize)] -pub struct AuthTokenResponse { - pub token: String -} - -#[derive(Debug, Deserialize)] -pub struct AuthTokenRequest { - pub username: String, - pub password: String, -} - #[derive(Debug)] pub struct UserInfo { pub id: String, @@ -205,32 +156,6 @@ impl<'r> FromRequest<'r> for UserInfo { } } -#[derive(Debug)] -pub enum UserError { - ZoneNotFound, - NotFound, - UserConflict, - BadCreds, - BadToken, - ExpiredToken, - MalformedHeader, - PermissionDenied, - DbError(DieselError), - PasswordError(HasherError), -} - -impl From for UserError { - fn from(e: HasherError) -> Self { - UserError::PasswordError(e) - } -} - -impl From for UserError { - fn from(e: DieselError) -> Self { - UserError::DbError(e) - } -} - impl LocalUser { pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result { use crate::schema::localuser::dsl::*; @@ -320,95 +245,4 @@ impl LocalUser { username: client_localuser.username, }) } - -} - -impl AuthClaims { - pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims { - let jti = Uuid::new_v4().to_simple().to_string(); - let iat = Utc::now(); - let exp = iat + token_duration; - - AuthClaims { - jti, - sub: user_info.id.clone(), - exp, - iat, - } - } - - pub fn decode(token: &str, secret: &str) -> JwtResult { - decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::new(JwtAlgorithm::HS256) - ).map(|data| data.claims) - } - - pub fn encode(self, secret: &str) -> JwtResult { - encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref())) - } -} - -// NOTE: Should probably not be implemented here -// also, "UserError" seems like a misleading name -impl Zone { - pub fn get_all(conn: &diesel::SqliteConnection) -> Result, UserError> { - use crate::schema::zone::dsl::*; - - zone.get_results(conn) - .map_err(UserError::DbError) - } - - pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result { - use crate::schema::zone::dsl::*; - - zone.filter(name.eq(zone_name)) - .get_result(conn) - .map_err(|e| match e { - DieselError::NotFound => UserError::ZoneNotFound, - other => UserError::DbError(other) - }) - } - - pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result { - use crate::schema::zone::dsl::*; - - let new_zone = Zone { - id: Uuid::new_v4().to_simple().to_string(), - name: zone_request.name.to_utf8(), - }; - - diesel::insert_into(zone) - .values(&new_zone) - .execute(conn) - .map_err(|e| match e { - DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict, - other => UserError::DbError(other) - })?; - Ok(new_zone) - } - - - pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> { - use crate::schema::user_zone::dsl::*; - - let new_user_zone = UserZone { - zone_id: self.id.clone(), - user_id: new_member.id.clone() - }; - - let res = diesel::insert_into(user_zone) - .values(new_user_zone) - .execute(conn); - - match res { - // If user has already access to the zone, safely ignore the conflit - // TODO: use 'on conflict do nothing' in postgres when we get there - Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (), - Err(e) => return Err(e.into()), - Ok(_) => () - }; - Ok(()) - } -} +} \ No newline at end of file diff --git a/src/models/zone.rs b/src/models/zone.rs new file mode 100644 index 0000000..47e62d1 --- /dev/null +++ b/src/models/zone.rs @@ -0,0 +1,93 @@ +use crate::models::user::UserInfo; + +use uuid::Uuid; +use diesel::prelude::*; +use diesel::result::Error as DieselError; +use serde::{Serialize, Deserialize}; + +use crate::schema::*; +use crate::dns::name::AbsoluteName; +use crate::models::user::UserZone; +use crate::models::errors::UserError; + + +#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)] +#[table_name = "zone"] +pub struct Zone { + #[serde(skip)] + pub id: String, + pub name: String, +} + +#[derive(Debug, Deserialize)] +pub struct AddZoneMemberRequest { + pub id: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateZoneRequest { + pub name: AbsoluteName, +} + +// NOTE: Should probably not be implemented here +// also, "UserError" seems like a misleading name +impl Zone { + pub fn get_all(conn: &diesel::SqliteConnection) -> Result, UserError> { + use crate::schema::zone::dsl::*; + + zone.get_results(conn) + .map_err(UserError::DbError) + } + + pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result { + use crate::schema::zone::dsl::*; + + zone.filter(name.eq(zone_name)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::ZoneNotFound, + other => UserError::DbError(other) + }) + } + + pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result { + use crate::schema::zone::dsl::*; + + let new_zone = Zone { + id: Uuid::new_v4().to_simple().to_string(), + name: zone_request.name.to_utf8(), + }; + + diesel::insert_into(zone) + .values(&new_zone) + .execute(conn) + .map_err(|e| match e { + DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict, + other => UserError::DbError(other) + })?; + Ok(new_zone) + } + + + pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> { + use crate::schema::user_zone::dsl::*; + + let new_user_zone = UserZone { + zone_id: self.id.clone(), + user_id: new_member.id.clone() + }; + + let res = diesel::insert_into(user_zone) + .values(new_user_zone) + .execute(conn); + + match res { + // If user has already access to the zone, safely ignore the conflit + // TODO: use 'on conflict do nothing' in postgres when we get there + Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (), + Err(e) => return Err(e.into()), + Ok(_) => () + }; + Ok(()) + } +} \ No newline at end of file diff --git a/src/routes/users.rs b/src/routes/users.rs index e6967b3..524d801 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -5,13 +5,8 @@ use rocket::http::Status; use crate::config::Config; use crate::DbConn; use crate::models::errors::{ErrorResponse, make_500}; -use crate::models::users::{ - LocalUser, - CreateUserRequest, - AuthClaims, - AuthTokenRequest, - AuthTokenResponse -}; +use crate::models::user::{LocalUser, CreateUserRequest}; +use crate::models::auth::{AuthClaims, AuthTokenRequest, AuthTokenResponse}; #[post("/users/me/token", data = "")] diff --git a/src/routes/zones.rs b/src/routes/zones.rs index ae7ddeb..d251543 100644 --- a/src/routes/zones.rs +++ b/src/routes/zones.rs @@ -5,27 +5,30 @@ use rocket::http::Status; use rocket_contrib::json::Json; -use trust_dns_client::{client::ClientHandle, op::UpdateMessage}; +use trust_dns_client::client::ClientHandle; use trust_dns_client::op::ResponseCode; use trust_dns_client::rr::{DNSClass, RecordType}; -use trust_dns_proto::DnsHandle; + pub use trust_dns_client::op::Message; pub use trust_dns_client::op::OpCode; pub use trust_dns_client::op::Query; pub use trust_dns_client::op::MessageType; -use crate::{DbConn, models::{dns, trust_dns_types}}; +use crate::{dns::{self, trust_dns_types}, DbConn}; use crate::models::errors::{ErrorResponse, make_500}; -use crate::models::users::{LocalUser, UserInfo, Zone, AddZoneMemberRequest, CreateZoneRequest}; +use crate::models::user::{LocalUser, UserInfo}; +use crate::models::zone::{Zone, AddZoneMemberRequest, CreateZoneRequest}; +use crate::dns::message::DnsMessage; +use crate::dns::message::MessageError; #[get("/zones//records")] pub async fn get_zone_records( - mut client: dns::DnsClient, + mut client: dns::client::DnsClient, conn: DbConn, user_info: Result, - zone: dns::AbsoluteName -) -> Result>, ErrorResponse> { + zone: dns::name::AbsoluteName +) -> Result>, ErrorResponse> { let user_info = user_info?; let zone_name = zone.to_string(); @@ -55,8 +58,8 @@ pub async fn get_zone_records( let answers = response.answers(); let mut records: Vec<_> = answers.to_vec().into_iter() - .map(dns::Record::from) - .filter(|record| !matches!(record.rdata, dns::RData::NULL { .. } | dns::RData::DNSSEC(_))) + .map(dns::record::Record::from) + .filter(|record| !matches!(record.rdata, dns::rdata::RData::NULL { .. } | dns::rdata::RData::DNSSEC(_))) .collect(); // AXFR response ends with SOA, we remove it so it is not doubled in the response. @@ -67,11 +70,11 @@ pub async fn get_zone_records( #[post("/zones//records", data = "")] pub async fn create_zone_records( - mut client: dns::DnsClient, + mut client: dns::client::DnsClient, conn: DbConn, user_info: Result, - zone: dns::AbsoluteName, - new_records: Json> + zone: dns::name::AbsoluteName, + new_records: Json> ) -> Result, ErrorResponse> { let user_info = user_info?; @@ -98,10 +101,6 @@ pub async fn create_zone_records( } } - let bad_zone_records: Vec<_> = records.iter().filter(|record| !zone.zone_of(record.name())).collect(); - // TODO: Get zone class from somewhere instead of always assuming IN - let bad_class_records: Vec<_> = records.iter().filter(|record| record.dns_class() != DNSClass::IN).collect(); - if !bad_records.is_empty() { return ErrorResponse::new( Status::BadRequest, @@ -114,54 +113,21 @@ pub async fn create_zone_records( ).err(); } - if !bad_zone_records.is_empty() { - return ErrorResponse::new( - Status::BadRequest, - "Record list contains records whose name does not belong to the zone".into() - ).with_details( - json!({ - "zone_name": zone.to_utf8(), - "records": bad_zone_records.into_iter().map(|r| r.clone().into()).collect::>() - }) - ).err(); - } - - if !bad_class_records.is_empty() { - return ErrorResponse::new( - Status::BadRequest, - "Record list contains records whose class differs from the zone class `IN`".into() - ).with_details( - json!({ - "zone_name": zone.to_utf8(), - "records": bad_class_records.into_iter().map(|r| r.clone().into()).collect::>() - }) - ).err(); - } - - let mut zone_query = Query::new(); - zone_query.set_name(zone.clone()) - .set_query_class(DNSClass::IN) - .set_query_type(RecordType::SOA); - let mut message = Message::new(); - - // TODO: set random / time based id - message - .set_id(0) - .set_message_type(MessageType::Query) - .set_op_code(OpCode::Update) - .set_recursion_desired(false); - message.add_zone(zone_query); - message.add_updates(records); - - { - let edns = message.edns_mut(); - edns.set_max_payload(1232); - edns.set_version(0); - } - - let response = { - let query = dns::ClientResponse(client.send(message)); - query.await.map_err(make_500)? + let response = match client.add_records(zone.clone(), DNSClass::IN, records) { + Ok(query) => query.await.map_err(make_500)?, + Err(MessageError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone}) => { + return ErrorResponse::new( + Status::BadRequest, + "Record list contains records that do not belong to the zone".into() + ).with_details( + json!({ + "zone_name": zone.to_utf8(), + "class": dns::class::DNSClass::from(class), + "mismatched_class": mismatched_class.into_iter().map(|r| r.clone().into()).collect::>(), + "mismatched_zone": mismatched_zone.into_iter().map(|r| r.clone().into()).collect::>(), + }) + ).err(); + } }; // TODO: better error handling @@ -199,7 +165,7 @@ pub async fn get_zones( #[post("/zones", data = "")] pub async fn create_zone( conn: DbConn, - mut client: dns::DnsClient, + mut client: dns::client::DnsClient, user_info: Result, zone_request: Json, ) -> Result, ErrorResponse> { @@ -230,7 +196,7 @@ pub async fn create_zone( #[post("/zones//members", data = "")] pub async fn add_member_to_zone<'r>( conn: DbConn, - zone: dns::AbsoluteName, + zone: dns::name::AbsoluteName, user_info: Result, zone_member_request: Json ) -> Result, ErrorResponse> { diff --git a/src/schema.rs b/src/schema.rs index 775c738..10c2d36 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,6 +1,6 @@ table! { use diesel::sql_types::*; - use crate::models::users::*; + use crate::models::user::*; localuser (user_id) { user_id -> Text,