use trust_dns_proto::DnsHandle; use trust_dns_client::rr::{DNSClass, RecordType}; use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode}; use trust_dns_client::error::ClientError; use trust_dns_proto::error::ProtoError; use trust_dns_client::proto::xfer::{DnsRequestOptions}; use super::{Name, Record, RData}; use super::client::{ClientResponse}; #[derive(Debug)] pub enum MessageError { RecordNotInZone { zone: Name, class: DNSClass, mismatched_class: Vec, mismatched_zone: Vec, }, ClientError(ClientError), ResponceNotOk(ResponseCode) } #[async_trait] pub trait DnsMessage: DnsHandle + Send { async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result, MessageError> { let response = { let mut query = Query::query(zone, RecordType::AXFR); query.set_query_class(class); ClientResponse(self.lookup(query, DnsRequestOptions::default())).await.map_err(|e| MessageError::ClientError(e))? }; if response.response_code() != ResponseCode::NoError { return Err(MessageError::ResponceNotOk(response.response_code())); } let answers = response.answers(); let mut records: Vec<_> = answers.to_vec().into_iter() .filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_))) .collect(); // AXFR response ends with SOA, we remove it so it is not doubled in the response. records.pop(); Ok(records) } fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec) -> Result::Response>, MessageError> { let mut mismatched_class = Vec::new(); let mut mismatched_zone = Vec::new(); for record in new_records.iter() { if !zone.zone_of(record.name()) { mismatched_zone.push(record.clone()); } if record.dns_class() != class { mismatched_class.push(record.clone()); } } if mismatched_class.len() > 0 || mismatched_zone.len() > 0 { return Err(MessageError::RecordNotInZone { zone, class, mismatched_zone, mismatched_class }) } let mut zone_query = Query::new(); zone_query.set_name(zone.clone()) .set_query_class(class) .set_query_type(RecordType::SOA); let mut message = Message::new(); // TODO: set random / time based id message .set_id(0) .set_message_type(MessageType::Query) .set_op_code(OpCode::Update) .set_recursion_desired(false); message.add_zone(zone_query); message.add_updates(new_records); { let edns = message.edns_mut(); edns.set_max_payload(1232); edns.set_version(0); } return Ok(ClientResponse(self.send(message))); } }