diff --git a/Cargo.lock b/Cargo.lock index 5901571..e56bff3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1411,6 +1411,7 @@ dependencies = [ "serde", "thiserror", "tokio", + "tower", "tower-http", "tower-livereload", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 5e3d7f5..2afab14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ sea-orm-migration = { version = "0.12.6", features = ["sqlx-sqlite"] } uuid = { version = "1.5.0", features = ["v7", "atomic", "fast-rng", "macro-diagnostics"] } password-hash = "0.5.0" axum-login = "0.7.3" +tower = "0.4.13" diff --git a/src/router.rs b/src/router.rs index c797a59..558f9cd 100644 --- a/src/router.rs +++ b/src/router.rs @@ -4,18 +4,21 @@ use crate::{error::AppError, views}; use argon2::password_hash::rand_core::OsRng; use argon2::password_hash::SaltString; use argon2::{Argon2, PasswordHasher}; -use axum::async_trait; +use axum::error_handling::HandleErrorLayer; use axum::extract::State; +use axum::{async_trait, BoxError}; use axum::{http::StatusCode, response::Html, routing::get, Form, Router}; use axum_csrf::{CsrfConfig, CsrfLayer, CsrfToken}; -use axum_login::{AuthUser, AuthnBackend, UserId}; +use axum_login::tower_sessions::{MemoryStore, SessionManagerLayer}; +use axum_login::{AuthManagerLayer, AuthUser, AuthnBackend, UserId}; use base64::prelude::*; use maud::html; use notify::Watcher; +use password_hash::{PasswordHash, PasswordVerifier}; use sea_orm::*; use serde::Deserialize; -use std::sync::Arc; use std::{env, path::Path}; +use tower::ServiceBuilder; use tower_http::services::ServeDir; use tower_livereload::LiveReloadLayer; use tracing::{info, instrument}; @@ -41,6 +44,16 @@ pub async fn new() -> Result { let cookie_key = cookie::Key::from(&cookie_key_bytes); let csrf_config = CsrfConfig::default().with_key(Some(cookie_key)); + let session_store = MemoryStore::default(); + let login_session_manager = + SessionManagerLayer::new(session_store).with_name("login_sessions.sid"); + + let auth_service = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| async { + StatusCode::BAD_REQUEST + })) + .layer(AuthManagerLayer::new(state.clone(), login_session_manager)); + let router = Router::new() .fallback(|| async { (StatusCode::NOT_FOUND, "404 page not found") }) .nest("/app", app_router) @@ -51,6 +64,7 @@ pub async fn new() -> Result { .route("/all_users", get(views::all_users)) .with_state(state) .layer(CsrfLayer::new(csrf_config)) + .layer(auth_service) .layer(live_reload_layer); Ok(router) @@ -94,13 +108,18 @@ where .to_string()) } +fn password_verify(password: S, current_digest: S) -> Result<(), password_hash::Error> +where + S: AsRef, +{ + let password_bytes = password.as_ref().as_bytes(); + let current = PasswordHash::new(current_digest.as_ref())?; + Ok(Argon2::default().verify_password(password_bytes, ¤t)?) +} + impl user::ActiveModel {} -async fn register( - c: CsrfToken, - State(s): State>, - Form(f): Form, -) -> AppRes { +async fn register(c: CsrfToken, State(s): State, Form(f): Form) -> AppRes { csrf_verify(c, &f.authenticity_token)?; // TODO: handle duplicate username @@ -151,12 +170,18 @@ impl AuthnBackend for state::State { type Error = AppError; async fn authenticate(&self, l: Self::Credentials) -> Result, Self::Error> { - Ok(User::find() + match User::find() .filter(user::Column::Username.eq(l.username)) // TODO: will this have index problems since I'm searching over the password digest? - .filter(user::Column::PasswordDigest.eq(password_digest(l.password)?)) .one(&self.db) - .await?) + .await? + { + Some(user) => match password_verify(&l.password, &user.password_digest) { + Ok(()) => Ok(Some(user)), + Err(e) => Err(e.into()), + }, + None => Ok(None), + } } async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { diff --git a/src/state.rs b/src/state.rs index b7a1480..356a2d0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,4 @@ -use std::{env, sync::Arc}; +use std::env; use sea_orm::{Database, DatabaseConnection}; use sea_orm_migration::MigratorTrait; @@ -11,12 +11,12 @@ pub struct State { } impl State { - pub async fn new() -> Result, anyhow::Error> { + pub async fn new() -> Result { let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); let db = Database::connect(database_url).await?; Migrator::refresh(&db).await?; - Ok(Arc::new(State { db })) + Ok(State { db }) } } diff --git a/src/views.rs b/src/views.rs index 0d6b5e4..880311b 100644 --- a/src/views.rs +++ b/src/views.rs @@ -12,7 +12,6 @@ use axum::{ use axum_csrf::CsrfToken; use maud::html; use sea_orm::EntityTrait; -use std::sync::Arc; use tracing::instrument; pub async fn csrf(csrf: CsrfToken, cb: F) -> impl IntoResponse @@ -108,7 +107,7 @@ pub async fn login(t: CsrfToken) -> impl IntoResponse { .await } -pub async fn all_users(State(s): State>) -> AppRes { +pub async fn all_users(State(s): State) -> AppRes { let users: Vec = User::find().all(&s.db).await?; Ok((