Merge branch 'feature/zone-updates' into 'master'

Feature/zone updates

See merge request dns-witch/nomilo!2
main
Gaël Berthaud-Müller 2022-03-05 13:00:33 +00:00
commit 51d447242b
25 changed files with 1951 additions and 555 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
config.toml config.toml
db.sqlite db.sqlite
__pycache__
/env

437
api.yml 100644
View File

@ -0,0 +1,437 @@
openapi: '3.0.0'
info:
description: ''
version: 0.1.0-dev
title: Nomilo
components:
securitySchemes:
ApiToken:
type: http
scheme: bearer
bearerFormat: JWT
parameters:
ZoneName:
name: zone
in: path
schema:
type: string
required: true
schemas:
UserRequest:
type: object
required:
- username
- password
- email
properties:
username:
type: string
password:
type: string
email:
type: string
role:
type: string
enum:
- admin
- zoneadmin
TokenRequest:
type: object
required:
- username
- password
properties:
username:
type: string
password:
type: string
TokenResponse:
type: object
required:
- token
properties:
token:
type: string
AddZoneMemberRequest:
type: object
required:
- id
properties:
id:
type: string
CreateZoneRequest:
type: object
required:
- name
properties:
name:
type: string
Zone:
type: object
required:
- name
properties:
name:
type: string
ZoneList:
type: array
items:
$ref: '#/components/schemas/Zone'
RecordBase:
type: object
required:
- Name
- Class
- TTL
- Type
properties:
Name:
type: string
Class:
type: string
enum:
- IN
- CH
- HS
- NONE
- ANY
TTL:
type: integer
Type:
type: string
RecordTypeA:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Address
properties:
Address:
type: string
RecordTypeAAAA:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Address
properties:
Address:
type: string
RecordTypeCAA:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
required:
- IssuerCritical
- Value
- PropertyTag
properties:
IssuerCritical:
type: boolean
Value:
type: string
PropertyTag:
type: string
RecordTypeCNAME:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Target
properties:
Target:
type: string
RecordTypeMX:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Preference
- MailExchanger
properties:
Preference:
type: integer
MailExchanger:
type: string
RecordTypeNS:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Target
properties:
Target:
type: string
RecordTypePTR:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Target
properties:
Target:
type: string
RecordTypeSOA:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- MasterServerName
- MaintainerName
- Refresh
- Retry
- Expire
- Minimum
- Serial
properties:
MasterServerName:
type: string
MaintainerName:
type: string
Refresh:
type: integer
Retry:
type: integer
Expire:
type: integer
Minimum:
type: integer
Serial:
type: integer
RecordTypeSRV:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Server
- Port
- Priority
- Weight
properties:
Server:
type: string
Port:
type: integer
Priority:
type: integer
Weight:
type: integer
RecordTypeSSHFP:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Algorithm
- DigestType
- Fingerprint
properties:
Algorithm:
type: integer
DigestType:
type: integer
Fingerprint:
type: string
RecordTypeTXT:
type: object
allOf:
- $ref: '#/components/schemas/RecordBase'
- type: object
required:
- Text
properties:
Text:
type: string
Record:
type: object
oneOf:
- $ref: '#/components/schemas/RecordTypeA'
- $ref: '#/components/schemas/RecordTypeAAAA'
- $ref: '#/components/schemas/RecordTypeCAA'
- $ref: '#/components/schemas/RecordTypeCNAME'
- $ref: '#/components/schemas/RecordTypeMX'
- $ref: '#/components/schemas/RecordTypeNS'
- $ref: '#/components/schemas/RecordTypePTR'
- $ref: '#/components/schemas/RecordTypeSOA'
- $ref: '#/components/schemas/RecordTypeSRV'
- $ref: '#/components/schemas/RecordTypeSSHFP'
- $ref: '#/components/schemas/RecordTypeTXT'
discriminator:
propertyName: Type
mapping:
A: '#/components/schemas/RecordTypeA'
AAAA: '#/components/schemas/RecordTypeAAAA'
CAA: '#/components/schemas/RecordTypeCAA'
CNAME: '#/components/schemas/RecordTypeCNAME'
MX: '#/components/schemas/RecordTypeMX'
NS: '#/components/schemas/RecordTypeNS'
PTR: '#/components/schemas/RecordTypePTR'
SOA: '#/components/schemas/RecordTypeSOA'
SRV: '#/components/schemas/RecordTypeSRV'
SSHFP: '#/components/schemas/RecordTypeSSHFP'
TXT: '#/components/schemas/RecordTypeTXT'
RecordList:
type: array
items:
$ref: '#/components/schemas/Record'
UpdateRecordsRequest:
type: object
required:
- oldRecords
- newRecords
properties:
oldRecords:
$ref: '#/components/schemas/RecordList'
newRecords:
$ref: '#/components/schemas/RecordList'
paths:
'/users':
post:
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UserRequest'
responses:
'201':
description: ''
'/users/me/token':
post:
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/TokenRequest'
responses:
'200':
description: ''
content:
application/json:
schema:
$ref: '#/components/schemas/TokenResponse'
'/zones':
get:
security:
- ApiToken: []
responses:
'200':
description: ''
content:
application/json:
schema:
$ref: '#/components/schemas/ZoneList'
post:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/CreateZoneRequest'
responses:
'200':
description: ''
content:
application/json:
schema:
$ref: '#/components/schemas/Zone'
'/zones/{zone}/members':
parameters:
- $ref: '#/components/parameters/ZoneName'
post:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/AddZoneMemberRequest'
responses:
'201':
description: ''
'/zones/{zone}/records':
parameters:
- $ref: '#/components/parameters/ZoneName'
get:
security:
- ApiToken: []
responses:
'200':
description: ''
content:
application/json:
schema:
$ref: '#/components/schemas/RecordList'
post:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RecordList'
responses:
'200':
description: ''
put:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdateRecordsRequest'
responses:
'200':
description: ''
delete:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RecordList'
responses:
'200':
description: ''

View File

@ -0,0 +1,24 @@
server:
listen: [ 0.0.0.0@5353, ::@5353 ]
log:
- target: stderr
any: debug
acl:
- id: example_acl
address: [ 127.0.0.1, ::1]
action: [transfer, update]
template:
- id: default
file: "zones/%s.zone"
journal-content: all
zonefile-load: difference-no-serial
zonefile-sync: -1
serial-policy: dateserial
zone:
- domain: example.com
acl: example_acl
template: default

View File

@ -0,0 +1,8 @@
services:
knot:
image: cznic/knot
volumes:
- ./zones:/storage/zones:ro
- ./config:/config:ro
command: knotd
network_mode: host

View File

@ -0,0 +1,13 @@
example.com. IN SOA ns.example.com. admin.example.com. (
2020250101 ; serial
28800 ; refresh (8 hours)
7200 ; retry (2 hours)
2419200 ; expire (4 weeks)
300 ; minimum (5 minutes)
)
example.com. 84600 IN NS ns.example.com.
srv1.example.com. 600 IN A 198.51.100.3
srv1.example.com. 600 IN AAAA 2001:db8:cafe:bc68::2
www 600 IN CNAME srv1

22
docs/Testing.md 100644
View File

@ -0,0 +1,22 @@
# Testing
To run the end-to-end tests the OpenAPI Python client should be generated first:
```
openapi-generator generate -i ./api.yml -g python --package-name nomilo_client -o ./python_client
```
Then install the client, here a virtual env is created for this purpose:
```
python -m venv env
env/bin/pip install ./python_client
```
Finally start the name server. It will listen on `127.0.0.1:5353`, be sure to update the configuration accordingly.
```
docker-compose -f ./dev-scripts/docker-compose.yml up -d
```
You are now all set to run the e2e tests. Note that Nomilo must be started first.
```
env/bin/python -m unittest e2e/*.py
```

157
e2e/zones.py 100644
View File

@ -0,0 +1,157 @@
from nomilo_client import ApiClient, Configuration
from nomilo_client.api.default_api import DefaultApi
from nomilo_client.models import (
TokenRequest,
RecordTypeSOA,
RecordTypeAAAA,
RecordTypeCNAME,
RecordTypeNS,
RecordTypeTXT,
RecordList,
UpdateRecordsRequest,
)
import logging
import string
import random
import unittest
import warnings
logging.basicConfig(level=logging.DEBUG)
HOST = 'http://localhost:8000/api/v1'
USER='toto'
PASSWORD='supersecure'
def build_api(host: str):
conf = Configuration(host=HOST)
api_client = ApiClient(configuration=conf)
return DefaultApi(api_client)
def build_authenticated_api(host: str, token: TokenRequest):
auth_conf = Configuration(host=host, access_token=token.token)
api_client = ApiClient(configuration=auth_conf)
return DefaultApi(api_client)
def random_string(length):
return ''.join(random.choice(string.ascii_lowercase) for x in range(length))
def random_name(zone):
return '%s.%s' % (random_string(16), zone)
class TestZones(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ignore warning about unclosed socket
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
api = build_api(HOST)
token = api.users_me_token_post(token_request=TokenRequest(username=USER,password=PASSWORD))
cls.api = build_authenticated_api(HOST, token)
def test_get_zones(self):
zones = self.api.zones_get()
zone_name = zones.value[0].name
self.assertEqual(zone_name, 'example.com.')
def test_get_records(self):
records = self.api.zones_zone_records_get(zone='example.com.')
for record in records.value:
if type(record) is RecordTypeSOA:
with self.subTest(type='soa'):
self.assertEqual(record.name, 'example.com.')
if type(record) is RecordTypeAAAA:
with self.subTest(type='ns'):
self.assertEqual(record.name, 'srv1.example.com.')
self.assertEqual(record.address, '2001:db8:cafe:bc68::2')
if type(record) is RecordTypeCNAME:
with self.subTest(type='cname'):
self.assertEqual(record.name, 'www.example.com.')
self.assertEqual(record.target, 'srv1.example.com.')
if type(record) is RecordTypeNS:
with self.subTest(type='ns'):
self.assertEqual(record.name, 'example.com.')
self.assertEqual(record.target, 'ns.example.com.')
def test_create_records(self):
new_record = RecordTypeTXT(
_class='IN',
ttl=300,
name=random_name('example.com.'),
text=random_string(32),
type='TXT'
)
self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList(value=[new_record]))
records = self.api.zones_zone_records_get(zone='example.com.')
found = False
for record in records.value:
if type(record) is RecordTypeTXT and record.name == new_record.name:
self.assertEqual(record.text, new_record.text, msg='New record does not have the expected value')
found = True
self.assertTrue(found, msg='New record not found in zone records')
def test_update_records(self):
name = random_name('example.com.')
old_record = RecordTypeTXT(
_class='IN',
ttl=300,
name=name,
text='old value',
type='TXT'
)
new_record = RecordTypeTXT(
_class='IN',
ttl=300,
name=name,
text='new value',
type='TXT'
)
self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList([old_record]))
update_records_request = UpdateRecordsRequest(
old_records=RecordList([old_record]),
new_records=RecordList([new_record]),
)
self.api.zones_zone_records_put(zone='example.com.', update_records_request=update_records_request)
records = self.api.zones_zone_records_get(zone='example.com.')
found = False
for record in records.value:
if type(record) is RecordTypeTXT and record.name == name:
self.assertEqual(record.text, new_record.text, msg='New record does not have the expected value')
found = True
self.assertTrue(found, msg='Updated record not found in zone records')
def test_delete_records(self):
name = random_name('example.com.')
record = RecordTypeTXT(
_class='IN',
ttl=300,
name=name,
text=random_string(32),
type='TXT'
)
self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList([record]))
self.api.zones_zone_records_delete(zone='example.com.', record_list=RecordList([record]))
records = self.api.zones_zone_records_get(zone='example.com.')
found = False
for record in records.value:
if type(record) is RecordTypeTXT and record.name == name:
found = True
self.assertFalse(found, msg='Delete record found in zone records')

69
src/dns/client.rs 100644
View File

@ -0,0 +1,69 @@
use std::{future::Future, pin::Pin, task::{Context, Poll}};
use std::ops::{Deref, DerefMut};
use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}};
use tokio::{net::TcpStream as TokioTcpStream, task};
use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream};
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use crate::config::Config;
pub struct DnsClient(AsyncClient);
impl Deref for DnsClient {
type Target = AsyncClient;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsClient {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<State<Config>>().await);
let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(config.dns.server);
let client = AsyncClient::with_timeout(
stream,
handle,
std::time::Duration::from_secs(5),
None);
let (client, bg) = match client.await {
Err(e) => {
println!("Failed to connect to DNS server {:#?}", e);
return Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => c
};
task::spawn(bg);
Outcome::Success(DnsClient(client))
}
}
// Reimplement this type here as ClientReponse in trust-dns crate have private fields
pub struct ClientResponse<R>(pub(crate) R)
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
impl<R> Future for ClientResponse<R>
where
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
{
type Output = Result<DnsResponse, ClientError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// This is from the future_utils crate, we simply reuse the reexport from Rocket
rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from)
}
}

View File

@ -0,0 +1,27 @@
use crate::dns;
// TODO: Use model types instead of dns types as input / output and only convert internaly?
// Zone content api
// E.g.: DNS update + axfr, zone file read + write
#[async_trait]
pub trait RecordConnector {
type Error;
async fn get_records(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, Self::Error>;
async fn add_records(&mut self, zone: dns::Name, class: dns::DNSClass, new_records: Vec<dns::Record>) -> Result<(), Self::Error>;
async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec<dns::Record>, new_records: Vec<dns::Record>) -> Result<(), Self::Error>;
async fn delete_records(&mut self, zone: dns::Name, class: dns::DNSClass, records: Vec<dns::Record>) -> Result<(), Self::Error>;
// delete_records
}
// Zone management api, todo
// E.g.: Manage catalog zone, dynamically generate knot / bind / nsd config...
#[async_trait]
pub trait ZoneConnector {
type Error;
// get_zones
// add_zone
// delete_zone
async fn zone_exists(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<(), Self::Error>;
}

View File

@ -0,0 +1,236 @@
use trust_dns_proto::DnsHandle;
use trust_dns_client::client::ClientHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode};
use trust_dns_client::error::ClientError;
use super::{Name, Record, RData};
use super::client::{ClientResponse, DnsClient};
use super::connector::{RecordConnector, ZoneConnector};
const MAX_PAYLOAD_LEN: u16 = 1232;
#[derive(Debug)]
pub enum DnsConnectorError {
ClientError(ClientError),
ResponceNotOk {
code: ResponseCode,
zone: Name,
},
}
pub struct DnsConnectorClient {
client: DnsClient
}
impl DnsConnectorClient {
pub fn new(client: DnsClient) -> Self {
DnsConnectorClient {
client
}
}
}
#[async_trait]
impl RecordConnector for DnsConnectorClient {
type Error = DnsConnectorError;
async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result<Vec<Record>, Self::Error>
{
let response = {
let query = self.client.query(zone.clone(), class, RecordType::AXFR);
query.await.map_err(|e| DnsConnectorError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {
return Err(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<(), Self::Error>
{
// Taken from trust_dns_client::op::update_message::append
// The original function can not be used as is because it takes a RecordSet and not a Record list
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(new_records);
{
let edns = message.edns_mut();
edns.set_max_payload(MAX_PAYLOAD_LEN);
edns.set_version(0);
}
let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?;
if response.response_code() != ResponseCode::NoError {
return Err(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
Ok(())
}
async fn update_records(&mut self, zone: Name, class: DNSClass, old_records: Vec<Record>, new_records: Vec<Record>) -> Result<(), Self::Error>
{
// Taken from trust_dns_client::op::update_message::compare_and_swap
// The original function can not be used as is because it takes a RecordSet and not a Record list
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
// make sure the record is what is expected
let mut prerequisite = old_records.clone();
for record in prerequisite.iter_mut() {
record.set_ttl(0);
}
message.add_pre_requisites(prerequisite);
// add the delete for the old record
let mut delete = old_records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// insert the new record...
message.add_updates(new_records);
// Extended dns
{
let edns = message.edns_mut();
edns.set_max_payload(MAX_PAYLOAD_LEN);
edns.set_version(0);
}
let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?;
if response.response_code() != ResponseCode::NoError {
return Err(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
Ok(())
}
async fn delete_records(&mut self, zone: Name, class: DNSClass, records: Vec<Record>) -> Result<(), Self::Error>
{
// for updates, the query section is used for the zone
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
let mut delete = records;
for record in delete.iter_mut() {
// the class must be none for delete
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// Extended dns
{
let edns = message.edns_mut();
edns.set_max_payload(MAX_PAYLOAD_LEN);
edns.set_version(0);
}
let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?;
if response.response_code() != ResponseCode::NoError {
return Err(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
Ok(())
}
}
#[async_trait]
impl ZoneConnector for DnsConnectorClient {
type Error = DnsConnectorError;
async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> Result<(), Self::Error>
{
let response = {
let query = self.client.query(zone.clone(), class, RecordType::SOA);
query.await.map_err(|e| DnsConnectorError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {
return Err(DnsConnectorError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
Ok(())
}
}

17
src/dns/mod.rs 100644
View File

@ -0,0 +1,17 @@
pub mod client;
pub mod dns_connector;
pub mod connector;
// Reexport trust dns types for convenience
pub use trust_dns_client::rr::rdata::{
DNSSECRData, caa, sshfp, mx, null, soa, srv, txt
};
pub use trust_dns_client::rr::{
RData, DNSClass, Record
};
pub use trust_dns_proto::rr::Name;
// Reexport module types
pub use connector::{RecordConnector, ZoneConnector};
pub use dns_connector::{DnsConnectorClient, DnsConnectorError};
pub use client::DnsClient;

View File

@ -8,6 +8,7 @@ mod models;
mod config; mod config;
mod schema; mod schema;
mod routes; mod routes;
mod dns;
use routes::users::*; use routes::users::*;
use routes::zones::*; use routes::zones::*;
@ -26,9 +27,13 @@ async fn rocket() -> rocket::Rocket {
.attach(DbConn::fairing()) .attach(DbConn::fairing())
.mount("/api/v1", routes![ .mount("/api/v1", routes![
get_zone_records, get_zone_records,
create_zone_records,
update_zone_records,
delete_zone_records,
get_zones, get_zones,
create_zone,
add_member_to_zone, add_member_to_zone,
create_auth_token, create_auth_token,
create_user create_user,
]) ])
} }

63
src/models/auth.rs 100644
View File

@ -0,0 +1,63 @@
use uuid::Uuid;
use serde::{Serialize, Deserialize};
use chrono::serde::ts_seconds;
use chrono::prelude::{DateTime, Utc};
use chrono::Duration;
use jsonwebtoken::{
encode, decode,
Header, Validation,
Algorithm as JwtAlgorithm, EncodingKey, DecodingKey,
errors::Result as JwtResult
};
use crate::models::user::UserInfo;
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthClaims {
pub jti: String,
pub sub: String,
#[serde(with = "ts_seconds")]
pub exp: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub iat: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
pub struct AuthTokenResponse {
pub token: String
}
#[derive(Debug, Deserialize)]
pub struct AuthTokenRequest {
pub username: String,
pub password: String,
}
impl AuthClaims {
pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims {
let jti = Uuid::new_v4().to_simple().to_string();
let iat = Utc::now();
let exp = iat + token_duration;
AuthClaims {
jti,
sub: user_info.id.clone(),
exp,
iat,
}
}
pub fn decode(token: &str, secret: &str) -> JwtResult<AuthClaims> {
decode::<AuthClaims>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::new(JwtAlgorithm::HS256)
).map(|data| data.claims)
}
pub fn encode(self, secret: &str) -> JwtResult<String> {
encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref()))
}
}

View File

@ -0,0 +1,40 @@
use serde::{Deserialize, Serialize};
use crate::dns;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum DNSClass {
IN,
CH,
HS,
NONE,
ANY,
OPT(u16),
}
impl From<dns::DNSClass> for DNSClass {
fn from(dns_class: dns::DNSClass) -> DNSClass {
match dns_class {
dns::DNSClass::IN => DNSClass::IN,
dns::DNSClass::CH => DNSClass::CH,
dns::DNSClass::HS => DNSClass::HS,
dns::DNSClass::NONE => DNSClass::NONE,
dns::DNSClass::ANY => DNSClass::ANY,
dns::DNSClass::OPT(v) => DNSClass::OPT(v),
}
}
}
impl From<DNSClass> for dns::DNSClass {
fn from(dns_class: DNSClass) -> dns::DNSClass {
match dns_class {
DNSClass::IN => dns::DNSClass::IN,
DNSClass::CH => dns::DNSClass::CH,
DNSClass::HS => dns::DNSClass::HS,
DNSClass::NONE => dns::DNSClass::NONE,
DNSClass::ANY => dns::DNSClass::ANY,
DNSClass::OPT(v) => dns::DNSClass::OPT(v),
}
}
}

View File

@ -1,329 +0,0 @@
use std::net::{Ipv6Addr, Ipv4Addr};
use std::fmt;
use std::ops::{Deref, DerefMut};
use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}};
use serde::{Deserialize, Deserializer, Serialize};
use tokio::{net::TcpStream as TokioTcpStream, task};
use trust_dns_client::{client::AsyncClient, serialize::binary::BinEncoder, tcp::TcpClientStream};
use trust_dns_proto::error::{ProtoError};
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use super::trust_dns_types::{self, Name};
use crate::config::Config;
#[derive(Deserialize, Serialize)]
#[serde(tag = "Type")]
#[serde(rename_all = "UPPERCASE")]
pub enum RData {
#[serde(rename_all = "PascalCase")]
A {
address: Ipv4Addr
},
#[serde(rename_all = "PascalCase")]
AAAA {
address: Ipv6Addr
},
#[serde(rename_all = "PascalCase")]
CAA {
issuer_critical: bool,
value: String,
property_tag: String,
},
#[serde(rename_all = "PascalCase")]
CNAME {
target: String
},
// HINFO(HINFO),
// HTTPS(SVCB),
#[serde(rename_all = "PascalCase")]
MX {
preference: u16,
mail_exchanger: String
},
// NAPTR(NAPTR),
#[serde(rename_all = "PascalCase")]
NULL {
data: String
},
#[serde(rename_all = "PascalCase")]
NS {
target: String
},
// OPENPGPKEY(OPENPGPKEY),
// OPT(OPT),
#[serde(rename_all = "PascalCase")]
PTR {
target: String
},
#[serde(rename_all = "PascalCase")]
SOA {
master_server_name: String,
maintainer_name: String,
refresh: i32,
retry: i32,
expire: i32,
minimum: u32,
serial: u32
},
#[serde(rename_all = "PascalCase")]
SRV {
server: String,
port: u16,
priority: u16,
weight: u16,
},
#[serde(rename_all = "PascalCase")]
SSHFP {
algorithm: u8,
digest_type: u8,
fingerprint: String,
},
// SVCB(SVCB),
// TLSA(TLSA),
#[serde(rename_all = "PascalCase")]
TXT {
text: String
},
// TODO: Eventually allow deserialization of DNSSEC records
#[serde(skip)]
DNSSEC(trust_dns_types::DNSSECRData),
#[serde(rename_all = "PascalCase")]
Unknown {
code: u16,
data: String,
},
// ZERO,
}
impl From<trust_dns_types::RData> for RData {
fn from(rdata: trust_dns_types::RData) -> RData {
match rdata {
trust_dns_types::RData::A(address) => RData::A { address },
trust_dns_types::RData::AAAA(address) => RData::AAAA { address },
// Still a draft, no iana number yet, I don't to put something that is not currently supported so that's why NULL and not unknown.
// TODO: probably need better error here, I don't know what to do about that as this would require to change the From for something else.
// (empty data because I'm lazy)
trust_dns_types::RData::ANAME(_) => RData::NULL {
data: String::new()
},
trust_dns_types::RData::CNAME(target) => RData::CNAME {
target: target.to_utf8()
},
trust_dns_types::RData::CAA(caa) => RData::CAA {
issuer_critical: caa.issuer_critical(),
value: format!("{}", CAAValue(caa.value())),
property_tag: caa.tag().as_str().to_string(),
},
trust_dns_types::RData::MX(mx) => RData::MX {
preference: mx.preference(),
mail_exchanger: mx.exchange().to_utf8()
},
trust_dns_types::RData::NULL(null) => RData::NULL {
data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default())
},
trust_dns_types::RData::NS(target) => RData::NS {
target: target.to_utf8()
},
trust_dns_types::RData::PTR(target) => RData::PTR {
target: target.to_utf8()
},
trust_dns_types::RData::SOA(soa) => RData::SOA {
master_server_name: soa.mname().to_utf8(),
maintainer_name: soa.rname().to_utf8(),
refresh: soa.refresh(),
retry: soa.retry(),
expire: soa.expire(),
minimum: soa.minimum(),
serial: soa.serial()
},
trust_dns_types::RData::SRV(srv) => RData::SRV {
server: srv.target().to_utf8(),
port: srv.port(),
priority: srv.priority(),
weight: srv.weight(),
},
trust_dns_types::RData::SSHFP(sshfp) => RData::SSHFP {
algorithm: sshfp.algorithm().into(),
digest_type: sshfp.fingerprint_type().into(),
fingerprint: trust_dns_types::sshfp::HEX.encode(sshfp.fingerprint()),
},
trust_dns_types::RData::TXT(txt) => RData::TXT { text: format!("{}", txt) },
trust_dns_types::RData::DNSSEC(data) => RData::DNSSEC(data),
rdata => {
let code = rdata.to_record_type().into();
let mut data = Vec::new();
let mut encoder = BinEncoder::new(&mut data);
// TODO: need better error handling (use TryFrom ?)
rdata.emit(&mut encoder).expect("could not encode data");
RData::Unknown {
code,
data: base64::encode(data),
}
}
}
}
}
struct CAAValue<'a>(&'a trust_dns_types::caa::Value);
// trust_dns Display implementation panics if no parameters
// Implementation based on caa::emit_value
// Also the quotes are strips to render in JSON
impl<'a> fmt::Display for CAAValue<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self.0 {
trust_dns_types::caa::Value::Issuer(name, parameters) => {
if let Some(name) = name {
write!(f, "{}", name)?;
}
if name.is_none() && parameters.is_empty() {
write!(f, ";")?;
}
for value in parameters {
write!(f, "; {}", value)?;
}
}
trust_dns_types::caa::Value::Url(url) => write!(f, "{}", url)?,
trust_dns_types::caa::Value::Unknown(v) => write!(f, "{:?}", v)?,
}
Ok(())
}
}
#[derive(Deserialize, Serialize)]
pub enum DNSClass {
IN,
CH,
HS,
NONE,
ANY,
OPT(u16),
}
impl From<trust_dns_types::DNSClass> for DNSClass {
fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass {
match dns_class {
trust_dns_types::DNSClass::IN => DNSClass::IN,
trust_dns_types::DNSClass::CH => DNSClass::CH,
trust_dns_types::DNSClass::HS => DNSClass::HS,
trust_dns_types::DNSClass::NONE => DNSClass::NONE,
trust_dns_types::DNSClass::ANY => DNSClass::ANY,
trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v),
}
}
}
#[derive(Deserialize, Serialize)]
pub struct Record {
#[serde(rename = "Name")]
pub name: String,
#[serde(rename = "Class")]
pub dns_class: DNSClass,
#[serde(rename = "TTL")]
pub ttl: u32,
#[serde(flatten)]
pub rdata: RData,
}
impl From<trust_dns_types::Record> for Record {
fn from(record: trust_dns_types::Record) -> Record {
Record {
name: record.name().to_utf8(),
//rr_type: record.rr_type().into(),
dns_class: record.dns_class().into(),
ttl: record.ttl(),
rdata: record.into_data().into(),
}
}
}
#[derive(Debug)]
pub struct AbsoluteName(Name);
impl<'r> FromParam<'r> for AbsoluteName {
type Error = ProtoError;
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
let mut name = Name::from_utf8(&param)?;
if !name.is_fqdn() {
name.set_fqdn(true);
}
Ok(AbsoluteName(name))
}
}
impl<'de> Deserialize<'de> for AbsoluteName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|string|
AbsoluteName::from_param(&string)
.map_err(|e| Error::custom(e.to_string()))
)
}
}
impl Deref for AbsoluteName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct DnsClient(AsyncClient);
impl Deref for DnsClient {
type Target = AsyncClient;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsClient {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<State<Config>>().await);
let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(config.dns.server);
let client = AsyncClient::with_timeout(
stream,
handle,
std::time::Duration::from_secs(5),
None);
let (client, bg) = match client.await {
Err(e) => {
println!("Failed to connect to DNS server {:#?}", e);
return Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => c
};
task::spawn(bg);
Outcome::Success(DnsClient(client))
}
}

View File

@ -3,8 +3,37 @@ use rocket::http::Status;
use rocket::request::{Request, Outcome}; use rocket::request::{Request, Outcome};
use rocket::response::{self, Response, Responder}; use rocket::response::{self, Response, Responder};
use rocket_contrib::json::Json; use rocket_contrib::json::Json;
use crate::models::users::UserError;
use serde_json::Value; use serde_json::Value;
use djangohashers::{HasherError};
use diesel::result::Error as DieselError;
use crate::dns::DnsConnectorError;
use crate::models;
#[derive(Debug)]
pub enum UserError {
ZoneNotFound,
NotFound,
UserConflict,
BadCreds,
BadToken,
ExpiredToken,
MalformedHeader,
PermissionDenied,
DbError(DieselError),
PasswordError(HasherError),
}
impl From<HasherError> for UserError {
fn from(e: HasherError) -> Self {
UserError::PasswordError(e)
}
}
impl From<DieselError> for UserError {
fn from(e: DieselError) -> Self {
UserError::DbError(e)
}
}
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
@ -70,6 +99,55 @@ impl From<UserError> for ErrorResponse {
} }
} }
impl From<DnsConnectorError> for ErrorResponse {
fn from(e: DnsConnectorError) -> Self {
match e {
DnsConnectorError::ResponceNotOk { code, zone } => {
println!("Query for zone {} failed with code {}", zone, code);
ErrorResponse::new(
Status::NotFound,
"Zone could not be found".into()
).with_details(json!({
"zone_name": zone.to_utf8()
}))
},
DnsConnectorError::ClientError(e) => make_500(e)
}
}
}
impl From<models::RecordListParseError> for ErrorResponse {
fn from(e: models::RecordListParseError) -> Self {
match e {
models::RecordListParseError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone} => {
ErrorResponse::new(
Status::BadRequest,
"Record list contains records that do not belong to the zone".into()
).with_details(
json!({
"zone_name": zone.to_utf8(),
"class": models::DNSClass::from(class),
"mismatched_class": mismatched_class,
"mismatched_zone": mismatched_zone,
})
)
},
models::RecordListParseError::ParseError { zone, bad_records } => {
ErrorResponse::new(
Status::BadRequest,
"Record list contains records that could not be parsed into DNS records".into()
).with_details(
json!({
"zone_name": zone.to_utf8(),
"records": bad_records
})
)
}
}
}
}
impl<S> From<ErrorResponse> for Outcome<S, ErrorResponse> { impl<S> From<ErrorResponse> for Outcome<S, ErrorResponse> {
fn from(e: ErrorResponse) -> Self { fn from(e: ErrorResponse) -> Self {

View File

@ -1,13 +1,19 @@
pub mod dns; //pub mod dns;
pub mod auth;
pub mod class;
pub mod errors; pub mod errors;
pub mod users; pub mod name;
pub mod rdata;
pub mod record;
pub mod user;
pub mod zone;
pub mod trust_dns_types { // Reexport types for convenience
pub use trust_dns_client::rr::rdata::{ pub use auth::{AuthClaims, AuthTokenRequest, AuthTokenResponse};
DNSSECRData, caa, sshfp, pub use class::DNSClass;
}; pub use errors::{UserError, ErrorResponse, make_500};
pub use trust_dns_client::rr::{ pub use name::{AbsoluteName, SerdeName};
RData, DNSClass, Record pub use user::{LocalUser, UserInfo, Role, UserZone, User, CreateUserRequest};
}; pub use rdata::RData;
pub use trust_dns_proto::rr::Name; pub use record::{Record, RecordList, ParseRecordList, RecordListParseError, UpdateRecordsRequest};
} pub use zone::{Zone, AddZoneMemberRequest, CreateZoneRequest};

78
src/models/name.rs 100644
View File

@ -0,0 +1,78 @@
use std::ops::Deref;
use rocket::request::FromParam;
use serde::{Deserialize, Serialize, Deserializer, Serializer};
use trust_dns_proto::error::ProtoError;
use crate::dns::Name;
#[derive(Debug, Clone)]
pub struct SerdeName(pub(crate)Name);
impl Deref for SerdeName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for SerdeName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|string|
Name::from_utf8(&string)
.map_err(|e| Error::custom(e.to_string()))
).map( SerdeName)
}
}
impl Serialize for SerdeName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer
{
self.0.to_utf8().serialize(serializer)
}
}
impl SerdeName {
pub fn into_inner(self) -> Name {
self.0
}
}
#[derive(Debug, Deserialize)]
pub struct AbsoluteName(SerdeName);
impl<'r> FromParam<'r> for AbsoluteName {
type Error = ProtoError;
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
let mut name = Name::from_utf8(&param)?;
if !name.is_fqdn() {
name.set_fqdn(true);
}
Ok(AbsoluteName(SerdeName(name)))
}
}
impl Deref for AbsoluteName {
type Target = Name;
fn deref(&self) -> &Self::Target {
&self.0.0
}
}
impl AbsoluteName {
pub fn into_inner(self) -> Name {
self.0.0
}
}

282
src/models/rdata.rs 100644
View File

@ -0,0 +1,282 @@
use std::fmt;
use std::convert::TryFrom;
use std::net::{Ipv6Addr, Ipv4Addr};
use serde::{Deserialize, Serialize};
use trust_dns_client::serialize::binary::BinEncoder;
use trust_dns_proto::error::ProtoError;
use crate::dns;
use super::name::SerdeName;
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "Type")]
#[serde(rename_all = "UPPERCASE")]
pub enum RData {
#[serde(rename_all = "PascalCase")]
A {
address: Ipv4Addr
},
#[serde(rename_all = "PascalCase")]
AAAA {
address: Ipv6Addr
},
#[serde(rename_all = "PascalCase")]
CAA {
issuer_critical: bool,
value: String,
property_tag: String,
},
#[serde(rename_all = "PascalCase")]
CNAME {
target: SerdeName
},
// HINFO(HINFO),
// HTTPS(SVCB),
#[serde(rename_all = "PascalCase")]
MX {
preference: u16,
mail_exchanger: SerdeName
},
// NAPTR(NAPTR),
#[serde(rename_all = "PascalCase")]
NULL {
data: String
},
#[serde(rename_all = "PascalCase")]
NS {
target: SerdeName
},
// OPENPGPKEY(OPENPGPKEY),
// OPT(OPT),
#[serde(rename_all = "PascalCase")]
PTR {
target: SerdeName
},
#[serde(rename_all = "PascalCase")]
SOA {
master_server_name: SerdeName,
maintainer_name: SerdeName,
refresh: i32,
retry: i32,
expire: i32,
minimum: u32,
serial: u32
},
#[serde(rename_all = "PascalCase")]
SRV {
server: SerdeName,
port: u16,
priority: u16,
weight: u16,
},
#[serde(rename_all = "PascalCase")]
SSHFP {
algorithm: u8,
digest_type: u8,
fingerprint: String,
},
// SVCB(SVCB),
// TLSA(TLSA),
#[serde(rename_all = "PascalCase")]
TXT {
text: String
},
// TODO: Eventually allow deserialization of DNSSEC records
#[serde(skip)]
DNSSEC(dns::DNSSECRData),
#[serde(rename_all = "PascalCase")]
Unknown {
code: u16,
data: String,
},
// ZERO,
// TODO: DS
// TODO: TLSA
}
impl From<dns::RData> for RData {
fn from(rdata: dns::RData) -> RData {
match rdata {
dns::RData::A(address) => RData::A { address },
dns::RData::AAAA(address) => RData::AAAA { address },
// Still a draft, no iana number yet, I don't to put something that is not currently supported so that's why NULL and not unknown.
// TODO: probably need better error here, I don't know what to do about that as this would require to change the From for something else.
// (empty data because I'm lazy)
dns::RData::ANAME(_) => RData::NULL {
data: String::new()
},
dns::RData::CNAME(target) => RData::CNAME {
target: SerdeName(target)
},
dns::RData::CAA(caa) => RData::CAA {
issuer_critical: caa.issuer_critical(),
value: format!("{}", CAAValue(caa.value())),
property_tag: caa.tag().as_str().to_string(),
},
dns::RData::MX(mx) => RData::MX {
preference: mx.preference(),
mail_exchanger: SerdeName(mx.exchange().clone())
},
dns::RData::NULL(null) => RData::NULL {
data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default())
},
dns::RData::NS(target) => RData::NS {
target: SerdeName(target)
},
dns::RData::PTR(target) => RData::PTR {
target: SerdeName(target)
},
dns::RData::SOA(soa) => RData::SOA {
master_server_name: SerdeName(soa.mname().clone()),
maintainer_name: SerdeName(soa.rname().clone()),
refresh: soa.refresh(),
retry: soa.retry(),
expire: soa.expire(),
minimum: soa.minimum(),
serial: soa.serial()
},
dns::RData::SRV(srv) => RData::SRV {
server: SerdeName(srv.target().clone()),
port: srv.port(),
priority: srv.priority(),
weight: srv.weight(),
},
dns::RData::SSHFP(sshfp) => RData::SSHFP {
algorithm: sshfp.algorithm().into(),
digest_type: sshfp.fingerprint_type().into(),
fingerprint: dns::sshfp::HEX.encode(sshfp.fingerprint()),
},
//TODO: This might alter data if not utf8 compatible, probably need to be replaced
//TODO: check whether concatenating txt data is harmful or not
dns::RData::TXT(txt) => RData::TXT { text: format!("{}", txt) },
dns::RData::DNSSEC(data) => RData::DNSSEC(data),
rdata => {
let code = rdata.to_record_type().into();
let mut data = Vec::new();
let mut encoder = BinEncoder::new(&mut data);
// TODO: need better error handling (use TryFrom ?)
rdata.emit(&mut encoder).expect("could not encode data");
RData::Unknown {
code,
data: base64::encode(data),
}
}
}
}
}
impl TryFrom<RData> for dns::RData {
type Error = ProtoError;
fn try_from(rdata: RData) -> Result<Self, Self::Error> {
Ok(match rdata {
RData::A { address } => dns::RData::A(address),
RData::AAAA { address } => dns::RData::AAAA(address),
// TODO: Round trip test all types below (currently not tested...)
RData::CAA { issuer_critical, value, property_tag } => {
let property = dns::caa::Property::from(property_tag);
let caa_value = {
// TODO: duplicate of trust_dns_client::serialize::txt::rdata_parser::caa::parse
// because caa::read_value is private
match property {
dns::caa::Property::Issue | dns::caa::Property::IssueWild => {
let value = dns::caa::read_issuer(value.as_bytes())?;
dns::caa::Value::Issuer(value.0, value.1)
}
dns::caa::Property::Iodef => {
let url = dns::caa::read_iodef(value.as_bytes())?;
dns::caa::Value::Url(url)
}
dns::caa::Property::Unknown(_) => dns::caa::Value::Unknown(value.as_bytes().to_vec()),
}
};
dns::RData::CAA(dns::caa::CAA {
issuer_critical,
tag: property,
value: caa_value,
})
},
RData::CNAME { target } => dns::RData::CNAME(target.into_inner()),
RData::MX { preference, mail_exchanger } => dns::RData::MX(
dns::mx::MX::new(preference, mail_exchanger.into_inner())
),
RData::NULL { data } => dns::RData::NULL(
dns::null::NULL::with(
base64::decode(data).map_err(|e| ProtoError::from(format!("{}", e)))?
)
),
RData::NS { target } => dns::RData::NS(target.into_inner()),
RData::PTR { target } => dns::RData::PTR(target.into_inner()),
RData::SOA {
master_server_name,
maintainer_name,
refresh,
retry,
expire,
minimum,
serial
} => dns::RData::SOA(
dns::soa::SOA::new(
master_server_name.into_inner(),
maintainer_name.into_inner(),
serial,
refresh,
retry,
expire,
minimum,
)
),
RData::SRV { server, port, priority, weight } => dns::RData::SRV(
dns::srv::SRV::new(priority, weight, port, server.into_inner())
),
RData::SSHFP { algorithm, digest_type, fingerprint } => dns::RData::SSHFP(
dns::sshfp::SSHFP::new(
// NOTE: This allows unassigned algorithms
dns::sshfp::Algorithm::from(algorithm),
dns::sshfp::FingerprintType::from(digest_type),
dns::sshfp::HEX.decode(fingerprint.as_bytes()).map_err(|e| ProtoError::from(format!("{}", e)))?
)
),
RData::TXT { text } => dns::RData::TXT(dns::txt::TXT::new(vec![text])),
// TODO: Error out for DNSSEC? Prefer downstream checks?
RData::DNSSEC(_) => todo!(),
// TODO: Disallow unknown? (could be used to bypass unsopported types?) Prefer downstream checks?
RData::Unknown { code: _code, data: _data } => todo!(),
})
}
}
struct CAAValue<'a>(&'a dns::caa::Value);
// trust_dns Display implementation panics if no parameters
// Implementation based on caa::emit_value
// Also the quotes are strips to render in JSON
impl<'a> fmt::Display for CAAValue<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self.0 {
dns::caa::Value::Issuer(name, parameters) => {
if let Some(name) = name {
write!(f, "{}", name)?;
}
if name.is_none() && parameters.is_empty() {
write!(f, ";")?;
}
for value in parameters {
write!(f, "; {}", value)?;
}
}
dns::caa::Value::Url(url) => write!(f, "{}", url)?,
dns::caa::Value::Unknown(v) => write!(f, "{:?}", v)?,
}
Ok(())
}
}

View File

@ -0,0 +1,122 @@
use std::convert::{TryFrom, TryInto};
use serde::{Deserialize, Serialize};
use trust_dns_proto::error::ProtoError;
use crate::dns;
use super::name::SerdeName;
use super::class::DNSClass;
use super::rdata::RData;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Record {
#[serde(rename = "Name")]
pub name: SerdeName,
// TODO: Make class optional, default to IN
#[serde(rename = "Class")]
pub dns_class: DNSClass,
#[serde(rename = "TTL")]
pub ttl: u32,
#[serde(flatten)]
pub rdata: RData,
}
impl From<dns::Record> for Record {
fn from(record: dns::Record) -> Record {
Record {
name: SerdeName(record.name().clone()),
dns_class: record.dns_class().into(),
ttl: record.ttl(),
rdata: record.into_data().into(),
}
}
}
impl TryFrom<Record> for dns::Record {
type Error = ProtoError;
fn try_from(record: Record) -> Result<Self, Self::Error> {
let mut trust_dns_record = dns::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?);
trust_dns_record.set_dns_class(record.dns_class.into());
Ok(trust_dns_record)
}
}
pub type RecordList = Vec<Record>;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpdateRecordsRequest {
pub old_records: RecordList,
pub new_records: RecordList,
}
pub enum RecordListParseError {
ParseError {
bad_records: Vec<Record>,
zone: dns::Name,
},
RecordNotInZone {
zone: dns::Name,
class: dns::DNSClass,
mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>,
},
}
pub trait ParseRecordList {
fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, RecordListParseError>;
}
impl ParseRecordList for RecordList {
fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, RecordListParseError> {
// TODO: What about relative names (also in cnames and stuff)
let mut bad_records = Vec::new();
let mut records: Vec<dns::Record> = Vec::new();
let mut mismatched_class: Vec<Record> = Vec::new();
let mut mismatched_zone: Vec<Record> = Vec::new();
for record in self.into_iter() {
let this_record = record.clone();
if let Ok(record) = dns::Record::try_from(record) {
let mut good_record = true;
if !zone.zone_of(record.name()) {
mismatched_zone.push(this_record.clone());
good_record = false;
}
if record.dns_class() != class {
mismatched_class.push(this_record.clone());
good_record = false;
}
if good_record {
records.push(record);
}
} else {
bad_records.push(this_record.clone());
}
}
if !bad_records.is_empty() {
return Err(RecordListParseError::ParseError {
zone,
bad_records,
});
}
if !mismatched_class.is_empty() || !mismatched_zone.is_empty() {
return Err(RecordListParseError::RecordNotInZone {
zone,
class,
mismatched_zone,
mismatched_class
});
}
return Ok(records)
}
}

View File

@ -3,28 +3,23 @@ use diesel::prelude::*;
use diesel::result::Error as DieselError; use diesel::result::Error as DieselError;
use diesel_derive_enum::DbEnum; use diesel_derive_enum::DbEnum;
use rocket::{State, request::{FromRequest, Request, Outcome}}; use rocket::{State, request::{FromRequest, Request, Outcome}};
use serde::{Serialize, Deserialize}; use serde::{Deserialize};
use chrono::serde::ts_seconds;
use chrono::prelude::{DateTime, Utc};
use chrono::Duration;
// TODO: Maybe just use argon2 crate directly // TODO: Maybe just use argon2 crate directly
use djangohashers::{make_password_with_algorithm, check_password, HasherError, Algorithm}; use djangohashers::{make_password_with_algorithm, check_password, Algorithm};
use jsonwebtoken::{ use jsonwebtoken::{
encode, decode,
Header, Validation,
Algorithm as JwtAlgorithm, EncodingKey, DecodingKey,
errors::Result as JwtResult,
errors::ErrorKind as JwtErrorKind errors::ErrorKind as JwtErrorKind
}; };
use crate::schema::*; use crate::schema::*;
use crate::DbConn; use crate::DbConn;
use crate::config::Config; use crate::config::Config;
use crate::models::errors::{ErrorResponse, make_500}; use crate::models::errors::{UserError, ErrorResponse, make_500};
use crate::models::zone::Zone;
use crate::models::auth::AuthClaims;
const BEARER: &str = "Bearer "; const BEARER: &str = "Bearer ";
const AUTH_HEADER: &str = "Authentication"; const AUTH_HEADER: &str = "Authorization";
#[derive(Debug, DbEnum, Deserialize, Clone)] #[derive(Debug, DbEnum, Deserialize, Clone)]
@ -61,14 +56,6 @@ pub struct UserZone {
pub zone_id: String, pub zone_id: String,
} }
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CreateUserRequest { pub struct CreateUserRequest {
pub username: String, pub username: String,
@ -77,37 +64,6 @@ pub struct CreateUserRequest {
pub role: Option<Role> pub role: Option<Role>
} }
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
// pub struct LdapUserAssociation {
// user_id: Uuid,
// ldap_id: String
// }
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthClaims {
pub jti: String,
pub sub: String,
#[serde(with = "ts_seconds")]
pub exp: DateTime<Utc>,
#[serde(with = "ts_seconds")]
pub iat: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
pub struct AuthTokenResponse {
pub token: String
}
#[derive(Debug, Deserialize)]
pub struct AuthTokenRequest {
pub username: String,
pub password: String,
}
#[derive(Debug)] #[derive(Debug)]
pub struct UserInfo { pub struct UserInfo {
pub id: String, pub id: String,
@ -200,32 +156,6 @@ impl<'r> FromRequest<'r> for UserInfo {
} }
} }
#[derive(Debug)]
pub enum UserError {
ZoneNotFound,
NotFound,
UserConflict,
BadCreds,
BadToken,
ExpiredToken,
MalformedHeader,
PermissionDenied,
DbError(DieselError),
PasswordError(HasherError),
}
impl From<HasherError> for UserError {
fn from(e: HasherError) -> Self {
UserError::PasswordError(e)
}
}
impl From<DieselError> for UserError {
fn from(e: DieselError) -> Self {
UserError::DbError(e)
}
}
impl LocalUser { impl LocalUser {
pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result<UserInfo, UserError> { pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result<UserInfo, UserError> {
use crate::schema::localuser::dsl::*; use crate::schema::localuser::dsl::*;
@ -315,76 +245,4 @@ impl LocalUser {
username: client_localuser.username, username: client_localuser.username,
}) })
} }
}
impl AuthClaims {
pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims {
let jti = Uuid::new_v4().to_simple().to_string();
let iat = Utc::now();
let exp = iat + token_duration;
AuthClaims {
jti,
sub: user_info.id.clone(),
exp,
iat,
}
}
pub fn decode(token: &str, secret: &str) -> JwtResult<AuthClaims> {
decode::<AuthClaims>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::new(JwtAlgorithm::HS256)
).map(|data| data.claims)
}
pub fn encode(self, secret: &str) -> JwtResult<String> {
encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref()))
}
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
} }

93
src/models/zone.rs 100644
View File

@ -0,0 +1,93 @@
use crate::models::user::UserInfo;
use uuid::Uuid;
use diesel::prelude::*;
use diesel::result::Error as DieselError;
use serde::{Serialize, Deserialize};
use crate::schema::*;
use super::name::AbsoluteName;
use super::user::UserZone;
use super::errors::UserError;
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
#[table_name = "zone"]
pub struct Zone {
#[serde(skip)]
pub id: String,
pub name: String,
}
#[derive(Debug, Deserialize)]
pub struct AddZoneMemberRequest {
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateZoneRequest {
pub name: AbsoluteName,
}
// NOTE: Should probably not be implemented here
// also, "UserError" seems like a misleading name
impl Zone {
pub fn get_all(conn: &diesel::SqliteConnection) -> Result<Vec<Zone>, UserError> {
use crate::schema::zone::dsl::*;
zone.get_results(conn)
.map_err(UserError::DbError)
}
pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
zone.filter(name.eq(zone_name))
.get_result(conn)
.map_err(|e| match e {
DieselError::NotFound => UserError::ZoneNotFound,
other => UserError::DbError(other)
})
}
pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result<Zone, UserError> {
use crate::schema::zone::dsl::*;
let new_zone = Zone {
id: Uuid::new_v4().to_simple().to_string(),
name: zone_request.name.to_utf8(),
};
diesel::insert_into(zone)
.values(&new_zone)
.execute(conn)
.map_err(|e| match e {
DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict,
other => UserError::DbError(other)
})?;
Ok(new_zone)
}
pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> {
use crate::schema::user_zone::dsl::*;
let new_user_zone = UserZone {
zone_id: self.id.clone(),
user_id: new_member.id.clone()
};
let res = diesel::insert_into(user_zone)
.values(new_user_zone)
.execute(conn);
match res {
// If user has already access to the zone, safely ignore the conflit
// TODO: use 'on conflict do nothing' in postgres when we get there
Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (),
Err(e) => return Err(e.into()),
Ok(_) => ()
};
Ok(())
}
}

View File

@ -4,39 +4,32 @@ use rocket::http::Status;
use crate::config::Config; use crate::config::Config;
use crate::DbConn; use crate::DbConn;
use crate::models::errors::{ErrorResponse, make_500}; use crate::models;
use crate::models::users::{
LocalUser,
CreateUserRequest,
AuthClaims,
AuthTokenRequest,
AuthTokenResponse
};
#[post("/users/me/token", data = "<auth_request>")] #[post("/users/me/token", data = "<auth_request>")]
pub async fn create_auth_token( pub async fn create_auth_token(
conn: DbConn, conn: DbConn,
config: State<'_, Config>, config: State<'_, Config>,
auth_request: Json<AuthTokenRequest> auth_request: Json<models::AuthTokenRequest>
) -> Result<Json<AuthTokenResponse>, ErrorResponse> { ) -> Result<Json<models::AuthTokenResponse>, models::ErrorResponse> {
let user_info = conn.run(move |c| { let user_info = conn.run(move |c| {
LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password) models::LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password)
}).await?; }).await?;
let token = AuthClaims::new(&user_info, config.web_app.token_duration) let token = models::AuthClaims::new(&user_info, config.web_app.token_duration)
.encode(&config.web_app.secret) .encode(&config.web_app.secret)
.map_err(make_500)?; .map_err(models::make_500)?;
Ok(Json(AuthTokenResponse { token })) Ok(Json(models::AuthTokenResponse { token }))
} }
#[post("/users", data = "<user_request>")] #[post("/users", data = "<user_request>")]
pub async fn create_user<'r>(conn: DbConn, user_request: Json<CreateUserRequest>) -> Result<Response<'r>, ErrorResponse> { pub async fn create_user<'r>(conn: DbConn, user_request: Json<models::CreateUserRequest>) -> Result<Response<'r>, models::ErrorResponse> {
// TODO: Check current user if any to check if user has permission to create users (with or without role) // TODO: Check current user if any to check if user has permission to create users (with or without role)
conn.run(|c| { conn.run(|c| {
LocalUser::create_user(&c, user_request.into_inner()) models::LocalUser::create_user(&c, user_request.into_inner())
}).await?; }).await?;
Response::build() Response::build()

View File

@ -3,69 +3,145 @@ use rocket::http::Status;
use rocket_contrib::json::Json; use rocket_contrib::json::Json;
use trust_dns_client::client::ClientHandle; use crate::DbConn;
use trust_dns_client::op::ResponseCode; use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector};
use trust_dns_client::rr::{DNSClass, RecordType}; use crate::models;
use crate::models::{ParseRecordList};
use crate::{DbConn, models::dns};
use crate::models::errors::{ErrorResponse, make_500};
use crate::models::users::{LocalUser, UserInfo, Zone, AddZoneMemberRequest};
#[get("/zones/<zone>/records")] #[get("/zones/<zone>/records")]
pub async fn get_zone_records( pub async fn get_zone_records(
mut client: dns::DnsClient, client: DnsClient,
conn: DbConn, conn: DbConn,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: dns::AbsoluteName zone: models::AbsoluteName
) -> Result<Json<Vec<dns::Record>>, ErrorResponse> { ) -> Result<Json<models::RecordList>, models::ErrorResponse> {
let user_info = user_info?; let user_info = user_info?;
let zone_name = zone.to_string();
if !user_info.is_admin() { conn.run(move |c| {
let zone_name = zone.clone().to_string(); if user_info.is_admin() {
conn.run(move |c| { models::Zone::get_by_name(c, &zone_name)
} else {
user_info.get_zone(c, &zone_name) user_info.get_zone(c, &zone_name)
}).await?; }
} }).await?;
let response = { let mut dns_api = DnsConnectorClient::new(client);
let query = client.query(zone.clone(), DNSClass::IN, RecordType::AXFR);
query.await.map_err(make_500)?
};
// TODO: Better error handling (ex. not authorized should be 500) let dns_records = dns_api.get_records(zone.clone(), models::DNSClass::IN.into()).await?;
if response.response_code() != ResponseCode::NoError { let records: Vec<_> = dns_records.into_iter().map(models::Record::from).collect();
println!("Querrying of zone {} failed with code {}", *zone, response.response_code());
return ErrorResponse::new(
Status::NotFound,
format!("Zone {} could not be found", *zone)
).err()
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.map(dns::Record::from)
.filter(|record| !matches!(record.rdata, dns::RData::NULL { .. } | dns::RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(Json(records)) Ok(Json(records))
} }
// TODO: the post version of that #[post("/zones/<zone>/records", data = "<new_records>")]
pub async fn create_zone_records(
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
new_records: Json<models::RecordList>
) -> Result<Json<()>, models::ErrorResponse> {
let user_info = user_info?;
let zone_name = zone.to_utf8();
conn.run(move |c| {
if user_info.is_admin() {
models::Zone::get_by_name(c, &zone_name)
} else {
user_info.get_zone(c, &zone_name)
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.add_records(
zone.clone(),
models::DNSClass::IN.into(),
new_records.into_inner().try_into_dns_type(zone.into_inner(), models::DNSClass::IN.into())?
).await?;
return Ok(Json(()));
}
#[put("/zones/<zone>/records", data = "<update_records_request>")]
pub async fn update_zone_records(
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
update_records_request: Json<models::UpdateRecordsRequest>
) -> Result<Json<()>, models::ErrorResponse> {
let user_info = user_info?;
let zone = zone.into_inner();
let zone_name = zone.to_utf8();
let update_records_request = update_records_request.into_inner();
conn.run(move |c| {
if user_info.is_admin() {
models::Zone::get_by_name(c, &zone_name)
} else {
user_info.get_zone(c, &zone_name)
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.update_records(
zone.clone(),
models::DNSClass::IN.into(),
update_records_request.old_records.try_into_dns_type(zone.clone(), models::DNSClass::IN.into())?,
update_records_request.new_records.try_into_dns_type(zone, models::DNSClass::IN.into())?,
).await?;
return Ok(Json(()));
}
#[delete("/zones/<zone>/records", data = "<records>")]
pub async fn delete_zone_records(
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
records: Json<models::RecordList>
) -> Result<Json<()>, models::ErrorResponse> {
let user_info = user_info?;
let zone_name = zone.to_utf8();
conn.run(move |c| {
if user_info.is_admin() {
models::Zone::get_by_name(c, &zone_name)
} else {
user_info.get_zone(c, &zone_name)
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.delete_records(
zone.clone(),
models::DNSClass::IN.into(),
records.into_inner().try_into_dns_type(zone.into_inner(), models::DNSClass::IN.into())?
).await?;
return Ok(Json(()));
}
#[get("/zones")] #[get("/zones")]
pub async fn get_zones( pub async fn get_zones(
conn: DbConn, conn: DbConn,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<models::UserInfo, models::ErrorResponse>,
) -> Result<Json<Vec<Zone>>, ErrorResponse> { ) -> Result<Json<Vec<models::Zone>>, models::ErrorResponse> {
let user_info = user_info?; let user_info = user_info?;
let zones = conn.run(move |c| { let zones = conn.run(move |c| {
if user_info.is_admin() { if user_info.is_admin() {
Zone::get_all(c) models::Zone::get_all(c)
} else { } else {
user_info.get_zones(c) user_info.get_zones(c)
} }
@ -74,25 +150,44 @@ pub async fn get_zones(
Ok(Json(zones)) Ok(Json(zones))
} }
#[post("/zones", data = "<zone_request>")]
pub async fn create_zone(
conn: DbConn,
client: DnsClient,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone_request: Json<models::CreateZoneRequest>,
) -> Result<Json<models::Zone>, models::ErrorResponse> {
user_info?.check_admin()?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.zone_exists(zone_request.name.clone(), models::DNSClass::IN.into()).await?;
let zone = conn.run(move |c| {
models::Zone::create_zone(c, zone_request.into_inner())
}).await?;
Ok(Json(zone))
}
#[post("/zones/<zone>/members", data = "<zone_member_request>")] #[post("/zones/<zone>/members", data = "<zone_member_request>")]
pub async fn add_member_to_zone<'r>( pub async fn add_member_to_zone<'r>(
conn: DbConn, conn: DbConn,
zone: dns::AbsoluteName, zone: models::AbsoluteName,
user_info: Result<UserInfo, ErrorResponse>, user_info: Result<models::UserInfo, models::ErrorResponse>,
zone_member_request: Json<AddZoneMemberRequest> zone_member_request: Json<models::AddZoneMemberRequest>
) -> Result<Response<'r>, ErrorResponse> { ) -> Result<Response<'r>, models::ErrorResponse> {
let user_info = user_info?; let user_info = user_info?;
let zone_name = zone.to_utf8(); let zone_name = zone.to_utf8();
conn.run(move |c| { conn.run(move |c| {
let zone = if user_info.is_admin() { let zone = if user_info.is_admin() {
Zone::get_by_name(c, &zone_name) models::Zone::get_by_name(c, &zone_name)
} else { } else {
user_info.get_zone(c, &zone_name) user_info.get_zone(c, &zone_name)
}?; }?;
let new_member = LocalUser::get_user_by_uuid(c, &zone_member_request.id)?; let new_member = models::LocalUser::get_user_by_uuid(c, &zone_member_request.id)?;
zone.add_member(&c, &new_member) zone.add_member(&c, &new_member)
}).await?; }).await?;

View File

@ -1,6 +1,6 @@
table! { table! {
use diesel::sql_types::*; use diesel::sql_types::*;
use crate::models::users::*; use crate::models::user::*;
localuser (user_id) { localuser (user_id) {
user_id -> Text, user_id -> Text,