From c324f5ec9b3e7bd82cee538c9f6e3fec8c6270a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20Berthaud-M=C3=BCller?= Date: Fri, 22 Apr 2022 21:01:26 +0200 Subject: [PATCH] use connector trait directly from request --- src/config.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++ src/dns/client.rs | 20 ----------------- src/dns/connector.rs | 8 ++++--- src/routes/zones.rs | 21 +++++------------- 4 files changed, 64 insertions(+), 38 deletions(-) diff --git a/src/config.rs b/src/config.rs index 935d597..74c6f3a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,6 +6,12 @@ use serde::{Deserialize, Deserializer}; use chrono::Duration; +use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}}; +use rocket::outcome::try_outcome; + +use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector}; + + #[derive(Debug, Deserialize)] pub struct Config { pub dns: DnsClientConfig, @@ -37,3 +43,50 @@ pub fn load(file_name: PathBuf) -> Config { let file_content = fs::read_to_string(file_name).expect("could not read config file"); toml::from_str(&file_content).expect("could not parse config file") } + +// TODO: Maybe remove this +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DnsClient { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::<&State>().await); + match DnsClient::new(config.dns.server).await { + Err(e) => { + println!("Failed to connect to DNS server: {}", e); + Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => Outcome::Success(c) + } + } +} + + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Box { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::<&State>().await); + match DnsClient::new(config.dns.server).await { + Err(e) => { + println!("Failed to connect to DNS server: {}", e); + Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => Outcome::Success(Box::new(DnsConnectorClient::new(c))) + } + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Box { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::<&State>().await); + match DnsClient::new(config.dns.server).await { + Err(e) => { + println!("Failed to connect to DNS server: {}", e); + Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => Outcome::Success(Box::new(DnsConnectorClient::new(c))) + } + } +} diff --git a/src/dns/client.rs b/src/dns/client.rs index 7f77c73..e914b26 100644 --- a/src/dns/client.rs +++ b/src/dns/client.rs @@ -2,15 +2,11 @@ use std::{future::Future, pin::Pin, task::{Context, Poll}}; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; -use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}}; -use rocket::outcome::try_outcome; use tokio::{net::TcpStream as TokioTcpStream, task}; use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream}; use trust_dns_proto::error::ProtoError; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use crate::config::Config; - pub struct DnsClient(AsyncClient); @@ -42,22 +38,6 @@ impl DnsClient { } } - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for DnsClient { - type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { - let config = try_outcome!(request.guard::<&State>().await); - match DnsClient::new(config.dns.server).await { - Err(e) => { - println!("Failed to connect to DNS server: {}", e); - Outcome::Failure((Status::InternalServerError, ())) - }, - Ok(c) => Outcome::Success(c) - } - } -} - // Reimplement this type here as ClientReponse in trust-dns crate have private fields pub struct ClientResponse(pub(crate) R) where diff --git a/src/dns/connector.rs b/src/dns/connector.rs index 00ddc6b..670b93f 100644 --- a/src/dns/connector.rs +++ b/src/dns/connector.rs @@ -1,11 +1,12 @@ use crate::dns; + // TODO: Use model types instead of dns types as input / output and only convert internaly? // Zone content api // E.g.: DNS update + axfr, zone file read + write #[async_trait] -pub trait RecordConnector { +pub trait RecordConnector: Send { async fn get_records(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result, Box>; async fn add_records(&mut self, zone: dns::Name, class: dns::DNSClass, new_records: Vec) -> Result<(), Box>; async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec, new_records: Vec) -> Result<(), Box>; @@ -16,7 +17,7 @@ pub trait RecordConnector { // Zone management api, todo // E.g.: Manage catalog zone, dynamically generate knot / bind / nsd config... #[async_trait] -pub trait ZoneConnector { +pub trait ZoneConnector: Send { // get_zones // add_zone // delete_zone @@ -26,4 +27,5 @@ pub trait ZoneConnector { pub trait ConnectorError: std::fmt::Debug + std::fmt::Display { fn is_proto_error(&self) -> bool; fn zone_name(&self) -> Option; -} \ No newline at end of file +} + diff --git a/src/routes/zones.rs b/src/routes/zones.rs index c425665..dbec883 100644 --- a/src/routes/zones.rs +++ b/src/routes/zones.rs @@ -3,14 +3,14 @@ use rocket::http::Status; use rocket::serde::json::Json; use crate::DbConn; -use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector}; +use crate::dns::{RecordConnector, ZoneConnector}; use crate::models; use crate::models::{ParseRecordList}; #[get("/zones//records")] pub async fn get_zone_records( - client: DnsClient, + mut dns_api: Box, conn: DbConn, user_info: Result, zone: models::AbsoluteName @@ -27,8 +27,6 @@ pub async fn get_zone_records( } }).await?; - let mut dns_api = DnsConnectorClient::new(client); - let dns_records = dns_api.get_records(zone.clone(), models::DNSClass::IN.into()).await?; let records: Vec<_> = dns_records.into_iter().map(models::Record::from).collect(); @@ -37,7 +35,7 @@ pub async fn get_zone_records( #[post("/zones//records", data = "")] pub async fn create_zone_records( - client: DnsClient, + mut dns_api: Box, conn: DbConn, user_info: Result, zone: models::AbsoluteName, @@ -55,8 +53,6 @@ pub async fn create_zone_records( } }).await?; - let mut dns_api = DnsConnectorClient::new(client); - dns_api.add_records( zone.clone(), models::DNSClass::IN.into(), @@ -68,7 +64,7 @@ pub async fn create_zone_records( #[put("/zones//records", data = "")] pub async fn update_zone_records( - client: DnsClient, + mut dns_api: Box, conn: DbConn, user_info: Result, zone: models::AbsoluteName, @@ -88,8 +84,6 @@ pub async fn update_zone_records( } }).await?; - let mut dns_api = DnsConnectorClient::new(client); - dns_api.update_records( zone.clone(), models::DNSClass::IN.into(), @@ -102,7 +96,7 @@ pub async fn update_zone_records( #[delete("/zones//records", data = "")] pub async fn delete_zone_records( - client: DnsClient, + mut dns_api: Box, conn: DbConn, user_info: Result, zone: models::AbsoluteName, @@ -120,8 +114,6 @@ pub async fn delete_zone_records( } }).await?; - let mut dns_api = DnsConnectorClient::new(client); - dns_api.delete_records( zone.clone(), models::DNSClass::IN.into(), @@ -152,13 +144,12 @@ pub async fn get_zones( #[post("/zones", data = "")] pub async fn create_zone( conn: DbConn, - client: DnsClient, + mut dns_api: Box, user_info: Result, zone_request: Json, ) -> Result, models::ErrorResponse> { user_info?.check_admin()?; - let mut dns_api = DnsConnectorClient::new(client); dns_api.zone_exists(zone_request.name.clone(), models::DNSClass::IN.into()).await?; let zone = conn.run(move |c| {