11use std:: sync:: { Arc , OnceLock } ;
22
3- use axum:: { extract:: State , http:: StatusCode , routing:: post, Json , Router } ;
3+ use axum:: {
4+ extract:: State ,
5+ http:: { HeaderMap , StatusCode } ,
6+ routing:: post,
7+ Json , Router ,
8+ } ;
49use jsonwebtoken:: { get_current_timestamp, DecodingKey , EncodingKey , Validation } ;
510use log:: { info, warn} ;
611use ring:: rand:: { SecureRandom , SystemRandom } ;
712use serde:: { Deserialize , Serialize } ;
13+ use serde_repr:: { Deserialize_repr , Serialize_repr } ;
814
915use crate :: { util, AppState } ;
1016
1117#[ derive( Deserialize , Clone ) ]
1218pub struct AuthConfig {
1319 route : String ,
20+ refresh_subroute : String ,
1421 secret_path : String ,
15- valid_secs : u64 ,
22+ valid_secs_refresh : u64 ,
23+ valid_secs_session : u64 ,
1624}
1725
1826#[ derive( Deserialize ) ]
@@ -21,11 +29,19 @@ pub struct AuthRequest {
2129 password : String ,
2230}
2331
32+ #[ repr( u8 ) ]
33+ #[ derive( Deserialize_repr , Serialize_repr , PartialEq , Eq ) ]
34+ pub enum TokenKind {
35+ Refresh = 0 ,
36+ Session = 1 ,
37+ }
38+
2439#[ derive( Deserialize , Serialize ) ]
2540pub struct Claims {
26- sub : String , // account id as a string
27- crt : u64 , // creation timestamp in UTC
28- exp : u64 , // expiration timestamp in UTC
41+ sub : String , // account id as a string
42+ crt : u64 , // creation timestamp in UTC
43+ exp : u64 , // expiration timestamp in UTC
44+ kind : TokenKind , // kind of token
2945}
3046
3147static SECRET_KEY : OnceLock < Vec < u8 > > = OnceLock :: new ( ) ;
@@ -55,21 +71,33 @@ pub fn register(
5571 rng : & SystemRandom ,
5672) -> Router < Arc < AppState > > {
5773 let route = & config. route ;
74+ let refresh_route = util:: get_subroute ( route, & config. refresh_subroute ) ;
5875 info ! ( "Registering auth route @ {}" , route) ;
76+ info ! ( "\t Refresh route @ {}" , refresh_route) ;
5977 check_secret ( & config. secret_path , rng) ;
60- routes. route ( route, post ( do_auth) )
78+ routes
79+ . route ( route, post ( do_auth) )
80+ . route ( & refresh_route, post ( do_refresh) )
6181}
6282
63- fn gen_jwt ( account_id : i64 , valid_secs : u64 ) -> Result < String , String > {
83+ fn gen_jwt ( auth_config : & AuthConfig , account_id : i64 , kind : TokenKind ) -> Result < String , String > {
6484 let secret = SECRET_KEY . get ( ) . unwrap ( ) ;
6585 let key = EncodingKey :: from_secret ( secret) ;
86+
87+ let valid_secs = match kind {
88+ TokenKind :: Refresh => auth_config. valid_secs_refresh ,
89+ TokenKind :: Session => auth_config. valid_secs_session ,
90+ } ;
91+
6692 let crt = get_current_timestamp ( ) ;
6793 let exp = crt + valid_secs;
6894 let claims = Claims {
6995 sub : account_id. to_string ( ) ,
7096 crt,
7197 exp,
98+ kind,
7299 } ;
100+
73101 jsonwebtoken:: encode ( & jsonwebtoken:: Header :: default ( ) , & claims, & key)
74102 . map_err ( |e| format ! ( "JWT error: {}" , e) )
75103}
@@ -85,7 +113,7 @@ fn get_validator(account_id: Option<i64>) -> Validation {
85113 validation
86114}
87115
88- pub fn validate_jwt ( jwt : & str ) -> Result < i64 , String > {
116+ pub fn validate_jwt ( jwt : & str , kind : TokenKind ) -> Result < i64 , String > {
89117 let Some ( secret) = SECRET_KEY . get ( ) else {
90118 return Err ( "Auth module not initialized" . to_string ( ) ) ;
91119 } ;
@@ -102,6 +130,10 @@ pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
102130 return Err ( "Expired JWT" . to_string ( ) ) ;
103131 }
104132
133+ if token. claims . kind != kind {
134+ return Err ( "Bad token kind" . to_string ( ) ) ;
135+ }
136+
105137 match token. claims . sub . parse ( ) {
106138 Ok ( id) => Ok ( id) ,
107139 Err ( e) => Err ( format ! ( "Bad account ID: {}" , e) ) ,
@@ -118,8 +150,11 @@ async fn do_auth(
118150 warn ! ( "Auth error: {}" , e) ;
119151 ( StatusCode :: UNAUTHORIZED , "Invalid credentials" . to_string ( ) )
120152 } ) ?;
121- let valid_secs = app. config . auth . as_ref ( ) . unwrap ( ) . valid_secs ;
122- match gen_jwt ( account_id, valid_secs) {
153+ match gen_jwt (
154+ app. config . auth . as_ref ( ) . unwrap ( ) ,
155+ account_id,
156+ TokenKind :: Refresh ,
157+ ) {
123158 Ok ( jwt) => Ok ( jwt) ,
124159 Err ( e) => {
125160 warn ! ( "Auth error: {}" , e) ;
@@ -130,3 +165,30 @@ async fn do_auth(
130165 }
131166 }
132167}
168+
169+ async fn do_refresh (
170+ State ( app) : State < Arc < AppState > > ,
171+ headers : HeaderMap ,
172+ ) -> Result < String , ( StatusCode , String ) > {
173+ assert ! ( app. is_tls) ;
174+ let db = app. db . lock ( ) . await ;
175+ // TODO validate the refresh token against the last password reset timestamp
176+ let account_id = match util:: validate_authed_request ( & headers, TokenKind :: Refresh ) {
177+ Ok ( id) => id,
178+ Err ( e) => return Err ( ( StatusCode :: UNAUTHORIZED , e) ) ,
179+ } ;
180+ match gen_jwt (
181+ app. config . auth . as_ref ( ) . unwrap ( ) ,
182+ account_id,
183+ TokenKind :: Session ,
184+ ) {
185+ Ok ( jwt) => Ok ( jwt) ,
186+ Err ( e) => {
187+ warn ! ( "Refresh error: {}" , e) ;
188+ Err ( (
189+ StatusCode :: INTERNAL_SERVER_ERROR ,
190+ "Server error" . to_string ( ) ,
191+ ) )
192+ }
193+ }
194+ }
0 commit comments