From c32c02a5ac599aa3c02b71bbe25516ab9031b562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20Berthaud-M=C3=BCller?= Date: Sat, 23 Apr 2022 12:14:19 +0200 Subject: [PATCH] rework cli to add db init --- Cargo.lock | 32 +++ Cargo.toml | 1 + src/cli/mod.rs | 41 ++++ src/cli/serve.rs | 59 +++++ src/cli/user.rs | 57 +++++ src/main.rs | 64 ++---- src/models/dns.bak.rs | 480 ++++++++++++++++++++++++++++++++++++++++ src/models/user.rs | 2 +- src/models/users.bak.rs | 414 ++++++++++++++++++++++++++++++++++ 9 files changed, 1108 insertions(+), 42 deletions(-) create mode 100644 src/cli/mod.rs create mode 100644 src/cli/serve.rs create mode 100644 src/cli/user.rs create mode 100644 src/models/dns.bak.rs create mode 100644 src/models/users.bak.rs diff --git a/Cargo.lock b/Cargo.lock index b87802a..564e47a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,6 +403,16 @@ dependencies = [ "syn", ] +[[package]] +name = "diesel_migrations" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf3cde8413353dc7f5d72fa8ce0b99a560a359d2c5ef1e5817ca731cd9008f4c" +dependencies = [ + "migrations_internals", + "migrations_macros", +] + [[package]] name = "digest" version = "0.10.3" @@ -886,6 +896,27 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +[[package]] +name = "migrations_internals" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b4fc84e4af020b837029e017966f86a1c2d5e83e64b589963d5047525995860" +dependencies = [ + "diesel", +] + +[[package]] +name = "migrations_macros" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9753f12909fd8d923f75ae5c3258cae1ed3c8ec052e1b38c93c21a6d157f789c" +dependencies = [ + "migrations_internals", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "mime" version = "0.3.16" @@ -953,6 +984,7 @@ dependencies = [ "clap", "diesel", "diesel-derive-enum", + "diesel_migrations", "djangohashers", "figment", "humantime", diff --git a/Cargo.toml b/Cargo.toml index 57627c4..8909672 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ toml = "0.5" base64 = "0.13.0" uuid = { version = "0.8.2", features = ["v4", "serde"] } diesel = { version = "1.4", features = ["sqlite"] } +diesel_migrations = "1.4" diesel-derive-enum = { version = "1", features = ["sqlite"] } djangohashers = { version = "1.4.0", features = ["with_argon2"], default-features = false } jsonwebtoken = "7.2.0" diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..7f5835b --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,41 @@ +pub mod serve; +pub mod user; + +use clap::{Parser, Subcommand}; +use figment::Figment; + +use crate::config::Config; + +use serve::ServeCommand; +use user::UserCommand; + +#[derive(Parser)] +#[clap(author, version, about, long_about = None)] +#[clap(propagate_version = true)] +pub struct NomiloCli { + #[clap(subcommand)] + pub command: Command +} + +#[derive(Subcommand)] +pub enum Command { + /// Lauch web server + Serve(ServeCommand), + /// Manage users + #[clap(subcommand)] + User(UserCommand) +} + + +pub trait NomiloCommand { + fn run(self, figment: Figment, app_config: Config); +} + +impl NomiloCommand for NomiloCli { + fn run(self, figment: Figment, app_config: Config) { + match self.command { + Command::Serve(sub) => sub.run(figment, app_config), + Command::User(sub) => sub.run(figment, app_config), + }; + } +} diff --git a/src/cli/serve.rs b/src/cli/serve.rs new file mode 100644 index 0000000..833cef5 --- /dev/null +++ b/src/cli/serve.rs @@ -0,0 +1,59 @@ +use std::process::exit; + +use clap::Parser; +use rocket::{Rocket, Build}; +use rocket::fairing::AdHoc; +use figment::Figment; + +use crate::config::Config; +use crate::routes::users::*; +use crate::routes::zones::*; +use crate::{DbConn}; +use crate::cli::NomiloCommand; + + +#[derive(Parser)] +pub struct ServeCommand; + +async fn run_migrations(rocket: Rocket) -> Rocket { + embed_migrations!("migrations"); + + let conn = match DbConn::get_one(&rocket).await { + Some(c) => c, + None => { + error!("Could not get a database connection"); + exit(1); + } + }; + + if let Err(e) = conn.run(|c| embedded_migrations::run(c)).await { + error!("Error running migrations: {}", e); + exit(1) + } + + rocket + +} + +impl NomiloCommand for ServeCommand { + fn run(self, figment: Figment, app_config: Config) { + rocket::async_main(async move { + let _res = rocket::custom(figment) + .manage(app_config) + .attach(DbConn::fairing()) + .attach(AdHoc::on_ignite("Database migration", run_migrations)) + .mount("/api/v1", routes![ + get_zone_records, + create_zone_records, + update_zone_records, + delete_zone_records, + get_zones, + create_zone, + add_member_to_zone, + create_auth_token, + create_user, + ]) + .launch().await; + }); + } +} diff --git a/src/cli/user.rs b/src/cli/user.rs new file mode 100644 index 0000000..9835487 --- /dev/null +++ b/src/cli/user.rs @@ -0,0 +1,57 @@ +use std::process::exit; + +use clap::{Parser, Subcommand}; +use figment::Figment; + +use crate::config::Config; +use crate::cli::{NomiloCommand}; +use crate::models::{LocalUser, CreateUserRequest, Role}; +use crate::get_db_conn; + +#[derive(Subcommand)] +pub enum UserCommand { + /// Add new user + Add(AddUserCommand), +} + + +#[derive(Parser)] +pub struct AddUserCommand { + #[clap(long = "--name", short = 'n')] + pub name: String, + #[clap(long = "--email", short = 'e')] + pub email: String, + #[clap(long = "--is-admin", short = 'a')] + pub is_admin: bool, + #[clap(long = "--password", short = 'p')] + pub password: Option, +} + + +impl NomiloCommand for UserCommand { + fn run(self, figment: Figment, app_config: Config) { + match self { + UserCommand::Add(sub) => sub.run(figment, app_config), + }; + } +} + +impl NomiloCommand for AddUserCommand { + fn run(self, figment: Figment, _app_config: Config) { + + let res = LocalUser::create_user(&get_db_conn(&figment), CreateUserRequest { + username: self.name, + email: self.email, + role: Some(if self.is_admin { Role::Admin } else { Role::ZoneAdmin }), + password: self.password.unwrap(), + }); + + match res { + Ok(_) => println!("Successfully added user"), + Err(err) => { + eprintln!("Error while adding user: {:?}", err); + exit(1); + } + }; + } +} diff --git a/src/main.rs b/src/main.rs index 0804876..89342c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,60 +1,47 @@ #![feature(proc_macro_hygiene, decl_macro)] + #[macro_use] extern crate rocket; #[macro_use] extern crate diesel; +#[macro_use] extern crate diesel_migrations; mod models; mod config; mod schema; mod routes; mod dns; +mod cli; use std::process::exit; +use clap::Parser; use figment::{Figment, Profile, providers::{Format, Toml, Env}}; use rocket_sync_db_pools::database; -use clap::{Parser, Subcommand}; - -use routes::users::*; -use routes::zones::*; +use diesel::prelude::*; +use crate::cli::{NomiloCli, NomiloCommand}; #[database("sqlite")] pub struct DbConn(diesel::SqliteConnection); -#[derive(Parser)] -#[clap(author, version, about, long_about = None)] -#[clap(propagate_version = true)] -struct Nomilo { - #[clap(subcommand)] - command: Command -} -#[derive(Subcommand)] -enum Command { - /// Lauch web server - Serve, -} +pub fn get_db_conn(figment: &Figment) -> diesel::SqliteConnection { + let url = match figment.focus("databases.sqlite").extract_inner::("url") { + Ok(url) => url, + Err(e) => { + eprintln!("Error loading configuration: {}", e); + exit(1); + } + }; -fn serve(figment: Figment, app_config: config::Config) { - rocket::async_main(async move { - let _res = rocket::custom(figment) - .manage(app_config) - .attach(DbConn::fairing()) - .mount("/api/v1", routes![ - get_zone_records, - create_zone_records, - update_zone_records, - delete_zone_records, - get_zones, - create_zone, - add_member_to_zone, - create_auth_token, - create_user, - ]) - .launch().await; - }); + match diesel::SqliteConnection::establish(&url) { + Ok(c) => c, + Err(e) => { + eprintln!("Error connecting to database at \"{}\": {}", url, e); + exit(1); + } + } } @@ -72,11 +59,6 @@ fn main() { } }; - let nomilo = Nomilo::parse(); - - match nomilo.command { - Command::Serve => serve(figment, app_config), - }; - - + let nomilo = NomiloCli::parse(); + nomilo.run(figment, app_config); } diff --git a/src/models/dns.bak.rs b/src/models/dns.bak.rs new file mode 100644 index 0000000..c9d13b8 --- /dev/null +++ b/src/models/dns.bak.rs @@ -0,0 +1,480 @@ +use std::{convert::{TryFrom, TryInto}, future::Future, net::{Ipv6Addr, Ipv4Addr}, pin::Pin, task::{Context, Poll}}; +use std::fmt; +use std::ops::{Deref, DerefMut}; + + +use rocket::{Request, State, http::Status, request::{FromParam, FromRequest, Outcome}}; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use tokio::{net::TcpStream as TokioTcpStream, task}; + +use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, serialize::binary::BinEncoder, tcp::TcpClientStream}; +use trust_dns_proto::error::{ProtoError}; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; + + +use super::trust_dns_types::{self, Name}; +use crate::config::Config; + + +#[derive(Deserialize, Serialize, Clone)] +#[serde(tag = "Type")] +#[serde(rename_all = "UPPERCASE")] +pub enum RData { + #[serde(rename_all = "PascalCase")] + A { + address: Ipv4Addr + }, + #[serde(rename_all = "PascalCase")] + AAAA { + address: Ipv6Addr + }, + #[serde(rename_all = "PascalCase")] + CAA { + issuer_critical: bool, + value: String, + property_tag: String, + }, + #[serde(rename_all = "PascalCase")] + CNAME { + target: SerdeName + }, + // HINFO(HINFO), + // HTTPS(SVCB), + #[serde(rename_all = "PascalCase")] + MX { + preference: u16, + mail_exchanger: SerdeName + }, + // NAPTR(NAPTR), + #[serde(rename_all = "PascalCase")] + NULL { + data: String + }, + #[serde(rename_all = "PascalCase")] + NS { + target: SerdeName + }, + // OPENPGPKEY(OPENPGPKEY), + // OPT(OPT), + #[serde(rename_all = "PascalCase")] + PTR { + target: SerdeName + }, + #[serde(rename_all = "PascalCase")] + SOA { + master_server_name: SerdeName, + maintainer_name: SerdeName, + refresh: i32, + retry: i32, + expire: i32, + minimum: u32, + serial: u32 + }, + #[serde(rename_all = "PascalCase")] + SRV { + server: SerdeName, + port: u16, + priority: u16, + weight: u16, + }, + #[serde(rename_all = "PascalCase")] + SSHFP { + algorithm: u8, + digest_type: u8, + fingerprint: String, + }, + // SVCB(SVCB), + // TLSA(TLSA), + #[serde(rename_all = "PascalCase")] + TXT { + text: String + }, + + // TODO: Eventually allow deserialization of DNSSEC records + #[serde(skip)] + DNSSEC(trust_dns_types::DNSSECRData), + #[serde(rename_all = "PascalCase")] + Unknown { + code: u16, + data: String, + }, + // ZERO, + + // TODO: DS +} + +impl From for RData { + fn from(rdata: trust_dns_types::RData) -> RData { + match rdata { + trust_dns_types::RData::A(address) => RData::A { address }, + trust_dns_types::RData::AAAA(address) => RData::AAAA { address }, + // Still a draft, no iana number yet, I don't to put something that is not currently supported so that's why NULL and not unknown. + // TODO: probably need better error here, I don't know what to do about that as this would require to change the From for something else. + // (empty data because I'm lazy) + trust_dns_types::RData::ANAME(_) => RData::NULL { + data: String::new() + }, + trust_dns_types::RData::CNAME(target) => RData::CNAME { + target: SerdeName(target) + }, + trust_dns_types::RData::CAA(caa) => RData::CAA { + issuer_critical: caa.issuer_critical(), + value: format!("{}", CAAValue(caa.value())), + property_tag: caa.tag().as_str().to_string(), + }, + trust_dns_types::RData::MX(mx) => RData::MX { + preference: mx.preference(), + mail_exchanger: SerdeName(mx.exchange().clone()) + }, + trust_dns_types::RData::NULL(null) => RData::NULL { + data: base64::encode(null.anything().map(|data| data.to_vec()).unwrap_or_default()) + }, + trust_dns_types::RData::NS(target) => RData::NS { + target: SerdeName(target) + }, + trust_dns_types::RData::PTR(target) => RData::PTR { + target: SerdeName(target) + }, + trust_dns_types::RData::SOA(soa) => RData::SOA { + master_server_name: SerdeName(soa.mname().clone()), + maintainer_name: SerdeName(soa.rname().clone()), + refresh: soa.refresh(), + retry: soa.retry(), + expire: soa.expire(), + minimum: soa.minimum(), + serial: soa.serial() + }, + trust_dns_types::RData::SRV(srv) => RData::SRV { + server: SerdeName(srv.target().clone()), + port: srv.port(), + priority: srv.priority(), + weight: srv.weight(), + }, + trust_dns_types::RData::SSHFP(sshfp) => RData::SSHFP { + algorithm: sshfp.algorithm().into(), + digest_type: sshfp.fingerprint_type().into(), + fingerprint: trust_dns_types::sshfp::HEX.encode(sshfp.fingerprint()), + }, + //TODO: This might alter data if not utf8 compatible, probably need to be replaced + //TODO: check whether concatenating txt data is harmful or not + trust_dns_types::RData::TXT(txt) => RData::TXT { text: format!("{}", txt) }, + trust_dns_types::RData::DNSSEC(data) => RData::DNSSEC(data), + rdata => { + let code = rdata.to_record_type().into(); + let mut data = Vec::new(); + let mut encoder = BinEncoder::new(&mut data); + // TODO: need better error handling (use TryFrom ?) + rdata.emit(&mut encoder).expect("could not encode data"); + + RData::Unknown { + code, + data: base64::encode(data), + } + } + } + } +} + +impl TryFrom for trust_dns_types::RData { + type Error = ProtoError; + + fn try_from(rdata: RData) -> Result { + Ok(match rdata { + RData::A { address } => trust_dns_types::RData::A(address), + RData::AAAA { address } => trust_dns_types::RData::AAAA(address), + // TODO: Round trip test all types below (currently not tested...) + RData::CAA { issuer_critical, value, property_tag } => { + let property = trust_dns_types::caa::Property::from(property_tag); + let caa_value = { + // TODO: duplicate of trust_dns_client::serialize::txt::rdata_parser::caa::parse + // because caa::read_value is private + match property { + trust_dns_types::caa::Property::Issue | trust_dns_types::caa::Property::IssueWild => { + let value = trust_dns_types::caa::read_issuer(value.as_bytes())?; + trust_dns_types::caa::Value::Issuer(value.0, value.1) + } + trust_dns_types::caa::Property::Iodef => { + let url = trust_dns_types::caa::read_iodef(value.as_bytes())?; + trust_dns_types::caa::Value::Url(url) + } + trust_dns_types::caa::Property::Unknown(_) => trust_dns_types::caa::Value::Unknown(value.as_bytes().to_vec()), + } + }; + trust_dns_types::RData::CAA(trust_dns_types::caa::CAA { + issuer_critical, + tag: property, + value: caa_value, + }) + }, + RData::CNAME { target } => trust_dns_types::RData::CNAME(target.into_inner()), + RData::MX { preference, mail_exchanger } => trust_dns_types::RData::MX( + trust_dns_types::mx::MX::new(preference, mail_exchanger.into_inner()) + ), + RData::NULL { data } => trust_dns_types::RData::NULL( + trust_dns_types::null::NULL::with( + base64::decode(data).map_err(|e| ProtoError::from(format!("{}", e)))? + ) + ), + RData::NS { target } => trust_dns_types::RData::NS(target.into_inner()), + RData::PTR { target } => trust_dns_types::RData::PTR(target.into_inner()), + RData::SOA { + master_server_name, + maintainer_name, + refresh, + retry, + expire, + minimum, + serial + } => trust_dns_types::RData::SOA( + trust_dns_types::soa::SOA::new( + master_server_name.into_inner(), + maintainer_name.into_inner(), + serial, + refresh, + retry, + expire, + minimum, + ) + ), + RData::SRV { server, port, priority, weight } => trust_dns_types::RData::SRV( + trust_dns_types::srv::SRV::new(priority, weight, port, server.into_inner()) + ), + RData::SSHFP { algorithm, digest_type, fingerprint } => trust_dns_types::RData::SSHFP( + trust_dns_types::sshfp::SSHFP::new( + // NOTE: This allows unassigned algorithms + trust_dns_types::sshfp::Algorithm::from(algorithm), + trust_dns_types::sshfp::FingerprintType::from(digest_type), + trust_dns_types::sshfp::HEX.decode(fingerprint.as_bytes()).map_err(|e| ProtoError::from(format!("{}", e)))? + ) + ), + RData::TXT { text } => trust_dns_types::RData::TXT(trust_dns_types::txt::TXT::new(vec![text])), + // TODO: Error out for DNSSEC? Prefer downstream checks? + RData::DNSSEC(_) => todo!(), + // TODO: Disallow unknown? (could be used to bypass unsopported types?) Prefer downstream checks? + RData::Unknown { code, data } => todo!(), + }) + } +} + +struct CAAValue<'a>(&'a trust_dns_types::caa::Value); + +// trust_dns Display implementation panics if no parameters +// Implementation based on caa::emit_value +// Also the quotes are strips to render in JSON +impl<'a> fmt::Display for CAAValue<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.0 { + trust_dns_types::caa::Value::Issuer(name, parameters) => { + + if let Some(name) = name { + write!(f, "{}", name)?; + } + + if name.is_none() && parameters.is_empty() { + write!(f, ";")?; + } + + for value in parameters { + write!(f, "; {}", value)?; + } + } + trust_dns_types::caa::Value::Url(url) => write!(f, "{}", url)?, + trust_dns_types::caa::Value::Unknown(v) => write!(f, "{:?}", v)?, + } + Ok(()) + } +} + +#[derive(Deserialize, Serialize, Clone)] +pub enum DNSClass { + IN, + CH, + HS, + NONE, + ANY, + OPT(u16), +} + +impl From for DNSClass { + fn from(dns_class: trust_dns_types::DNSClass) -> DNSClass { + match dns_class { + trust_dns_types::DNSClass::IN => DNSClass::IN, + trust_dns_types::DNSClass::CH => DNSClass::CH, + trust_dns_types::DNSClass::HS => DNSClass::HS, + trust_dns_types::DNSClass::NONE => DNSClass::NONE, + trust_dns_types::DNSClass::ANY => DNSClass::ANY, + trust_dns_types::DNSClass::OPT(v) => DNSClass::OPT(v), + } + } +} + +impl From for trust_dns_types::DNSClass { + fn from(dns_class: DNSClass) -> trust_dns_types::DNSClass { + match dns_class { + DNSClass::IN => trust_dns_types::DNSClass::IN, + DNSClass::CH => trust_dns_types::DNSClass::CH, + DNSClass::HS => trust_dns_types::DNSClass::HS, + DNSClass::NONE => trust_dns_types::DNSClass::NONE, + DNSClass::ANY => trust_dns_types::DNSClass::ANY, + DNSClass::OPT(v) => trust_dns_types::DNSClass::OPT(v), + } + } +} + + +// Reimplement this type here as ClientReponse in trust-dns crate have private fields +pub struct ClientResponse(pub(crate) R) +where + R: Future> + Send + Unpin + 'static; + +impl Future for ClientResponse +where + R: Future> + Send + Unpin + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // This is from the future_utils crate, we simply reuse the reexport from Rocket + rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from) + } +} + + + +#[derive(Deserialize, Serialize, Clone)] +pub struct Record { + #[serde(rename = "Name")] + pub name: SerdeName, + // TODO: Make class optional, default to IN + #[serde(rename = "Class")] + pub dns_class: DNSClass, + #[serde(rename = "TTL")] + pub ttl: u32, + #[serde(flatten)] + pub rdata: RData, +} + +impl From for Record { + fn from(record: trust_dns_types::Record) -> Record { + Record { + name: SerdeName(record.name().clone()), + dns_class: record.dns_class().into(), + ttl: record.ttl(), + rdata: record.into_data().into(), + } + } +} + +impl TryFrom for trust_dns_types::Record { + type Error = ProtoError; + + fn try_from(record: Record) -> Result { + let mut trust_dns_record = trust_dns_types::Record::from_rdata(record.name.into_inner(), record.ttl, record.rdata.try_into()?); + trust_dns_record.set_dns_class(record.dns_class.into()); + Ok(trust_dns_record) + } +} + +#[derive(Debug, Clone)] +pub struct SerdeName(Name); + +impl<'de> Deserialize<'de> for SerdeName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de> + { + use serde::de::Error; + + String::deserialize(deserializer) + .and_then(|string| + Name::from_utf8(&string) + .map_err(|e| Error::custom(e.to_string())) + ).map( SerdeName) + } +} + +impl Deref for SerdeName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Serialize for SerdeName { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer + { + self.0.to_utf8().serialize(serializer) + } +} + +impl SerdeName { + fn into_inner(self) -> Name { + self.0 + } +} + + +#[derive(Debug, Deserialize)] +pub struct AbsoluteName(SerdeName); + +impl<'r> FromParam<'r> for AbsoluteName { + type Error = ProtoError; + + fn from_param(param: &'r str) -> Result { + let mut name = Name::from_utf8(¶m)?; + if !name.is_fqdn() { + name.set_fqdn(true); + } + Ok(AbsoluteName(SerdeName(name))) + } +} + +impl Deref for AbsoluteName { + type Target = Name; + fn deref(&self) -> &Self::Target { + &self.0.0 + } +} +pub struct DnsClient(AsyncClient); + +impl Deref for DnsClient { + type Target = AsyncClient; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DnsClient { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DnsClient { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> Outcome { + let config = try_outcome!(request.guard::>().await); + let (stream, handle) = TcpClientStream::>::new(config.dns.server); + let client = AsyncClient::with_timeout( + stream, + handle, + std::time::Duration::from_secs(5), + None); + let (client, bg) = match client.await { + Err(e) => { + println!("Failed to connect to DNS server {:#?}", e); + return Outcome::Failure((Status::InternalServerError, ())) + }, + Ok(c) => c + }; + task::spawn(bg); + Outcome::Success(DnsClient(client)) + } +} diff --git a/src/models/user.rs b/src/models/user.rs index a52dbc9..108f1c7 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -173,7 +173,7 @@ impl LocalUser { username: user_request.username.clone(), password: make_password_with_algorithm(&user_request.password, Algorithm::Argon2), // TODO: Use role from request - role: Role::ZoneAdmin, + role: if let Some(user_role) = user_request.role { user_role } else { Role::ZoneAdmin }, }; let res = UserInfo { diff --git a/src/models/users.bak.rs b/src/models/users.bak.rs new file mode 100644 index 0000000..d19142c --- /dev/null +++ b/src/models/users.bak.rs @@ -0,0 +1,414 @@ +use uuid::Uuid; +use diesel::prelude::*; +use diesel::result::Error as DieselError; +use diesel_derive_enum::DbEnum; +use rocket::{State, request::{FromRequest, Request, Outcome}}; +use serde::{Serialize, Deserialize}; +use chrono::serde::ts_seconds; +use chrono::prelude::{DateTime, Utc}; +use chrono::Duration; +// TODO: Maybe just use argon2 crate directly +use djangohashers::{make_password_with_algorithm, check_password, HasherError, Algorithm}; +use jsonwebtoken::{ + encode, decode, + Header, Validation, + Algorithm as JwtAlgorithm, EncodingKey, DecodingKey, + errors::Result as JwtResult, + errors::ErrorKind as JwtErrorKind +}; + +use crate::schema::*; +use crate::DbConn; +use crate::config::Config; +use crate::models::errors::{ErrorResponse, make_500}; +use crate::models::dns::AbsoluteName; + +const BEARER: &str = "Bearer "; +const AUTH_HEADER: &str = "Authorization"; + + +#[derive(Debug, DbEnum, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum Role { + #[db_rename = "admin"] + Admin, + #[db_rename = "zoneadmin"] + ZoneAdmin, +} + +// TODO: Store Uuid instead of string?? +#[derive(Debug, Queryable, Identifiable, Insertable)] +#[table_name = "user"] +pub struct User { + pub id: String, +} + +#[derive(Debug, Queryable, Identifiable, Insertable)] +#[table_name = "localuser"] +#[primary_key(user_id)] +pub struct LocalUser { + pub user_id: String, + pub username: String, + pub password: String, + pub role: Role, +} + +#[derive(Debug, Queryable, Identifiable, Insertable)] +#[table_name = "user_zone"] +#[primary_key(user_id, zone_id)] +pub struct UserZone { + pub user_id: String, + pub zone_id: String, +} + +#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)] +#[table_name = "zone"] +pub struct Zone { + #[serde(skip)] + pub id: String, + pub name: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateUserRequest { + pub username: String, + pub password: String, + pub email: String, + pub role: Option +} + +#[derive(Debug, Deserialize)] +pub struct AddZoneMemberRequest { + pub id: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateZoneRequest { + pub name: AbsoluteName, +} + +// pub struct LdapUserAssociation { +// user_id: Uuid, +// ldap_id: String +// } + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthClaims { + pub jti: String, + pub sub: String, + #[serde(with = "ts_seconds")] + pub exp: DateTime, + #[serde(with = "ts_seconds")] + pub iat: DateTime, +} + +#[derive(Debug, Serialize)] +pub struct AuthTokenResponse { + pub token: String +} + +#[derive(Debug, Deserialize)] +pub struct AuthTokenRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug)] +pub struct UserInfo { + pub id: String, + pub role: Role, + pub username: String, +} + +impl UserInfo { + pub fn is_admin(&self) -> bool { + matches!(self.role, Role::Admin) + } + + pub fn check_admin(&self) -> Result<(), UserError> { + if self.is_admin() { + Ok(()) + } else { + Err(UserError::PermissionDenied) + } + } + + pub fn get_zone(&self, conn: &diesel::SqliteConnection, zone_name: &str) -> Result { + use crate::schema::user_zone::dsl::*; + use crate::schema::zone::dsl::*; + + let (res_zone, _): (Zone, UserZone) = zone.inner_join(user_zone) + .filter(name.eq(zone_name)) + .filter(user_id.eq(&self.id)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::ZoneNotFound, + other => UserError::DbError(other) + })?; + + Ok(res_zone) + } + + pub fn get_zones(&self, conn: &diesel::SqliteConnection) -> Result, UserError> { + use crate::schema::user_zone::dsl::*; + use crate::schema::zone::dsl::*; + + let res: Vec<(Zone, UserZone)> = zone.inner_join(user_zone) + .filter(user_id.eq(&self.id)) + .get_results(conn) + .map_err(UserError::DbError)?; + + Ok(res.into_iter().map(|(z, _)| z).collect()) + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for UserInfo { + type Error = ErrorResponse; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let auth_header = match request.headers().get_one(AUTH_HEADER) { + None => return Outcome::Forward(()), + Some(auth_header) => auth_header, + }; + + let token = if auth_header.starts_with(BEARER) { + auth_header.trim_start_matches(BEARER) + } else { + return ErrorResponse::from(UserError::MalformedHeader).into() + }; + + let config = try_outcome!(request.guard::>().await.map_failure(make_500)); + let conn = try_outcome!(request.guard::().await.map_failure(make_500)); + + let token_data = AuthClaims::decode( + token, &config.web_app.secret + ).map_err(|e| match e.into_kind() { + JwtErrorKind::ExpiredSignature => UserError::ExpiredToken, + _ => UserError::BadToken, + }); + + let token_data = match token_data { + Err(e) => return ErrorResponse::from(e).into(), + Ok(data) => data + }; + + let user_id = token_data.sub; + + conn.run(move |c| { + match LocalUser::get_user_by_uuid(c, &user_id) { + Err(UserError::NotFound) => ErrorResponse::from(UserError::NotFound).into(), + Err(e) => ErrorResponse::from(e).into(), + Ok(d) => Outcome::Success(d), + } + }).await + } +} + +#[derive(Debug)] +pub enum UserError { + ZoneNotFound, + NotFound, + UserConflict, + BadCreds, + BadToken, + ExpiredToken, + MalformedHeader, + PermissionDenied, + DbError(DieselError), + PasswordError(HasherError), +} + +impl From for UserError { + fn from(e: HasherError) -> Self { + UserError::PasswordError(e) + } +} + +impl From for UserError { + fn from(e: DieselError) -> Self { + UserError::DbError(e) + } +} + +impl LocalUser { + pub fn create_user(conn: &diesel::SqliteConnection, user_request: CreateUserRequest) -> Result { + use crate::schema::localuser::dsl::*; + use crate::schema::user::dsl::*; + + let new_user_id = Uuid::new_v4().to_simple().to_string(); + + let new_user = User { + id: new_user_id.clone(), + }; + + let new_localuser = LocalUser { + user_id: new_user_id, + username: user_request.username.clone(), + password: make_password_with_algorithm(&user_request.password, Algorithm::Argon2), + // TODO: Use role from request + role: Role::ZoneAdmin, + }; + + let res = UserInfo { + id: new_user.id.clone(), + role: new_localuser.role.clone(), + username: new_localuser.username.clone(), + }; + + conn.immediate_transaction(|| -> diesel::QueryResult<()> { + diesel::insert_into(user) + .values(new_user) + .execute(conn)?; + + diesel::insert_into(localuser) + .values(new_localuser) + .execute(conn)?; + + Ok(()) + }).map_err(|e| match e { + DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict, + other => UserError::DbError(other) + })?; + + Ok(res) + } + + pub fn get_user_by_creds( + conn: &diesel::SqliteConnection, + request_username: &str, + request_password: &str + ) -> Result { + + use crate::schema::localuser::dsl::*; + use crate::schema::user::dsl::*; + + let (client_user, client_localuser): (User, LocalUser) = user.inner_join(localuser) + .filter(username.eq(request_username)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::BadCreds, + other => UserError::DbError(other) + })?; + + if !check_password(&request_password, &client_localuser.password)? { + return Err(UserError::BadCreds); + } + + Ok(UserInfo { + id: client_user.id, + role: client_localuser.role, + username: client_localuser.username, + }) + } + + pub fn get_user_by_uuid(conn: &diesel::SqliteConnection, request_user_id: &str) -> Result { + use crate::schema::localuser::dsl::*; + use crate::schema::user::dsl::*; + + let (client_user, client_localuser): (User, LocalUser) = user.inner_join(localuser) + .filter(id.eq(request_user_id)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::NotFound, + other => UserError::DbError(other) + })?; + + Ok(UserInfo { + id: client_user.id, + role: client_localuser.role, + username: client_localuser.username, + }) + } + +} + +impl AuthClaims { + pub fn new(user_info: &UserInfo, token_duration: Duration) -> AuthClaims { + let jti = Uuid::new_v4().to_simple().to_string(); + let iat = Utc::now(); + let exp = iat + token_duration; + + AuthClaims { + jti, + sub: user_info.id.clone(), + exp, + iat, + } + } + + pub fn decode(token: &str, secret: &str) -> JwtResult { + decode::( + token, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::new(JwtAlgorithm::HS256) + ).map(|data| data.claims) + } + + pub fn encode(self, secret: &str) -> JwtResult { + encode(&Header::default(), &self, &EncodingKey::from_secret(secret.as_ref())) + } +} + +// NOTE: Should probably not be implemented here +// also, "UserError" seems like a misleading name +impl Zone { + pub fn get_all(conn: &diesel::SqliteConnection) -> Result, UserError> { + use crate::schema::zone::dsl::*; + + zone.get_results(conn) + .map_err(UserError::DbError) + } + + pub fn get_by_name(conn: &diesel::SqliteConnection, zone_name: &str) -> Result { + use crate::schema::zone::dsl::*; + + zone.filter(name.eq(zone_name)) + .get_result(conn) + .map_err(|e| match e { + DieselError::NotFound => UserError::ZoneNotFound, + other => UserError::DbError(other) + }) + } + + pub fn create_zone(conn: &diesel::SqliteConnection, zone_request: CreateZoneRequest) -> Result { + use crate::schema::zone::dsl::*; + + let new_zone = Zone { + id: Uuid::new_v4().to_simple().to_string(), + name: zone_request.name.to_utf8(), + }; + + diesel::insert_into(zone) + .values(&new_zone) + .execute(conn) + .map_err(|e| match e { + DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _) => UserError::UserConflict, + other => UserError::DbError(other) + })?; + Ok(new_zone) + } + + + pub fn add_member(&self, conn: &diesel::SqliteConnection, new_member: &UserInfo) -> Result<(), UserError> { + use crate::schema::user_zone::dsl::*; + + let new_user_zone = UserZone { + zone_id: self.id.clone(), + user_id: new_member.id.clone() + }; + + let res = diesel::insert_into(user_zone) + .values(new_user_zone) + .execute(conn); + + match res { + // If user has already access to the zone, safely ignore the conflit + // TODO: use 'on conflict do nothing' in postgres when we get there + Err(DieselError::DatabaseError(diesel::result::DatabaseErrorKind::UniqueViolation, _)) => (), + Err(e) => return Err(e.into()), + Ok(_) => () + }; + Ok(()) + } +}