use connector trait directly from request

This commit is contained in:
Hannaeko 2022-04-22 21:01:26 +02:00
parent e44e6eea63
commit c324f5ec9b
4 changed files with 64 additions and 38 deletions

View file

@ -6,6 +6,12 @@ use serde::{Deserialize, Deserializer};
use chrono::Duration;
use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}};
use rocket::outcome::try_outcome;
use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector};
#[derive(Debug, Deserialize)]
pub struct Config {
pub dns: DnsClientConfig,
@ -37,3 +43,50 @@ pub fn load(file_name: PathBuf) -> Config {
let file_content = fs::read_to_string(file_name).expect("could not read config file");
toml::from_str(&file_content).expect("could not parse config file")
}
// TODO: Maybe remove this
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<&State<Config>>().await);
match DnsClient::new(config.dns.server).await {
Err(e) => {
println!("Failed to connect to DNS server: {}", e);
Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => Outcome::Success(c)
}
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Box<dyn RecordConnector> {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<&State<Config>>().await);
match DnsClient::new(config.dns.server).await {
Err(e) => {
println!("Failed to connect to DNS server: {}", e);
Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => Outcome::Success(Box::new(DnsConnectorClient::new(c)))
}
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Box<dyn ZoneConnector> {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<&State<Config>>().await);
match DnsClient::new(config.dns.server).await {
Err(e) => {
println!("Failed to connect to DNS server: {}", e);
Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => Outcome::Success(Box::new(DnsConnectorClient::new(c)))
}
}
}

View file

@ -2,15 +2,11 @@ use std::{future::Future, pin::Pin, task::{Context, Poll}};
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}};
use rocket::outcome::try_outcome;
use tokio::{net::TcpStream as TokioTcpStream, task};
use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream};
use trust_dns_proto::error::ProtoError;
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
use crate::config::Config;
pub struct DnsClient(AsyncClient);
@ -42,22 +38,6 @@ impl DnsClient {
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DnsClient {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = try_outcome!(request.guard::<&State<Config>>().await);
match DnsClient::new(config.dns.server).await {
Err(e) => {
println!("Failed to connect to DNS server: {}", e);
Outcome::Failure((Status::InternalServerError, ()))
},
Ok(c) => Outcome::Success(c)
}
}
}
// Reimplement this type here as ClientReponse in trust-dns crate have private fields
pub struct ClientResponse<R>(pub(crate) R)
where

View file

@ -1,11 +1,12 @@
use crate::dns;
// TODO: Use model types instead of dns types as input / output and only convert internaly?
// Zone content api
// E.g.: DNS update + axfr, zone file read + write
#[async_trait]
pub trait RecordConnector {
pub trait RecordConnector: Send {
async fn get_records(&mut self, zone: dns::Name, class: dns::DNSClass) -> Result<Vec<dns::Record>, Box<dyn ConnectorError>>;
async fn add_records(&mut self, zone: dns::Name, class: dns::DNSClass, new_records: Vec<dns::Record>) -> Result<(), Box<dyn ConnectorError>>;
async fn update_records(&mut self, zone: dns::Name, class: dns::DNSClass, old_records: Vec<dns::Record>, new_records: Vec<dns::Record>) -> Result<(), Box<dyn ConnectorError>>;
@ -16,7 +17,7 @@ pub trait RecordConnector {
// Zone management api, todo
// E.g.: Manage catalog zone, dynamically generate knot / bind / nsd config...
#[async_trait]
pub trait ZoneConnector {
pub trait ZoneConnector: Send {
// get_zones
// add_zone
// delete_zone
@ -26,4 +27,5 @@ pub trait ZoneConnector {
pub trait ConnectorError: std::fmt::Debug + std::fmt::Display {
fn is_proto_error(&self) -> bool;
fn zone_name(&self) -> Option<dns::Name>;
}
}

View file

@ -3,14 +3,14 @@ use rocket::http::Status;
use rocket::serde::json::Json;
use crate::DbConn;
use crate::dns::{DnsClient, DnsConnectorClient, RecordConnector, ZoneConnector};
use crate::dns::{RecordConnector, ZoneConnector};
use crate::models;
use crate::models::{ParseRecordList};
#[get("/zones/<zone>/records")]
pub async fn get_zone_records(
client: DnsClient,
mut dns_api: Box<dyn RecordConnector>,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName
@ -27,8 +27,6 @@ pub async fn get_zone_records(
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
let dns_records = dns_api.get_records(zone.clone(), models::DNSClass::IN.into()).await?;
let records: Vec<_> = dns_records.into_iter().map(models::Record::from).collect();
@ -37,7 +35,7 @@ pub async fn get_zone_records(
#[post("/zones/<zone>/records", data = "<new_records>")]
pub async fn create_zone_records(
client: DnsClient,
mut dns_api: Box<dyn RecordConnector>,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
@ -55,8 +53,6 @@ pub async fn create_zone_records(
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.add_records(
zone.clone(),
models::DNSClass::IN.into(),
@ -68,7 +64,7 @@ pub async fn create_zone_records(
#[put("/zones/<zone>/records", data = "<update_records_request>")]
pub async fn update_zone_records(
client: DnsClient,
mut dns_api: Box<dyn RecordConnector>,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
@ -88,8 +84,6 @@ pub async fn update_zone_records(
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.update_records(
zone.clone(),
models::DNSClass::IN.into(),
@ -102,7 +96,7 @@ pub async fn update_zone_records(
#[delete("/zones/<zone>/records", data = "<records>")]
pub async fn delete_zone_records(
client: DnsClient,
mut dns_api: Box<dyn RecordConnector>,
conn: DbConn,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone: models::AbsoluteName,
@ -120,8 +114,6 @@ pub async fn delete_zone_records(
}
}).await?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.delete_records(
zone.clone(),
models::DNSClass::IN.into(),
@ -152,13 +144,12 @@ pub async fn get_zones(
#[post("/zones", data = "<zone_request>")]
pub async fn create_zone(
conn: DbConn,
client: DnsClient,
mut dns_api: Box<dyn ZoneConnector>,
user_info: Result<models::UserInfo, models::ErrorResponse>,
zone_request: Json<models::CreateZoneRequest>,
) -> Result<Json<models::Zone>, models::ErrorResponse> {
user_info?.check_admin()?;
let mut dns_api = DnsConnectorClient::new(client);
dns_api.zone_exists(zone_request.name.clone(), models::DNSClass::IN.into()).await?;
let zone = conn.run(move |c| {