This commit is contained in:
Hannaeko 2024-12-15 21:21:03 +01:00
parent efadd4dda0
commit 39cef3b600
27 changed files with 2764 additions and 2206 deletions

2333
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,35 +1,26 @@
[package]
name = "nomilo"
version = "0.1.0-dev"
authors = ["DNS Witch Collective <dns-witch@familier.net.eu.org>"]
version = "0.2.0-dev"
authors = ["DNS Witch Collective <dns-witch@dns-witch.eu.org>"]
edition = "2021"
license = "AGPL-3.0-or-later"
readme = "README.md"
repository = "https://git.familier.net.eu.org/dns-witch/nomilo"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
trust-dns-client = { version = "0.22", features = ["dnssec-openssl"] }
trust-dns-proto = "0.22"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
rocket = { version = "0.5.0-rc.2", features = ["json"], default-features = false }
rocket_sync_db_pools = { default-features = false, features = ["diesel_sqlite_pool"], version = "0.1.0-rc.2"}
base64 = "0.21"
uuid = { version = "0.8", features = ["v4", "serde"] }
diesel = { version = "1.4", features = ["sqlite", "chrono"] }
diesel_migrations = "1.4"
diesel-derive-enum = { version = "1", features = ["sqlite"] }
chrono = { version = "0.4", features = ["serde"] }
humantime = "2.1"
tokio = "1"
figment = { version = "0.10", features = ["toml", "env"] }
clap = {version = "3", features = ["derive", "cargo"]}
argon2 = {version = "0.4", default-features = false, features = ["alloc", "password-hash"] }
rand = "0.8"
tera = {version = "1", default-features = false}
# From trust-dns-client
futures-util = { version = "0.3", default-features = false, features = ["std"] }
# From rocket / cookie-rs
time = "0.3"
#uuid = { version = "1.11", features = ["v4", "serde"] }
#chrono = { version = "0.4", features = ["serde"] }
#humantime = "2.1"
tokio = {version = "1", default-features = false, features = [ "macros", "rt-multi-thread" ] }
#clap = { version = "4", features = [ "derive", "cargo" ] }
#argon2 = { version = "0.5", default-features = false, features = ["alloc", "password-hash"] }
#rand = "0.8"
#tera = { version = "1", default-features = false }
domain = { version = "0.10.3", features = [ "tsig", "unstable-client-transport" ]}
axum = { version = "0.8.0-alpha.1", default-features = false, features = [ "http1", "json", "form", "query", "tokio" ]}
bb8 = { version = "0.9" }
rusqlite = { version = "0.32"}
async-trait = { version = "0.1" }

View file

@ -1,8 +1,16 @@
services:
knot:
image: cznic/knot
image: cznic/knot:3.4
volumes:
- ./zones:/storage/zones:ro
- ./config:/config:ro
command: knotd
command: knotd --verbose
network_mode: host
named:
image: internetsystemsconsortium/bind9:9.20
volumes:
- ./zones:/var/lib/bind:ro
- ./config:/etc/bind:ro
#command: named -g
network_mode: host

View file

@ -7,6 +7,7 @@ example.com. IN SOA ns.example.com. admin.example.com. (
)
example.com. 84600 IN NS ns.example.com.
ns.example.com. 84600 IN A 198.51.100.3
srv1.example.com. 600 IN A 198.51.100.3
srv1.example.com. 600 IN AAAA 2001:db8:cafe:bc68::2

65
src/database.rs Normal file
View file

@ -0,0 +1,65 @@
use std::sync::Arc;
use crate::ressouces::zone::ZoneModel;
pub trait Db: ZoneModel + Send + Sync {}
pub type BoxedDb = Arc<dyn Db>;
impl Db for sqlite::SqliteDB {}
pub mod sqlite {
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Clone)]
pub struct SqliteDB {
pub pool: bb8::Pool<SqliteConnManager>
}
impl SqliteDB {
pub async fn new(path: PathBuf) -> Self {
let pool = bb8::Pool::builder()
.build(SqliteConnManager::new(path))
.await
.expect("Unable to connect to database");
SqliteDB {
pool,
}
}
}
#[derive(Clone)]
pub struct SqliteConnManager {
path: Arc<PathBuf>
}
impl SqliteConnManager {
pub fn new(path: PathBuf) -> Self {
SqliteConnManager {
path: Arc::new(path)
}
}
}
impl bb8::ManageConnection for SqliteConnManager {
type Connection = rusqlite::Connection;
type Error = rusqlite::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let opt = self.clone();
tokio::task::spawn_blocking(move || {
rusqlite::Connection::open(opt.path.as_ref())
}).await.unwrap()
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
tokio::task::block_in_place(|| conn.execute_batch(""))
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}
}

View file

@ -1,271 +0,0 @@
use trust_dns_proto::DnsHandle;
use trust_dns_client::client::ClientHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode, Edns};
use trust_dns_client::error::ClientError;
use super::{Name, Record, RData};
use super::client::{ClientResponse, DnsClient};
use super::connector::{RecordConnector, ZoneConnector, ConnectorError, ConnectorResult};
const MAX_PAYLOAD_LEN: u16 = 1232;
#[derive(Debug)]
pub enum DnsConnectorError {
ClientError(ClientError),
ResponceNotOk {
code: ResponseCode,
zone: Name,
},
}
pub struct DnsConnectorClient {
client: DnsClient
}
impl DnsConnectorClient {
pub fn new(client: DnsClient) -> Self {
DnsConnectorClient {
client
}
}
}
impl ConnectorError for DnsConnectorError {
fn zone_name(&self) -> Option<Name> {
if let DnsConnectorError::ResponceNotOk { code: _code, zone } = self {
Some(zone.clone())
} else {
None
}
}
fn is_proto_error(&self) -> bool {
return matches!(self, DnsConnectorError::ClientError(_));
}
}
impl std::fmt::Display for DnsConnectorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DnsConnectorError::ClientError(e) => {
write!(f, "DNS client error: {}", e)
},
DnsConnectorError::ResponceNotOk { code, zone } => {
write!(f, "Query for zone \"{}\" failed with code \"{}\"", zone, code)
}
}
}
}
fn set_edns(message: &mut Message) {
let edns = message.extensions_mut().get_or_insert_with(Edns::new);
edns.set_max_payload(MAX_PAYLOAD_LEN);
edns.set_version(0);
}
#[async_trait]
impl RecordConnector for DnsConnectorClient {
//type Error = DnsConnectorError;
async fn get_records(&mut self, zone: Name, class: DNSClass) -> ConnectorResult<Vec<Record>>
{
let response = {
let query = self.client.query(zone.clone(), class, RecordType::AXFR);
match query.await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
}
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.filter(|record| record.data().is_some() && !matches!(record.data().unwrap(), RData::NULL { .. } | RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> ConnectorResult<()>
{
// Taken from trust_dns_client::op::update_message::append
// The original function can not be used as is because it takes a RecordSet and not a Record list
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(new_records);
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
async fn update_records(&mut self, zone: Name, class: DNSClass, old_records: Vec<Record>, new_records: Vec<Record>) -> ConnectorResult<()>
{
// Taken from trust_dns_client::op::update_message::compare_and_swap
// The original function can not be used as is because it takes a RecordSet and not a Record list
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
// make sure the record is what is expected
let mut prerequisite = old_records.clone();
for record in prerequisite.iter_mut() {
record.set_ttl(0);
}
message.add_pre_requisites(prerequisite);
// add the delete for the old record
let mut delete = old_records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// insert the new record...
message.add_updates(new_records);
// Extended dns
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
async fn delete_records(&mut self, zone: Name, class: DNSClass, records: Vec<Record>) -> ConnectorResult<()>
{
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
let mut delete = records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// Extended dns
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
}
#[async_trait]
impl ZoneConnector for DnsConnectorClient {
async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> ConnectorResult<()>
{
let response = {
info!("Querying SOA for name {}", zone);
let query = self.client.query(zone.clone(), class, RecordType::SOA);
match query.await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
}
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
}

569
src/dns/dns_driver.rs Normal file
View file

@ -0,0 +1,569 @@
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use std::sync::Arc;
use domain::base::iana::{Opcode, Rcode};
use domain::base::{message_builder, name, wire};
use domain::base::{MessageBuilder, Name, Rtype};
use domain::net::client::{tsig, stream};
use domain::tsig::{Algorithm, Key, KeyName};
use domain::net::client::request::{self, RequestMessage, RequestMessageMulti, SendRequest, SendRequestMulti};
use tokio::net::TcpStream;
use super::{rdata, record};
use super::{RecordDriver, ZoneDriver, DnsDriverError};
use crate::errors::Error;
use async_trait::async_trait;
pub struct DnsDriverConfig {
pub address: SocketAddr,
pub tsig: Option<TsigConfig>
}
pub struct TsigConfig {
pub key_name: KeyName,
pub secret: Vec<u8>,
pub algorithm: Algorithm
}
#[derive(Clone)]
pub struct DnsDriver {
pub addr: SocketAddr,
pub tsig_key: Option<Arc<Key>>,
}
type TsigDnsClient = tsig::Connection<stream::Connection<
tsig::RequestMessage<request::RequestMessage<Vec<u8>>, Arc<Key>>,
tsig::RequestMessage<request::RequestMessageMulti<Vec<u8>>, Arc<Key>>
>, Arc<Key>>;
impl DnsDriver {
pub fn from_config(config: DnsDriverConfig) -> Self {
let key = config.tsig.map(|tsig_config| {
Arc::new(
Key::new(
tsig_config.algorithm,
&tsig_config.secret,
tsig_config.key_name,
None,
None
).expect("Failed to build key"),
)
});
Self {
addr: config.address,
tsig_key: key,
}
}
async fn client<Req, ReqMulti>(&self) -> Result<stream::Connection<Req, ReqMulti>, DnsDriverError>
where
Req: request::ComposeRequest + Send + Sync + 'static,
ReqMulti: request::ComposeRequestMulti + Send + Sync + 'static,
{
let mut stream_config = stream::Config::default();
stream_config.set_response_timeout(
Duration::from_millis(100),
);
let tcp_connect = TcpStream::connect(self.addr).await?;
let (tcp_conn, transport) = stream::Connection::with_config(
tcp_connect, stream_config
);
tokio::spawn(transport.run());
Ok(tcp_conn)
}
async fn tsig_client(&self) -> Result<Option<TsigDnsClient>, DnsDriverError>
{
if let Some(ref key) = self.tsig_key {
let conn = self.client().await?;
Ok(Some(tsig::Connection::new(key.clone(), conn)))
} else {
Ok(None)
}
}
}
#[async_trait]
impl ZoneDriver for DnsDriver {
async fn zone_exists(&self, zone: &str) -> Result<(), DnsDriverError> {
let client = self.client::<_, RequestMessageMulti<Vec<u8>>>().await?;
let mut msg = MessageBuilder::new_vec().question();
msg.push((
Name::vec_from_str(zone)?,
Rtype::SOA,
))?;
let req = RequestMessage::new(msg)?;
let res = SendRequest::send_request(&client, req)
.get_response()
.await?;
let rcode = res.header().rcode();
match rcode {
Rcode::NOERROR => Ok(()),
Rcode::NXDOMAIN | Rcode::REFUSED => Err(DnsDriverError::ZoneNotFound {
name: zone.to_string(),
}),
rcode => Err(DnsDriverError::ServerError {
rcode: rcode.to_string(),
name: zone.to_string(),
qtype: Rtype::SOA.to_string()
})
}
}
}
#[async_trait]
impl RecordDriver for DnsDriver {
/// ------------- AXFR -------------
async fn get_records(&self, zone: &str) -> Result<Vec<record::Record>, DnsDriverError> {
let mut msg = MessageBuilder::new_vec();
msg.header_mut().set_ad(true);
let mut msg = msg.question();
msg.push((
Name::vec_from_str(zone)?,
Rtype::AXFR,
))?;
let req = RequestMessageMulti::new(msg)?;
let tsig_client = self.tsig_client().await?;
let mut request = if let Some(client) = tsig_client {
SendRequestMulti::send_request(&client, req)
} else {
let client = self.client::<RequestMessage<Vec<u8>>,_>().await?;
SendRequestMulti::send_request(&client, req)
};
let mut records = Vec::new();
while let Some(reply) = request.get_response().await? {
let rcode = reply.header().rcode();
if rcode != Rcode::NOERROR {
return Err(DnsDriverError::ServerError {
rcode: rcode.to_string(),
name: zone.to_string(),
qtype: Rtype::AXFR.to_string()
});
}
let answer = reply.answer()?;
for record in answer.limit_to::<rdata::ParsedRData<_, _>>() {
let record = record?;
records.push(record.into())
}
}
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
/// ------------- Dynamic Update - RFC 2136 -------------
///
/// 2 - Update Message Format
/// +---------------------+
/// | Header |
/// +---------------------+
/// | Zone | specifies the zone to be updated (RFC1035 Question)
/// +---------------------+
/// | Prerequisite | RRs or RRsets which must (not) preexist (RFC1035 Answer)
/// +---------------------+
/// | Update | RRs or RRsets to be added or deleted (RFC1035 Authority)
/// +---------------------+
/// | Additional Data | additional data
/// +---------------------+
/// 2.2 - Message Header
///
/// OPCODE is set to UPDATE.
/// UPDATE uses only one flag bit (QR).
///
/// 2.3 - Zone Section
///
/// The ZNAME is the zone name, the ZTYPE must be SOA, and the ZCLASS is
/// the zone's class.
///
/// 3.2.4 - Table Of Metavalues Used In Prerequisite Section
///
/// TTL must be specified as zero (0) for all prerequisite
///
/// CLASS TYPE RDATA Meaning
/// ------------------------------------------------------------
/// ANY ANY empty Name is in use
/// ANY rrset empty RRset exists (value independent)
/// NONE ANY empty Name is not in use
/// NONE rrset empty RRset does not exist
/// zone rrset rr RRset exists (value dependent) - Match against ALL RR in a RRset!!
///
/// 3.4.2.6 - Table Of Metavalues Used In Update Section
///
/// CLASS TYPE RDATA Meaning
/// ---------------------------------------------------------
/// ANY ANY empty Delete all RRsets from a name - TTL must be specified as zero (0)
/// ANY rrset empty Delete an RRset - TTL must be specified as zero (0)
/// NONE rrset rr Delete an RR from an RRset - TTL must be specified as zero (0)
/// zone rrset rr Add to an RRset
async fn add_records(&self, zone: &str, new_records: &[record::DnsRecordImpl]) -> Result<(), DnsDriverError> {
let mut msg = MessageBuilder::new_vec();
msg.header_mut().set_opcode(Opcode::UPDATE);
let mut msg = msg.question();
msg.push((
Name::vec_from_str(zone)?,
Rtype::SOA,
))?;
let mut msg = msg.authority();
for record in new_records {
msg.push(record)?;
}
let req = RequestMessage::new(msg)?;
let tsig_client = self.tsig_client().await?;
let mut request = if let Some(client) = tsig_client {
SendRequest::send_request(&client, req)
} else {
let client = self.client::<_, RequestMessageMulti<Vec<u8>>>().await?;
SendRequest::send_request(&client, req)
};
let reply = request.get_response().await?;
let rcode = reply.header().rcode();
if rcode != Rcode::NOERROR {
Err(DnsDriverError::ServerError {
rcode: rcode.to_string(),
name: zone.to_string(),
qtype: "UPDATE".to_string(),
})
} else {
Ok(())
}
}
}
impl From<io::Error> for DnsDriverError {
fn from(value: io::Error) -> Self {
DnsDriverError::ConnectionError { reason: Box::new(value) }
}
}
impl From<request::Error> for DnsDriverError {
fn from(value: request::Error) -> Self {
DnsDriverError::ConnectionError { reason: Box::new(value) }
}
}
impl From<message_builder::PushError> for DnsDriverError {
fn from(value: message_builder::PushError) -> Self {
DnsDriverError::OperationError { reason: Box::new(value) }
}
}
impl From<name::FromStrError> for DnsDriverError {
fn from(value: name::FromStrError) -> Self {
DnsDriverError::OperationError { reason: Box::new(value) }
}
}
impl From<wire::ParseError> for DnsDriverError {
fn from(value: wire::ParseError) -> Self {
DnsDriverError::OperationError { reason: Box::new(value) }
}
}
/*
use trust_dns_proto::DnsHandle;
use trust_dns_client::client::ClientHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode, Edns};
use trust_dns_client::error::ClientError;
use super::{Name, Record, RData};
use super::client::{ClientResponse, DnsClient};
use super::connector::{RecordConnector, ZoneConnector, ConnectorError, ConnectorResult};
const MAX_PAYLOAD_LEN: u16 = 1232;
#[derive(Debug)]
pub enum DnsConnectorError {
ClientError(ClientError),
ResponceNotOk {
code: ResponseCode,
zone: Name,
},
}
pub struct DnsConnectorClient {
client: DnsClient
}
impl DnsConnectorClient {
pub fn new(client: DnsClient) -> Self {
DnsConnectorClient {
client
}
}
}
impl ConnectorError for DnsConnectorError {
fn zone_name(&self) -> Option<Name> {
if let DnsConnectorError::ResponceNotOk { code: _code, zone } = self {
Some(zone.clone())
} else {
None
}
}
fn is_proto_error(&self) -> bool {
return matches!(self, DnsConnectorError::ClientError(_));
}
}
impl std::fmt::Display for DnsConnectorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DnsConnectorError::ClientError(e) => {
write!(f, "DNS client error: {}", e)
},
DnsConnectorError::ResponceNotOk { code, zone } => {
write!(f, "Query for zone \"{}\" failed with code \"{}\"", zone, code)
}
}
}
}
fn set_edns(message: &mut Message) {
let edns = message.extensions_mut().get_or_insert_with(Edns::new);
edns.set_max_payload(MAX_PAYLOAD_LEN);
edns.set_version(0);
}
#[async_trait]
impl RecordConnector for DnsConnectorClient {
//type Error = DnsConnectorError;
async fn get_records(&mut self, zone: Name, class: DNSClass) -> ConnectorResult<Vec<Record>>
{
let response = {
let query = self.client.query(zone.clone(), class, RecordType::AXFR);
match query.await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
}
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.filter(|record| record.data().is_some() && !matches!(record.data().unwrap(), RData::NULL { .. } | RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> ConnectorResult<()>
{
// Taken from trust_dns_client::op::update_message::append
// The original function can not be used as is because it takes a RecordSet and not a Record list
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(new_records);
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
async fn update_records(&mut self, zone: Name, class: DNSClass, old_records: Vec<Record>, new_records: Vec<Record>) -> ConnectorResult<()>
{
// Taken from trust_dns_client::op::update_message::compare_and_swap
// The original function can not be used as is because it takes a RecordSet and not a Record list
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
// make sure the record is what is expected
let mut prerequisite = old_records.clone();
for record in prerequisite.iter_mut() {
record.set_ttl(0);
}
message.add_pre_requisites(prerequisite);
// add the delete for the old record
let mut delete = old_records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// insert the new record...
message.add_updates(new_records);
// Extended dns
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
async fn delete_records(&mut self, zone: Name, class: DNSClass, records: Vec<Record>) -> ConnectorResult<()>
{
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
let mut delete = records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// Extended dns
set_edns(&mut message);
let response = match ClientResponse(self.client.send(message)).await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
}
#[async_trait]
impl ZoneConnector for DnsConnectorClient {
async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> ConnectorResult<()>
{
let response = {
info!("Querying SOA for name {}", zone);
let query = self.client.query(zone.clone(), class, RecordType::SOA);
match query.await.map_err(|e| Box::new(DnsConnectorError::ClientError(e))) {
Err(e) => return Err(e),
Ok(v) => v,
}
};
if response.response_code() != ResponseCode::NoError {
return Err(Box::new(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
}));
}
Ok(())
}
}
*/

View file

@ -1,3 +1,37 @@
pub mod rdata;
pub mod record;
pub mod dns_driver;
use std::sync::Arc;
use async_trait::async_trait;
pub type BoxedZoneDriver = Arc<dyn ZoneDriver>;
pub enum DnsDriverError {
ConnectionError { reason: Box<dyn std::error::Error> },
OperationError { reason: Box<dyn std::error::Error> },
ServerError { rcode: String, name: String, qtype: String },
ZoneNotFound { name: String },
}
#[async_trait]
pub trait ZoneDriver: Send + Sync {
// get_zones
// add_zone
// delete_zone
async fn zone_exists(&self, zone: &str) -> Result<(), DnsDriverError>;
}
#[async_trait]
pub trait RecordDriver: Send + Sync {
async fn get_records(&self, zone: &str) -> Result<Vec<record::Record>, DnsDriverError>;
async fn add_records(&self, zone: &str, new_records: &[record::DnsRecordImpl]) -> Result<(), DnsDriverError>;
//async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec<dns::Record>, new_records: Vec<dns::Record>) -> ConnectorResult<()>;
//async fn delete_records(&mut self, zone: dns::Name, class: dns::DNSClass, records: Vec<dns::Record>) -> ConnectorResult<()>;
}
/*
pub mod client;
pub mod dns_connector;
pub mod connector;
@ -53,3 +87,4 @@ impl<'r> FromRequest<'r> for Box<dyn ZoneConnector> {
}
}
}
*/

528
src/dns/rdata.rs Normal file
View file

@ -0,0 +1,528 @@
use std::fmt::Write;
use std::net::{Ipv4Addr, Ipv6Addr};
use domain::base::rdata::ComposeRecordData;
use domain::base::scan::Symbol;
use domain::base::wire::{Composer, ParseError};
use domain::base::{Name, ParseRecordData, ParsedName, RecordData, Rtype, ToName, Ttl};
use domain::rdata;
use domain::dep::octseq::{Parser, Octets};
use serde::{Deserialize, Serialize};
use crate::errors::Error;
use crate::validation;
use crate::macros::{append_errors, push_error};
use super::record::RecordParseError;
/// Type used to serialize / deserialize resource records data to response / request
///
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", content = "rdata")]
#[serde(rename_all = "UPPERCASE")]
pub enum RData {
A(A),
Aaaa(Aaaa),
// TODO: CAA
Cname(Cname),
// TODO: DS
Mx(Mx),
Ns(Ns),
Ptr(Ptr),
Soa(Soa),
Srv(Srv),
// TODO: SSHFP
// TODO: SVCB / HTTPS
// TODO: TLSA
Txt(Txt),
}
pub enum ParsedRData<Name, Octs> {
A(rdata::A),
Aaaa(rdata::Aaaa),
Cname(rdata::Cname<Name>),
Mx(rdata::Mx<Name>),
Ns(rdata::Ns<Name>),
Ptr(rdata::Ptr<Name>),
Soa(rdata::Soa<Name>),
Srv(rdata::Srv<Name>),
Txt(rdata::Txt<Octs>),
}
impl<Name: ToString, Octs: AsRef<[u8]>> From<ParsedRData<Name, Octs>> for RData {
fn from(value: ParsedRData<Name, Octs>) -> Self {
match value {
ParsedRData::A(record_rdata) => RData::A(record_rdata.into()),
ParsedRData::Aaaa(record_rdata) => RData::Aaaa(record_rdata.into()),
ParsedRData::Cname(record_rdata) => RData::Cname(record_rdata.into()),
ParsedRData::Mx(record_rdata) => RData::Mx(record_rdata.into()),
ParsedRData::Ns(record_rdata) => RData::Ns(record_rdata.into()),
ParsedRData::Ptr(record_rdata) => RData::Ptr(record_rdata.into()),
ParsedRData::Soa(record_rdata) => RData::Soa(record_rdata.into()),
ParsedRData::Srv(record_rdata) => RData::Srv(record_rdata.into()),
ParsedRData::Txt(record_rdata) => RData::Txt(record_rdata.into()),
}
}
}
impl TryFrom<RData> for ParsedRData<Name<Vec<u8>>, Vec<u8>> {
type Error = Vec<Error>;
fn try_from(value: RData) -> Result<Self, Self::Error> {
let rdata = match value {
RData::A(record_rdata) => ParsedRData::A(record_rdata.parse_record()?),
RData::Aaaa(record_rdata) => ParsedRData::Aaaa(record_rdata.parse_record()?),
RData::Cname(record_rdata) => ParsedRData::Cname(record_rdata.parse_record()?),
RData::Mx(record_rdata) => ParsedRData::Mx(record_rdata.parse_record()?),
RData::Ns(record_rdata) => ParsedRData::Ns(record_rdata.parse_record()?),
RData::Ptr(record_rdata) => ParsedRData::Ptr(record_rdata.parse_record()?),
RData::Soa(record_rdata) => ParsedRData::Soa(record_rdata.parse_record()?),
RData::Srv(record_rdata) => ParsedRData::Srv(record_rdata.parse_record()?),
RData::Txt(record_rdata) => ParsedRData::Txt(record_rdata.parse_record()?),
};
Ok(rdata)
}
}
macro_rules! parse_name {
($value:expr, $field:ident, $rtype:literal, $errors:expr) => {
{
let name = push_error!(
validation::normalize_domain(&$value.$field),
$errors, concat!("/", stringify!($field))
);
let name = name.and_then(|name| {
push_error!(
name.parse::<Name<_>>().map_err(|e| {
Error::from(RecordParseError::RDataUnknown {
input: $value.$field,
field: stringify!(field).to_string(),
rtype: $rtype.to_string(),
}).with_cause(&e.to_string())
}),
$errors, concat!("/", stringify!($field))
)
});
name
}
};
}
/* --------- A --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct A {
pub address: String,
}
impl From<rdata::A> for A {
fn from(record_data: rdata::A) -> Self {
A { address: record_data.addr().to_string() }
}
}
impl A {
pub fn parse_record(self) -> Result<rdata::A, Vec<Error>> {
let mut errors = Vec::new();
let address = push_error!(self.address.parse::<Ipv4Addr>().map_err(|e| {
Error::from(RecordParseError::Ip4Address { input: self.address })
.with_cause(&e.to_string())
.with_path("/address")
}), errors);
if errors.is_empty() {
Ok(rdata::A::new(address.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- AAAA --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Aaaa {
pub address: String,
}
impl From<rdata::Aaaa> for Aaaa {
fn from(record_data: rdata::Aaaa) -> Self {
Aaaa { address: record_data.addr().to_string() }
}
}
impl Aaaa {
pub fn parse_record(self) -> Result<rdata::Aaaa, Vec<Error>> {
let mut errors = Vec::new();
let address = push_error!(self.address.parse::<Ipv6Addr>().map_err(|e| {
Error::from(RecordParseError::Ip6Address { input: self.address })
.with_cause(&e.to_string())
.with_path("/address")
}), errors);
if errors.is_empty() {
Ok(rdata::Aaaa::new(address.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- CNAME --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Cname {
pub target: String,
}
impl<N: ToString> From<rdata::Cname<N>> for Cname {
fn from(record_data: rdata::Cname<N>) -> Self {
Cname { target: record_data.cname().to_string() }
}
}
impl Cname {
pub fn parse_record(self) -> Result<rdata::Cname<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let cname = parse_name!(self, target, "CNAME", errors);
if errors.is_empty() {
Ok(rdata::Cname::new(cname.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- MX --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Mx {
pub preference: u16,
pub mail_exchanger: String,
}
impl<N: ToString> From<rdata::Mx<N>> for Mx {
fn from(record_data: rdata::Mx<N>) -> Self {
Mx {
preference: record_data.preference(),
mail_exchanger: record_data.exchange().to_string()
}
}
}
impl Mx {
fn parse_record(self) -> Result<rdata::Mx<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let mail_exchanger = parse_name!(self, mail_exchanger, "MX", errors);
if errors.is_empty() {
Ok(rdata::Mx::new(self.preference, mail_exchanger.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- NS --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Ns {
pub target: String,
}
impl<N: ToString> From<rdata::Ns<N>> for Ns {
fn from(record_rdata: rdata::Ns<N>) -> Self {
Ns {
target: record_rdata.nsdname().to_string(),
}
}
}
impl Ns {
fn parse_record(self) -> Result<rdata::Ns<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let ns_name = parse_name!(self, target, "NS", errors);
if errors.is_empty() {
Ok(rdata::Ns::new(ns_name.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- PTR --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Ptr {
pub target: String,
}
impl<N: ToString> From<rdata::Ptr<N>> for Ptr {
fn from(record_rdata: rdata::Ptr<N>) -> Self {
Ptr {
target: record_rdata.ptrdname().to_string(),
}
}
}
impl Ptr {
fn parse_record(self) -> Result<rdata::Ptr<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let ptr_name = parse_name!(self, target, "PTR", errors);
if errors.is_empty() {
Ok(rdata::Ptr::new(ptr_name.unwrap()))
} else {
Err(errors)
}
}
}
/* --------- SOA --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Soa {
pub primary_server: String,
pub maintainer: String,
pub refresh: u32,
pub retry: u32,
pub expire: u32,
pub minimum: u32,
pub serial: u32,
}
impl<N: ToString> From<rdata::Soa<N>> for Soa {
fn from(record_rdata: rdata::Soa<N>) -> Self {
Soa {
primary_server: record_rdata.mname().to_string(),
maintainer: record_rdata.rname().to_string(),
refresh: record_rdata.refresh().as_secs(),
retry: record_rdata.retry().as_secs(),
expire: record_rdata.expire().as_secs(),
minimum: record_rdata.minimum().as_secs(),
serial: record_rdata.serial().into(),
}
}
}
impl Soa {
fn parse_record(self) -> Result<rdata::Soa<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let primary_ns = parse_name!(self, primary_server, "SOA", errors);
let maintainer = parse_name!(self, maintainer, "SOA", errors);
if errors.is_empty() {
Ok(rdata::Soa::new(
primary_ns.unwrap(),
maintainer.unwrap(),
self.refresh.into(),
Ttl::from_secs(self.retry),
Ttl::from_secs(self.expire),
Ttl::from_secs(self.minimum),
Ttl::from_secs(self.serial),
))
} else {
Err(errors)
}
}
}
/* --------- SRV --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Srv {
pub server: String,
pub port: u16,
pub priority: u16,
pub weight: u16,
}
impl<N: ToString> From<rdata::Srv<N>> for Srv {
fn from(record_data: rdata::Srv<N>) -> Self {
Srv {
server: record_data.target().to_string(),
priority: record_data.priority(),
weight: record_data.weight(),
port: record_data.port(),
}
}
}
impl Srv {
fn parse_record(self) -> Result<rdata::Srv<Name<Vec<u8>>>, Vec<Error>> {
let mut errors = Vec::new();
let server = parse_name!(self, server, "SRV", errors);
if errors.is_empty() {
Ok(rdata::Srv::new(
self.priority,
self.weight,
self.port,
server.unwrap(),
))
} else {
Err(errors)
}
}
}
/* --------- TXT --------- */
#[derive(Debug, Deserialize, Serialize)]
pub struct Txt {
pub text: String,
}
impl<O: AsRef<[u8]>> From<rdata::Txt<O>> for Txt {
fn from(record_data: rdata::Txt<O>) -> Self {
let mut concatenated_text = String::new();
for text in record_data.iter() {
for c in text {
// Escapes '\' and non printable chars
let c = Symbol::display_from_octet(*c);
write!(concatenated_text, "{}", c).unwrap();
}
}
Txt {
text: concatenated_text
}
}
}
impl Txt {
fn parse_record(self) -> Result<rdata::Txt<Vec<u8>>, Vec<Error>> {
let mut errors = Vec::new();
let data = append_errors!(validation::parse_txt_data(&self.text), errors, "/text");
let data = data.and_then(|data| {
push_error!(rdata::Txt::build_from_slice(&data).map_err(|e| {
Error::from(RecordParseError::RDataUnknown {
input: self.text,
field: "text".into(),
rtype: "TXT".into(),
}).with_cause(&e.to_string())
.with_path("/text")
}), errors)
});
if errors.is_empty() {
Ok(data.unwrap())
} else {
Err(errors)
}
}
}
/* --------- ParsedRData: domain traits impl --------- */
impl<Name, Octs> ParsedRData<Name, Octs> {
pub fn rtype(&self) -> Rtype {
match self {
ParsedRData::A(_) => Rtype::A,
ParsedRData::Aaaa(_) => Rtype::AAAA,
ParsedRData::Cname(_) => Rtype::CNAME,
ParsedRData::Mx(_) => Rtype::MX,
ParsedRData::Ns(_) => Rtype::NS,
ParsedRData::Ptr(_) => Rtype::PTR,
ParsedRData::Soa(_) => Rtype::SOA,
ParsedRData::Srv(_) => Rtype::SRV,
ParsedRData::Txt(_) => Rtype::TXT,
}
}
}
impl<Name, Octs> RecordData for ParsedRData<Name, Octs> {
fn rtype(&self) -> Rtype {
ParsedRData::rtype(self)
}
}
impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> for ParsedRData<ParsedName<Octs::Range<'a>>, Octs::Range<'a>> {
fn parse_rdata(
rtype: Rtype,
parser: &mut Parser<'a, Octs>,
) -> Result<Option<Self>, ParseError> {
let record = match rtype {
Rtype::A => ParsedRData::A(rdata::A::parse(parser)?),
Rtype::AAAA => ParsedRData::Aaaa(rdata::Aaaa::parse(parser)?),
Rtype::CNAME => ParsedRData::Cname(rdata::Cname::parse(parser)?),
Rtype::MX => ParsedRData::Mx(rdata::Mx::parse(parser)?),
Rtype::NS => ParsedRData::Ns(rdata::Ns::parse(parser)?),
Rtype::PTR => ParsedRData::Ptr(rdata::Ptr::parse(parser)?),
Rtype::SOA => ParsedRData::Soa(rdata::Soa::parse(parser)?),
Rtype::SRV => ParsedRData::Srv(rdata::Srv::parse(parser)?),
Rtype::TXT => ParsedRData::Txt(rdata::Txt::parse(parser)?),
_ => return Ok(None)
};
Ok(Some(record))
}
}
impl<Name: ToName, Octs: AsRef<[u8]>> ComposeRecordData for ParsedRData<Name, Octs> {
fn rdlen(&self, compress: bool) -> Option<u16> {
match self {
ParsedRData::A(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Aaaa(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Cname(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Mx(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Ns(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Ptr(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Soa(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Srv(record_rdata) => record_rdata.rdlen(compress),
ParsedRData::Txt(record_rdata) => record_rdata.rdlen(compress),
}
}
fn compose_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
match self {
ParsedRData::A(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Aaaa(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Cname(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Mx(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Ns(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Ptr(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Soa(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Srv(record_rdata) => record_rdata.compose_rdata(target),
ParsedRData::Txt(record_rdata) => record_rdata.compose_rdata(target),
}
}
fn compose_canonical_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
match self {
ParsedRData::A(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Aaaa(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Cname(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Mx(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Ns(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Ptr(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Soa(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Srv(record_rdata) => record_rdata.compose_canonical_rdata(target),
ParsedRData::Txt(record_rdata) => record_rdata.compose_canonical_rdata(target),
}
}
}

128
src/dns/record.rs Normal file
View file

@ -0,0 +1,128 @@
use serde::{Deserialize, Serialize};
use domain::base::{iana::Class, Name, Record as DnsRecord, Ttl};
use crate::{errors::Error, validation};
use crate::macros::{append_errors, push_error};
use super::rdata::{ParsedRData, RData};
pub enum RecordParseError {
Ip4Address { input: String },
Ip6Address { input: String },
RDataUnknown { input: String, field: String, rtype: String },
NameUnknown { input: String },
NotInZone { name: String, zone: String },
}
pub enum RecordError {
Validation { suberrors: Vec<Error> },
}
pub(crate) type DnsRecordImpl = DnsRecord<
Name<Vec<u8>>,
ParsedRData<Name<Vec<u8>>,Vec<u8>>
>;
#[derive(Debug, Deserialize, Serialize)]
pub struct Record {
pub name: String,
pub ttl: u32,
#[serde(flatten)]
pub rdata: RData
}
impl<Name: ToString, Oct: AsRef<[u8]>> From<DnsRecord<Name, ParsedRData<Name, Oct>>> for Record {
fn from(value: DnsRecord<Name, ParsedRData<Name, Oct>>) -> Self {
Record {
name: value.owner().to_string(),
ttl: value.ttl().as_secs(),
rdata: value.into_data().into(),
}
}
}
impl Record {
fn convert(self, zone_name: &Name<Vec<u8>>) -> Result<DnsRecordImpl, Vec<Error>> {
let mut errors = Vec::new();
let name = push_error!(validation::normalize_domain(&self.name), errors, "/name");
let name = name.and_then(|name| push_error!(name.parse::<Name<_>>().map_err(|e| {
Error::from(RecordParseError::NameUnknown {
input: self.name.clone()
}).with_cause(&e.to_string())
}), errors, "/name"));
let name = name.and_then(|name| {
if !name.ends_with(zone_name) {
errors.push(
Error::from(RecordParseError::NotInZone { name: self.name, zone: zone_name.to_string() })
.with_path("/name")
);
None
} else {
Some(name)
}
});
let ttl = Ttl::from_secs(self.ttl);
let rdata = append_errors!(ParsedRData::try_from(self.rdata), errors, "/rdata");
if errors.is_empty() {
Ok(DnsRecord::new(name.unwrap(), Class::IN, ttl, rdata.unwrap()))
} else {
Err(errors)
}
}
}
#[derive(Debug, Deserialize)]
pub struct RecordList(Vec<Record>);
impl RecordList {
fn convert(self, zone_name: &Name<Vec<u8>>) -> Result<Vec<DnsRecordImpl>, Vec<Error>> {
let mut errors = Vec::new();
let mut records = Vec::new();
for (index, record) in self.0.into_iter().enumerate() {
let record = append_errors!(record.convert(zone_name), errors, &format!("/{index}"));
if let Some(record) = record {
records.push(record)
}
}
if errors.is_empty() {
Ok(records)
} else {
Err(errors)
}
}
}
#[derive(Debug,Deserialize)]
pub struct AddRecordsRequest {
pub new_records: RecordList
}
pub struct AddRecords {
pub new_records: Vec<DnsRecordImpl>
}
impl AddRecordsRequest {
pub fn validate(self, zone_name: &str) -> Result<AddRecords, Error> {
let zone_name: Name<Vec<u8>> = zone_name.parse().expect("zone name is assumed to be valid");
let mut errors = Vec::new();
let records = append_errors!(self.new_records.convert(&zone_name), errors, "/new_records");
if errors.is_empty() {
Ok(AddRecords {
new_records: records.unwrap(),
})
} else {
Err(Error::from(RecordError::Validation { suberrors: errors }))
}
}
}

370
src/errors.rs Normal file
View file

@ -0,0 +1,370 @@
use std::fmt;
use axum::http::{self, StatusCode};
use axum::response::{AppendHeaders, IntoResponse, Response};
use axum::Json;
use serde::{Serialize, Serializer};
use serde_json::{Value, json};
use crate::dns::{DnsDriverError, ZoneDriver};
use crate::dns::record::{RecordError, RecordParseError};
use crate::ressouces::zone::ZoneError;
use crate::validation::{DomainValidationError, TxtParseError};
#[derive(Debug, Serialize)]
pub struct Error {
#[serde(skip)]
cause: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_status")]
status: Option<StatusCode>,
code: String,
description: String,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
errors: Option<Vec<Error>>,
}
pub fn serialize_status<S>(status: &Option<StatusCode>, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
if let Some(status) = status {
serializer.serialize_u16(status.as_u16())
} else {
serializer.serialize_unit()
}
}
impl Error {
pub fn new(code: &str, description: &str) -> Self {
Error {
cause: None,
status: None,
code: code.into(),
description: description.into(),
details: None,
path: None,
errors: None
}
}
pub fn with_cause(self, cause: &str) -> Self {
Self {
cause: Some(cause.into()),
..self
}
}
pub fn with_status(self, status: StatusCode) -> Self {
Self {
status: Some(status),
..self
}
}
pub fn with_path(self, path: &str) -> Self {
if let Some(current_path) = self.path {
Self {
path: Some(format!("{path}{current_path}")),
..self
}
} else {
Self {
path: Some(path.into()),
..self
}
}
}
pub fn with_details<T: Serialize> (self, details: T) -> Self {
let mut new_details = serde_json::to_value(details).expect("failed to convert details to serde_json::Value");
let details = self.details;
// append new details to existing details
if let Some(mut details) = details {
if let Some(object) = details.as_object_mut() {
if let Some(new_object) = new_details.as_object_mut() {
object.append(new_object);
return Self {
details: Some(details),
..self
}
}
}
}
Self {
details: Some(new_details),
..self
}
}
pub fn with_suberrors(self, mut errors: Vec<Error>) -> Self {
for error in &mut errors {
error.status = None;
}
Self {
errors: Some(errors),
..self
}
}
pub fn override_status(self, status: StatusCode) -> Self {
if self.status.is_some() {
self.with_status(status)
} else {
self
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.description)?;
if let Some(cause) = &self.cause {
write!(f, ": {}", cause)?;
}
if self.status.is_some() || self.details.is_some() {
write!(f, " (")?;
}
if let Some(status) = &self.status {
write!(f, "status = {}", status)?;
}
if let Some(details) = &self.details {
if self.status.is_some() {
write!(f, ", ")?;
}
write!(f, "details = {}", serde_json::to_string(details).expect("Failed to serialize error details"))?;
}
if self.status.is_some() || self.details.is_some() {
write!(f, ")")?;
}
Ok(())
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
if let Some(status) = self.status {
(status, Json(self)).into_response()
} else {
eprintln!("{}", self);
(
StatusCode::INTERNAL_SERVER_ERROR,
AppendHeaders([
(http::header::CONTENT_TYPE, "application/json")
]),
r#"{"status": 500,"description":"Internal server error","code":"internal"}"#
).into_response()
}
}
}
impl From<bb8::RunError<rusqlite::Error>> for Error {
fn from(value: bb8::RunError<rusqlite::Error>) -> Self {
Error::new("db:pool", "Failed to get database connection from pool")
.with_cause(&value.to_string())
}
}
impl From<rusqlite::Error> for Error {
fn from(value: rusqlite::Error) -> Self {
Error::new("db:sqlite", "Sqlite failure")
.with_cause(&format!("{:?}", value))
}
}
impl From<ZoneError> for Error {
fn from(value: ZoneError) -> Self {
match value {
ZoneError::ZoneConflict { name } => {
Error::new("zone:conflict", "Zone {zone_name} already exists")
.with_details(json!({
"zone_name": name
}))
.with_status(StatusCode::CONFLICT)
},
ZoneError::NotFound { name } => {
Error::new("zone:not_found", "The zone {zone_name} could not be found")
.with_details(json!({
"zone_name": name
}))
.with_status(StatusCode::NOT_FOUND)
},
ZoneError::Validation { suberrors } => {
Error::new("zone:validation", "Error while validating zone input data")
.with_suberrors(suberrors)
.with_status(StatusCode::BAD_REQUEST)
},
ZoneError::NotExistsNs { name } => {
Error::new("zone:not_exists_ns", "The zone {zone_name} does not exist on the name server")
.with_details(json!({
"zone_name": name
}))
.with_status(StatusCode::BAD_REQUEST)
}
}
}
}
impl From<DomainValidationError> for Error {
fn from(value: DomainValidationError) -> Self {
match value {
DomainValidationError::CharactersNotPermitted { label } => {
Error::new("domain:characters_not_permitted", "Domain name label {label} contains characters not permitted. The allowed characters are lowercase alphanumeric characters (a-z and 0-9), the dash ('-'), the underscore ('_') and the forward slash ('/').")
.with_details(json!({
"label": label
}))
},
DomainValidationError::EmptyDomain => {
Error::new("domain:empty_domain", "Domain name can not be empty or the root domain ('.')")
},
DomainValidationError::EmptyLabel => {
Error::new("domain:empty_label", "Domain name contains empty labels (repeated dots)")
},
DomainValidationError::DomainTooLong { length } => {
Error::new("domain:domain_too_long", "Domain name too long ({length} characters), the maximum length is 255 characters")
.with_details(json!({
"length": length
}))
},
DomainValidationError::LabelToolLong { length, label } => {
Error::new("domain:label_too_long", "Domain name label {label} is too long ({label_length} characters), the maximum length is 63 characters")
.with_details(json!({
"label": label,
"length": length,
}))
},
}
}
}
impl From<TxtParseError> for Error {
fn from(value: TxtParseError) -> Self {
match value {
TxtParseError::BadEscapeDigitIndexTooHigh { sequence } => {
Error::new("record:txt:parse:escape_decimal_index_too_high", "Octect escape sequence should be between 000 and 255. Offending escape sequence: \\{sequence}")
.with_details(json!({
"sequence": sequence
}))
},
TxtParseError::BadEscapeDigitsNotDigits { sequence } => {
Error::new("record:txt:parse:escape_decimal_not_digits", "Expected an octect escape sequence due to the presence of a back slash (\\) followed by a digit but found non digit characters. Offending escape sequence: \\{sequence}")
.with_details(json!({
"sequence": sequence
}))
},
TxtParseError::BadEscapeDigitsTooShort { sequence } => {
Error::new("record:txt:parse:escape_decimal_too_short", "Expected an octect escape sequence due to the presence of a back slash (\\) followed by a digit but found found {sequence_lenght} characters instead of three. Offending escape sequence: \\{sequence}")
.with_details(json!({
"sequence": sequence,
"sequence_lenght": sequence.len()
}))
},
TxtParseError::MissingEscape => {
Error::new("record:txt:parse:escape_missing", "Expected an escape sequence due to the presence of a back slash (\\) at the end of the input but found nothing")
},
TxtParseError::NonAscii { character } => {
Error::new("record:txt:parse:non_ascii", "Found a non ASCII character ({character}). Only printable ASCII characters are allowed.")
.with_details(json!({
"character": character
}))
}
}
}
}
impl From<DnsDriverError> for Error {
fn from(value: DnsDriverError) -> Self {
match value {
DnsDriverError::ConnectionError { reason } => {
Error::new("dns:connection", "Error while connecting to the name server")
.with_cause(&reason.to_string())
},
DnsDriverError::OperationError { reason } => {
Error::new("dns:operation", "DNS operation error")
.with_cause(&reason.to_string())
},
DnsDriverError::ServerError { rcode, name, qtype } => {
Error::new("dns:server", "Unexpected response to query")
.with_details(json!({
"rcode": rcode,
"name": name,
"qtype": qtype,
}))
},
DnsDriverError::ZoneNotFound { name } => {
Error::new("dns:zone_not_found", "The zone {zone_name} does not exist on the name server")
.with_details(json!({
"zone_name": name
}))
}
}
}
}
impl From<RecordParseError> for Error {
fn from(value: RecordParseError) -> Self {
match value {
RecordParseError::Ip4Address { input } => {
Error::new("record:parse:ip4", "The following IPv4 address {input} is invalid. IPv4 addresses should have four numbers, each between 0 and 255, separated by dots.")
.with_details(json!({
"input": input
}))
},
RecordParseError::Ip6Address { input } => {
Error::new("record:parse:ip6", "The following IPv4 address {input} is invalid. IPv6 addresses should have eight groups of four hexadecimal digit separated by colons. Leftmost zeros in a group can be omitted, sequence of zeros can be shorted by a double colons.")
.with_details(json!({
"input": input
}))
},
RecordParseError::RDataUnknown { input, field, rtype } => {
Error::new("record:parse:rdata_unknown", "Unknown error while parsing record rdata field")
.with_details(json!({
"input": input,
"field": field,
"rtype": rtype,
}))
},
RecordParseError::NameUnknown { input } => {
Error::new("record:parse:name_unknown", "Unknown error while parsing record name")
.with_details(json!({
"input": input
}))
},
RecordParseError::NotInZone { name, zone } => {
Error::new("record:parse:not_in_zone", "The domain name {name} is not in the current zone ({zone})")
.with_details(json!({
"name": name,
"zone": zone
}))
}
}
}
}
impl From<RecordError > for Error {
fn from(value: RecordError) -> Self {
match value {
RecordError::Validation { suberrors } => {
Error::new("record:validation", "Error while validating input records")
.with_suberrors(suberrors)
.with_status(StatusCode::BAD_REQUEST)
}
}
}
}

35
src/macros.rs Normal file
View file

@ -0,0 +1,35 @@
macro_rules! push_error {
($value:expr, $errors:expr) => {
match $value {
Err(error) => { $errors.push(error); None },
Ok(value) => Some(value)
}
};
($value:expr, $errors:expr, $path:expr) => {
match $value {
Err(error) => { $errors.push(error.with_path($path)); None },
Ok(value) => Some(value)
}
};
}
macro_rules! append_errors {
($value:expr, $errors:expr) => {
match $value {
Err(mut err) => { $errors.append(&mut err); None },
Ok(value) => Some(value)
}
};
($value:expr, $errors:expr, $path:expr) => {
match $value {
Err(err) => { $errors.extend(err.into_iter().map(|e| {
e.with_path($path)
})); None },
Ok(value) => Some(value)
}
};
}
pub(crate) use append_errors;
pub(crate) use push_error;

View file

@ -1,26 +1,86 @@
#![feature(proc_macro_hygiene, decl_macro)]
//#![feature(proc_macro_hygiene, decl_macro)]
/*
#[macro_use] extern crate rocket;
#[macro_use] extern crate diesel;
#[macro_use] extern crate diesel_migrations;
*/
mod routes;
mod cli;
mod config;
//mod routes;
//mod cli;
//mod config;
//mod models;
//mod schema;
//mod template;
//mod controllers;
//use std::process::exit;
//use clap::Parser;
//use figment::{Figment, Profile, providers::{Format, Toml, Env}};
//use rocket_sync_db_pools::database;
//use diesel::prelude::*;
mod errors;
mod dns;
mod models;
mod schema;
mod template;
mod controllers;
mod routes;
mod ressouces;
mod database;
mod validation;
mod macros;
use std::process::exit;
use std::sync::Arc;
use clap::Parser;
use figment::{Figment, Profile, providers::{Format, Toml, Env}};
use rocket_sync_db_pools::database;
use diesel::prelude::*;
use axum::Router;
use axum::routing;
use database::sqlite::SqliteDB;
use database::Db;
use dns::dns_driver::DnsDriverConfig;
use dns::dns_driver::TsigConfig;
use dns::{ZoneDriver, RecordDriver};
#[derive(Clone)]
pub struct AppState {
zone: Arc<dyn ZoneDriver>,
records: Arc<dyn RecordDriver>,
db: Arc<dyn Db>,
}
#[tokio::main]
async fn main() {
let dns_driver = dns::dns_driver::DnsDriver::from_config(DnsDriverConfig {
address: "127.0.0.1:5353".parse().unwrap(),
tsig: Some(TsigConfig {
key_name: "dev".parse().unwrap(),
secret: domain::utils::base64::decode::<Vec<u8>>("mbmz4J3Efm1BUjqe12M1RHsOnPjYhKQe+2iKO4tL+a4=").unwrap(),
algorithm: domain::tsig::Algorithm::Sha256,
})
});
let dns_driver = Arc::new(dns_driver);
let app_state = AppState {
zone: dns_driver.clone(),
records: dns_driver.clone(),
db: Arc::new(SqliteDB::new("db.sqlite".into()).await),
};
let app = Router::new()
.route("/admin/zones", routing::post(routes::api::zones::create_zone))
.route("/zones/{zone_name}/records", routing::get(routes::api::zones::get_zone_records))
.route("/zones/{zone_name}/records", routing::post(routes::api::zones::create_zone_records))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:8000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
/*
use crate::cli::{NomiloCli, NomiloCommand};
#[database("sqlite")]
@ -63,3 +123,4 @@ fn main() {
let nomilo = NomiloCli::parse();
nomilo.run(figment, app_config);
}
*/

View file

@ -1,93 +0,0 @@
use crate::models::user::UserInfo;
use uuid::Uuid;
use diesel::prelude::*;
use diesel::result::Error as DieselError;
use serde::{Serialize, Deserialize};
use crate::schema::*;
use super::name::AbsoluteName;
use super::user::UserZone;
use super::errors::UserError;
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
#[derive(Debug, Deserialize, FromForm)]
pub struct CreateZoneRequest {
pub name: AbsoluteName,
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
let new_zone = Zone {
id: Uuid::new_v4().to_simple().to_string(),
name: zone_request.name.to_utf8(),
};
diesel::insert_into(zone)
.values(&new_zone)
.execute(conn)
.map_err(|e| match e {
DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict,
other => UserError::DbError(other)
})?;
Ok(new_zone)
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
}

View file

@ -1,3 +1,4 @@
/*
pub mod class;
pub mod errors;
pub mod name;
@ -16,3 +17,6 @@ pub use user::{LocalUser, UserInfo, Role, UserZone, User, CreateUserRequest};
pub use rdata::RData;
pub use record::{Record, RecordList, ParseRecordList, RecordListParseError, UpdateRecordsRequest};
pub use zone::{Zone, AddZoneMemberRequest, CreateZoneRequest};
*/
pub mod zone;

View file

@ -1,3 +1,4 @@
/*
use uuid::Uuid;
use diesel::prelude::*;
use diesel::result::Error as DieselError;
@ -234,3 +235,4 @@ impl LocalUser {
})
}
}
*/

222
src/ressouces/zone.rs Normal file
View file

@ -0,0 +1,222 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use rusqlite::Error as RusqliteError;
use crate::database::{BoxedDb, sqlite::SqliteDB};
use crate::dns::{BoxedZoneDriver, DnsDriverError};
use crate::errors::Error;
use crate::macros::push_error;
use crate::validation;
pub enum ZoneError {
ZoneConflict { name: String },
NotFound { name: String },
NotExistsNs { name: String },
Validation { suberrors: Vec<Error> },
}
#[derive(Debug, Serialize)]
pub struct Zone {
pub name: String,
}
impl Zone {
pub async fn create(create_zone: CreateZoneRequest, zone_driver: BoxedZoneDriver, db: BoxedDb) -> Result<Self, Error> {
let create_zone = create_zone.validate()?;
zone_driver.zone_exists(&create_zone.name)
.await
.map_err(|e| {
match e {
DnsDriverError::ZoneNotFound { name } => {
Error::from(ZoneError::NotExistsNs { name })
.with_path("/name")
},
e => Error::from(e)
}
})?;
db.create_zone(create_zone).await
}
}
#[derive(Deserialize)]
pub struct CreateZoneRequest {
pub name: String
}
pub struct CreateZone {
pub name: String
}
impl CreateZoneRequest {
pub fn validate(self) -> Result<CreateZone, Error> {
let mut errors = Vec::new();
let name = push_error!(validation::normalize_domain(&self.name), errors, "/name");
name.ok_or(Error::from(ZoneError::Validation { suberrors: errors }))
.map(|name| {
CreateZone { name }
})
}
}
#[async_trait]
pub trait ZoneModel: Send + Sync {
async fn create_zone(&self, create_zone: CreateZone) -> Result<Zone, Error>;
async fn get_zone_by_name(&self, zone_name: &str) -> Result<Zone, Error>;
}
#[async_trait]
impl ZoneModel for SqliteDB {
async fn create_zone(&self, create_zone: CreateZone) -> Result<Zone, Error> {
let pool = self.pool.clone();
let conn = pool.get().await?;
tokio::task::block_in_place(move || {
let mut stmt = conn.prepare("insert into zones (name) values (?1) returning *")?;
let zone = stmt.query_row((&create_zone.name,), |row| {
Ok(Zone {
name: row.get(0)?
})
}).map_err(|e| {
match e {
/* SQLITE_CONSTRAINT_PRIMARYKEY */
RusqliteError::SqliteFailure(e, _) if e.extended_code == 1555 => {
Error::from(ZoneError::ZoneConflict { name: create_zone.name })
.with_path("/name")
},
e => Error::new("internal:zone:create", "Failed to create zone")
.with_cause(&e.to_string())
}
})?;
Ok(zone)
})
}
async fn get_zone_by_name(&self, zone_name: &str) -> Result<Zone, Error> {
let pool = self.pool.clone();
let conn = pool.get().await?;
tokio::task::block_in_place(move || {
let mut stmt = conn.prepare("select * from zones where name = ?1")?;
let zone = stmt.query_row((zone_name,), |row| {
Ok(Zone {
name: row.get(0)?
})
}).map_err(|e| {
match e {
RusqliteError::QueryReturnedNoRows => {
Error::from(ZoneError::NotFound { name: zone_name.to_string() })
},
e => Error::new("internal:zone:get_by_name", "Failed to fetch zone by name")
.with_cause(&e.to_string())
}
})?;
Ok(zone)
})
}
}
/*
use crate::models::user::UserInfo;
use uuid::Uuid;
use diesel::prelude::*;
use diesel::result::Error as DieselError;
use serde::{Serialize, Deserialize};
use crate::schema::*;
use super::name::AbsoluteName;
use super::user::UserZone;
use super::errors::UserError;
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
#[derive(Debug, Deserialize, FromForm)]
pub struct CreateZoneRequest {
pub name: AbsoluteName,
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
let new_zone = Zone {
id: Uuid::new_v4().to_simple().to_string(),
name: zone_request.name.to_utf8(),
};
diesel::insert_into(zone)
.values(&new_zone)
.execute(conn)
.map_err(|e| match e {
DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict,
other => UserError::DbError(other)
})?;
Ok(new_zone)
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
}
*/

View file

@ -1,5 +1,5 @@
pub mod users;
//pub mod users;
pub mod zones;
pub use users::*;
pub use zones::*;
//pub use users::*;
//pub use zones::*;

View file

@ -1,3 +1,50 @@
use axum::extract::{Path, State};
use axum::Json;
use crate::dns::record::{AddRecordsRequest, Record};
use crate::AppState;
use crate::errors::Error;
use crate::ressouces::zone::{CreateZoneRequest, Zone};
pub async fn create_zone(
State(app): State<AppState>,
Json(create_zone): Json<CreateZoneRequest>,
) -> Result<Json<Zone>, Error>
{
Zone::create(create_zone, app.zone, app.db).await.map(Json)
}
pub async fn get_zone_records(
Path(zone_name): Path<String>,
State(app): State<AppState>,
) -> Result<Json<Vec<Record>>, Error>
{
let zone = app.db.get_zone_by_name(&zone_name).await?;
let records = app.records.get_records(&zone.name).await?;
Ok(Json(records))
}
pub async fn create_zone_records(
Path(zone_name): Path<String>,
State(app): State<AppState>,
Json(add_records): Json<AddRecordsRequest>,
) -> Result<Json<Vec<Record>>, Error>
{
let zone = app.db.get_zone_by_name(&zone_name).await?;
let add_records = add_records.validate(&zone.name)?;
app.records.add_records(&zone.name, &add_records.new_records).await?;
let records = add_records.new_records.into_iter()
.map(|r| r.into())
.collect();
Ok(Json(records))
}
/*
use rocket::http::Status;
use rocket::serde::json::Json;
@ -178,3 +225,4 @@ pub async fn add_member_to_zone<'r>(
Ok(Status::Created) // TODO: change this?
}
*/

View file

@ -1,2 +1,2 @@
pub mod ui;
//pub mod ui;
pub mod api;

116
src/validation.rs Normal file
View file

@ -0,0 +1,116 @@
use crate::errors::Error;
pub enum DomainValidationError {
EmptyDomain,
DomainTooLong { length: usize },
CharactersNotPermitted { label: String },
LabelToolLong { length: usize, label: String },
EmptyLabel
}
/// Not complete but probably good enough
/// https://doc.zonemaster.fr/v2024.1/specifications/tests/RequirementsAndNormalizationOfDomainNames.html
/// TODO: No support of dots in labels, how to handle RNAME in SOA?
pub fn normalize_domain(domain_name: &str) -> Result<String, Error> {
let domain = domain_name.strip_prefix('.').unwrap_or(domain_name).to_lowercase();
if domain.is_empty() {
Err(Error::from(DomainValidationError::EmptyDomain))
} else if domain.as_bytes().len() > 255 {
Err(Error::from(DomainValidationError::DomainTooLong { length: domain.as_bytes().len() }))
} else {
let labels = domain.split('.').collect::<Vec<_>>();
if labels.iter().any(|l| l.is_empty()) {
return Err(
Error::from(DomainValidationError::EmptyLabel)
);
}
for label in labels {
if !label.chars().all(|c| {
// allow for '/' for reverse zone
c.is_ascii_alphanumeric() || c == '-' || c == '/' || c == '_'
}) {
return Err(
Error::from(DomainValidationError::CharactersNotPermitted { label: label.into() })
);
}
if label.as_bytes().len() > 63 {
return Err(Error::from(DomainValidationError::LabelToolLong {
label: label.into(),
length: label.as_bytes().len()
}));
}
}
Ok(domain)
}
}
pub enum TxtParseError {
MissingEscape,
NonAscii { character: String },
BadEscapeDigitsTooShort { sequence: String },
BadEscapeDigitsNotDigits { sequence: String },
BadEscapeDigitIndexTooHigh { sequence: String },
}
pub fn parse_txt_data(text: &str) -> Result<Vec<u8>, Vec<Error>> {
let mut chars = text.chars();
let mut errors = Vec::new();
let mut data = Vec::new();
#[inline]
fn printable(ch: char) -> bool {
ch.is_ascii() && ('\u{20}'..='\u{7E}').contains(&ch)
}
while let Some(ch) = chars.next() {
if ch == '\\' {
match chars.next() {
Some(ch) => {
if ch.is_ascii_digit() {
let mut digits: Vec<_> = chars.by_ref().take(2).collect();
digits.insert(0, ch);
if digits.len() < 3 {
errors.push(Error::from(TxtParseError::BadEscapeDigitsTooShort { sequence: String::from_iter(digits) }))
} else if digits.iter().all(|c| c.is_ascii_digit()) {
errors.push(Error::from(TxtParseError::BadEscapeDigitsNotDigits { sequence: String::from_iter(digits) }))
} else {
let index = {
digits[0].to_digit(10).unwrap() * 100 +
digits[1].to_digit(10).unwrap() * 10 +
digits[2].to_digit(10).unwrap()
};
if index > 255 {
errors.push(Error::from(TxtParseError::BadEscapeDigitIndexTooHigh { sequence: String::from_iter(digits) }))
}
}
} else if printable(ch) {
data.push(ch as u8)
} else {
errors.push(Error::from(TxtParseError::NonAscii { character: ch.into() }))
}
},
None => {
errors.push(Error::from(TxtParseError::MissingEscape))
}
}
} else if printable(ch) {
data.push(ch as u8);
} else {
errors.push(Error::from(TxtParseError::NonAscii { character: ch.into() }))
}
}
//TODO: check txt data max length?
if errors.is_empty() {
Ok(data)
} else {
Err(errors)
}
}