diff --git a/README.md b/README.md index ae960011..9f6c7f19 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,9 @@ This course is designed with a lot of details, so that everyone, even with very - Lecture #74: [Graceful shutdown gRPC/HTTP servers and async worker](https://www.youtube.com/watch?v=TdB2W8l4dHw&list=PLy_6D98if3ULEtXtNSY_2qN21VCKgoQAE) - Lecture #75: [Go 1.22 fixed the most common for-loop trap](https://www.youtube.com/watch?v=rIHO9TqJtQQ&list=PLy_6D98if3ULEtXtNSY_2qN21VCKgoQAE) - Lecture #76: [Setup CORS policy with Go and VueJS](https://www.youtube.com/watch?v=hOz4f4SdArc&list=PLy_6D98if3ULEtXtNSY_2qN21VCKgoQAE) +- Lecture #77: [Upgrade golang JWT package to v5](https://www.youtube.com/watch?v=iVk3jOF1Cv4&list=PLy_6D98if3ULEtXtNSY_2qN21VCKgoQAE) -## Frontend course videos (Vue.JS) +## Frontend crash course videos (Vue.JS) - Lecture #1: [Build reactive web app with VueJS](https://www.youtube.com/watch?v=fRGgDBCWQJg&list=PLy_6D98if3UI3rsFRTHM1LMtVprYMp-GT) - Lecture #2: [Introduction to Vue router and Vue component](https://www.youtube.com/watch?v=4rv484TofFA&list=PLy_6D98if3UI3rsFRTHM1LMtVprYMp-GT) diff --git a/api/middleware.go b/api/middleware.go index dfd55252..56b5dddd 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -42,7 +42,7 @@ func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { } accessToken := fields[1] - payload, err := tokenMaker.VerifyToken(accessToken) + payload, err := tokenMaker.VerifyToken(accessToken, token.TokenTypeAccessToken) if err != nil { ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) return diff --git a/api/middleware_test.go b/api/middleware_test.go index 8322f0da..9065ac1e 100644 --- a/api/middleware_test.go +++ b/api/middleware_test.go @@ -22,7 +22,7 @@ func addAuthorization( role string, duration time.Duration, ) { - token, payload, err := tokenMaker.CreateToken(username, role, duration) + token, payload, err := tokenMaker.CreateToken(username, role, duration, token.TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, payload) diff --git a/api/token.go b/api/token.go index 7d7b6152..7183ad79 100644 --- a/api/token.go +++ b/api/token.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" db "github.com/techschool/simplebank/db/sqlc" + "github.com/techschool/simplebank/token" ) type renewAccessTokenRequest struct { @@ -26,7 +27,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) { return } - refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) + refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken, token.TokenTypeRefreshToken) if err != nil { ctx.JSON(http.StatusUnauthorized, errorResponse(err)) return @@ -70,6 +71,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) { refreshPayload.Username, refreshPayload.Role, server.config.AccessTokenDuration, + token.TokenTypeAccessToken, ) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) diff --git a/api/user.go b/api/user.go index 94de4768..640f3c6c 100644 --- a/api/user.go +++ b/api/user.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" db "github.com/techschool/simplebank/db/sqlc" + "github.com/techschool/simplebank/token" "github.com/techschool/simplebank/util" ) @@ -111,6 +112,7 @@ func (server *Server) loginUser(ctx *gin.Context) { user.Username, user.Role, server.config.AccessTokenDuration, + token.TokenTypeAccessToken, ) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) @@ -121,6 +123,7 @@ func (server *Server) loginUser(ctx *gin.Context) { user.Username, user.Role, server.config.RefreshTokenDuration, + token.TokenTypeRefreshToken, ) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) diff --git a/app.env b/app.env index 8bd7f24e..7f047d31 100644 --- a/app.env +++ b/app.env @@ -5,7 +5,7 @@ MIGRATION_URL=file://db/migration HTTP_SERVER_ADDRESS=0.0.0.0:8080 GRPC_SERVER_ADDRESS=0.0.0.0:9090 TOKEN_SYMMETRIC_KEY=12345678901234567890123456789012 -ACCESS_TOKEN_DURATION=15m +ACCESS_TOKEN_DURATION=1m REFRESH_TOKEN_DURATION=24h REDIS_ADDRESS=0.0.0.0:6379 EMAIL_SENDER_NAME=Simple Bank diff --git a/gapi/authorization.go b/gapi/authorization.go index 892467db..3c76d663 100644 --- a/gapi/authorization.go +++ b/gapi/authorization.go @@ -37,7 +37,7 @@ func (server *Server) authorizeUser(ctx context.Context, accessibleRoles []strin } accessToken := fields[1] - payload, err := server.tokenMaker.VerifyToken(accessToken) + payload, err := server.tokenMaker.VerifyToken(accessToken, token.TokenTypeAccessToken) if err != nil { return nil, fmt.Errorf("invalid access token: %s", err) } diff --git a/gapi/main_test.go b/gapi/main_test.go index ab741f60..36bda757 100644 --- a/gapi/main_test.go +++ b/gapi/main_test.go @@ -26,8 +26,8 @@ func newTestServer(t *testing.T, store db.Store, taskDistributor worker.TaskDist return server } -func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, role string, duration time.Duration) context.Context { - accessToken, _, err := tokenMaker.CreateToken(username, role, duration) +func newContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, role string, duration time.Duration, tokenType token.TokenType) context.Context { + accessToken, _, err := tokenMaker.CreateToken(username, role, duration, tokenType) require.NoError(t, err) bearerToken := fmt.Sprintf("%s %s", authorizationBearer, accessToken) diff --git a/gapi/rpc_create_user.go b/gapi/rpc_create_user.go index ec87dd0d..66b2ee60 100644 --- a/gapi/rpc_create_user.go +++ b/gapi/rpc_create_user.go @@ -50,7 +50,7 @@ func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) txResult, err := server.store.CreateUserTx(ctx, arg) if err != nil { if db.ErrorCode(err) == db.UniqueViolation { - return nil, status.Errorf(codes.AlreadyExists, err.Error()) + return nil, status.Error(codes.AlreadyExists, err.Error()) } return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } diff --git a/gapi/rpc_login_user.go b/gapi/rpc_login_user.go index 259ba213..139a953d 100644 --- a/gapi/rpc_login_user.go +++ b/gapi/rpc_login_user.go @@ -6,6 +6,7 @@ import ( db "github.com/techschool/simplebank/db/sqlc" "github.com/techschool/simplebank/pb" + "github.com/techschool/simplebank/token" "github.com/techschool/simplebank/util" "github.com/techschool/simplebank/val" "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -37,6 +38,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( user.Username, user.Role, server.config.AccessTokenDuration, + token.TokenTypeAccessToken, ) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create access token") @@ -46,6 +48,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( user.Username, user.Role, server.config.RefreshTokenDuration, + token.TokenTypeRefreshToken, ) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create refresh token") diff --git a/gapi/rpc_update_user_test.go b/gapi/rpc_update_user_test.go index c906201e..0774aced 100644 --- a/gapi/rpc_update_user_test.go +++ b/gapi/rpc_update_user_test.go @@ -67,7 +67,7 @@ func TestUpdateUserAPI(t *testing.T) { Return(updatedUser, nil) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute, token.TokenTypeAccessToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.NoError(t, err) @@ -112,7 +112,7 @@ func TestUpdateUserAPI(t *testing.T) { Return(updatedUser, nil) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, banker.Username, banker.Role, time.Minute) + return newContextWithBearerToken(t, tokenMaker, banker.Username, banker.Role, time.Minute, token.TokenTypeAccessToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.NoError(t, err) @@ -136,7 +136,7 @@ func TestUpdateUserAPI(t *testing.T) { Times(0) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, other.Username, other.Role, time.Minute) + return newContextWithBearerToken(t, tokenMaker, other.Username, other.Role, time.Minute, token.TokenTypeAccessToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) @@ -159,7 +159,7 @@ func TestUpdateUserAPI(t *testing.T) { Return(db.User{}, db.ErrRecordNotFound) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute, token.TokenTypeAccessToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) @@ -181,7 +181,7 @@ func TestUpdateUserAPI(t *testing.T) { Times(0) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute, token.TokenTypeAccessToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) @@ -203,7 +203,29 @@ func TestUpdateUserAPI(t *testing.T) { Times(0) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { - return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, -time.Minute) + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, -time.Minute, token.TokenTypeAccessToken) + }, + checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unauthenticated, st.Code()) + }, + }, + { + name: "WrongTokenType", + req: &pb.UpdateUserRequest{ + Username: user.Username, + FullName: &newName, + Email: &newEmail, + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + UpdateUser(gomock.Any(), gomock.Any()). + Times(0) + }, + buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { + return newContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute, token.TokenTypeRefreshToken) }, checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) { require.Error(t, err) diff --git a/go.mod b/go.mod index 6ae7ee9e..024d9a16 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/techschool/simplebank -go 1.22 +go 1.24 require ( github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29 diff --git a/simplebank b/simplebank new file mode 100755 index 00000000..7f819373 Binary files /dev/null and b/simplebank differ diff --git a/token/jwt_maker.go b/token/jwt_maker.go index c130dfae..59af469d 100644 --- a/token/jwt_maker.go +++ b/token/jwt_maker.go @@ -24,8 +24,8 @@ func NewJWTMaker(secretKey string) (Maker, error) { } // CreateToken creates a new token for a specific username and duration -func (maker *JWTMaker) CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) { - payload, err := NewPayload(username, role, duration) +func (maker *JWTMaker) CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) { + payload, err := NewPayload(username, role, duration, tokenType) if err != nil { return "", payload, err } @@ -36,7 +36,7 @@ func (maker *JWTMaker) CreateToken(username string, role string, duration time.D } // VerifyToken checks if the token is valid or not -func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { +func (maker *JWTMaker) VerifyToken(token string, tokenType TokenType) (*Payload, error) { keyFunc := func(token *jwt.Token) (interface{}, error) { _, ok := token.Method.(*jwt.SigningMethodHMAC) if !ok { @@ -58,5 +58,10 @@ func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { return nil, ErrInvalidToken } + err = payload.Valid(tokenType) + if err != nil { + return nil, err + } + return payload, nil } diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go index 5ab81777..277c3010 100644 --- a/token/jwt_maker_test.go +++ b/token/jwt_maker_test.go @@ -20,12 +20,12 @@ func TestJWTMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, payload, err := maker.CreateToken(username, role, duration) + token, payload, err := maker.CreateToken(username, role, duration, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) - payload, err = maker.VerifyToken(token) + payload, err = maker.VerifyToken(token, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) @@ -40,19 +40,19 @@ func TestExpiredJWTToken(t *testing.T) { maker, err := NewJWTMaker(util.RandomString(32)) require.NoError(t, err) - token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) - payload, err = maker.VerifyToken(token) + payload, err = maker.VerifyToken(token, TokenTypeAccessToken) require.Error(t, err) require.EqualError(t, err, ErrExpiredToken.Error()) require.Nil(t, payload) } func TestInvalidJWTTokenAlgNone(t *testing.T) { - payload, err := NewPayload(util.RandomOwner(), util.DepositorRole, time.Minute) + payload, err := NewPayload(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) require.NoError(t, err) jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) @@ -62,7 +62,22 @@ func TestInvalidJWTTokenAlgNone(t *testing.T) { maker, err := NewJWTMaker(util.RandomString(32)) require.NoError(t, err) - payload, err = maker.VerifyToken(token) + payload, err = maker.VerifyToken(token, TokenTypeAccessToken) + require.Error(t, err) + require.EqualError(t, err, ErrInvalidToken.Error()) + require.Nil(t, payload) +} + +func TestJWTWrongTokenType(t *testing.T) { + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotEmpty(t, payload) + + payload, err = maker.VerifyToken(token, TokenTypeRefreshToken) require.Error(t, err) require.EqualError(t, err, ErrInvalidToken.Error()) require.Nil(t, payload) diff --git a/token/maker.go b/token/maker.go index 9466e787..9c65fe95 100644 --- a/token/maker.go +++ b/token/maker.go @@ -7,8 +7,8 @@ import ( // Maker is an interface for managing tokens type Maker interface { // CreateToken creates a new token for a specific username and duration - CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) + CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) // VerifyToken checks if the token is valid or not - VerifyToken(token string) (*Payload, error) + VerifyToken(token string, tokenType TokenType) (*Payload, error) } diff --git a/token/paseto_maker.go b/token/paseto_maker.go index d855837f..8ee9ffcb 100644 --- a/token/paseto_maker.go +++ b/token/paseto_maker.go @@ -29,8 +29,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) { } // CreateToken creates a new token for a specific username and duration -func (maker *PasetoMaker) CreateToken(username string, role string, duration time.Duration) (string, *Payload, error) { - payload, err := NewPayload(username, role, duration) +func (maker *PasetoMaker) CreateToken(username string, role string, duration time.Duration, tokenType TokenType) (string, *Payload, error) { + payload, err := NewPayload(username, role, duration, tokenType) if err != nil { return "", payload, err } @@ -40,7 +40,7 @@ func (maker *PasetoMaker) CreateToken(username string, role string, duration tim } // VerifyToken checks if the token is valid or not -func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { +func (maker *PasetoMaker) VerifyToken(token string, tokenType TokenType) (*Payload, error) { payload := &Payload{} err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil) @@ -48,7 +48,7 @@ func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { return nil, ErrInvalidToken } - err = payload.Valid() + err = payload.Valid(tokenType) if err != nil { return nil, err } diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go index f95573cf..86f74159 100644 --- a/token/paseto_maker_test.go +++ b/token/paseto_maker_test.go @@ -19,12 +19,12 @@ func TestPasetoMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, payload, err := maker.CreateToken(username, role, duration) + token, payload, err := maker.CreateToken(username, role, duration, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) - payload, err = maker.VerifyToken(token) + payload, err = maker.VerifyToken(token, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) @@ -39,13 +39,28 @@ func TestExpiredPasetoToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(32)) require.NoError(t, err) - token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute) + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, -time.Minute, TokenTypeAccessToken) require.NoError(t, err) require.NotEmpty(t, token) require.NotEmpty(t, payload) - payload, err = maker.VerifyToken(token) + payload, err = maker.VerifyToken(token, TokenTypeAccessToken) require.Error(t, err) require.EqualError(t, err, ErrExpiredToken.Error()) require.Nil(t, payload) } + +func TestPasetoWrongTokenType(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + token, payload, err := maker.CreateToken(util.RandomOwner(), util.DepositorRole, time.Minute, TokenTypeAccessToken) + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotEmpty(t, payload) + + payload, err = maker.VerifyToken(token, TokenTypeRefreshToken) + require.Error(t, err) + require.EqualError(t, err, ErrInvalidToken.Error()) + require.Nil(t, payload) +} diff --git a/token/payload.go b/token/payload.go index 05ec6397..2130681f 100644 --- a/token/payload.go +++ b/token/payload.go @@ -14,9 +14,17 @@ var ( ErrExpiredToken = errors.New("token has expired") ) +type TokenType byte + +const ( + TokenTypeAccessToken = 1 + TokenTypeRefreshToken = 2 +) + // Payload contains the payload data of the token type Payload struct { ID uuid.UUID `json:"id"` + Type TokenType `json:"token_type"` Username string `json:"username"` Role string `json:"role"` IssuedAt time.Time `json:"issued_at"` @@ -24,7 +32,7 @@ type Payload struct { } // NewPayload creates a new token payload with a specific username and duration -func NewPayload(username string, role string, duration time.Duration) (*Payload, error) { +func NewPayload(username string, role string, duration time.Duration, tokenType TokenType) (*Payload, error) { tokenID, err := uuid.NewRandom() if err != nil { return nil, err @@ -32,6 +40,7 @@ func NewPayload(username string, role string, duration time.Duration) (*Payload, payload := &Payload{ ID: tokenID, + Type: tokenType, Username: username, Role: role, IssuedAt: time.Now(), @@ -41,7 +50,10 @@ func NewPayload(username string, role string, duration time.Duration) (*Payload, } // Valid checks if the token payload is valid or not -func (payload *Payload) Valid() error { +func (payload *Payload) Valid(tokenType TokenType) error { + if payload.Type != tokenType { + return ErrInvalidToken + } if time.Now().After(payload.ExpiredAt) { return ErrExpiredToken }