add records update

This commit is contained in:
Hannaeko 2022-03-04 21:55:27 +01:00
parent 77cc634257
commit 3767cc6ea0
12 changed files with 266 additions and 87 deletions

23
api.yml
View file

@ -307,6 +307,17 @@ components:
items:
$ref: '#/components/schemas/Record'
UpdateRecordsRequest:
type: object
required:
- oldRecords
- newRecords
properties:
oldRecords:
$ref: '#/components/schemas/RecordList'
newRecords:
$ref: '#/components/schemas/RecordList'
paths:
'/users':
@ -400,3 +411,15 @@ paths:
responses:
'200':
description: ''
put:
security:
- ApiToken: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdateRecordsRequest'
responses:
'200':
description: ''

View file

@ -7,7 +7,8 @@ from nomilo_client.models import (
RecordTypeCNAME,
RecordTypeNS,
RecordTypeTXT,
RecordList
RecordList,
UpdateRecordsRequest,
)
import logging
@ -97,3 +98,39 @@ class TestZones(unittest.TestCase):
found = True
self.assertTrue(found, msg='New record not found in zone records')
def test_update_records(self):
name = random_name('example.com.')
old_record = RecordTypeTXT(
_class='IN',
ttl=300,
name=name,
text='old value',
type='TXT'
)
new_record = RecordTypeTXT(
_class='IN',
ttl=300,
name=name,
text='new value',
type='TXT'
)
self.api.zones_zone_records_post(zone='example.com.', record_list=RecordList(value=[old_record]))
update_records_request = UpdateRecordsRequest(
old_records=RecordList([old_record]),
new_records=RecordList([new_record]),
)
self.api.zones_zone_records_put(zone='example.com.', update_records_request=update_records_request)
records = self.api.zones_zone_records_get(zone='example.com.')
found = False
for record in records.value:
if type(record) is RecordTypeTXT and record.name == name:
self.assertEqual(record.text, new_record.text, msg='New record does not have the expected value')
found = True
self.assertTrue(found, msg='New record not found in zone records')

View file

@ -10,7 +10,7 @@ pub trait RecordApi {
async fn get_records(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, Self::Error>;
async fn add_records(&mut self, zone: dns::Name, class: dns::DNSClass, new_records: Vec<dns::Record>) -> Result<(), Self::Error>;
// update_records
async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec<dns::Record>, new_records: Vec<dns::Record>) -> Result<(), Self::Error>;
// delete_records
}
@ -22,6 +22,5 @@ pub trait ZoneApi {
// get_zones
// add_zone
// delete_zone
// zone_exists
async fn zone_exists(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<(), Self::Error>;
}

View file

@ -1,8 +1,8 @@
use trust_dns_proto::DnsHandle;
use trust_dns_client::client::ClientHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode};
use trust_dns_client::error::ClientError;
use trust_dns_client::proto::xfer::{DnsRequestOptions};
use super::{Name, Record, RData};
use super::client::{ClientResponse, DnsClient};
@ -11,12 +11,6 @@ use super::api::{RecordApi, ZoneApi};
#[derive(Debug)]
pub enum DnsApiError {
RecordNotInZone {
zone: Name,
class: DNSClass,
mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>,
},
ClientError(ClientError),
ResponceNotOk {
code: ResponseCode,
@ -44,9 +38,8 @@ impl RecordApi for DnsApiClient {
async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result<Vec<Record>, Self::Error>
{
let response = {
let mut query = Query::query(zone.clone(), RecordType::AXFR);
query.set_query_class(class);
ClientResponse(self.client.lookup(query, DnsRequestOptions::default())).await.map_err(|e| DnsApiError::ClientError(e))?
let query = self.client.query(zone.clone(), class, RecordType::AXFR);
query.await.map_err(|e| DnsApiError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {
@ -68,31 +61,14 @@ impl RecordApi for DnsApiClient {
async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<(), Self::Error>
{
let mut mismatched_class = Vec::new();
let mut mismatched_zone = Vec::new();
for record in new_records.iter() {
if !zone.zone_of(record.name()) {
mismatched_zone.push(record.clone());
}
if record.dns_class() != class {
mismatched_class.push(record.clone());
}
}
if mismatched_class.len() > 0 || mismatched_zone.len() > 0 {
return Err(DnsApiError::RecordNotInZone {
zone,
class,
mismatched_zone,
mismatched_class
})
}
// Taken from trust_dns_client::op::update_message::append
// The original function can not be used as is because it takes a RecordSet and not a Record list
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
@ -121,6 +97,68 @@ impl RecordApi for DnsApiClient {
Ok(())
}
async fn update_records(&mut self, zone: Name, class: DNSClass, old_records: Vec<Record>, new_records: Vec<Record>) -> Result<(), Self::Error>
{
// Taken from trust_dns_client::op::update_message::compare_and_swap
// The original function can not be used as is because it takes a RecordSet and not a Record list
// for updates, the query section is used for the zone
let mut zone_query: Query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message: Message = Message::new();
// build the message
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
// make sure the record is what is expected
let mut prerequisite = old_records.clone();
for record in prerequisite.iter_mut() {
record.set_ttl(0);
}
message.add_pre_requisites(prerequisite);
// add the delete for the old record
let mut delete = old_records;
// the class must be none for delete
for record in delete.iter_mut() {
record.set_dns_class(DNSClass::NONE);
// the TTL should be 0
record.set_ttl(0);
}
message.add_updates(delete);
// insert the new record...
message.add_updates(new_records);
// Extended dns
{
let edns = message.edns_mut();
edns.set_max_payload(1232);
edns.set_version(0);
}
let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsApiError::ClientError(e))?;
if response.response_code() != ResponseCode::NoError {
return Err(DnsApiError::ResponceNotOk {
code: response.response_code(),
zone: zone,
});
}
Ok(())
}
}
@ -131,9 +169,8 @@ impl ZoneApi for DnsApiClient {
async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> Result<(), Self::Error>
{
let response = {
let mut query = Query::query(zone.clone(), RecordType::SOA);
query.set_query_class(class);
ClientResponse(self.client.lookup(query, DnsRequestOptions::default())).await.map_err(|e| DnsApiError::ClientError(e))?
let query = self.client.query(zone.clone(), class, RecordType::SOA);
query.await.map_err(|e| DnsApiError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {

View file

@ -14,3 +14,4 @@ pub use trust_dns_proto::rr::Name;
// Reexport module types
pub use api::{RecordApi, ZoneApi};
pub use dns_api::DnsApiClient;
pub use client::DnsClient;

View file

@ -28,6 +28,7 @@ async fn rocket() -> rocket::Rocket {
.mount("/api/v1", routes![
get_zone_records,
create_zone_records,
update_zone_records,
get_zones,
create_zone,
add_member_to_zone,

View file

@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use crate::dns;
#[derive(Deserialize, Serialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum DNSClass {
IN,
CH,

View file

@ -102,19 +102,6 @@ impl From<UserError> for ErrorResponse {
impl From<DnsApiError> for ErrorResponse {
fn from(e: DnsApiError) -> Self {
match e {
DnsApiError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone} => {
ErrorResponse::new(
Status::BadRequest,
"Record list contains records that do not belong to the zone".into()
).with_details(
json!({
"zone_name": zone.to_utf8(),
"class": models::DNSClass::from(class),
"mismatched_class": mismatched_class.into_iter().map(|r| r.clone().into()).collect::<Vec<models::Record>>(),
"mismatched_zone": mismatched_zone.into_iter().map(|r| r.clone().into()).collect::<Vec<models::Record>>(),
})
)
},
DnsApiError::ResponceNotOk { code, zone } => {
println!("Query for zone {} failed with code {}", zone, code);
@ -132,16 +119,33 @@ impl From<DnsApiError> for ErrorResponse {
impl From<models::RecordListParseError> for ErrorResponse {
fn from(e: models::RecordListParseError) -> Self {
models::ErrorResponse::new(
match e {
models::RecordListParseError::RecordNotInZone { zone, class, mismatched_class, mismatched_zone} => {
ErrorResponse::new(
Status::BadRequest,
"Record list contains records that do not belong to the zone".into()
).with_details(
json!({
"zone_name": zone.to_utf8(),
"class": models::DNSClass::from(class),
"mismatched_class": mismatched_class,
"mismatched_zone": mismatched_zone,
})
)
},
models::RecordListParseError::ParseError { zone, bad_records } => {
ErrorResponse::new(
Status::BadRequest,
"Record list contains records that could not be parsed into DNS records".into()
).with_details(
json!({
"zone_name": e.zone.to_utf8(),
"records": e.bad_records
"zone_name": zone.to_utf8(),
"records": bad_records
})
)
}
}
}
}

View file

@ -15,5 +15,5 @@ pub use errors::{UserError, ErrorResponse, make_500};
pub use name::{AbsoluteName, SerdeName};
pub use user::{LocalUser, UserInfo, Role, UserZone, User, CreateUserRequest};
pub use rdata::RData;
pub use record::{Record, RecordList, ParseRecordList, RecordListParseError};
pub use record::{Record, RecordList, ParseRecordList, RecordListParseError, UpdateRecordsRequest};
pub use zone::{Zone, AddZoneMemberRequest, CreateZoneRequest};

View file

@ -11,7 +11,7 @@ use crate::dns;
use super::name::SerdeName;
#[derive(Deserialize, Serialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "Type")]
#[serde(rename_all = "UPPERCASE")]
pub enum RData {

View file

@ -8,7 +8,7 @@ use super::class::DNSClass;
use super::rdata::RData;
#[derive(Deserialize, Serialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Record {
#[serde(rename = "Name")]
pub name: SerdeName,
@ -45,36 +45,78 @@ impl TryFrom<Record> for dns::Record {
pub type RecordList = Vec<Record>;
pub struct RecordListParseError {
pub bad_records: Vec<Record>,
pub zone: dns::Name,
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpdateRecordsRequest {
pub old_records: RecordList,
pub new_records: RecordList,
}
pub enum RecordListParseError {
ParseError {
bad_records: Vec<Record>,
zone: dns::Name,
},
RecordNotInZone {
zone: dns::Name,
class: dns::DNSClass,
mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>,
},
}
pub trait ParseRecordList {
fn try_into_dns_type(self, zone: dns::Name) -> Result<Vec<dns::Record>, RecordListParseError>;
fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, RecordListParseError>;
}
impl ParseRecordList for RecordList {
fn try_into_dns_type(self, zone: dns::Name) -> Result<Vec<dns::Record>, RecordListParseError> {
fn try_into_dns_type(self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, RecordListParseError> {
// TODO: What about relative names (also in cnames and stuff)
let mut bad_records = Vec::new();
let mut records: Vec<dns::Record> = Vec::new();
let mut mismatched_class: Vec<Record> = Vec::new();
let mut mismatched_zone: Vec<Record> = Vec::new();
for record in self.into_iter() {
let this_record = record.clone();
if let Ok(record) = record.try_into() {
if let Ok(record) = dns::Record::try_from(record) {
let mut good_record = true;
if !zone.zone_of(record.name()) {
mismatched_zone.push(this_record.clone());
good_record = false;
}
if record.dns_class() != class {
mismatched_class.push(this_record.clone());
good_record = false;
}
if good_record {
records.push(record);
}
} else {
bad_records.push(this_record.clone());
}
}
if !bad_records.is_empty() {
return Err(RecordListParseError {
return Err(RecordListParseError::ParseError {
zone,
bad_records,
});
}
if !mismatched_class.is_empty() || !mismatched_zone.is_empty() {
return Err(RecordListParseError::RecordNotInZone {
zone,
class,
mismatched_zone,
mismatched_class
});
}
return Ok(records)
}
}

View file

@ -4,19 +4,18 @@ use rocket::http::Status;
use rocket_contrib::json::Json;
use crate::DbConn;
use crate::dns;
use crate::dns::{DnsClient, DnsApiClient, RecordApi, ZoneApi};
use crate::models;
use crate::dns::{RecordApi, ZoneApi};
use crate::models::{ParseRecordList};
#[get("/zones/<zone>/records")]
pub async fn get_zone_records(
client: dns::client::DnsClient,
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName
) -> Result<Json<Vec<models::Record>>, models::ErrorResponse> {
) -> Result<Json<models::RecordList>, models::ErrorResponse> {
let user_info = user_info?;
let zone_name = zone.to_string();
@ -29,9 +28,9 @@ pub async fn get_zone_records(
}
}).await?;
let mut dns_api = dns::DnsApiClient::new(client);
let mut dns_api = DnsApiClient::new(client);
let dns_records = dns_api.get_records(zone.clone(), dns::DNSClass::IN).await?;
let dns_records = dns_api.get_records(zone.clone(), models::DNSClass::IN.into()).await?;
let records: Vec<_> = dns_records.into_iter().map(models::Record::from).collect();
Ok(Json(records))
@ -39,11 +38,11 @@ pub async fn get_zone_records(
#[post("/zones/<zone>/records", data = "<new_records>")]
pub async fn create_zone_records(
client: dns::client::DnsClient,
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
new_records: Json<Vec<models::Record>>
new_records: Json<models::RecordList>
) -> Result<Json<()>, models::ErrorResponse> {
let user_info = user_info?;
@ -58,17 +57,53 @@ pub async fn create_zone_records(
}).await?;
let mut dns_api = dns::DnsApiClient::new(client);
let mut dns_api = DnsApiClient::new(client);
dns_api.add_records(
zone.clone(),
models::DNSClass::IN.into(),
new_records.into_inner().try_into_dns_type(zone.into_inner())?
new_records.into_inner().try_into_dns_type(zone.into_inner(), models::DNSClass::IN.into())?
).await?;
return Ok(Json(()));
}
#[put("/zones/<zone>/records", data = "<update_records_request>")]
pub async fn update_zone_records(
client: DnsClient,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
update_records_request: Json<models::UpdateRecordsRequest>
) -> Result<Json<()>, models::ErrorResponse> {
let user_info = user_info?;
let zone = zone.into_inner();
let zone_name = zone.to_utf8();
let update_records_request = update_records_request.into_inner();
conn.run(move |c| {
if user_info.is_admin() {
models::Zone::get_by_name(c, &zone_name)
} else {
user_info.get_zone(c, &zone_name)
}
}).await?;
let mut dns_api = DnsApiClient::new(client);
dns_api.update_records(
zone.clone(),
models::DNSClass::IN.into(),
update_records_request.old_records.try_into_dns_type(zone.clone(), models::DNSClass::IN.into())?,
update_records_request.new_records.try_into_dns_type(zone, models::DNSClass::IN.into())?,
).await?;
return Ok(Json(()));
}
#[get("/zones")]
pub async fn get_zones(
conn: DbConn,
@ -90,14 +125,14 @@ pub async fn get_zones(
#[post("/zones", data = "<zone_request>")]
pub async fn create_zone(
conn: DbConn,
client: dns::client::DnsClient,
client: DnsClient,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone_request: Json<models::CreateZoneRequest>,
) -> Result<Json<models::Zone>, models::ErrorResponse> {
user_info?.check_admin()?;
let mut dns_api = dns::DnsApiClient::new(client);
dns_api.zone_exists(zone_request.name.clone(), dns::DNSClass::IN).await?;
let mut dns_api = DnsApiClient::new(client);
dns_api.zone_exists(zone_request.name.clone(), models::DNSClass::IN.into()).await?;
let zone = conn.run(move |c| {
models::Zone::create_zone(c, zone_request.into_inner())