use std::{future::Future, pin::Pin, task::{Context, Poll}}; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}}; use rocket::outcome::try_outcome; use tokio::{net::TcpStream as TokioTcpStream, task}; use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream}; use trust_dns_proto::error::ProtoError; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; use crate::config::Config; pub struct DnsClient(AsyncClient); impl Deref for DnsClient { type Target = AsyncClient; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for DnsClient { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl DnsClient { pub async fn new(addr: SocketAddr) -> Result<Self, ProtoError> { let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(addr); let client = AsyncClient::with_timeout( stream, handle, std::time::Duration::from_secs(5), None); let (client, bg) = client.await?; task::spawn(bg); return Ok(DnsClient(client)) } } #[rocket::async_trait] impl<'r> FromRequest<'r> for DnsClient { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { let config = try_outcome!(request.guard::<&State<Config>>().await); match DnsClient::new(config.dns.server).await { Err(e) => { println!("Failed to connect to DNS server {:#?}", e); Outcome::Failure((Status::InternalServerError, ())) }, Ok(c) => Outcome::Success(c) } } } // Reimplement this type here as ClientReponse in trust-dns crate have private fields pub struct ClientResponse<R>(pub(crate) R) where R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static; impl<R> Future for ClientResponse<R> where R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static, { type Output = Result<DnsResponse, ClientError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { // This is from the future_utils crate, we simply reuse the reexport from Rocket rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) } }