diff --git a/backend/src/auth.rs b/backend/src/auth.rs index 4921228..7246c14 100644 --- a/backend/src/auth.rs +++ b/backend/src/auth.rs @@ -18,6 +18,38 @@ impl AuthState { } } +#[derive(Debug)] +#[allow(dead_code)] +pub struct Token(pub String); + +#[rocket::async_trait] +impl<'r> rocket::request::FromRequest<'r> for Token { + type Error = (); + + async fn from_request( + request: &'r rocket::Request<'_>, + ) -> rocket::request::Outcome { + let token = request.headers().get_one("Authorization"); + + match token { + Some(token) => { + // Check if token starts with "Bearer " + if let Some(token) = token.strip_prefix("Bearer ") { + let state = request.guard::<&State>().await.unwrap(); + let tokens = state.tokens.lock().unwrap(); + + if tokens.contains_key(token) { + return rocket::request::Outcome::Success(Token(token.to_string())); + } + } + + rocket::request::Outcome::Error((rocket::http::Status::Unauthorized, ())) + } + None => rocket::request::Outcome::Error((rocket::http::Status::Unauthorized, ())), + } + } +} + #[post("/login", data = "")] pub fn login( state: &State, diff --git a/backend/src/main.rs b/backend/src/main.rs index eda72b7..5e2f97e 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -25,7 +25,11 @@ impl std::ops::Deref for User { } #[get("/")] -fn get_user(user_list: &rocket::State>, name: String) -> Option { +fn get_user( + _token: auth::Token, + user_list: &rocket::State>, + name: String, +) -> Option { user_list .iter() .find(|user| user.person.name == name) @@ -33,7 +37,7 @@ fn get_user(user_list: &rocket::State>, name: String) -> Option>) -> items::PersonList { +fn get_users(_token: auth::Token, user_list: &rocket::State>) -> items::PersonList { items::PersonList { person: user_list .inner() diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 5ec67df..5456315 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,7 +10,11 @@ function App() { useEffect(() => { if (!token) return; - fetch("/api") + fetch("/api", { + headers: { + Authorization: `Bearer ${token}`, + }, + }) .then((res) => res.arrayBuffer()) .then((buffer) => { const list = PersonList.decode(new Uint8Array(buffer));