125 lines
3.4 KiB
Rust

use crate::items;
use crate::proto_utils::Proto;
use rocket::State;
use std::collections::HashMap;
use std::sync::Mutex;
use uuid::Uuid;
pub struct AuthState {
// Map token -> username
tokens: Mutex<HashMap<String, String>>,
}
impl AuthState {
pub fn new() -> Self {
Self {
tokens: Mutex::new(HashMap::new()),
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct Token(pub String);
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for Token {
type Error = ();
async fn from_request(
request: &'r rocket::Request<'_>,
) -> rocket::request::Outcome<Self, Self::Error> {
let token = request.headers().get_one("Authorization");
match token {
Some(token) => {
// Check if token starts with "Bearer "
if let Some(token) = token.strip_prefix("Bearer ") {
let state = request.guard::<&State<AuthState>>().await.unwrap();
let tokens = state.tokens.lock().unwrap();
if tokens.contains_key(token) {
return rocket::request::Outcome::Success(Token(token.to_string()));
}
}
rocket::request::Outcome::Error((rocket::http::Status::Unauthorized, ()))
}
None => rocket::request::Outcome::Error((rocket::http::Status::Unauthorized, ())),
}
}
}
#[post("/login", data = "<request>")]
pub fn login(
state: &State<AuthState>,
user_list: &State<Vec<crate::User>>,
request: Proto<items::LoginRequest>,
) -> items::LoginResponse {
let req = request.into_inner();
if let Some(user) = user_list.iter().find(|u| u.name == req.username)
&& bcrypt::verify(&req.password, &user.password_hash).unwrap_or(false)
{
let token = Uuid::new_v4().to_string();
let mut tokens = state.tokens.lock().unwrap();
tokens.insert(token.clone(), req.username);
return items::LoginResponse {
token,
success: true,
message: "Login successful".to_string(),
};
}
items::LoginResponse {
token: "".to_string(),
success: false,
message: "Invalid credentials".to_string(),
}
}
#[post("/logout", data = "<request>")]
pub fn logout(
state: &State<AuthState>,
request: Proto<items::LogoutRequest>,
) -> items::LogoutResponse {
let req = request.into_inner();
let mut tokens = state.tokens.lock().unwrap();
if tokens.remove(&req.token).is_some() {
items::LogoutResponse {
success: true,
message: "Logged out successfully".to_string(),
}
} else {
items::LogoutResponse {
success: false,
message: "Invalid token".to_string(),
}
}
}
#[post("/get_auth_status", data = "<request>")]
pub fn get_auth_status(
state: &State<AuthState>,
request: Proto<items::AuthStatusRequest>,
) -> items::AuthStatusResponse {
let req = request.into_inner();
let tokens = state.tokens.lock().unwrap();
if let Some(username) = tokens.get(&req.token) {
items::AuthStatusResponse {
authenticated: true,
username: username.clone(),
message: "Authenticated".to_string(),
}
} else {
items::AuthStatusResponse {
authenticated: false,
username: "".to_string(),
message: "Not authenticated".to_string(),
}
}
}