nomilo/src/dns/message.rs

95 lines
2.7 KiB
Rust
Raw Normal View History

2022-03-04 12:08:03 +00:00
use trust_dns_proto::DnsHandle;
use trust_dns_client::rr::{DNSClass, RecordType};
2022-03-04 12:50:57 +00:00
use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode};
use trust_dns_client::error::ClientError;
2022-03-04 12:08:03 +00:00
use trust_dns_proto::error::ProtoError;
2022-03-04 12:50:57 +00:00
use trust_dns_client::proto::xfer::{DnsRequestOptions};
2022-03-04 12:08:03 +00:00
2022-03-04 16:17:15 +00:00
use super::{Name, Record, RData};
2022-03-04 12:08:03 +00:00
use super::client::{ClientResponse};
2022-03-04 12:50:57 +00:00
#[derive(Debug)]
2022-03-04 12:08:03 +00:00
pub enum MessageError {
RecordNotInZone {
zone: Name,
class: DNSClass,
mismatched_class: Vec<Record>,
mismatched_zone: Vec<Record>,
2022-03-04 12:50:57 +00:00
},
ClientError(ClientError),
ResponceNotOk(ResponseCode)
2022-03-04 12:08:03 +00:00
}
2022-03-04 12:50:57 +00:00
#[async_trait]
2022-03-04 12:08:03 +00:00
pub trait DnsMessage: DnsHandle<Error = ProtoError> + Send {
2022-03-04 12:50:57 +00:00
async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result<Vec<Record>, MessageError>
{
let response = {
let mut query = Query::query(zone, RecordType::AXFR);
query.set_query_class(class);
ClientResponse(self.lookup(query, DnsRequestOptions::default())).await.map_err(|e| MessageError::ClientError(e))?
};
if response.response_code() != ResponseCode::NoError {
return Err(MessageError::ResponceNotOk(response.response_code()));
}
let answers = response.answers();
let mut records: Vec<_> = answers.to_vec().into_iter()
.filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_)))
.collect();
// AXFR response ends with SOA, we remove it so it is not doubled in the response.
records.pop();
Ok(records)
}
fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<ClientResponse<<Self as DnsHandle>::Response>, MessageError>
2022-03-04 12:08:03 +00:00
{
let mut mismatched_class = Vec::new();
let mut mismatched_zone = Vec::new();
for record in new_records.iter() {
if !zone.zone_of(record.name()) {
mismatched_zone.push(record.clone());
}
if record.dns_class() != class {
mismatched_class.push(record.clone());
}
}
if mismatched_class.len() > 0 || mismatched_zone.len() > 0 {
return Err(MessageError::RecordNotInZone {
zone,
class,
mismatched_zone,
mismatched_class
})
}
let mut zone_query = Query::new();
zone_query.set_name(zone.clone())
.set_query_class(class)
.set_query_type(RecordType::SOA);
let mut message = Message::new();
// TODO: set random / time based id
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Update)
.set_recursion_desired(false);
message.add_zone(zone_query);
message.add_updates(new_records);
{
let edns = message.edns_mut();
edns.set_max_payload(1232);
edns.set_version(0);
}
return Ok(ClientResponse(self.send(message)));
}
}