use trust_dns_proto::DnsHandle; use trust_dns_client::rr::{DNSClass, RecordType}; use trust_dns_client::op::{UpdateMessage, OpCode, MessageType, Message, Query, ResponseCode}; use trust_dns_client::error::ClientError; use trust_dns_client::proto::xfer::{DnsRequestOptions}; use super::{Name, Record, RData}; use super::client::{ClientResponse, DnsClient}; use super::api::{RecordApi, ZoneApi}; #[derive(Debug)] pub enum DnsApiError { RecordNotInZone { zone: Name, class: DNSClass, mismatched_class: Vec<Record>, mismatched_zone: Vec<Record>, }, ClientError(ClientError), ResponceNotOk { code: ResponseCode, zone: Name, }, } pub struct DnsApiClient { client: DnsClient } impl DnsApiClient { pub fn new(client: DnsClient) -> Self { DnsApiClient { client } } } #[async_trait] impl RecordApi for DnsApiClient { type Error = DnsApiError; async fn get_records(&mut self, zone: Name, class: DNSClass) -> Result<Vec<Record>, Self::Error> { let response = { let mut query = Query::query(zone.clone(), RecordType::AXFR); query.set_query_class(class); ClientResponse(self.client.lookup(query, DnsRequestOptions::default())).await.map_err(|e| DnsApiError::ClientError(e))? }; if response.response_code() != ResponseCode::NoError { return Err(DnsApiError::ResponceNotOk { code: response.response_code(), zone: zone, }); } let answers = response.answers(); let mut records: Vec<_> = answers.to_vec().into_iter() .filter(|record| !matches!(record.rdata(), RData::NULL { .. } | RData::DNSSEC(_))) .collect(); // AXFR response ends with SOA, we remove it so it is not doubled in the response. records.pop(); Ok(records) } async fn add_records(&mut self, zone: Name, class: DNSClass, new_records: Vec<Record>) -> Result<(), Self::Error> { let mut mismatched_class = Vec::new(); let mut mismatched_zone = Vec::new(); for record in new_records.iter() { if !zone.zone_of(record.name()) { mismatched_zone.push(record.clone()); } if record.dns_class() != class { mismatched_class.push(record.clone()); } } if mismatched_class.len() > 0 || mismatched_zone.len() > 0 { return Err(DnsApiError::RecordNotInZone { zone, class, mismatched_zone, mismatched_class }) } let mut zone_query = Query::new(); zone_query.set_name(zone.clone()) .set_query_class(class) .set_query_type(RecordType::SOA); let mut message = Message::new(); // TODO: set random / time based id message .set_id(0) .set_message_type(MessageType::Query) .set_op_code(OpCode::Update) .set_recursion_desired(false); message.add_zone(zone_query); message.add_updates(new_records); { let edns = message.edns_mut(); edns.set_max_payload(1232); edns.set_version(0); } let response = ClientResponse(self.client.send(message)).await.map_err(|e| DnsApiError::ClientError(e))?; if response.response_code() != ResponseCode::NoError { return Err(DnsApiError::ResponceNotOk { code: response.response_code(), zone: zone, }); } Ok(()) } } #[async_trait] impl ZoneApi for DnsApiClient { type Error = DnsApiError; async fn zone_exists(&mut self, zone: Name, class: DNSClass) -> Result<(), Self::Error> { let response = { let mut query = Query::query(zone.clone(), RecordType::SOA); query.set_query_class(class); ClientResponse(self.client.lookup(query, DnsRequestOptions::default())).await.map_err(|e| DnsApiError::ClientError(e))? }; if response.response_code() != ResponseCode::NoError { return Err(DnsApiError::ResponceNotOk { code: response.response_code(), zone: zone, }); } Ok(()) } }