diff --git a/.gitignore b/.gitignore index a631336..48f07a6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ config.toml db.sqlite +__pycache__ +/env \ No newline at end of file diff --git a/api.yml b/api.yml new file mode 100644 index 0000000..fdef762 --- /dev/null +++ b/api.yml @@ -0,0 +1,437 @@ +openapi: '3.0.0' +info: + description: '' + version: 0.1.0-dev + title: Nomilo + + +components: + securitySchemes: + ApiToken: + type: http + scheme: bearer + bearerFormat: JWT + + parameters: + ZoneName: + name: zone + in: path + schema: + type: string + required: true + + schemas: + UserRequest: + type: object + required: + - username + - password + - email + properties: + username: + type: string + password: + type: string + email: + type: string + role: + type: string + enum: + - admin + - zoneadmin + + TokenRequest: + type: object + required: + - username + - password + properties: + username: + type: string + password: + type: string + + TokenResponse: + type: object + required: + - token + properties: + token: + type: string + + AddZoneMemberRequest: + type: object + required: + - id + properties: + id: + type: string + + CreateZoneRequest: + type: object + required: + - name + properties: + name: + type: string + + Zone: + type: object + required: + - name + properties: + name: + type: string + + ZoneList: + type: array + items: + $ref: '#/components/schemas/Zone' + + RecordBase: + type: object + required: + - Name + - Class + - TTL + - Type + properties: + Name: + type: string + Class: + type: string + enum: + - IN + - CH + - HS + - NONE + - ANY + TTL: + type: integer + Type: + type: string + + RecordTypeA: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Address + properties: + Address: + type: string + + RecordTypeAAAA: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Address + properties: + Address: + type: string + + RecordTypeCAA: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + required: + - IssuerCritical + - Value + - PropertyTag + properties: + IssuerCritical: + type: boolean + Value: + type: string + PropertyTag: + type: string + + RecordTypeCNAME: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Target + properties: + Target: + type: string + + RecordTypeMX: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Preference + - MailExchanger + properties: + Preference: + type: integer + MailExchanger: + type: string + + RecordTypeNS: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Target + properties: + Target: + type: string + + RecordTypePTR: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Target + properties: + Target: + type: string + + RecordTypeSOA: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - MasterServerName + - MaintainerName + - Refresh + - Retry + - Expire + - Minimum + - Serial + properties: + MasterServerName: + type: string + MaintainerName: + type: string + Refresh: + type: integer + Retry: + type: integer + Expire: + type: integer + Minimum: + type: integer + Serial: + type: integer + + RecordTypeSRV: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Server + - Port + - Priority + - Weight + properties: + Server: + type: string + Port: + type: integer + Priority: + type: integer + Weight: + type: integer + + RecordTypeSSHFP: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Algorithm + - DigestType + - Fingerprint + properties: + Algorithm: + type: integer + DigestType: + type: integer + Fingerprint: + type: string + + RecordTypeTXT: + type: object + allOf: + - $ref: '#/components/schemas/RecordBase' + - type: object + required: + - Text + properties: + Text: + type: string + + Record: + type: object + oneOf: + - $ref: '#/components/schemas/RecordTypeA' + - $ref: '#/components/schemas/RecordTypeAAAA' + - $ref: '#/components/schemas/RecordTypeCAA' + - $ref: '#/components/schemas/RecordTypeCNAME' + - $ref: '#/components/schemas/RecordTypeMX' + - $ref: '#/components/schemas/RecordTypeNS' + - $ref: '#/components/schemas/RecordTypePTR' + - $ref: '#/components/schemas/RecordTypeSOA' + - $ref: '#/components/schemas/RecordTypeSRV' + - $ref: '#/components/schemas/RecordTypeSSHFP' + - $ref: '#/components/schemas/RecordTypeTXT' + discriminator: + propertyName: Type + mapping: + A: '#/components/schemas/RecordTypeA' + AAAA: '#/components/schemas/RecordTypeAAAA' + CAA: '#/components/schemas/RecordTypeCAA' + CNAME: '#/components/schemas/RecordTypeCNAME' + MX: '#/components/schemas/RecordTypeMX' + NS: '#/components/schemas/RecordTypeNS' + PTR: '#/components/schemas/RecordTypePTR' + SOA: '#/components/schemas/RecordTypeSOA' + SRV: '#/components/schemas/RecordTypeSRV' + SSHFP: '#/components/schemas/RecordTypeSSHFP' + TXT: '#/components/schemas/RecordTypeTXT' + + RecordList: + type: array + items: + $ref: '#/components/schemas/Record' + + UpdateRecordsRequest: + type: object + required: + - oldRecords + - newRecords + properties: + oldRecords: + $ref: '#/components/schemas/RecordList' + newRecords: + $ref: '#/components/schemas/RecordList' + + +paths: + '/users': + post: + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UserRequest' + responses: + '201': + description: '' + + '/users/me/token': + post: + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TokenRequest' + responses: + '200': + description: '' + content: + application/json: + schema: + $ref: '#/components/schemas/TokenResponse' + '/zones': + get: + security: + - ApiToken: [] + responses: + '200': + description: '' + content: + application/json: + schema: + $ref: '#/components/schemas/ZoneList' + post: + security: + - ApiToken: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateZoneRequest' + responses: + '200': + description: '' + content: + application/json: + schema: + $ref: '#/components/schemas/Zone' + + '/zones/{zone}/members': + parameters: + - $ref: '#/components/parameters/ZoneName' + post: + security: + - ApiToken: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AddZoneMemberRequest' + responses: + '201': + description: '' + + '/zones/{zone}/records': + parameters: + - $ref: '#/components/parameters/ZoneName' + get: + security: + - ApiToken: [] + responses: + '200': + description: '' + content: + application/json: + schema: + $ref: '#/components/schemas/RecordList' + post: + security: + - ApiToken: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RecordList' + responses: + '200': + description: '' + put: + security: + - ApiToken: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateRecordsRequest' + responses: + '200': + description: '' + + delete: + security: + - ApiToken: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RecordList' + responses: + '200': + description: '' + diff --git a/dev-scripts/config/knot.conf b/dev-scripts/config/knot.conf new file mode 100644 index 0000000..e6e04ef --- /dev/null +++ b/dev-scripts/config/knot.conf @@ -0,0 +1,24 @@ +server: + listen: [ 0.0.0.0@5353, ::@5353 ] + +log: + - target: stderr + any: debug + +acl: + - id: example_acl + address: [ 127.0.0.1, ::1] + action: [transfer, update] + +template: + - id: default + file: "zones/%s.zone" + journal-content: all + zonefile-load: difference-no-serial + zonefile-sync: -1 + serial-policy: dateserial + +zone: + - domain: example.com + acl: example_acl + template: default diff --git a/dev-scripts/docker-compose.yml b/dev-scripts/docker-compose.yml new file mode 100644 index 0000000..5738e1c --- /dev/null +++ b/dev-scripts/docker-compose.yml @@ -0,0 +1,8 @@ +services: + knot: + image: cznic/knot + volumes: + - ./zones:/storage/zones:ro + - ./config:/config:ro + command: knotd + network_mode: host diff --git a/dev-scripts/zones/example.com.zone b/dev-scripts/zones/example.com.zone new file mode 100644 index 0000000..aea8c99 --- /dev/null +++ b/dev-scripts/zones/example.com.zone @@ -0,0 +1,13 @@ +example.com. IN SOA ns.example.com. admin.example.com. ( + 2020250101 ; serial + 28800 ; refresh (8 hours) + 7200 ; retry (2 hours) + 2419200 ; expire (4 weeks) + 300 ; minimum (5 minutes) + ) + +example.com. 84600 IN NS ns.example.com. + +srv1.example.com. 600 IN A 198.51.100.3 +srv1.example.com. 600 IN AAAA 2001:db8:cafe:bc68::2 +www 600 IN CNAME srv1 diff --git a/docs/Testing.md b/docs/Testing.md new file mode 100644 index 0000000..891bbed --- /dev/null +++ b/docs/Testing.md @@ -0,0 +1,22 @@ +# Testing + +To run the end-to-end tests the OpenAPI Python client should be generated first: +``` +openapi-generator generate -i ./api.yml -g python --package-name nomilo_client -o ./python_client +``` + +Then install the client, here a virtual env is created for this purpose: +``` +python -m venv env +env/bin/pip install ./python_client +``` + +Finally start the name server. It will listen on `127.0.0.1:5353`, be sure to update the configuration accordingly. +``` +docker-compose -f ./dev-scripts/docker-compose.yml up -d +``` + +You are now all set to run the e2e tests. Note that Nomilo must be started first. +``` +env/bin/python -m unittest e2e/*.py +``` \ No newline at end of file diff --git a/e2e/zones.py b/e2e/zones.py new file mode 100644 index 0000000..d4784de --- /dev/null +++ b/e2e/zones.py @@ -0,0 +1,157 @@ +from nomilo_client import ApiClient, Configuration +from nomilo_client.api.default_api import DefaultApi +from nomilo_client.models import ( + TokenRequest, + RecordTypeSOA, + RecordTypeAAAA, + RecordTypeCNAME, + RecordTypeNS, + RecordTypeTXT, + RecordList, + UpdateRecordsRequest, +) + +import logging +import string +import random + +import unittest +import warnings + + +logging.basicConfig(level=logging.DEBUG) + +HOST = 'http://localhost:8000/api/v1' +USER='toto' +PASSWORD='supersecure' + + +def build_api(host: str): + conf = Configuration(host=HOST) + api_client = ApiClient(configuration=conf) + return DefaultApi(api_client) + +def build_authenticated_api(host: str, token: TokenRequest): + auth_conf = Configuration(host=host, access_token=token.token) + api_client = ApiClient(configuration=auth_conf) + return DefaultApi(api_client) + +def random_string(length): + return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + +def random_name(zone): + return '%s.%s' % (random_string(16), zone) + + +class TestZones(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Ignore warning about unclosed socket + warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) + + api = build_api(HOST) + token = api.users_me_token_post(token_request=TokenRequest(username=USER,password=PASSWORD)) + cls.api = build_authenticated_api(HOST, token) + + def test_get_zones(self): + zones = self.api.zones_get() + zone_name = zones.value[0].name + self.assertEqual(zone_name, 'example.com.') + + def test_get_records(self): + records = self.api.zones_zone_records_get(zone='example.com.') + for record in records.value: + if type(record) is RecordTypeSOA: + with self.subTest(type='soa'): + self.assertEqual(record.name, 'example.com.') + + if type(record) is RecordTypeAAAA: + with self.subTest(type='ns'): + self.assertEqual(record.name, 'srv1.example.com.') + self.assertEqual(record.address, '2001:db8:cafe:bc68::2') + + if type(record) is RecordTypeCNAME: + with self.subTest(type='cname'): + self.assertEqual(record.name, 'www.example.com.') + self.assertEqual(record.target, 'srv1.example.com.') + + if type(record) is RecordTypeNS: + with self.subTest(type='ns'): + self.assertEqual(record.name, 'example.com.') + self.assertEqual(record.target, 'ns.example.com.') + + def test_create_records(self): + new_record = RecordTypeTXT( + _class='IN', + ttl=300, + name=random_name('example.com.'), + text=random_string(32), + type='TXT' + ) + + self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList(value=[new_record])) + records = self.api.zones_zone_records_get(zone='example.com.') + found = False + for record in records.value: + if type(record) is RecordTypeTXT and record.name == new_record.name: + self.assertEqual(record.text, new_record.text, msg='New record does not have the expected value') + found = True + + self.assertTrue(found, msg='New record not found in zone records') + + def test_update_records(self): + name = random_name('example.com.') + old_record = RecordTypeTXT( + _class='IN', + ttl=300, + name=name, + text='old value', + type='TXT' + ) + + new_record = RecordTypeTXT( + _class='IN', + ttl=300, + name=name, + text='new value', + type='TXT' + ) + + self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList([old_record])) + + update_records_request = UpdateRecordsRequest( + old_records=RecordList([old_record]), + new_records=RecordList([new_record]), + ) + + self.api.zones_zone_records_put(zone='example.com.', update_records_request=update_records_request) + + records = self.api.zones_zone_records_get(zone='example.com.') + found = False + for record in records.value: + if type(record) is RecordTypeTXT and record.name == name: + self.assertEqual(record.text, new_record.text, msg='New record does not have the expected value') + found = True + + self.assertTrue(found, msg='Updated record not found in zone records') + + def test_delete_records(self): + name = random_name('example.com.') + record = RecordTypeTXT( + _class='IN', + ttl=300, + name=name, + text=random_string(32), + type='TXT' + ) + + self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList([record])) + self.api.zones_zone_records_delete(zone='example.com.', record_list=RecordList([record])) + + records = self.api.zones_zone_records_get(zone='example.com.') + found = False + for record in records.value: + if type(record) is RecordTypeTXT and record.name == name: + found = True + + self.assertFalse(found, msg='Delete record found in zone records') \ No newline at end of file diff --git a/src/dns/client.rs b/src/dns/client.rs new file mode 100644 index 0000000..dfcd761 --- /dev/null +++ b/src/dns/client.rs @@ -0,0 +1,69 @@ +use std::{future::Future, pin::Pin, task::{Context, Poll}}; +use std::ops::{Deref, DerefMut}; + +use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}}; +use tokio::{net::TcpStream as TokioTcpStream, task}; +use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream}; +use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; + +use crate::config::Config; + + + +pub struct DnsClient(AsyncClient); + +impl Deref for DnsClient { + type Target = AsyncClient; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DnsClient { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DnsClient { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::>().await); + let (stream, handle) = TcpClientStream::>::new(config.dns.server); + let client = AsyncClient::with_timeout( + stream, + handle, + std::time::Duration::from_secs(5), + None); + let (client, bg) = match client.await { + Err(e) => { + println!("Failed to connect to DNS server {:#?}", e); + return Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => c + }; + task::spawn(bg); + Outcome::Success(DnsClient(client)) + } +} + +// Reimplement this type here as ClientReponse in trust-dns crate have private fields +pub struct ClientResponse(pub(crate) R) +where + R: Future> + Send + Unpin + 'static; + +impl Future for ClientResponse +where + R: Future> + Send + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // This is from the future_utils crate, we simply reuse the reexport from Rocket + rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) + } +} \ No newline at end of file diff --git a/src/dns/connector.rs b/src/dns/connector.rs new file mode 100644 index 0000000..70f75e8 --- /dev/null +++ b/src/dns/connector.rs @@ -0,0 +1,27 @@ +use crate::dns; + +// TODO: Use model types instead of dns types as input / output and only convert internaly? + +// Zone content api +// E.g.: DNS update + axfr, zone file read + write +#[async_trait] +pub trait RecordConnector { + type Error; + + async fn get_records(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result, Self::Error>; + async fn add_records(&mut self, zone: dns::Name, class: dns::DNSClass, new_records: Vec) -> Result<(), Self::Error>; + async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec, new_records: Vec) -> Result<(), Self::Error>; + async fn delete_records(&mut self, zone: dns::Name, class: dns::DNSClass, records: Vec) -> Result<(), Self::Error>; + // delete_records +} + +// Zone management api, todo +// E.g.: Manage catalog zone, dynamically generate knot / bind / nsd config... +#[async_trait] +pub trait ZoneConnector { + type Error; + // get_zones + // add_zone + // delete_zone + async fn zone_exists(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<(), Self::Error>; +} \ No newline at end of file diff --git a/src/dns/dns_connector.rs b/src/dns/dns_connector.rs new file mode 100644 index 0000000..aad60e5 --- /dev/null +++ b/src/dns/dns_connector.rs @@ -0,0 +1,236 @@ +use trust_dns_proto::DnsHandle; +use trust_dns_client::client::ClientHandle; +use trust_dns_client::rr::{DNSClass, RecordType}; +use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode}; +use trust_dns_client::error::ClientError; + +use super::{Name, Record, RData}; +use super::client::{ClientResponse, DnsClient}; +use super::connector::{RecordConnector, ZoneConnector}; + + +const MAX_PAYLOAD_LEN: u16 = 1232; + + +#[derive(Debug)] +pub enum DnsConnectorError { + ClientError(ClientError), + ResponceNotOk { + code: ResponseCode, + zone: Name, + }, +} + +pub struct DnsConnectorClient { + client: DnsClient +} + +impl DnsConnectorClient { + pub fn new(client: DnsClient) -> Self { + DnsConnectorClient { + client + } + } +} + + +#[async_trait] +impl RecordConnector for DnsConnectorClient { + type Error = DnsConnectorError; + + async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result, Self::Error> + { + let response = { + let query = self.client.query(zone.clone(), class, RecordType::AXFR); + query.await.map_err(|e| DnsConnectorError::ClientError(e))? + }; + + if response.response_code() != ResponseCode::NoError { + return Err(DnsConnectorError::ResponceNotOk { + code: response.response_code(), + zone: zone, + }); + } + + let answers = response.answers(); + let mut records: Vec<_> = answers.to_vec().into_iter() + .filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_))) + .collect(); + + // AXFR response ends with SOA, we remove it so it is not doubled in the response. + records.pop(); + Ok(records) + } + + async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec) -> Result<(), Self::Error> + { + // Taken from trust_dns_client::op::update_message::append + // The original function can not be used as is because it takes a RecordSet and not a Record list + + let mut zone_query = Query::new(); + zone_query.set_name(zone.clone()) + .set_query_class(class) + .set_query_type(RecordType::SOA); + + let mut message = Message::new(); + + // TODO: set random / time based id + message + .set_id(0) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Update) + .set_recursion_desired(false); + message.add_zone(zone_query); + message.add_updates(new_records); + + { + let edns = message.edns_mut(); + edns.set_max_payload(MAX_PAYLOAD_LEN); + edns.set_version(0); + } + + let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?; + + if response.response_code() != ResponseCode::NoError { + return Err(DnsConnectorError::ResponceNotOk { + code: response.response_code(), + zone: zone, + }); + } + + Ok(()) + } + + async fn update_records(&mut self, zone: Name, class: DNSClass, old_records: Vec, new_records: Vec) -> Result<(), Self::Error> + { + // Taken from trust_dns_client::op::update_message::compare_and_swap + // The original function can not be used as is because it takes a RecordSet and not a Record list + + // for updates, the query section is used for the zone + let mut zone_query = Query::new(); + zone_query.set_name(zone.clone()) + .set_query_class(class) + .set_query_type(RecordType::SOA); + + let mut message: Message = Message::new(); + + // build the message + // TODO: set random / time based id + message + .set_id(0) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Update) + .set_recursion_desired(false); + message.add_zone(zone_query); + + // make sure the record is what is expected + let mut prerequisite = old_records.clone(); + for record in prerequisite.iter_mut() { + record.set_ttl(0); + } + message.add_pre_requisites(prerequisite); + + // add the delete for the old record + let mut delete = old_records; + for record in delete.iter_mut() { + // the class must be none for delete + record.set_dns_class(DNSClass::NONE); + // the TTL should be 0 + record.set_ttl(0); + } + message.add_updates(delete); + + // insert the new record... + message.add_updates(new_records); + + // Extended dns + { + let edns = message.edns_mut(); + edns.set_max_payload(MAX_PAYLOAD_LEN); + edns.set_version(0); + } + + let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?; + + if response.response_code() != ResponseCode::NoError { + return Err(DnsConnectorError::ResponceNotOk { + code: response.response_code(), + zone: zone, + }); + } + + Ok(()) + } + + async fn delete_records(&mut self, zone: Name, class: DNSClass, records: Vec) -> Result<(), Self::Error> + { + // for updates, the query section is used for the zone + let mut zone_query = Query::new(); + zone_query.set_name(zone.clone()) + .set_query_class(class) + .set_query_type(RecordType::SOA); + + let mut message: Message = Message::new(); + + // build the message + // TODO: set random / time based id + message + .set_id(0) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Update) + .set_recursion_desired(false); + message.add_zone(zone_query); + + let mut delete = records; + for record in delete.iter_mut() { + // the class must be none for delete + record.set_dns_class(DNSClass::NONE); + // the TTL should be 0 + record.set_ttl(0); + } + message.add_updates(delete); + + // Extended dns + { + let edns = message.edns_mut(); + edns.set_max_payload(MAX_PAYLOAD_LEN); + edns.set_version(0); + } + + let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsConnectorError::ClientError(e))?; + + if response.response_code() != ResponseCode::NoError { + return Err(DnsConnectorError::ResponceNotOk { + code: response.response_code(), + zone: zone, + }); + } + + Ok(()) + + } +} + + +#[async_trait] +impl ZoneConnector for DnsConnectorClient { + type Error = DnsConnectorError; + + async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> Result<(), Self::Error> + { + let response = { + let query = self.client.query(zone.clone(), class, RecordType::SOA); + query.await.map_err(|e| DnsConnectorError::ClientError(e))? + }; + + if response.response_code() != ResponseCode::NoError { + return Err(DnsConnectorError::ResponceNotOk { + code: response.response_code(), + zone: zone, + }); + } + + Ok(()) + } + +} \ No newline at end of file diff --git a/src/dns/mod.rs b/src/dns/mod.rs new file mode 100644 index 0000000..a5fc33d --- /dev/null +++ b/src/dns/mod.rs @@ -0,0 +1,17 @@ +pub mod client; +pub mod dns_connector; +pub mod connector; + +// Reexport trust dns types for convenience +pub use trust_dns_client::rr::rdata::{ + DNSSECRData, caa, sshfp, mx, null, soa, srv, txt +}; +pub use trust_dns_client::rr::{ + RData, DNSClass, Record +}; +pub use trust_dns_proto::rr::Name; + +// Reexport module types +pub use connector::{RecordConnector, ZoneConnector}; +pub use dns_connector::{DnsConnectorClient, DnsConnectorError}; +pub use client::DnsClient; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 7dff5a5..4870eee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ mod models; mod config; mod schema; mod routes; +mod dns; use routes::users::*; use routes::zones::*; @@ -26,9 +27,13 @@ async fn rocket() -> rocket::Rocket { .attach(DbConn::fairing()) .mount("/api/v1", routes![ get_zone_records, + create_zone_records, + update_zone_records, + delete_zone_records, get_zones, + create_zone, add_member_to_zone, create_auth_token, - create_user + create_user, ]) } diff --git a/src/models/auth.rs b/src/models/auth.rs new file mode 100644 index 0000000..05fca5e --- /dev/null +++ b/src/models/auth.rs @@ -0,0 +1,63 @@ +use uuid::Uuid; +use serde::{Serialize, Deserialize}; +use chrono::serde::ts_seconds; +use chrono::prelude::{DateTime, Utc}; +use chrono::Duration; +use jsonwebtoken::{ + encode, decode, + Header, Validation, + Algorithm as JwtAlgorithm, EncodingKey, DecodingKey, + errors::Result as JwtResult +}; + +use crate::models::user::UserInfo; + + + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthClaims { + pub jti: String, + pub sub: String, + #[serde(with = "ts_seconds")] + pub exp: DateTime, + #[serde(with = "ts_seconds")] + pub iat: DateTime, +} + +#[derive(Debug, Serialize)] +pub struct AuthTokenResponse { + pub token: String +} + +#[derive(Debug, Deserialize)] +pub struct AuthTokenRequest { + pub username: String, + pub password: String, +} + +impl AuthClaims { + pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims { + let jti = Uuid::new_v4().to_simple().to_string(); + let iat = Utc::now(); + let exp = iat + token_duration; + + AuthClaims { + jti, + sub: user_info.id.clone(), + exp, + iat, + } + } + + pub fn decode(token: &str, secret: &str) -> JwtResult { + decode::( + token, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::new(JwtAlgorithm::HS256) + ).map(|data| data.claims) + } + + pub fn encode(self, secret: &str) -> JwtResult { + encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref())) + } +} \ No newline at end of file diff --git a/src/models/class.rs b/src/models/class.rs new file mode 100644 index 0000000..7cb45d9 --- /dev/null +++ b/src/models/class.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +use crate::dns; + + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub enum DNSClass { + IN, + CH, + HS, + NONE, + ANY, + OPT(u16), +} + +impl From for DNSClass { + fn from(dns_class: dns::DNSClass) -> DNSClass { + match dns_class { + dns::DNSClass::IN => DNSClass::IN, + dns::DNSClass::CH => DNSClass::CH, + dns::DNSClass::HS => DNSClass::HS, + dns::DNSClass::NONE => DNSClass::NONE, + dns::DNSClass::ANY => DNSClass::ANY, + dns::DNSClass::OPT(v) => DNSClass::OPT(v), + } + } +} + +impl From for dns::DNSClass { + fn from(dns_class: DNSClass) -> dns::DNSClass { + match dns_class { + DNSClass::IN => dns::DNSClass::IN, + DNSClass::CH => dns::DNSClass::CH, + DNSClass::HS => dns::DNSClass::HS, + DNSClass::NONE => dns::DNSClass::NONE, + DNSClass::ANY => dns::DNSClass::ANY, + DNSClass::OPT(v) => dns::DNSClass::OPT(v), + } + } +} \ No newline at end of file diff --git a/src/models/dns.rs b/src/models/dns.rs deleted file mode 100644 index 2bb22b9..0000000 --- a/src/models/dns.rs +++ /dev/null @@ -1,329 +0,0 @@ -use std::net::{Ipv6Addr, Ipv4Addr}; -use std::fmt; -use std::ops::{Deref, DerefMut}; - - -use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}}; - -use serde::{Deserialize, Deserializer, Serialize}; - -use tokio::{net::TcpStream as TokioTcpStream, task}; - -use trust_dns_client::{client::AsyncClient, serialize::binary::BinEncoder, tcp::TcpClientStream}; -use trust_dns_proto::error::{ProtoError}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; - - -use super::trust_dns_types::{self, Name}; -use crate::config::Config; - - -#[derive(Deserialize, Serialize)] -#[serde(tag = "Type")] -#[serde(rename_all = "UPPERCASE")] -pub enum RData { - #[serde(rename_all = "PascalCase")] - A { - address: Ipv4Addr - }, - #[serde(rename_all = "PascalCase")] - AAAA { - address: Ipv6Addr - }, - #[serde(rename_all = "PascalCase")] - CAA { - issuer_critical: bool, - value: String, - property_tag: String, - }, - #[serde(rename_all = "PascalCase")] - CNAME { - target: String - }, - // HINFO(HINFO), - // HTTPS(SVCB), - #[serde(rename_all = "PascalCase")] - MX { - preference: u16, - mail_exchanger: String - }, - // NAPTR(NAPTR), - #[serde(rename_all = "PascalCase")] - NULL { - data: String - }, - #[serde(rename_all = "PascalCase")] - NS { - target: String - }, - // OPENPGPKEY(OPENPGPKEY), - // OPT(OPT), - #[serde(rename_all = "PascalCase")] - PTR { - target: String - }, - #[serde(rename_all = "PascalCase")] - SOA { - master_server_name: String, - maintainer_name: String, - refresh: i32, - retry: i32, - expire: i32, - minimum: u32, - serial: u32 - }, - #[serde(rename_all = "PascalCase")] - SRV { - server: String, - port: u16, - priority: u16, - weight: u16, - }, - #[serde(rename_all = "PascalCase")] - SSHFP { - algorithm: u8, - digest_type: u8, - fingerprint: String, - }, - // SVCB(SVCB), - // TLSA(TLSA), - #[serde(rename_all = "PascalCase")] - TXT { - text: String - }, - - // TODO: Eventually allow deserialization of DNSSEC records - #[serde(skip)] - DNSSEC(trust_dns_types::DNSSECRData), - #[serde(rename_all = "PascalCase")] - Unknown { - code: u16, - data: String, - }, - // ZERO, -} - -impl From for RData { - fn from(rdata: trust_dns_types::RData) -> RData { - match rdata { - trust_dns_types::RData::A(address) => RData::A { address }, - trust_dns_types::RData::AAAA(address) => RData::AAAA { address }, - // Still a draft, no iana number yet, I don't to put something that is not currently supported so that's why NULL and not unknown. - // TODO: probably need better error here, I don't know what to do about that as this would require to change the From for something else. - // (empty data because I'm lazy) - trust_dns_types::RData::ANAME(_) => RData::NULL { - data: String::new() - }, - trust_dns_types::RData::CNAME(target) => RData::CNAME { - target: target.to_utf8() - }, - trust_dns_types::RData::CAA(caa) => RData::CAA { - issuer_critical: caa.issuer_critical(), - value: format!("{}", CAAValue(caa.value())), - property_tag: caa.tag().as_str().to_string(), - }, - trust_dns_types::RData::MX(mx) => RData::MX { - preference: mx.preference(), - mail_exchanger: mx.exchange().to_utf8() - }, - trust_dns_types::RData::NULL(null) => RData::NULL { - data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default()) - }, - trust_dns_types::RData::NS(target) => RData::NS { - target: target.to_utf8() - }, - trust_dns_types::RData::PTR(target) => RData::PTR { - target: target.to_utf8() - }, - trust_dns_types::RData::SOA(soa) => RData::SOA { - master_server_name: soa.mname().to_utf8(), - maintainer_name: soa.rname().to_utf8(), - refresh: soa.refresh(), - retry: soa.retry(), - expire: soa.expire(), - minimum: soa.minimum(), - serial: soa.serial() - }, - trust_dns_types::RData::SRV(srv) => RData::SRV { - server: srv.target().to_utf8(), - port: srv.port(), - priority: srv.priority(), - weight: srv.weight(), - }, - trust_dns_types::RData::SSHFP(sshfp) => RData::SSHFP { - algorithm: sshfp.algorithm().into(), - digest_type: sshfp.fingerprint_type().into(), - fingerprint: trust_dns_types::sshfp::HEX.encode(sshfp.fingerprint()), - }, - trust_dns_types::RData::TXT(txt) => RData::TXT { text: format!("{}", txt) }, - trust_dns_types::RData::DNSSEC(data) => RData::DNSSEC(data), - rdata => { - let code = rdata.to_record_type().into(); - let mut data = Vec::new(); - let mut encoder = BinEncoder::new(&mut data); - // TODO: need better error handling (use TryFrom ?) - rdata.emit(&mut encoder).expect("could not encode data"); - - RData::Unknown { - code, - data: base64::encode(data), - } - } - } - } -} - -struct CAAValue<'a>(&'a trust_dns_types::caa::Value); - -// trust_dns Display implementation panics if no parameters -// Implementation based on caa::emit_value -// Also the quotes are strips to render in JSON -impl<'a> fmt::Display for CAAValue<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self.0 { - trust_dns_types::caa::Value::Issuer(name, parameters) => { - - if let Some(name) = name { - write!(f, "{}", name)?; - } - - if name.is_none() && parameters.is_empty() { - write!(f, ";")?; - } - - for value in parameters { - write!(f, "; {}", value)?; - } - } - trust_dns_types::caa::Value::Url(url) => write!(f, "{}", url)?, - trust_dns_types::caa::Value::Unknown(v) => write!(f, "{:?}", v)?, - } - Ok(()) - } -} - -#[derive(Deserialize, Serialize)] -pub enum DNSClass { - IN, - CH, - HS, - NONE, - ANY, - OPT(u16), -} - -impl From for DNSClass { - fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass { - match dns_class { - trust_dns_types::DNSClass::IN => DNSClass::IN, - trust_dns_types::DNSClass::CH => DNSClass::CH, - trust_dns_types::DNSClass::HS => DNSClass::HS, - trust_dns_types::DNSClass::NONE => DNSClass::NONE, - trust_dns_types::DNSClass::ANY => DNSClass::ANY, - trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v), - } - } -} - - -#[derive(Deserialize, Serialize)] -pub struct Record { - #[serde(rename = "Name")] - pub name: String, - #[serde(rename = "Class")] - pub dns_class: DNSClass, - #[serde(rename = "TTL")] - pub ttl: u32, - #[serde(flatten)] - pub rdata: RData, -} - -impl From for Record { - fn from(record: trust_dns_types::Record) -> Record { - Record { - name: record.name().to_utf8(), - //rr_type: record.rr_type().into(), - dns_class: record.dns_class().into(), - ttl: record.ttl(), - rdata: record.into_data().into(), - } - } -} - -#[derive(Debug)] -pub struct AbsoluteName(Name); - -impl<'r> FromParam<'r> for AbsoluteName { - type Error = ProtoError; - - fn from_param(param: &'r str) -> Result { - let mut name = Name::from_utf8(¶m)?; - if !name.is_fqdn() { - name.set_fqdn(true); - } - Ok(AbsoluteName(name)) - } -} - -impl<'de> Deserialize<'de> for AbsoluteName { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de> - { - use serde::de::Error; - - String::deserialize(deserializer) - .and_then(|string| - AbsoluteName::from_param(&string) - .map_err(|e| Error::custom(e.to_string())) - ) - } -} - -impl Deref for AbsoluteName { - type Target = Name; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - - -pub struct DnsClient(AsyncClient); - -impl Deref for DnsClient { - type Target = AsyncClient; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for DnsClient { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for DnsClient { - type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { - let config = try_outcome!(request.guard::>().await); - let (stream, handle) = TcpClientStream::>::new(config.dns.server); - let client = AsyncClient::with_timeout( - stream, - handle, - std::time::Duration::from_secs(5), - None); - let (client, bg) = match client.await { - Err(e) => { - println!("Failed to connect to DNS server {:#?}", e); - return Outcome::Failure((Status::InternalServerError, ())) - }, - Ok(c) => c - }; - task::spawn(bg); - Outcome::Success(DnsClient(client)) - } -} diff --git a/src/models/errors.rs b/src/models/errors.rs index 589d10d..18ff43c 100644 --- a/src/models/errors.rs +++ b/src/models/errors.rs @@ -3,8 +3,37 @@ use rocket::http::Status; use rocket::request::{Request, Outcome}; use rocket::response::{self, Response, Responder}; use rocket_contrib::json::Json; -use crate::models::users::UserError; use serde_json::Value; +use djangohashers::{HasherError}; +use diesel::result::Error as DieselError; +use crate::dns::DnsConnectorError; +use crate::models; + +#[derive(Debug)] +pub enum UserError { + ZoneNotFound, + NotFound, + UserConflict, + BadCreds, + BadToken, + ExpiredToken, + MalformedHeader, + PermissionDenied, + DbError(DieselError), + PasswordError(HasherError), +} + +impl From for UserError { + fn from(e: HasherError) -> Self { + UserError::PasswordError(e) + } +} + +impl From for UserError { + fn from(e: DieselError) -> Self { + UserError::DbError(e) + } +} #[derive(Serialize, Debug)] @@ -70,6 +99,55 @@ impl From for ErrorResponse { } } +impl From for ErrorResponse { + fn from(e: DnsConnectorError) -> Self { + match e { + DnsConnectorError::ResponceNotOk { code, zone } => { + println!("Query for zone {} failed with code {}", zone, code); + + ErrorResponse::new( + Status::NotFound, + "Zone could not be found".into() + ).with_details(json!({ + "zone_name": zone.to_utf8() + })) + }, + DnsConnectorError::ClientError(e) => make_500(e) + } + } +} + +impl From for ErrorResponse { + fn from(e: models::RecordListParseError) -> Self { + match e { + models::RecordListParseError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone} => { + ErrorResponse::new( + Status::BadRequest, + "Record list contains records that do not belong to the zone".into() + ).with_details( + json!({ + "zone_name": zone.to_utf8(), + "class": models::DNSClass::from(class), + "mismatched_class": mismatched_class, + "mismatched_zone": mismatched_zone, + }) + ) + }, + models::RecordListParseError::ParseError { zone, bad_records } => { + ErrorResponse::new( + Status::BadRequest, + "Record list contains records that could not be parsed into DNS records".into() + ).with_details( + json!({ + "zone_name": zone.to_utf8(), + "records": bad_records + }) + ) + } + } + } +} + impl From for Outcome { fn from(e: ErrorResponse) -> Self { diff --git a/src/models/mod.rs b/src/models/mod.rs index dd71b7c..67a95b8 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,13 +1,19 @@ -pub mod dns; +//pub mod dns; +pub mod auth; +pub mod class; pub mod errors; -pub mod users; +pub mod name; +pub mod rdata; +pub mod record; +pub mod user; +pub mod zone; -pub mod trust_dns_types { - pub use trust_dns_client::rr::rdata::{ - DNSSECRData, caa, sshfp, - }; - pub use trust_dns_client::rr::{ - RData, DNSClass, Record - }; - pub use trust_dns_proto::rr::Name; -} +// Reexport types for convenience +pub use auth::{AuthClaims, AuthTokenRequest, AuthTokenResponse}; +pub use class::DNSClass; +pub use errors::{UserError, ErrorResponse, make_500}; +pub use name::{AbsoluteName, SerdeName}; +pub use user::{LocalUser, UserInfo, Role, UserZone, User, CreateUserRequest}; +pub use rdata::RData; +pub use record::{Record, RecordList, ParseRecordList, RecordListParseError, UpdateRecordsRequest}; +pub use zone::{Zone, AddZoneMemberRequest, CreateZoneRequest}; \ No newline at end of file diff --git a/src/models/name.rs b/src/models/name.rs new file mode 100644 index 0000000..11438e9 --- /dev/null +++ b/src/models/name.rs @@ -0,0 +1,78 @@ +use std::ops::Deref; + + +use rocket::request::FromParam; +use serde::{Deserialize, Serialize, Deserializer, Serializer}; +use trust_dns_proto::error::ProtoError; + +use crate::dns::Name; + + +#[derive(Debug, Clone)] +pub struct SerdeName(pub(crate)Name); + +impl Deref for SerdeName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'de> Deserialize<'de> for SerdeName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de> + { + use serde::de::Error; + + String::deserialize(deserializer) + .and_then(|string| + Name::from_utf8(&string) + .map_err(|e| Error::custom(e.to_string())) + ).map( SerdeName) + } +} + +impl Serialize for SerdeName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer + { + self.0.to_utf8().serialize(serializer) + } +} + +impl SerdeName { + pub fn into_inner(self) -> Name { + self.0 + } +} + + +#[derive(Debug, Deserialize)] +pub struct AbsoluteName(SerdeName); + +impl<'r> FromParam<'r> for AbsoluteName { + type Error = ProtoError; + + fn from_param(param: &'r str) -> Result { + let mut name = Name::from_utf8(¶m)?; + if !name.is_fqdn() { + name.set_fqdn(true); + } + Ok(AbsoluteName(SerdeName(name))) + } +} + +impl Deref for AbsoluteName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0.0 + } +} + +impl AbsoluteName { + pub fn into_inner(self) -> Name { + self.0.0 + } +} \ No newline at end of file diff --git a/src/models/rdata.rs b/src/models/rdata.rs new file mode 100644 index 0000000..34b81e7 --- /dev/null +++ b/src/models/rdata.rs @@ -0,0 +1,282 @@ +use std::fmt; +use std::convert::TryFrom; +use std::net::{Ipv6Addr, Ipv4Addr}; + +use serde::{Deserialize, Serialize}; + +use trust_dns_client::serialize::binary::BinEncoder; +use trust_dns_proto::error::ProtoError; + +use crate::dns; +use super::name::SerdeName; + + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(tag = "Type")] +#[serde(rename_all = "UPPERCASE")] +pub enum RData { + #[serde(rename_all = "PascalCase")] + A { + address: Ipv4Addr + }, + #[serde(rename_all = "PascalCase")] + AAAA { + address: Ipv6Addr + }, + #[serde(rename_all = "PascalCase")] + CAA { + issuer_critical: bool, + value: String, + property_tag: String, + }, + #[serde(rename_all = "PascalCase")] + CNAME { + target: SerdeName + }, + // HINFO(HINFO), + // HTTPS(SVCB), + #[serde(rename_all = "PascalCase")] + MX { + preference: u16, + mail_exchanger: SerdeName + }, + // NAPTR(NAPTR), + #[serde(rename_all = "PascalCase")] + NULL { + data: String + }, + #[serde(rename_all = "PascalCase")] + NS { + target: SerdeName + }, + // OPENPGPKEY(OPENPGPKEY), + // OPT(OPT), + #[serde(rename_all = "PascalCase")] + PTR { + target: SerdeName + }, + #[serde(rename_all = "PascalCase")] + SOA { + master_server_name: SerdeName, + maintainer_name: SerdeName, + refresh: i32, + retry: i32, + expire: i32, + minimum: u32, + serial: u32 + }, + #[serde(rename_all = "PascalCase")] + SRV { + server: SerdeName, + port: u16, + priority: u16, + weight: u16, + }, + #[serde(rename_all = "PascalCase")] + SSHFP { + algorithm: u8, + digest_type: u8, + fingerprint: String, + }, + // SVCB(SVCB), + // TLSA(TLSA), + #[serde(rename_all = "PascalCase")] + TXT { + text: String + }, + + // TODO: Eventually allow deserialization of DNSSEC records + #[serde(skip)] + DNSSEC(dns::DNSSECRData), + #[serde(rename_all = "PascalCase")] + Unknown { + code: u16, + data: String, + }, + // ZERO, + + // TODO: DS + // TODO: TLSA +} + +impl From for RData { + fn from(rdata: dns::RData) -> RData { + match rdata { + dns::RData::A(address) => RData::A { address }, + dns::RData::AAAA(address) => RData::AAAA { address }, + // Still a draft, no iana number yet, I don't to put something that is not currently supported so that's why NULL and not unknown. + // TODO: probably need better error here, I don't know what to do about that as this would require to change the From for something else. + // (empty data because I'm lazy) + dns::RData::ANAME(_) => RData::NULL { + data: String::new() + }, + dns::RData::CNAME(target) => RData::CNAME { + target: SerdeName(target) + }, + dns::RData::CAA(caa) => RData::CAA { + issuer_critical: caa.issuer_critical(), + value: format!("{}", CAAValue(caa.value())), + property_tag: caa.tag().as_str().to_string(), + }, + dns::RData::MX(mx) => RData::MX { + preference: mx.preference(), + mail_exchanger: SerdeName(mx.exchange().clone()) + }, + dns::RData::NULL(null) => RData::NULL { + data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default()) + }, + dns::RData::NS(target) => RData::NS { + target: SerdeName(target) + }, + dns::RData::PTR(target) => RData::PTR { + target: SerdeName(target) + }, + dns::RData::SOA(soa) => RData::SOA { + master_server_name: SerdeName(soa.mname().clone()), + maintainer_name: SerdeName(soa.rname().clone()), + refresh: soa.refresh(), + retry: soa.retry(), + expire: soa.expire(), + minimum: soa.minimum(), + serial: soa.serial() + }, + dns::RData::SRV(srv) => RData::SRV { + server: SerdeName(srv.target().clone()), + port: srv.port(), + priority: srv.priority(), + weight: srv.weight(), + }, + dns::RData::SSHFP(sshfp) => RData::SSHFP { + algorithm: sshfp.algorithm().into(), + digest_type: sshfp.fingerprint_type().into(), + fingerprint: dns::sshfp::HEX.encode(sshfp.fingerprint()), + }, + //TODO: This might alter data if not utf8 compatible, probably need to be replaced + //TODO: check whether concatenating txt data is harmful or not + dns::RData::TXT(txt) => RData::TXT { text: format!("{}", txt) }, + dns::RData::DNSSEC(data) => RData::DNSSEC(data), + rdata => { + let code = rdata.to_record_type().into(); + let mut data = Vec::new(); + let mut encoder = BinEncoder::new(&mut data); + // TODO: need better error handling (use TryFrom ?) + rdata.emit(&mut encoder).expect("could not encode data"); + + RData::Unknown { + code, + data: base64::encode(data), + } + } + } + } +} + +impl TryFrom for dns::RData { + type Error = ProtoError; + + fn try_from(rdata: RData) -> Result { + Ok(match rdata { + RData::A { address } => dns::RData::A(address), + RData::AAAA { address } => dns::RData::AAAA(address), + // TODO: Round trip test all types below (currently not tested...) + RData::CAA { issuer_critical, value, property_tag } => { + let property = dns::caa::Property::from(property_tag); + let caa_value = { + // TODO: duplicate of trust_dns_client::serialize::txt::rdata_parser::caa::parse + // because caa::read_value is private + match property { + dns::caa::Property::Issue | dns::caa::Property::IssueWild => { + let value = dns::caa::read_issuer(value.as_bytes())?; + dns::caa::Value::Issuer(value.0, value.1) + } + dns::caa::Property::Iodef => { + let url = dns::caa::read_iodef(value.as_bytes())?; + dns::caa::Value::Url(url) + } + dns::caa::Property::Unknown(_) => dns::caa::Value::Unknown(value.as_bytes().to_vec()), + } + }; + dns::RData::CAA(dns::caa::CAA { + issuer_critical, + tag: property, + value: caa_value, + }) + }, + RData::CNAME { target } => dns::RData::CNAME(target.into_inner()), + RData::MX { preference, mail_exchanger } => dns::RData::MX( + dns::mx::MX::new(preference, mail_exchanger.into_inner()) + ), + RData::NULL { data } => dns::RData::NULL( + dns::null::NULL::with( + base64::decode(data).map_err(|e| ProtoError::from(format!("{}", e)))? + ) + ), + RData::NS { target } => dns::RData::NS(target.into_inner()), + RData::PTR { target } => dns::RData::PTR(target.into_inner()), + RData::SOA { + master_server_name, + maintainer_name, + refresh, + retry, + expire, + minimum, + serial + } => dns::RData::SOA( + dns::soa::SOA::new( + master_server_name.into_inner(), + maintainer_name.into_inner(), + serial, + refresh, + retry, + expire, + minimum, + ) + ), + RData::SRV { server, port, priority, weight } => dns::RData::SRV( + dns::srv::SRV::new(priority, weight, port, server.into_inner()) + ), + RData::SSHFP { algorithm, digest_type, fingerprint } => dns::RData::SSHFP( + dns::sshfp::SSHFP::new( + // NOTE: This allows unassigned algorithms + dns::sshfp::Algorithm::from(algorithm), + dns::sshfp::FingerprintType::from(digest_type), + dns::sshfp::HEX.decode(fingerprint.as_bytes()).map_err(|e| ProtoError::from(format!("{}", e)))? + ) + ), + RData::TXT { text } => dns::RData::TXT(dns::txt::TXT::new(vec![text])), + // TODO: Error out for DNSSEC? Prefer downstream checks? + RData::DNSSEC(_) => todo!(), + // TODO: Disallow unknown? (could be used to bypass unsopported types?) Prefer downstream checks? + RData::Unknown { code: _code, data: _data } => todo!(), + }) + } +} + +struct CAAValue<'a>(&'a dns::caa::Value); + +// trust_dns Display implementation panics if no parameters +// Implementation based on caa::emit_value +// Also the quotes are strips to render in JSON +impl<'a> fmt::Display for CAAValue<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.0 { + dns::caa::Value::Issuer(name, parameters) => { + + if let Some(name) = name { + write!(f, "{}", name)?; + } + + if name.is_none() && parameters.is_empty() { + write!(f, ";")?; + } + + for value in parameters { + write!(f, "; {}", value)?; + } + } + dns::caa::Value::Url(url) => write!(f, "{}", url)?, + dns::caa::Value::Unknown(v) => write!(f, "{:?}", v)?, + } + Ok(()) + } +} diff --git a/src/models/record.rs b/src/models/record.rs new file mode 100644 index 0000000..8e51efd --- /dev/null +++ b/src/models/record.rs @@ -0,0 +1,122 @@ +use std::convert::{TryFrom, TryInto}; +use serde::{Deserialize, Serialize}; +use trust_dns_proto::error::ProtoError; + +use crate::dns; +use super::name::SerdeName; +use super::class::DNSClass; +use super::rdata::RData; + + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Record { + #[serde(rename = "Name")] + pub name: SerdeName, + // TODO: Make class optional, default to IN + #[serde(rename = "Class")] + pub dns_class: DNSClass, + #[serde(rename = "TTL")] + pub ttl: u32, + #[serde(flatten)] + pub rdata: RData, +} + +impl From for Record { + fn from(record: dns::Record) -> Record { + Record { + name: SerdeName(record.name().clone()), + dns_class: record.dns_class().into(), + ttl: record.ttl(), + rdata: record.into_data().into(), + } + } +} + +impl TryFrom for dns::Record { + type Error = ProtoError; + + fn try_from(record: Record) -> Result { + let mut trust_dns_record = dns::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?); + trust_dns_record.set_dns_class(record.dns_class.into()); + Ok(trust_dns_record) + } +} + + +pub type RecordList = Vec; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UpdateRecordsRequest { + pub old_records: RecordList, + pub new_records: RecordList, +} + +pub enum RecordListParseError { + ParseError { + bad_records: Vec, + zone: dns::Name, + }, + RecordNotInZone { + zone: dns::Name, + class: dns::DNSClass, + mismatched_class: Vec, + mismatched_zone: Vec, + }, +} + +pub trait ParseRecordList { + fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result, RecordListParseError>; +} + +impl ParseRecordList for RecordList { + fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result, RecordListParseError> { + // TODO: What about relative names (also in cnames and stuff) + let mut bad_records = Vec::new(); + let mut records: Vec = Vec::new(); + let mut mismatched_class: Vec = Vec::new(); + let mut mismatched_zone: Vec = Vec::new(); + + for record in self.into_iter() { + let this_record = record.clone(); + if let Ok(record) = dns::Record::try_from(record) { + let mut good_record = true; + + if !zone.zone_of(record.name()) { + mismatched_zone.push(this_record.clone()); + good_record = false; + } + + if record.dns_class() != class { + mismatched_class.push(this_record.clone()); + good_record = false; + } + + if good_record { + records.push(record); + } + } else { + bad_records.push(this_record.clone()); + } + } + + if !bad_records.is_empty() { + return Err(RecordListParseError::ParseError { + zone, + bad_records, + }); + } + + if !mismatched_class.is_empty() || !mismatched_zone.is_empty() { + return Err(RecordListParseError::RecordNotInZone { + zone, + class, + mismatched_zone, + mismatched_class + }); + } + + return Ok(records) + } +} + diff --git a/src/models/users.rs b/src/models/user.rs similarity index 65% rename from src/models/users.rs rename to src/models/user.rs index 2c9d334..95db8e8 100644 --- a/src/models/users.rs +++ b/src/models/user.rs @@ -3,28 +3,23 @@ use diesel::prelude::*; use diesel::result::Error as DieselError; use diesel_derive_enum::DbEnum; use rocket::{State, request::{FromRequest, Request, Outcome}}; -use serde::{Serialize, Deserialize}; -use chrono::serde::ts_seconds; -use chrono::prelude::{DateTime, Utc}; -use chrono::Duration; +use serde::{Deserialize}; // TODO: Maybe just use argon2 crate directly -use djangohashers::{make_password_with_algorithm, check_password, HasherError, Algorithm}; +use djangohashers::{make_password_with_algorithm, check_password, Algorithm}; use jsonwebtoken::{ - encode, decode, - Header, Validation, - Algorithm as JwtAlgorithm, EncodingKey, DecodingKey, - errors::Result as JwtResult, errors::ErrorKind as JwtErrorKind }; use crate::schema::*; use crate::DbConn; use crate::config::Config; -use crate::models::errors::{ErrorResponse, make_500}; +use crate::models::errors::{UserError, ErrorResponse, make_500}; +use crate::models::zone::Zone; +use crate::models::auth::AuthClaims; const BEARER: &str = "Bearer "; -const AUTH_HEADER: &str = "Authentication"; +const AUTH_HEADER: &str = "Authorization"; #[derive(Debug, DbEnum, Deserialize, Clone)] @@ -61,14 +56,6 @@ pub struct UserZone { pub zone_id: String, } -#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)] -#[table_name = "zone"] -pub struct Zone { - #[serde(skip)] - pub id: String, - pub name: String, -} - #[derive(Debug, Deserialize)] pub struct CreateUserRequest { pub username: String, @@ -77,37 +64,6 @@ pub struct CreateUserRequest { pub role: Option } -#[derive(Debug, Deserialize)] -pub struct AddZoneMemberRequest { - pub id: String, -} - -// pub struct LdapUserAssociation { -// user_id: Uuid, -// ldap_id: String -// } - -#[derive(Debug, Serialize, Deserialize)] -pub struct AuthClaims { - pub jti: String, - pub sub: String, - #[serde(with = "ts_seconds")] - pub exp: DateTime, - #[serde(with = "ts_seconds")] - pub iat: DateTime, -} - -#[derive(Debug, Serialize)] -pub struct AuthTokenResponse { - pub token: String -} - -#[derive(Debug, Deserialize)] -pub struct AuthTokenRequest { - pub username: String, - pub password: String, -} - #[derive(Debug)] pub struct UserInfo { pub id: String, @@ -200,32 +156,6 @@ impl<'r> FromRequest<'r> for UserInfo { } } -#[derive(Debug)] -pub enum UserError { - ZoneNotFound, - NotFound, - UserConflict, - BadCreds, - BadToken, - ExpiredToken, - MalformedHeader, - PermissionDenied, - DbError(DieselError), - PasswordError(HasherError), -} - -impl From for UserError { - fn from(e: HasherError) -> Self { - UserError::PasswordError(e) - } -} - -impl From for UserError { - fn from(e: DieselError) -> Self { - UserError::DbError(e) - } -} - impl LocalUser { pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result { use crate::schema::localuser::dsl::*; @@ -315,76 +245,4 @@ impl LocalUser { username: client_localuser.username, }) } - -} - -impl AuthClaims { - pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims { - let jti = Uuid::new_v4().to_simple().to_string(); - let iat = Utc::now(); - let exp = iat + token_duration; - - AuthClaims { - jti, - sub: user_info.id.clone(), - exp, - iat, - } - } - - pub fn decode(token: &str, secret: &str) -> JwtResult { - decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::new(JwtAlgorithm::HS256) - ).map(|data| data.claims) - } - - pub fn encode(self, secret: &str) -> JwtResult { - encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref())) - } -} - -// NOTE: Should probably not be implemented here -// also, "UserError" seems like a misleading name -impl Zone { - pub fn get_all(conn: &diesel::SqliteConnection) -> Result, UserError> { - use crate::schema::zone::dsl::*; - - zone.get_results(conn) - .map_err(UserError::DbError) - } - - pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result { - use crate::schema::zone::dsl::*; - - zone.filter(name.eq(zone_name)) - .get_result(conn) - .map_err(|e| match e { - DieselError::NotFound => UserError::ZoneNotFound, - other => UserError::DbError(other) - }) - } - - pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> { - use crate::schema::user_zone::dsl::*; - - let new_user_zone = UserZone { - zone_id: self.id.clone(), - user_id: new_member.id.clone() - }; - - let res = diesel::insert_into(user_zone) - .values(new_user_zone) - .execute(conn); - - match res { - // If user has already access to the zone, safely ignore the conflit - // TODO: use 'on conflict do nothing' in postgres when we get there - Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (), - Err(e) => return Err(e.into()), - Ok(_) => () - }; - Ok(()) - } -} +} \ No newline at end of file diff --git a/src/models/zone.rs b/src/models/zone.rs new file mode 100644 index 0000000..afca690 --- /dev/null +++ b/src/models/zone.rs @@ -0,0 +1,93 @@ +use crate::models::user::UserInfo; + +use uuid::Uuid; +use diesel::prelude::*; +use diesel::result::Error as DieselError; +use serde::{Serialize, Deserialize}; + +use crate::schema::*; +use super::name::AbsoluteName; +use super::user::UserZone; +use super::errors::UserError; + + +#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)] +#[table_name = "zone"] +pub struct Zone { + #[serde(skip)] + pub id: String, + pub name: String, +} + +#[derive(Debug, Deserialize)] +pub struct AddZoneMemberRequest { + pub id: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateZoneRequest { + pub name: AbsoluteName, +} + +// NOTE: Should probably not be implemented here +// also, "UserError" seems like a misleading name +impl Zone { + pub fn get_all(conn: &diesel::SqliteConnection) -> Result, UserError> { + use crate::schema::zone::dsl::*; + + zone.get_results(conn) + .map_err(UserError::DbError) + } + + pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result { + use crate::schema::zone::dsl::*; + + zone.filter(name.eq(zone_name)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::ZoneNotFound, + other => UserError::DbError(other) + }) + } + + pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result { + use crate::schema::zone::dsl::*; + + let new_zone = Zone { + id: Uuid::new_v4().to_simple().to_string(), + name: zone_request.name.to_utf8(), + }; + + diesel::insert_into(zone) + .values(&new_zone) + .execute(conn) + .map_err(|e| match e { + DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict, + other => UserError::DbError(other) + })?; + Ok(new_zone) + } + + + pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> { + use crate::schema::user_zone::dsl::*; + + let new_user_zone = UserZone { + zone_id: self.id.clone(), + user_id: new_member.id.clone() + }; + + let res = diesel::insert_into(user_zone) + .values(new_user_zone) + .execute(conn); + + match res { + // If user has already access to the zone, safely ignore the conflit + // TODO: use 'on conflict do nothing' in postgres when we get there + Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (), + Err(e) => return Err(e.into()), + Ok(_) => () + }; + Ok(()) + } +} \ No newline at end of file diff --git a/src/routes/users.rs b/src/routes/users.rs index e6967b3..bf0ba61 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -4,39 +4,32 @@ use rocket::http::Status; use crate::config::Config; use crate::DbConn; -use crate::models::errors::{ErrorResponse, make_500}; -use crate::models::users::{ - LocalUser, - CreateUserRequest, - AuthClaims, - AuthTokenRequest, - AuthTokenResponse -}; +use crate::models; #[post("/users/me/token", data = "")] pub async fn create_auth_token( conn: DbConn, config: State<'_, Config>, - auth_request: Json -) -> Result, ErrorResponse> { + auth_request: Json +) -> Result, models::ErrorResponse> { let user_info = conn.run(move |c| { - LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password) + models::LocalUser::get_user_by_creds(c, &auth_request.username, &auth_request.password) }).await?; - let token = AuthClaims::new(&user_info, config.web_app.token_duration) + let token = models::AuthClaims::new(&user_info, config.web_app.token_duration) .encode(&config.web_app.secret) - .map_err(make_500)?; + .map_err(models::make_500)?; - Ok(Json(AuthTokenResponse { token })) + Ok(Json(models::AuthTokenResponse { token })) } #[post("/users", data = "")] -pub async fn create_user<'r>(conn: DbConn, user_request: Json) -> Result, ErrorResponse> { +pub async fn create_user<'r>(conn: DbConn, user_request: Json) -> Result, models::ErrorResponse> { // TODO: Check current user if any to check if user has permission to create users (with or without role) conn.run(|c| { - LocalUser::create_user(&c, user_request.into_inner()) + models::LocalUser::create_user(&c, user_request.into_inner()) }).await?; Response::build() diff --git a/src/routes/zones.rs b/src/routes/zones.rs index c893b5d..94cdd89 100644 --- a/src/routes/zones.rs +++ b/src/routes/zones.rs @@ -3,69 +3,145 @@ use rocket::http::Status; use rocket_contrib::json::Json; -use trust_dns_client::client::ClientHandle; -use trust_dns_client::op::ResponseCode; -use trust_dns_client::rr::{DNSClass, RecordType}; - -use crate::{DbConn, models::dns}; -use crate::models::errors::{ErrorResponse, make_500}; -use crate::models::users::{LocalUser, UserInfo, Zone, AddZoneMemberRequest}; +use crate::DbConn; +use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector}; +use crate::models; +use crate::models::{ParseRecordList}; #[get("/zones//records")] pub async fn get_zone_records( - mut client: dns::DnsClient, + client: DnsClient, conn: DbConn, - user_info: Result, - zone: dns::AbsoluteName -) -> Result>, ErrorResponse> { + user_info: Result, + zone: models::AbsoluteName +) -> Result, models::ErrorResponse> { let user_info = user_info?; + let zone_name = zone.to_string(); - if !user_info.is_admin() { - let zone_name = zone.clone().to_string(); - conn.run(move |c| { + conn.run(move |c| { + if user_info.is_admin() { + models::Zone::get_by_name(c, &zone_name) + } else { user_info.get_zone(c, &zone_name) - }).await?; - } + } + }).await?; - let response = { - let query = client.query(zone.clone(), DNSClass::IN, RecordType::AXFR); - query.await.map_err(make_500)? - }; + let mut dns_api = DnsConnectorClient::new(client); - // TODO: Better error handling (ex. not authorized should be 500) - if response.response_code() != ResponseCode::NoError { - println!("Querrying of zone {} failed with code {}", *zone, response.response_code()); - return ErrorResponse::new( - Status::NotFound, - format!("Zone {} could not be found", *zone) - ).err() - } - - let answers = response.answers(); - let mut records: Vec<_> = answers.to_vec().into_iter() - .map(dns::Record::from) - .filter(|record| !matches!(record.rdata, dns::RData::NULL { .. } | dns::RData::DNSSEC(_))) - .collect(); - - // AXFR response ends with SOA, we remove it so it is not doubled in the response. - records.pop(); + let dns_records = dns_api.get_records(zone.clone(), models::DNSClass::IN.into()).await?; + let records: Vec<_> = dns_records.into_iter().map(models::Record::from).collect(); Ok(Json(records)) } -// TODO: the post version of that +#[post("/zones//records", data = "")] +pub async fn create_zone_records( + client: DnsClient, + conn: DbConn, + user_info: Result, + zone: models::AbsoluteName, + new_records: Json +) -> Result, models::ErrorResponse> { + + let user_info = user_info?; + let zone_name = zone.to_utf8(); + + conn.run(move |c| { + if user_info.is_admin() { + models::Zone::get_by_name(c, &zone_name) + } else { + user_info.get_zone(c, &zone_name) + } + }).await?; + + let mut dns_api = DnsConnectorClient::new(client); + + dns_api.add_records( + zone.clone(), + models::DNSClass::IN.into(), + new_records.into_inner().try_into_dns_type(zone.into_inner(), models::DNSClass::IN.into())? + ).await?; + + return Ok(Json(())); +} + +#[put("/zones//records", data = "")] +pub async fn update_zone_records( + client: DnsClient, + conn: DbConn, + user_info: Result, + zone: models::AbsoluteName, + update_records_request: Json +) -> Result, models::ErrorResponse> { + + let user_info = user_info?; + let zone = zone.into_inner(); + let zone_name = zone.to_utf8(); + let update_records_request = update_records_request.into_inner(); + + conn.run(move |c| { + if user_info.is_admin() { + models::Zone::get_by_name(c, &zone_name) + } else { + user_info.get_zone(c, &zone_name) + } + }).await?; + + let mut dns_api = DnsConnectorClient::new(client); + + dns_api.update_records( + zone.clone(), + models::DNSClass::IN.into(), + update_records_request.old_records.try_into_dns_type(zone.clone(), models::DNSClass::IN.into())?, + update_records_request.new_records.try_into_dns_type(zone, models::DNSClass::IN.into())?, + ).await?; + + return Ok(Json(())); +} + +#[delete("/zones//records", data = "")] +pub async fn delete_zone_records( + client: DnsClient, + conn: DbConn, + user_info: Result, + zone: models::AbsoluteName, + records: Json +) -> Result, models::ErrorResponse> { + + let user_info = user_info?; + let zone_name = zone.to_utf8(); + + conn.run(move |c| { + if user_info.is_admin() { + models::Zone::get_by_name(c, &zone_name) + } else { + user_info.get_zone(c, &zone_name) + } + }).await?; + + let mut dns_api = DnsConnectorClient::new(client); + + dns_api.delete_records( + zone.clone(), + models::DNSClass::IN.into(), + records.into_inner().try_into_dns_type(zone.into_inner(), models::DNSClass::IN.into())? + ).await?; + + return Ok(Json(())); +} + #[get("/zones")] pub async fn get_zones( conn: DbConn, - user_info: Result, -) -> Result>, ErrorResponse> { + user_info: Result, +) -> Result>, models::ErrorResponse> { let user_info = user_info?; let zones = conn.run(move |c| { if user_info.is_admin() { - Zone::get_all(c) + models::Zone::get_all(c) } else { user_info.get_zones(c) } @@ -74,25 +150,44 @@ pub async fn get_zones( Ok(Json(zones)) } +#[post("/zones", data = "")] +pub async fn create_zone( + conn: DbConn, + client: DnsClient, + user_info: Result, + zone_request: Json, +) -> Result, models::ErrorResponse> { + user_info?.check_admin()?; + + let mut dns_api = DnsConnectorClient::new(client); + dns_api.zone_exists(zone_request.name.clone(), models::DNSClass::IN.into()).await?; + + let zone = conn.run(move |c| { + models::Zone::create_zone(c, zone_request.into_inner()) + }).await?; + + Ok(Json(zone)) +} + #[post("/zones//members", data = "")] pub async fn add_member_to_zone<'r>( conn: DbConn, - zone: dns::AbsoluteName, - user_info: Result, - zone_member_request: Json -) -> Result, ErrorResponse> { + zone: models::AbsoluteName, + user_info: Result, + zone_member_request: Json +) -> Result, models::ErrorResponse> { let user_info = user_info?; let zone_name = zone.to_utf8(); conn.run(move |c| { let zone = if user_info.is_admin() { - Zone::get_by_name(c, &zone_name) + models::Zone::get_by_name(c, &zone_name) } else { user_info.get_zone(c, &zone_name) }?; - let new_member = LocalUser::get_user_by_uuid(c, &zone_member_request.id)?; + let new_member = models::LocalUser::get_user_by_uuid(c, &zone_member_request.id)?; zone.add_member(&c, &new_member) }).await?; diff --git a/src/schema.rs b/src/schema.rs index 775c738..10c2d36 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,6 +1,6 @@ table! { use diesel::sql_types::*; - use crate::models::users::*; + use crate::models::user::*; localuser (user_id) { user_id -> Text,