move record fetching in message module

main
Hannaeko 2022-03-04 13:50:57 +01:00
parent 690987010d
commit 094539d376
2 changed files with 47 additions and 27 deletions

View File

@ -1,24 +1,52 @@
use trust_dns_proto::DnsHandle; use trust_dns_proto::DnsHandle;
use trust_dns_client::rr::{DNSClass, RecordType}; use trust_dns_client::rr::{DNSClass, RecordType};
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query}; use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode};
use trust_dns_client::error::ClientError;
use trust_dns_proto::error::ProtoError; use trust_dns_proto::error::ProtoError;
use trust_dns_client::proto::xfer::{DnsRequestOptions};
use super::trust_dns_types::{Name, Record}; use super::trust_dns_types::{Name, Record, RData};
use super::client::{ClientResponse}; use super::client::{ClientResponse};
#[derive(Debug)]
pub enum MessageError { pub enum MessageError {
RecordNotInZone { RecordNotInZone {
zone: Name, zone: Name,
class: DNSClass, class: DNSClass,
mismatched_class: Vec<Record>, mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>, mismatched_zone: Vec<Record>,
} },
ClientError(ClientError),
ResponceNotOk(ResponseCode)
} }
#[async_trait]
pub trait DnsMessage: DnsHandle<Error = ProtoError> + Send { pub trait DnsMessage: DnsHandle<Error = ProtoError> + Send {
fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<ClientResponse<Self::Response>, MessageError> async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result<Vec<Record>, MessageError>
{
let response = {
let mut query = Query::query(zone, RecordType::AXFR);
query.set_query_class(class);
ClientResponse(self.lookup(query, DnsRequestOptions::default())).await.map_err(|e| MessageError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {
return Err(MessageError::ResponceNotOk(response.response_code()));
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<ClientResponse<<Self as DnsHandle>::Response>, MessageError>
{ {
let mut mismatched_class = Vec::new(); let mut mismatched_class = Vec::new();
let mut mismatched_zone = Vec::new(); let mut mismatched_zone = Vec::new();

View File

@ -41,29 +41,20 @@ pub async fn get_zone_records(
} }
}).await?; }).await?;
let response = { let records: Vec<_> = match client.get_records(zone.clone(), DNSClass::IN).await {
let query = client.query(zone.clone(), DNSClass::IN, RecordType::AXFR); Ok(records) => records.into_iter().map(dns::record::Record::from).collect(),
query.await.map_err(make_500)?
};
if response.response_code() != ResponseCode::NoError { Err(MessageError::ResponceNotOk(code)) => {
println!("Querrying AXFR of zone {} failed with code {}", *zone, response.response_code()); println!("Querrying AXFR of zone {} failed with code {}", *zone, code);
return ErrorResponse::new( return ErrorResponse::new(
Status::NotFound, Status::NotFound,
"Zone could not be found".into() "Zone could not be found".into()
).with_details(json!({ ).with_details(json!({
"zone_name": zone.to_utf8() "zone_name": zone.to_utf8()
})).err(); })).err();
} },
Err(err) => { return make_500(err).err(); },
let answers = response.answers(); };
let mut records: Vec<_> = answers.to_vec().into_iter()
.map(dns::record::Record::from)
.filter(|record| !matches!(record.rdata, dns::rdata::RData::NULL { .. } | dns::rdata::RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(Json(records)) Ok(Json(records))
} }
@ -127,7 +118,8 @@ pub async fn create_zone_records(
"mismatched_zone": mismatched_zone.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::record::Record>>(), "mismatched_zone": mismatched_zone.into_iter().map(|r| r.clone().into()).collect::<Vec<dns::record::Record>>(),
}) })
).err(); ).err();
} },
Err(e) => return make_500(e).err()
}; };
// TODO: better error handling // TODO: better error handling