From 22761ee92ea41eb9014f85ca4a2e7a0d6d23fd5c Mon Sep 17 00:00:00 2001 From: Daniel Flanagan Date: Tue, 14 Nov 2023 14:07:47 -0600 Subject: [PATCH] Setup a proper error type for the web application --- src/error.rs | 46 +++++++++++++++++++++++++++++++--------------- src/router.rs | 13 +++---------- src/views.rs | 34 +++++++++++++++++++--------------- 3 files changed, 53 insertions(+), 40 deletions(-) diff --git a/src/error.rs b/src/error.rs index fb3678e..01cae08 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,25 +1,41 @@ +use core::fmt; + use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; +use thiserror::Error; -pub struct AppError(anyhow::Error); +#[derive(Error, Debug)] +pub enum AppError { + InvalidCsrf(#[from] axum_csrf::CsrfError), + Other(#[from] anyhow::Error), +} + +impl AppError { + fn status_code(&self) -> StatusCode { + match self { + AppError::Other(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::InvalidCsrf(_) => StatusCode::BAD_REQUEST, + } + } + + fn message(&self) -> String { + match self { + AppError::Other(e) => format!("something went wrong: {}", e), + AppError::InvalidCsrf(e) => format!("unable to verify csrf: {}", e), + } + } +} + +impl fmt::Display for AppError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message()) + } +} impl IntoResponse for AppError { fn into_response(self) -> Response { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Something went wrong: {}", self.0), - ) - .into_response() - } -} - -impl From for AppError -where - E: Into, -{ - fn from(err: E) -> Self { - Self(err.into()) + (self.status_code(), self.message()).into_response() } } diff --git a/src/router.rs b/src/router.rs index acffcbb..fec7920 100644 --- a/src/router.rs +++ b/src/router.rs @@ -66,7 +66,7 @@ struct Register { } impl<'a> TryInto> for &'a Register { - type Error = argon2::password_hash::Error; + type Error = anyhow::Error; fn try_into(self: &'a Register) -> Result, Self::Error> { let salt = SaltString::generate(&mut OsRng); @@ -86,18 +86,11 @@ impl<'a> TryInto> for &'a Register { } async fn register( - csrf_token: CsrfToken, + c: CsrfToken, Form(register): Form, ) -> Result { // TODO: https://docs.rs/axum_csrf/latest/axum_csrf/#prevent-post-replay-attacks-with-csrf - - let v = csrf_token.verify(®ister.authenticity_token); - if v.is_err() { - return Ok(( - StatusCode::BAD_REQUEST, - Html(html! { "invalid request" }.into_string()), - )); - } + c.verify(®ister.authenticity_token)?; let new_user: NewUser = (®ister).try_into()?; diff --git a/src/views.rs b/src/views.rs index b75d7ce..76a1978 100644 --- a/src/views.rs +++ b/src/views.rs @@ -12,6 +12,14 @@ use crate::{ partials::{footer, header}, }; +pub async fn csrf(csrf: CsrfToken, cb: F) -> impl IntoResponse +where + F: Fn(&str) -> Html, +{ + let token = csrf.authenticity_token().unwrap(); + (csrf, cb(&token)) +} + #[instrument] pub async fn index() -> Html { Html( @@ -39,10 +47,8 @@ pub async fn index() -> Html { ) } -pub async fn register(csrf: CsrfToken) -> impl IntoResponse { - let token = csrf.authenticity_token().unwrap(); - ( - csrf, +pub async fn register(t: CsrfToken) -> impl IntoResponse { + csrf(t, |token| { Html( html! { (header()) @@ -64,22 +70,20 @@ pub async fn register(csrf: CsrfToken) -> impl IntoResponse { (footer()) } .into_string(), - ), - ) - .into_response() + ) + }) + .await } -pub async fn login(csrf: CsrfToken) -> impl IntoResponse { - let token = csrf.authenticity_token().unwrap(); - ( - csrf, +pub async fn login(t: CsrfToken) -> impl IntoResponse { + csrf(t, |token| { Html( html! { (header()) main class="prose" { h1 { "Login" } form method="post" { - input type="hidden" name="authenticity_token" value=(token) {} + input type="hidden" name="authenticity_token" value=(token) {} label { input {} } @@ -91,9 +95,9 @@ pub async fn login(csrf: CsrfToken) -> impl IntoResponse { (footer()) } .into_string(), - ), - ) - .into_response() + ) + }) + .await } #[allow(unreachable_code)]