diff --git a/dev-scripts/config/knot.conf b/dev-scripts/config/knot.conf index 34b3dba..e6e04ef 100644 --- a/dev-scripts/config/knot.conf +++ b/dev-scripts/config/knot.conf @@ -8,7 +8,7 @@ log: acl: - id: example_acl address: [ 127.0.0.1, ::1] - action: transfer + action: [transfer, update] template: - id: default diff --git a/dev-scripts/docker-compose.yml b/dev-scripts/docker-compose.yml index c509d4b..5738e1c 100644 --- a/dev-scripts/docker-compose.yml +++ b/dev-scripts/docker-compose.yml @@ -2,7 +2,7 @@ services: knot: image: cznic/knot volumes: - - $PWD/zones:/storage/zones:ro - - $PWD/config:/config:ro + - ./zones:/storage/zones:ro + - ./config:/config:ro command: knotd network_mode: host diff --git a/src/main.rs b/src/main.rs index 676e88e..e4ed571 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,7 @@ async fn rocket() -> rocket::Rocket { .attach(DbConn::fairing()) .mount("/api/v1", routes![ get_zone_records, + create_zone_records, get_zones, create_zone, add_member_to_zone, diff --git a/src/models/dns.rs b/src/models/dns.rs index 2bb22b9..7f3e8a3 100644 --- a/src/models/dns.rs +++ b/src/models/dns.rs @@ -1,15 +1,15 @@ -use std::net::{Ipv6Addr, Ipv4Addr}; +use std::{convert::{TryFrom, TryInto}, future::Future, net::{Ipv6Addr, Ipv4Addr}, pin::Pin, task::{Context, Poll}}; use std::fmt; use std::ops::{Deref, DerefMut}; use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}}; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tokio::{net::TcpStream as TokioTcpStream, task}; -use trust_dns_client::{client::AsyncClient, serialize::binary::BinEncoder, tcp::TcpClientStream}; +use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, serialize::binary::BinEncoder, tcp::TcpClientStream}; use trust_dns_proto::error::{ProtoError}; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; @@ -38,14 +38,14 @@ pub enum RData { }, #[serde(rename_all = "PascalCase")] CNAME { - target: String + target: SerdeName }, // HINFO(HINFO), // HTTPS(SVCB), #[serde(rename_all = "PascalCase")] MX { preference: u16, - mail_exchanger: String + mail_exchanger: SerdeName }, // NAPTR(NAPTR), #[serde(rename_all = "PascalCase")] @@ -54,18 +54,18 @@ pub enum RData { }, #[serde(rename_all = "PascalCase")] NS { - target: String + target: SerdeName }, // OPENPGPKEY(OPENPGPKEY), // OPT(OPT), #[serde(rename_all = "PascalCase")] PTR { - target: String + target: SerdeName }, #[serde(rename_all = "PascalCase")] SOA { - master_server_name: String, - maintainer_name: String, + master_server_name: SerdeName, + maintainer_name: SerdeName, refresh: i32, retry: i32, expire: i32, @@ -74,7 +74,7 @@ pub enum RData { }, #[serde(rename_all = "PascalCase")] SRV { - server: String, + server: SerdeName, port: u16, priority: u16, weight: u16, @@ -115,7 +115,7 @@ impl From for RData { data: String::new() }, trust_dns_types::RData::CNAME(target) => RData::CNAME { - target: target.to_utf8() + target: SerdeName(target) }, trust_dns_types::RData::CAA(caa) => RData::CAA { issuer_critical: caa.issuer_critical(), @@ -124,20 +124,20 @@ impl From for RData { }, trust_dns_types::RData::MX(mx) => RData::MX { preference: mx.preference(), - mail_exchanger: mx.exchange().to_utf8() + mail_exchanger: SerdeName(mx.exchange().clone()) }, trust_dns_types::RData::NULL(null) => RData::NULL { data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default()) }, trust_dns_types::RData::NS(target) => RData::NS { - target: target.to_utf8() + target: SerdeName(target) }, trust_dns_types::RData::PTR(target) => RData::PTR { - target: target.to_utf8() + target: SerdeName(target) }, trust_dns_types::RData::SOA(soa) => RData::SOA { - master_server_name: soa.mname().to_utf8(), - maintainer_name: soa.rname().to_utf8(), + master_server_name: SerdeName(soa.mname().clone()), + maintainer_name: SerdeName(soa.rname().clone()), refresh: soa.refresh(), retry: soa.retry(), expire: soa.expire(), @@ -145,7 +145,7 @@ impl From for RData { serial: soa.serial() }, trust_dns_types::RData::SRV(srv) => RData::SRV { - server: srv.target().to_utf8(), + server: SerdeName(srv.target().clone()), port: srv.port(), priority: srv.priority(), weight: srv.weight(), @@ -173,6 +173,51 @@ impl From for RData { } } +impl TryFrom for trust_dns_types::RData { + type Error = ProtoError; + + fn try_from(rdata: RData) -> Result { + Ok(match rdata { + RData::A { address } => trust_dns_types::RData::A(address), + RData::AAAA { address } => trust_dns_types::RData::AAAA(address), + RData::CAA { issuer_critical, value, property_tag } => { + let property = trust_dns_types::caa::Property::from(property_tag); + let caa_value = { + // TODO: duplicate of trust_dns_client::serialize::txt::rdata_parser::caa::parse + // because caa::read_value is private + match property { + trust_dns_types::caa::Property::Issue | trust_dns_types::caa::Property::IssueWild => { + let value = trust_dns_types::caa::read_issuer(value.as_bytes())?; + trust_dns_types::caa::Value::Issuer(value.0, value.1) + } + trust_dns_types::caa::Property::Iodef => { + let url = trust_dns_types::caa::read_iodef(value.as_bytes())?; + trust_dns_types::caa::Value::Url(url) + } + trust_dns_types::caa::Property::Unknown(_) => trust_dns_types::caa::Value::Unknown(value.as_bytes().to_vec()), + } + }; + trust_dns_types::RData::CAA(trust_dns_types::caa::CAA { + issuer_critical, + tag: property, + value: caa_value, + }) + }, + RData::CNAME { target } => todo!(), + RData::MX { preference, mail_exchanger } => todo!(), + RData::NULL { data } => todo!(), + RData::NS { target } => todo!(), + RData::PTR { target } => todo!(), + RData::SOA { master_server_name, maintainer_name, refresh, retry, expire, minimum, serial } => todo!(), + RData::SRV { server, port, priority, weight } => todo!(), + RData::SSHFP { algorithm, digest_type, fingerprint } => todo!(), + RData::TXT { text } => todo!(), + RData::DNSSEC(_) => todo!(), + RData::Unknown { code, data } => todo!(), + }) + } +} + struct CAAValue<'a>(&'a trust_dns_types::caa::Value); // trust_dns Display implementation panics if no parameters @@ -225,11 +270,44 @@ impl From for DNSClass { } } +impl From for trust_dns_types::DNSClass { + fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass { + match dns_class { + DNSClass::IN => trust_dns_types::DNSClass::IN, + DNSClass::CH => trust_dns_types::DNSClass::CH, + DNSClass::HS => trust_dns_types::DNSClass::HS, + DNSClass::NONE => trust_dns_types::DNSClass::NONE, + DNSClass::ANY => trust_dns_types::DNSClass::ANY, + DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v), + } + } +} + + +// Reimplement this type here as ClientReponse in trust-dns crate have private fields +pub struct ClientResponse(pub(crate) R) +where + R: Future> + Send + Unpin + 'static; + +impl Future for ClientResponse +where + R: Future> + Send + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // This is from the future_utils crate, we simply reuse the reexport from Rocket + rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) + } +} + + #[derive(Deserialize, Serialize)] pub struct Record { #[serde(rename = "Name")] - pub name: String, + pub name: SerdeName, + // TODO: Make class optional, default to IN #[serde(rename = "Class")] pub dns_class: DNSClass, #[serde(rename = "TTL")] @@ -241,8 +319,7 @@ pub struct Record { impl From for Record { fn from(record: trust_dns_types::Record) -> Record { Record { - name: record.name().to_utf8(), - //rr_type: record.rr_type().into(), + name: SerdeName(record.name().clone()), dns_class: record.dns_class().into(), ttl: record.ttl(), rdata: record.into_data().into(), @@ -250,8 +327,59 @@ impl From for Record { } } +impl TryFrom for trust_dns_types::Record { + type Error = ProtoError; + + fn try_from(record: Record) -> Result { + let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?); + trust_dns_record.set_dns_class(record.dns_class.into()); + Ok(trust_dns_record) + } +} + #[derive(Debug)] -pub struct AbsoluteName(Name); +pub struct SerdeName(Name); + +impl<'de> Deserialize<'de> for SerdeName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de> + { + use serde::de::Error; + + String::deserialize(deserializer) + .and_then(|string| + Name::from_utf8(&string) + .map_err(|e| Error::custom(e.to_string())) + ).map( SerdeName) + } +} + +impl Deref for SerdeName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Serialize for SerdeName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer + { + self.0.to_utf8().serialize(serializer) + } +} + +impl SerdeName { + fn into_inner(self) -> Name { + self.0 + } +} + + +#[derive(Debug, Deserialize)] +pub struct AbsoluteName(SerdeName); impl<'r> FromParam<'r> for AbsoluteName { type Error = ProtoError; @@ -261,33 +389,16 @@ impl<'r> FromParam<'r> for AbsoluteName { if !name.is_fqdn() { name.set_fqdn(true); } - Ok(AbsoluteName(name)) - } -} - -impl<'de> Deserialize<'de> for AbsoluteName { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de> - { - use serde::de::Error; - - String::deserialize(deserializer) - .and_then(|string| - AbsoluteName::from_param(&string) - .map_err(|e| Error::custom(e.to_string())) - ) + Ok(AbsoluteName(SerdeName(name))) } } impl Deref for AbsoluteName { type Target = Name; fn deref(&self) -> &Self::Target { - &self.0 + &self.0.0 } } - - pub struct DnsClient(AsyncClient); impl Deref for DnsClient { diff --git a/src/routes/zones.rs b/src/routes/zones.rs index c29f2b1..4da2978 100644 --- a/src/routes/zones.rs +++ b/src/routes/zones.rs @@ -1,13 +1,23 @@ +use std::convert::TryFrom; +use std::convert::TryInto; + use rocket::Response; use rocket::http::Status; use rocket_contrib::json::Json; -use trust_dns_client::client::ClientHandle; +use serde_json::json; + +use trust_dns_client::{client::ClientHandle, op::UpdateMessage}; use trust_dns_client::op::ResponseCode; use trust_dns_client::rr::{DNSClass, RecordType}; +use trust_dns_proto::DnsHandle; +pub use trust_dns_client::op::Message; +pub use trust_dns_client::op::OpCode; +pub use trust_dns_client::op::Query; +pub use trust_dns_client::op::MessageType; -use crate::{DbConn, models::dns}; +use crate::{DbConn, models::{dns, trust_dns_types}}; use crate::models::errors::{ErrorResponse, make_500}; use crate::models::users::{LocalUser, UserInfo, Zone, AddZoneMemberRequest, CreateZoneRequest}; @@ -56,6 +66,81 @@ pub async fn get_zone_records( Ok(Json(records)) } +#[post("/zones//records", data = "")] +pub async fn create_zone_records( + mut client: dns::DnsClient, + conn: DbConn, + user_info: Result, + zone: dns::AbsoluteName, + new_records: Json> +) -> Result, ErrorResponse> { + + let user_info = user_info?; + let zone_name = zone.to_utf8(); + + conn.run(move |c| { + if user_info.is_admin() { + Zone::get_by_name(c, &zone_name) + } else { + user_info.get_zone(c, &zone_name) + } + }).await?; + // TODO: What about relative names (also in cnames and stuff) + // TODO: error handling + let records: Vec = new_records.into_inner().into_iter().map(|r| r.try_into().unwrap()).collect(); + + let bad_zone_records: Vec<_> = records.iter().filter(|record| !zone.zone_of(record.name())).collect(); + // TODO: Get zone class from somewhere instead of always assuming IN + let bad_class_records: Vec<_> = records.iter().filter(|record| record.dns_class() != DNSClass::IN).collect(); + + if !bad_zone_records.is_empty() { + return ErrorResponse::new( + Status::BadRequest, + format!("Record list contains records whose name that do not belong to the zone {}", *zone) + ).with_details( + json!(bad_zone_records.into_iter().map(|r| r.name().to_utf8()).collect::>()) + ).err() + } + + if !bad_class_records.is_empty() { + return ErrorResponse::new( + Status::BadRequest, + "Record list contains records whose class differs from the zone class `IN`".into() + ).with_details( + json!(bad_class_records.into_iter().map(|r| r.name().to_utf8()).collect::>()) + ).err() + } + + let mut zone_query = Query::new(); + zone_query.set_name(zone.clone()) + .set_query_class(DNSClass::IN) + .set_query_type(RecordType::SOA); + let mut message = Message::new(); + + // TODO: set random / time based id + message + .set_id(0) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Update) + .set_recursion_desired(false); + message.add_zone(zone_query); + message.add_updates(records); + + { + let edns = message.edns_mut(); + edns.set_max_payload(1232); + edns.set_version(0); + } + + // TODO: check if NOERROR or something + let _response = { + let query = dns::ClientResponse(client.send(message)); + query.await.map_err(make_500)? + }; + + Ok(Json(())) +} + #[get("/zones")] pub async fn get_zones( conn: DbConn, @@ -83,7 +168,6 @@ pub async fn create_zone( ) -> Result, ErrorResponse> { user_info?.check_admin()?; - // Check if the zone exists in the DNS server let response = { let query = client.query(zone_request.name.clone(), DNSClass::IN, RecordType::SOA);