From 1bcc5b923d4b892410debe12b79d9ec8e7b5545b Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 11 Feb 2025 21:01:51 +0100 Subject: [PATCH 01/36] remove old code --- api/apps.go | 177 ------------------------ api/handlers.go | 321 -------------------------------------------- api/routes.go | 7 + api/service.go | 30 +---- api/service_test.go | 5 +- api/tokens.go | 179 ------------------------ client/client.go | 136 ------------------- client/config.go | 40 ------ cmd/authapi/main.go | 11 +- db/db.go | 111 --------------- db/mongo/apps.go | 149 -------------------- db/mongo/mongo.go | 161 ---------------------- db/mongo/tokens.go | 113 ---------------- db/temp.go | 152 --------------------- helpers/consts.go | 50 ------- helpers/helpers.go | 110 --------------- 16 files changed, 15 insertions(+), 1737 deletions(-) delete mode 100644 api/apps.go delete mode 100644 api/handlers.go create mode 100644 api/routes.go delete mode 100644 api/tokens.go delete mode 100644 client/client.go delete mode 100644 client/config.go delete mode 100644 db/db.go delete mode 100644 db/mongo/apps.go delete mode 100644 db/mongo/mongo.go delete mode 100644 db/mongo/tokens.go delete mode 100644 db/temp.go delete mode 100644 helpers/consts.go delete mode 100644 helpers/helpers.go diff --git a/api/apps.go b/api/apps.go deleted file mode 100644 index cd78c48..0000000 --- a/api/apps.go +++ /dev/null @@ -1,177 +0,0 @@ -package api - -import ( - "encoding/hex" - "fmt" - - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/helpers" -) - -// authApp method creates a new app based on the provided name, email, redirectURL -// and duration. It returns the app id and the app secret. If the name, email or -// redirectURL are empty, it returns an error. If the duration is less than the -// minimum duration, it returns an error. If something fails during the process, -// it returns an error. The app id and the app secret are generated based on the -// email using the generateApp function. The app is stored in the database using -// the app id as the key. The secret is stored in the database using the hashed -// secret as the key. The hashed secret is required to be compared with the -// secret provided by the user in the requests. -func (s *Service) authApp(name, email, redirectURL string, duration uint64) (string, string, error) { - // check if the name, email, and redirectURL are not empty - if len(name) == 0 || len(email) == 0 || len(redirectURL) == 0 { - return "", "", fmt.Errorf("name, email, and redirectURL are required") - } - // check if the duration is valid - if duration < helpers.MinTokenDuration { - return "", "", fmt.Errorf("duration must be at least %d seconds", helpers.MinTokenDuration) - } - // compose the app struct for the database - appData := &db.App{ - Name: name, - AdminEmail: email, - SessionDuration: duration, - RedirectURL: redirectURL, - UsersQuota: helpers.DefaultUsersQuota, - } - // generate app based on email - appId, secret, hSecret, err := generateApp(appData.AdminEmail) - if err != nil { - return "", "", err - } - // store app in the database - if err := s.db.SetApp(appId, appData); err != nil { - return "", "", err - } - // store secret in the database - if err := s.db.SetSecret(hSecret, appId); err != nil { - return "", "", err - } - return appId, secret, nil -} - -// appMetadata method retrieves the app data based on the app id. If the app id is -// empty, it returns an error. If something fails during the process, it returns -// an error. The app data includes the name, the email of the admin, the redirect -// URL, the duration, the users quota, and the current users. The current users -// are retrieved from the database using the app id to count the number of tokens -// for the app. -func (s *Service) appMetadata(appId string) (AppData, error) { - dbApp, err := s.db.AppById(appId) - if err != nil { - return AppData{}, err - } - app := AppData{ - Name: dbApp.Name, - Email: dbApp.AdminEmail, - RedirectURL: dbApp.RedirectURL, - Duration: dbApp.SessionDuration, - UsersQuota: dbApp.UsersQuota, - } - // get the number of current tokens for the app, if it fails, it returns 0 - app.CurrentUsers, _ = s.db.CountTokens(appId) - return app, nil -} - -// updateAppMetadata method updates the app metadata based on the app id, name, -// redirectURL, and duration. If the app id is empty, it returns an error. If -// the duration is non zero an less than the minimum duration, it returns an -// error. If something fails during the process, it returns an error. -func (s *Service) updateAppMetadata(appId, name, redirectURL string, duration uint64) error { - // check if the app id is not empty - if len(appId) == 0 { - return fmt.Errorf("app id is required") - } - // check if the duration is valid - if duration != 0 && duration < helpers.MinTokenDuration { - return fmt.Errorf("duration must be at least %d seconds", helpers.MinTokenDuration) - } - // get app from the database - app, err := s.db.AppById(appId) - if err != nil { - return err - } - // update app metadata - if name != "" { - app.Name = name - } - if redirectURL != "" { - app.RedirectURL = redirectURL - } - if duration != 0 { - app.SessionDuration = duration - } - // store app in the database - return s.db.SetApp(appId, app) -} - -// removeApp method removes an app based on the app id. If the app id is empty, -// it returns an error. If something fails during the process, it returns an -// error. It also removes all the tokens for the app from the database using -// the app id as the prefix to find them. -func (s *Service) removeApp(appId string) error { - // check if the app id is not empty - if len(appId) == 0 { - return fmt.Errorf("app id is required") - } - // remove all the tokens for the app from the database, using the app id as - // the prefix - if err := s.db.DeleteTokensByPrefix(appId); err != nil { - return err - } - // remove app from the database - return s.db.DeleteApp(appId) -} - -func (s *Service) validSecret(appId, rawSecret string) bool { - secret, err := helpers.Hash(rawSecret, helpers.SecretSize) - if err != nil { - return false - } - valid, err := s.db.ValidSecret(secret, appId) - if err != nil { - return false - } - return valid -} - -// generateApp function generates an app based on the email. It returns the app -// id, the app secret and the hashed secret. If the email is empty or something -// fails during the process, it returns an error. The app id is generated -// hashing the email with a length of 4 bytes. The app secret is generated -// using the appSecret function. -func generateApp(email string) (string, string, string, error) { - if len(email) == 0 { - return "", "", "", fmt.Errorf("email is required") - } - // hash email - hEmail, err := helpers.Hash(email, helpers.EmailHashSize) - if err != nil { - return "", "", "", err - } - bAppNonce := helpers.RandBytes(helpers.AppNonceSize) - hAppNonce := hex.EncodeToString(bAppNonce) - appId := hEmail + hAppNonce - // generate secret - secret, hSecret, err := appSecret() - if err != nil { - return "", "", "", err - } - return appId, secret, hSecret, nil -} - -// appSecret function generates an new app secret. It returns the secret, the -// hashed secret and an error if something fails during the process. The secret -// is a random sequence of 16 bytes encoded as a hexadecimal string. The hashed -// secret is required to store the secret in the database without exposing it. -func appSecret() (string, string, error) { - // generate secret - bSecret := helpers.RandBytes(helpers.SecretSize) - secret := hex.EncodeToString(bSecret) - // hash secret - hSecret, err := helpers.Hash(secret, helpers.SecretSize) - if err != nil { - return "", "", err - } - return secret, hSecret, nil -} diff --git a/api/handlers.go b/api/handlers.go deleted file mode 100644 index dc2c72f..0000000 --- a/api/handlers.go +++ /dev/null @@ -1,321 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io" - "log" - "net/http" - - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/email" - "github.com/simpleauthlink/authapi/helpers" -) - -// userTokenHandler method generates a token for the user and sends it via email -// to the user's email address. The token is generated based on the app id -// and the user's email address. The token is stored in the database with an -// expiration time. It gets the app secret from the helpers.AppSecretHeader -// header and the user's email address from the request body. If it success it -// sends an "Ok" response. If something goes wrong, it sends an internal server -// error response. If the app secret is missing or the request body is invalid, -// it sends a bad request response. -func (s *Service) userTokenHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) - return - } - // parse request - req := &TokenRequest{} - if err := json.Unmarshal(body, req); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) - return - } - // check if the email is allowed - if !s.emailQueue.Allowed(req.Email) { - http.Error(w, "disallowed domain", http.StatusBadRequest) - return - } - // generate token - magicLink, token, appName, err := s.magicLink(appSecret, req.Email, req.RedirectURL, req.Duration) - if err != nil { - log.Println("ERR: error generating token:", err) - http.Error(w, "error generating token", http.StatusInternalServerError) - return - } - // compose and push the email to the queue to be sent, if it fails, delete - // the token from the database, log the error and send an error response - emailData := email.NewUserEmailData(appName, req.Email, magicLink, token) - emailBody, err := email.ParseTemplate(s.cfg.TokenEmailTemplate, emailData) - if err != nil { - log.Println("ERR: error parsing email template:", err) - http.Error(w, "error parsing email template", http.StatusInternalServerError) - return - } - if err := s.emailQueue.Push(&email.Email{ - To: req.Email, - Subject: fmt.Sprintf(userTokenSubject, appName), - Body: emailBody, - }); err != nil { - log.Println("ERR: error sending email:", err) - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - http.Error(w, "error sending email", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// validateUserTokenHandler method validates the user token. It gets the token -// from the helpers.TokenQueryParam query string and checks if it is valid. If -// the token is valid, it sends a response with the "Ok" message. If the token -// is invalid, it sends an unauthorized response. If the token is missing, it -// sends a bad request response. -func (s *Service) validateUserTokenHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token - if !s.validUserToken(token, appSecret) { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// appTokenHandler method generates creates an app in the service, it generates -// an app id and a secret for the app. It sends the app id and the secret via -// email to the app's email address. It gets the app name, email, callback, and -// duration from the request body. If it success it sends an "Ok" response. If -// something goes wrong, it sends an internal server error response. If the -// request body is invalid, it sends a bad request response. -func (s *Service) appTokenHandler(w http.ResponseWriter, r *http.Request) { - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) - return - } - app := &AppData{} - if err := json.Unmarshal(body, app); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) - return - } - // check if the email is allowed - if !s.emailQueue.Allowed(app.Email) { - http.Error(w, "disallowed domain", http.StatusBadRequest) - return - } - // generate token - appId, secret, err := s.authApp(app.Name, app.Email, app.RedirectURL, app.Duration) - if err != nil { - log.Println("ERR: error generating token:", err) - http.Error(w, "error generating token", http.StatusInternalServerError) - return - } - emailData := email.NewAppEmailData(appId, app.Name, app.RedirectURL, secret, app.Email) - emailBody, err := email.ParseTemplate(s.cfg.AppEmailTemplate, emailData) - if err != nil { - log.Println("ERR: error parsing email template:", err) - http.Error(w, "error parsing email template", http.StatusInternalServerError) - return - } - // compose and push the email to the queue to be sent if it fails, delete - // the app from the database, log the error and send an error response - if err := s.emailQueue.Push(&email.Email{ - To: app.Email, - Subject: fmt.Sprintf(appTokenSubject, app.Name), - Body: emailBody, - }); err != nil { - log.Println("ERR: error sending email:", err) - if err := s.removeApp(appId); err != nil { - log.Println("ERR: error deleting app:", err) - } - http.Error(w, "error sending email", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// appHandler method gets the app metadata from the service. It gets the app id -// from the token provided in the URL query. If the token is missing, it sends -// a bad request response. If the token is invalid or is not an admin token, it -// sends an unauthorized response. If the app is not found, it sends a not found -// response. If it success it sends the app metadata. If something goes wrong, -// it sends an internal server error response. -func (s *Service) appHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // get the app from the database - app, err := s.appMetadata(appId) - if err != nil { - if err == db.ErrAppNotFound { - http.Error(w, "app not found", http.StatusNotFound) - return - } - log.Println("ERR: error getting app:", err) - http.Error(w, "error getting app", http.StatusInternalServerError) - return - } - // encode the app metadata - res, err := json.Marshal(&app) - if err != nil { - log.Println("ERR: error marshaling app:", err) - http.Error(w, "error marshaling app", http.StatusInternalServerError) - return - } - // send response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write(res); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// updateAppHandler method updates an app in the service. It gets the app id -// from the URL path and the app name, callback, and duration from the request -// body. If the app id is missing, it sends a bad request response. If the app -// is not found, it sends a not found response. If it success it sends an Ok -// response. If something goes wrong, it sends an internal server error -// response. -func (s *Service) updateAppHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) - return - } - // decode the app from the request - app := &AppData{} - if err := json.Unmarshal(body, app); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) - return - } - // update the app in the database - if err := s.updateAppMetadata(appId, app.Name, app.RedirectURL, app.Duration); err != nil { - log.Println("ERR: error updating app:", err) - http.Error(w, "error updating app", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// delAppHandler method deletes an app from the service. It gets the app id from -// the token provided in the URL query. If the token is missing, it sends a bad -// request response. If the token is invalid or is not an admin token, it sends -// an unauthorized response. If it success it sends an Ok response. If something -// goes wrong, it sends an internal server error response. -func (s *Service) delAppHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // remove the app from the service - if err := s.removeApp(appId); err != nil { - log.Println("ERR: error deleting app:", err) - http.Error(w, "error deleting app", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} diff --git a/api/routes.go b/api/routes.go new file mode 100644 index 0000000..f10df75 --- /dev/null +++ b/api/routes.go @@ -0,0 +1,7 @@ +package api + +const ( + // HealthCheckPath constant is the path used to check the health of the API + // server. It is a string with a value of "/health". + HealthCheckPath = "/ping" +) diff --git a/api/service.go b/api/service.go index ebe696d..17e09d1 100644 --- a/api/service.go +++ b/api/service.go @@ -12,9 +12,7 @@ import ( "time" "github.com/lucasmenendez/apihandler" - "github.com/simpleauthlink/authapi/db" "github.com/simpleauthlink/authapi/email" - "github.com/simpleauthlink/authapi/helpers" ) // Config struct represents the configuration needed to init the service. It @@ -37,17 +35,16 @@ type Service struct { cancel context.CancelFunc wait sync.WaitGroup cfg *Config - db db.DB emailQueue *email.EmailQueue handler *apihandler.Handler httpServer *http.Server } -// New function creates a new service based on the provided context, the db -// interface and configuration. It initializes the email queue, creates the -// service and sets the api handlers. If something goes wrong during the -// process, it returns an error. -func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { +// New function creates a new service based on the provided context and +// configuration. It initializes the email queue, creates the service and +// sets the api handlers. If something goes wrong during the process, it +// returns an error. +func New(ctx context.Context, cfg *Config) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) emailQueue, err := email.NewEmailQueue(internalCtx, &cfg.EmailConfig) if err != nil { @@ -62,7 +59,6 @@ func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { ctx: internalCtx, cancel: cancel, cfg: cfg, - db: db, emailQueue: emailQueue, handler: apihandler.NewHandler(&apihandler.Config{ CORS: true, @@ -72,17 +68,9 @@ func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { }, }), } - srv.handler.Get(helpers.HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { + srv.handler.Get(HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - // user handlers - srv.handler.Post(helpers.UserEndpointPath, srv.userTokenHandler) - srv.handler.Get(helpers.UserEndpointPath, srv.validateUserTokenHandler) - // app handlers - srv.handler.Get(helpers.AppEndpointPath, srv.appHandler) - srv.handler.Post(helpers.AppEndpointPath, srv.appTokenHandler) - srv.handler.Put(helpers.AppEndpointPath, srv.updateAppHandler) - srv.handler.Delete(helpers.AppEndpointPath, srv.delAppHandler) // build the http server srv.httpServer = &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server, cfg.ServerPort), @@ -96,8 +84,6 @@ func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { func (s *Service) Start() error { // start the email queue s.emailQueue.Start() - // start the token cleaner in the background - s.sanityTokenCleaner() // start the api server if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { return err @@ -109,10 +95,6 @@ func (s *Service) Start() error { // background processes to finish. It closes the database. If something goes // wrong during the process, it returns an error. func (s *Service) Stop() error { - // close the database - if err := s.db.Close(); err != nil { - return fmt.Errorf("error closing db: %w", err) - } // stop the email queue s.emailQueue.Stop() // cancel the context and wait for the background processes finish diff --git a/api/service_test.go b/api/service_test.go index c36056e..ba797b3 100644 --- a/api/service_test.go +++ b/api/service_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/simpleauthlink/authapi/db" "github.com/simpleauthlink/authapi/email" ) @@ -13,9 +12,7 @@ func TestNew(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - testDB := new(db.TempDriver) - testDB.Init(nil) - srv, err := New(ctx, testDB, &Config{ + srv, err := New(ctx, &Config{ Server: "localhost", ServerPort: 8080, CleanerCooldown: 30 * time.Second, diff --git a/api/tokens.go b/api/tokens.go deleted file mode 100644 index 2f0f155..0000000 --- a/api/tokens.go +++ /dev/null @@ -1,179 +0,0 @@ -package api - -import ( - "fmt" - "log" - "net/url" - "strings" - "time" - - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/helpers" -) - -// magicLink function generates and returns a magic link, the generated token -// and the associated app name, based on the provided app secret and the user -// email. If the secret or the email are empty, it returns an error. It gets -// the app id from the database based on the secret. It generates a token and -// calculates the expiration time based on the app session duration. It stores -// the token and the expiration time in the database. It returns the magic link -// composed of the app callback and the generated token. -func (s *Service) magicLink(rawSecret, email, redirectURL string, duration uint64) (string, string, string, error) { - // check if the secret and email are not empty - if len(rawSecret) == 0 || len(email) == 0 { - return "", "", "", fmt.Errorf("secret and email are required") - } - // get app secret from raw secret - appSecret, err := helpers.Hash(rawSecret, helpers.SecretSize) - if err != nil { - return "", "", "", err - } - // get app and app id from the database based on the secret - app, appId, err := s.db.AppBySecret(appSecret) - if err != nil { - return "", "", "", err - } - // get the number of tokens for the app using the app id as the prefix - numberOfAppTokens, err := s.db.CountTokens(appId) - if err != nil { - return "", "", "", err - } - // check if the number of tokens is greater than the users quota - if numberOfAppTokens >= app.UsersQuota { - return "", "", "", fmt.Errorf("users quota reached") - } - // generate token and calculate expiration - token, userId, err := helpers.EncodeUserToken(appId, email) - if err != nil { - return "", "", "", err - } - // by default, the session duration is the app session duration but it can - // be overwritten by the request - sessionDuration := app.SessionDuration - if duration > 0 { - sessionDuration = duration - } - expiration := time.Now().Add(time.Duration(sessionDuration) * time.Second) - // check if there is a token for the user and app in the database and delete - // it if it exists - tokenPrefix := strings.Join([]string{appId, userId}, helpers.TokenSeparator) - if err := s.db.DeleteTokensByPrefix(tokenPrefix); err != nil { - if err != db.ErrTokenNotFound { - log.Println("ERR: error checking token:", err) - } - } - // set token and expiration in the database - if err := s.db.SetToken(db.Token(token), expiration); err != nil { - return "", "", "", err - } - // return the magic link based on the app callback and the generated token - // by default, the redirect URL is the app redirect URL but it can be - // overwritten by the request - baseRawURL := app.RedirectURL - if redirectURL != "" { - baseRawURL = redirectURL - } - baseURL, err := url.Parse(baseRawURL) - if err != nil { - return "", "", "", fmt.Errorf("invalid redirect URL: %w", err) - } - urlQuery := baseURL.Query() - urlQuery.Set(helpers.TokenQueryParam, token) - baseURL.RawQuery = urlQuery.Encode() - return helpers.SafeURL(baseURL), token, app.Name, nil -} - -// validUserToken function checks if the provided token is valid. It checks if -// the token is not empty, if the app id is in the database, if the token is not -// expired and if the token is in the database. If the token is invalid, it -// returns false. If something goes wrong during the process, it logs the error -// and returns false. If the token is valid, it returns true. -func (s *Service) validUserToken(token, rawSecret string) bool { - // check if the token and secret are not empty - if len(token) == 0 || len(rawSecret) == 0 { - return false - } - // get the app id from the token - appId, _, err := helpers.DecodeUserToken(token) - if err != nil { - return false - } - // check if the secret is valid - if !s.validSecret(appId, rawSecret) { - return false - } - // get the token expiration from the database - expiration, err := s.db.TokenExpiration(db.Token(token)) - if err != nil { - return false - } - // check if the token is expired - if time.Now().After(expiration) { - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - return false - } - return true -} - -// validAdminToken function checks if the provided token is a valid admin token. -// It checks if the token is not empty, if the app id is in the database, if the -// token is not expired and if the token is in the database. If the token is -// invalid, it returns false. It also returns the app id if the token is valid. -func (s *Service) validAdminToken(token, rawSecret string) (string, bool) { - // check if the token and secret are not empty - if len(token) == 0 || len(rawSecret) == 0 { - return "", false - } - // get the app id from the token - appId, userId, err := helpers.DecodeUserToken(token) - if err != nil { - return "", false - } - // the app id is composed by the admin user id hash and a nonce, so - // the app id starts with the admin user id, check if so - if !strings.HasPrefix(appId, userId) { - return "", false - } - // check if the secret is valid - if !s.validSecret(appId, rawSecret) { - return "", false - } - // get the token expiration from the database - expiration, err := s.db.TokenExpiration(db.Token(token)) - if err != nil { - return "", false - } - // check if the token is expired - if time.Now().After(expiration) { - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - return "", false - } - return appId, true -} - -// sanityTokenCleaner function starts a goroutine that cleans the expired tokens -// from the database every time the cooldown time is reached. It uses a ticker -// to check the cooldown time and a context to stop the goroutine when the -// service is stopped. If something goes wrong during the process, it logs the -// error. -func (s *Service) sanityTokenCleaner() { - s.wait.Add(1) - go func() { - defer s.wait.Done() - ticker := time.NewTicker(s.cfg.CleanerCooldown) - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - if err := s.db.DeleteExpiredTokens(); err != nil { - log.Println("ERR: error deleting expired tokens:", err) - } - } - } - }() -} diff --git a/client/client.go b/client/client.go deleted file mode 100644 index 71925b5..0000000 --- a/client/client.go +++ /dev/null @@ -1,136 +0,0 @@ -package client - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - - "github.com/simpleauthlink/authapi/api" - "github.com/simpleauthlink/authapi/helpers" -) - -// Client struct represents the client to interact with the API server. It -// contains the configuration of the client. The configuration includes the -// secret of the app and the API endpoint. The API endpoint is optional and if -// it is empty, it uses the default API endpoint. The client provides two -// methods to interact with the API server, RequestToken and ValidateToken. -type Client struct { - config *ClientConfig -} - -// New function creates a new client based on the provided configuration. It -// returns the client and an error if the configuration is invalid. The -// configuration must include, at least, the secret of your app. If the API -// endpoint is empty, it uses the default API endpoint. It validates the config -// and returns an error if the configuration is nil, the secret is empty or the -// API endpoint is invalid. -func New(config *ClientConfig) (*Client, error) { - if err := config.check(); err != nil { - return nil, err - } - return &Client{config: config}, nil -} - -// RequestToken function requests a token for the user based on the provided -// email. It returns an error if the email is empty. It receives the context -// and the token request. The token request includes the email of the user, the -// redirect URL and the session duration. The session duration is optional and -// if it is zero, it uses the default session duration. It creates a new URL -// based on the API endpoint, encodes the request, creates the request, sets -// the secret in the header, sets the content type and makes the request. It -// checks the status code and returns an error if the status code is different -// from 200, if so returns an error trying to decode the body of the response. -func (cli *Client) RequestToken(ctx context.Context, req *api.TokenRequest) error { - if req == nil || req.Email == "" { - return fmt.Errorf("email is required to request a token") - } - // create a new URL based on the API endpoint - url := new(url.URL) - *url = *cli.config.url - // set the path - url.Path = helpers.UserEndpointPath - // encode the request - encodedReq, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("error encoding request: %w", err) - } - // create the request - buf := bytes.NewBuffer(encodedReq) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), buf) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - // set the secret in the header - httpReq.Header.Set(helpers.AppSecretHeader, cli.config.Secret) - // set the content type - httpReq.Header.Set("Content-Type", "application/json") - // make the request - res, err := http.DefaultClient.Do(httpReq) - if err != nil { - return fmt.Errorf("error making request: %w", err) - } - defer res.Body.Close() - // check the status code and return an error if the status code is different - // from 200, if so return an error trying to decode the body of the response - if res.StatusCode != http.StatusOK { - // decode body and return error - msg, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - return fmt.Errorf("unexpected response: [%d] %s", res.StatusCode, string(msg)) - } - return nil -} - -// ValidateToken function validates the token provided using the API server. It -// returns true if the token is valid, false if the token is invalid, or an -// error if something goes wrong during the process. It receives the context, -// the token and the client configuration. The configuration must include, at -// least, the secret of your app. If the API endpoint is empty, it uses the -// default API endpoint. It validates the config and returns an error if the -// configuration is nil, the secret is empty or the API endpoint is invalid. -func (cli *Client) ValidateToken(ctx context.Context, token string) (bool, error) { - // create a new URL based on the API endpoint - url := new(url.URL) - *url = *cli.config.url - // add token to the query - query := url.Query() - query.Set(helpers.TokenQueryParam, token) - // set the path and query - url.Path = helpers.UserEndpointPath - url.RawQuery = query.Encode() - // create the request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) - if err != nil { - return false, fmt.Errorf("error creating request: %w", err) - } - // set the secret in the header - req.Header.Set(helpers.AppSecretHeader, cli.config.Secret) - // make the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, fmt.Errorf("error making request: %w", err) - } - defer resp.Body.Close() - // check the status code, return true if the status code is 200 or false if - // the status code is 401, otherwise return an error trying to decode the - // body of the response - switch resp.StatusCode { - case http.StatusOK: - return true, nil - case http.StatusUnauthorized: - return false, nil - default: - // decode body and return error - msg, err := io.ReadAll(resp.Body) - if err != nil { - return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - return false, fmt.Errorf("unexpected response: [%d] %s", resp.StatusCode, string(msg)) - } -} diff --git a/client/config.go b/client/config.go deleted file mode 100644 index fc35fef..0000000 --- a/client/config.go +++ /dev/null @@ -1,40 +0,0 @@ -package client - -import ( - "fmt" - "net/url" - - "github.com/simpleauthlink/authapi/helpers" -) - -// ClientConfig struct represents the configuration needed to use the client. -type ClientConfig struct { - // APIEndpoint is the API hostname. - APIEndpoint string - url *url.URL - // Secret is the app secret on the API server. - Secret string -} - -// check function validates the configuration and returns an error if the -// configuration is invalid. It checks if the configuration is nil, if the -// secret is empty, and if the API endpoint is invalid. If the API endpoint is -// empty, it uses the default API endpoint. It returns an error if the -// configuration is nil, the secret is empty or the API endpoint is invalid. -func (conf *ClientConfig) check() error { - if conf == nil { - return fmt.Errorf("config is required") - } - if conf.APIEndpoint == "" { - conf.APIEndpoint = helpers.DefaultAPIEndpoint - } - if conf.Secret == "" { - return fmt.Errorf("secret is required") - } - var err error - conf.url, err = url.Parse(conf.APIEndpoint) - if err != nil { - return fmt.Errorf("invalid API endpoint: %w", err) - } - return nil -} diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index e1ca2c9..b16ad1c 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -10,7 +10,6 @@ import ( "time" "github.com/simpleauthlink/authapi/api" - "github.com/simpleauthlink/authapi/db/mongo" "github.com/simpleauthlink/authapi/email" ) @@ -83,16 +82,8 @@ func main() { if err != nil { log.Fatalln("ERR: error parsing config:", err) } - // init the database with mongo driver - db := new(mongo.MongoDriver) - if err := db.Init(mongo.Config{ - MongoURI: c.dbURI, - Database: c.dbName, - }); err != nil { - log.Fatalf("error initializing db: %v", err) - } // create the service - service, err := api.New(context.Background(), db, &api.Config{ + service, err := api.New(context.Background(), &api.Config{ EmailConfig: email.EmailConfig{ Address: c.emailAddr, Password: c.emailPass, diff --git a/db/db.go b/db/db.go deleted file mode 100644 index 9f8cec9..0000000 --- a/db/db.go +++ /dev/null @@ -1,111 +0,0 @@ -package db - -import ( - "fmt" - "time" -) - -var ( - // ErrInvalidConfig error is returned when the provided database - // configuration is missing or invalid. - ErrInvalidConfig = fmt.Errorf("invalid database config") - // ErrOpenConn error is returned when the database connection can't be - // opened with the provided configuration. - ErrOpenConn = fmt.Errorf("error opening database") - // ErrCloseConn error is returned when the database connection can't be - // closed. - ErrCloseConn = fmt.Errorf("error closing database") - // ErrAppNotFound error is returned when the desired app is not found in the - // database. - ErrAppNotFound = fmt.Errorf("app not found") - // ErrGetApp error is returned when something fails getting a app from the - // database. - ErrGetApp = fmt.Errorf("error getting the app from database") - // ErrSetApp error is returned when something fails storing a app in the - // database. - ErrSetApp = fmt.Errorf("error storing the app in database") - // ErrDelApp error is returned when something fails deleting a app from the - // database. - ErrDelApp = fmt.Errorf("error deleting the app from database") - // ErrSecretNotFound error is returned when the desired secret is not found - // in the database. - ErrSetSecret = fmt.Errorf("error storing the secret in database") - // ErrDelSecret error is returned when something fails deleting a secret - // from the database. - ErrDelSecret = fmt.Errorf("error deleting the secret from database") - // ErrTokenNotFound error is returned when the desired token is not found in - // the database. - ErrTokenNotFound = fmt.Errorf("token not found") - // ErrGetToken error is returned when something fails getting a token from - // the database. - ErrGetToken = fmt.Errorf("error getting the token from database") - // ErrSetToken error is returned when something fails storing a token in the - // database. - ErrSetToken = fmt.Errorf("error storing the token in database") - // ErrDelToken error is returned when something fails deleting a token from - // the database. - ErrDelToken = fmt.Errorf("error deleting the token from database") -) - -// App struct represents the application information that is stored in the -// database. -type App struct { - Name string - AdminEmail string - SessionDuration uint64 - RedirectURL string - UsersQuota int64 -} - -// Token type represents the token that is stored in the database. -type Token string - -type DB interface { - // Init method allows to the interface implementation to receive some config - // information and init the database connection. It returns an error if the - // config is invalid or the connection can't be opened. - Init(config any) error - // Close method allows to the interface implementation to close the database - // connection. It returns an error if something fails during the closing. - Close() error - // AppById method gets an app from the database based on the app id. It - // returns the app and an error if something goes wrong. - AppById(appId string) (*App, error) - // AppBySecret method gets an app from the database based on the app secret. - // It returns the app, the app id and an error if something goes wrong. - AppBySecret(secret string) (*App, string, error) - // SetApp method stores an app in the database. It returns an error if - // something goes wrong. - SetApp(appId string, app *App) error - // DeleteApp method deletes an app from the database. It returns an error if - // something goes wrong. - DeleteApp(appId string) error - // ValidSecret method checks if a secret is valid. It returns true if the - // secret is valid and false if it is not. - ValidSecret(secret, appId string) (bool, error) - // SetSecret method stores a secret in the database. It returns an error if - // something goes wrong. - SetSecret(secret, appId string) error - // DeleteSecret method deletes a secret from the database. It returns an - // error if something goes wrong. - DeleteSecret(secret string) error - // TokenExpiration method gets the token expiration from the database. It - // returns the expiration time and an error if something goes wrong. - TokenExpiration(token Token) (time.Time, error) - // SetToken method stores a token in the database with an expiration time. - // It returns an error if something goes wrong. - SetToken(token Token, expiration time.Time) error - // DeleteToken method deletes a token from the database. It returns an error - // if something goes wrong. - DeleteToken(token Token) error - // DeleteTokenByPrefix method deletes all the tokens with the provided - // prefix from the database. It returns an error if something goes wrong. - DeleteTokensByPrefix(prefix string) error - // DeleteExpiredTokens method deletes all the expired tokens from the - // database. It returns an error if something goes wrong. - DeleteExpiredTokens() error - // CountTokens method counts the number of tokens in the database. It allows - // to filter the tokens by the provided prefix. It returns the number of - // tokens and an error if something goes wrong. - CountTokens(prefix string) (int64, error) -} diff --git a/db/mongo/apps.go b/db/mongo/apps.go deleted file mode 100644 index f29ce67..0000000 --- a/db/mongo/apps.go +++ /dev/null @@ -1,149 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type App struct { - ID string `bson:"_id"` - Name string `bson:"name"` - AdminEmail string `bson:"admin_email"` - SessionDuration uint64 `bson:"session_duration"` - RedirectURL string `bson:"redirect_url"` - UsersQuota int64 `bson:"users_quota"` - Secret string `bson:"secret"` -} - -func (md *MongoDriver) AppById(appId string) (*db.App, error) { - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // get app from the database based on the app id - var app App - if err := md.apps.FindOne(ctx, bson.M{"_id": appId}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return nil, db.ErrAppNotFound - } - return nil, errors.Join(db.ErrGetApp, err) - } - // return app - return &db.App{ - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, nil -} - -func (md *MongoDriver) AppBySecret(secret string) (*db.App, string, error) { - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // get app from the database based on the app secret - var app App - if err := md.apps.FindOne(ctx, bson.M{"secret": secret}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return nil, "", db.ErrAppNotFound - } - return nil, "", errors.Join(db.ErrGetApp, err) - } - // return app and app id - return &db.App{ - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, app.ID, nil -} - -func (md *MongoDriver) SetApp(appId string, app *db.App) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // create or update app in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbApp, err := dynamicUpdateDocument(App{ - ID: appId, - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, nil) - if err != nil { - return errors.Join(db.ErrSetApp, err) - } - opts := options.Update().SetUpsert(true) - if _, err := md.apps.UpdateOne(ctx, bson.M{"_id": appId}, dbApp, opts); err != nil { - return errors.Join(db.ErrSetApp, err) - } - return nil -} - -func (md *MongoDriver) DeleteApp(appId string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete secret from the database by the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.DeleteOne(ctx, bson.M{"_id": appId}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrDelApp, err) - } - return nil -} - -func (md *MongoDriver) ValidSecret(secret, appId string) (bool, error) { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // get app from the database based on the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - var app App - if err := md.apps.FindOne(ctx, bson.M{"_id": appId}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return false, db.ErrAppNotFound - } - return false, errors.Join(db.ErrGetApp, err) - } - return app.Secret == secret, nil -} - -func (md *MongoDriver) SetSecret(secret, appId string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // set secret to app in the database by the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.UpdateOne(ctx, bson.M{"_id": appId}, bson.M{"$set": bson.M{"secret": secret}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrSetSecret, err) - } - return nil -} - -func (md *MongoDriver) DeleteSecret(secret string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete secret of the app from the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.UpdateOne(ctx, bson.M{"secret": secret}, bson.M{"$unset": bson.M{"secret": ""}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrDelSecret, err) - } - return nil -} diff --git a/db/mongo/mongo.go b/db/mongo/mongo.go deleted file mode 100644 index 2e8c183..0000000 --- a/db/mongo/mongo.go +++ /dev/null @@ -1,161 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "fmt" - "reflect" - "sync" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/mongo/readpref" -) - -const ( - tokensCollection = "tokens" - secretsCollection = "secrets" - appsCollection = "apps" -) - -type Config struct { - MongoURI string - Database string -} - -type MongoDriver struct { - ctx context.Context - cancel context.CancelFunc - config Config - client *mongo.Client - keysLock sync.RWMutex - - tokens *mongo.Collection - apps *mongo.Collection -} - -func (md *MongoDriver) Init(config any) error { - // validate config - cfg, ok := config.(Config) - if !ok { - return db.ErrInvalidConfig - } - if cfg.Database == "" { - return fmt.Errorf("%w: no database name provided", db.ErrInvalidConfig) - } - if cfg.MongoURI == "" { - return fmt.Errorf("%w: no database url provided", db.ErrInvalidConfig) - } - // init the client options - opts := options.Client() - opts.ApplyURI(cfg.MongoURI) - opts.SetMaxConnecting(200) - timeout := time.Second * 10 - opts.ConnectTimeout = &timeout - // connect to the database - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - client, err := mongo.Connect(ctx, opts) - if err != nil { - return errors.Join(db.ErrOpenConn, err) - } - // check if the connection is available - ctx, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel2() - if err := client.Ping(ctx, readpref.Primary()); err != nil { - return errors.Join(db.ErrOpenConn, err) - } - // create the internal context - md.ctx, md.cancel = context.WithCancel(context.Background()) - // set the client and config - md.client = client - md.config = cfg - // instantiate the collections - md.tokens = client.Database(cfg.Database).Collection(tokensCollection) - md.apps = client.Database(cfg.Database).Collection(appsCollection) - // create the indexes - if err := md.createIndexes(); err != nil { - return errors.Join(db.ErrOpenConn, err) - } - return nil -} - -func (md *MongoDriver) Close() error { - md.cancel() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := md.client.Disconnect(ctx); err != nil { - return errors.Join(db.ErrCloseConn, err) - } - return nil -} - -// createIndexes creates the indexes for the collections. It creates an index -// for the app secrets and an index for the token expiration. It returns an -// error if something goes wrong. -func (md *MongoDriver) createIndexes() error { - ctx, cancel := context.WithTimeout(md.ctx, 20*time.Second) - defer cancel() - // create an index for app secrets - if _, err := md.apps.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "secrets", Value: 1}}, // 1 for ascending order - Options: nil, - }); err != nil { - return err - } - // create an index for token expiration - if _, err := md.tokens.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "expiration", Value: 1}}, - Options: nil, - }); err != nil { - return err - } - return nil -} - -// dynamicUpdateDocument creates a BSON update document from a struct, -// including only non-zero fields. It uses reflection to iterate over the -// struct fields and create the update document. The struct fields must have -// a bson tag to be included in the update document. The _id field is skipped. -func dynamicUpdateDocument(item interface{}, alwaysUpdate []string) (bson.M, error) { - // check if the input is a pointer to a struct - val := reflect.ValueOf(item) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - // check if the input is a struct - if !val.IsValid() || val.Kind() != reflect.Struct { - return nil, fmt.Errorf("input must be a valid struct") - } - update := bson.M{} - typ := val.Type() - // create a map for quick lookup of always update fields - alwaysUpdateMap := make(map[string]bool, len(alwaysUpdate)) - for _, tag := range alwaysUpdate { - alwaysUpdateMap[tag] = true - } - // iterate over the struct fields - for i := 0; i < val.NumField(); i++ { - // check if the field can be accessed - field := val.Field(i) - if !field.CanInterface() { - continue - } - // get the field bson tag and type - fieldType := typ.Field(i) - tag := fieldType.Tag.Get("bson") - // skip the field if the tag is empty, "-" or "_id" - if tag == "" || tag == "-" || tag == "_id" { - continue - } - // check if the field should always be updated or is not the zero value - _, alwaysUpdate := alwaysUpdateMap[tag] - if alwaysUpdate || !reflect.DeepEqual(field.Interface(), reflect.Zero(field.Type()).Interface()) { - update[tag] = field.Interface() - } - } - return bson.M{"$set": update}, nil -} diff --git a/db/mongo/tokens.go b/db/mongo/tokens.go deleted file mode 100644 index 1afe378..0000000 --- a/db/mongo/tokens.go +++ /dev/null @@ -1,113 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type Token struct { - Token db.Token `bson:"_id"` - Expiration int64 `bson:"expiration"` -} - -func (md *MongoDriver) TokenExpiration(token db.Token) (time.Time, error) { - var dbToken Token - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if err := md.tokens.FindOne(ctx, bson.M{"_id": token}).Decode(&dbToken); err != nil { - if err == mongo.ErrNoDocuments { - return time.Time{}, db.ErrTokenNotFound - } - return time.Time{}, errors.Join(db.ErrGetToken, err) - } - return time.Unix(0, dbToken.Expiration), nil -} - -func (md *MongoDriver) SetToken(token db.Token, expiration time.Time) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // set token in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbToken := Token{ - Token: token, - Expiration: expiration.UnixNano(), - } - opts := options.Replace().SetUpsert(true) - if _, err := md.tokens.ReplaceOne(ctx, bson.M{"_id": token}, dbToken, opts); err != nil { - return errors.Join(db.ErrSetToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteToken(token db.Token) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete token from the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.tokens.DeleteOne(ctx, bson.M{"_id": token}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrTokenNotFound - } - return errors.Join(db.ErrDelToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteTokensByPrefix(prefix string) error { - // check if the prefix is empty and return nil if it is - if prefix == "" { - return nil - } - // check if there is a token with the provided prefix in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.tokens.DeleteMany(ctx, bson.M{"_id": bson.M{"$regex": "^" + prefix}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrTokenNotFound - } - return errors.Join(db.ErrGetToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteExpiredTokens() error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete expired tokens from the database, filter by expiration time less - // than now - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbNow := time.Now().UnixNano() - if _, err := md.tokens.DeleteMany(ctx, bson.M{"expiration": bson.M{"$lt": dbNow}}); err != nil { - return errors.Join(db.ErrDelToken, err) - } - return nil -} - -func (md *MongoDriver) CountTokens(prefix string) (int64, error) { - // count the number of tokens in the database, filter by the provided prefix - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // filter by prefix if provided - filter := bson.M{} - if prefix != "" { - filter = bson.M{"_id": bson.M{"$regex": "^" + prefix}} - } - // count the number of tokens and return the result - count, err := md.tokens.CountDocuments(ctx, filter) - if err != nil { - if err == mongo.ErrNoDocuments { - return 0, db.ErrTokenNotFound - } - return 0, errors.Join(db.ErrGetToken, err) - } - return count, nil -} diff --git a/db/temp.go b/db/temp.go deleted file mode 100644 index 760b4a8..0000000 --- a/db/temp.go +++ /dev/null @@ -1,152 +0,0 @@ -package db - -import ( - "strings" - "sync" - "time" -) - -type TempDriver struct { - apps map[string]App - secretToApp map[string]string - tokens map[Token]int64 - lock sync.RWMutex -} - -func (tdb *TempDriver) Init(_ any) error { - tdb.apps = make(map[string]App) - tdb.secretToApp = make(map[string]string) - tdb.tokens = make(map[Token]int64) - return nil -} - -func (tdb *TempDriver) Close() error { - return nil -} - -func (tdb *TempDriver) AppById(appId string) (*App, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - app, ok := tdb.apps[appId] - if !ok { - return nil, ErrAppNotFound - } - return &app, nil -} - -func (tdb *TempDriver) AppBySecret(secret string) (*App, string, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - appId, ok := tdb.secretToApp[secret] - if !ok { - return nil, "", ErrAppNotFound - } - app, ok := tdb.apps[appId] - if !ok { - return nil, "", ErrAppNotFound - } - return &app, appId, nil -} - -func (tdb *TempDriver) SetApp(appId string, app *App) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.apps[appId] = *app - return nil -} - -func (tdb *TempDriver) DeleteApp(appId string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.apps, appId) - return nil -} - -func (tdb *TempDriver) ValidSecret(secret, appId string) (bool, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - appIdFromSecret, ok := tdb.secretToApp[secret] - if !ok { - return false, nil - } - return appIdFromSecret == appId, nil -} - -func (tdb *TempDriver) SetSecret(secret, appId string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.secretToApp[secret] = appId - return nil -} - -func (tdb *TempDriver) DeleteSecret(secret string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.secretToApp, secret) - return nil -} - -func (tdb *TempDriver) TokenExpiration(token Token) (time.Time, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - exp, ok := tdb.tokens[token] - if !ok { - return time.Time{}, ErrTokenNotFound - } - return time.Unix(0, exp), nil -} - -func (tdb *TempDriver) SetToken(token Token, expiration time.Time) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.tokens[token] = expiration.UnixNano() - return nil -} - -func (tdb *TempDriver) DeleteToken(token Token) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.tokens, token) - return nil -} - -func (tdb *TempDriver) DeleteTokensByPrefix(prefix string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - if prefix == "" { - return nil - } - for token := range tdb.tokens { - if strings.HasPrefix(string(token), prefix) { - delete(tdb.tokens, token) - } - } - return nil -} - -func (tdb *TempDriver) DeleteExpiredTokens() error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - now := time.Now().UnixNano() - for token, expiration := range tdb.tokens { - if now > expiration { - delete(tdb.tokens, token) - } - } - return nil -} - -func (tdb *TempDriver) CountTokens(prefix string) (int64, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - if prefix == "" { - return int64(len(tdb.tokens)), nil - } - var count int64 - for token := range tdb.tokens { - if strings.HasPrefix(string(token), prefix) { - count++ - } - } - return count, nil -} diff --git a/helpers/consts.go b/helpers/consts.go deleted file mode 100644 index a3c1223..0000000 --- a/helpers/consts.go +++ /dev/null @@ -1,50 +0,0 @@ -package helpers - -const ( - // TokenSeparator constant is the separator used to split the token into - // parts. It is a string with a value of "-". - TokenSeparator = "-" - // TokenQueryParam constant is the query parameter used to send the token in - // the request. It is a string with a value of "token". - TokenQueryParam = "token" - // AppSecretHeader constant is the header used to send the app secret in the - // request. It is a string with a value of "APP_SECRET". - AppSecretHeader = "APP_SECRET" - // DefaultAPIEndpoint constant is the default API endpoint used by the - // client. It is a string with a value of "https://api.simpleauth.link/". - DefaultAPIEndpoint = "https://api.simpleauth.link/" - // HealthCheckPath constant is the path used to check the health of the API - // server. It is a string with a value of "/health". - HealthCheckPath = "/health" - // AppEndpointPath constant is the path used to API endpoints related to - // apps. It is a string with a value of "/app". - AppEndpointPath = "/app" - // UserEndpointPath constant is the path used to API endpoints related to - // users. It is a string with a value of "/user". - UserEndpointPath = "/user" - // MinTokenDuration constant is the minimum duration allowed for a token to - // be valid, which is an integer with a value of 60 (seconds). - MinTokenDuration = 60 // seconds - // defaultUsersQuota constant is the default number of users allowed for an - // app, which is an integer with a value of 100. - DefaultUsersQuota = 100 // users - // UserIdSize constant is the size of the user id, which is an integer with a - // value of 4 (bytes). - UserIdSize = 4 - // AppIdSize constant is the size of the app id, which is an integer with a - // value of 8 (bytes). - AppIdSize = 8 - // EmailHashSize constant is the size of the email hash, which is an integer - // with a value of 4 (bytes). The email hash is used to generate the user id - // and the app id. - EmailHashSize = 4 - // AppNonceSize constant is the size of the app nonce, which is an integer - // with a value of 4 (bytes). The app nonce is used to generate the app id. - AppNonceSize = 4 - // SecretSize constant is the size of the secret, which is an integer with a - // value of 16 (bytes). - SecretSize = 16 - // TokenSize constant is the size of the token, which is an integer with a - // value of 8 (bytes). - TokenSize = 8 -) diff --git a/helpers/helpers.go b/helpers/helpers.go deleted file mode 100644 index ac0d67e..0000000 --- a/helpers/helpers.go +++ /dev/null @@ -1,110 +0,0 @@ -package helpers - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "math/rand" - "net/url" - "strings" -) - -// EncodeUserToken function encodes the user information into a token and -// returns it. It receives the app id and the email of the user and returns the -// token and the user id. If the app id or the email are empty, it returns an -// error. The token is composed of three parts separated by a token separator. -// The first part is a random sequence of 8 bytes encoded as a hexadecimal -// string. The second part is the app id and the third part is the user id. The -// user id is generated hashing the email with a length of 4 bytes. The token -// is returned following the token format: -// -// [appId(8)]-[userId(8)]-[randomPart(16)] -func EncodeUserToken(appId, email string) (string, string, error) { - // check if the app id and email are not empty - if len(appId) == 0 || len(email) == 0 { - return "", "", fmt.Errorf("appId and email are required") - } - bToken := RandBytes(TokenSize) - hexToken := hex.EncodeToString(bToken) - // hash email - userId, err := Hash(email, UserIdSize) - if err != nil { - return "", "", err - } - return strings.Join([]string{appId, userId, hexToken}, TokenSeparator), userId, nil -} - -// DecodeUserToken function decodes the user information from the token provided -// and returns the app id and the user id. If the token is invalid, it returns -// an error. It splits the provided token by the token separator and returns the -// second and third parts, which are the app id and the user id respectively, -// following the token format: -// -// [appId(8)]-[userId(8)]-[randomPart(16)] -func DecodeUserToken(token string) (string, string, error) { - tokenParts := strings.Split(token, TokenSeparator) - if len(tokenParts) != 3 { - return "", "", fmt.Errorf("invalid token") - } - return tokenParts[0], tokenParts[1], nil -} - -// RandBytes generates a random byte slice of length n. It returns nil if n is -// less than 1. -func RandBytes(n int) []byte { - if n < 1 { - return nil - } - b := make([]byte, n) - for i := 0; i < n; { - val := rand.Uint64() - for j := 0; j < 8 && i < n; j++ { - b[i] = byte(val & 0xff) - val >>= 8 - i++ - } - } - return b -} - -// Hash generates a hash of the input string using SHA-256 algorithm. The n -// parameter allows to truncate the hash to n bytes. It returns the hash as a -// hexadecimal string. The resulting string will have a length of 2*n. If n is -// less than 1 or greater than the hash length, the full hash will be returned. -// If the input string is empty, it returns an empty string. If something fails -// during the hashing process, it returns an error. -func Hash(input string, n int) (string, error) { - if input == "" { - return "", nil - } - hash := sha256.New() - if _, err := hash.Write([]byte(input)); err != nil { - return "", err - } - bHash := hash.Sum(nil) - if n > 0 && n < len(bHash) { - bHash = bHash[:n] - } - return hex.EncodeToString(bHash), nil -} - -// SafeURL function returns a safe URL string from the provided URL. It returns -// an empty string if the URL is nil. The resulting string will have the format: -// scheme://host/path#fragment?query. If the URL has no path, query or fragment, -// they will be omitted. The query parameters will be encoded. -func SafeURL(url *url.URL) string { - if url == nil { - return "" - } - strURL := fmt.Sprintf("%s://%s", url.Scheme, url.Host) - if url.Path != "" { - strURL += url.Path - } - if url.Fragment != "" { - strURL += fmt.Sprintf("#%s", url.Fragment) - } - if encoded := url.Query().Encode(); encoded != "" { - strURL += fmt.Sprintf("?%s", encoded) - } - return strURL -} From 81c8a037d5389aefa672d68268acdfadfb6d06f9 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 11 Feb 2025 21:02:46 +0100 Subject: [PATCH 02/36] new primitives for new implementation with tests: app, expiration and id --- token/app.go | 95 +++++++++++++++++++++ token/app_test.go | 180 +++++++++++++++++++++++++++++++++++++++ token/consts.go | 18 ++++ token/errors.go | 10 +++ token/expiration.go | 58 +++++++++++++ token/expiration_test.go | 61 +++++++++++++ token/id.go | 57 +++++++++++++ token/id_test.go | 84 ++++++++++++++++++ 8 files changed, 563 insertions(+) create mode 100644 token/app.go create mode 100644 token/app_test.go create mode 100644 token/consts.go create mode 100644 token/errors.go create mode 100644 token/expiration.go create mode 100644 token/expiration_test.go create mode 100644 token/id.go create mode 100644 token/id_test.go diff --git a/token/app.go b/token/app.go new file mode 100644 index 0000000..8b33898 --- /dev/null +++ b/token/app.go @@ -0,0 +1,95 @@ +package token + +import ( + "encoding/base64" + "strings" + "time" +) + +type App struct { + Name string + RedirectURI string + SessionDuration time.Duration +} + +func (app *App) Valid() bool { + if app == nil { + return false + } + if len(app.Name) < appNameMinLen || len(app.Name) > appNameMaxLen { + return false + } + if !uriRegexp.MatchString(app.RedirectURI) || len(app.RedirectURI) > redirectURIMaxLen { + return false + } + if app.SessionDuration < minDuration || app.SessionDuration > maxDuration { + return false + } + return true +} + +func (app *App) Attributes() []string { + return []string{app.Name, app.RedirectURI, app.SessionDuration.String()} +} + +func (app *App) SetAttributes(attrs []string) *App { + if len(attrs) != 3 { + return nil + } + duration, err := time.ParseDuration(attrs[2]) + if err != nil { + return nil + } + app.Name = attrs[0] + app.RedirectURI = attrs[1] + app.SessionDuration = duration + if !app.Valid() { + return nil + } + return app +} + +func (app *App) String() string { + if !app.Valid() { + return "" + } + return strings.Join(app.Attributes(), appDataSeparator) +} + +func (app *App) SetString(data string) *App { + b := strings.Split(data, appDataSeparator) + return app.SetAttributes(b) +} + +func (app *App) Bytes() []byte { + return []byte(app.String()) +} + +func (app *App) SetBytes(data []byte) *App { + return app.SetString(string(data)) +} + +func (app *App) Marshal() []byte { + if !app.Valid() { + return nil + } + bApp := app.Bytes() + b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bApp))) + base64.RawStdEncoding.Encode(b, bApp) + return b +} + +func (app *App) Unmarshal(data []byte) *App { + b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) + if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + return nil + } + return app.SetBytes(b) +} + +func (app *App) ID() *AppID { + if !app.Valid() { + return nil + } + return new(AppID).SetBytes(app.Marshal()) +} diff --git a/token/app_test.go b/token/app_test.go new file mode 100644 index 0000000..f44976f --- /dev/null +++ b/token/app_test.go @@ -0,0 +1,180 @@ +package token + +import ( + "bytes" + "testing" + "time" +) + +const ( + testAppName = "MySuperMegaApp" + testRedirectURI = "https://example.com/login?app=MySuperMegaApp" + testSessionDuration = time.Minute * 30 +) + +func TestValidApp(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + if !app.Valid() { + t.Errorf("expected valid app data") + } + // test app name + app.Name = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + if app.Valid() { + t.Errorf("expected invalid app data") + } + app.Name = "no" + if app.Valid() { + t.Errorf("expected invalid app data") + } + app.Name = testAppName + // test redirect URI + app.RedirectURI = "https://example.com/login?app=lorem_ipsum_dolor_sit_amet_consectetur_adipiscing_elit_sed_do_eiusmod_tempor_incididunt_ut_labore_et_dolore_magna_aliqua" + if app.Valid() { + t.Errorf("expected invalid app data") + } + app.RedirectURI = "no_url" + if app.Valid() { + t.Errorf("expected invalid app data") + } + app.RedirectURI = testRedirectURI + // test session duration + app.SessionDuration = minDuration - 1 + if app.Valid() { + t.Errorf("expected invalid app data") + } + app.SessionDuration = maxDuration + 1 + if app.Valid() { + t.Errorf("expected invalid app data") + } +} + +func TestAttributesSetAttributesApp(t *testing.T) { + if res := new(App).SetAttributes([]string{}); res != nil { + t.Errorf("expected nil, got %v", res) + } + + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, "no_duration"}); res != nil { + t.Errorf("expected nil, got %v", res) + } + + if res := new(App).SetAttributes([]string{testAppName, "no_url", testSessionDuration.String()}); res != nil { + t.Errorf("expected nil, got %v", res) + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + data := new(App).SetAttributes(app.Attributes()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestStringSetStringApp(t *testing.T) { + if res := new(App).String(); res != "" { + t.Errorf("expected empty string, got %q", res) + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + data := new(App).SetString(app.String()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestBytesSetBytesApp(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + data := new(App).SetBytes(app.Bytes()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestMarshalUnmarshalApp(t *testing.T) { + if res := new(App).Marshal(); res != nil { + t.Errorf("expected nil, got %v", res) + } + + if res := new(App).Unmarshal([]byte{1}); res != nil { + t.Errorf("expected nil, got %v", res) + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + data := new(App).Unmarshal(app.Marshal()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestAppID(t *testing.T) { + if id := new(App).ID(); id != nil { + t.Errorf("expected nil, got %v", id) + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + id := app.ID() + if id == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(id.Bytes(), app.Marshal()) { + t.Errorf("expected %v, got %v", app.Marshal(), id.Bytes()) + } +} diff --git a/token/consts.go b/token/consts.go new file mode 100644 index 0000000..acc11f4 --- /dev/null +++ b/token/consts.go @@ -0,0 +1,18 @@ +package token + +import ( + "regexp" + "time" +) + +const ( + appDataSeparator = "|" + appNameMinLen = 3 + appNameMaxLen = 20 + redirectURIPattern = `^https?://[a-zA-Z0-9-]+(\.[a-zA-Z0-9-]+)+(/[a-zA-Z0-9-._~:/?#[\]@!$&'()*+,;=]*)?$` + redirectURIMaxLen = 80 + minDuration = 5 * time.Minute + maxDuration = 180 * 24 * time.Hour +) + +var uriRegexp = regexp.MustCompile(redirectURIPattern) diff --git a/token/errors.go b/token/errors.go new file mode 100644 index 0000000..287a594 --- /dev/null +++ b/token/errors.go @@ -0,0 +1,10 @@ +package token + +import "fmt" + +var ( + ErrInvalidAppID = fmt.Errorf("invalid app ID") + ErrInvalidAppName = fmt.Errorf("invalid app name") + ErrInvalidRedirectURI = fmt.Errorf("invalid redirect URI") + ErrInvalidSessionDuration = fmt.Errorf("invalid session duration") +) diff --git a/token/expiration.go b/token/expiration.go new file mode 100644 index 0000000..ddb690c --- /dev/null +++ b/token/expiration.go @@ -0,0 +1,58 @@ +package token + +import ( + "encoding/base64" + "time" +) + +type Expiration time.Time + +func NewExpiration(d time.Duration) Expiration { + return Expiration(time.Now().Add(d)) +} + +func (exp *Expiration) String() string { + t := time.Time(*exp) + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339Nano) +} + +func (exp *Expiration) SetString(data string) *Expiration { + t, err := time.Parse(time.RFC3339Nano, data) + if err != nil { + return nil + } + *exp = Expiration(t) + return exp +} + +func (exp *Expiration) Bytes() []byte { + if exp.String() == "" { + return nil + } + return []byte(exp.String()) +} + +func (exp *Expiration) SetBytes(data []byte) *Expiration { + return exp.SetString(string(data)) +} + +func (exp *Expiration) Marshal() []byte { + bExp := exp.Bytes() + if len(bExp) == 0 || bExp[0] == 0 { + return nil + } + b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bExp))) + base64.RawStdEncoding.Encode(b, bExp) + return b +} + +func (exp *Expiration) Unmarshal(data []byte) *Expiration { + b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) + if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + return nil + } + return exp.SetBytes(b) +} diff --git a/token/expiration_test.go b/token/expiration_test.go new file mode 100644 index 0000000..348a0af --- /dev/null +++ b/token/expiration_test.go @@ -0,0 +1,61 @@ +package token + +import ( + "bytes" + "testing" + "time" +) + +func TestStringSetStringExpiration(t *testing.T) { + exp := NewExpiration(time.Second) + str := exp.String() + decoded := new(Expiration).SetString(str) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if exp.String() != decoded.String() { + t.Errorf("expected %v, got %v", exp, decoded) + } + if exp := new(Expiration).SetString("invalid"); exp != nil { + t.Errorf("expected nil, got %v", exp) + } + if exp := new(Expiration).String(); exp != "" { + t.Errorf("expected empty string, got %v", exp) + } +} + +func TestBytesSetBytesExpiration(t *testing.T) { + exp := NewExpiration(time.Second) + b := exp.Bytes() + decoded := new(Expiration).SetBytes(b) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if !bytes.Equal(b, decoded.Bytes()) { + t.Errorf("expected %v, got %v", exp, decoded) + } + if exp := new(Expiration).SetBytes([]byte("invalid")); exp != nil { + t.Errorf("expected nil, got %v", exp) + } + if exp := new(Expiration).Bytes(); exp != nil { + t.Errorf("expected nil, got %v", exp) + } +} + +func TestMarshalUnmarshalExpiration(t *testing.T) { + exp := NewExpiration(time.Second) + encoded := exp.Marshal() + decoded := new(Expiration).Unmarshal(encoded) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if exp.String() != decoded.String() { + t.Errorf("expected %v, got %v", exp, decoded) + } + if res := new(Expiration).Marshal(); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(Expiration).Unmarshal([]byte{1}); res != nil { + t.Errorf("expected nil, got %v", res) + } +} diff --git a/token/id.go b/token/id.go new file mode 100644 index 0000000..133db3c --- /dev/null +++ b/token/id.go @@ -0,0 +1,57 @@ +package token + +import ( + "crypto/ed25519" + "crypto/sha256" +) + +type AppID string + +func (id *AppID) String() string { + return string(*id) +} + +func (id *AppID) SetString(data string) *AppID { + newID := AppID(data) + if !new(App).Unmarshal(newID.Bytes()).Valid() { + return nil + } + *id = newID + return id +} + +func (id *AppID) Bytes() []byte { + return []byte(*id) +} + +func (id *AppID) SetBytes(data []byte) *AppID { + return id.SetString(string(data)) +} + +func (id *AppID) PrivKey() ed25519.PrivateKey { + bID := id.Bytes() + if len(bID) == 0 { + return nil + } + hID := sha256.Sum256(bID) + return ed25519.NewKeyFromSeed(hID[:]) +} + +func (id *AppID) Sign(data []byte) []byte { + privKey := id.PrivKey() + if len(privKey) == 0 { + return nil + } + hData := sha256.Sum256(data) + return ed25519.Sign(privKey, hData[:]) +} + +func (id *AppID) Verify(data, sig []byte) bool { + privKey := id.PrivKey() + if privKey == nil { + return false + } + pubKey := privKey.Public().(ed25519.PublicKey) + hData := sha256.Sum256(data) + return ed25519.Verify(pubKey, hData[:], sig) +} diff --git a/token/id_test.go b/token/id_test.go new file mode 100644 index 0000000..892d74f --- /dev/null +++ b/token/id_test.go @@ -0,0 +1,84 @@ +package token + +import ( + "bytes" + "testing" +) + +func TestStringSetStringAppID(t *testing.T) { + if id := new(AppID).SetString("testID"); id != nil { + t.Errorf("expected nil, got %v", id) + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + id := app.ID() + if id == nil { + t.Fatalf("error decoding app ID") + } + if id.String() != string(app.Marshal()) { + t.Errorf("expected %s, got %s", string(app.Marshal()), id.String()) + } + newID := new(AppID).SetString(string(app.Marshal())) + if newID == nil { + t.Fatalf("error decoding app ID") + } + if newID.String() != id.String() { + t.Errorf("expected %s, got %s", id.String(), newID.String()) + } +} + +func TestBytesSetBytesAppID(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + id := app.ID() + if id == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(id.Bytes(), app.Marshal()) { + t.Errorf("expected %v, got %v", app.Marshal(), id.Bytes()) + } + newID := new(AppID).SetBytes(app.Marshal()) + if newID == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(newID.Bytes(), id.Bytes()) { + t.Errorf("expected %v, got %v", id.Bytes(), newID.Bytes()) + } +} + +func TestPrivKeySignVerifyAppID(t *testing.T) { + if privKey := new(AppID).PrivKey(); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if sig := new(AppID).Sign([]byte("test data")); sig != nil { + t.Errorf("expected nil, got %v", sig) + } + if new(AppID).Verify([]byte("test data"), []byte("test sig")) { + t.Errorf("expected signature to be invalid") + } + + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + id := app.ID() + if id == nil { + t.Fatalf("error decoding app ID") + } + data := []byte("test data") + sig := id.Sign(data) + if sig == nil { + t.Fatalf("error signing data") + } + if !id.Verify(data, sig) { + t.Errorf("expected signature to be valid") + } +} From 1ec6447f9049af00a7bb087ed6c785d2851e0f37 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 11 Feb 2025 21:05:10 +0100 Subject: [PATCH 03/36] dependencies updated --- go.mod | 19 ++--------------- go.sum | 64 ++-------------------------------------------------------- 2 files changed, 4 insertions(+), 79 deletions(-) diff --git a/go.mod b/go.mod index 973a808..fc0b9ff 100644 --- a/go.mod +++ b/go.mod @@ -2,21 +2,6 @@ module github.com/simpleauthlink/authapi go 1.21 -require ( - github.com/lucasmenendez/apihandler v0.0.7 - go.mongodb.org/mongo-driver v1.15.0 -) +require github.com/lucasmenendez/apihandler v0.0.7 -require ( - github.com/golang/snappy v0.0.1 // indirect - github.com/klauspost/compress v1.13.6 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect - github.com/xdg-go/pbkdf2 v1.0.0 // indirect - github.com/xdg-go/scram v1.1.2 // indirect - github.com/xdg-go/stringprep v1.0.4 // indirect - github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.6.0 // indirect -) +require golang.org/x/time v0.10.0 // indirect diff --git a/go.sum b/go.sum index afad0e8..09e5a75 100644 --- a/go.sum +++ b/go.sum @@ -1,64 +1,4 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/lucasmenendez/apihandler v0.0.4 h1:QspySW+hZp45HsLur2VcJQ/EcaRzll6XhOUriPRrHYs= -github.com/lucasmenendez/apihandler v0.0.4/go.mod h1:1R2dcf/Wbr6sx7Gjjv5oWWKgD8Pokib/d5BuCwhaBcA= -github.com/lucasmenendez/apihandler v0.0.5-0.20240520102504-ffd40a81622e h1:NpAmyHWwUwxtYGFiQRfHA1+BGnv05rQ9m/p9YK41yC0= -github.com/lucasmenendez/apihandler v0.0.5-0.20240520102504-ffd40a81622e/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -github.com/lucasmenendez/apihandler v0.0.6 h1:og9FRFIiPwLAyLbwFS3IvwlewD6/woqlau+1PvISvRY= -github.com/lucasmenendez/apihandler v0.0.6/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= github.com/lucasmenendez/apihandler v0.0.7 h1:OItUaGN5J+KrYFLZnQUNHXnOBP6HZyvlobyk1Jd7JkI= github.com/lucasmenendez/apihandler v0.0.7/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= -github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= -github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= -github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= -github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mongodb.org/mongo-driver v1.15.0 h1:rJCKC8eEliewXjZGf0ddURtl7tTVy1TK3bfl0gkUSLc= -go.mongodb.org/mongo-driver v1.15.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= From 2871a853d6f3513c717d167e51835974dcbaa173 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 11 Feb 2025 21:22:21 +0100 Subject: [PATCH 04/36] workflow for testing on github actions --- .github/workflows/main.yml | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 .github/workflows/main.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..7109b3b --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,50 @@ +name: Build and Test + +on: + push: + branches: + - main + pull_request: + +jobs: + job_go_checks: + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 1 + - name: Set up Go environment + uses: actions/setup-go@v5 + with: + go-version: "1.23" + - name: Tidy go module + run: | + go mod tidy + if [[ $(git status --porcelain) ]]; then + git diff + echo + echo "go mod tidy made these changes, please run 'go mod tidy' and include those changes in a commit" + exit 1 + fi + - name: Run gofumpt + run: diff -u <(echo -n) <(go run mvdan.cc/gofumpt@@latest -d .) + - name: Run go vet + run: go vet ./... + + job_go_test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 2 + - name: Set up Go environment + uses: actions/setup-go@v5 + with: + go-version: "1.23" + - name: Run Go test -race + run: go test ./... -race -timeout=1h From aa2f41e1a7d272d11045808c513c8c65e9e40ac8 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 11 Feb 2025 21:25:40 +0100 Subject: [PATCH 05/36] old test fixed --- api/service_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/service_test.go b/api/service_test.go index ba797b3..b065eec 100644 --- a/api/service_test.go +++ b/api/service_test.go @@ -19,8 +19,8 @@ func TestNew(t *testing.T) { EmailConfig: email.EmailConfig{ EmailHost: "smtp.gmail.com", EmailPort: 587, - Address: "", - Password: "", + Address: "test@email.com", + Password: "password1234", }, }) if err != nil { From 7ff6911089e49d53b59ff90535718adc99057af0 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 12 Feb 2025 21:17:33 +0100 Subject: [PATCH 06/36] token and verify token methods --- token/app.go | 7 +++++ token/app_test.go | 14 +++++++--- token/consts.go | 3 ++- token/expiration.go | 16 ++++++++++-- token/expiration_test.go | 34 +++++++++++++++++++++--- token/id.go | 56 +++++++++++++++++++++++++++++++++++++--- token/id_test.go | 42 ++++++++++++++++++++++++++++-- 7 files changed, 156 insertions(+), 16 deletions(-) diff --git a/token/app.go b/token/app.go index 8b33898..0b9fd11 100644 --- a/token/app.go +++ b/token/app.go @@ -93,3 +93,10 @@ func (app *App) ID() *AppID { } return new(AppID).SetBytes(app.Marshal()) } + +func (app *App) SetID(id *AppID) *App { + if id == nil { + return nil + } + return app.Unmarshal(id.Bytes()) +} diff --git a/token/app_test.go b/token/app_test.go index f44976f..7b53d02 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -89,7 +89,6 @@ func TestStringSetStringApp(t *testing.T) { if res := new(App).String(); res != "" { t.Errorf("expected empty string, got %q", res) } - app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -135,11 +134,9 @@ func TestMarshalUnmarshalApp(t *testing.T) { if res := new(App).Marshal(); res != nil { t.Errorf("expected nil, got %v", res) } - if res := new(App).Unmarshal([]byte{1}); res != nil { t.Errorf("expected nil, got %v", res) } - app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -164,7 +161,6 @@ func TestAppID(t *testing.T) { if id := new(App).ID(); id != nil { t.Errorf("expected nil, got %v", id) } - app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -177,4 +173,14 @@ func TestAppID(t *testing.T) { if !bytes.Equal(id.Bytes(), app.Marshal()) { t.Errorf("expected %v, got %v", app.Marshal(), id.Bytes()) } + if res := new(App).SetID(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + newApp := new(App).SetID(id) + if newApp == nil { + t.Fatalf("error decoding app ID") + } + if newApp.String() != app.String() { + t.Errorf("expected %s, got %s", app.String(), newApp.String()) + } } diff --git a/token/consts.go b/token/consts.go index acc11f4..989005d 100644 --- a/token/consts.go +++ b/token/consts.go @@ -11,8 +11,9 @@ const ( appNameMaxLen = 20 redirectURIPattern = `^https?://[a-zA-Z0-9-]+(\.[a-zA-Z0-9-]+)+(/[a-zA-Z0-9-._~:/?#[\]@!$&'()*+,;=]*)?$` redirectURIMaxLen = 80 - minDuration = 5 * time.Minute + minDuration = 30 * time.Second maxDuration = 180 * 24 * time.Hour + tokenSeparator = '.' ) var uriRegexp = regexp.MustCompile(redirectURIPattern) diff --git a/token/expiration.go b/token/expiration.go index ddb690c..6f6ab12 100644 --- a/token/expiration.go +++ b/token/expiration.go @@ -7,8 +7,20 @@ import ( type Expiration time.Time -func NewExpiration(d time.Duration) Expiration { - return Expiration(time.Now().Add(d)) +func NewExpiration(d time.Duration) *Expiration { + if d < minDuration || d > maxDuration { + return nil + } + exp := Expiration(time.Now().Add(d)) + return &exp +} + +func (exp *Expiration) Time() time.Time { + return time.Time(*exp) +} + +func (exp *Expiration) Valid() bool { + return time.Now().Before(exp.Time()) } func (exp *Expiration) String() string { diff --git a/token/expiration_test.go b/token/expiration_test.go index 348a0af..ffcb84c 100644 --- a/token/expiration_test.go +++ b/token/expiration_test.go @@ -6,8 +6,36 @@ import ( "time" ) +func TestNewExpirationTime(t *testing.T) { + exp := NewExpiration(minDuration - 1) + if exp != nil { + t.Fatalf("expected nil, got %v", exp) + } + exp = NewExpiration(minDuration * 2) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") + } + expTime := exp.Time() + expected := time.Now().Add(minDuration * 2) + if expected.Sub(expTime) > time.Millisecond*50 { + t.Errorf("expected %v, got %v", expected, expTime) + } +} + +func TestExpirationValid(t *testing.T) { + t.Parallel() + exp := NewExpiration(minDuration) + if !exp.Valid() { + t.Errorf("expected valid expiration, got invalid") + } + time.Sleep(minDuration) + if exp.Valid() { + t.Errorf("expected invalid expiration, got valid") + } +} + func TestStringSetStringExpiration(t *testing.T) { - exp := NewExpiration(time.Second) + exp := NewExpiration(minDuration) str := exp.String() decoded := new(Expiration).SetString(str) if decoded == nil { @@ -25,7 +53,7 @@ func TestStringSetStringExpiration(t *testing.T) { } func TestBytesSetBytesExpiration(t *testing.T) { - exp := NewExpiration(time.Second) + exp := NewExpiration(minDuration) b := exp.Bytes() decoded := new(Expiration).SetBytes(b) if decoded == nil { @@ -43,7 +71,7 @@ func TestBytesSetBytesExpiration(t *testing.T) { } func TestMarshalUnmarshalExpiration(t *testing.T) { - exp := NewExpiration(time.Second) + exp := NewExpiration(minDuration) encoded := exp.Marshal() decoded := new(Expiration).Unmarshal(encoded) if decoded == nil { diff --git a/token/id.go b/token/id.go index 133db3c..5f81437 100644 --- a/token/id.go +++ b/token/id.go @@ -1,8 +1,10 @@ package token import ( + "bytes" "crypto/ed25519" "crypto/sha256" + "encoding/hex" ) type AppID string @@ -43,15 +45,61 @@ func (id *AppID) Sign(data []byte) []byte { return nil } hData := sha256.Sum256(data) - return ed25519.Sign(privKey, hData[:]) + rawSign := ed25519.Sign(privKey, hData[:]) + // encode to hex + sign := make([]byte, hex.EncodedLen(len(rawSign))) + hex.Encode(sign, rawSign) + return sign } -func (id *AppID) Verify(data, sig []byte) bool { +func (id *AppID) Verify(msg, sig []byte) bool { privKey := id.PrivKey() if privKey == nil { return false } + // decode sign from hex + rawSign := make([]byte, hex.DecodedLen(len(sig))) + if _, err := hex.Decode(rawSign, sig); err != nil { + return false + } pubKey := privKey.Public().(ed25519.PublicKey) - hData := sha256.Sum256(data) - return ed25519.Verify(pubKey, hData[:], sig) + hMsg := sha256.Sum256(msg) + return ed25519.Verify(pubKey, hMsg[:], rawSign) +} + +func (id *AppID) NewToken(secret, email string) []byte { + app := new(App).SetID(id) + if app == nil { + return nil + } + exp := NewExpiration(app.SessionDuration) + msg := signMsg(id.Bytes(), []byte(secret), []byte(email), exp.Bytes()) + sig := id.Sign(msg) + return fmtToken(exp.Marshal(), sig) +} + +func (id *AppID) VerifyToken(token []byte, secret, email string) bool { + if len(token) == 0 { + return false + } + parts := bytes.Split(token, []byte{tokenSeparator}) + if len(parts) != 2 { + return false + } + dExp := new(Expiration).Unmarshal(parts[0]) + if dExp == nil || !dExp.Valid() { + return false + } + sig := parts[1] + msg := signMsg(id.Bytes(), []byte(secret), []byte(email), dExp.Bytes()) + return id.Verify(msg, sig) +} + +func signMsg(id, secret, email, exp []byte) []byte { + return bytes.Join([][]byte{id, email, exp, secret}, nil) +} + +func fmtToken(exp, sig []byte) []byte { + t := append(exp, tokenSeparator) + return append(t, sig...) } diff --git a/token/id_test.go b/token/id_test.go index 892d74f..c052fab 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -3,13 +3,13 @@ package token import ( "bytes" "testing" + "time" ) func TestStringSetStringAppID(t *testing.T) { if id := new(AppID).SetString("testID"); id != nil { t.Errorf("expected nil, got %v", id) } - app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -63,7 +63,6 @@ func TestPrivKeySignVerifyAppID(t *testing.T) { if new(AppID).Verify([]byte("test data"), []byte("test sig")) { t.Errorf("expected signature to be invalid") } - app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -81,4 +80,43 @@ func TestPrivKeySignVerifyAppID(t *testing.T) { if !id.Verify(data, sig) { t.Errorf("expected signature to be valid") } + if id.Verify(data, []byte("invalid sig")) { + t.Errorf("expected signature to be invalid") + } +} + +func TestNewTokenVerifyToken(t *testing.T) { + t.Parallel() + if res := new(AppID).NewToken("", ""); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: 30 * time.Second, + } + id := app.ID() + if id == nil { + t.Fatalf("error decoding app ID") + } + email := "test@email.com" + secret := "api_secret" + token := id.NewToken(secret, email) + if token == nil { + t.Fatalf("error creating token") + } + if !id.VerifyToken(token, secret, email) { + t.Errorf("expected token to be valid") + } + time.Sleep(app.SessionDuration + 1) + if id.VerifyToken(token, secret, email) { + t.Errorf("expected token to be invalid") + } + if id.VerifyToken(nil, secret, email) { + t.Errorf("expected token to be invalid") + } + exp := NewExpiration(minDuration) + if id.VerifyToken(exp.Marshal(), secret, email) { + t.Errorf("expected token to be invalid") + } } From b66553cbacb93d57981a6e199185a5e26be0014c Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 18 Feb 2025 00:07:18 +0100 Subject: [PATCH 07/36] include hedged nonce during sign generation --- token/app_test.go | 2 ++ token/id.go | 56 +++++++++++++++++++++++++------------- token/id_test.go | 68 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 94 insertions(+), 32 deletions(-) diff --git a/token/app_test.go b/token/app_test.go index 7b53d02..c53211b 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -12,6 +12,8 @@ const ( testSessionDuration = time.Minute * 30 ) +var testAppSecret = []byte("super_secret_key") + func TestValidApp(t *testing.T) { app := &App{ Name: testAppName, diff --git a/token/id.go b/token/id.go index 5f81437..4fc6225 100644 --- a/token/id.go +++ b/token/id.go @@ -3,6 +3,7 @@ package token import ( "bytes" "crypto/ed25519" + "crypto/hmac" "crypto/sha256" "encoding/hex" ) @@ -30,30 +31,34 @@ func (id *AppID) SetBytes(data []byte) *AppID { return id.SetString(string(data)) } -func (id *AppID) PrivKey() ed25519.PrivateKey { +func (id *AppID) PrivKey(secret []byte) ed25519.PrivateKey { + hFn := hmac.New(sha256.New, secret) + bID := id.Bytes() if len(bID) == 0 { return nil } - hID := sha256.Sum256(bID) - return ed25519.NewKeyFromSeed(hID[:]) + // hID := sha256.Sum256(bID) + hID := hFn.Sum(bID) + return ed25519.NewKeyFromSeed(hID[:32]) } -func (id *AppID) Sign(data []byte) []byte { - privKey := id.PrivKey() +func (id *AppID) Sign(secret, msg []byte) []byte { + privKey := id.PrivKey(secret) if len(privKey) == 0 { return nil } - hData := sha256.Sum256(data) - rawSign := ed25519.Sign(privKey, hData[:]) + hmsg := sha256.Sum256(msg) + data := append(msg, hedgedNonce(privKey[:], hmsg[:])...) + rawSign := ed25519.Sign(privKey, data[:]) // encode to hex sign := make([]byte, hex.EncodedLen(len(rawSign))) hex.Encode(sign, rawSign) return sign } -func (id *AppID) Verify(msg, sig []byte) bool { - privKey := id.PrivKey() +func (id *AppID) Verify(secret, msg, sig []byte) bool { + privKey := id.PrivKey(secret) if privKey == nil { return false } @@ -62,23 +67,24 @@ func (id *AppID) Verify(msg, sig []byte) bool { if _, err := hex.Decode(rawSign, sig); err != nil { return false } + hmsg := sha256.Sum256(msg) + data := append(msg, hedgedNonce(privKey[:], hmsg[:])...) pubKey := privKey.Public().(ed25519.PublicKey) - hMsg := sha256.Sum256(msg) - return ed25519.Verify(pubKey, hMsg[:], rawSign) + return ed25519.Verify(pubKey, data, rawSign) } -func (id *AppID) NewToken(secret, email string) []byte { +func (id *AppID) NewToken(secret []byte, email string) []byte { app := new(App).SetID(id) if app == nil { return nil } exp := NewExpiration(app.SessionDuration) - msg := signMsg(id.Bytes(), []byte(secret), []byte(email), exp.Bytes()) - sig := id.Sign(msg) + msg := signMsg(id.Bytes(), []byte(email), exp.Bytes()) + sig := id.Sign(secret, msg) return fmtToken(exp.Marshal(), sig) } -func (id *AppID) VerifyToken(token []byte, secret, email string) bool { +func (id *AppID) VerifyToken(secret, token []byte, email string) bool { if len(token) == 0 { return false } @@ -91,15 +97,27 @@ func (id *AppID) VerifyToken(token []byte, secret, email string) bool { return false } sig := parts[1] - msg := signMsg(id.Bytes(), []byte(secret), []byte(email), dExp.Bytes()) - return id.Verify(msg, sig) + msg := signMsg(id.Bytes(), []byte(email), dExp.Bytes()) + return id.Verify(secret, msg, sig) } -func signMsg(id, secret, email, exp []byte) []byte { - return bytes.Join([][]byte{id, email, exp, secret}, nil) +func signMsg(id, email, exp []byte) []byte { + res := append(id, email...) + return append(res, exp...) } func fmtToken(exp, sig []byte) []byte { t := append(exp, tokenSeparator) return append(t, sig...) } + +func hedgedNonce(inputs ...[]byte) []byte { + if len(inputs) == 0 || len(inputs[0]) == 0 { + return nil + } + hFn := hmac.New(sha256.New, inputs[0]) + for _, in := range inputs[1:] { + hFn.Write(in) + } + return hFn.Sum(nil) +} diff --git a/token/id_test.go b/token/id_test.go index c052fab..f1e60cd 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -2,6 +2,8 @@ package token import ( "bytes" + "crypto/hmac" + "crypto/sha256" "testing" "time" ) @@ -54,13 +56,13 @@ func TestBytesSetBytesAppID(t *testing.T) { } func TestPrivKeySignVerifyAppID(t *testing.T) { - if privKey := new(AppID).PrivKey(); privKey != nil { + if privKey := new(AppID).PrivKey(testAppSecret); privKey != nil { t.Errorf("expected nil, got %v", privKey) } - if sig := new(AppID).Sign([]byte("test data")); sig != nil { + if sig := new(AppID).Sign(testAppSecret, []byte("test data")); sig != nil { t.Errorf("expected nil, got %v", sig) } - if new(AppID).Verify([]byte("test data"), []byte("test sig")) { + if new(AppID).Verify(testAppSecret, []byte("test data"), []byte("test sig")) { t.Errorf("expected signature to be invalid") } app := &App{ @@ -73,21 +75,21 @@ func TestPrivKeySignVerifyAppID(t *testing.T) { t.Fatalf("error decoding app ID") } data := []byte("test data") - sig := id.Sign(data) + sig := id.Sign(testAppSecret, data) if sig == nil { t.Fatalf("error signing data") } - if !id.Verify(data, sig) { + if !id.Verify(testAppSecret, data, sig) { t.Errorf("expected signature to be valid") } - if id.Verify(data, []byte("invalid sig")) { + if id.Verify(testAppSecret, data, []byte("invalid sig")) { t.Errorf("expected signature to be invalid") } } func TestNewTokenVerifyToken(t *testing.T) { t.Parallel() - if res := new(AppID).NewToken("", ""); res != nil { + if res := new(AppID).NewToken(nil, ""); res != nil { t.Errorf("expected nil, got %v", res) } app := &App{ @@ -100,23 +102,63 @@ func TestNewTokenVerifyToken(t *testing.T) { t.Fatalf("error decoding app ID") } email := "test@email.com" - secret := "api_secret" - token := id.NewToken(secret, email) + token := id.NewToken(testAppSecret, email) if token == nil { t.Fatalf("error creating token") } - if !id.VerifyToken(token, secret, email) { + if !id.VerifyToken(testAppSecret, token, email) { t.Errorf("expected token to be valid") } time.Sleep(app.SessionDuration + 1) - if id.VerifyToken(token, secret, email) { + if id.VerifyToken(testAppSecret, token, email) { t.Errorf("expected token to be invalid") } - if id.VerifyToken(nil, secret, email) { + if id.VerifyToken(testAppSecret, nil, email) { t.Errorf("expected token to be invalid") } exp := NewExpiration(minDuration) - if id.VerifyToken(exp.Marshal(), secret, email) { + if id.VerifyToken(testAppSecret, exp.Marshal(), email) { t.Errorf("expected token to be invalid") } } + +func Test_signMsg(t *testing.T) { + expected := []byte("testcombineddata") + if res := signMsg([]byte("test"), []byte("combined"), []byte("data")); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } +} + +func Test_fmtToken(t *testing.T) { + sig := []byte("testsig") + exp := []byte("testexp") + expected := append(exp, tokenSeparator) + expected = append(expected, sig...) + if res := fmtToken(exp, sig); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } +} + +func Test_hedgedNonce(t *testing.T) { + if res := hedgedNonce(); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := hedgedNonce(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + seed := []byte("test") + hFn := hmac.New(sha256.New, seed) + expected := hFn.Sum(nil) + if res := hedgedNonce(seed); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } + + seed = []byte("test") + hFn = hmac.New(sha256.New, seed) + in := []byte("data") + hFn.Write(in) + expected = hFn.Sum(nil) + if res := hedgedNonce(seed, in); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } +} From 003d0fb5b8920eb5c14d8b9c1f76ef51756a6676 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 19:34:07 +0100 Subject: [PATCH 08/36] remove old code --- api/service_test.go | 34 ----------- assets/app_email_template.html | 98 -------------------------------- assets/token_email_template.html | 88 ---------------------------- email/disposable.go | 68 ---------------------- email/templates.go | 72 ----------------------- 5 files changed, 360 deletions(-) delete mode 100644 api/service_test.go delete mode 100644 assets/app_email_template.html delete mode 100644 assets/token_email_template.html delete mode 100644 email/disposable.go delete mode 100644 email/templates.go diff --git a/api/service_test.go b/api/service_test.go deleted file mode 100644 index b065eec..0000000 --- a/api/service_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package api - -import ( - "context" - "testing" - "time" - - "github.com/simpleauthlink/authapi/email" -) - -func TestNew(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - srv, err := New(ctx, &Config{ - Server: "localhost", - ServerPort: 8080, - CleanerCooldown: 30 * time.Second, - EmailConfig: email.EmailConfig{ - EmailHost: "smtp.gmail.com", - EmailPort: 587, - Address: "test@email.com", - Password: "password1234", - }, - }) - if err != nil { - t.Errorf("expected nil, got %v", err) - return - } - if srv == nil { - t.Errorf("expected not nil, got nil") - return - } -} diff --git a/assets/app_email_template.html b/assets/app_email_template.html deleted file mode 100644 index 6e77482..0000000 --- a/assets/app_email_template.html +++ /dev/null @@ -1,98 +0,0 @@ - - - - - - - Your app '{{.AppName}}' is ready 🎉 - - - - - - - - -
- - - - - - - - - - - - - - - - -
- -

SimpleAuth.link

-
- 👋 Hi, {{.EmailHandler}}! -

- Your app '{{.AppName}}' has been successfully created ✅. Here are the details of your app: -

- - - - - - - - - - - - - - - - - -
App ID{{.AppID}}
App Name{{.AppName}}
App Secret{{.Secret}}
Redirect URL{{.RedirectURL}}
-

- Check out the documentation to getting started integrating SimpleAuth with your app 🚀. -
- - - - -
- 🤓 Read the documentation -
-
- ⚠️ Remember to keep your app secret safe and secure. ⚠️ -

- You can always regenerate a new app secret. -
-
- - - \ No newline at end of file diff --git a/assets/token_email_template.html b/assets/token_email_template.html deleted file mode 100644 index 2a51785..0000000 --- a/assets/token_email_template.html +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - - Your Magic Link for {{.AppName}} Login - - - - - - - - -
- - - - - - - - - - - - - - - - -
- -

SimpleAuth.link

-
- 👋 Hi, {{.EmailHandler}}! -

- Your magic link to login to '{{.AppName}}' is ready 🎉. -

- Click the button below to login to your account. 👇 -
- - - - - - - -
- Login to - {{.AppName}} -
-
- {{.Token}} -
-
- or copy and paste the following link in your browser: -

-
{{.MagicLink}}
-

- If you did not request this, please ignore this email. -
-
- - - \ No newline at end of file diff --git a/email/disposable.go b/email/disposable.go deleted file mode 100644 index ab62b07..0000000 --- a/email/disposable.go +++ /dev/null @@ -1,68 +0,0 @@ -package email - -import ( - "bufio" - "context" - "errors" - "net/http" - "regexp" - "strings" - "time" -) - -// domainRgx is the regular expression used to validate a domain. -var domainRgx = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) - -// LoadRemoteDisposableDomains loads a list of disposable domains from a remote -// source url. It reads the content of the source url line by line and parses -// each line as a domain. It returns a list of disposable domains or an error if -// something fails. -func LoadRemoteDisposableDomains(ctx context.Context, disposableSrc string) ([]string, error) { - internalCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - // prepare the request - req, err := http.NewRequestWithContext(internalCtx, http.MethodGet, disposableSrc, nil) - if err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - // perform the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - // read the response body line by line - defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) - var domains []string - for scanner.Scan() { - domain := scanner.Text() - if domainRgx.MatchString(domain) { - domains = append(domains, domain) - } - } - if err := scanner.Err(); err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - return domains, nil -} - -// CheckEmail checks if the email address is valid. It compares the domain with -// a list of disallowed domains. It returns true if the email address is valid, -// otherwise it returns false. -func CheckEmail(disallowedDomains []string, email string) bool { - if len(disallowedDomains) == 0 { - return true - } - // split the email address - parts := strings.Split(email, "@") - if len(parts) != 2 { - return false - } - // check the domain - for _, domain := range disallowedDomains { - if domain == parts[1] { - return false - } - } - return true -} diff --git a/email/templates.go b/email/templates.go deleted file mode 100644 index c2aa470..0000000 --- a/email/templates.go +++ /dev/null @@ -1,72 +0,0 @@ -package email - -import ( - "bytes" - "fmt" - "strings" - "text/template" -) - -// UserEmailData struct includes the data required to fill the user email -// template. -type UserEmailData struct { - AppName string - EmailHandler string - MagicLink string - Token string -} - -// AppEmailData struct includes the data required to fill the app email -// template. -type AppEmailData struct { - AppID string - AppName string - RedirectURL string - Secret string - EmailHandler string -} - -// NewUserEmailData creates a new UserEmailData with the provided data. -func NewUserEmailData(appName, email, magicLink, token string) *UserEmailData { - return &UserEmailData{ - AppName: appName, - EmailHandler: emailHandler(email), - MagicLink: magicLink, - Token: token, - } -} - -// NewAppEmailData creates a new AppEmailData with the provided data. -func NewAppEmailData(appID, appName, redirectURL, secret, email string) *AppEmailData { - return &AppEmailData{ - AppID: appID, - AppName: appName, - RedirectURL: redirectURL, - Secret: secret, - EmailHandler: emailHandler(email), - } -} - -// ParseTemplate parses the template file provided with the data provided. It -// returns the parsed template as a string. If an error occurs, it returns the -// error. -func ParseTemplate(templatePath string, data interface{}) (string, error) { - // parse the template file provided - t, err := template.ParseFiles(templatePath) - if err != nil { - return "", err - } - // execute the template to fill it with the data provided - buf := new(bytes.Buffer) - if err := t.Execute(buf, data); err != nil { - return "", fmt.Errorf("error parsing template: %w", err) - } - return buf.String(), nil -} - -// emailHandler method extracts the email handler from the email address. It -// splits the email address by the "@" symbol and returns the first part. -func emailHandler(emailAddress string) string { - emailParts := strings.Split(emailAddress, "@") - return emailParts[0] -} From 810e45d9abf8bd55a74f081c7c257cabd6963e6b Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 19:37:51 +0100 Subject: [PATCH 09/36] some internal stuff like custom error class and fake smtp server for testing --- internal/error.go | 41 ++++++++++++++ internal/fake_smtp_server.go | 101 +++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 internal/error.go create mode 100644 internal/fake_smtp_server.go diff --git a/internal/error.go b/internal/error.go new file mode 100644 index 0000000..852d61c --- /dev/null +++ b/internal/error.go @@ -0,0 +1,41 @@ +package internal + +import "fmt" + +// Error represents an error with a message and a trace. It is a custom struct +// that can be used to wrap errors with additional information. +type Error struct { + msg string + trace error +} + +// NewErr creates a new Error instance with the given message. It is a helper +// function that simplifies the creation of Error instances in other packages. +func NewErr(msg string) *Error { + return &Error{msg: msg} +} + +// Error returns the error message. If the error has a trace, it is appended to +// the message. The error implements the error interface. +func (e *Error) Error() string { + err := fmt.Errorf("%s", e.msg) + if e.trace != nil { + err = fmt.Errorf("%s: %w", err, e.trace) + } + return err.Error() +} + +// With adds an error as a trace to the error. It is a helper function that +// simplifies the addition of traces to Error instances in other packages. +func (e *Error) With(err error) *Error { + e.trace = err + return e +} + +// Withf adds a formatted error message as a trace to the error. It is a helper +// function that simplifies the addition of formatted traces to Error instances +// in other packages. +func (e *Error) Withf(tmpl string, args ...any) *Error { + e.trace = fmt.Errorf(tmpl, args...) + return e +} diff --git a/internal/fake_smtp_server.go b/internal/fake_smtp_server.go new file mode 100644 index 0000000..e15f5ba --- /dev/null +++ b/internal/fake_smtp_server.go @@ -0,0 +1,101 @@ +package internal + +import ( + "bufio" + "context" + "fmt" + "net" + "strings" +) + +// FakeSMTPServer represents a simple SMTP testing server. +type FakeSMTPServer struct { + addr string + inbox chan string + listener net.Listener +} + +// NewFakeSMTPServer creates a new FakeSMTPServer instance that listens on the +// given address and port and stores the received emails in the inbox channel +// provided. +func NewFakeSMTPServer(addr string, port int, inbox chan string) *FakeSMTPServer { + return &FakeSMTPServer{addr: fmt.Sprintf("%s:%d", addr, port), inbox: inbox} +} + +// Start method launches the test SMTP server. +func (s *FakeSMTPServer) Start(ctx context.Context) error { + var err error + s.listener, err = net.Listen("tcp", s.addr) + if err != nil { + return err + } + go func() { + for { + select { + case <-ctx.Done(): + s.listener.Close() + default: + conn, err := s.listener.Accept() + if err != nil { + return + } + go s.handleConn(conn) + } + } + }() + return nil +} + +// Stop method shuts down the test SMTP server. +func (s *FakeSMTPServer) Stop() { + s.listener.Close() +} + +func (s *FakeSMTPServer) handleConn(conn net.Conn) { + defer conn.Close() + reader := bufio.NewReader(conn) + // send greeting + fmt.Fprintf(conn, "220 Fake SMTP Service Ready\r\n") + var dataBuilder strings.Builder + inData := false + // read incoming data + for { + line, err := reader.ReadString('\n') + if err != nil { + return + } + line = strings.TrimRight(line, "\r\n") + // check if we are in the data section + if inData { + if line == "." { + inData = false + // send back a confirmation and store the data + fmt.Fprintf(conn, "250 OK\r\n") + s.inbox <- dataBuilder.String() + dataBuilder.Reset() + continue + } + dataBuilder.WriteString(line + "\n") + continue + } + // simple command handling + switch { + case strings.HasPrefix(line, "HELO"), strings.HasPrefix(line, "EHLO"): + fmt.Fprintf(conn, "250 Hello\r\n") + case strings.HasPrefix(line, "MAIL FROM:"): + fmt.Fprintf(conn, "250 OK\r\n") + case strings.HasPrefix(line, "RCPT TO:"): + fmt.Fprintf(conn, "250 OK\r\n") + case strings.HasPrefix(line, "DATA"): + // prepare to receive data + fmt.Fprintf(conn, "354 End data with .\r\n") + inData = true + case strings.HasPrefix(line, "QUIT"): + // close the connection + fmt.Fprintf(conn, "221 Bye\r\n") + return + default: + fmt.Fprintf(conn, "250 OK\r\n") + } + } +} From b21a917161de01f0b022640ac5dbe943e8bfc90d Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 19:38:28 +0100 Subject: [PATCH 10/36] new email package with tests and new templates definitions which support plain and html content --- email/emailqueue.go | 236 ++++++++++++----------- email/emailqueue_test.go | 289 ++++++++++++++++++++++++++++ email/errors.go | 36 ++-- email/template.go | 94 +++++++++ email/template_test.go | 133 +++++++++++++ email/templates/login/definition.go | 30 +++ email/templates/login/template.html | 75 ++++++++ 7 files changed, 771 insertions(+), 122 deletions(-) create mode 100644 email/emailqueue_test.go create mode 100644 email/template.go create mode 100644 email/template_test.go create mode 100644 email/templates/login/definition.go create mode 100644 email/templates/login/template.html diff --git a/email/emailqueue.go b/email/emailqueue.go index 9dac8ec..dbf1e96 100644 --- a/email/emailqueue.go +++ b/email/emailqueue.go @@ -4,40 +4,63 @@ import ( "bytes" "context" "fmt" + "mime/multipart" "net/mail" "net/smtp" "net/textproto" - "regexp" "sync" - "time" ) -// sendRetries is the number of retries to send the email. -const sendRetries = 3 +// defaultSendRetries is the default number of retries to send the email. +const defaultSendRetries = 3 -// emailRgx is the regular expression used to validate an email address. -var emailRgx = regexp.MustCompile(`^[\w-\.]+@([\w-]+\.)+[\w-]{2,}$`) +// Email struct represents the email that is going to be sent. It includes the +// recipient email address, the subject and the body of the email. +type Email struct { + To string + Subject string + Body []byte + PlainBody []byte +} + +// Valid method checks if the email is valid. It returns true if the recipient +// email address, the subject and the body are not empty. +func (e *Email) Valid() bool { + if e.Subject == "" || (len(e.Body) == 0 && len(e.PlainBody) == 0) { + return false + } + _, err := mail.ParseAddress(e.To) + return err == nil +} // EmailConfig struct represents the email configuration that is needed to send // an email using and SMTP server. It includes the email address (used as the // sender address but also as the username for the SMTP server), the email // server hostname, its port and the password. type EmailConfig struct { - Address string - EmailHost string - EmailPort int - Password string - DisposableSrc string - TokenEmailTemplate string - AppEmailTemplate string + FromName string + FromAddress string + SMTPUsername string + SMTPPassword string + SMTPServer string + SMTPPort int + Retries int + ErrorCh chan error } -// Email struct represents the email that is going to be sent. It includes the -// recipient email address, the subject and the body of the email. -type Email struct { - To string - Subject string - Body string +// Valid method checks if the email configuration is valid. It returns true if +// the sender name, the SMTP server and its port are not empty, and the sender +// email address is valid. It also sets the number of retries to the default +// value if it is not set. +func (cfg *EmailConfig) Valid() bool { + if cfg.FromName == "" || cfg.SMTPServer == "" || cfg.SMTPPort == 0 { + return false + } + if cfg.Retries == 0 { + cfg.Retries = defaultSendRetries + } + _, err := mail.ParseAddress(cfg.FromAddress) + return err == nil } // EmailQueue struct represents the email queue. It includes the context and the @@ -45,37 +68,37 @@ type Email struct { // the email, the list of emails to send, and the waiter to wait for the // background process to finish. type EmailQueue struct { - ctx context.Context - cancel context.CancelFunc - cfg *EmailConfig - items []*Email - itemsMtx sync.Mutex - waiter sync.WaitGroup - disallowedDomains []string + ctx context.Context + cancel context.CancelFunc + cfg *EmailConfig + auth smtp.Auth + items []*Email + itemsMtx sync.Mutex + waiter sync.WaitGroup + errCh chan error } // NewEmailQueue creates a new EmailQueue with the provided configuration. func NewEmailQueue(ctx context.Context, cfg *EmailConfig) (*EmailQueue, error) { // check if the configuration is valid - if cfg.Address == "" || !emailRgx.MatchString(cfg.Address) || - cfg.EmailHost == "" || cfg.EmailPort == 0 || cfg.Password == "" { + if !cfg.Valid() { return nil, ErrInvalidConfig } + // init the email queue internalCtx, cancel := context.WithCancel(ctx) - // load the disposable domains if a source is provided - var err error - disallowedDomains := []string{} - if cfg.DisposableSrc != "" { - disallowedDomains, err = LoadRemoteDisposableDomains(internalCtx, cfg.DisposableSrc) + eq := &EmailQueue{ + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + items: []*Email{}, + errCh: cfg.ErrorCh, + } + // init SMTP auth + if cfg.SMTPUsername != "" && cfg.SMTPPassword != "" { + eq.auth = smtp.PlainAuth("", cfg.SMTPUsername, cfg.SMTPPassword, cfg.SMTPServer) } // return the email queue - return &EmailQueue{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - items: []*Email{}, - disallowedDomains: disallowedDomains, - }, err + return eq, nil } // Start method starts the email queue. It listens for new emails in the queue @@ -94,16 +117,16 @@ func (eq *EmailQueue) Start() { continue } if err := eq.Send(e); err != nil { - fmt.Println(err) - } else { - eq.Pop() + if eq.errCh != nil { + eq.errCh <- err + } } } - time.Sleep(time.Second) } }() } +// Stop method stops the email queue. func (eq *EmailQueue) Stop() { eq.cancel() eq.waiter.Wait() @@ -112,29 +135,15 @@ func (eq *EmailQueue) Stop() { // Push method adds a new email to the queue. func (eq *EmailQueue) Push(e *Email) error { // check if the email is valid - if e.To == "" || !emailRgx.MatchString(e.To) || e.Subject == "" || e.Body == "" { + if !e.Valid() { return ErrInvalidEmail } - // check if the email is allowed - if !eq.Allowed(e.To) { - return ErrDisallowedDomain - } eq.itemsMtx.Lock() eq.items = append(eq.items, e) eq.itemsMtx.Unlock() return nil } -// Top method returns the first email in the queue. -func (eq *EmailQueue) Top() *Email { - eq.itemsMtx.Lock() - defer eq.itemsMtx.Unlock() - if len(eq.items) == 0 { - return nil - } - return eq.items[0] -} - // Pop method removes the first email in the queue and returns it. func (eq *EmailQueue) Pop() *Email { eq.itemsMtx.Lock() @@ -152,73 +161,78 @@ func (eq *EmailQueue) Pop() *Email { // It composes the email message, creates the auth object with the email // credentials, the server string with the host and the port, and the receipts. // Finally, it sends the email. If something fails during the process, it -// returns an error. +// returns an error. It can be used even the queue is not started. func (eq *EmailQueue) Send(e *Email) error { + // check if the email is valid + if !e.Valid() { + return ErrInvalidEmail + } // compose the email body - body, err := eq.encodeEmail(e) + body, err := eq.composeBody(e) if err != nil { - return fmt.Errorf("error composing email: %w", err) - } - // check if the email is allowed - if !eq.Allowed(e.To) { - return ErrDisallowedDomain + return ErrComposeEmail.With(err) } - // create the auth object with the email credentials - auth := smtp.PlainAuth("", eq.cfg.Address, eq.cfg.Password, eq.cfg.EmailHost) // create the server string with the host and the port and the receipts - server := fmt.Sprintf("%s:%d", eq.cfg.EmailHost, eq.cfg.EmailPort) + server := fmt.Sprintf("%s:%d", eq.cfg.SMTPServer, eq.cfg.SMTPPort) receipts := []string{e.To} // send the email - for i := 0; i < sendRetries; i++ { - if err = smtp.SendMail(server, auth, eq.cfg.Address, receipts, body); err == nil { + for i := 0; i < eq.cfg.Retries; i++ { + if err = smtp.SendMail(server, eq.auth, eq.cfg.FromAddress, receipts, body); err == nil { break } } if err != nil { - return fmt.Errorf("error sending email: %w", err) + return ErrSendEmail.With(err) } return nil } -// Allowed method checks if the email address is allowed. It compares the domain -// with a list of disallowed domains. It returns true if the email address is -// allowed, otherwise it returns false. -func (eq *EmailQueue) Allowed(address string) bool { - if !emailRgx.MatchString(address) { - return false - } - return CheckEmail(eq.disallowedDomains, address) -} - -// encodeEmail method encodes the email to a byte slice. It validates the from -// and to addresses, sets the headers for the html email, and writes the body. -// It returns the encoded email or an error if something fails during the -// process. -func (eq *EmailQueue) encodeEmail(email *Email) ([]byte, error) { - // validate from address - from, err := mail.ParseAddress(eq.cfg.Address) +// composeBody creates the email body with the message data. It creates a +// multipart email with a plain text and an HTML part. It returns the email +// content as a byte slice or an error if the body could not be composed. +func (eq *EmailQueue) composeBody(msg *Email) ([]byte, error) { + // parse 'to' email address + to, err := mail.ParseAddress(msg.To) if err != nil { - return nil, fmt.Errorf("error parsing address: %w", err) + return nil, ErrParseAddress.With(err) } - // validate to address - to, err := mail.ParseAddress(email.To) - if err != nil { - return nil, fmt.Errorf("error parsing address: %w", err) - } - // set headers for html email - header := textproto.MIMEHeader{} - header.Set(textproto.CanonicalMIMEHeaderKey("from"), from.Address) - header.Set(textproto.CanonicalMIMEHeaderKey("to"), to.Address) - header.Set(textproto.CanonicalMIMEHeaderKey("content-type"), "text/html; charset=UTF-8") - header.Set(textproto.CanonicalMIMEHeaderKey("mime-version"), "1.0") - header.Set(textproto.CanonicalMIMEHeaderKey("subject"), email.Subject) - // init empty message - var buffer bytes.Buffer - // write header - for key, value := range header { - buffer.WriteString(fmt.Sprintf("%s: %s\r\n", key, value[0])) - } - // write body - buffer.WriteString(fmt.Sprintf("\r\n%s", email.Body)) - return buffer.Bytes(), nil + // create email headers + var headers bytes.Buffer + boundary := "----=_Part_0_123456789.123456789" + headers.WriteString(fmt.Sprintf("From: %s\r\n", eq.cfg.FromAddress)) + headers.WriteString(fmt.Sprintf("To: %s\r\n", to.String())) + headers.WriteString(fmt.Sprintf("Subject: %s\r\n", msg.Subject)) + headers.WriteString("MIME-Version: 1.0\r\n") + headers.WriteString(fmt.Sprintf("Content-Type: multipart/alternative; boundary=\"%s\"\r\n", boundary)) + headers.WriteString("\r\n") // blank line between headers and body + // create multipart writer + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if err := writer.SetBoundary(boundary); err != nil { + return nil, ErrSetBoundary.With(err) + } + // plain text part + textPart, _ := writer.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"text/plain; charset=\"UTF-8\""}, + "Content-Transfer-Encoding": {"7bit"}, + }) + if _, err := textPart.Write(msg.PlainBody); err != nil { + return nil, ErrWriteBody.With(err) + } + // HTML part + htmlPart, _ := writer.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"text/html; charset=\"UTF-8\""}, + "Content-Transfer-Encoding": {"7bit"}, + }) + if _, err := htmlPart.Write(msg.Body); err != nil { + return nil, ErrWriteHTMLBody.With(err) + } + if err := writer.Close(); err != nil { + return nil, ErrCloseEmailWriter.With(err) + } + // combine headers and body and return the content + var email bytes.Buffer + email.Write(headers.Bytes()) + email.Write(body.Bytes()) + return email.Bytes(), nil } diff --git a/email/emailqueue_test.go b/email/emailqueue_test.go new file mode 100644 index 0000000..45e61a9 --- /dev/null +++ b/email/emailqueue_test.go @@ -0,0 +1,289 @@ +package email + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/simpleauthlink/authapi/internal" +) + +const ( + testServerAddr = "127.0.0.1" + testServerPort = 2525 + testSenderName = "Test Sender" + testSender = "sender@testmail.com" + testReceiver = "receiver@testmail.com" + testSubject = "Test email" + testBody = "This is a test email" + testHTMLBody = "

This is a test email

" +) + +var inboxChan = make(chan string, 1) + +func TestMain(m *testing.M) { + defer close(inboxChan) + // create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // start test SMTP server to receive the email + testSrv := internal.NewFakeSMTPServer(testServerAddr, testServerPort, inboxChan) + if err := testSrv.Start(ctx); err != nil { + panic(err) + } + defer testSrv.Stop() + m.Run() +} + +func TestValidEmail(t *testing.T) { + if !(&Email{ + To: testReceiver, + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be valid") + } + if !(&Email{ + To: testReceiver, + Subject: testSubject, + Body: []byte(testBody), + PlainBody: nil, + }).Valid() { + t.Error("expected email to be valid") + } + if (&Email{ + To: testReceiver, + Subject: "", + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (&Email{ + To: "", + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (&Email{ + To: "invalidEmail", + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (&Email{}).Valid() { + t.Error("expected email to be invalid") + } +} + +func TestValidConfig(t *testing.T) { + if !(&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be valid") + } + if (&EmailConfig{ + SMTPServer: "", + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: 0, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: "", + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: "", + }).Valid() { + t.Error("expected config to be invalid") + } +} + +func TestNewEmailQueue(t *testing.T) { + // create email queue with valid config + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + if eq == nil { + t.Error("expected email queue to be created") + } + // create email queue with auth + eq, err = NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + SMTPUsername: "username", + SMTPPassword: "password", + }) + if err != nil { + t.Fatal(err) + } + if eq == nil { + t.Error("expected email queue to be created") + } + // create email queue with invalid config + eq, err = NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: "", + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err == nil { + t.Error("expected error creating email queue") + } + if eq != nil { + t.Error("expected email queue to be nil") + } +} + +func TestSendEmail(t *testing.T) { + // create email queue but don't start it + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + // send email + if err := eq.Send(&Email{ + To: testReceiver, + Subject: testSubject, + Body: []byte(testHTMLBody), + PlainBody: []byte(testBody), + }); err != nil { + t.Fatal(err) + } + // check if the email was received + select { + case receivedMsg := <-inboxChan: + if !strings.Contains(receivedMsg, testSubject) { + t.Errorf("expected email content to contain %q, got %q", testSubject, receivedMsg) + } + if !strings.Contains(receivedMsg, testBody) { + t.Errorf("expected email content to contain %q, got %q", testBody, receivedMsg) + } + if !strings.Contains(receivedMsg, testHTMLBody) { + t.Errorf("expected email content to contain %q, got %q", testHTMLBody, receivedMsg) + } + case <-time.After(2 * time.Second): + t.Error("timed out waiting for the email to be received") + } + // try to send email to an invalid SMTP server + badEq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: 8080, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + if err := badEq.Send(&Email{ + To: testReceiver, + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + }); err == nil { + t.Error("expected error sending email") + } +} + +func TestPushSendEmail(t *testing.T) { + // create email queue and start it + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error, 1) + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + ErrorCh: errCh, + }) + if err != nil { + t.Fatal(err) + } + eq.Start() + defer eq.Stop() + // push email + if err := eq.Push(&Email{ + To: testReceiver, + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + }); err != nil { + t.Fatal(err) + } + // check if the email was received + select { + case receivedMsg := <-inboxChan: + if !strings.Contains(receivedMsg, testSubject) { + t.Errorf("expected email content to contain %q, got %q", testSubject, receivedMsg) + } + if !strings.Contains(receivedMsg, testBody) { + t.Errorf("expected email content to contain %q, got %q", testBody, receivedMsg) + } + case <-time.After(2 * time.Second): + t.Error("timed out waiting for the email to be received") + } + // sleep to pop nil email + time.Sleep(2 * time.Second) + // push invalid email + if err := eq.Push(&Email{}); err == nil { + t.Error("expected error pushing invalid email") + } + // push another email and modify the email queue to fail + email := &Email{ + To: testReceiver, + Subject: testSubject, + Body: nil, + PlainBody: []byte(testBody), + } + if err := eq.Push(email); err != nil { + t.Fatal(err) + } + email.PlainBody = nil + if err := <-errCh; err == nil { + t.Error("expected error sending email") + } +} diff --git a/email/errors.go b/email/errors.go index 33218a7..1415164 100644 --- a/email/errors.go +++ b/email/errors.go @@ -1,19 +1,33 @@ package email -import "fmt" +import "github.com/simpleauthlink/authapi/internal" var ( // ErrInvalidConfig is the error returned when the configuration is invalid. - ErrInvalidConfig = fmt.Errorf("invalid configuration") + ErrInvalidConfig = internal.NewErr("invalid configuration") // ErrInitQueue is the error returned when the queue cannot be initialized. - ErrInitQueue = fmt.Errorf("error initializing the queue") - // ErrInvalidDomain is the error returned when the domain is invalid. - ErrInvalidDomain = fmt.Errorf("invalid domain") - // ErrLoadingDisposableDomains is the error returned when the disposable - // domains cannot be loaded. - ErrLoadingDisposableDomains = fmt.Errorf("error loading disposable domains") - // ErrDisallowedDomain is the error returned when the domain is disallowed. - ErrDisallowedDomain = fmt.Errorf("disallowed domain") + ErrInitQueue = internal.NewErr("error initializing the queue") // ErrInvalidEmail is the error returned when the email is invalid. - ErrInvalidEmail = fmt.Errorf("invalid email") + ErrInvalidEmail = internal.NewErr("invalid email") + // ErrInvalidTemplate is the error returned when the template is invalid. + ErrInvalidTemplate = internal.NewErr("invalid template") + // ErrSendEmail is the error returned when the email cannot be sent. + ErrSendEmail = internal.NewErr("error sending email") + // ErrComposeEmail is the error returned when the email cannot be composed. + ErrComposeEmail = internal.NewErr("error composing email") + // ErrParseAddress is the error returned when the email address cannot + // be parsed. + ErrParseAddress = internal.NewErr("error parsing email address") + // ErrSetBoundary is the error returned when the boundary cannot be set + // when a multipart email is composed. + ErrSetBoundary = internal.NewErr("error setting boundary") + // ErrWriteHTMLBody is the error returned when the email plain body cannot + // be written. + ErrWriteBody = internal.NewErr("error writing email plain body") + // ErrWriteHTMLBody is the error returned when the email HTML body cannot + // be written. + ErrWriteHTMLBody = internal.NewErr("error writing email html body") + // ErrCloseEmailWriter is the error returned when the email writer cannot + // be closed after composing the email. + ErrCloseEmailWriter = internal.NewErr("error closing email writer") ) diff --git a/email/template.go b/email/template.go new file mode 100644 index 0000000..9f4fada --- /dev/null +++ b/email/template.go @@ -0,0 +1,94 @@ +package email + +import ( + "bytes" + htmltemplate "html/template" + texttemplate "text/template" +) + +// EmailTemplate is the definition of an email template, which contains the +// HTML and plain text placeholders to be filled with the data. +type EmailTemplate struct { + HTML string + Plain string +} + +// Compose methods fills the email template with the data and returns the email +// ready to be sent. It returns the email or an error if the template could not +// be filled. It tries to fill both the HTML and plain text templates, but if +// any of them is missing, it will return an error. If some of the placeholders +// in the template are not filled, they will be left as they are. +func (temp *EmailTemplate) Compose(to, subject string, data any) (*Email, error) { + if to == "" || subject == "" { + return nil, ErrComposeEmail + } + // compose the html body + body, err := temp.composeHTML(data) + if err != nil { + return nil, err + } + // compose the plain body + plainBody, err := temp.composePlain(data) + if err != nil { + return nil, err + } + // if both bodies are empty, return an error + if plainBody == nil && body == nil { + return nil, ErrInvalidTemplate + } + // return the email with the filled bodies + email := &Email{ + To: to, + Subject: subject, + Body: body, + PlainBody: plainBody, + } + if !email.Valid() { + return nil, ErrComposeEmail.Withf("resulting email is not valid") + } + return email, nil +} + +// composePlain method fills the plain text template with the data and returns the +// filled content as a byte slice. It returns the filled template or an error +// if the template could not be filled. If the plain text template is empty, it +// returns nil and no error. +func (temp *EmailTemplate) composePlain(data any) ([]byte, error) { + if temp.Plain == "" { + return nil, nil + } + // parse the placeholder plain body template + tmpl, err := texttemplate.New("plain").Parse(temp.Plain) + if err != nil { + return nil, err + } + // inflate the template with the data + buf := new(bytes.Buffer) + if err := tmpl.Execute(buf, data); err != nil { + return nil, err + } + // return the notification with the plain body filled with the data + return buf.Bytes(), nil +} + +// composeHTML method fills the HTML template with the data and returns the filled +// content as a byte slice. It returns the filled template or an error if the +// template could not be filled. If the HTML template is empty, it returns nil +// and no error. +func (temp *EmailTemplate) composeHTML(data any) ([]byte, error) { + if temp.HTML == "" { + return nil, nil + } + // parse the email template + tmpl, err := htmltemplate.New("html").Parse(temp.HTML) + if err != nil { + return nil, err + } + // inflate the template with the data + buf := new(bytes.Buffer) + if err := tmpl.Execute(buf, data); err != nil { + return nil, err + } + // set the body of the notification + return buf.Bytes(), nil +} diff --git a/email/template_test.go b/email/template_test.go new file mode 100644 index 0000000..9cb763d --- /dev/null +++ b/email/template_test.go @@ -0,0 +1,133 @@ +package email + +import "testing" + +var testTemplate = &EmailTemplate{ + HTML: "

{{.Title}}

{{.Content}}

", + Plain: "Title: {{.Title}}\nContent: {{.Content}}", +} + +type testData struct { + Title string + Content string +} + +func TestCompose(t *testing.T) { + // valid data + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + email, err := testTemplate.Compose(testReceiver, testSubject, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if email.To != testReceiver { + t.Fatalf("got %v, want %v", email.To, testReceiver) + } + if email.Subject != testSubject { + t.Fatalf("got %v, want %v", email.Subject, testSubject) + } + expectedBody := "

Test Title

Test Content

" + expectedPlain := "Title: Test Title\nContent: Test Content" + if string(email.Body) != expectedBody { + t.Fatalf("got %v, want %v", string(email.Body), expectedBody) + } + if string(email.PlainBody) != expectedPlain { + t.Fatalf("got %v, want %v", string(email.PlainBody), expectedPlain) + } + // no subject + if _, err := testTemplate.Compose(testReceiver, "", data); err == nil { + t.Fatalf("expected error, got nil") + } + // no to address + if _, err := testTemplate.Compose("", testSubject, data); err == nil { + t.Fatalf("expected error, got nil") + } + // bad to address + if _, err := testTemplate.Compose("wrongEmail", testSubject, data); err == nil { + t.Fatalf("expected error, got nil") + } + // invalid template + emptyTemplate := &EmailTemplate{} + if _, err := emptyTemplate.Compose(testReceiver, testSubject, data); err == nil { + t.Fatalf("expected error, got nil") + } + // no html template + noHTMLTemplate := &EmailTemplate{Plain: testTemplate.Plain} + onlyPlainEmail, err := noHTMLTemplate.Compose(testReceiver, testSubject, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if onlyPlainEmail.Body != nil { + t.Fatalf("expected nil, got %v", string(onlyPlainEmail.Body)) + } + if string(onlyPlainEmail.PlainBody) != expectedPlain { + t.Fatalf("got %v, want %v", string(onlyPlainEmail.PlainBody), expectedPlain) + } + // no plain template + noPlainTemplate := &EmailTemplate{HTML: testTemplate.HTML} + onlyHTMLEmail, err := noPlainTemplate.Compose(testReceiver, testSubject, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if string(onlyHTMLEmail.Body) != expectedBody { + t.Fatalf("got %v, want %v", string(onlyHTMLEmail.Body), expectedBody) + } + if onlyHTMLEmail.PlainBody != nil { + t.Fatalf("expected nil, got %v", string(onlyHTMLEmail.PlainBody)) + } + +} + +func Test_composePlain(t *testing.T) { + // valid data and template + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + body, err := testTemplate.composePlain(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + expected := "Title: Test Title\nContent: Test Content" + if string(body) != expected { + t.Fatalf("got %v, want %v", string(body), expected) + } + // no plain template + wrongPlainTemplate := *testTemplate + wrongPlainTemplate.Plain = "" + body, err = wrongPlainTemplate.composePlain(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if body != nil { + t.Fatalf("expected nil, got %v", string(body)) + } +} + +func Test_composeHTML(t *testing.T) { + // valid data and template + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + body, err := testTemplate.composeHTML(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + expected := "

Test Title

Test Content

" + if string(body) != expected { + t.Fatalf("got %v, want %v", string(body), expected) + } + // no html template + wrongHTMLTemplate := *testTemplate + wrongHTMLTemplate.HTML = "" + body, err = wrongHTMLTemplate.composeHTML(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if body != nil { + t.Fatalf("expected nil, got %v", string(body)) + } +} diff --git a/email/templates/login/definition.go b/email/templates/login/definition.go new file mode 100644 index 0000000..35255ac --- /dev/null +++ b/email/templates/login/definition.go @@ -0,0 +1,30 @@ +package login + +import ( + _ "embed" + + "github.com/simpleauthlink/authapi/email" +) + +//go:embed template.html +var htmlTemplate string + +// Data struct contains the required data to fill the login email template. +type Data struct { + AppName string + Email string + Token string + Link string +} + +// Template is the login email template definition, which contains the HTML +// and plain text templates. +var Template = email.EmailTemplate{ + HTML: htmlTemplate, + Plain: `Hi, {{.Email}} +You can login to '{{.AppName}}' following this link: +{{.Link}} +It contains your login token: '{{.Token}}' +Which is only valid for you and for a short period of time. +If you didn't request this, you can ignore this email.`, +} diff --git a/email/templates/login/template.html b/email/templates/login/template.html new file mode 100644 index 0000000..4a704f6 --- /dev/null +++ b/email/templates/login/template.html @@ -0,0 +1,75 @@ + + + + + + Your Magic Link for {{.AppName}} Login + + + + + + + + +
+ + + + + + + + + + + + + + + + +
+ +

SimpleAuth.link

+
+ 👋 Hi, {{.Email}}! +

+ Your magic link to login to '{{.AppName}}' is ready 🎉. +

+ Click the button below to login to your account. 👇 +
+ + + + + + + +
+ Login to + {{.AppName}} +
+
+ {{.Token}} +
+
+ or copy and paste the following link in your browser: +

+
{{.Link}}
+

+ If you did not request this, please ignore this email. +
+
+ + + \ No newline at end of file From c65f8ceef0ae9c35dd75fde3ddbec578c2486810 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 19:39:19 +0100 Subject: [PATCH 11/36] update main with new changes and new workflow step to show the test coverage in a report --- .github/parse-test.js | 55 ++++++++++++++++++++++++++++++++++++++ .github/workflows/main.yml | 24 ++++++++++++++++- cmd/authapi/main.go | 37 +++++-------------------- 3 files changed, 84 insertions(+), 32 deletions(-) create mode 100644 .github/parse-test.js diff --git a/.github/parse-test.js b/.github/parse-test.js new file mode 100644 index 0000000..9483cc2 --- /dev/null +++ b/.github/parse-test.js @@ -0,0 +1,55 @@ +const readline = require("readline"); + +const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + terminal: false, +}); + +const summary = { fail: [], pass: [], skip: [] }; + +rl.on("line", (line) => { + const output = JSON.parse(line); + if ( + output.Action === "pass" || + output.Action === "skip" || + output.Action === "fail" + ) { + if (output.Test) { + summary[output.Action].push(output); + } + } +}); + +function totalTime(entries) { + return entries.reduce((total, l) => total + l.Elapsed, 0); +} + +rl.on("close", () => { + console.log("## 📋 Tests executed"); + console.log("| | Number of Tests | Total Time |"); + console.log("|--|--|--|"); + console.log( + "| ✅ Passed | %d | %fs |", + summary.pass.length, + totalTime(summary.pass) + ); + console.log( + "| ❌ Failed | %d | %fs |", + summary.fail.length, + totalTime(summary.fail) + ); + console.log( + "| 🔜 Skipped | %d | %fs |", + summary.skip.length, + totalTime(summary.skip) + ); + + if (summary.fail.length > 0) { + console.log("\n## Failures\n"); + } + + summary.fail.forEach((test) => { + console.log("* %s (%s) %fs", test.Test, test.Package, test.Elapsed); + }); +}); \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7109b3b..518f23e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -47,4 +47,26 @@ jobs: with: go-version: "1.23" - name: Run Go test -race - run: go test ./... -race -timeout=1h + run: go test ./... -v -race -timeout=1h + - name: convert coverage to html + run: go tool cover -html=cover.out -o cover.html + - name: print test report + run: | + set -o pipefail && cat tests.log | node .github/parse-tests.js >> $GITHUB_STEP_SUMMARY + echo $GITHUB_STEP_SUMMARY + - name: print coverage result + run: | + go tool cover -func=cover.out > ./cover.txt + echo "
📏 Tests coverage" >> $GITHUB_STEP_SUMMARY + echo -e "\n\`\`\`" >> $GITHUB_STEP_SUMMARY + cat ./cover.txt >> $GITHUB_STEP_SUMMARY + echo -e "\`\`\`\n
" >> $GITHUB_STEP_SUMMARY + - name: store code coverage artifact + uses: actions/upload-artifact@v4 + with: + name: report + path: | + tests.log + cover.txt + cover.out + cover.html diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index b16ad1c..44f29c6 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -16,20 +16,15 @@ import ( const ( defaultHost = "0.0.0.0" defaultPort = 8080 - defaultDatabaseURI = "mongodb://admin:password@localhost:27017/" - defaultDatabaseName = "simpleauth" defaultEmailAddr = "" defaultEmailPass = "" defaultEmailHost = "" defaultEmailPort = 587 defaultTokenEmailTemplate = "assets/token_email_template.html" defaultAppEmailTemplate = "assets/app_email_template.html" - defaultDisposableSrcURL = "https://raw.githubusercontent.com/disposable-email-domains/disposable-email-domains/master/disposable_email_blocklist.conf" hostFlag = "host" portFlag = "port" - dbURIFlag = "db-uri" - dbNameFlag = "db-name" emailAddrFlag = "email-addr" emailPassFlag = "email-pass" emailHostFlag = "email-host" @@ -47,19 +42,15 @@ const ( emailPortFlagDesc = "email server port" tokenEmailTemplateDesc = "path to the html template of new token email" appEmailTemplateDesc = "path to the html template of new app email" - disposableSrcDesc = "source url of list of disposable emails domains" hostEnv = "SIMPLEAUTH_HOST" portEnv = "SIMPLEAUTH_PORT" - dbURIEnv = "SIMPLEAUTH_DB_URI" - dbNameEnv = "SIMPLEAUTH_DB_NAME" emailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" emailPassEnv = "SIMPLEAUTH_EMAIL_PASS" emailHostEnv = "SIMPLEAUTH_EMAIL_HOST" emailPortEnv = "SIMPLEAUTH_EMAIL_PORT" tokenEmailTemplateEnv = "SIMPLEAUTH_TOKEN_EMAIL_TEMPLATE" appEmailTemplateEnv = "SIMPLEAUTH_APP_EMAIL_TEMPLATE" - disposableSrcEnv = "SIMPLEAUTH_DISPOSABLE_SRC" ) type config struct { @@ -85,13 +76,12 @@ func main() { // create the service service, err := api.New(context.Background(), &api.Config{ EmailConfig: email.EmailConfig{ - Address: c.emailAddr, - Password: c.emailPass, - EmailHost: c.emailHost, - EmailPort: c.emailPort, - DisposableSrc: c.disposableSrc, - TokenEmailTemplate: c.tokenEmailTemplate, - AppEmailTemplate: c.appEmailTemplate, + FromName: "SimpleAuthLink", + FromAddress: c.emailAddr, + SMTPUsername: c.emailAddr, + SMTPPassword: c.emailPass, + SMTPServer: c.emailHost, + SMTPPort: c.emailPort, }, Server: c.host, ServerPort: c.port, @@ -115,28 +105,22 @@ func parseConfig() (*config, error) { // get config from flags flag.StringVar(&fhost, hostFlag, defaultHost, hostFlagDesc) flag.IntVar(&fport, portFlag, defaultPort, hostFlagDesc) - flag.StringVar(&fdbURI, dbURIFlag, defaultDatabaseURI, dbURIFlagDesc) - flag.StringVar(&fdbName, dbNameFlag, defaultDatabaseName, dbNameFlagDesc) flag.StringVar(&femailAddr, emailAddrFlag, defaultEmailAddr, emailAddrFlagDesc) flag.StringVar(&femailPass, emailPassFlag, defaultEmailPass, emailPassFlagDesc) flag.StringVar(&femailHost, emailHostFlag, defaultEmailHost, emailHostFlagDesc) flag.StringVar(&ftokenEmailTemplate, tokenEmailTemplateFlag, defaultTokenEmailTemplate, tokenEmailTemplateDesc) flag.StringVar(&fappEmailTemplate, appEmailTemplateFlag, defaultAppEmailTemplate, appEmailTemplateDesc) flag.IntVar(&femailPort, emailPortFlag, defaultEmailPort, emailPortFlagDesc) - flag.StringVar(&fdisposableSrc, disposableSrcFlag, defaultDisposableSrcURL, disposableSrcDesc) flag.Parse() // get config from env envHost := os.Getenv(hostEnv) envPort := os.Getenv(portEnv) - envDBURI := os.Getenv(dbURIEnv) - envDBName := os.Getenv(dbNameEnv) envEmailAddr := os.Getenv(emailAddrEnv) envEmailPass := os.Getenv(emailPassEnv) envEmailHost := os.Getenv(emailHostEnv) envEmailPort := os.Getenv(emailPortEnv) envtokenEmailTemplate := os.Getenv(tokenEmailTemplateEnv) envAppEmailTemplate := os.Getenv(appEmailTemplateEnv) - envDisposableSrc := os.Getenv(disposableSrcEnv) // check if the required flags are set if femailAddr == "" && envEmailAddr == "" { @@ -173,12 +157,6 @@ func parseConfig() (*config, error) { return nil, fmt.Errorf("invalid port value: %s", envPort) } } - if envDBURI != "" { - c.dbURI = envDBURI - } - if envDBName != "" { - c.dbName = envDBName - } if envEmailAddr != "" { c.emailAddr = envEmailAddr } @@ -201,8 +179,5 @@ func parseConfig() (*config, error) { if envAppEmailTemplate != "" { c.appEmailTemplate = envAppEmailTemplate } - if envDisposableSrc != "" { - c.disposableSrc = envDisposableSrc - } return c, nil } From 5fd5ae5e0fca923a28a51df0f64c62c81fbb662b Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 20:11:34 +0100 Subject: [PATCH 12/36] race fixed --- email/emailqueue.go | 18 +++++++++--------- email/emailqueue_test.go | 32 ++++++++++++++------------------ 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/email/emailqueue.go b/email/emailqueue.go index dbf1e96..bd9110f 100644 --- a/email/emailqueue.go +++ b/email/emailqueue.go @@ -112,8 +112,8 @@ func (eq *EmailQueue) Start() { case <-eq.ctx.Done(): return default: - e := eq.Pop() - if e == nil { + e, ok := eq.Pop() + if !ok { continue } if err := eq.Send(e); err != nil { @@ -133,27 +133,27 @@ func (eq *EmailQueue) Stop() { } // Push method adds a new email to the queue. -func (eq *EmailQueue) Push(e *Email) error { +func (eq *EmailQueue) Push(e Email) error { // check if the email is valid if !e.Valid() { return ErrInvalidEmail } eq.itemsMtx.Lock() - eq.items = append(eq.items, e) + eq.items = append(eq.items, &e) eq.itemsMtx.Unlock() return nil } // Pop method removes the first email in the queue and returns it. -func (eq *EmailQueue) Pop() *Email { +func (eq *EmailQueue) Pop() (Email, bool) { eq.itemsMtx.Lock() defer eq.itemsMtx.Unlock() if len(eq.items) == 0 { - return nil + return Email{}, false } e := eq.items[0] eq.items = eq.items[1:] - return e + return *e, true } // Send method sends the email using the queue configuration. It uses the @@ -162,7 +162,7 @@ func (eq *EmailQueue) Pop() *Email { // credentials, the server string with the host and the port, and the receipts. // Finally, it sends the email. If something fails during the process, it // returns an error. It can be used even the queue is not started. -func (eq *EmailQueue) Send(e *Email) error { +func (eq *EmailQueue) Send(e Email) error { // check if the email is valid if !e.Valid() { return ErrInvalidEmail @@ -190,7 +190,7 @@ func (eq *EmailQueue) Send(e *Email) error { // composeBody creates the email body with the message data. It creates a // multipart email with a plain text and an HTML part. It returns the email // content as a byte slice or an error if the body could not be composed. -func (eq *EmailQueue) composeBody(msg *Email) ([]byte, error) { +func (eq *EmailQueue) composeBody(msg Email) ([]byte, error) { // parse 'to' email address to, err := mail.ParseAddress(msg.To) if err != nil { diff --git a/email/emailqueue_test.go b/email/emailqueue_test.go index 45e61a9..d053388 100644 --- a/email/emailqueue_test.go +++ b/email/emailqueue_test.go @@ -185,7 +185,7 @@ func TestSendEmail(t *testing.T) { t.Fatal(err) } // send email - if err := eq.Send(&Email{ + if err := eq.Send(Email{ To: testReceiver, Subject: testSubject, Body: []byte(testHTMLBody), @@ -208,6 +208,16 @@ func TestSendEmail(t *testing.T) { case <-time.After(2 * time.Second): t.Error("timed out waiting for the email to be received") } + // try to send invalid email + if err := eq.Send(Email{}); err == nil { + t.Error("expected error sending invalid email") + } + // try to compose a invalid email + if body, err := eq.composeBody(Email{}); err == nil { + t.Error("expected error composing invalid email") + } else if body != nil { + t.Error("expected body to be nil") + } // try to send email to an invalid SMTP server badEq, err := NewEmailQueue(ctx, &EmailConfig{ SMTPServer: testServerAddr, @@ -218,7 +228,7 @@ func TestSendEmail(t *testing.T) { if err != nil { t.Fatal(err) } - if err := badEq.Send(&Email{ + if err := badEq.Send(Email{ To: testReceiver, Subject: testSubject, Body: nil, @@ -246,7 +256,7 @@ func TestPushSendEmail(t *testing.T) { eq.Start() defer eq.Stop() // push email - if err := eq.Push(&Email{ + if err := eq.Push(Email{ To: testReceiver, Subject: testSubject, Body: nil, @@ -269,21 +279,7 @@ func TestPushSendEmail(t *testing.T) { // sleep to pop nil email time.Sleep(2 * time.Second) // push invalid email - if err := eq.Push(&Email{}); err == nil { + if err := eq.Push(Email{}); err == nil { t.Error("expected error pushing invalid email") } - // push another email and modify the email queue to fail - email := &Email{ - To: testReceiver, - Subject: testSubject, - Body: nil, - PlainBody: []byte(testBody), - } - if err := eq.Push(email); err != nil { - t.Fatal(err) - } - email.PlainBody = nil - if err := <-errCh; err == nil { - t.Error("expected error sending email") - } } From aaf1ca2b0c325f99f0e501557e4e30fac78a6210 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Thu, 20 Feb 2025 20:14:01 +0100 Subject: [PATCH 13/36] fix workflow --- .github/{parse-test.js => parse-coverage-report.js} | 0 .github/workflows/main.yml | 13 ++++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) rename .github/{parse-test.js => parse-coverage-report.js} (100%) diff --git a/.github/parse-test.js b/.github/parse-coverage-report.js similarity index 100% rename from .github/parse-test.js rename to .github/parse-coverage-report.js diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 518f23e..609f74a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -48,20 +48,23 @@ jobs: go-version: "1.23" - name: Run Go test -race run: go test ./... -v -race -timeout=1h - - name: convert coverage to html + - name: Rerun Go test to generate coverage report + run: | + go test -v --race -timeout 15m -coverprofile=./cover.out -json ./... > tests.log + - name: Convert report to html run: go tool cover -html=cover.out -o cover.html - - name: print test report + - name: Print coverage report run: | - set -o pipefail && cat tests.log | node .github/parse-tests.js >> $GITHUB_STEP_SUMMARY + set -o pipefail && cat tests.log | node .github/parse-coverage-report.js >> $GITHUB_STEP_SUMMARY echo $GITHUB_STEP_SUMMARY - - name: print coverage result + - name: Print coverage report run: | go tool cover -func=cover.out > ./cover.txt echo "
📏 Tests coverage" >> $GITHUB_STEP_SUMMARY echo -e "\n\`\`\`" >> $GITHUB_STEP_SUMMARY cat ./cover.txt >> $GITHUB_STEP_SUMMARY echo -e "\`\`\`\n
" >> $GITHUB_STEP_SUMMARY - - name: store code coverage artifact + - name: Store coverage report uses: actions/upload-artifact@v4 with: name: report From cbab2f70269eb23e784c3897d30df46a379bb79e Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 2 Mar 2025 01:17:23 +0100 Subject: [PATCH 14/36] new notifications and mail templates, including login mail template --- {email => notification/email}/emailqueue.go | 71 +++++++----------- .../email}/emailqueue_test.go | 73 ++++++++++++------- {email => notification/email}/errors.go | 0 {email => notification/email}/template.go | 22 +++--- .../email}/template_test.go | 44 ++++++++--- notification/notification.go | 33 +++++++++ .../templates/login/definition.go | 9 ++- .../templates/login/template.html | 24 ++++-- 8 files changed, 170 insertions(+), 106 deletions(-) rename {email => notification/email}/emailqueue.go (83%) rename {email => notification/email}/emailqueue_test.go (82%) rename {email => notification/email}/errors.go (100%) rename {email => notification/email}/template.go (84%) rename {email => notification/email}/template_test.go (74%) create mode 100644 notification/notification.go rename {email => notification}/templates/login/definition.go (69%) rename {email => notification}/templates/login/template.html (72%) diff --git a/email/emailqueue.go b/notification/email/emailqueue.go similarity index 83% rename from email/emailqueue.go rename to notification/email/emailqueue.go index bd9110f..f381db3 100644 --- a/email/emailqueue.go +++ b/notification/email/emailqueue.go @@ -9,30 +9,13 @@ import ( "net/smtp" "net/textproto" "sync" + + "github.com/simpleauthlink/authapi/notification" ) // defaultSendRetries is the default number of retries to send the email. const defaultSendRetries = 3 -// Email struct represents the email that is going to be sent. It includes the -// recipient email address, the subject and the body of the email. -type Email struct { - To string - Subject string - Body []byte - PlainBody []byte -} - -// Valid method checks if the email is valid. It returns true if the recipient -// email address, the subject and the body are not empty. -func (e *Email) Valid() bool { - if e.Subject == "" || (len(e.Body) == 0 && len(e.PlainBody) == 0) { - return false - } - _, err := mail.ParseAddress(e.To) - return err == nil -} - // EmailConfig struct represents the email configuration that is needed to send // an email using and SMTP server. It includes the email address (used as the // sender address but also as the username for the SMTP server), the email @@ -72,7 +55,7 @@ type EmailQueue struct { cancel context.CancelFunc cfg *EmailConfig auth smtp.Auth - items []*Email + items []*notification.Notification itemsMtx sync.Mutex waiter sync.WaitGroup errCh chan error @@ -90,7 +73,7 @@ func NewEmailQueue(ctx context.Context, cfg *EmailConfig) (*EmailQueue, error) { ctx: internalCtx, cancel: cancel, cfg: cfg, - items: []*Email{}, + items: []*notification.Notification{}, errCh: cfg.ErrorCh, } // init SMTP auth @@ -132,49 +115,49 @@ func (eq *EmailQueue) Stop() { eq.waiter.Wait() } -// Push method adds a new email to the queue. -func (eq *EmailQueue) Push(e Email) error { - // check if the email is valid - if !e.Valid() { - return ErrInvalidEmail - } - eq.itemsMtx.Lock() - eq.items = append(eq.items, &e) - eq.itemsMtx.Unlock() - return nil -} - // Pop method removes the first email in the queue and returns it. -func (eq *EmailQueue) Pop() (Email, bool) { +func (eq *EmailQueue) Pop() (notification.Notification, bool) { eq.itemsMtx.Lock() defer eq.itemsMtx.Unlock() if len(eq.items) == 0 { - return Email{}, false + return notification.Notification{}, false } e := eq.items[0] eq.items = eq.items[1:] return *e, true } +// Push method adds a new email to the queue. +func (eq *EmailQueue) Push(n notification.Notification) error { + // check if the email is valid + if !n.Valid() { + return ErrInvalidEmail + } + eq.itemsMtx.Lock() + eq.items = append(eq.items, &n) + eq.itemsMtx.Unlock() + return nil +} + // Send method sends the email using the queue configuration. It uses the // email address as the sender address and the username for the SMTP server. // It composes the email message, creates the auth object with the email // credentials, the server string with the host and the port, and the receipts. // Finally, it sends the email. If something fails during the process, it // returns an error. It can be used even the queue is not started. -func (eq *EmailQueue) Send(e Email) error { +func (eq *EmailQueue) Send(n notification.Notification) error { // check if the email is valid - if !e.Valid() { + if !n.Valid() { return ErrInvalidEmail } // compose the email body - body, err := eq.composeBody(e) + body, err := eq.composeBody(n) if err != nil { return ErrComposeEmail.With(err) } // create the server string with the host and the port and the receipts server := fmt.Sprintf("%s:%d", eq.cfg.SMTPServer, eq.cfg.SMTPPort) - receipts := []string{e.To} + receipts := []string{n.Params.To} // send the email for i := 0; i < eq.cfg.Retries; i++ { if err = smtp.SendMail(server, eq.auth, eq.cfg.FromAddress, receipts, body); err == nil { @@ -190,9 +173,9 @@ func (eq *EmailQueue) Send(e Email) error { // composeBody creates the email body with the message data. It creates a // multipart email with a plain text and an HTML part. It returns the email // content as a byte slice or an error if the body could not be composed. -func (eq *EmailQueue) composeBody(msg Email) ([]byte, error) { +func (eq *EmailQueue) composeBody(n notification.Notification) ([]byte, error) { // parse 'to' email address - to, err := mail.ParseAddress(msg.To) + to, err := mail.ParseAddress(n.Params.To) if err != nil { return nil, ErrParseAddress.With(err) } @@ -201,7 +184,7 @@ func (eq *EmailQueue) composeBody(msg Email) ([]byte, error) { boundary := "----=_Part_0_123456789.123456789" headers.WriteString(fmt.Sprintf("From: %s\r\n", eq.cfg.FromAddress)) headers.WriteString(fmt.Sprintf("To: %s\r\n", to.String())) - headers.WriteString(fmt.Sprintf("Subject: %s\r\n", msg.Subject)) + headers.WriteString(fmt.Sprintf("Subject: %s\r\n", n.Params.Subject)) headers.WriteString("MIME-Version: 1.0\r\n") headers.WriteString(fmt.Sprintf("Content-Type: multipart/alternative; boundary=\"%s\"\r\n", boundary)) headers.WriteString("\r\n") // blank line between headers and body @@ -216,7 +199,7 @@ func (eq *EmailQueue) composeBody(msg Email) ([]byte, error) { "Content-Type": {"text/plain; charset=\"UTF-8\""}, "Content-Transfer-Encoding": {"7bit"}, }) - if _, err := textPart.Write(msg.PlainBody); err != nil { + if _, err := textPart.Write(n.PlainBody); err != nil { return nil, ErrWriteBody.With(err) } // HTML part @@ -224,7 +207,7 @@ func (eq *EmailQueue) composeBody(msg Email) ([]byte, error) { "Content-Type": {"text/html; charset=\"UTF-8\""}, "Content-Transfer-Encoding": {"7bit"}, }) - if _, err := htmlPart.Write(msg.Body); err != nil { + if _, err := htmlPart.Write(n.Body); err != nil { return nil, ErrWriteHTMLBody.With(err) } if err := writer.Close(); err != nil { diff --git a/email/emailqueue_test.go b/notification/email/emailqueue_test.go similarity index 82% rename from email/emailqueue_test.go rename to notification/email/emailqueue_test.go index d053388..feb4a1c 100644 --- a/email/emailqueue_test.go +++ b/notification/email/emailqueue_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/simpleauthlink/authapi/internal" + "github.com/simpleauthlink/authapi/notification" ) const ( @@ -37,47 +38,57 @@ func TestMain(m *testing.M) { } func TestValidEmail(t *testing.T) { - if !(&Email{ - To: testReceiver, - Subject: testSubject, + if !(¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, Body: nil, PlainBody: []byte(testBody), }).Valid() { t.Error("expected email to be valid") } - if !(&Email{ - To: testReceiver, - Subject: testSubject, + if !(¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, Body: []byte(testBody), PlainBody: nil, }).Valid() { t.Error("expected email to be valid") } - if (&Email{ - To: testReceiver, - Subject: "", + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: "", + }, Body: nil, PlainBody: []byte(testBody), }).Valid() { t.Error("expected email to be invalid") } - if (&Email{ - To: "", - Subject: testSubject, + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: "", + Subject: testSubject, + }, Body: nil, PlainBody: []byte(testBody), }).Valid() { t.Error("expected email to be invalid") } - if (&Email{ - To: "invalidEmail", - Subject: testSubject, + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: "invalidEmail", + Subject: testSubject, + }, Body: nil, PlainBody: []byte(testBody), }).Valid() { t.Error("expected email to be invalid") } - if (&Email{}).Valid() { + if (¬ification.Notification{}).Valid() { t.Error("expected email to be invalid") } } @@ -185,9 +196,11 @@ func TestSendEmail(t *testing.T) { t.Fatal(err) } // send email - if err := eq.Send(Email{ - To: testReceiver, - Subject: testSubject, + if err := eq.Send(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, Body: []byte(testHTMLBody), PlainBody: []byte(testBody), }); err != nil { @@ -209,11 +222,11 @@ func TestSendEmail(t *testing.T) { t.Error("timed out waiting for the email to be received") } // try to send invalid email - if err := eq.Send(Email{}); err == nil { + if err := eq.Send(notification.Notification{}); err == nil { t.Error("expected error sending invalid email") } // try to compose a invalid email - if body, err := eq.composeBody(Email{}); err == nil { + if body, err := eq.composeBody(notification.Notification{}); err == nil { t.Error("expected error composing invalid email") } else if body != nil { t.Error("expected body to be nil") @@ -228,9 +241,11 @@ func TestSendEmail(t *testing.T) { if err != nil { t.Fatal(err) } - if err := badEq.Send(Email{ - To: testReceiver, - Subject: testSubject, + if err := badEq.Send(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, Body: nil, PlainBody: []byte(testBody), }); err == nil { @@ -256,9 +271,11 @@ func TestPushSendEmail(t *testing.T) { eq.Start() defer eq.Stop() // push email - if err := eq.Push(Email{ - To: testReceiver, - Subject: testSubject, + if err := eq.Push(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, Body: nil, PlainBody: []byte(testBody), }); err != nil { @@ -279,7 +296,7 @@ func TestPushSendEmail(t *testing.T) { // sleep to pop nil email time.Sleep(2 * time.Second) // push invalid email - if err := eq.Push(Email{}); err == nil { + if err := eq.Push(notification.Notification{}); err == nil { t.Error("expected error pushing invalid email") } } diff --git a/email/errors.go b/notification/email/errors.go similarity index 100% rename from email/errors.go rename to notification/email/errors.go diff --git a/email/template.go b/notification/email/template.go similarity index 84% rename from email/template.go rename to notification/email/template.go index 9f4fada..4e9bb7a 100644 --- a/email/template.go +++ b/notification/email/template.go @@ -4,6 +4,8 @@ import ( "bytes" htmltemplate "html/template" texttemplate "text/template" + + "github.com/simpleauthlink/authapi/notification" ) // EmailTemplate is the definition of an email template, which contains the @@ -18,34 +20,30 @@ type EmailTemplate struct { // be filled. It tries to fill both the HTML and plain text templates, but if // any of them is missing, it will return an error. If some of the placeholders // in the template are not filled, they will be left as they are. -func (temp *EmailTemplate) Compose(to, subject string, data any) (*Email, error) { - if to == "" || subject == "" { - return nil, ErrComposeEmail +func (temp *EmailTemplate) Compose(params notification.NotificationParams, data any) (notification.Notification, error) { + if !params.Valid() { + return notification.Notification{}, ErrComposeEmail } // compose the html body body, err := temp.composeHTML(data) if err != nil { - return nil, err + return notification.Notification{}, err } // compose the plain body plainBody, err := temp.composePlain(data) if err != nil { - return nil, err + return notification.Notification{}, err } // if both bodies are empty, return an error if plainBody == nil && body == nil { - return nil, ErrInvalidTemplate + return notification.Notification{}, ErrInvalidTemplate } // return the email with the filled bodies - email := &Email{ - To: to, - Subject: subject, + email := notification.Notification{ + Params: params, Body: body, PlainBody: plainBody, } - if !email.Valid() { - return nil, ErrComposeEmail.Withf("resulting email is not valid") - } return email, nil } diff --git a/email/template_test.go b/notification/email/template_test.go similarity index 74% rename from email/template_test.go rename to notification/email/template_test.go index 9cb763d..413d3b0 100644 --- a/email/template_test.go +++ b/notification/email/template_test.go @@ -1,6 +1,10 @@ package email -import "testing" +import ( + "testing" + + "github.com/simpleauthlink/authapi/notification" +) var testTemplate = &EmailTemplate{ HTML: "

{{.Title}}

{{.Content}}

", @@ -18,15 +22,18 @@ func TestCompose(t *testing.T) { Title: "Test Title", Content: "Test Content", } - email, err := testTemplate.Compose(testReceiver, testSubject, data) + email, err := testTemplate.Compose(notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, data) if err != nil { t.Fatalf("expected nil, got error: %v", err) } - if email.To != testReceiver { - t.Fatalf("got %v, want %v", email.To, testReceiver) + if email.Params.To != testReceiver { + t.Fatalf("got %v, want %v", email.Params.To, testReceiver) } - if email.Subject != testSubject { - t.Fatalf("got %v, want %v", email.Subject, testSubject) + if email.Params.Subject != testSubject { + t.Fatalf("got %v, want %v", email.Params.Subject, testSubject) } expectedBody := "

Test Title

Test Content

" expectedPlain := "Title: Test Title\nContent: Test Content" @@ -37,25 +44,38 @@ func TestCompose(t *testing.T) { t.Fatalf("got %v, want %v", string(email.PlainBody), expectedPlain) } // no subject - if _, err := testTemplate.Compose(testReceiver, "", data); err == nil { + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: testReceiver, + Subject: "", + }, data); err == nil { t.Fatalf("expected error, got nil") } // no to address - if _, err := testTemplate.Compose("", testSubject, data); err == nil { + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: "", + Subject: testSubject, + }, data); err == nil { t.Fatalf("expected error, got nil") } // bad to address - if _, err := testTemplate.Compose("wrongEmail", testSubject, data); err == nil { + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: "bad email", + Subject: testSubject, + }, data); err == nil { t.Fatalf("expected error, got nil") } // invalid template emptyTemplate := &EmailTemplate{} - if _, err := emptyTemplate.Compose(testReceiver, testSubject, data); err == nil { + validParams := notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + } + if _, err := emptyTemplate.Compose(validParams, data); err == nil { t.Fatalf("expected error, got nil") } // no html template noHTMLTemplate := &EmailTemplate{Plain: testTemplate.Plain} - onlyPlainEmail, err := noHTMLTemplate.Compose(testReceiver, testSubject, data) + onlyPlainEmail, err := noHTMLTemplate.Compose(validParams, data) if err != nil { t.Fatalf("expected nil, got error: %v", err) } @@ -67,7 +87,7 @@ func TestCompose(t *testing.T) { } // no plain template noPlainTemplate := &EmailTemplate{HTML: testTemplate.HTML} - onlyHTMLEmail, err := noPlainTemplate.Compose(testReceiver, testSubject, data) + onlyHTMLEmail, err := noPlainTemplate.Compose(validParams, data) if err != nil { t.Fatalf("expected nil, got error: %v", err) } diff --git a/notification/notification.go b/notification/notification.go new file mode 100644 index 0000000..308c79a --- /dev/null +++ b/notification/notification.go @@ -0,0 +1,33 @@ +package notification + +import "net/mail" + +type NotificationParams struct { + To string + Subject string +} + +func (p NotificationParams) Valid() bool { + _, err := mail.ParseAddress(p.To) + return err == nil && p.Subject != "" +} + +type Notification struct { + Params NotificationParams + Body []byte + PlainBody []byte +} + +// Valid method checks if the email is valid. It returns true if the recipient +// email address, the subject and the body are not empty. +func (n *Notification) Valid() bool { + return n.Params.Valid() && max(len(n.Body), len(n.PlainBody)) > 0 +} + +type Queue interface { + Start() + Stop() + Pop() (Notification, bool) + Push(Notification) error + Send(Notification) error +} diff --git a/email/templates/login/definition.go b/notification/templates/login/definition.go similarity index 69% rename from email/templates/login/definition.go rename to notification/templates/login/definition.go index 35255ac..35bc3d3 100644 --- a/email/templates/login/definition.go +++ b/notification/templates/login/definition.go @@ -3,7 +3,7 @@ package login import ( _ "embed" - "github.com/simpleauthlink/authapi/email" + "github.com/simpleauthlink/authapi/notification/email" ) //go:embed template.html @@ -17,12 +17,17 @@ type Data struct { Link string } +// Subject returns the email subject based on the login data. +func (d Data) Subject() string { + return "Your token for '" + d.AppName + "'" +} + // Template is the login email template definition, which contains the HTML // and plain text templates. var Template = email.EmailTemplate{ HTML: htmlTemplate, Plain: `Hi, {{.Email}} -You can login to '{{.AppName}}' following this link: +You can access to '{{.AppName}}' app using the following link: {{.Link}} It contains your login token: '{{.Token}}' Which is only valid for you and for a short period of time. diff --git a/email/templates/login/template.html b/notification/templates/login/template.html similarity index 72% rename from email/templates/login/template.html rename to notification/templates/login/template.html index 4a704f6..ec1476c 100644 --- a/email/templates/login/template.html +++ b/notification/templates/login/template.html @@ -12,7 +12,7 @@ - + -
@@ -20,23 +20,25 @@ style="border-collapse: collapse; border: 1px solid #cccccc;">
- + Logo

SimpleAuth.link

- 👋 Hi, {{.Email}}! + +

👋 Hi, {{.Email}}!



- Your magic link to login to '{{.AppName}}' is ready 🎉. + Your magic link is ready! 🎉

+ You can access to your {{.AppName}} account using it 🔐. +
Click the button below to login to your account. 👇
- + -
Login to @@ -44,9 +46,15 @@

SimpleAuth

+ +
+ It contains your login token: +
+
{{.Token}}
+
+ Which is only valid for you and for a short period of time.
- {{.Token}} + If you didn't request this, you can ignore this email.
From 68b0ec581b756b1072b18c0983923d8e751741cb Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 2 Mar 2025 01:21:11 +0100 Subject: [PATCH 15/36] new token package with new structs to better data managment, including expirations, secrets and tokens --- token/app_test.go | 2 +- token/consts.go | 2 +- token/expiration.go | 37 ++++++--- token/expiration_test.go | 54 ++++++++----- token/id.go | 93 +++++++++++++---------- token/id_test.go | 97 ++++++++++++++++-------- token/secret.go | 36 +++++++++ token/secret_test.go | 53 +++++++++++++ token/token.go | 106 ++++++++++++++++++++++++++ token/token_test.go | 159 +++++++++++++++++++++++++++++++++++++++ 10 files changed, 538 insertions(+), 101 deletions(-) create mode 100644 token/secret.go create mode 100644 token/secret_test.go create mode 100644 token/token.go create mode 100644 token/token_test.go diff --git a/token/app_test.go b/token/app_test.go index c53211b..3fed8e8 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -12,7 +12,7 @@ const ( testSessionDuration = time.Minute * 30 ) -var testAppSecret = []byte("super_secret_key") +var testAppSecret = new(Secret).SetParts([]byte("super_secret_key"), []byte("super_secret_salt")) func TestValidApp(t *testing.T) { app := &App{ diff --git a/token/consts.go b/token/consts.go index 989005d..0de326b 100644 --- a/token/consts.go +++ b/token/consts.go @@ -9,7 +9,7 @@ const ( appDataSeparator = "|" appNameMinLen = 3 appNameMaxLen = 20 - redirectURIPattern = `^https?://[a-zA-Z0-9-]+(\.[a-zA-Z0-9-]+)+(/[a-zA-Z0-9-._~:/?#[\]@!$&'()*+,;=]*)?$` + redirectURIPattern = `^https?://(?:localhost|[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)+)(?::\d+)?(/[a-zA-Z0-9-._~:/?#[\]@!$&'()*+,;=]*)?$` redirectURIMaxLen = 80 minDuration = 30 * time.Second maxDuration = 180 * 24 * time.Hour diff --git a/token/expiration.go b/token/expiration.go index 6f6ab12..6444a90 100644 --- a/token/expiration.go +++ b/token/expiration.go @@ -7,24 +7,43 @@ import ( type Expiration time.Time -func NewExpiration(d time.Duration) *Expiration { - if d < minDuration || d > maxDuration { - return nil - } - exp := Expiration(time.Now().Add(d)) - return &exp +func (exp *Expiration) Valid() bool { + return time.Now().Before(exp.Time()) } func (exp *Expiration) Time() time.Time { return time.Time(*exp) } -func (exp *Expiration) Valid() bool { - return time.Now().Before(exp.Time()) +func (exp *Expiration) SetTime(t time.Time) *Expiration { + // if no expiration is provided, initialize a new one + if exp == nil { + exp = new(Expiration) + } + // set the expiration time + *exp = Expiration(t) + return exp +} + +func (exp *Expiration) Duration() time.Duration { + return time.Until(exp.Time()) +} + +func (exp *Expiration) SetDuration(d time.Duration) *Expiration { + if d < minDuration || d > maxDuration { + return nil + } + if exp == nil { + exp = new(Expiration) + } + return exp.SetTime(time.Now().Add(d)) } func (exp *Expiration) String() string { - t := time.Time(*exp) + if exp == nil { + return "" + } + t := exp.Time() if t.IsZero() { return "" } diff --git a/token/expiration_test.go b/token/expiration_test.go index ffcb84c..2d49f0a 100644 --- a/token/expiration_test.go +++ b/token/expiration_test.go @@ -6,36 +6,52 @@ import ( "time" ) -func TestNewExpirationTime(t *testing.T) { - exp := NewExpiration(minDuration - 1) - if exp != nil { - t.Fatalf("expected nil, got %v", exp) +func TestValidExpiration(t *testing.T) { + t.Parallel() + exp := new(Expiration).SetDuration(minDuration + time.Second) + if !exp.Valid() { + t.Errorf("expected valid expiration, got invalid") + } + time.Sleep(minDuration + (time.Second * 2)) + if exp.Valid() { + t.Errorf("expected invalid expiration, got valid") + } +} + +func TestTimeSetTimeExpiration(t *testing.T) { + var nilExp *Expiration + exp := nilExp.SetTime(time.Now().Add(minDuration * 2)) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") } - exp = NewExpiration(minDuration * 2) + exp = new(Expiration).SetTime(time.Now().Add(minDuration * 2)) if exp == nil { t.Fatalf("expected valid expiration, got nil") } expTime := exp.Time() expected := time.Now().Add(minDuration * 2) - if expected.Sub(expTime) > time.Millisecond*50 { + if expected.Sub(expTime) > time.Millisecond*300 { t.Errorf("expected %v, got %v", expected, expTime) } } -func TestExpirationValid(t *testing.T) { - t.Parallel() - exp := NewExpiration(minDuration) - if !exp.Valid() { - t.Errorf("expected valid expiration, got invalid") +func TestDurationSetDurationExpiration(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration - 1) + if exp != nil { + t.Fatalf("expected nil, got %v", exp) } - time.Sleep(minDuration) - if exp.Valid() { - t.Errorf("expected invalid expiration, got valid") + exp = exp.SetDuration(minDuration * 2) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") + } + expectedDuration := time.Duration(minDuration * 2) + if expectedDuration-exp.Duration() > time.Millisecond*300 { + t.Errorf("expected %v, got %v", expectedDuration, exp.Duration()) } } func TestStringSetStringExpiration(t *testing.T) { - exp := NewExpiration(minDuration) + exp := new(Expiration).SetDuration(minDuration * 2) str := exp.String() decoded := new(Expiration).SetString(str) if decoded == nil { @@ -50,10 +66,14 @@ func TestStringSetStringExpiration(t *testing.T) { if exp := new(Expiration).String(); exp != "" { t.Errorf("expected empty string, got %v", exp) } + var nilExp *Expiration + if exp := nilExp.String(); exp != "" { + t.Errorf("expected empty string, got %v", exp) + } } func TestBytesSetBytesExpiration(t *testing.T) { - exp := NewExpiration(minDuration) + exp := new(Expiration).SetDuration(minDuration * 2) b := exp.Bytes() decoded := new(Expiration).SetBytes(b) if decoded == nil { @@ -71,7 +91,7 @@ func TestBytesSetBytesExpiration(t *testing.T) { } func TestMarshalUnmarshalExpiration(t *testing.T) { - exp := NewExpiration(minDuration) + exp := new(Expiration).SetDuration(minDuration * 2) encoded := exp.Marshal() decoded := new(Expiration).Unmarshal(encoded) if decoded == nil { diff --git a/token/id.go b/token/id.go index 4fc6225..6944296 100644 --- a/token/id.go +++ b/token/id.go @@ -1,11 +1,10 @@ package token import ( - "bytes" "crypto/ed25519" "crypto/hmac" "crypto/sha256" - "encoding/hex" + "encoding/base64" ) type AppID string @@ -31,86 +30,100 @@ func (id *AppID) SetBytes(data []byte) *AppID { return id.SetString(string(data)) } -func (id *AppID) PrivKey(secret []byte) ed25519.PrivateKey { - hFn := hmac.New(sha256.New, secret) - +func (id *AppID) PrivKey(secret Secret) ed25519.PrivateKey { + if id == nil { + return nil + } + if !secret.Valid() { + return nil + } bID := id.Bytes() if len(bID) == 0 { return nil } - // hID := sha256.Sum256(bID) + hFn := hmac.New(sha256.New, secret.Bytes()) hID := hFn.Sum(bID) return ed25519.NewKeyFromSeed(hID[:32]) } -func (id *AppID) Sign(secret, msg []byte) []byte { +func (id *AppID) Sign(secret Secret, msg []byte) []byte { + if id == nil || len(msg) == 0 { + return nil + } privKey := id.PrivKey(secret) if len(privKey) == 0 { return nil } - hmsg := sha256.Sum256(msg) - data := append(msg, hedgedNonce(privKey[:], hmsg[:])...) + data := append(msg, hedgedNonce(privKey[:], msg)...) rawSign := ed25519.Sign(privKey, data[:]) - // encode to hex - sign := make([]byte, hex.EncodedLen(len(rawSign))) - hex.Encode(sign, rawSign) + // encode to base64 + sign := make([]byte, base64.RawStdEncoding.EncodedLen(len(rawSign))) + base64.RawStdEncoding.Encode(sign, rawSign) return sign } -func (id *AppID) Verify(secret, msg, sig []byte) bool { +func (id *AppID) Verify(secret Secret, msg, sig []byte) bool { + if id == nil || len(msg) == 0 || len(sig) == 0 { + return false + } privKey := id.PrivKey(secret) if privKey == nil { return false } - // decode sign from hex - rawSign := make([]byte, hex.DecodedLen(len(sig))) - if _, err := hex.Decode(rawSign, sig); err != nil { + // decode sign from base64 + rawSign := make([]byte, base64.RawStdEncoding.DecodedLen(len(sig))) + if _, err := base64.RawStdEncoding.Decode(rawSign, sig); err != nil { return false } - hmsg := sha256.Sum256(msg) - data := append(msg, hedgedNonce(privKey[:], hmsg[:])...) + data := append(msg, hedgedNonce(privKey[:], msg)...) pubKey := privKey.Public().(ed25519.PublicKey) return ed25519.Verify(pubKey, data, rawSign) } -func (id *AppID) NewToken(secret []byte, email string) []byte { +func (id *AppID) GenerateToken(secret Secret, email string) Token { + if id == nil { + return nil + } app := new(App).SetID(id) if app == nil { return nil } - exp := NewExpiration(app.SessionDuration) - msg := signMsg(id.Bytes(), []byte(email), exp.Bytes()) + exp := new(Expiration).SetDuration(app.SessionDuration) + msg := id.Message(email, *exp) sig := id.Sign(secret, msg) - return fmtToken(exp.Marshal(), sig) + if len(sig) == 0 { + return nil + } + return *new(Token).SetExpiration(*exp).SetSignature(sig) +} + +func (id *AppID) Message(email string, exp Expiration) []byte { + if id == nil || len(email) == 0 || !exp.Valid() { + return nil + } + hmsg := sha256.Sum256(append(append(id.Bytes(), []byte(email)...), exp.Bytes()...)) + return hmsg[:] } -func (id *AppID) VerifyToken(secret, token []byte, email string) bool { - if len(token) == 0 { +func (id *AppID) VerifyToken(token Token, secret Secret, email string) bool { + if id == nil { return false } - parts := bytes.Split(token, []byte{tokenSeparator}) - if len(parts) != 2 { + exp := token.Expiration() + if exp == nil || !exp.Valid() { return false } - dExp := new(Expiration).Unmarshal(parts[0]) - if dExp == nil || !dExp.Valid() { + sig := token.Signature() + if len(sig) == 0 { + return false + } + msg := id.Message(email, *exp) + if len(msg) == 0 { return false } - sig := parts[1] - msg := signMsg(id.Bytes(), []byte(email), dExp.Bytes()) return id.Verify(secret, msg, sig) } -func signMsg(id, email, exp []byte) []byte { - res := append(id, email...) - return append(res, exp...) -} - -func fmtToken(exp, sig []byte) []byte { - t := append(exp, tokenSeparator) - return append(t, sig...) -} - func hedgedNonce(inputs ...[]byte) []byte { if len(inputs) == 0 || len(inputs[0]) == 0 { return nil diff --git a/token/id_test.go b/token/id_test.go index f1e60cd..8eda50e 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -56,13 +56,27 @@ func TestBytesSetBytesAppID(t *testing.T) { } func TestPrivKeySignVerifyAppID(t *testing.T) { - if privKey := new(AppID).PrivKey(testAppSecret); privKey != nil { + var nilAppID *AppID + if privKey := nilAppID.PrivKey(*testAppSecret); privKey != nil { t.Errorf("expected nil, got %v", privKey) } - if sig := new(AppID).Sign(testAppSecret, []byte("test data")); sig != nil { + if nilSig := nilAppID.Sign(*testAppSecret, []byte("test data")); nilSig != nil { + t.Errorf("expected nil, got %v", nilSig) + } + if nilVerify := nilAppID.Verify(*testAppSecret, []byte("test data"), []byte("test sig")); nilVerify { + t.Errorf("expected signature to be invalid") + } + badSecret := new(Secret) + if nilPrivKey := new(AppID).PrivKey(*badSecret); nilPrivKey != nil { + t.Errorf("expected nil, got %v", nilPrivKey) + } + if privKey := new(AppID).PrivKey(*testAppSecret); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if sig := new(AppID).Sign(*testAppSecret, []byte("test data")); sig != nil { t.Errorf("expected nil, got %v", sig) } - if new(AppID).Verify(testAppSecret, []byte("test data"), []byte("test sig")) { + if new(AppID).Verify(*testAppSecret, []byte("test data"), []byte("test sig")) { t.Errorf("expected signature to be invalid") } app := &App{ @@ -75,70 +89,87 @@ func TestPrivKeySignVerifyAppID(t *testing.T) { t.Fatalf("error decoding app ID") } data := []byte("test data") - sig := id.Sign(testAppSecret, data) + sig := id.Sign(*testAppSecret, data) if sig == nil { t.Fatalf("error signing data") } - if !id.Verify(testAppSecret, data, sig) { + if !id.Verify(*testAppSecret, data, sig) { t.Errorf("expected signature to be valid") } - if id.Verify(testAppSecret, data, []byte("invalid sig")) { + if id.Verify(*testAppSecret, data, []byte("invalid sig")) { t.Errorf("expected signature to be invalid") } } -func TestNewTokenVerifyToken(t *testing.T) { +func TestMessage(t *testing.T) { + if privKey := new(AppID).PrivKey(*testAppSecret); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if sig := new(AppID).Sign(*testAppSecret, []byte("test data")); sig != nil { + t.Errorf("expected nil, got %v", sig) + } +} + +func TestGenerateTokenVerifyToken(t *testing.T) { t.Parallel() - if res := new(AppID).NewToken(nil, ""); res != nil { + var nilAppID *AppID + if res := nilAppID.GenerateToken(nil, ""); res != nil { + t.Errorf("expected nil, got %v", res) + } + if nilAppID.VerifyToken(nil, *testAppSecret, "") { + t.Errorf("expected token to be invalid") + } + if res := new(AppID).GenerateToken(nil, ""); res != nil { t.Errorf("expected nil, got %v", res) } app := &App{ Name: testAppName, RedirectURI: testRedirectURI, - SessionDuration: 30 * time.Second, + SessionDuration: minDuration, } id := app.ID() if id == nil { t.Fatalf("error decoding app ID") } email := "test@email.com" - token := id.NewToken(testAppSecret, email) + if id.VerifyToken([]byte{}, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + dummyToken := new(Token).SetExpiration(*new(Expiration).SetDuration(minDuration * 3)) + if id.VerifyToken(*dummyToken, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + dummyToken = new(Token).SetSignature([]byte("test")) + if id.VerifyToken(*dummyToken, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + if invalidToken := id.GenerateToken(*testAppSecret, ""); invalidToken != nil { + t.Fatalf("expected nil, got %v", invalidToken) + } + + token := id.GenerateToken(*testAppSecret, email) if token == nil { t.Fatalf("error creating token") } - if !id.VerifyToken(testAppSecret, token, email) { + if id.VerifyToken(token, *testAppSecret, "") { + t.Errorf("expected token to be invalid") + } + if !id.VerifyToken(token, *testAppSecret, email) { t.Errorf("expected token to be valid") } - time.Sleep(app.SessionDuration + 1) - if id.VerifyToken(testAppSecret, token, email) { + time.Sleep(app.SessionDuration + time.Second) + if id.VerifyToken(token, *testAppSecret, email) { t.Errorf("expected token to be invalid") } - if id.VerifyToken(testAppSecret, nil, email) { + if id.VerifyToken(nil, *testAppSecret, email) { t.Errorf("expected token to be invalid") } - exp := NewExpiration(minDuration) - if id.VerifyToken(testAppSecret, exp.Marshal(), email) { + exp := new(Expiration).SetDuration(minDuration) + if id.VerifyToken(exp.Marshal(), *testAppSecret, email) { t.Errorf("expected token to be invalid") } } -func Test_signMsg(t *testing.T) { - expected := []byte("testcombineddata") - if res := signMsg([]byte("test"), []byte("combined"), []byte("data")); !bytes.Equal(res, expected) { - t.Errorf("expected %v, got %v", expected, res) - } -} - -func Test_fmtToken(t *testing.T) { - sig := []byte("testsig") - exp := []byte("testexp") - expected := append(exp, tokenSeparator) - expected = append(expected, sig...) - if res := fmtToken(exp, sig); !bytes.Equal(res, expected) { - t.Errorf("expected %v, got %v", expected, res) - } -} - func Test_hedgedNonce(t *testing.T) { if res := hedgedNonce(); res != nil { t.Errorf("expected nil, got %v", res) diff --git a/token/secret.go b/token/secret.go new file mode 100644 index 0000000..e55ef85 --- /dev/null +++ b/token/secret.go @@ -0,0 +1,36 @@ +package token + +import "crypto/sha256" + +type Secret []byte + +func (s *Secret) SetParts(raw ...[]byte) *Secret { + if s == nil { + s = new(Secret) + } + newParts := []byte{} + for _, part := range raw { + if len(part) == 0 { + continue + } + hsecret := sha256.Sum256(part) + newParts = append(newParts, hsecret[:]...) + } + if len(newParts) != 0 { + *s = append(*s, newParts...) + } + return s +} + +func (s *Secret) Bytes() []byte { + return []byte(*s) +} + +func (s *Secret) Valid() bool { + if s == nil { + return false + } + // secret is valid if it has more than 1 part, and each part is hashed + // to a sha256 size + return len(*s) > sha256.Size +} diff --git a/token/secret_test.go b/token/secret_test.go new file mode 100644 index 0000000..4dbac49 --- /dev/null +++ b/token/secret_test.go @@ -0,0 +1,53 @@ +package token + +import ( + "bytes" + "crypto/sha256" + "testing" +) + +func TestSetPartsSecret(t *testing.T) { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + valid := new(Secret).SetParts(servicePart, appPart) + if valid == nil { + t.Errorf("expected Secret, got nil") + } + hServicePart := sha256.Sum256(servicePart) + hAppPart := sha256.Sum256(appPart) + expected := append(hServicePart[:], hAppPart[:]...) + if !bytes.Equal(valid.Bytes(), expected) { + t.Errorf("expected %x, got %x", expected[:], valid.Bytes()) + } + if noServicePart := new(Secret).SetParts(nil, appPart); !bytes.Equal(noServicePart.Bytes(), hAppPart[:]) { + t.Errorf("expected nil, got %x", noServicePart) + } + if noAppPart := new(Secret).SetParts(servicePart, nil); !bytes.Equal(noAppPart.Bytes(), hServicePart[:]) { + t.Errorf("expected nil, got %x", noAppPart) + } + var nilSecret *Secret + valid = nilSecret.SetParts(servicePart, appPart) + if !bytes.Equal(valid.Bytes(), expected) { + t.Errorf("expected nil, got %v", nilSecret) + } +} + +func TestValidSecret(t *testing.T) { + var nilSecret *Secret + if valid := nilSecret.Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + if valid := new(Secret).Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + servicePart := []byte("service-secret") + singlePart := new(Secret).SetParts(servicePart) + if valid := singlePart.Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + appPart := []byte("app-secret") + valid := new(Secret).SetParts(servicePart, appPart) + if !valid.Valid() { + t.Errorf("expected true, got false") + } +} diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000..9ff64c2 --- /dev/null +++ b/token/token.go @@ -0,0 +1,106 @@ +package token + +import ( + "bytes" +) + +type Token []byte + +func (t *Token) String() string { + if t == nil { + return "" + } + return string(t.Bytes()) +} + +func (t *Token) SetString(data string) *Token { + if t == nil { + t = new(Token) + } + return t.SetBytes([]byte(data)) +} + +func (t *Token) Bytes() []byte { + if t == nil { + return nil + } + if _, _, ok := t.parts(); !ok { + return nil + } + return []byte(*t) +} + +func (t *Token) SetBytes(data []byte) *Token { + if t == nil { + t = new(Token) + } + ntoken := &Token{} + *ntoken = data + if _, _, ok := ntoken.parts(); !ok { + return t + } + *t = data + return t +} + +func (t *Token) Expiration() *Expiration { + if rawExp, _, ok := t.parts(); ok { + return new(Expiration).Unmarshal(rawExp) + } + return nil +} + +func (t *Token) SetExpiration(exp Expiration) *Token { + // if no token is provided, create a new one + if t == nil { + t = new(Token) + } + // if expiration is invalid, return the current token + if !exp.Valid() { + return t + } + // create a base content with the new expiration and no signature + baseContent := append(exp.Marshal(), tokenSeparator) + // get the current signature, if there is one, return a token with the + // base content and no signature + sig := t.Signature() + if sig == nil { + return t.SetBytes(baseContent) + } + // if there is a signature, update the token to replace the expiration + // part with the new expiration + return t.SetBytes(append(baseContent, sig...)) +} + +func (t *Token) Signature() []byte { + _, sig, _ := t.parts() + return sig +} + +func (t *Token) SetSignature(sig []byte) *Token { + // if no token is provided, create a new one + if t == nil { + t = new(Token) + } + // create a base content with no expiration and the new signature + baseContent := append([]byte{tokenSeparator}, sig...) + // get the current expiration, if there is none, return a token with the + // base content and no expiration + exp := t.Expiration() + if exp == nil { + return t.SetBytes(baseContent) + } + // if there is an expiration, update the token to replace the signature + // part with the new signature + return t.SetBytes(append(exp.Marshal(), baseContent...)) +} + +func (t *Token) parts() ([]byte, []byte, bool) { + if t == nil { + return nil, nil, false + } + if p := bytes.Split([]byte(*t), []byte{tokenSeparator}); len(p) == 2 { + return p[0], p[1], true + } + return nil, nil, false +} diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..6b9c80a --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,159 @@ +package token + +import ( + "bytes" + "testing" + "time" +) + +func TestStringSetStringToken(t *testing.T) { + var token *Token + token.SetString("test") + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetString("test") + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + exp := new(Expiration).SetDuration(minDuration * 2) + expected := string(exp.Marshal()) + string(tokenSeparator) + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } + + expected = string(tokenSeparator) + "testSignature" + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } + + expected = string(exp.Marshal()) + string(tokenSeparator) + "testSignature" + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } +} + +func TestBytesSetBytesToken(t *testing.T) { + var token *Token + token.SetBytes([]byte("test")) + if token.Bytes() != nil { + t.Errorf("expected nil, got %v", token.Bytes()) + } + token = token.SetBytes([]byte("test")) + if token.Bytes() != nil { + t.Errorf("expected nil, got %v", token.Bytes()) + } + exp := new(Expiration).SetDuration(minDuration * 2) + onlyExp := append(exp.Marshal(), tokenSeparator) + token.SetBytes(onlyExp) + if !bytes.Equal(token.Bytes(), onlyExp) { + t.Errorf("expected %v, got %v", onlyExp, token.Bytes()) + } + + onlySign := append([]byte{tokenSeparator}, []byte("testSignature")...) + token.SetBytes(onlySign) + if !bytes.Equal(token.Bytes(), onlySign) { + t.Errorf("expected %v, got %v", onlySign, token.Bytes()) + } + + fullToken := append(append(exp.Marshal(), tokenSeparator), []byte("testSignature")...) + token.SetBytes(fullToken) + if !bytes.Equal(token.Bytes(), fullToken) { + t.Errorf("expected %v, got %v", fullToken, token.Bytes()) + } +} + +func TestExpirationSetExpirationToken(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration * 2) + var token *Token + token.SetExpiration(*exp) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetExpiration(Expiration(time.Now().Add(-time.Second))) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token.SetExpiration(*exp) + justExpExp := token.Expiration() + if justExpExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if justExpExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, justExpExp) + } + token.SetSignature([]byte("test")) + completeTokenExp := token.Expiration() + if completeTokenExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if completeTokenExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, completeTokenExp) + } + exp = new(Expiration).SetDuration(minDuration * 3) + token.SetExpiration(*exp) + newExpExp := token.Expiration() + if newExpExp == nil || newExpExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, newExpExp) + } + validStr := token.String() + newtoken := new(Token).SetString(validStr) + newtokenExp := newtoken.Expiration() + if newtokenExp == nil || newtokenExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, newtokenExp) + } +} + +func TestSignatureSetSignatureToken(t *testing.T) { + var token *Token + token.SetSignature([]byte("test")) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetSignature([]byte("test")) + justSignSign := token.Signature() + if !bytes.Equal(justSignSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), justSignSign) + } + token.SetExpiration(*new(Expiration).SetDuration(minDuration * 3)) + completeTokenSign := token.Signature() + if !bytes.Equal(completeTokenSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), completeTokenSign) + } + validStr := token.String() + newtoken := new(Token).SetString(validStr) + newtokenSign := newtoken.Signature() + if !bytes.Equal(newtokenSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), newtokenSign) + } +} + +func Test_partsToken(t *testing.T) { + var token *Token + if _, _, ok := token.parts(); ok { + t.Errorf("expected false, got true") + } + exp := new(Expiration).SetDuration(minDuration * 2) + token = token.SetExpiration(*exp) + rawExp, _, ok := token.parts() + if !ok { + t.Errorf("expected false, got true") + } + if !bytes.Equal(rawExp, exp.Marshal()) { + t.Errorf("expected %v, got %v", exp.Marshal(), rawExp) + } + token.SetSignature([]byte("test")) + rawExp, sign, ok := token.parts() + if !ok { + t.Errorf("expected true, got false") + } + if !bytes.Equal(rawExp, exp.Marshal()) { + t.Errorf("expected %v, got %v", exp.Marshal(), rawExp) + } + if !bytes.Equal(sign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), sign) + } +} From 212fa88d91dc9242fd6bce1136782caae5ce6ea1 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 2 Mar 2025 01:23:06 +0100 Subject: [PATCH 16/36] full api implementation with basic features, includes requests, responses, errors and initial tests --- api/error.go | 78 +++++++++++++++++ api/handlers.go | 124 ++++++++++++++++++++++++++ api/handlers_test.go | 202 +++++++++++++++++++++++++++++++++++++++++++ api/helpers.go | 24 +++++ api/io.go | 60 +++++++++++++ api/routes.go | 3 + api/service.go | 80 +++++++---------- api/service_test.go | 77 +++++++++++++++++ api/types.go | 47 ++++++---- cmd/authapi/main.go | 151 +++++++++++++++----------------- 10 files changed, 695 insertions(+), 151 deletions(-) create mode 100644 api/error.go create mode 100644 api/handlers.go create mode 100644 api/handlers_test.go create mode 100644 api/helpers.go create mode 100644 api/io.go create mode 100644 api/service_test.go diff --git a/api/error.go b/api/error.go new file mode 100644 index 0000000..94697a5 --- /dev/null +++ b/api/error.go @@ -0,0 +1,78 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" +) + +var ( + // Decode data errors + DecodeAppIDRequestErr = newApiErr(1001, http.StatusBadRequest).With("could not decode app id request") + DecodeTokenRequestErr = newApiErr(1002, http.StatusBadRequest).With("could not decode token request") + DecodeTokenStatusRequestErr = newApiErr(1003, http.StatusBadRequest).With("could not decode token status request") + // Encode data errors + EncodeAppIDResponseErr = newApiErr(1010, http.StatusInternalServerError).With("could not encode app id response") + EncodeTokenStatusResponseErr = newApiErr(1011, http.StatusInternalServerError).With("could not encode token status response") + // Bad request errors + InvalidAppHeadersErr = newApiErr(1020, http.StatusBadRequest).With("invalid app headers") + InvalidAppIDErr = newApiErr(1021, http.StatusBadRequest).With("invalid app id") + InvalidAppSecretErr = newApiErr(1022, http.StatusBadRequest).With("invalid app secret") + // Internal errors + GenerateTokenErr = newApiErr(1030, http.StatusInternalServerError).With("could not generate token") + GenerateEmailErr = newApiErr(1031, http.StatusInternalServerError).With("could not generate email") + SendEmailErr = newApiErr(1032, http.StatusInternalServerError).With("could not send email") + InternalErr = newApiErr(1033, http.StatusInternalServerError).With("internal server error") +) + +type APIError struct { + Code int `json:"code"` + Message string `json:"message"` + Err string `json:"error"` + StatusCode int `json:"status_code"` +} + +func (e *APIError) Bytes() []byte { + bErr, err := json.Marshal(e) + if err != nil { + return nil + } + return bErr +} + +func (e *APIError) Error() string { + return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.StatusCode) +} + +func (e *APIError) WithErr(err error) *APIError { + if e.Err == "" { + e.Err = err.Error() + return e + } + e.Err = fmt.Sprintf("%s: %s", e.Err, err.Error()) + return e +} + +func (e *APIError) With(msg string) *APIError { + if e.Message == "" { + e.Message = msg + return e + } + e.Message = fmt.Sprintf("%s: %s", e.Message, msg) + return e +} + +func (e *APIError) Write(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(e.StatusCode) + if _, err := w.Write(e.Bytes()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func newApiErr(code, status int) *APIError { + return &APIError{ + Code: code, + StatusCode: status, + } +} diff --git a/api/handlers.go b/api/handlers.go new file mode 100644 index 0000000..1d1c6b4 --- /dev/null +++ b/api/handlers.go @@ -0,0 +1,124 @@ +package api + +import ( + "net/http" + + "github.com/simpleauthlink/authapi/notification" + "github.com/simpleauthlink/authapi/notification/templates/login" + "github.com/simpleauthlink/authapi/token" +) + +func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { + // decode the app data from the request body + req := new(Request[AppIDRequest]) + if err := req.Read(r); err != nil { + DecodeAppIDRequestErr.WithErr(err).Write(w) + return + } + // create the app from the data and check if it is valid + app := req.Data.parseApp() + if !app.Valid() { + InvalidAppIDErr.Write(w) + return + } + // return the app id + if err := ResponseWith(&AppIDResponse{app.ID().String()}).Write(w); err != nil { + EncodeAppIDResponseErr.WithErr(err).Write(w) + } +} + +func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { + // get the app id from the request header + strAppID, strAppSecret, err := appConfigFromRequest(r) + if err != nil { + InvalidAppHeadersErr.WithErr(err).Write(w) + return + } + // decode the app id get the app from it + appID := new(token.AppID).SetString(strAppID) + app := new(token.App).SetID(appID) + // check if the app id is valid (it should be a valid app) + if !app.Valid() { + InvalidAppIDErr.Write(w) + return + } + // decode the token request from the request body + req := new(Request[TokenRequest]) + if err := req.Read(r); err != nil { + DecodeTokenRequestErr.WithErr(err).Write(w) + return + } + // generate user token + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) + return + } + token := appID.GenerateToken(*secret, req.Data.Email) + if token == nil { + GenerateTokenErr.With(req.Data.Email).Write(w) + return + } + // compose the email with the token + loginData := login.Data{ + AppName: app.Name, + Email: req.Data.Email, + Token: token.String(), + Link: app.RedirectURI + token.String(), + } + loginEmail, err := login.Template.Compose(notification.NotificationParams{ + To: req.Data.Email, + Subject: loginData.Subject(), + }, loginData) + if err != nil { + GenerateEmailErr.WithErr(err).Write(w) + return + } + // push the email to the notification queue + if err := s.nq.Push(loginEmail); err != nil { + SendEmailErr.WithErr(err).Write(w) + return + } + if err := OkResponse().Write(w); err != nil { + InternalErr.WithErr(err).Write(w) + } +} + +func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { + // get the app id from the request header + strAppID, strAppSecret, err := appConfigFromRequest(r) + if err != nil { + InvalidAppHeadersErr.WithErr(err).Write(w) + return + } + // decode the app id get the app from it + appID := new(token.AppID).SetString(strAppID) + app := new(token.App).SetID(appID) + // check if the app id is valid (it should be a valid app) + if !app.Valid() { + InvalidAppIDErr.Write(w) + return + } + // decode the token status request from the request body + req := new(Request[TokenStatusRequest]) + if err := req.Read(r); err != nil { + DecodeTokenStatusRequestErr.WithErr(err).Write(w) + return + } + // check if the token is valid + tkn := new(token.Token).SetString(req.Data.Token) + exp := tkn.Expiration().Time() + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) + return + } + ok := appID.VerifyToken(*tkn, *secret, req.Data.Email) + if err := ResponseWith(&TokenStatusResponse{ + Valid: ok, + Expiration: exp, + }).Write(w); err != nil { + EncodeTokenStatusResponseErr.WithErr(err).Write(w) + return + } +} diff --git a/api/handlers_test.go b/api/handlers_test.go new file mode 100644 index 0000000..36a9842 --- /dev/null +++ b/api/handlers_test.go @@ -0,0 +1,202 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "testing" + "time" + + "github.com/simpleauthlink/authapi/notification" + "github.com/simpleauthlink/authapi/notification/email" + "github.com/simpleauthlink/authapi/notification/templates/login" + "github.com/simpleauthlink/authapi/token" +) + +type testCaseAPIHandler[ReqType, ResType any] struct { + name string + method string + endpoint string + header http.Header + request *ReqType + response *ResType + err *APIError +} + +func (testCase testCaseAPIHandler[Rq, Rs]) url() string { + return fmt.Sprintf("%s%s", testServerApiURL, testCase.endpoint) +} + +func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T, parallel bool) { + t.Run(testCase.name, func(t *testing.T) { + if parallel { + t.Parallel() + } + var reqBuffer io.Reader + if testCase.request != nil { + rawBody, err := json.Marshal(testCase.request) + if err != nil { + t.Fatalf("could not marshal request: %v", err) + } + reqBuffer = bytes.NewReader(rawBody) + } + req, err := http.NewRequest(testCase.method, testCase.url(), reqBuffer) + if err != nil { + t.Fatalf("could not create request: %v", err) + } + req.Header = testCase.header + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("could not send request: %v", err) + } + defer resp.Body.Close() + switch { + case testCase.err != nil: + if resp.StatusCode != testCase.err.StatusCode { + t.Fatalf("expected status code: %d, got: %d", testCase.err.StatusCode, resp.StatusCode) + } + err := new(APIError) + if err := json.NewDecoder(resp.Body).Decode(err); err != nil { + t.Fatalf("could not decode error response: %v", err) + } + if err.Code != testCase.err.Code { + t.Fatalf("expected error code: %d, got: %d", testCase.err.Code, err.Code) + } + return + case testCase.response != nil: + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code: %d, got: %d", http.StatusOK, resp.StatusCode) + } + expected, err := json.Marshal(testCase.response) + if err != nil { + t.Fatalf("could not marshal response: %v", err) + } + res, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read response: %v", err) + } + if !bytes.Equal(bytes.TrimSpace(expected), bytes.TrimSpace(res)) { + t.Fatalf("expected response: %s, got: %s", expected, res) + } + default: + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code: %d, got: %d", http.StatusOK, resp.StatusCode) + } + } + }) +} + +func TestGenerateAppIDHandler(t *testing.T) { + testApp := &token.App{ + Name: testAppName, + RedirectURI: testAppRedirectURL, + SessionDuration: testAppSessionDuration, + } + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "valid request", + method: http.MethodPost, + endpoint: AppsPath, + request: &AppIDRequest{ + Name: testApp.Name, + RedirectURL: testApp.RedirectURI, + Duration: int64(testApp.SessionDuration), + Secret: testAppSecret, + }, + response: &AppIDResponse{ + ID: testApp.ID().String(), + }, + }.Run(t, true) + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "no request", + method: http.MethodPost, + endpoint: AppsPath, + request: nil, + err: DecodeAppIDRequestErr, + }.Run(t, true) + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "invalid request", + method: http.MethodPost, + endpoint: AppsPath, + request: &AppIDRequest{ + Name: testAppName, + RedirectURL: testAppRedirectURL, + Duration: int64(time.Second), + Secret: testAppSecret, + }, + err: InvalidAppIDErr, + }.Run(t, true) +} + +func TestRequestTokenAndStatusHandler(t *testing.T) { + testApp := &token.App{ + Name: testAppName, + RedirectURI: testAppRedirectURL, + SessionDuration: testAppSessionDuration, + } + testAppID := testApp.ID() + login.Template = email.EmailTemplate{ + HTML: "", + Plain: `\[{{.Token}}]`, + } + testCaseAPIHandler[TokenRequest, any]{ + name: "valid request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + }.Run(t, false) + + var testToken *token.Token + select { + case receivedMsg := <-inboxChan: + data := login.Data{ + AppName: testAppName, + Email: testUserEmail, + Token: `(.+\..+)`, + Link: testAppRedirectURL + receivedMsg, + } + notification, err := login.Template.Compose(notification.NotificationParams{ + To: testUserEmail, + Subject: data.Subject(), + }, data) + if err != nil { + t.Fatalf("could not compose notification: %v", err) + } + tokenRgx := regexp.MustCompile(string(notification.PlainBody)) + tokenResult := tokenRgx.FindAllStringSubmatch(receivedMsg, -1) + if len(tokenResult) < 1 || len(tokenResult[0]) < 2 { + t.Fatal("could not find token in email") + } + testToken = new(token.Token).SetString(tokenResult[0][1]) + break + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for the email to be received") + } + + testCaseAPIHandler[TokenStatusRequest, TokenStatusResponse]{ + name: "valid token status request", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: &TokenStatusResponse{ + Valid: true, + Expiration: testToken.Expiration().Time(), + }, + }.Run(t, false) +} diff --git a/api/helpers.go b/api/helpers.go new file mode 100644 index 0000000..19fd684 --- /dev/null +++ b/api/helpers.go @@ -0,0 +1,24 @@ +package api + +import ( + "fmt" + "net/http" +) + +const ( + AppIDHeader = "APP_ID" + AppSecretHeader = "APP_SECRET" +) + +func appConfigFromRequest(r *http.Request) (string, string, error) { + // get the app id from the request header + strAppID := r.Header.Get(AppIDHeader) + if strAppID == "" { + return "", "", fmt.Errorf("missing app id") + } + strAppSecret := r.Header.Get(AppSecretHeader) + if strAppSecret == "" { + return "", "", fmt.Errorf("missing app secret") + } + return strAppID, strAppSecret, nil +} diff --git a/api/io.go b/api/io.go new file mode 100644 index 0000000..b4b2f2e --- /dev/null +++ b/api/io.go @@ -0,0 +1,60 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" +) + +type Response[T any] struct { + Data T + empty bool +} + +func ResponseWith[T any](data *T) *Response[T] { + if data == nil { + return &Response[T]{empty: true} + } + return &Response[T]{ + Data: *data, + empty: false, + } +} + +func OkResponse() *Response[any] { + return &Response[any]{empty: true} +} + +func (r *Response[T]) Write(w http.ResponseWriter) error { + if !r.empty { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + return json.NewEncoder(w).Encode(r.Data) + } + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("OK")) + return err +} + +type TokenRequest struct { + Email string `json:"email"` +} + +type Request[T any] struct { + Data T +} + +func (req *Request[T]) Read(r *http.Request) error { + if req == nil { + req = new(Request[T]) + } + rawBody, err := io.ReadAll(r.Body) + if err != nil { + return err + } + if len(rawBody) == 0 { + return fmt.Errorf("empty request body") + } + return json.Unmarshal(rawBody, &req.Data) +} diff --git a/api/routes.go b/api/routes.go index f10df75..928805a 100644 --- a/api/routes.go +++ b/api/routes.go @@ -4,4 +4,7 @@ const ( // HealthCheckPath constant is the path used to check the health of the API // server. It is a string with a value of "/health". HealthCheckPath = "/ping" + + AppsPath = "/apps" + TokensPath = "/tokens" ) diff --git a/api/service.go b/api/service.go index 17e09d1..4d32e4c 100644 --- a/api/service.go +++ b/api/service.go @@ -3,7 +3,6 @@ package api import ( "context" "fmt" - "log" "net/http" "os" "os/signal" @@ -12,54 +11,33 @@ import ( "time" "github.com/lucasmenendez/apihandler" - "github.com/simpleauthlink/authapi/email" + "github.com/simpleauthlink/authapi/notification" ) -// Config struct represents the configuration needed to init the service. It -// includes the email configuration, the server hostname, the server port, the -// data path to store the database, and the cleaner cooldown to clean the -// expired tokens. type Config struct { - email.EmailConfig - Server string - ServerPort int - CleanerCooldown time.Duration + Server string + ServerPort int + Secret string } -// Service struct represents the service that is going to be started. It -// includes the context and the cancel function to stop the service, the wait -// group to wait for the background processes to finish, the configuration, -// the database connection and the api handler. type Service struct { ctx context.Context cancel context.CancelFunc wait sync.WaitGroup cfg *Config - emailQueue *email.EmailQueue + nq notification.Queue handler *apihandler.Handler httpServer *http.Server } -// New function creates a new service based on the provided context and -// configuration. It initializes the email queue, creates the service and -// sets the api handlers. If something goes wrong during the process, it -// returns an error. -func New(ctx context.Context, cfg *Config) (*Service, error) { +func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - emailQueue, err := email.NewEmailQueue(internalCtx, &cfg.EmailConfig) - if err != nil { - if emailQueue == nil { - cancel() - return nil, err - } - log.Println("WRN: something occurs during email queue creation:", err) - } // create the service srv := &Service{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - emailQueue: emailQueue, + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + nq: nq, handler: apihandler.NewHandler(&apihandler.Config{ CORS: true, RateLimitConfig: &apihandler.RateLimitConfig{ @@ -68,9 +46,13 @@ func New(ctx context.Context, cfg *Config) (*Service, error) { }, }), } + // register the routes and handlers srv.handler.Get(HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + OkResponse().Write(w) }) + srv.handler.Post(AppsPath, srv.generateAppIDHandler) + srv.handler.Post(TokensPath, srv.requestTokenHandler) + srv.handler.Put(TokensPath, srv.verifyTokenHandler) // build the http server srv.httpServer = &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server, cfg.ServerPort), @@ -79,11 +61,8 @@ func New(ctx context.Context, cfg *Config) (*Service, error) { return srv, nil } -// Start method starts the service. It starts the token cleaner and the api -// server. If something goes wrong during the process, it returns an error. +// Start method starts the service. func (s *Service) Start() error { - // start the email queue - s.emailQueue.Start() // start the api server if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { return err @@ -91,16 +70,23 @@ func (s *Service) Start() error { return nil } -// Stop method stops the service. It cancels the context and waits for the -// background processes to finish. It closes the database. If something goes -// wrong during the process, it returns an error. -func (s *Service) Stop() error { - // stop the email queue - s.emailQueue.Stop() +func (s *Service) Stop() { // cancel the context and wait for the background processes finish s.cancel() defer s.wait.Wait() - return nil +} + +func (s *Service) Ping() bool { + url := fmt.Sprintf("http://%s:%d%s", s.cfg.Server, s.cfg.ServerPort, HealthCheckPath) + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return false + } + response, err := http.DefaultClient.Do(request) + if err != nil { + return false + } + return response.StatusCode == http.StatusOK } // WaitToShutdown method waits for the service to shutdown. It listens for the @@ -112,10 +98,6 @@ func (s *Service) WaitToShutdown() error { <-done ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - defer func() { - if err := s.Stop(); err != nil { - log.Println(err) - } - }() + defer s.Stop() return s.httpServer.Shutdown(ctx) } diff --git a/api/service_test.go b/api/service_test.go new file mode 100644 index 0000000..4c931a7 --- /dev/null +++ b/api/service_test.go @@ -0,0 +1,77 @@ +package api + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/simpleauthlink/authapi/internal" + "github.com/simpleauthlink/authapi/notification/email" +) + +const ( + testServerAddr = "127.0.0.1" + testServerSMTPPort = 2526 + testServerAPIPort = 5555 + testServerSecret = "server-secret" + testSenderName = "TestAPI" + testSender = "api-test@testmail.com" + testUserEmail = "user@testmail.com" + testAppName = "TestApp" + testAppRedirectURL = "http://testapp.com" + testAppSessionDuration = time.Second * 35 + testAppSecret = "test-secret" +) + +var ( + testServerApiURL = fmt.Sprintf("http://%s:%d", testServerAddr, testServerAPIPort) + inboxChan = make(chan string, 1) +) + +func TestMain(m *testing.M) { + defer close(inboxChan) + // create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // start test SMTP server to receive the email + testSrv := internal.NewFakeSMTPServer(testServerAddr, testServerSMTPPort, inboxChan) + if err := testSrv.Start(ctx); err != nil { + panic(err) + } + defer testSrv.Stop() + // create email queue with valid config + eq, err := email.NewEmailQueue(ctx, &email.EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerSMTPPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + panic(err) + } + eq.Start() + defer eq.Stop() + // create the API service + apiSrv, err := New(ctx, &Config{ + Server: testServerAddr, + ServerPort: testServerAPIPort, + Secret: testServerSecret, + }, eq) + if err != nil { + panic(err) + } + go func() { + if err := apiSrv.Start(); err != nil { + panic(err) + } + }() + defer apiSrv.Stop() + // make ping to the server to check if it is running + if ok := apiSrv.Ping(); !ok { + panic("API server is not running") + } + // run the tests + os.Exit(m.Run()) +} diff --git a/api/types.go b/api/types.go index 00d6aa9..9236f42 100644 --- a/api/types.go +++ b/api/types.go @@ -1,27 +1,36 @@ package api -const ( - userTokenSubject = "Here is your magic link for '%s' 🔐" - appTokenSubject = "Your app '%s' is ready! 🎉" +import ( + "time" + + "github.com/simpleauthlink/authapi/token" ) -// TokenRequest struct includes the required information by the API service to -// create a token, which is the email of the user. The app secret is also -// required but it is provided in the request headers. -type TokenRequest struct { - Email string `json:"email"` +type AppIDRequest struct { + Name string `json:"name"` + Duration int64 `json:"session_duration"` RedirectURL string `json:"redirect_url"` - Duration uint64 `json:"session_duration"` + Secret string `json:"secret"` +} + +func (data *AppIDRequest) parseApp() *token.App { + return &token.App{ + Name: data.Name, + RedirectURI: data.RedirectURL, + SessionDuration: time.Duration(data.Duration), + } +} + +type AppIDResponse struct { + ID string `json:"id"` +} + +type TokenStatusRequest struct { + Token string `json:"token"` + Email string `json:"email"` } -// AppData struct includes the required information by the API service to -// create an app, which are the name, the email of the admin, the session -// duration and the callback URL. -type AppData struct { - Name string `json:"name"` - Email string `json:"admin_email"` - Duration uint64 `json:"session_duration"` - RedirectURL string `json:"redirect_url"` - UsersQuota int64 `json:"users_quota"` - CurrentUsers int64 `json:"current_users"` +type TokenStatusResponse struct { + Valid bool `json:"valid"` + Expiration time.Time `json:"expiration"` } diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index 44f29c6..0d688ff 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -7,64 +7,52 @@ import ( "log" "os" "strconv" - "time" "github.com/simpleauthlink/authapi/api" - "github.com/simpleauthlink/authapi/email" + "github.com/simpleauthlink/authapi/notification/email" ) const ( - defaultHost = "0.0.0.0" - defaultPort = 8080 - defaultEmailAddr = "" - defaultEmailPass = "" - defaultEmailHost = "" - defaultEmailPort = 587 - defaultTokenEmailTemplate = "assets/token_email_template.html" - defaultAppEmailTemplate = "assets/app_email_template.html" + defaultHost = "0.0.0.0" + defaultPort = 8080 + defaultEmailAddr = "" + defaultEmailPass = "" + defaultEmailHost = "" + defaultEmailPort = 587 + defaultSecret = "simpleauthlink-secret" - hostFlag = "host" - portFlag = "port" - emailAddrFlag = "email-addr" - emailPassFlag = "email-pass" - emailHostFlag = "email-host" - emailPortFlag = "email-port" - tokenEmailTemplateFlag = "email-token-template" - appEmailTemplateFlag = "email-app-template" - disposableSrcFlag = "disposable-src" - hostFlagDesc = "service host" - portFlagDesc = "service port" - dbURIFlagDesc = "database uri" - dbNameFlagDesc = "database name" - emailAddrFlagDesc = "email account address" - emailPassFlagDesc = "email account password" - emailHostFlagDesc = "email server host" - emailPortFlagDesc = "email server port" - tokenEmailTemplateDesc = "path to the html template of new token email" - appEmailTemplateDesc = "path to the html template of new app email" + hostFlag = "host" + portFlag = "port" + emailAddrFlag = "email-addr" + emailPassFlag = "email-pass" + emailHostFlag = "email-host" + emailPortFlag = "email-port" + secretFlag = "secret" + hostFlagDesc = "service host" + portFlagDesc = "service port" + emailAddrFlagDesc = "email account address" + emailPassFlagDesc = "email account password" + emailHostFlagDesc = "email server host" + emailPortFlagDesc = "email server port" + secretFlagDesc = "secret used to generate the tokens" - hostEnv = "SIMPLEAUTH_HOST" - portEnv = "SIMPLEAUTH_PORT" - emailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" - emailPassEnv = "SIMPLEAUTH_EMAIL_PASS" - emailHostEnv = "SIMPLEAUTH_EMAIL_HOST" - emailPortEnv = "SIMPLEAUTH_EMAIL_PORT" - tokenEmailTemplateEnv = "SIMPLEAUTH_TOKEN_EMAIL_TEMPLATE" - appEmailTemplateEnv = "SIMPLEAUTH_APP_EMAIL_TEMPLATE" + hostEnv = "SIMPLEAUTH_HOST" + portEnv = "SIMPLEAUTH_PORT" + emailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" + emailPassEnv = "SIMPLEAUTH_EMAIL_PASS" + emailHostEnv = "SIMPLEAUTH_EMAIL_HOST" + emailPortEnv = "SIMPLEAUTH_EMAIL_PORT" + secretEnv = "SIMPLEAUTH_SECRET" ) type config struct { - host string - port int - dbURI string - dbName string - emailAddr string - emailPass string - emailHost string - emailPort int - tokenEmailTemplate string - appEmailTemplate string - disposableSrc string + host string + port int + emailAddr string + emailPass string + emailHost string + emailPort int + secret string } func main() { @@ -73,20 +61,27 @@ func main() { if err != nil { log.Fatalln("ERR: error parsing config:", err) } + // create email queue + emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ + FromName: "SimpleAuthLink", + FromAddress: c.emailAddr, + SMTPUsername: c.emailAddr, + SMTPPassword: c.emailPass, + SMTPServer: c.emailHost, + SMTPPort: c.emailPort, + }) + if err != nil { + log.Fatalln("WRN: something occurs during email queue creation:", err) + } + // start the email queue and defer to stop it + emailQueue.Start() + defer emailQueue.Stop() // create the service service, err := api.New(context.Background(), &api.Config{ - EmailConfig: email.EmailConfig{ - FromName: "SimpleAuthLink", - FromAddress: c.emailAddr, - SMTPUsername: c.emailAddr, - SMTPPassword: c.emailPass, - SMTPServer: c.emailHost, - SMTPPort: c.emailPort, - }, - Server: c.host, - ServerPort: c.port, - CleanerCooldown: 30 * time.Minute, - }) + Server: c.host, + ServerPort: c.port, + Secret: c.secret, + }, emailQueue) if err != nil { log.Fatalln("ERR: error creating service:", err) } @@ -100,7 +95,7 @@ func main() { } func parseConfig() (*config, error) { - var fhost, fdbURI, fdbName, femailAddr, femailPass, femailHost, ftokenEmailTemplate, fappEmailTemplate, fdisposableSrc string + var fhost, femailAddr, femailPass, femailHost, fsecret string var fport, femailPort int // get config from flags flag.StringVar(&fhost, hostFlag, defaultHost, hostFlagDesc) @@ -108,9 +103,8 @@ func parseConfig() (*config, error) { flag.StringVar(&femailAddr, emailAddrFlag, defaultEmailAddr, emailAddrFlagDesc) flag.StringVar(&femailPass, emailPassFlag, defaultEmailPass, emailPassFlagDesc) flag.StringVar(&femailHost, emailHostFlag, defaultEmailHost, emailHostFlagDesc) - flag.StringVar(&ftokenEmailTemplate, tokenEmailTemplateFlag, defaultTokenEmailTemplate, tokenEmailTemplateDesc) - flag.StringVar(&fappEmailTemplate, appEmailTemplateFlag, defaultAppEmailTemplate, appEmailTemplateDesc) flag.IntVar(&femailPort, emailPortFlag, defaultEmailPort, emailPortFlagDesc) + flag.StringVar(&fsecret, secretFlag, defaultSecret, secretFlagDesc) flag.Parse() // get config from env envHost := os.Getenv(hostEnv) @@ -119,9 +113,7 @@ func parseConfig() (*config, error) { envEmailPass := os.Getenv(emailPassEnv) envEmailHost := os.Getenv(emailHostEnv) envEmailPort := os.Getenv(emailPortEnv) - envtokenEmailTemplate := os.Getenv(tokenEmailTemplateEnv) - envAppEmailTemplate := os.Getenv(appEmailTemplateEnv) - + envSecret := os.Getenv(secretEnv) // check if the required flags are set if femailAddr == "" && envEmailAddr == "" { return nil, fmt.Errorf("email address is required, use -%s or set %s env var", emailAddrFlag, emailAddrEnv) @@ -132,19 +124,18 @@ func parseConfig() (*config, error) { if femailHost == "" && envEmailHost == "" { return nil, fmt.Errorf("email host is required, use -%s or set %s env var", emailHostFlag, emailHostEnv) } + if fsecret == "" && envSecret == "" { + return nil, fmt.Errorf("secret is required, use -%s or set %s env var", secretFlag, secretEnv) + } // set flags values by default c := &config{ - host: fhost, - port: fport, - dbURI: fdbURI, - dbName: fdbName, - emailAddr: femailAddr, - emailPass: femailPass, - emailHost: femailHost, - emailPort: femailPort, - tokenEmailTemplate: ftokenEmailTemplate, - appEmailTemplate: fappEmailTemplate, - disposableSrc: fdisposableSrc, + host: fhost, + port: fport, + emailAddr: femailAddr, + emailPass: femailPass, + emailHost: femailHost, + emailPort: femailPort, + secret: fsecret, } // if some flags are not set, set them by env if envHost != "" { @@ -173,11 +164,5 @@ func parseConfig() (*config, error) { return nil, fmt.Errorf("invalid email port value: %s", envEmailPort) } } - if envtokenEmailTemplate != "" { - c.tokenEmailTemplate = envtokenEmailTemplate - } - if envAppEmailTemplate != "" { - c.appEmailTemplate = envAppEmailTemplate - } return c, nil } From 2514b8c32dedbec4d47827668a9a00605fbfd2b7 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 2 Mar 2025 01:28:32 +0100 Subject: [PATCH 17/36] update workflow --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 609f74a..845b40e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,7 +50,7 @@ jobs: run: go test ./... -v -race -timeout=1h - name: Rerun Go test to generate coverage report run: | - go test -v --race -timeout 15m -coverprofile=./cover.out -json ./... > tests.log + go test -v -timeout 15m -coverprofile=./cover.out -json ./... > tests.log - name: Convert report to html run: go tool cover -html=cover.out -o cover.html - name: Print coverage report From 7d30e7e752ad96448ff0b200b409645e8682b4ea Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 2 Mar 2025 01:31:22 +0100 Subject: [PATCH 18/36] fixing tests --- api/service_test.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/service_test.go b/api/service_test.go index 4c931a7..a964701 100644 --- a/api/service_test.go +++ b/api/service_test.go @@ -69,8 +69,16 @@ func TestMain(m *testing.M) { }() defer apiSrv.Stop() // make ping to the server to check if it is running - if ok := apiSrv.Ping(); !ok { - panic("API server is not running") + nRetries := 5 + for { + if nRetries == 0 { + panic("API server is not running") + } + if ok := apiSrv.Ping(); ok { + break + } + nRetries-- + time.Sleep(time.Second) } // run the tests os.Exit(m.Run()) From 12622504905802d9434b4dd5d6f532538c9fee31 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sat, 8 Mar 2025 13:44:48 +0100 Subject: [PATCH 19/36] token comments and new tests --- token/app.go | 46 +++++++++++++++ token/app_test.go | 15 ++++- token/expiration.go | 51 ++++++++++++++++- token/expiration_test.go | 18 +++++- token/id.go | 117 +++++++++++++++++++++++++++++++++++---- token/id_test.go | 8 +++ token/secret.go | 15 +++++ token/token.go | 55 +++++++++++++++++- 8 files changed, 304 insertions(+), 21 deletions(-) diff --git a/token/app.go b/token/app.go index 0b9fd11..b378caf 100644 --- a/token/app.go +++ b/token/app.go @@ -6,69 +6,105 @@ import ( "time" ) +// App represents an application that can request tokens. It has a name, a +// redirect URI, and a session duration. type App struct { Name string RedirectURI string SessionDuration time.Duration } +// Valid method returns true if the app is valid, false otherwise. An app is +// considered valid if its name is between 3 and 20 characters, its redirect +// URI is a valid URI, and its session duration is between 5 minutes and 24 +// hours. func (app *App) Valid() bool { if app == nil { return false } + // check if the app name is between the min and max length if len(app.Name) < appNameMinLen || len(app.Name) > appNameMaxLen { return false } + // check if the redirect URI is valid if !uriRegexp.MatchString(app.RedirectURI) || len(app.RedirectURI) > redirectURIMaxLen { return false } + // check if the session duration is between the min and max duration if app.SessionDuration < minDuration || app.SessionDuration > maxDuration { return false } return true } +// Attributes method returns the app's attributes as a slice of strings. This +// is useful for encoding the app. func (app *App) Attributes() []string { return []string{app.Name, app.RedirectURI, app.SessionDuration.String()} } +// SetAttributes method sets the app's attributes from a slice of strings. This +// is useful for decoding the app. func (app *App) SetAttributes(attrs []string) *App { + // check if the slice has the correct number of attributes if len(attrs) != 3 { return nil } + // parse the session duration duration, err := time.ParseDuration(attrs[2]) if err != nil { return nil } + // if the app is nil, create a new app + if app == nil { + app = new(App) + } + // set the app's attributes app.Name = attrs[0] app.RedirectURI = attrs[1] app.SessionDuration = duration + // check if the app is valid and return it if it is if !app.Valid() { return nil } return app } +// String method returns the app as a string. This is useful for debugging +// and encoding the app. The resulting string is the app's attributes joined +// by the app data separator. func (app *App) String() string { if !app.Valid() { return "" } + // join the app's attributes with the app data separator return strings.Join(app.Attributes(), appDataSeparator) } +// SetString method sets the app from a string. This is useful for decoding +// the app. The string should be the app's attributes joined by the app data +// separator. func (app *App) SetString(data string) *App { b := strings.Split(data, appDataSeparator) return app.SetAttributes(b) } +// Bytes method returns the app as a byte slice. This is useful for encoding +// the app. It is equivalent to converting the app to a string and then +// converting the string to a byte slice. func (app *App) Bytes() []byte { return []byte(app.String()) } +// SetBytes method sets the app from a byte slice. This is useful for decoding +// the app. It is equivalent to converting the byte slice to a string and then +// converting the string to the app. func (app *App) SetBytes(data []byte) *App { return app.SetString(string(data)) } +// Marshal method returns the app as a base64-encoded byte slice. It is used +// to be included in the app ID, which makes it self-contained. func (app *App) Marshal() []byte { if !app.Valid() { return nil @@ -79,6 +115,8 @@ func (app *App) Marshal() []byte { return b } +// Unmarshal method sets the app from a base64-encoded byte slice. It is used +// to extract the app from the app ID. func (app *App) Unmarshal(data []byte) *App { b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { @@ -87,6 +125,10 @@ func (app *App) Unmarshal(data []byte) *App { return app.SetBytes(b) } +// ID method returns the app ID of the app. The app ID is a self-contained +// representation of the app that can be used to generate tokens. It is +// created by encoding the app as a base64-encoded byte slice using the +// Marshal method. func (app *App) ID() *AppID { if !app.Valid() { return nil @@ -94,6 +136,10 @@ func (app *App) ID() *AppID { return new(AppID).SetBytes(app.Marshal()) } +// SetID method sets the app from an app ID. The app ID is a self-contained +// representation of the app that can be used to generate tokens. The app is +// extracted from the app ID by decoding the app as a base64-encoded byte +// slice using the Unmarshal method. func (app *App) SetID(id *AppID) *App { if id == nil { return nil diff --git a/token/app_test.go b/token/app_test.go index 3fed8e8..0b2f4a2 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -58,7 +58,6 @@ func TestAttributesSetAttributesApp(t *testing.T) { if res := new(App).SetAttributes([]string{}); res != nil { t.Errorf("expected nil, got %v", res) } - if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, "no_duration"}); res != nil { t.Errorf("expected nil, got %v", res) } @@ -72,6 +71,20 @@ func TestAttributesSetAttributesApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } + var nilApp *App + nilData := nilApp.SetAttributes(app.Attributes()) + if nilData == nil { + t.Fatalf("error decoding app data") + } + if nilData.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, nilData.Name) + } + if nilData.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, nilData.RedirectURI) + } + if nilData.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, nilData.SessionDuration) + } data := new(App).SetAttributes(app.Attributes()) if data == nil { t.Fatalf("error decoding app data") diff --git a/token/expiration.go b/token/expiration.go index 6444a90..31e7f40 100644 --- a/token/expiration.go +++ b/token/expiration.go @@ -5,16 +5,24 @@ import ( "time" ) +// Expiration represents a time when a token expires. It is a wrapper around +// time.Time that provides additional methods for setting and getting the +// expiration time. type Expiration time.Time +// Valid method returns true if the expiration is valid, false otherwise. An +// expiration is considered valid if it is in the future. func (exp *Expiration) Valid() bool { return time.Now().Before(exp.Time()) } +// Time method returns the expiration time as a time.Time. func (exp *Expiration) Time() time.Time { return time.Time(*exp) } +// SetTime method sets the expiration time from a time.Time. If the expiration +// is nil, a new expiration is created. func (exp *Expiration) SetTime(t time.Time) *Expiration { // if no expiration is provided, initialize a new one if exp == nil { @@ -22,13 +30,20 @@ func (exp *Expiration) SetTime(t time.Time) *Expiration { } // set the expiration time *exp = Expiration(t) + // if the expiration is invalid, return nil + if !exp.Valid() { + return nil + } return exp } +// Duration method returns the duration until the expiration time. func (exp *Expiration) Duration() time.Duration { return time.Until(exp.Time()) } +// SetDuration method sets the expiration time from a duration. If the duration +// is invalid, the expiration time is not set. func (exp *Expiration) SetDuration(d time.Duration) *Expiration { if d < minDuration || d > maxDuration { return nil @@ -39,26 +54,47 @@ func (exp *Expiration) SetDuration(d time.Duration) *Expiration { return exp.SetTime(time.Now().Add(d)) } +// String method returns the expiration time as a string in RFC3339Nano format. +// It is useful for encoding the expiration time. If the expiration is nil, an +// empty string is returned. func (exp *Expiration) String() string { if exp == nil { return "" } - t := exp.Time() - if t.IsZero() { + if !exp.Valid() { return "" } - return t.Format(time.RFC3339Nano) + return exp.Time().Format(time.RFC3339Nano) } +// SetString method sets the expiration time from a string in RFC3339Nano +// format. It is useful for decoding the expiration time. If the string +// is invalid, the expiration time is not set and nil is returned. If the +// expiration is nil, a new expiration is created. If the resulting expiration +// is invalid, nil is returned. func (exp *Expiration) SetString(data string) *Expiration { + // parse the expiration time t, err := time.Parse(time.RFC3339Nano, data) if err != nil { return nil } + // if the expiration is nil, initialize a new one + if exp == nil { + exp = new(Expiration) + } + // set the expiration time *exp = Expiration(t) + // if the expiration is invalid, return nil + if !exp.Valid() { + return nil + } return exp } +// Bytes method returns the expiration time as a byte slice. It is useful for +// encoding the expiration time. If the expiration is nil, nil is returned. It +// is equivalent to converting the expiration time to a string and then +// converting the string to a byte slice. func (exp *Expiration) Bytes() []byte { if exp.String() == "" { return nil @@ -66,10 +102,16 @@ func (exp *Expiration) Bytes() []byte { return []byte(exp.String()) } +// SetBytes method sets the expiration time from a byte slice. It is useful for +// decoding the expiration time. It is equivalent to converting the byte slice +// to a string and then setting the expiration time from the string. func (exp *Expiration) SetBytes(data []byte) *Expiration { return exp.SetString(string(data)) } +// Marshal method returns the expiration time as a base64 encoded byte slice. It +// is useful for encoding the expiration time. If the expiration is nil or +// invalid, nil is returned. func (exp *Expiration) Marshal() []byte { bExp := exp.Bytes() if len(bExp) == 0 || bExp[0] == 0 { @@ -80,6 +122,9 @@ func (exp *Expiration) Marshal() []byte { return b } +// Unmarshal method sets the expiration time from a base64 encoded byte slice. It +// is useful for decoding the expiration time. If the expiration is nil or +// invalid, nil is returned. func (exp *Expiration) Unmarshal(data []byte) *Expiration { b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { diff --git a/token/expiration_test.go b/token/expiration_test.go index 2d49f0a..f7b2525 100644 --- a/token/expiration_test.go +++ b/token/expiration_test.go @@ -33,6 +33,10 @@ func TestTimeSetTimeExpiration(t *testing.T) { if expected.Sub(expTime) > time.Millisecond*300 { t.Errorf("expected %v, got %v", expected, expTime) } + invalidTime := time.Now().Add(-time.Second) + if exp := new(Expiration).SetTime(invalidTime); exp != nil { + t.Errorf("expected nil, got %v", exp) + } } func TestDurationSetDurationExpiration(t *testing.T) { @@ -67,8 +71,18 @@ func TestStringSetStringExpiration(t *testing.T) { t.Errorf("expected empty string, got %v", exp) } var nilExp *Expiration - if exp := nilExp.String(); exp != "" { - t.Errorf("expected empty string, got %v", exp) + if nilExp.String() != "" { + t.Errorf("expected empty string, got %v", nilExp.String()) + } + if nilExp = nilExp.SetString(str); nilExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if nilExp.String() != str { + t.Errorf("expected %v, got %v", str, nilExp.String()) + } + invalidTime := time.Now().Add(-time.Second).Format(time.RFC3339Nano) + if exp := new(Expiration).SetString(invalidTime); exp != nil { + t.Errorf("expected nil, got %v", exp) } } diff --git a/token/id.go b/token/id.go index 6944296..9917eda 100644 --- a/token/id.go +++ b/token/id.go @@ -7,34 +7,63 @@ import ( "encoding/base64" ) -type AppID string +// AppID represents an application ID that is used to generate and verify +// tokens. It is a wrapper around a byte slice that provides additional +// methods for setting and getting the application ID. +type AppID []byte +// String method returns the application ID as a string. If the application +// ID is nil, an empty string is returned. It internally calls the Bytes +// method to get the application ID as a byte slice and converts it to a +// string. func (id *AppID) String() string { - return string(*id) + return string(id.Bytes()) } +// SetString method sets the application ID from a string. If the application +// ID is nil, a new application ID is created. If the string is empty, the +// application ID is not set. It internally calls the SetBytes method to set +// the application ID from a byte slice. func (id *AppID) SetString(data string) *AppID { - newID := AppID(data) - if !new(App).Unmarshal(newID.Bytes()).Valid() { - return nil - } - *id = newID - return id + return id.SetBytes([]byte(data)) } +// Bytes method returns the application ID as a byte slice. func (id *AppID) Bytes() []byte { return []byte(*id) } +// SetBytes method sets the application ID from a byte slice. If the +// application ID is nil, a new application ID is created. If the byte slice +// is empty, the application ID is not set. It internally calls the Unmarshal +// method of the App to check that the application ID is valid before setting +// the application ID from the byte slice. If the resulting application is not +// valid, the application ID is not set and nil is returned. func (id *AppID) SetBytes(data []byte) *AppID { - return id.SetString(string(data)) + // check if the application ID is valid + if !new(App).Unmarshal(data).Valid() { + return nil + } + // if no application ID is provided, create a new one + if id == nil { + id = new(AppID) + } + // set the application ID + *id = data + return id } +// PrivKey method returns the private key for the application ID. If the +// application ID is nil or the secret is invalid, nil is returned. It +// internally calls the Bytes method to get the application ID as a byte +// slice and uses it to generate the private key. The private key is +// generated by hashing the application ID with the secret and using the +// resulting hash as the seed for an ed25519 private key. func (id *AppID) PrivKey(secret Secret) ed25519.PrivateKey { - if id == nil { + if !secret.Valid() { return nil } - if !secret.Valid() { + if id == nil { return nil } bID := id.Bytes() @@ -46,26 +75,49 @@ func (id *AppID) PrivKey(secret Secret) ed25519.PrivateKey { return ed25519.NewKeyFromSeed(hID[:32]) } +// Sign method returns the signature of the message for the application ID. +// If the application ID is nil, the message is empty, or the secret is +// invalid, nil is returned. It internally calls the PrivKey method to get +// the private key for the application ID and uses it to sign the message. +// The message is signed by appending a nonce to it and hashing the result +// with the private key. The signature is then encoded to base64 before +// being returned to be used as a part of the token, keeping it as short as +// possible. func (id *AppID) Sign(secret Secret, msg []byte) []byte { + // check if the application ID is valid or the message is empty if id == nil || len(msg) == 0 { return nil } + // get the private key for the application ID and the secret privKey := id.PrivKey(secret) if len(privKey) == 0 { return nil } + // append the message with a nonce and hash it with the private key data := append(msg, hedgedNonce(privKey[:], msg)...) + // sign the data with the private key rawSign := ed25519.Sign(privKey, data[:]) - // encode to base64 + // encode the signature to base64 and return it sign := make([]byte, base64.RawStdEncoding.EncodedLen(len(rawSign))) base64.RawStdEncoding.Encode(sign, rawSign) return sign } +// Verify method returns true if the signature of the message is valid for +// the application ID. If the application ID is nil, the message is empty, +// the signature is empty, or the secret is invalid, false is returned. It +// internally calls the PrivKey method to get the private key for the +// application ID and uses it to verify the signature. The message is +// verified by appending a nonce to it and hashing the result with the +// public key. The signature is then decoded from base64 and verified with +// the public key to ensure that it was signed by the private key. func (id *AppID) Verify(secret Secret, msg, sig []byte) bool { + // check if the application ID is valid or the message and signature are + // not empty if id == nil || len(msg) == 0 || len(sig) == 0 { return false } + // get the private key for the application ID and the secret privKey := id.PrivKey(secret) if privKey == nil { return false @@ -75,55 +127,96 @@ func (id *AppID) Verify(secret Secret, msg, sig []byte) bool { if _, err := base64.RawStdEncoding.Decode(rawSign, sig); err != nil { return false } + // recover the data with the nonce and the message data := append(msg, hedgedNonce(privKey[:], msg)...) + // verify the data with the public key pubKey := privKey.Public().(ed25519.PublicKey) return ed25519.Verify(pubKey, data, rawSign) } +// GenerateToken method returns a token for the application ID. If the app ID +// is nil, the secret is invalid, or the email is empty, nil is returned. It +// internally calls the Message method to generate the message for the token +// and uses it to sign the message. The token is generated by hashing the +// application ID with the email and expiration time, and signing the result +// with the secret. The signature is then used to create a token with the +// expiration time and signature. func (id *AppID) GenerateToken(secret Secret, email string) Token { + // check if the application ID is valid if id == nil { return nil } + // get the application for the application ID app := new(App).SetID(id) if app == nil { return nil } + // calculate the expiration time for the current app exp := new(Expiration).SetDuration(app.SessionDuration) + // get the message to sign msg := id.Message(email, *exp) + // sign the message with the secret sig := id.Sign(secret, msg) + // check if the signature is valid if len(sig) == 0 { return nil } + // create a new token with the expiration time and signature return *new(Token).SetExpiration(*exp).SetSignature(sig) } +// Message method returns the message for the application ID. If the app ID +// is nil, the email is empty, or the expiration time is invalid, nil is +// returned. It is used to generate the message for the token by hashing the +// application ID with the email and expiration time. The message is generated +// by appending the application ID with the email and expiration time, and +// hashing the result with sha256 to create a unique message for the token. func (id *AppID) Message(email string, exp Expiration) []byte { + // check if the application ID is valid, the email is not empty, and the + // expiration time is valid if id == nil || len(email) == 0 || !exp.Valid() { return nil } + // hash the application ID with the email and expiration time hmsg := sha256.Sum256(append(append(id.Bytes(), []byte(email)...), exp.Bytes()...)) return hmsg[:] } +// VerifyToken method returns true if the token is valid for the application +// ID. If the app ID is nil, the token is nil, the secret is invalid, or the +// email is empty, false is returned. It is used to verify the token by +// checking that the expiration time is valid and the signature is correct. +// The token is verified by hashing the application ID with the email and +// expiration time, and verifying the signature with the secret. func (id *AppID) VerifyToken(token Token, secret Secret, email string) bool { + // check if the application ID is valid if id == nil { return false } + // check if the expiration time is valid exp := token.Expiration() if exp == nil || !exp.Valid() { return false } + // check if the token contains a signature sig := token.Signature() if len(sig) == 0 { return false } + // get the message to verify msg := id.Message(email, *exp) if len(msg) == 0 { return false } + // verify the token with the secret return id.Verify(secret, msg, sig) } +// hedgedNonce function returns a nonce for the application ID. It is used to +// generate a unique nonce for the application ID by hashing the application +// ID with the message. The nonce is generated by appending the application ID +// with the message and hashing the result with sha256. It helps to prevent +// replay attacks by ensuring that the nonce is unique for each message. func hedgedNonce(inputs ...[]byte) []byte { if len(inputs) == 0 || len(inputs[0]) == 0 { return nil diff --git a/token/id_test.go b/token/id_test.go index 8eda50e..f8739ec 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -53,6 +53,14 @@ func TestBytesSetBytesAppID(t *testing.T) { if !bytes.Equal(newID.Bytes(), id.Bytes()) { t.Errorf("expected %v, got %v", id.Bytes(), newID.Bytes()) } + var nilID *AppID + nilID = nilID.SetBytes(app.Marshal()) + if nilID == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(nilID.Bytes(), id.Bytes()) { + t.Errorf("expected %v, got %v", id.Bytes(), nilID.Bytes()) + } } func TestPrivKeySignVerifyAppID(t *testing.T) { diff --git a/token/secret.go b/token/secret.go index e55ef85..2acdc47 100644 --- a/token/secret.go +++ b/token/secret.go @@ -2,12 +2,22 @@ package token import "crypto/sha256" +// Secret represents a secret that is used to sign and verify tokens. It is +// a wrapper around a byte slice that provides additional methods for setting +// and getting the secret. It should have at least 2 parts, each hashed to a +// sha256 size. type Secret []byte +// SetParts method sets the secret's parts from a slice of byte slices. If the +// secret is nil, a new secret is created. If the parts are empty, the secret +// is not set. The parts are hashed to a sha256 size and concatenated to form +// the secret. func (s *Secret) SetParts(raw ...[]byte) *Secret { + // if no secret is provided, initialize a new one if s == nil { s = new(Secret) } + // hash each part to a sha256 size and concatenate them newParts := []byte{} for _, part := range raw { if len(part) == 0 { @@ -16,16 +26,21 @@ func (s *Secret) SetParts(raw ...[]byte) *Secret { hsecret := sha256.Sum256(part) newParts = append(newParts, hsecret[:]...) } + // if there are new parts, append them to the secret if len(newParts) != 0 { *s = append(*s, newParts...) } return s } +// Bytes method returns the secret as a byte slice. func (s *Secret) Bytes() []byte { return []byte(*s) } +// Valid method returns true if the secret is valid, false otherwise. A secret +// is considered valid if it has more than 1 part, and each part is hashed to +// a sha256 size. func (s *Secret) Valid() bool { if s == nil { return false diff --git a/token/token.go b/token/token.go index 9ff64c2..ce90604 100644 --- a/token/token.go +++ b/token/token.go @@ -1,11 +1,16 @@ package token -import ( - "bytes" -) +import "bytes" +// Token is a type that represents a user token. It is a wrapper around a byte +// slice that provides additional methods for setting and getting the token. It +// should have 2 parts, the first part is the expiration time, and the second +// part is the signature. type Token []byte +// String method returns the token as a string. It is useful for encoding the +// token. If the token is nil, an empty string is returned. It internally calls +// the Bytes method to get the token as a byte slice. func (t *Token) String() string { if t == nil { return "" @@ -13,6 +18,10 @@ func (t *Token) String() string { return string(t.Bytes()) } +// SetString method sets the token from a string. It is useful for decoding the +// token. The string should be the token's expiration time and signature joined +// by the token separator. If the token is invalid, the token is not set. It +// internally calls the SetBytes method to set the token from a byte slice. func (t *Token) SetString(data string) *Token { if t == nil { t = new(Token) @@ -20,29 +29,52 @@ func (t *Token) SetString(data string) *Token { return t.SetBytes([]byte(data)) } +// Bytes method returns the token as a byte slice. It is useful for encoding +// the token. If the token is nil, nil is returned. It internally calls the +// parts method to get the token's expiration time and signature as byte +// slices. It checks that the parts are valid before returning the token +// as a byte slice. func (t *Token) Bytes() []byte { + // check if the token is nil if t == nil { return nil } + // check if the token has valid parts if _, _, ok := t.parts(); !ok { return nil } + // return the token as a byte slice return []byte(*t) } +// SetBytes method sets the token from a byte slice. It is useful for +// decoding the token. The byte slice should be the token's expiration time +// and signature joined by the token separator. If the token is invalid, the +// token is not set. It internally calls the parts method to get the token's +// expiration time and signature as byte slices. It checks that the parts are +// valid before setting the token from the byte slice. func (t *Token) SetBytes(data []byte) *Token { + // if no token is provided, create a new one if t == nil { t = new(Token) } + // generate a new token from the data ntoken := &Token{} *ntoken = data + // check if the new token has valid parts if _, _, ok := ntoken.parts(); !ok { return t } + // set the token to the new token *t = data return t } +// Expiration method returns the token's expiration time. It is useful for +// getting the expiration time. If the token is nil, nil is returned. It +// internally calls the parts method to get the token's expiration time and +// signature as byte slices. It checks that the expiration time is valid before +// returning it. func (t *Token) Expiration() *Expiration { if rawExp, _, ok := t.parts(); ok { return new(Expiration).Unmarshal(rawExp) @@ -50,6 +82,11 @@ func (t *Token) Expiration() *Expiration { return nil } +// SetExpiration method sets the token's expiration time. If the token is nil, +// a new token is created. If the expiration time is invalid, the token is not +// set. It internally calls the parts method to replace the token's expiration +// time with the new expiration time. It checks that the expiration time is +// valid before setting it. func (t *Token) SetExpiration(exp Expiration) *Token { // if no token is provided, create a new one if t == nil { @@ -72,11 +109,18 @@ func (t *Token) SetExpiration(exp Expiration) *Token { return t.SetBytes(append(baseContent, sig...)) } +// Signature method returns the token's signature. If the token is nil, nil is +// returned. It internally calls the parts method to get the signature part as +// byte slices. func (t *Token) Signature() []byte { _, sig, _ := t.parts() return sig } +// SetSignature method sets the token's signature. If the token is nil, a +// new token is created. If the signature is nil, the token is not set. It +// internally calls the parts method to replace the token's signature with +// the new signature. func (t *Token) SetSignature(sig []byte) *Token { // if no token is provided, create a new one if t == nil { @@ -95,6 +139,11 @@ func (t *Token) SetSignature(sig []byte) *Token { return t.SetBytes(append(exp.Marshal(), baseContent...)) } +// parts private method returns the token's expiration time and signature as +// byte slices. It also returns a boolean indicating if the token has valid +// parts. If the token is nil, or the parts are invalid, the parts are nil and +// the boolean is false. It splits the token by the token separator and checks +// that the result has 2 parts (expiration time and signature). func (t *Token) parts() ([]byte, []byte, bool) { if t == nil { return nil, nil, false From c7b2f4f7536c592932b782bf1280dc40354ad546 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sat, 8 Mar 2025 13:45:29 +0100 Subject: [PATCH 20/36] new api test and minor changes --- api/handlers_test.go | 68 +++++++++++++++++++++++++++++ api/io.go | 4 -- api/types.go | 4 ++ notification/email/template_test.go | 1 - 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/api/handlers_test.go b/api/handlers_test.go index 36a9842..00d184e 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -137,6 +137,74 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { SessionDuration: testAppSessionDuration, } testAppID := testApp.ID() + testCaseAPIHandler[TokenRequest, any]{ + name: "no appID request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t, false) + testCaseAPIHandler[TokenRequest, any]{ + name: "invalid app id request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{"invalid"}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppIDErr, + }.Run(t, false) + testCaseAPIHandler[TokenRequest, any]{ + name: "no app secret request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t, false) + testCaseAPIHandler[TokenRequest, any]{ + name: "no email provided", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: "", + }, + response: nil, + err: GenerateTokenErr, + }.Run(t, false) + invalid := []byte("invalid") + testCaseAPIHandler[[]byte, any]{ + name: "no request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &invalid, + response: nil, + err: DecodeTokenRequestErr, + }.Run(t, false) + login.Template = email.EmailTemplate{ HTML: "", Plain: `\[{{.Token}}]`, diff --git a/api/io.go b/api/io.go index b4b2f2e..9ec021a 100644 --- a/api/io.go +++ b/api/io.go @@ -37,10 +37,6 @@ func (r *Response[T]) Write(w http.ResponseWriter) error { return err } -type TokenRequest struct { - Email string `json:"email"` -} - type Request[T any] struct { Data T } diff --git a/api/types.go b/api/types.go index 9236f42..755c391 100644 --- a/api/types.go +++ b/api/types.go @@ -25,6 +25,10 @@ type AppIDResponse struct { ID string `json:"id"` } +type TokenRequest struct { + Email string `json:"email"` +} + type TokenStatusRequest struct { Token string `json:"token"` Email string `json:"email"` diff --git a/notification/email/template_test.go b/notification/email/template_test.go index 413d3b0..1d5881d 100644 --- a/notification/email/template_test.go +++ b/notification/email/template_test.go @@ -97,7 +97,6 @@ func TestCompose(t *testing.T) { if onlyHTMLEmail.PlainBody != nil { t.Fatalf("expected nil, got %v", string(onlyHTMLEmail.PlainBody)) } - } func Test_composePlain(t *testing.T) { From 256d4ee9d5eb403d730c90c5d5e1f60382a91483 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sat, 8 Mar 2025 23:31:42 +0100 Subject: [PATCH 21/36] dependencies updated, handlers tests --- .github/workflows/main.yml | 4 +- api/handlers.go | 15 ++------ api/handlers_test.go | 79 +++++++++++++++++++++++++++++++------- api/io.go | 12 ++++-- api/service.go | 19 ++++----- go.mod | 6 +-- go.sum | 8 ++-- 7 files changed, 93 insertions(+), 50 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 845b40e..4eb7bd2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Go environment uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version: "1.24" - name: Tidy go module run: | go mod tidy @@ -45,7 +45,7 @@ jobs: - name: Set up Go environment uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version: "1.24" - name: Run Go test -race run: go test ./... -v -race -timeout=1h - name: Rerun Go test to generate coverage report diff --git a/api/handlers.go b/api/handlers.go index 1d1c6b4..386fe9e 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -22,9 +22,7 @@ func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { return } // return the app id - if err := ResponseWith(&AppIDResponse{app.ID().String()}).Write(w); err != nil { - EncodeAppIDResponseErr.WithErr(err).Write(w) - } + ResponseWith(&AppIDResponse{app.ID().String()}).WriteJSON(w) } func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -79,9 +77,7 @@ func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { SendEmailErr.WithErr(err).Write(w) return } - if err := OkResponse().Write(w); err != nil { - InternalErr.WithErr(err).Write(w) - } + OkResponse().WriteJSON(w) } func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -114,11 +110,8 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { return } ok := appID.VerifyToken(*tkn, *secret, req.Data.Email) - if err := ResponseWith(&TokenStatusResponse{ + ResponseWith(&TokenStatusResponse{ Valid: ok, Expiration: exp, - }).Write(w); err != nil { - EncodeTokenStatusResponseErr.WithErr(err).Write(w) - return - } + }).WriteJSON(w) } diff --git a/api/handlers_test.go b/api/handlers_test.go index 00d184e..8a77be2 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -30,11 +30,8 @@ func (testCase testCaseAPIHandler[Rq, Rs]) url() string { return fmt.Sprintf("%s%s", testServerApiURL, testCase.endpoint) } -func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T, parallel bool) { +func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { - if parallel { - t.Parallel() - } var reqBuffer io.Reader if testCase.request != nil { rawBody, err := json.Marshal(testCase.request) @@ -108,14 +105,14 @@ func TestGenerateAppIDHandler(t *testing.T) { response: &AppIDResponse{ ID: testApp.ID().String(), }, - }.Run(t, true) + }.Run(t) testCaseAPIHandler[AppIDRequest, AppIDResponse]{ name: "no request", method: http.MethodPost, endpoint: AppsPath, request: nil, err: DecodeAppIDRequestErr, - }.Run(t, true) + }.Run(t) testCaseAPIHandler[AppIDRequest, AppIDResponse]{ name: "invalid request", method: http.MethodPost, @@ -127,7 +124,7 @@ func TestGenerateAppIDHandler(t *testing.T) { Secret: testAppSecret, }, err: InvalidAppIDErr, - }.Run(t, true) + }.Run(t) } func TestRequestTokenAndStatusHandler(t *testing.T) { @@ -149,7 +146,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { }, response: nil, err: InvalidAppHeadersErr, - }.Run(t, false) + }.Run(t) testCaseAPIHandler[TokenRequest, any]{ name: "invalid app id request", method: http.MethodPost, @@ -163,7 +160,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { }, response: nil, err: InvalidAppIDErr, - }.Run(t, false) + }.Run(t) testCaseAPIHandler[TokenRequest, any]{ name: "no app secret request", method: http.MethodPost, @@ -176,7 +173,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { }, response: nil, err: InvalidAppHeadersErr, - }.Run(t, false) + }.Run(t) testCaseAPIHandler[TokenRequest, any]{ name: "no email provided", method: http.MethodPost, @@ -190,7 +187,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { }, response: nil, err: GenerateTokenErr, - }.Run(t, false) + }.Run(t) invalid := []byte("invalid") testCaseAPIHandler[[]byte, any]{ name: "no request", @@ -203,7 +200,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { request: &invalid, response: nil, err: DecodeTokenRequestErr, - }.Run(t, false) + }.Run(t) login.Template = email.EmailTemplate{ HTML: "", @@ -221,7 +218,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { Email: testUserEmail, }, response: nil, - }.Run(t, false) + }.Run(t) var testToken *token.Token select { @@ -266,5 +263,59 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { Valid: true, Expiration: testToken.Expiration().Time(), }, - }.Run(t, false) + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "invalid app id", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{"invalid"}, + AppSecretHeader: []string{testAppSecret}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppIDErr, + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "no app secret", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "no headers", + method: http.MethodPut, + endpoint: TokensPath, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + + testCaseAPIHandler[any, any]{ + name: "no request", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + AppIDHeader: []string{testAppID.String()}, + AppSecretHeader: []string{testAppSecret}, + }, + err: DecodeTokenStatusRequestErr, + }.Run(t) } diff --git a/api/io.go b/api/io.go index 9ec021a..ae2cc16 100644 --- a/api/io.go +++ b/api/io.go @@ -26,15 +26,19 @@ func OkResponse() *Response[any] { return &Response[any]{empty: true} } -func (r *Response[T]) Write(w http.ResponseWriter) error { +func (r *Response[T]) WriteJSON(w http.ResponseWriter) { if !r.empty { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(r.Data) + if err := json.NewEncoder(w).Encode(r.Data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return } w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("OK")) - return err + if _, err := w.Write([]byte("OK")); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } type Request[T any] struct { diff --git a/api/service.go b/api/service.go index 4d32e4c..a0c5d89 100644 --- a/api/service.go +++ b/api/service.go @@ -32,23 +32,18 @@ type Service struct { func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) + rateLimiter := apihandler.RateLimiter(internalCtx, 100, 100, time.Minute*3) // create the service srv := &Service{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - nq: nq, - handler: apihandler.NewHandler(&apihandler.Config{ - CORS: true, - RateLimitConfig: &apihandler.RateLimitConfig{ - Rate: 2, - Limit: 10, - }, - }), + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + nq: nq, + handler: apihandler.NewHandler(true, rateLimiter), } // register the routes and handlers srv.handler.Get(HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { - OkResponse().Write(w) + OkResponse().WriteJSON(w) }) srv.handler.Post(AppsPath, srv.generateAppIDHandler) srv.handler.Post(TokensPath, srv.requestTokenHandler) diff --git a/go.mod b/go.mod index fc0b9ff..940068e 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/simpleauthlink/authapi -go 1.21 +go 1.24 -require github.com/lucasmenendez/apihandler v0.0.7 +require github.com/lucasmenendez/apihandler v0.0.8 -require golang.org/x/time v0.10.0 // indirect +require golang.org/x/time v0.11.0 // indirect diff --git a/go.sum b/go.sum index 09e5a75..7e7d3a9 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -github.com/lucasmenendez/apihandler v0.0.7 h1:OItUaGN5J+KrYFLZnQUNHXnOBP6HZyvlobyk1Jd7JkI= -github.com/lucasmenendez/apihandler v0.0.7/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= -golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +github.com/lucasmenendez/apihandler v0.0.8 h1:xHBNqdg+/eKpmjSvcQIkfIWxhsBcQa5TpwRzU00KugU= +github.com/lucasmenendez/apihandler v0.0.8/go.mod h1:u19tqauhQwxXbR2rw9//dCdM7oNNqzzPbByHh7R1imU= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= From af95dadaf12b762a944a0a7a8cee43fe6ce75070 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sat, 8 Mar 2025 23:50:54 +0100 Subject: [PATCH 22/36] more api tests --- api/helpers_test.go | 60 +++++++++++++++++++++ api/io.go | 3 ++ api/io_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 api/helpers_test.go create mode 100644 api/io_test.go diff --git a/api/helpers_test.go b/api/helpers_test.go new file mode 100644 index 0000000..d385de5 --- /dev/null +++ b/api/helpers_test.go @@ -0,0 +1,60 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestAppConfigFromRequest(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expectedAppID string + expectedSecret string + expectError bool + }{ + { + name: "Valid headers", + headers: map[string]string{AppIDHeader: "testAppID", AppSecretHeader: "testAppSecret"}, + expectedAppID: "testAppID", + expectedSecret: "testAppSecret", + expectError: false, + }, + { + name: "Missing app id", + headers: map[string]string{AppSecretHeader: "testAppSecret"}, + expectError: true, + }, + { + name: "Missing app secret", + headers: map[string]string{AppIDHeader: "testAppID"}, + expectError: true, + }, + { + name: "Missing both headers", + headers: map[string]string{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + appID, appSecret, err := appConfigFromRequest(req) + if (err != nil) != tt.expectError { + t.Errorf("expected error: %v, got: %v", tt.expectError, err) + } + if appID != tt.expectedAppID { + t.Errorf("expected appID: %s, got: %s", tt.expectedAppID, appID) + } + if appSecret != tt.expectedSecret { + t.Errorf("expected appSecret: %s, got: %s", tt.expectedSecret, appSecret) + } + }) + } +} diff --git a/api/io.go b/api/io.go index ae2cc16..7d8fd07 100644 --- a/api/io.go +++ b/api/io.go @@ -49,6 +49,9 @@ func (req *Request[T]) Read(r *http.Request) error { if req == nil { req = new(Request[T]) } + if r.Body == nil { + return fmt.Errorf("nil request body") + } rawBody, err := io.ReadAll(r.Body) if err != nil { return err diff --git a/api/io_test.go b/api/io_test.go new file mode 100644 index 0000000..ace6588 --- /dev/null +++ b/api/io_test.go @@ -0,0 +1,126 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestResponseWith(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + nilResp := ResponseWith[Data](nil) + if !nilResp.empty { + t.Errorf("expected response to be empty") + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + if resp.empty { + t.Errorf("expected response to be non-empty") + } + if resp.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, resp.Data.Message) + } +} + +func TestOkResponse(t *testing.T) { + resp := OkResponse() + if !resp.empty { + t.Errorf("expected response to be empty") + } +} + +func TestWriteJSON(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + expected, _ := json.Marshal(data) + if rr.Body.String() != string(expected)+"\n" { + t.Errorf("expected body %s, got %s", string(expected), rr.Body.String()) + } + // write json with nil data + nilResp := ResponseWith[string](nil) + rr = httptest.NewRecorder() + nilResp.WriteJSON(rr) + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + +} + +func TestWriteJSON_Empty(t *testing.T) { + resp := OkResponse() + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + if rr.Body.String() != "OK" { + t.Errorf("expected body OK, got %s", rr.Body.String()) + } +} + +func TestRead(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + body, _ := json.Marshal(data) + req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + + var request Request[Data] + if err := request.Read(req); err != nil { + t.Errorf("unexpected error: %v", err) + } + if request.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, request.Data.Message) + } +} + +func TestRead_EmptyBody(t *testing.T) { + noBody, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + + if nilReq := new(Request[any]).Read(noBody); nilReq == nil { + t.Errorf("expected error, got nil") + } + + req, err := http.NewRequest("POST", "/", bytes.NewBuffer([]byte(""))) + if err != nil { + t.Fatal(err) + } + + var request *Request[any] + err = request.Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + err = new(Request[any]).Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + if err.Error() != "empty request body" { + t.Errorf("expected empty request body error, got %v", err) + } +} From 6daa958ef4d594a1965f8428303d36d069757184 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 9 Mar 2025 00:02:12 +0100 Subject: [PATCH 23/36] more notification tests --- notification/notification_test.go | 90 +++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 notification/notification_test.go diff --git a/notification/notification_test.go b/notification/notification_test.go new file mode 100644 index 0000000..573947c --- /dev/null +++ b/notification/notification_test.go @@ -0,0 +1,90 @@ +package notification + +import ( + "testing" +) + +func TestNotificationParams_Valid(t *testing.T) { + tests := []struct { + name string + params NotificationParams + valid bool + }{ + { + name: "Valid params", + params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + valid: true, + }, + { + name: "Invalid email", + params: NotificationParams{To: "invalid-email", Subject: "Test Subject"}, + valid: false, + }, + { + name: "Empty subject", + params: NotificationParams{To: "test@example.com", Subject: ""}, + valid: false, + }, + { + name: "Empty email and subject", + params: NotificationParams{To: "", Subject: ""}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.params.Valid(); got != tt.valid { + t.Errorf("expected valid: %v, got: %v", tt.valid, got) + } + }) + } +} + +func TestNotification_Valid(t *testing.T) { + tests := []struct { + name string + notification Notification + valid bool + }{ + { + name: "Valid notification with Body", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + Body: []byte("Test Body"), + }, + valid: true, + }, + { + name: "Valid notification with PlainBody", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + PlainBody: []byte("Test Plain Body"), + }, + valid: true, + }, + { + name: "Invalid notification with empty Body and PlainBody", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + }, + valid: false, + }, + { + name: "Invalid notification with invalid Params", + notification: Notification{ + Params: NotificationParams{To: "invalid-email", Subject: "Test Subject"}, + Body: []byte("Test Body"), + }, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.notification.Valid(); got != tt.valid { + t.Errorf("expected valid: %v, got: %v", tt.valid, got) + } + }) + } +} From 25c3156d558d9f56bc01e72057300a0f8e45fb7b Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 30 Mar 2025 20:24:47 +0200 Subject: [PATCH 24/36] demo cmd --- api/error.go | 17 +-- api/handlers.go | 40 ++++++ api/handlers_test.go | 29 +--- api/io.go | 64 ++++++--- api/routes.go | 6 +- api/service.go | 30 +++- api/types.go | 13 +- cmd/authapi/main.go | 138 ++++--------------- cmd/consts.go | 34 +++++ cmd/demo/main.go | 63 +++++++++ docker-compose.yml | 39 ------ docker/Dockerfile.demo | 16 +++ Dockerfile => docker/Dockerfile.prod | 3 +- example.env | 13 +- internal/osflag/osflag.go | 152 +++++++++++++++++++++ notification/templates/login/definition.go | 29 ++++ 16 files changed, 465 insertions(+), 221 deletions(-) create mode 100644 cmd/consts.go create mode 100644 cmd/demo/main.go delete mode 100644 docker-compose.yml create mode 100644 docker/Dockerfile.demo rename Dockerfile => docker/Dockerfile.prod (71%) create mode 100644 internal/osflag/osflag.go diff --git a/api/error.go b/api/error.go index 94697a5..93e54ad 100644 --- a/api/error.go +++ b/api/error.go @@ -15,9 +15,10 @@ var ( EncodeAppIDResponseErr = newApiErr(1010, http.StatusInternalServerError).With("could not encode app id response") EncodeTokenStatusResponseErr = newApiErr(1011, http.StatusInternalServerError).With("could not encode token status response") // Bad request errors - InvalidAppHeadersErr = newApiErr(1020, http.StatusBadRequest).With("invalid app headers") - InvalidAppIDErr = newApiErr(1021, http.StatusBadRequest).With("invalid app id") - InvalidAppSecretErr = newApiErr(1022, http.StatusBadRequest).With("invalid app secret") + InvalidAppHeadersErr = newApiErr(1020, http.StatusBadRequest).With("invalid app headers") + InvalidAppIDErr = newApiErr(1021, http.StatusBadRequest).With("invalid app id") + InvalidAppSecretErr = newApiErr(1022, http.StatusBadRequest).With("invalid app secret") + InvalidDemoEmailInboxErr = newApiErr(1023, http.StatusBadRequest).With("invalid demo email inbox") // Internal errors GenerateTokenErr = newApiErr(1030, http.StatusInternalServerError).With("could not generate token") GenerateEmailErr = newApiErr(1031, http.StatusInternalServerError).With("could not generate email") @@ -28,8 +29,8 @@ var ( type APIError struct { Code int `json:"code"` Message string `json:"message"` - Err string `json:"error"` - StatusCode int `json:"status_code"` + Err string `json:"error,omitempty"` + statusCode int } func (e *APIError) Bytes() []byte { @@ -41,7 +42,7 @@ func (e *APIError) Bytes() []byte { } func (e *APIError) Error() string { - return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.StatusCode) + return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.statusCode) } func (e *APIError) WithErr(err error) *APIError { @@ -64,7 +65,7 @@ func (e *APIError) With(msg string) *APIError { func (e *APIError) Write(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json") - w.WriteHeader(e.StatusCode) + w.WriteHeader(e.statusCode) if _, err := w.Write(e.Bytes()); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -73,6 +74,6 @@ func (e *APIError) Write(w http.ResponseWriter) { func newApiErr(code, status int) *APIError { return &APIError{ Code: code, - StatusCode: status, + statusCode: status, } } diff --git a/api/handlers.go b/api/handlers.go index 386fe9e..6e0757d 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "net/http" "github.com/simpleauthlink/authapi/notification" @@ -115,3 +116,42 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { Expiration: exp, }).WriteJSON(w) } + +func (s *Service) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + OkResponse().Write(w) +} + +func (s *Service) demoInboxHandler(w http.ResponseWriter, r *http.Request) { + // get the email from get parameters + email := r.URL.Query().Get("email") + if email == "" { + InvalidDemoEmailInboxErr.Write(w) + return + } + // set http headers required for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + // create a channel for client disconnection + clientGone := r.Context().Done() + // create a response controller + rc := http.NewResponseController(w) + for { + select { + case <-s.ctx.Done(): + case <-clientGone: + return + case msg := <-s.demoMailInbox: + // find the token in the email + if testToken := login.FindToken(email, msg); testToken != nil { + // send an event to the client with the token in the "data" field + if _, err := fmt.Fprintf(w, "data: %s\n\n", testToken); err != nil { + return + } + if err := rc.Flush(); err != nil { + return + } + } + } + } +} diff --git a/api/handlers_test.go b/api/handlers_test.go index 8a77be2..a3d06f8 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -6,11 +6,9 @@ import ( "fmt" "io" "net/http" - "regexp" "testing" "time" - "github.com/simpleauthlink/authapi/notification" "github.com/simpleauthlink/authapi/notification/email" "github.com/simpleauthlink/authapi/notification/templates/login" "github.com/simpleauthlink/authapi/token" @@ -52,8 +50,8 @@ func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T) { defer resp.Body.Close() switch { case testCase.err != nil: - if resp.StatusCode != testCase.err.StatusCode { - t.Fatalf("expected status code: %d, got: %d", testCase.err.StatusCode, resp.StatusCode) + if resp.StatusCode != testCase.err.statusCode { + t.Fatalf("expected status code: %d, got: %d", testCase.err.statusCode, resp.StatusCode) } err := new(APIError) if err := json.NewDecoder(resp.Body).Decode(err); err != nil { @@ -99,7 +97,7 @@ func TestGenerateAppIDHandler(t *testing.T) { request: &AppIDRequest{ Name: testApp.Name, RedirectURL: testApp.RedirectURI, - Duration: int64(testApp.SessionDuration), + Duration: testApp.SessionDuration.String(), Secret: testAppSecret, }, response: &AppIDResponse{ @@ -120,7 +118,7 @@ func TestGenerateAppIDHandler(t *testing.T) { request: &AppIDRequest{ Name: testAppName, RedirectURL: testAppRedirectURL, - Duration: int64(time.Second), + Duration: time.Second.String(), Secret: testAppSecret, }, err: InvalidAppIDErr, @@ -223,25 +221,10 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { var testToken *token.Token select { case receivedMsg := <-inboxChan: - data := login.Data{ - AppName: testAppName, - Email: testUserEmail, - Token: `(.+\..+)`, - Link: testAppRedirectURL + receivedMsg, - } - notification, err := login.Template.Compose(notification.NotificationParams{ - To: testUserEmail, - Subject: data.Subject(), - }, data) - if err != nil { - t.Fatalf("could not compose notification: %v", err) - } - tokenRgx := regexp.MustCompile(string(notification.PlainBody)) - tokenResult := tokenRgx.FindAllStringSubmatch(receivedMsg, -1) - if len(tokenResult) < 1 || len(tokenResult[0]) < 2 { + testToken = login.FindToken(testUserEmail, receivedMsg) + if testToken == nil { t.Fatal("could not find token in email") } - testToken = new(token.Token).SetString(tokenResult[0][1]) break case <-time.After(2 * time.Second): t.Fatal("timed out waiting for the email to be received") diff --git a/api/io.go b/api/io.go index 7d8fd07..b1429c8 100644 --- a/api/io.go +++ b/api/io.go @@ -7,6 +7,27 @@ import ( "net/http" ) +type Request[T any] struct { + Data T +} + +func (req *Request[T]) Read(r *http.Request) error { + if req == nil { + req = new(Request[T]) + } + if r.Body == nil { + return fmt.Errorf("nil request body") + } + rawBody, err := io.ReadAll(r.Body) + if err != nil { + return err + } + if len(rawBody) == 0 { + return fmt.Errorf("empty request body") + } + return json.Unmarshal(rawBody, &req.Data) +} + type Response[T any] struct { Data T empty bool @@ -22,7 +43,10 @@ func ResponseWith[T any](data *T) *Response[T] { } } -func OkResponse() *Response[any] { +func OkResponse(body ...byte) *Response[any] { + if len(body) > 0 { + return &Response[any]{Data: body, empty: false} + } return &Response[any]{empty: true} } @@ -41,23 +65,31 @@ func (r *Response[T]) WriteJSON(w http.ResponseWriter) { } } -type Request[T any] struct { - Data T -} - -func (req *Request[T]) Read(r *http.Request) error { - if req == nil { - req = new(Request[T]) +func (r *Response[T]) Write(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + if r.empty { + if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return } - if r.Body == nil { - return fmt.Errorf("nil request body") + if data, ok := r.bytes(); ok { + if _, err := w.Write(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } - rawBody, err := io.ReadAll(r.Body) - if err != nil { - return err +} + +func (r *Response[T]) bytes() ([]byte, bool) { + // check if the response is empty + if r.empty { + return nil, true } - if len(rawBody) == 0 { - return fmt.Errorf("empty request body") + // ensure that the response data is an slice of bytes + switch v := any(r.Data).(type) { + case []byte: + return v, true + default: + return nil, false } - return json.Unmarshal(rawBody, &req.Data) } diff --git a/api/routes.go b/api/routes.go index 928805a..72bbf6f 100644 --- a/api/routes.go +++ b/api/routes.go @@ -4,7 +4,7 @@ const ( // HealthCheckPath constant is the path used to check the health of the API // server. It is a string with a value of "/health". HealthCheckPath = "/ping" - - AppsPath = "/apps" - TokensPath = "/tokens" + AppsPath = "/apps" + TokensPath = "/tokens" + DemoInboxPath = "/demo/inbox" ) diff --git a/api/service.go b/api/service.go index a0c5d89..850f79e 100644 --- a/api/service.go +++ b/api/service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/lucasmenendez/apihandler" + "github.com/simpleauthlink/authapi/internal" "github.com/simpleauthlink/authapi/notification" ) @@ -18,6 +19,10 @@ type Config struct { Server string ServerPort int Secret string + // demo stuff + DemoMode bool + DemoSMTPAddr string + DemoSMTPPort int } type Service struct { @@ -28,6 +33,9 @@ type Service struct { nq notification.Queue handler *apihandler.Handler httpServer *http.Server + // demo stuff + demoMailServer *internal.FakeSMTPServer + demoMailInbox chan string } func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { @@ -41,13 +49,21 @@ func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, err nq: nq, handler: apihandler.NewHandler(true, rateLimiter), } + // demo stuff + if cfg.DemoMode { + srv.demoMailInbox = make(chan string, 1) + srv.demoMailServer = internal.NewFakeSMTPServer(cfg.DemoSMTPAddr, + cfg.DemoSMTPPort, srv.demoMailInbox) + if err := srv.demoMailServer.Start(internalCtx); err != nil { + return nil, err + } + _ = srv.handler.Get(DemoInboxPath, srv.demoInboxHandler) + } // register the routes and handlers - srv.handler.Get(HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { - OkResponse().WriteJSON(w) - }) - srv.handler.Post(AppsPath, srv.generateAppIDHandler) - srv.handler.Post(TokensPath, srv.requestTokenHandler) - srv.handler.Put(TokensPath, srv.verifyTokenHandler) + _ = srv.handler.Post(AppsPath, srv.generateAppIDHandler) + _ = srv.handler.Post(TokensPath, srv.requestTokenHandler) + _ = srv.handler.Put(TokensPath, srv.verifyTokenHandler) + _ = srv.handler.Get(HealthCheckPath, srv.healthCheckHandler) // build the http server srv.httpServer = &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server, cfg.ServerPort), @@ -91,7 +107,7 @@ func (s *Service) WaitToShutdown() error { done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-done - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() defer s.Stop() return s.httpServer.Shutdown(ctx) diff --git a/api/types.go b/api/types.go index 755c391..47b958d 100644 --- a/api/types.go +++ b/api/types.go @@ -8,17 +8,20 @@ import ( type AppIDRequest struct { Name string `json:"name"` - Duration int64 `json:"session_duration"` + Duration string `json:"session_duration"` RedirectURL string `json:"redirect_url"` Secret string `json:"secret"` } func (data *AppIDRequest) parseApp() *token.App { - return &token.App{ - Name: data.Name, - RedirectURI: data.RedirectURL, - SessionDuration: time.Duration(data.Duration), + if duration, err := time.ParseDuration(data.Duration); err == nil { + return &token.App{ + Name: data.Name, + RedirectURI: data.RedirectURL, + SessionDuration: duration, + } } + return new(token.App) } type AppIDResponse struct { diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index 0d688ff..bd54065 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -2,49 +2,15 @@ package main import ( "context" - "flag" "fmt" "log" - "os" - "strconv" "github.com/simpleauthlink/authapi/api" + "github.com/simpleauthlink/authapi/cmd" + "github.com/simpleauthlink/authapi/internal/osflag" "github.com/simpleauthlink/authapi/notification/email" ) -const ( - defaultHost = "0.0.0.0" - defaultPort = 8080 - defaultEmailAddr = "" - defaultEmailPass = "" - defaultEmailHost = "" - defaultEmailPort = 587 - defaultSecret = "simpleauthlink-secret" - - hostFlag = "host" - portFlag = "port" - emailAddrFlag = "email-addr" - emailPassFlag = "email-pass" - emailHostFlag = "email-host" - emailPortFlag = "email-port" - secretFlag = "secret" - hostFlagDesc = "service host" - portFlagDesc = "service port" - emailAddrFlagDesc = "email account address" - emailPassFlagDesc = "email account password" - emailHostFlagDesc = "email server host" - emailPortFlagDesc = "email server port" - secretFlagDesc = "secret used to generate the tokens" - - hostEnv = "SIMPLEAUTH_HOST" - portEnv = "SIMPLEAUTH_PORT" - emailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" - emailPassEnv = "SIMPLEAUTH_EMAIL_PASS" - emailHostEnv = "SIMPLEAUTH_EMAIL_HOST" - emailPortEnv = "SIMPLEAUTH_EMAIL_PORT" - secretEnv = "SIMPLEAUTH_SECRET" -) - type config struct { host string port int @@ -55,12 +21,30 @@ type config struct { secret string } +func (c *config) String() string { + return fmt.Sprintf(`{"server": "%s:%d", "smtpServer": "%s:%d", "smtpAuth": "%s:%s", "secret": "%s"}`, + c.host, c.port, c.emailHost, c.emailPort, c.emailAddr, c.emailPass, c.secret) +} + func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) - c, err := parseConfig() - if err != nil { - log.Fatalln("ERR: error parsing config:", err) + c := new(config) + // get config from flags + osflag.StringVar(&c.host, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) + osflag.IntVar(&c.port, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.HostFlagDesc, false) + osflag.StringVar(&c.emailAddr, cmd.EmailAddrEnv, cmd.EmailAddrFlag, cmd.DefaultEmailAddr, cmd.EmailAddrFlagDesc, true) + osflag.StringVar(&c.emailPass, cmd.EmailPassEnv, cmd.EmailPassFlag, cmd.DefaultEmailPass, cmd.EmailPassFlagDesc, true) + osflag.StringVar(&c.emailHost, cmd.EmailHostEnv, cmd.EmailHostFlag, cmd.DefaultEmailHost, cmd.EmailHostFlagDesc, true) + osflag.IntVar(&c.emailPort, cmd.EmailPortEnv, cmd.EmailPortFlag, cmd.DefaultEmailPort, cmd.EmailPortFlagDesc, false) + osflag.StringVar(&c.secret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, true) + if err := osflag.Parse(); err != nil { + log.Fatalln("ERR: error parsing flags:", err) + } + if !osflag.Parsed() { + log.Fatalln("ERR: error parsing flags:", "flags not parsed") + osflag.PrintDefaults() } + log.Println("INF: starting service with config:", c.String()) // create email queue emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ FromName: "SimpleAuthLink", @@ -85,84 +69,14 @@ func main() { if err != nil { log.Fatalln("ERR: error creating service:", err) } + // start the service in background go func() { if err := service.Start(); err != nil { log.Fatalln("ERR: error running service:", err) } }() // wait for the service to finish - service.WaitToShutdown() -} - -func parseConfig() (*config, error) { - var fhost, femailAddr, femailPass, femailHost, fsecret string - var fport, femailPort int - // get config from flags - flag.StringVar(&fhost, hostFlag, defaultHost, hostFlagDesc) - flag.IntVar(&fport, portFlag, defaultPort, hostFlagDesc) - flag.StringVar(&femailAddr, emailAddrFlag, defaultEmailAddr, emailAddrFlagDesc) - flag.StringVar(&femailPass, emailPassFlag, defaultEmailPass, emailPassFlagDesc) - flag.StringVar(&femailHost, emailHostFlag, defaultEmailHost, emailHostFlagDesc) - flag.IntVar(&femailPort, emailPortFlag, defaultEmailPort, emailPortFlagDesc) - flag.StringVar(&fsecret, secretFlag, defaultSecret, secretFlagDesc) - flag.Parse() - // get config from env - envHost := os.Getenv(hostEnv) - envPort := os.Getenv(portEnv) - envEmailAddr := os.Getenv(emailAddrEnv) - envEmailPass := os.Getenv(emailPassEnv) - envEmailHost := os.Getenv(emailHostEnv) - envEmailPort := os.Getenv(emailPortEnv) - envSecret := os.Getenv(secretEnv) - // check if the required flags are set - if femailAddr == "" && envEmailAddr == "" { - return nil, fmt.Errorf("email address is required, use -%s or set %s env var", emailAddrFlag, emailAddrEnv) - } - if femailPass == "" && envEmailPass == "" { - return nil, fmt.Errorf("email password is required, use -%s or set %s env var", emailPassFlag, emailPassEnv) - } - if femailHost == "" && envEmailHost == "" { - return nil, fmt.Errorf("email host is required, use -%s or set %s env var", emailHostFlag, emailHostEnv) - } - if fsecret == "" && envSecret == "" { - return nil, fmt.Errorf("secret is required, use -%s or set %s env var", secretFlag, secretEnv) - } - // set flags values by default - c := &config{ - host: fhost, - port: fport, - emailAddr: femailAddr, - emailPass: femailPass, - emailHost: femailHost, - emailPort: femailPort, - secret: fsecret, - } - // if some flags are not set, set them by env - if envHost != "" { - c.host = envHost - } - if envPort != "" { - if nenvPort, err := strconv.Atoi(envPort); err == nil { - c.port = nenvPort - } else { - return nil, fmt.Errorf("invalid port value: %s", envPort) - } - } - if envEmailAddr != "" { - c.emailAddr = envEmailAddr - } - if envEmailPass != "" { - c.emailPass = envEmailPass - } - if envEmailHost != "" { - c.emailHost = envEmailHost - } - if envEmailPort != "" { - if nenvEmailPort, err := strconv.Atoi(envEmailPort); err == nil { - c.emailPort = nenvEmailPort - } else { - return nil, fmt.Errorf("invalid email port value: %s", envEmailPort) - } + if err := service.WaitToShutdown(); err != nil { + log.Fatalln("ERR: error waiting for service to finish:", err) } - return c, nil } diff --git a/cmd/consts.go b/cmd/consts.go new file mode 100644 index 0000000..3d912b6 --- /dev/null +++ b/cmd/consts.go @@ -0,0 +1,34 @@ +package cmd + +const ( + DefaultHost = "0.0.0.0" + DefaultPort = 8080 + DefaultEmailAddr = "" + DefaultEmailPass = "" + DefaultEmailHost = "" + DefaultEmailPort = 587 + DefaultSecret = "simpleauthlink-secret" + + HostFlag = "host" + PortFlag = "port" + EmailAddrFlag = "email-addr" + EmailPassFlag = "email-pass" + EmailHostFlag = "email-host" + EmailPortFlag = "email-port" + SecretFlag = "secret" + HostFlagDesc = "service host" + PortFlagDesc = "service port" + EmailAddrFlagDesc = "email account address" + EmailPassFlagDesc = "email account password" + EmailHostFlagDesc = "email server host" + EmailPortFlagDesc = "email server port" + SecretFlagDesc = "secret used to generate the tokens" + + HostEnv = "SIMPLEAUTH_HOST" + PortEnv = "SIMPLEAUTH_PORT" + EmailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" + EmailPassEnv = "SIMPLEAUTH_EMAIL_PASS" + EmailHostEnv = "SIMPLEAUTH_EMAIL_HOST" + EmailPortEnv = "SIMPLEAUTH_EMAIL_PORT" + SecretEnv = "SIMPLEAUTH_SECRET" +) diff --git a/cmd/demo/main.go b/cmd/demo/main.go new file mode 100644 index 0000000..9fa400c --- /dev/null +++ b/cmd/demo/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "log" + + "github.com/simpleauthlink/authapi/api" + "github.com/simpleauthlink/authapi/cmd" + "github.com/simpleauthlink/authapi/internal/osflag" + "github.com/simpleauthlink/authapi/notification/email" +) + +func main() { + var ( + demoServer string + demoPort int + demoSecret string + ) + osflag.StringVar(&demoServer, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) + osflag.IntVar(&demoPort, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.PortFlagDesc, false) + osflag.StringVar(&demoSecret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, false) + if err := osflag.Parse(); err != nil { + log.Fatalln("ERR: error parsing flags:", err) + } + log.Println("INF: starting service with config:", demoServer, demoPort, demoSecret) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // create the email queue + emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ + FromName: "SimpleAuthLink Demo", + FromAddress: "demo@simpleauth.link", + SMTPServer: demoServer, + SMTPPort: 2525, + }) + if err != nil { + log.Fatalln("WRN: something occurs during email queue creation:", err) + } + // start the email queue and defer to stop it + emailQueue.Start() + defer emailQueue.Stop() + // create the service + service, err := api.New(ctx, &api.Config{ + Server: demoServer, + ServerPort: demoPort, + Secret: demoSecret, + DemoMode: true, + DemoSMTPAddr: demoServer, + DemoSMTPPort: 2525, + }, emailQueue) + if err != nil { + log.Fatalln("ERR: error creating service:", err) + } + // start the service in background + go func() { + if err := service.Start(); err != nil { + log.Fatalln("ERR: error running service:", err) + } + }() + // wait for the service to finish + if err := service.WaitToShutdown(); err != nil { + log.Fatalln("ERR: error waiting for service to finish:", err) + } +} diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index d85f3c9..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: simpleauthlink - -services: - authapi: - env_file: - - .env - build: - context: ./ - ports: - - ${SIMPLEAUTH_PORT}:${SIMPLEAUTH_PORT} - sysctls: - net.core.somaxconn: 8128 - restart: ${RESTART:-unless-stopped} - depends_on: - - mongo - mongo: - image: mongo - restart: ${RESTART:-unless-stopped} - ports: - - 27017:27017 - environment: - - MONGO_INITDB_ROOT_USERNAME=root - - MONGO_INITDB_ROOT_PASSWORD=authapi - - MONGO_INITDB_DATABASE=simpleauth - volumes: - - mongodb:/data/mongodb - mongo-express: - image: mongo-express - restart: ${RESTART:-unless-stopped} - ports: - - 8081:8081 - environment: - ME_CONFIG_MONGODB_ADMINUSERNAME: root - ME_CONFIG_MONGODB_ADMINPASSWORD: authapi - ME_CONFIG_MONGODB_URL: mongodb://root:authapi@mongo:27017/ -volumes: - mongodb: {} - - diff --git a/docker/Dockerfile.demo b/docker/Dockerfile.demo new file mode 100644 index 0000000..ff4a3b6 --- /dev/null +++ b/docker/Dockerfile.demo @@ -0,0 +1,16 @@ +# build +FROM golang:1.24-alpine as builder + +WORKDIR /app/data +COPY . . + +RUN go mod tidy +RUN go build -o /authapi ./cmd/demo/main.go + +# deploy +FROM alpine:latest + +WORKDIR / +COPY --from=builder /authapi /authapi + +ENTRYPOINT /authapi \ No newline at end of file diff --git a/Dockerfile b/docker/Dockerfile.prod similarity index 71% rename from Dockerfile rename to docker/Dockerfile.prod index 960ed89..204a427 100644 --- a/Dockerfile +++ b/docker/Dockerfile.prod @@ -1,5 +1,5 @@ # build -FROM golang:1.21-alpine as builder +FROM golang:1.24-alpine as builder WORKDIR /app/data COPY . . @@ -12,6 +12,5 @@ FROM alpine:latest WORKDIR / COPY --from=builder /authapi /authapi -COPY --from=builder /app/data/assets /assets ENTRYPOINT /authapi \ No newline at end of file diff --git a/example.env b/example.env index b37a676..17c1588 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,7 @@ -SIMPLEAUTH_EMAIL_ADDR="" -SIMPLEAUTH_EMAIL_PASS="" -SIMPLEAUTH_EMAIL_HOST="" -SIMPLEAUTH_DB_URI="mongodb://root:authapi@mongo:27017/" -SIMPLEAUTH_DB_NAME="simpleauth" -SIMPLEAUTH_DISPOSABLE_SRC="https://raw.githubusercontent.com/disposable-email-domains/disposable-email-domains/master/disposable_email_blocklist.conf" \ No newline at end of file +SIMPLEAUTH_HOST="localhost" +SIMPLEAUTH_PORT=8080 +SIMPLEAUTH_EMAIL_ADDR="test@test.com" +SIMPLEAUTH_EMAIL_PASS="smtp_server_password" +SIMPLEAUTH_EMAIL_HOST="smtp.example.com" +SIMPLEAUTH_EMAIL_PORT=587 +SIMPLEAUTH_SECRET="my_backend_secret" \ No newline at end of file diff --git a/internal/osflag/osflag.go b/internal/osflag/osflag.go new file mode 100644 index 0000000..d71924a --- /dev/null +++ b/internal/osflag/osflag.go @@ -0,0 +1,152 @@ +package osflag + +import ( + "flag" + "fmt" + "os" + "strconv" + "time" +) + +type OsFlagSet struct { + *flag.FlagSet + required map[string]bool + parsed bool +} + +var CommandLine *OsFlagSet + +func init() { + CommandLine = new(OsFlagSet) + if len(os.Args) == 0 { + CommandLine.FlagSet = flag.NewFlagSet("", flag.ExitOnError) + } else { + CommandLine.FlagSet = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + } + CommandLine.required = make(map[string]bool) +} + +func (of *OsFlagSet) BoolVar(p *bool, env, name string, value bool, usage string, required bool) { + var newDefault bool = value + if rawBool := os.Getenv(env); rawBool != "" { + if rawBool == "true" || rawBool == "True" || rawBool == "TRUE" || rawBool == "1" { + newDefault = true + } + } + of.required[name] = required + of.FlagSet.BoolVar(p, name, newDefault, usage) +} + +func (of *OsFlagSet) DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { + var newDefault time.Duration = value + if rawDuration := os.Getenv(env); rawDuration != "" { + if dur, err := time.ParseDuration(rawDuration); err == nil { + newDefault = dur + } + } + of.required[name] = required + of.FlagSet.DurationVar(p, name, newDefault, usage) +} + +func (of *OsFlagSet) Float64Var(p *float64, env, name string, value float64, usage string, required bool) { + var newDefault float64 = value + if rawFloat := os.Getenv(env); rawFloat != "" { + if f, err := strconv.ParseFloat(rawFloat, 64); err == nil { + newDefault = f + } + } + of.required[name] = required + of.FlagSet.Float64Var(p, name, newDefault, usage) +} + +func (of *OsFlagSet) IntVar(p *int, env, name string, value int, usage string, required bool) { + var newDefault int = value + if rawInt := os.Getenv(env); rawInt != "" { + if integer, err := strconv.Atoi(rawInt); err == nil { + newDefault = integer + } + } + of.required[name] = required + of.FlagSet.IntVar(p, name, newDefault, usage) +} + +func (of *OsFlagSet) StringVar(p *string, env, name string, value string, usage string, required bool) { + var newDefault string = value + if rawString := os.Getenv(env); rawString != "" { + newDefault = rawString + } + of.required[name] = required + of.FlagSet.StringVar(p, name, newDefault, usage) +} + +func (of *OsFlagSet) UintVar(p *uint, env, name string, value uint, usage string, required bool) { + var newDefault uint + if rawUint := os.Getenv(env); rawUint != "" { + if ui, err := strconv.ParseUint(rawUint, 10, 64); err == nil { + newDefault = uint(ui) + } + } + of.required[name] = required + of.FlagSet.UintVar(p, name, newDefault, usage) +} + +func (of *OsFlagSet) Parse() error { + if err := of.FlagSet.Parse(os.Args[1:]); err != nil { + return err + } + // check if all required flags are set + for name, required := range of.required { + if required { + f := of.FlagSet.Lookup(name) + if f == nil || f.Value.String() == "" { + return fmt.Errorf("required flag %s is not set", name) + } + } + } + of.parsed = of.FlagSet.Parsed() + return nil +} + +func (of *OsFlagSet) Parsed() bool { + return of.parsed +} + +func (of *OsFlagSet) PrintDefaults() { + of.FlagSet.PrintDefaults() +} + +func BoolVar(p *bool, env, name string, value bool, usage string, required bool) { + CommandLine.BoolVar(p, env, name, value, usage, required) +} + +func DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { + CommandLine.DurationVar(p, env, name, value, usage, required) +} + +func Float64Var(p *float64, env, name string, value float64, usage string, required bool) { + CommandLine.Float64Var(p, env, name, value, usage, required) +} + +func IntVar(p *int, env, name string, value int, usage string, required bool) { + CommandLine.IntVar(p, env, name, value, usage, required) +} + +func StringVar(p *string, env, name string, value string, usage string, required bool) { + CommandLine.StringVar(p, env, name, value, usage, required) +} + +func UintVar(p *uint, env, name string, value uint, usage string, required bool) { + CommandLine.UintVar(p, env, name, value, usage, required) +} + +func Parse() error { + return CommandLine.Parse() +} + +func Parsed() bool { + return CommandLine.parsed +} + +func PrintDefaults() { + CommandLine.PrintDefaults() +} diff --git a/notification/templates/login/definition.go b/notification/templates/login/definition.go index 35bc3d3..9a66b67 100644 --- a/notification/templates/login/definition.go +++ b/notification/templates/login/definition.go @@ -2,8 +2,11 @@ package login import ( _ "embed" + "regexp" + "github.com/simpleauthlink/authapi/notification" "github.com/simpleauthlink/authapi/notification/email" + "github.com/simpleauthlink/authapi/token" ) //go:embed template.html @@ -22,6 +25,32 @@ func (d Data) Subject() string { return "Your token for '" + d.AppName + "'" } +// FindToken function extracts the token from the email content. It uses a +// regular expression to fill the template with regex and find the token in +// the email content. Then it decodes the token and returns it. If the token +// is not found, it returns nil. +func FindToken(email, content string) *token.Token { + loginData := Data{ + AppName: `.+`, + Email: email, + Token: `(.+\..+)`, + Link: `.+`, + } + loginEmail, err := Template.Compose(notification.NotificationParams{ + To: email, + Subject: loginData.Subject(), + }, loginData) + if err != nil { + return nil + } + tokenRgx := regexp.MustCompile(string(loginEmail.PlainBody)) + tokenResult := tokenRgx.FindAllStringSubmatch(content, -1) + if len(tokenResult) < 1 || len(tokenResult[0]) < 2 { + return nil + } + return new(token.Token).SetString(tokenResult[0][1]) +} + // Template is the login email template definition, which contains the HTML // and plain text templates. var Template = email.EmailTemplate{ From 5797f0a6d892ee77897fed2936d3a962e1f90d98 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 30 Mar 2025 20:52:25 +0200 Subject: [PATCH 25/36] increase ratelimits and new makefile --- Makefile | 12 ++++++++++++ api/service.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..efbcccb --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +.PHONY: run demo clean + +demo: clean + @echo "Building demo image..." + docker build -f docker/Dockerfile.demo -t demo-simpleauthlink . + @echo "Running demo container..." + docker run --name demo-simpleauthlink --env-file .env -p 8080:8080 -d demo-simpleauthlink + +clean: + @echo "Cleaning up previous containers and images..." + -docker rm -f demo-simpleauthlink 2>/dev/null || true + -docker rmi -f demo-simpleauthlink 2>/dev/null || true diff --git a/api/service.go b/api/service.go index 850f79e..f04be52 100644 --- a/api/service.go +++ b/api/service.go @@ -40,7 +40,7 @@ type Service struct { func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - rateLimiter := apihandler.RateLimiter(internalCtx, 100, 100, time.Minute*3) + rateLimiter := apihandler.RateLimiter(internalCtx, 1000, 1000, time.Minute*3) // create the service srv := &Service{ ctx: internalCtx, From b571e3c1c401930c44dc6d1e4f36c65bc16e3666 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 30 Mar 2025 20:57:15 +0200 Subject: [PATCH 26/36] remove prefix from env vars --- cmd/consts.go | 14 +++++++------- example.env | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cmd/consts.go b/cmd/consts.go index 3d912b6..4205497 100644 --- a/cmd/consts.go +++ b/cmd/consts.go @@ -24,11 +24,11 @@ const ( EmailPortFlagDesc = "email server port" SecretFlagDesc = "secret used to generate the tokens" - HostEnv = "SIMPLEAUTH_HOST" - PortEnv = "SIMPLEAUTH_PORT" - EmailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" - EmailPassEnv = "SIMPLEAUTH_EMAIL_PASS" - EmailHostEnv = "SIMPLEAUTH_EMAIL_HOST" - EmailPortEnv = "SIMPLEAUTH_EMAIL_PORT" - SecretEnv = "SIMPLEAUTH_SECRET" + HostEnv = "HOST" + PortEnv = "PORT" + EmailAddrEnv = "EMAIL_ADDR" + EmailPassEnv = "EMAIL_PASS" + EmailHostEnv = "EMAIL_HOST" + EmailPortEnv = "EMAIL_PORT" + SecretEnv = "SECRET" ) diff --git a/example.env b/example.env index 17c1588..617594f 100644 --- a/example.env +++ b/example.env @@ -1,7 +1,7 @@ -SIMPLEAUTH_HOST="localhost" -SIMPLEAUTH_PORT=8080 -SIMPLEAUTH_EMAIL_ADDR="test@test.com" -SIMPLEAUTH_EMAIL_PASS="smtp_server_password" -SIMPLEAUTH_EMAIL_HOST="smtp.example.com" -SIMPLEAUTH_EMAIL_PORT=587 -SIMPLEAUTH_SECRET="my_backend_secret" \ No newline at end of file +HOST="localhost" +PORT=8080 +EMAIL_ADDR="test@test.com" +EMAIL_PASS="smtp_server_password" +EMAIL_HOST="smtp.example.com" +EMAIL_PORT=587 +SECRET="my_backend_secret" \ No newline at end of file From c9471d4a8ff6cbb669de0579b7bafe4b6c44ee36 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 30 Mar 2025 21:19:28 +0200 Subject: [PATCH 27/36] remove rate limiter for deploy debug --- Makefile | 14 ++++++++++---- api/service.go | 13 +++++++------ docker/Dockerfile.demo | 2 ++ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index efbcccb..ebd23ff 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,15 @@ demo: clean @echo "Building demo image..." docker build -f docker/Dockerfile.demo -t demo-simpleauthlink . @echo "Running demo container..." - docker run --name demo-simpleauthlink --env-file .env -p 8080:8080 -d demo-simpleauthlink + docker run --name demo-simpleauthlink --env-file .env -p ${PORT}:80 -d demo-simpleauthlink -clean: +clean: checkport @echo "Cleaning up previous containers and images..." - -docker rm -f demo-simpleauthlink 2>/dev/null || true - -docker rmi -f demo-simpleauthlink 2>/dev/null || true + @docker rm -f demo-simpleauthlink 2>/dev/null || true + @echo "Containers cleaned up" + @docker rmi -f demo-simpleauthlink 2>/dev/null || true + @echo "Images cleaned up" + @echo "Cleaning up done" + +checkport: + @echo "Using default port ${PORT}" \ No newline at end of file diff --git a/api/service.go b/api/service.go index f04be52..4fa480c 100644 --- a/api/service.go +++ b/api/service.go @@ -40,14 +40,15 @@ type Service struct { func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - rateLimiter := apihandler.RateLimiter(internalCtx, 1000, 1000, time.Minute*3) + // rateLimiter := apihandler.RateLimiter(internalCtx, 1000, 1000, time.Minute*3) // create the service srv := &Service{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - nq: nq, - handler: apihandler.NewHandler(true, rateLimiter), + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + nq: nq, + // handler: apihandler.NewHandler(true, rateLimiter), + handler: apihandler.NewHandler(true, nil), } // demo stuff if cfg.DemoMode { diff --git a/docker/Dockerfile.demo b/docker/Dockerfile.demo index ff4a3b6..38694fb 100644 --- a/docker/Dockerfile.demo +++ b/docker/Dockerfile.demo @@ -10,6 +10,8 @@ RUN go build -o /authapi ./cmd/demo/main.go # deploy FROM alpine:latest +EXPOSE 80 + WORKDIR / COPY --from=builder /authapi /authapi From ee8abb2815c8ab9927695978839154bcd17ebde4 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Mon, 31 Mar 2025 00:30:41 +0200 Subject: [PATCH 28/36] more integration of app secret --- api/handlers.go | 44 +++++++++++++++++++++++++++----------------- api/handlers_test.go | 8 ++++++-- api/helpers.go | 4 ++-- api/types.go | 7 ++++--- token/app.go | 39 +++++++++++++++++++++++++++++++-------- token/app_test.go | 38 +++++++++++++++++++++++++++++--------- token/id.go | 5 ++++- token/id_test.go | 24 ++++++++++++++++++++---- token/secret.go | 11 ++++++++++- 9 files changed, 133 insertions(+), 47 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index 6e0757d..f407649 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -9,6 +9,13 @@ import ( "github.com/simpleauthlink/authapi/token" ) +// generateAppIDHandler handles the request to generate an app id it decodes +// the app data from the request body and returns the app id in the response +// body. Every app data information is required to generate the app id. The +// app id is a self-contained representation of the app that can be used to +// generate tokens. It is created by encoding the app as a base64-encoded +// byte slice resulting in concatenating the app name, redirect uri, and +// session duration. func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { // decode the app data from the request body req := new(Request[AppIDRequest]) @@ -17,13 +24,15 @@ func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { return } // create the app from the data and check if it is valid - app := req.Data.parseApp() - if !app.Valid() { + app, appSecret := req.Data.parseApp() + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(appSecret)) + app.SetSecret(secret) + if !app.Valid(secret.Hash()) { InvalidAppIDErr.Write(w) return } // return the app id - ResponseWith(&AppIDResponse{app.ID().String()}).WriteJSON(w) + ResponseWith(&AppIDResponse{app.ID(secret).String()}).WriteJSON(w) } func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -36,8 +45,14 @@ func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { // decode the app id get the app from it appID := new(token.AppID).SetString(strAppID) app := new(token.App).SetID(appID) + // compose the app secret with both parts + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) + return + } // check if the app id is valid (it should be a valid app) - if !app.Valid() { + if !app.Valid(secret.Hash()) { InvalidAppIDErr.Write(w) return } @@ -48,11 +63,6 @@ func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { return } // generate user token - secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) - if !secret.Valid() { - InvalidAppSecretErr.Write(w) - return - } token := appID.GenerateToken(*secret, req.Data.Email) if token == nil { GenerateTokenErr.With(req.Data.Email).Write(w) @@ -91,8 +101,14 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { // decode the app id get the app from it appID := new(token.AppID).SetString(strAppID) app := new(token.App).SetID(appID) + // compose the app secret with both parts + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) + return + } // check if the app id is valid (it should be a valid app) - if !app.Valid() { + if !app.Valid(secret.Hash()) { InvalidAppIDErr.Write(w) return } @@ -105,14 +121,8 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { // check if the token is valid tkn := new(token.Token).SetString(req.Data.Token) exp := tkn.Expiration().Time() - secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) - if !secret.Valid() { - InvalidAppSecretErr.Write(w) - return - } - ok := appID.VerifyToken(*tkn, *secret, req.Data.Email) ResponseWith(&TokenStatusResponse{ - Valid: ok, + Valid: appID.VerifyToken(*tkn, *secret, req.Data.Email), Expiration: exp, }).WriteJSON(w) } diff --git a/api/handlers_test.go b/api/handlers_test.go index a3d06f8..65d5c81 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -90,6 +90,8 @@ func TestGenerateAppIDHandler(t *testing.T) { RedirectURI: testAppRedirectURL, SessionDuration: testAppSessionDuration, } + secret := new(token.Secret).SetParts([]byte(testServerSecret), []byte(testAppSecret)) + testApp.SetSecret(secret) testCaseAPIHandler[AppIDRequest, AppIDResponse]{ name: "valid request", method: http.MethodPost, @@ -101,7 +103,7 @@ func TestGenerateAppIDHandler(t *testing.T) { Secret: testAppSecret, }, response: &AppIDResponse{ - ID: testApp.ID().String(), + ID: testApp.ID(secret).String(), }, }.Run(t) testCaseAPIHandler[AppIDRequest, AppIDResponse]{ @@ -131,7 +133,9 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { RedirectURI: testAppRedirectURL, SessionDuration: testAppSessionDuration, } - testAppID := testApp.ID() + secret := new(token.Secret).SetParts([]byte(testServerSecret), []byte(testAppSecret)) + testApp.SetSecret(secret) + testAppID := testApp.ID(secret) testCaseAPIHandler[TokenRequest, any]{ name: "no appID request", method: http.MethodPost, diff --git a/api/helpers.go b/api/helpers.go index 19fd684..2635ef6 100644 --- a/api/helpers.go +++ b/api/helpers.go @@ -6,8 +6,8 @@ import ( ) const ( - AppIDHeader = "APP_ID" - AppSecretHeader = "APP_SECRET" + AppIDHeader = "AppID" + AppSecretHeader = "AppSecret" ) func appConfigFromRequest(r *http.Request) (string, string, error) { diff --git a/api/types.go b/api/types.go index 47b958d..4daaebd 100644 --- a/api/types.go +++ b/api/types.go @@ -13,15 +13,16 @@ type AppIDRequest struct { Secret string `json:"secret"` } -func (data *AppIDRequest) parseApp() *token.App { +func (data *AppIDRequest) parseApp() (*token.App, string) { if duration, err := time.ParseDuration(data.Duration); err == nil { - return &token.App{ + app := &token.App{ Name: data.Name, RedirectURI: data.RedirectURL, SessionDuration: duration, } + return app, data.Secret } - return new(token.App) + return new(token.App), "" } type AppIDResponse struct { diff --git a/token/app.go b/token/app.go index b378caf..c8dd65c 100644 --- a/token/app.go +++ b/token/app.go @@ -1,7 +1,9 @@ package token import ( + "bytes" "encoding/base64" + "encoding/hex" "strings" "time" ) @@ -12,13 +14,14 @@ type App struct { Name string RedirectURI string SessionDuration time.Duration + AppSecretHash []byte } // Valid method returns true if the app is valid, false otherwise. An app is // considered valid if its name is between 3 and 20 characters, its redirect // URI is a valid URI, and its session duration is between 5 minutes and 24 // hours. -func (app *App) Valid() bool { +func (app *App) Valid(secretHash []byte) bool { if app == nil { return false } @@ -34,20 +37,23 @@ func (app *App) Valid() bool { if app.SessionDuration < minDuration || app.SessionDuration > maxDuration { return false } + if secretHash != nil { + return bytes.Equal(app.AppSecretHash, secretHash) + } return true } // Attributes method returns the app's attributes as a slice of strings. This // is useful for encoding the app. func (app *App) Attributes() []string { - return []string{app.Name, app.RedirectURI, app.SessionDuration.String()} + return []string{app.Name, app.RedirectURI, app.SessionDuration.String(), hex.EncodeToString(app.AppSecretHash)} } // SetAttributes method sets the app's attributes from a slice of strings. This // is useful for decoding the app. func (app *App) SetAttributes(attrs []string) *App { // check if the slice has the correct number of attributes - if len(attrs) != 3 { + if len(attrs) != 4 { return nil } // parse the session duration @@ -63,8 +69,16 @@ func (app *App) SetAttributes(attrs []string) *App { app.Name = attrs[0] app.RedirectURI = attrs[1] app.SessionDuration = duration + appSecretHash, err := hex.DecodeString(attrs[3]) + if err != nil { + return nil + } + if len(appSecretHash) != 12 { + return nil + } + app.AppSecretHash = appSecretHash[:12] // check if the app is valid and return it if it is - if !app.Valid() { + if !app.Valid(nil) { return nil } return app @@ -74,7 +88,7 @@ func (app *App) SetAttributes(attrs []string) *App { // and encoding the app. The resulting string is the app's attributes joined // by the app data separator. func (app *App) String() string { - if !app.Valid() { + if !app.Valid(nil) { return "" } // join the app's attributes with the app data separator @@ -106,7 +120,7 @@ func (app *App) SetBytes(data []byte) *App { // Marshal method returns the app as a base64-encoded byte slice. It is used // to be included in the app ID, which makes it self-contained. func (app *App) Marshal() []byte { - if !app.Valid() { + if !app.Valid(nil) { return nil } bApp := app.Bytes() @@ -129,8 +143,8 @@ func (app *App) Unmarshal(data []byte) *App { // representation of the app that can be used to generate tokens. It is // created by encoding the app as a base64-encoded byte slice using the // Marshal method. -func (app *App) ID() *AppID { - if !app.Valid() { +func (app *App) ID(secret *Secret) *AppID { + if !app.Valid(secret.Hash()) { return nil } return new(AppID).SetBytes(app.Marshal()) @@ -146,3 +160,12 @@ func (app *App) SetID(id *AppID) *App { } return app.Unmarshal(id.Bytes()) } + +func (app *App) SetSecret(secret *Secret) *App { + if app == nil { + return nil + } + // set the app secret hash + app.AppSecretHash = secret.Hash() + return app +} diff --git a/token/app_test.go b/token/app_test.go index 0b2f4a2..0ccf14b 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -20,36 +20,36 @@ func TestValidApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } - if !app.Valid() { + if !app.Valid(nil) { t.Errorf("expected valid app data") } // test app name app.Name = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } app.Name = "no" - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } app.Name = testAppName // test redirect URI app.RedirectURI = "https://example.com/login?app=lorem_ipsum_dolor_sit_amet_consectetur_adipiscing_elit_sed_do_eiusmod_tempor_incididunt_ut_labore_et_dolore_magna_aliqua" - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } app.RedirectURI = "no_url" - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } app.RedirectURI = testRedirectURI // test session duration app.SessionDuration = minDuration - 1 - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } app.SessionDuration = maxDuration + 1 - if app.Valid() { + if app.Valid(nil) { t.Errorf("expected invalid app data") } } @@ -71,6 +71,10 @@ func TestAttributesSetAttributesApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) var nilApp *App nilData := nilApp.SetAttributes(app.Attributes()) if nilData == nil { @@ -109,6 +113,10 @@ func TestStringSetStringApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) data := new(App).SetString(app.String()) if data == nil { t.Fatalf("error decoding app data") @@ -130,6 +138,10 @@ func TestBytesSetBytesApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) data := new(App).SetBytes(app.Bytes()) if data == nil { t.Fatalf("error decoding app data") @@ -157,6 +169,10 @@ func TestMarshalUnmarshalApp(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) data := new(App).Unmarshal(app.Marshal()) if data == nil { t.Fatalf("error decoding app data") @@ -173,7 +189,10 @@ func TestMarshalUnmarshalApp(t *testing.T) { } func TestAppID(t *testing.T) { - if id := new(App).ID(); id != nil { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + if id := new(App).ID(secret); id != nil { t.Errorf("expected nil, got %v", id) } app := &App{ @@ -181,7 +200,8 @@ func TestAppID(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } - id := app.ID() + app.SetSecret(secret) + id := app.ID(secret) if id == nil { t.Fatalf("error decoding app ID") } diff --git a/token/id.go b/token/id.go index 9917eda..e7f0699 100644 --- a/token/id.go +++ b/token/id.go @@ -30,6 +30,9 @@ func (id *AppID) SetString(data string) *AppID { // Bytes method returns the application ID as a byte slice. func (id *AppID) Bytes() []byte { + if id == nil { + return nil + } return []byte(*id) } @@ -41,7 +44,7 @@ func (id *AppID) Bytes() []byte { // valid, the application ID is not set and nil is returned. func (id *AppID) SetBytes(data []byte) *AppID { // check if the application ID is valid - if !new(App).Unmarshal(data).Valid() { + if !new(App).Unmarshal(data).Valid(nil) { return nil } // if no application ID is provided, create a new one diff --git a/token/id_test.go b/token/id_test.go index f8739ec..c82dc6b 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -17,7 +17,11 @@ func TestStringSetStringAppID(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } - id := app.ID() + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) if id == nil { t.Fatalf("error decoding app ID") } @@ -39,7 +43,11 @@ func TestBytesSetBytesAppID(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } - id := app.ID() + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) if id == nil { t.Fatalf("error decoding app ID") } @@ -92,7 +100,11 @@ func TestPrivKeySignVerifyAppID(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: testSessionDuration, } - id := app.ID() + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) if id == nil { t.Fatalf("error decoding app ID") } @@ -135,7 +147,11 @@ func TestGenerateTokenVerifyToken(t *testing.T) { RedirectURI: testRedirectURI, SessionDuration: minDuration, } - id := app.ID() + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) if id == nil { t.Fatalf("error decoding app ID") } diff --git a/token/secret.go b/token/secret.go index 2acdc47..97d3fa7 100644 --- a/token/secret.go +++ b/token/secret.go @@ -38,6 +38,15 @@ func (s *Secret) Bytes() []byte { return []byte(*s) } +func (s *Secret) Hash() []byte { + if s == nil { + return nil + } + // hash the secret to a sha256 size + h := sha256.Sum256(*s) + return h[:12] +} + // Valid method returns true if the secret is valid, false otherwise. A secret // is considered valid if it has more than 1 part, and each part is hashed to // a sha256 size. @@ -47,5 +56,5 @@ func (s *Secret) Valid() bool { } // secret is valid if it has more than 1 part, and each part is hashed // to a sha256 size - return len(*s) > sha256.Size + return len(*s)%sha256.Size == 0 && len(*s) > sha256.Size } From 7febd698e97fbaf1ca108291fbed5bb532b117f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Men=C3=A9ndez?= Date: Sun, 6 Apr 2025 21:37:36 +0200 Subject: [PATCH 29/36] Create LICENSE --- LICENSE | 661 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 661 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. From 1dd1edae136f9a41e849935665359580ea0a7467 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Sun, 6 Apr 2025 22:26:52 +0200 Subject: [PATCH 30/36] make file and docker file for production --- Makefile | 21 ++++++++++++++++----- docker/Dockerfile.prod | 2 ++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index ebd23ff..9e5f3f8 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ .PHONY: run demo clean -demo: clean +demo: clean-demo @echo "Building demo image..." docker build -f docker/Dockerfile.demo -t demo-simpleauthlink . @echo "Running demo container..." - docker run --name demo-simpleauthlink --env-file .env -p ${PORT}:80 -d demo-simpleauthlink + docker run --name demo-simpleauthlink --env-file demo.env -p ${PORT}:80 -d demo-simpleauthlink -clean: checkport +clean-demo: @echo "Cleaning up previous containers and images..." @docker rm -f demo-simpleauthlink 2>/dev/null || true @echo "Containers cleaned up" @@ -14,5 +14,16 @@ clean: checkport @echo "Images cleaned up" @echo "Cleaning up done" -checkport: - @echo "Using default port ${PORT}" \ No newline at end of file +api: clean-api + @echo "Building API image..." + docker build -f docker/Dockerfile.prod -t simpleauthlink . + @echo "Running API container..." + docker run --name simpleauthlink --env-file .env -p ${PORT}:80 simpleauthlink + +clean-api: + @echo "Cleaning up previous containers and images..." + @docker rm -f simpleauthlink 2>/dev/null || true + @echo "Containers cleaned up" + @docker rmi -f simpleauthlink 2>/dev/null || true + @echo "Images cleaned up" + @echo "Cleaning up done" \ No newline at end of file diff --git a/docker/Dockerfile.prod b/docker/Dockerfile.prod index 204a427..dc43de2 100644 --- a/docker/Dockerfile.prod +++ b/docker/Dockerfile.prod @@ -10,6 +10,8 @@ RUN go build -o /authapi ./cmd/authapi/main.go # deploy FROM alpine:latest +EXPOSE 80 + WORKDIR / COPY --from=builder /authapi /authapi From ef0c2d271dded0c5b9a79d8e12f9cc3adf1c2847 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 9 Apr 2025 21:19:40 +0200 Subject: [PATCH 31/36] new tests and fakesmtpserver moved to its own package --- api/service.go | 6 +- api/service_test.go | 4 +- internal/error_test.go | 49 ++++ .../server.go} | 38 ++- internal/fakesmtpserver/server_test.go | 239 ++++++++++++++++++ internal/osflag/osflag_test.go | 136 ++++++++++ notification/email/emailqueue_test.go | 4 +- token/app.go | 7 +- token/app_test.go | 54 +++- token/id_test.go | 8 + token/secret.go | 12 +- token/secret_test.go | 23 ++ 12 files changed, 558 insertions(+), 22 deletions(-) create mode 100644 internal/error_test.go rename internal/{fake_smtp_server.go => fakesmtpserver/server.go} (69%) create mode 100644 internal/fakesmtpserver/server_test.go create mode 100644 internal/osflag/osflag_test.go diff --git a/api/service.go b/api/service.go index 4fa480c..fb53a0a 100644 --- a/api/service.go +++ b/api/service.go @@ -11,7 +11,7 @@ import ( "time" "github.com/lucasmenendez/apihandler" - "github.com/simpleauthlink/authapi/internal" + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" "github.com/simpleauthlink/authapi/notification" ) @@ -34,7 +34,7 @@ type Service struct { handler *apihandler.Handler httpServer *http.Server // demo stuff - demoMailServer *internal.FakeSMTPServer + demoMailServer *fakesmtpserver.FakeSMTPServer demoMailInbox chan string } @@ -53,7 +53,7 @@ func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, err // demo stuff if cfg.DemoMode { srv.demoMailInbox = make(chan string, 1) - srv.demoMailServer = internal.NewFakeSMTPServer(cfg.DemoSMTPAddr, + srv.demoMailServer = fakesmtpserver.NewServer(cfg.DemoSMTPAddr, cfg.DemoSMTPPort, srv.demoMailInbox) if err := srv.demoMailServer.Start(internalCtx); err != nil { return nil, err diff --git a/api/service_test.go b/api/service_test.go index a964701..4703a1e 100644 --- a/api/service_test.go +++ b/api/service_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/simpleauthlink/authapi/internal" + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" "github.com/simpleauthlink/authapi/notification/email" ) @@ -36,7 +36,7 @@ func TestMain(m *testing.M) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // start test SMTP server to receive the email - testSrv := internal.NewFakeSMTPServer(testServerAddr, testServerSMTPPort, inboxChan) + testSrv := fakesmtpserver.NewServer(testServerAddr, testServerSMTPPort, inboxChan) if err := testSrv.Start(ctx); err != nil { panic(err) } diff --git a/internal/error_test.go b/internal/error_test.go new file mode 100644 index 0000000..6cb6e8d --- /dev/null +++ b/internal/error_test.go @@ -0,0 +1,49 @@ +package internal + +import ( + "errors" + "fmt" + "testing" +) + +func TestNewErr(t *testing.T) { + err := NewErr("test message") + if err.msg != "test message" { + t.Errorf("expected message 'test message', got '%s'", err.msg) + } + if err.trace != nil { + t.Errorf("expected nil trace, got '%v'", err.trace) + } +} + +func TestError(t *testing.T) { + err := NewErr("test message") + if err.Error() != "test message" { + t.Errorf("expected 'test message', got '%s'", err.Error()) + } + + wrappedErr := errors.New("wrapped error") + _ = err.With(wrappedErr) + expected := "test message: wrapped error" + if err.Error() != expected { + t.Errorf("expected '%s', got '%s'", expected, err.Error()) + } +} + +func TestWith(t *testing.T) { + err := NewErr("test message") + wrappedErr := errors.New("wrapped error") + _ = err.With(wrappedErr) + if err.trace != wrappedErr { + t.Errorf("expected trace '%v', got '%v'", wrappedErr, err.trace) + } +} + +func TestWithf(t *testing.T) { + err := NewErr("test message") + _ = err.Withf("formatted %s", "error") + expectedTrace := fmt.Errorf("formatted %s", "error").Error() + if err.trace.Error() != expectedTrace { + t.Errorf("expected trace '%s', got '%s'", expectedTrace, err.trace.Error()) + } +} diff --git a/internal/fake_smtp_server.go b/internal/fakesmtpserver/server.go similarity index 69% rename from internal/fake_smtp_server.go rename to internal/fakesmtpserver/server.go index e15f5ba..bbd3ee2 100644 --- a/internal/fake_smtp_server.go +++ b/internal/fakesmtpserver/server.go @@ -1,4 +1,4 @@ -package internal +package fakesmtpserver import ( "bufio" @@ -6,6 +6,7 @@ import ( "fmt" "net" "strings" + "sync" ) // FakeSMTPServer represents a simple SMTP testing server. @@ -13,19 +14,25 @@ type FakeSMTPServer struct { addr string inbox chan string listener net.Listener + mu sync.Mutex // Mutex to protect listener } -// NewFakeSMTPServer creates a new FakeSMTPServer instance that listens on the -// given address and port and stores the received emails in the inbox channel +// NewServer creates a new FakeSMTPServer instance that listens on the given +// address and port and stores the received emails in the inbox channel // provided. -func NewFakeSMTPServer(addr string, port int, inbox chan string) *FakeSMTPServer { - return &FakeSMTPServer{addr: fmt.Sprintf("%s:%d", addr, port), inbox: inbox} +func NewServer(addr string, port int, inbox chan string) *FakeSMTPServer { + return &FakeSMTPServer{ + addr: fmt.Sprintf("%s:%d", addr, port), + inbox: inbox, + } } // Start method launches the test SMTP server. func (s *FakeSMTPServer) Start(ctx context.Context) error { var err error + s.mu.Lock() s.listener, err = net.Listen("tcp", s.addr) + s.mu.Unlock() if err != nil { return err } @@ -33,12 +40,19 @@ func (s *FakeSMTPServer) Start(ctx context.Context) error { for { select { case <-ctx.Done(): - s.listener.Close() + s.Stop() // Use Stop to safely close the listener + return default: - conn, err := s.listener.Accept() - if err != nil { + s.mu.Lock() + listener := s.listener // Copy listener under lock + s.mu.Unlock() + if listener == nil { return } + conn, err := listener.Accept() + if err != nil { + continue + } go s.handleConn(conn) } } @@ -48,7 +62,13 @@ func (s *FakeSMTPServer) Start(ctx context.Context) error { // Stop method shuts down the test SMTP server. func (s *FakeSMTPServer) Stop() { - s.listener.Close() + s.mu.Lock() + listener := s.listener // Copy listener under lock + s.listener = nil // Set listener to nil under lock + s.mu.Unlock() + if listener != nil { + listener.Close() // Close listener outside the lock + } } func (s *FakeSMTPServer) handleConn(conn net.Conn) { diff --git a/internal/fakesmtpserver/server_test.go b/internal/fakesmtpserver/server_test.go new file mode 100644 index 0000000..de6e351 --- /dev/null +++ b/internal/fakesmtpserver/server_test.go @@ -0,0 +1,239 @@ +package fakesmtpserver + +import ( + "bufio" + "context" + "net" + "strconv" + "strings" + "testing" + "time" +) + +func getFreePort() (string, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", err + } + defer listener.Close() + return listener.Addr().String(), nil +} + +func splitHostPort(address string) (string, int, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, err + } + return host, port, nil +} + +func TestFakeSMTPServer(t *testing.T) { + inbox := make(chan string, 1) + address, err := getFreePort() + if err != nil { + t.Fatalf("Failed to get free port: %v", err) + } + host, port, err := splitHostPort(address) + if err != nil { + t.Fatalf("Failed to split host and port: %v", err) + } + server := NewServer(host, port, inbox) + ctx, cancel := context.WithCancel(t.Context()) // Fixed incorrect t.Context() + defer cancel() + + errChan := make(chan error, 1) // Channel to capture errors from the goroutine + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err // Send error to the channel + } + close(errChan) // Close the channel when done + }() + time.Sleep(100 * time.Millisecond) // Give the server time to start + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + default: + // No error, continue with the test + } + + // Connect to the server and send an email + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + // Read greeting + if greeting, _ := reader.ReadString('\n'); !strings.HasPrefix(greeting, "220") { + t.Fatalf("Expected greeting, got: %s", greeting) + } + + // Send HELO + if _, err := writer.WriteString("HELO localhost\r\n"); err != nil { + t.Fatalf("Failed to write HELO command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected HELO response, got: %s", response) + } + + // Send MAIL FROM + if _, err := writer.WriteString("MAIL FROM:\r\n"); err != nil { + t.Fatalf("Failed to write MAIL FROM command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected MAIL FROM response, got: %s", response) + } + + // Send RCPT TO + if _, err := writer.WriteString("RCPT TO:\r\n"); err != nil { + t.Fatalf("Failed to write RCPT TO command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected RCPT TO response, got: %s", response) + } + + // Send DATA + if _, err := writer.WriteString("DATA\r\n"); err != nil { + t.Fatalf("Failed to write DATA command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "354") { + t.Fatalf("Expected DATA response, got: %s", response) + } + + // Send email content + if _, err := writer.WriteString("Subject: Test Email\r\n\r\nThis is a test email.\r\n.\r\n"); err != nil { + t.Fatalf("Failed to write email content: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected email content response, got: %s", response) + } + + // Send QUIT + if _, err := writer.WriteString("QUIT\r\n"); err != nil { + t.Fatalf("Failed to write QUIT command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "221") { + t.Fatalf("Expected QUIT response, got: %s", response) + } + + // Verify email content in inbox + select { + case email := <-inbox: + if !strings.Contains(email, "Subject: Test Email") || !strings.Contains(email, "This is a test email.") { + t.Fatalf("Unexpected email content: %s", email) + } + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for email in inbox") + } + + // Stop the server + server.Stop() +} + +func TestFakeSMTPServer_UnsupportedCommand(t *testing.T) { + inbox := make(chan string, 1) + address, err := getFreePort() + if err != nil { + t.Fatalf("Failed to get free port: %v", err) + } + host, port, err := splitHostPort(address) + if err != nil { + t.Fatalf("Failed to split host and port: %v", err) + } + server := NewServer(host, port, inbox) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errChan := make(chan error, 1) + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err + } + close(errChan) + }() + time.Sleep(100 * time.Millisecond) + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + default: + } + + // Connect to the server and send an unsupported command + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + // Read greeting + if greeting, _ := reader.ReadString('\n'); !strings.HasPrefix(greeting, "220") { + t.Fatalf("Expected greeting, got: %s", greeting) + } + + // Send unsupported command + if _, err := writer.WriteString("FOO BAR\r\n"); err != nil { + t.Fatalf("Failed to write unsupported command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected default response, got: %s", response) + } + + // Stop the server + server.Stop() +} + +func TestFakeSMTPServer_BadAddress(t *testing.T) { + inbox := make(chan string, 1) + server := NewServer("invalid-address", 2527, inbox) // Added a dummy port + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + errChan := make(chan error, 1) + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err + } + close(errChan) + }() + time.Sleep(100 * time.Millisecond) + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err == nil { + t.Fatalf("Expected error starting server with bad address, got nil") + } + default: + } +} diff --git a/internal/osflag/osflag_test.go b/internal/osflag/osflag_test.go new file mode 100644 index 0000000..a70fbbd --- /dev/null +++ b/internal/osflag/osflag_test.go @@ -0,0 +1,136 @@ +package osflag + +import ( + "flag" + "os" + "testing" + "time" +) + +func resetCommandLine() { + CommandLine = new(OsFlagSet) + CommandLine.FlagSet = flag.NewFlagSet("", flag.ExitOnError) + CommandLine.required = make(map[string]bool) + + // Filter out test framework flags + os.Args = os.Args[:1] +} + +func TestBoolVar(t *testing.T) { + resetCommandLine() + var flagValue bool + os.Setenv("TEST_BOOL", "true") + defer os.Unsetenv("TEST_BOOL") + + CommandLine.BoolVar(&flagValue, "TEST_BOOL", "boolFlag", false, "A boolean flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if !flagValue { + t.Errorf("Expected true, got %v", flagValue) + } +} + +func TestDurationVar(t *testing.T) { + resetCommandLine() + var flagValue time.Duration + os.Setenv("TEST_DURATION", "5s") + defer os.Unsetenv("TEST_DURATION") + + CommandLine.DurationVar(&flagValue, "TEST_DURATION", "durationFlag", 0, "A duration flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 5*time.Second { + t.Errorf("Expected 5s, got %v", flagValue) + } +} + +func TestFloat64Var(t *testing.T) { + resetCommandLine() + var flagValue float64 + os.Setenv("TEST_FLOAT", "3.14") + defer os.Unsetenv("TEST_FLOAT") + + CommandLine.Float64Var(&flagValue, "TEST_FLOAT", "floatFlag", 0.0, "A float flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 3.14 { + t.Errorf("Expected 3.14, got %v", flagValue) + } +} + +func TestIntVar(t *testing.T) { + resetCommandLine() + var flagValue int + os.Setenv("TEST_INT", "42") + defer os.Unsetenv("TEST_INT") + + CommandLine.IntVar(&flagValue, "TEST_INT", "intFlag", 0, "An int flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 42 { + t.Errorf("Expected 42, got %v", flagValue) + } +} + +func TestStringVar(t *testing.T) { + resetCommandLine() + var flagValue string + os.Setenv("TEST_STRING", "hello") + defer os.Unsetenv("TEST_STRING") + + CommandLine.StringVar(&flagValue, "TEST_STRING", "stringFlag", "default", "A string flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != "hello" { + t.Errorf("Expected 'hello', got %v", flagValue) + } +} + +func TestUintVar(t *testing.T) { + resetCommandLine() + var flagValue uint + os.Setenv("TEST_UINT", "100") + defer os.Unsetenv("TEST_UINT") + + CommandLine.UintVar(&flagValue, "TEST_UINT", "uintFlag", 0, "A uint flag", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 100 { + t.Errorf("Expected 100, got %v", flagValue) + } +} + +func TestRequiredFlag(t *testing.T) { + resetCommandLine() + var flagValue string + CommandLine.StringVar(&flagValue, "", "requiredFlag", "", "A required flag", true) + + if err := CommandLine.Parse(); err == nil { + t.Errorf("Expected error for missing required flag, got nil") + } +} + +func TestDefaultValues(t *testing.T) { + resetCommandLine() + var flagValue string + CommandLine.StringVar(&flagValue, "", "defaultFlag", "defaultValue", "A flag with a default value", false) + if err := CommandLine.Parse(); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != "defaultValue" { + t.Errorf("Expected 'defaultValue', got %v", flagValue) + } +} diff --git a/notification/email/emailqueue_test.go b/notification/email/emailqueue_test.go index feb4a1c..8c3ad2a 100644 --- a/notification/email/emailqueue_test.go +++ b/notification/email/emailqueue_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/simpleauthlink/authapi/internal" + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" "github.com/simpleauthlink/authapi/notification" ) @@ -29,7 +29,7 @@ func TestMain(m *testing.M) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // start test SMTP server to receive the email - testSrv := internal.NewFakeSMTPServer(testServerAddr, testServerPort, inboxChan) + testSrv := fakesmtpserver.NewServer(testServerAddr, testServerPort, inboxChan) if err := testSrv.Start(ctx); err != nil { panic(err) } diff --git a/token/app.go b/token/app.go index c8dd65c..a952a14 100644 --- a/token/app.go +++ b/token/app.go @@ -73,10 +73,10 @@ func (app *App) SetAttributes(attrs []string) *App { if err != nil { return nil } - if len(appSecretHash) != 12 { + if len(appSecretHash) != secretHashSize { return nil } - app.AppSecretHash = appSecretHash[:12] + app.AppSecretHash = appSecretHash // check if the app is valid and return it if it is if !app.Valid(nil) { return nil @@ -165,6 +165,9 @@ func (app *App) SetSecret(secret *Secret) *App { if app == nil { return nil } + if secret == nil { + return app + } // set the app secret hash app.AppSecretHash = secret.Hash() return app diff --git a/token/app_test.go b/token/app_test.go index 0ccf14b..a68acd0 100644 --- a/token/app_test.go +++ b/token/app_test.go @@ -2,6 +2,7 @@ package token import ( "bytes" + "encoding/hex" "testing" "time" ) @@ -52,20 +53,36 @@ func TestValidApp(t *testing.T) { if app.Valid(nil) { t.Errorf("expected invalid app data") } + var nilApp *App + if nilApp.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.AppSecretHash = testAppSecret.Hash() + servicePart := []byte("invalid-service-secret") + appPart := []byte("invalid-app-secret") + invalidSecret := new(Secret).SetParts(servicePart, appPart) + if app.Valid(invalidSecret.Hash()) { + t.Errorf("expected invalid app data") + } } func TestAttributesSetAttributesApp(t *testing.T) { if res := new(App).SetAttributes([]string{}); res != nil { t.Errorf("expected nil, got %v", res) } - if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, "no_duration"}); res != nil { + strHash := hex.EncodeToString(testAppSecret.Hash()) + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, "no_duration", strHash}); res != nil { t.Errorf("expected nil, got %v", res) } - if res := new(App).SetAttributes([]string{testAppName, "no_url", testSessionDuration.String()}); res != nil { t.Errorf("expected nil, got %v", res) } - + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, testSessionDuration.String(), "invalid_hash"}); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, testSessionDuration.String(), "2bbc94cb9c916e1f6f1354ef30c1c80767b85159570304baa402c088180a0ec5"}); res != nil { + t.Errorf("expected nil, got %v", res) + } app := &App{ Name: testAppName, RedirectURI: testRedirectURI, @@ -102,6 +119,12 @@ func TestAttributesSetAttributesApp(t *testing.T) { if data.SessionDuration != testSessionDuration { t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) } + // set an out of range duration to test if the app is valid + attrs := app.Attributes() + attrs[2] = "1s" + if res := new(App).SetAttributes(attrs); res != nil { + t.Errorf("expected nil, got %v", res) + } } func TestStringSetStringApp(t *testing.T) { @@ -219,3 +242,28 @@ func TestAppID(t *testing.T) { t.Errorf("expected %s, got %s", app.String(), newApp.String()) } } + +func TestSetSecretApp(t *testing.T) { + var nilApp *App + if res := nilApp.SetSecret(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + if res := app.SetSecret(nil); res == nil { + t.Errorf("expected nil, got %v", res) + } + if app.AppSecretHash != nil { + t.Errorf("expected nil, got %v", app.AppSecretHash) + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + if !bytes.Equal(app.AppSecretHash, secret.Hash()) { + t.Errorf("expected %v, got %v", secret.Hash(), app.AppSecretHash) + } +} diff --git a/token/id_test.go b/token/id_test.go index c82dc6b..8419a2d 100644 --- a/token/id_test.go +++ b/token/id_test.go @@ -69,6 +69,14 @@ func TestBytesSetBytesAppID(t *testing.T) { if !bytes.Equal(nilID.Bytes(), id.Bytes()) { t.Errorf("expected %v, got %v", id.Bytes(), nilID.Bytes()) } + // nil app ID + if nilID = new(AppID).SetBytes(nil); nilID != nil { + t.Errorf("expected nil, got %v", nilID) + } + var nilAppID *AppID + if bNilAppID := nilAppID.Bytes(); bNilAppID != nil { + t.Errorf("expected nil, got %v", bNilAppID) + } } func TestPrivKeySignVerifyAppID(t *testing.T) { diff --git a/token/secret.go b/token/secret.go index 97d3fa7..9b54921 100644 --- a/token/secret.go +++ b/token/secret.go @@ -2,6 +2,12 @@ package token import "crypto/sha256" +// secretHashSize is the size of the secret hash. It is used to determine +// the size of the secret when it is hashed. The hash is created by hashing +// the secret to a sha256 size. The hash is used to sign and verify tokens. +// The hash is also used to create the app ID and it is part of it. +const secretHashSize = 12 + // Secret represents a secret that is used to sign and verify tokens. It is // a wrapper around a byte slice that provides additional methods for setting // and getting the secret. It should have at least 2 parts, each hashed to a @@ -38,13 +44,17 @@ func (s *Secret) Bytes() []byte { return []byte(*s) } +// Hash method returns the hash of the secret as a byte slice. The hash is +// created by hashing the secret to a sha256 size. The hash is used to create +// the app ID and it is part of it, but is also used to sign and verify the +// user sessions in the token generation process. func (s *Secret) Hash() []byte { if s == nil { return nil } // hash the secret to a sha256 size h := sha256.Sum256(*s) - return h[:12] + return h[:secretHashSize] } // Valid method returns true if the secret is valid, false otherwise. A secret diff --git a/token/secret_test.go b/token/secret_test.go index 4dbac49..981819f 100644 --- a/token/secret_test.go +++ b/token/secret_test.go @@ -51,3 +51,26 @@ func TestValidSecret(t *testing.T) { t.Errorf("expected true, got false") } } + +func TestSecretHash(t *testing.T) { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + h := secret.Hash() + if h == nil { + t.Errorf("expected hash, got nil") + } + if len(h) != secretHashSize { + t.Errorf("expected hash size %d, got %d", secretHashSize, len(h)) + } + hSecret := sha256.Sum256(secret.Bytes()) + if !bytes.Equal(h, hSecret[:secretHashSize]) { + t.Errorf("expected %x, got %x", hSecret[:secretHashSize], h) + } + // try to hash a nil secret + var nilSecret *Secret + h = nilSecret.Hash() + if h != nil { + t.Errorf("expected nil, got %x", h) + } +} From d3f94c04e25e530426597d3c1450419a4bfa6494 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Tue, 15 Apr 2025 23:58:55 +0200 Subject: [PATCH 32/36] first README --- README.md | 113 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d1e9f50..2496487 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,112 @@ -# SimpleAuthLink API +[![Last release](https://img.shields.io/github/v/release/simpleauthlink/authapi?color=purple)](https://github.com/simpleauthlink/authapi/releases/latest) +[![GoDoc](https://godoc.org/github.com/simpleauthlink/authapi?status.svg)](https://godoc.org/github.com/simpleauthlink/authapi) +[![Go Report Card](https://goreportcard.com/badge/github.com/simpleauthlink/authapi)](https://goreportcard.com/report/github.com/simpleauthlink/authapi) +[![Build and Test](https://github.com/simpleauthlink/authapi/actions/workflows/main.yml/badge.svg?branch=main)](https://github.com/simpleauthlink/authapi/actions/workflows/main.yml) +[![license](https://img.shields.io/github/license/simpleauthlink/authapi)](LICENSE) -WIP document, read more in the project website: [https://simpleauth.link](https://simpleauth.link). \ No newline at end of file + + + +# SimpleAuth.link API + +> Passwordless authentication for your users using just an email address. + +This repository contains the source code of the SimpleAuth.link API Service. + +Read full [documentation here](https://docs.simpleauth.link). + +--- + +## Technical Details 💻 + +### Token Generation Process 🔑 + +* By leveraging the Ed25519 signature algorithm, the service deterministically generates a private key using your App ID and secret. +* This ensures that each token is cryptographically secure and uniquely tied to your application, eliminating the need to store sensitive keys. + +### Stateless Architecture 🕊️ + +* SimpleAuth.link works without a traditional database. It does not store any user data, including email addresses, on its servers. +* Instead, the data generated is self-contained, requiring no further information or state to be used. This stateless design increases security and reduces the risk of data breaches. + +## Development 🧑‍💻 + +### Prerequisites 📝 + - Go (version 1.24 or later recommended) + - Docker (optional, for containerized deployment) + +### Clone the Repository 📥 +```sh +git clone --branch v2 https://github.com/SimpleAuthLink/authapi.git +cd authapi +``` + +### Code Structure 🪜 +* Modularity: The repository is structured into packages to simplify testing and future expansion. + - `api/`: Contains the core API endpoint definitions and routing logic. + - `cmd/`: Entry points and command-line utilities for running the API server. + - `docker/`: Dockerfiles and configuration for containerized builds. + - `internal/`: Internal packages and libraries used across the application. + - `notification/`: Code handling user notifications. + - `token/`: Modules for token creation and management. + - `.github/`: GitHub-specific workflows and configurations (CI/CD, issue templates, etc.). + + This layout supports modular development and clear separation of concerns across different parts of the API service. + +* Testing: Use the standard Go testing framework to run tests. You can run tests with: + ```sh + go test ./... + ``` + +### Run with go 🦫 +For development purposes, you can run the API server directly with Go. +```sh +go run ./cmd/authapi -h +``` +```sh +Usage of authapi: + -email-addr string + email account address + -email-host string + email server host + -email-pass string + email account password + -email-port int + email server port (default 587) + -host string + service host (default "0.0.0.0") + -port int + service host (default 8080) + -secret string + secret used to generate the tokens (default "simpleauthlink-secret") +``` + +### Run with docker 🐳 + +1. **Prepare the Environment File** + + Copy the `example.env` file to `.env` and edit the file to fill in your parameters: + ```bash + HOST="localhost" + PORT=8080 + EMAIL_ADDR="test@test.com" + EMAIL_PASS="smtp_server_password" + EMAIL_HOST="smtp.example.com" + EMAIL_PORT=587 + SECRET="my_backend_secret" + ``` + +2. **Build the Docker Image** + + Run the following command in the root of your project to build the image: + ```bash + docker build -f docker/Dockerfile.prod -t simpleauthlink . + ``` + +3. **Run the Docker Container** + + Once the image is built, start a container using the environment file: + ```bash + docker run --name simpleauthlink --env-file .env -p 8080:80 simpleauthlink + ``` + This command maps the container’s port 80 to port 8080 on your host. \ No newline at end of file From 306a2598cead43bf04ca75fc11b50c0282bf9251 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 16 Apr 2025 01:34:06 +0200 Subject: [PATCH 33/36] new api/io package with more tests --- api/error.go | 79 ------------------- api/errors.go | 27 +++++++ api/handlers.go | 15 ++-- api/handlers_test.go | 43 ++++++----- api/helpers.go | 15 ++-- api/helpers_test.go | 6 +- api/io.go | 95 ----------------------- api/io/error.go | 88 +++++++++++++++++++++ api/io/error_test.go | 90 ++++++++++++++++++++++ api/io/req.go | 39 ++++++++++ api/io/req_test.go | 57 ++++++++++++++ api/io/res.go | 113 +++++++++++++++++++++++++++ api/io/res_test.go | 179 +++++++++++++++++++++++++++++++++++++++++++ api/io_test.go | 126 ------------------------------ api/routes.go | 20 ++++- api/service.go | 51 ++++++++++-- 16 files changed, 695 insertions(+), 348 deletions(-) delete mode 100644 api/error.go create mode 100644 api/errors.go delete mode 100644 api/io.go create mode 100644 api/io/error.go create mode 100644 api/io/error_test.go create mode 100644 api/io/req.go create mode 100644 api/io/req_test.go create mode 100644 api/io/res.go create mode 100644 api/io/res_test.go delete mode 100644 api/io_test.go diff --git a/api/error.go b/api/error.go deleted file mode 100644 index 93e54ad..0000000 --- a/api/error.go +++ /dev/null @@ -1,79 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" -) - -var ( - // Decode data errors - DecodeAppIDRequestErr = newApiErr(1001, http.StatusBadRequest).With("could not decode app id request") - DecodeTokenRequestErr = newApiErr(1002, http.StatusBadRequest).With("could not decode token request") - DecodeTokenStatusRequestErr = newApiErr(1003, http.StatusBadRequest).With("could not decode token status request") - // Encode data errors - EncodeAppIDResponseErr = newApiErr(1010, http.StatusInternalServerError).With("could not encode app id response") - EncodeTokenStatusResponseErr = newApiErr(1011, http.StatusInternalServerError).With("could not encode token status response") - // Bad request errors - InvalidAppHeadersErr = newApiErr(1020, http.StatusBadRequest).With("invalid app headers") - InvalidAppIDErr = newApiErr(1021, http.StatusBadRequest).With("invalid app id") - InvalidAppSecretErr = newApiErr(1022, http.StatusBadRequest).With("invalid app secret") - InvalidDemoEmailInboxErr = newApiErr(1023, http.StatusBadRequest).With("invalid demo email inbox") - // Internal errors - GenerateTokenErr = newApiErr(1030, http.StatusInternalServerError).With("could not generate token") - GenerateEmailErr = newApiErr(1031, http.StatusInternalServerError).With("could not generate email") - SendEmailErr = newApiErr(1032, http.StatusInternalServerError).With("could not send email") - InternalErr = newApiErr(1033, http.StatusInternalServerError).With("internal server error") -) - -type APIError struct { - Code int `json:"code"` - Message string `json:"message"` - Err string `json:"error,omitempty"` - statusCode int -} - -func (e *APIError) Bytes() []byte { - bErr, err := json.Marshal(e) - if err != nil { - return nil - } - return bErr -} - -func (e *APIError) Error() string { - return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.statusCode) -} - -func (e *APIError) WithErr(err error) *APIError { - if e.Err == "" { - e.Err = err.Error() - return e - } - e.Err = fmt.Sprintf("%s: %s", e.Err, err.Error()) - return e -} - -func (e *APIError) With(msg string) *APIError { - if e.Message == "" { - e.Message = msg - return e - } - e.Message = fmt.Sprintf("%s: %s", e.Message, msg) - return e -} - -func (e *APIError) Write(w http.ResponseWriter) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(e.statusCode) - if _, err := w.Write(e.Bytes()); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func newApiErr(code, status int) *APIError { - return &APIError{ - Code: code, - statusCode: status, - } -} diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 0000000..6613cf8 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,27 @@ +package api + +import ( + "net/http" + + "github.com/simpleauthlink/authapi/api/io" +) + +var ( + // Decode data errors + DecodeAppIDRequestErr = io.NewAPIError(1001, http.StatusBadRequest).With("could not decode app id request") + DecodeTokenRequestErr = io.NewAPIError(1002, http.StatusBadRequest).With("could not decode token request") + DecodeTokenStatusRequestErr = io.NewAPIError(1003, http.StatusBadRequest).With("could not decode token status request") + // Encode data errors + EncodeAppIDResponseErr = io.NewAPIError(1010, http.StatusInternalServerError).With("could not encode app id response") + EncodeTokenStatusResponseErr = io.NewAPIError(1011, http.StatusInternalServerError).With("could not encode token status response") + // Bad request errors + InvalidAppHeadersErr = io.NewAPIError(1020, http.StatusBadRequest).With("invalid app headers") + InvalidAppIDErr = io.NewAPIError(1021, http.StatusBadRequest).With("invalid app id") + InvalidAppSecretErr = io.NewAPIError(1022, http.StatusBadRequest).With("invalid app secret") + InvalidDemoEmailInboxErr = io.NewAPIError(1023, http.StatusBadRequest).With("invalid demo email inbox") + // Internal errors + GenerateTokenErr = io.NewAPIError(1030, http.StatusInternalServerError).With("could not generate token") + GenerateEmailErr = io.NewAPIError(1031, http.StatusInternalServerError).With("could not generate email") + SendEmailErr = io.NewAPIError(1032, http.StatusInternalServerError).With("could not send email") + InternalErr = io.NewAPIError(1033, http.StatusInternalServerError).With("internal server error") +) diff --git a/api/handlers.go b/api/handlers.go index f407649..95f0ae4 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" + "github.com/simpleauthlink/authapi/api/io" "github.com/simpleauthlink/authapi/notification" "github.com/simpleauthlink/authapi/notification/templates/login" "github.com/simpleauthlink/authapi/token" @@ -18,7 +19,7 @@ import ( // session duration. func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { // decode the app data from the request body - req := new(Request[AppIDRequest]) + req := new(io.Request[AppIDRequest]) if err := req.Read(r); err != nil { DecodeAppIDRequestErr.WithErr(err).Write(w) return @@ -32,7 +33,7 @@ func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { return } // return the app id - ResponseWith(&AppIDResponse{app.ID(secret).String()}).WriteJSON(w) + io.ResponseWith(&AppIDResponse{app.ID(secret).String()}).WriteJSON(w) } func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -57,7 +58,7 @@ func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { return } // decode the token request from the request body - req := new(Request[TokenRequest]) + req := new(io.Request[TokenRequest]) if err := req.Read(r); err != nil { DecodeTokenRequestErr.WithErr(err).Write(w) return @@ -88,7 +89,7 @@ func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { SendEmailErr.WithErr(err).Write(w) return } - OkResponse().WriteJSON(w) + io.OkResponse().WriteJSON(w) } func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { @@ -113,7 +114,7 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { return } // decode the token status request from the request body - req := new(Request[TokenStatusRequest]) + req := new(io.Request[TokenStatusRequest]) if err := req.Read(r); err != nil { DecodeTokenStatusRequestErr.WithErr(err).Write(w) return @@ -121,14 +122,14 @@ func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { // check if the token is valid tkn := new(token.Token).SetString(req.Data.Token) exp := tkn.Expiration().Time() - ResponseWith(&TokenStatusResponse{ + io.ResponseWith(&TokenStatusResponse{ Valid: appID.VerifyToken(*tkn, *secret, req.Data.Email), Expiration: exp, }).WriteJSON(w) } func (s *Service) healthCheckHandler(w http.ResponseWriter, r *http.Request) { - OkResponse().Write(w) + io.OkResponse().Write(w) } func (s *Service) demoInboxHandler(w http.ResponseWriter, r *http.Request) { diff --git a/api/handlers_test.go b/api/handlers_test.go index 65d5c81..bc55131 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + apiio "github.com/simpleauthlink/authapi/api/io" "github.com/simpleauthlink/authapi/notification/email" "github.com/simpleauthlink/authapi/notification/templates/login" "github.com/simpleauthlink/authapi/token" @@ -21,7 +22,7 @@ type testCaseAPIHandler[ReqType, ResType any] struct { header http.Header request *ReqType response *ResType - err *APIError + err *apiio.APIError } func (testCase testCaseAPIHandler[Rq, Rs]) url() string { @@ -50,10 +51,10 @@ func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T) { defer resp.Body.Close() switch { case testCase.err != nil: - if resp.StatusCode != testCase.err.statusCode { - t.Fatalf("expected status code: %d, got: %d", testCase.err.statusCode, resp.StatusCode) + if resp.StatusCode != testCase.err.StatusCode { + t.Fatalf("expected status code: %d, got: %d", testCase.err.StatusCode, resp.StatusCode) } - err := new(APIError) + err := new(apiio.APIError) if err := json.NewDecoder(resp.Body).Decode(err); err != nil { t.Fatalf("could not decode error response: %v", err) } @@ -141,7 +142,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppSecretHeader: []string{testAppSecret}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenRequest{ Email: testUserEmail, @@ -154,8 +155,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{"invalid"}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{"invalid"}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenRequest{ Email: testUserEmail, @@ -168,7 +169,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, + appIDHeader: []string{testAppID.String()}, }, request: &TokenRequest{ Email: testUserEmail, @@ -181,8 +182,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenRequest{ Email: "", @@ -196,8 +197,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, }, request: &invalid, response: nil, @@ -213,8 +214,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPost, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenRequest{ Email: testUserEmail, @@ -239,8 +240,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPut, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenStatusRequest{ Token: testToken.String(), @@ -257,8 +258,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPut, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{"invalid"}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{"invalid"}, + appSecretHeader: []string{testAppSecret}, }, request: &TokenStatusRequest{ Token: testToken.String(), @@ -273,7 +274,7 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPut, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, + appIDHeader: []string{testAppID.String()}, }, request: &TokenStatusRequest{ Token: testToken.String(), @@ -300,8 +301,8 @@ func TestRequestTokenAndStatusHandler(t *testing.T) { method: http.MethodPut, endpoint: TokensPath, header: http.Header{ - AppIDHeader: []string{testAppID.String()}, - AppSecretHeader: []string{testAppSecret}, + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, }, err: DecodeTokenStatusRequestErr, }.Run(t) diff --git a/api/helpers.go b/api/helpers.go index 2635ef6..3511b7d 100644 --- a/api/helpers.go +++ b/api/helpers.go @@ -5,18 +5,19 @@ import ( "net/http" ) -const ( - AppIDHeader = "AppID" - AppSecretHeader = "AppSecret" -) - +// appConfigFromRequest extracts the app id and app secret from the request +// headers. It returns an error if the app id or app secret is missing. The +// app id and app secret are used to authenticate the app making the request. +// The app id is a unique identifier for the app, and the app secret is a +// shared secret used to verify the authenticity of the request for this +// service. func appConfigFromRequest(r *http.Request) (string, string, error) { // get the app id from the request header - strAppID := r.Header.Get(AppIDHeader) + strAppID := r.Header.Get(appIDHeader) if strAppID == "" { return "", "", fmt.Errorf("missing app id") } - strAppSecret := r.Header.Get(AppSecretHeader) + strAppSecret := r.Header.Get(appSecretHeader) if strAppSecret == "" { return "", "", fmt.Errorf("missing app secret") } diff --git a/api/helpers_test.go b/api/helpers_test.go index d385de5..8d0c9e8 100644 --- a/api/helpers_test.go +++ b/api/helpers_test.go @@ -16,19 +16,19 @@ func TestAppConfigFromRequest(t *testing.T) { }{ { name: "Valid headers", - headers: map[string]string{AppIDHeader: "testAppID", AppSecretHeader: "testAppSecret"}, + headers: map[string]string{appIDHeader: "testAppID", appSecretHeader: "testAppSecret"}, expectedAppID: "testAppID", expectedSecret: "testAppSecret", expectError: false, }, { name: "Missing app id", - headers: map[string]string{AppSecretHeader: "testAppSecret"}, + headers: map[string]string{appSecretHeader: "testAppSecret"}, expectError: true, }, { name: "Missing app secret", - headers: map[string]string{AppIDHeader: "testAppID"}, + headers: map[string]string{appIDHeader: "testAppID"}, expectError: true, }, { diff --git a/api/io.go b/api/io.go deleted file mode 100644 index b1429c8..0000000 --- a/api/io.go +++ /dev/null @@ -1,95 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io" - "net/http" -) - -type Request[T any] struct { - Data T -} - -func (req *Request[T]) Read(r *http.Request) error { - if req == nil { - req = new(Request[T]) - } - if r.Body == nil { - return fmt.Errorf("nil request body") - } - rawBody, err := io.ReadAll(r.Body) - if err != nil { - return err - } - if len(rawBody) == 0 { - return fmt.Errorf("empty request body") - } - return json.Unmarshal(rawBody, &req.Data) -} - -type Response[T any] struct { - Data T - empty bool -} - -func ResponseWith[T any](data *T) *Response[T] { - if data == nil { - return &Response[T]{empty: true} - } - return &Response[T]{ - Data: *data, - empty: false, - } -} - -func OkResponse(body ...byte) *Response[any] { - if len(body) > 0 { - return &Response[any]{Data: body, empty: false} - } - return &Response[any]{empty: true} -} - -func (r *Response[T]) WriteJSON(w http.ResponseWriter) { - if !r.empty { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(r.Data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - } - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("OK")); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func (r *Response[T]) Write(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - if r.empty { - if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return - } - if data, ok := r.bytes(); ok { - if _, err := w.Write(data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - } -} - -func (r *Response[T]) bytes() ([]byte, bool) { - // check if the response is empty - if r.empty { - return nil, true - } - // ensure that the response data is an slice of bytes - switch v := any(r.Data).(type) { - case []byte: - return v, true - default: - return nil, false - } -} diff --git a/api/io/error.go b/api/io/error.go new file mode 100644 index 0000000..d5b77c6 --- /dev/null +++ b/api/io/error.go @@ -0,0 +1,88 @@ +package io + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// APIError represents an error response from the API. It includes a code, +// message, and optional error string. The StatusCode field is used to set the +// HTTP status code for the response. It has some methods to manipulate the +// error message and write the error response to an http.ResponseWriter. +type APIError struct { + Code int `json:"code"` + Message string `json:"message"` + Err string `json:"error,omitempty"` + StatusCode int `json:"-"` +} + +// Error implements the error interface for APIError. It returns a string +// representation of the error, including the code, message, error string, +// and status code. It can be used for logging or debugging purposes. +func (e *APIError) Error() string { + return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.StatusCode) +} + +// WithErr appends an error message to the existing error string in the +// APIError. If the existing error string is empty, it sets it to the new +// error message. This method is useful for chaining error messages together +// for better debugging and logging. It returns the updated APIError instance +// and also updates the current instance. +func (e *APIError) WithErr(err error) *APIError { + if e.Err == "" { + e.Err = err.Error() + return e + } + e.Err = fmt.Sprintf("%s: %s", e.Err, err.Error()) + return e +} + +// With appends a string message to the existing message in the APIError. If +// the existing message is empty, it sets it to the new message. This method +// is useful for chaining messages together for better debugging and logging. +// It returns the updated APIError instance and also updates the current +// instance. +func (e *APIError) With(msg string) *APIError { + if e.Message == "" { + e.Message = msg + return e + } + e.Message = fmt.Sprintf("%s: %s", e.Message, msg) + return e +} + +// WriteJSON writes the APIError as a JSON response to the provided +// http.ResponseWriter. It sets the Content-Type header to "application/json" +// and writes the status code and serialized error bytes to the response. +// If an error occurs during writing, it writes an internal server error +// response instead. +func (e *APIError) Write(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(e.StatusCode) + if _, err := w.Write(e.bytes()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// bytes serializes the APIError to JSON bytes. If an error occurs during +// serialization, it returns nil. +func (e *APIError) bytes() []byte { + bErr, err := json.Marshal(e) + if err != nil { + return nil + } + return bErr +} + +// NewAPIError creates a new APIError instance with the provided code and +// status code. It initializes the error string and message to empty strings. +// This function is useful for creating a new APIError instance with +// specific error codes and status codes. It returns a pointer to the +// newly created APIError instance. +func NewAPIError(code, status int) *APIError { + return &APIError{ + Code: code, + StatusCode: status, + } +} diff --git a/api/io/error_test.go b/api/io/error_test.go new file mode 100644 index 0000000..4da371f --- /dev/null +++ b/api/io/error_test.go @@ -0,0 +1,90 @@ +package io + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewAPIError(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + if err.Code != 1001 { + t.Errorf("expected code 1001, got %d", err.Code) + } + if err.StatusCode != http.StatusBadRequest { + t.Errorf("expected status code %d, got %d", http.StatusBadRequest, err.StatusCode) + } + if err.Message != "" { + t.Errorf("expected empty message, got %s", err.Message) + } + if err.Err != "" { + t.Errorf("expected empty error string, got %s", err.Err) + } +} + +func TestAPIError_Error(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + expected := "code: 1001, message: Bad Request, error: Invalid input, status_code: 400" + if err.Error() != expected { + t.Errorf("expected %s, got %s", expected, err.Error()) + } +} + +func TestAPIError_WithErr(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + _ = err.WithErr(errors.New("Invalid input")) + if err.Err != "Invalid input" { + t.Errorf("expected error string 'Invalid input', got %s", err.Err) + } + + _ = err.WithErr(errors.New("Missing field")) + if err.Err != "Invalid input: Missing field" { + t.Errorf("expected error string 'Invalid input: Missing field', got %s", err.Err) + } +} + +func TestAPIError_With(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + _ = err.With("Bad Request") + if err.Message != "Bad Request" { + t.Errorf("expected message 'Bad Request', got %s", err.Message) + } + + _ = err.With("Invalid input") + if err.Message != "Bad Request: Invalid input" { + t.Errorf("expected message 'Bad Request: Invalid input', got %s", err.Message) + } +} + +func TestAPIError_Write(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + + rr := httptest.NewRecorder() + err.Write(rr) + + if status := rr.Code; status != http.StatusBadRequest { + t.Errorf("expected status code %d, got %d", http.StatusBadRequest, status) + } + + expected := `{"code":1001,"message":"Bad Request","error":"Invalid input"}` + if rr.Body.String() != expected { + t.Errorf("expected body %s, got %s", expected, rr.Body.String()) + } +} + +func TestAPIError_bytes(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + + data := err.bytes() + expected := `{"code":1001,"message":"Bad Request","error":"Invalid input"}` + if string(data) != expected { + t.Errorf("expected %s, got %s", expected, string(data)) + } +} diff --git a/api/io/req.go b/api/io/req.go new file mode 100644 index 0000000..0b9691d --- /dev/null +++ b/api/io/req.go @@ -0,0 +1,39 @@ +package io + +import ( + "encoding/json" + "fmt" + "io" + "net/http" +) + +// Request represents a request with a generic data type to be unmarshalled +// from the request body. It implements the Read method to read and +// unmarshal the request body into the Data field. +type Request[T any] struct { + Data T +} + +// Read reads the request body and unmarshals it into the Data field of the +// Request struct. It returns an error if the request body is nil or empty, +// or if there is an error during unmarshalling. If the Request struct is nil, +// it initializes a new instance of Request[T]. This method is useful for +// handling incoming requests in a generic way, allowing for different data +// types to be processed without needing to define separate request structs +// for each type. +func (req *Request[T]) Read(r *http.Request) error { + if req == nil { + req = new(Request[T]) + } + if r.Body == nil { + return fmt.Errorf("nil request body") + } + rawBody, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("failed to read request body: %w", err) + } + if len(rawBody) == 0 { + return fmt.Errorf("empty request body") + } + return json.Unmarshal(rawBody, &req.Data) +} diff --git a/api/io/req_test.go b/api/io/req_test.go new file mode 100644 index 0000000..98b9790 --- /dev/null +++ b/api/io/req_test.go @@ -0,0 +1,57 @@ +package io + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" +) + +func TestRead(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + body, _ := json.Marshal(data) + req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + + var request Request[Data] + if err := request.Read(req); err != nil { + t.Errorf("unexpected error: %v", err) + } + if request.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, request.Data.Message) + } +} + +func TestRead_EmptyBody(t *testing.T) { + noBody, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + + if nilReq := new(Request[any]).Read(noBody); nilReq == nil { + t.Errorf("expected error, got nil") + } + + req, err := http.NewRequest("POST", "/", bytes.NewBuffer([]byte(""))) + if err != nil { + t.Fatal(err) + } + + var request *Request[any] + err = request.Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + err = new(Request[any]).Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + if err.Error() != "empty request body" { + t.Errorf("expected empty request body error, got %v", err) + } +} diff --git a/api/io/res.go b/api/io/res.go new file mode 100644 index 0000000..a73cc8f --- /dev/null +++ b/api/io/res.go @@ -0,0 +1,113 @@ +package io + +import ( + "encoding/json" + "net/http" +) + +// Response represents a response with a generic data type. It can be used to +// send JSON responses or plain text responses. The Data field holds the +// response data, and the empty field indicates whether the response is empty +// or not. It has methods to write the response to an http.ResponseWriter in +// JSON format or plain text format. +type Response[T any] struct { + Data T + empty bool +} + +// ResponseWith creates a new Response instance with the provided data. If +// the data is nil, it returns an empty response. This method is useful for +// creating responses with different data types without needing to define +// separate response structs for each type. It returns a pointer to the +// Response instance. +func ResponseWith[T any](data *T) *Response[T] { + if data == nil { + return &Response[T]{empty: true} + } + return &Response[T]{ + Data: *data, + empty: false, + } +} + +// OkResponse creates a new Response instance with the provided byte slice. +// If the byte slice is empty, it returns an empty response. This method is +// useful for creating responses with raw byte data. It returns a pointer to +// the Response instance. The empty field indicates whether the response is +// empty or not. If the byte slice is empty, the response is considered empty. +// If the byte slice is not empty, the response is considered non-empty. +func OkResponse(body ...byte) *Response[any] { + if len(body) > 0 { + return &Response[any]{Data: body, empty: false} + } + return &Response[any]{empty: true} +} + +// WriteJSON writes the response data to the provided http.ResponseWriter in +// JSON format. It sets the Content-Type header to "application/json" and +// writes the response data as JSON. If the response is empty, it writes a +// plain text "OK" response with a 200 OK status code. If there is an error +// during JSON encoding or response writing, it writes an error response with +// a 500 Internal Server Error status code. This method is useful for sending +// JSON responses to the client. It can be used in HTTP handlers or middleware +// to send structured JSON responses. +func (r *Response[T]) WriteJSON(w http.ResponseWriter) { + if !r.empty { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(r.Data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// Write writes the response data to the provided http.ResponseWriter in +// plain text format. It sets the Content-Type header to "text/plain" and +// writes the response data as plain text. If the response is empty, it writes +// a plain text "OK" response with a 200 OK status code. If there is an error +// during response writing, it writes an error response with a 500 Internal +// Server Error status code. This method is useful for sending plain text +// responses to the client. It can be used in HTTP handlers or middleware to +// send simple text responses. It is a more generic method than WriteJSON, as +// it does not require the response data to be JSON-serializable. +func (r *Response[T]) Write(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + if r.empty { + if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + if data, ok := r.bytes(); ok { + if _, err := w.Write(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } +} + +// bytes returns the response data as a byte slice. If the response is empty, +// it returns nil and a boolean indicating that the response is empty. If the +// response data is not a byte slice, it returns nil and a boolean indicating +// that the response data is not a byte slice. This method is useful for +// converting the response data to a byte slice for writing to the response +// writer or for further processing. +func (r *Response[T]) bytes() ([]byte, bool) { + // check if the response is empty + if r.empty { + return nil, true + } + // ensure that the response data is an slice of bytes + switch v := any(r.Data).(type) { + case []byte: + return v, true + case string: + return []byte(v), true + default: + return nil, false + } +} diff --git a/api/io/res_test.go b/api/io/res_test.go new file mode 100644 index 0000000..fc6925d --- /dev/null +++ b/api/io/res_test.go @@ -0,0 +1,179 @@ +package io + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestResponseWith(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + nilResp := ResponseWith[Data](nil) + if !nilResp.empty { + t.Errorf("expected response to be empty") + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + if resp.empty { + t.Errorf("expected response to be non-empty") + } + if resp.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, resp.Data.Message) + } +} + +func TestOkResponse(t *testing.T) { + resp := OkResponse() + if !resp.empty { + t.Errorf("expected response to be empty") + } +} + +func TestWriteJSON(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + expected, _ := json.Marshal(data) + if rr.Body.String() != string(expected)+"\n" { // Ensure newline is accounted for + t.Errorf("expected body %s, got %s", string(expected)+"\n", rr.Body.String()) + } + // write json with nil data + nilResp := ResponseWith[string](nil) + rr = httptest.NewRecorder() + nilResp.WriteJSON(rr) + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } +} + +func TestWriteJSON_Empty(t *testing.T) { + resp := OkResponse() + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + if rr.Body.String() != "OK" { + t.Errorf("expected body OK, got %s", rr.Body.String()) + } +} + +func TestOkResponse_WithBody(t *testing.T) { + body := []byte("Hello, World!") + resp := OkResponse(body...) + + if resp.empty { + t.Errorf("expected response to be non-empty") + } + + if string(resp.Data.([]byte)) != string(body) { + t.Errorf("expected body %s, got %s", string(body), string(resp.Data.([]byte))) + } +} + +func TestWrite(t *testing.T) { + data := []byte("Hello, World!") + resp := ResponseWith(&data) + + rr := httptest.NewRecorder() + resp.Write(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + // Adjust to match the expected plain text output + if rr.Body.String() != string(data) { + t.Errorf("expected body %s, got %s", string(data), rr.Body.String()) + } +} + +func TestWrite_Empty(t *testing.T) { + resp := OkResponse() + + rr := httptest.NewRecorder() + resp.Write(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + if rr.Body.String() != "OK" { + t.Errorf("expected body OK, got %s", rr.Body.String()) + } +} + +func TestBytes(t *testing.T) { + body := []byte("Hello, World!") + resp := OkResponse(body...) + + data, ok := resp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + + if string(data) != string(body) { + t.Errorf("expected body %s, got %s", string(body), string(data)) + } +} + +func TestBytes_Empty(t *testing.T) { + resp := OkResponse() + + data, ok := resp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if data != nil { + t.Errorf("expected nil data for empty response, got %v", data) + } + + // custom text response + bmsg := []byte("Hello, World!") + bresp := ResponseWith(&bmsg) + bdata, ok := bresp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if string(bdata) != string(bmsg) { + t.Errorf("expected body %s, got %s", string(bmsg), string(bdata)) + } + + // string type + msg := "Hello, World!" + sresp := ResponseWith(&msg) + sdata, ok := sresp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if string(sdata) != msg { + t.Errorf("expected body %s, got %s", msg, string(sdata)) + } + + // invalid type + imsg := 123 + iresp := ResponseWith(&imsg) + idata, ok := iresp.bytes() + if ok { + t.Errorf("expected bytes to be invalid") + } + if idata != nil { + t.Errorf("expected nil data for invalid response, got %v", idata) + } +} diff --git a/api/io_test.go b/api/io_test.go deleted file mode 100644 index ace6588..0000000 --- a/api/io_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" -) - -func TestResponseWith(t *testing.T) { - type Data struct { - Message string `json:"message"` - } - nilResp := ResponseWith[Data](nil) - if !nilResp.empty { - t.Errorf("expected response to be empty") - } - data := &Data{Message: "Hello, World!"} - resp := ResponseWith(data) - if resp.empty { - t.Errorf("expected response to be non-empty") - } - if resp.Data.Message != data.Message { - t.Errorf("expected %s, got %s", data.Message, resp.Data.Message) - } -} - -func TestOkResponse(t *testing.T) { - resp := OkResponse() - if !resp.empty { - t.Errorf("expected response to be empty") - } -} - -func TestWriteJSON(t *testing.T) { - type Data struct { - Message string `json:"message"` - } - data := &Data{Message: "Hello, World!"} - resp := ResponseWith(data) - - rr := httptest.NewRecorder() - resp.WriteJSON(rr) - - if status := rr.Code; status != http.StatusOK { - t.Errorf("expected status code %d, got %d", http.StatusOK, status) - } - - expected, _ := json.Marshal(data) - if rr.Body.String() != string(expected)+"\n" { - t.Errorf("expected body %s, got %s", string(expected), rr.Body.String()) - } - // write json with nil data - nilResp := ResponseWith[string](nil) - rr = httptest.NewRecorder() - nilResp.WriteJSON(rr) - if status := rr.Code; status != http.StatusOK { - t.Errorf("expected status code %d, got %d", http.StatusOK, status) - } - -} - -func TestWriteJSON_Empty(t *testing.T) { - resp := OkResponse() - - rr := httptest.NewRecorder() - resp.WriteJSON(rr) - - if status := rr.Code; status != http.StatusOK { - t.Errorf("expected status code %d, got %d", http.StatusOK, status) - } - - if rr.Body.String() != "OK" { - t.Errorf("expected body OK, got %s", rr.Body.String()) - } -} - -func TestRead(t *testing.T) { - type Data struct { - Message string `json:"message"` - } - data := &Data{Message: "Hello, World!"} - body, _ := json.Marshal(data) - req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body)) - if err != nil { - t.Fatal(err) - } - - var request Request[Data] - if err := request.Read(req); err != nil { - t.Errorf("unexpected error: %v", err) - } - if request.Data.Message != data.Message { - t.Errorf("expected %s, got %s", data.Message, request.Data.Message) - } -} - -func TestRead_EmptyBody(t *testing.T) { - noBody, err := http.NewRequest("POST", "/", nil) - if err != nil { - t.Fatal(err) - } - - if nilReq := new(Request[any]).Read(noBody); nilReq == nil { - t.Errorf("expected error, got nil") - } - - req, err := http.NewRequest("POST", "/", bytes.NewBuffer([]byte(""))) - if err != nil { - t.Fatal(err) - } - - var request *Request[any] - err = request.Read(req) - if err == nil { - t.Errorf("expected error, got nil") - } - err = new(Request[any]).Read(req) - if err == nil { - t.Errorf("expected error, got nil") - } - if err.Error() != "empty request body" { - t.Errorf("expected empty request body error, got %v", err) - } -} diff --git a/api/routes.go b/api/routes.go index 72bbf6f..78631ae 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,10 +1,26 @@ package api +// routes paths constants const ( // HealthCheckPath constant is the path used to check the health of the API // server. It is a string with a value of "/health". HealthCheckPath = "/ping" - AppsPath = "/apps" - TokensPath = "/tokens" + // AppsPath constant is the path used to create the apps in the API server. + AppsPath = "/apps" + // TokensPath constant is the path used to generate and verify the tokens + // in the API server. + TokensPath = "/tokens" + // DemoInboxPath constant is the path used to get the demo email inbox + // in the API server when it runs in demo mode. DemoInboxPath = "/demo/inbox" ) + +// other api related constants +const ( + // appIDHeader constant is the header of the app ID in the request. It is + // used to authenticate the app making the request. + appIDHeader = "AppID" + // appSecretHeader constant is the header of the app secret in the request + // It is used to authenticate the app making the request. + appSecretHeader = "AppSecret" +) diff --git a/api/service.go b/api/service.go index fb53a0a..7fc5223 100644 --- a/api/service.go +++ b/api/service.go @@ -15,6 +15,15 @@ import ( "github.com/simpleauthlink/authapi/notification" ) +// Config struct represents the configuration for the API service. It contains +// the server address, server port, and secret key for the service. The server +// address is the address where the service will listen for incoming requests, +// and the server port is the port number where the service will listen for +// incoming requests. The secret key is used to sign and verify tokens. The +// demo mode is used to enable or disable the demo functionality of the service. +// The demo SMTP address and port are used to configure the demo mail server. +// The demo mode is used to enable or disable the demo functionality of the +// service. type Config struct { Server string ServerPort int @@ -25,6 +34,17 @@ type Config struct { DemoSMTPPort int } +// Service struct represents the API service. It contains the context, cancel +// function, wait group, configuration, notification queue, API handler, HTTP +// server, and demo mail server. The context is used to manage the lifecycle +// of the service, while the wait group is used to wait for background processes +// to finish. The notification queue is used to send notifications, and the API +// handler is used to handle incoming requests. The HTTP server is used to +// serve the API endpoints, and the demo mail server is used to simulate +// sending emails in demo mode. +// The demo mail server is a fake SMTP server that captures emails sent to it +// for testing purposes. The demo mail inbox is a channel that receives the +// captured emails. type Service struct { ctx context.Context cancel context.CancelFunc @@ -38,17 +58,25 @@ type Service struct { demoMailInbox chan string } +// New function creates a new service instance. It takes a context, a config +// struct, and a notification queue as parameters. It returns a pointer to the +// service instance and an error if something goes wrong during the process. +// The function is responsible for setting up the service and its dependencies. +// It handles the configuration, rate limiting, and HTTP server setup. It also +// manages the demo mode functionality, including the demo mail server and +// inbox. The function is designed to be used as a constructor for the service +// and is responsible for initializing all the necessary components for the +// service to function properly. func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - // rateLimiter := apihandler.RateLimiter(internalCtx, 1000, 1000, time.Minute*3) + rateLimiter := apihandler.RateLimiter(internalCtx, 50, 50, time.Minute*3) // create the service srv := &Service{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - nq: nq, - // handler: apihandler.NewHandler(true, rateLimiter), - handler: apihandler.NewHandler(true, nil), + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + nq: nq, + handler: apihandler.NewHandler(true, rateLimiter), } // demo stuff if cfg.DemoMode { @@ -73,7 +101,7 @@ func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, err return srv, nil } -// Start method starts the service. +// Start method starts the service by starting the http server. func (s *Service) Start() error { // start the api server if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -82,12 +110,19 @@ func (s *Service) Start() error { return nil } +// Stop method stops the service. It cancels the context and waits for the +// background processes to finish. It also closes the http server and the +// demo mail server if it is running. func (s *Service) Stop() { // cancel the context and wait for the background processes finish s.cancel() defer s.wait.Wait() } +// Ping method checks if the service is up and running. It sends a GET request +// to the health check endpoint and returns true if the response status code +// is 200 OK, otherwise it returns false. If something goes wrong during the +// process, it returns false. func (s *Service) Ping() bool { url := fmt.Sprintf("http://%s:%d%s", s.cfg.Server, s.cfg.ServerPort, HealthCheckPath) request, err := http.NewRequest(http.MethodGet, url, nil) From c87b91825d3bd6d7aef8133ed71c42599122f105 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 16 Apr 2025 18:02:54 +0200 Subject: [PATCH 34/36] issues with osflag fixed, more comments and tests --- cmd/authapi/main.go | 6 +- cmd/consts.go | 4 + cmd/demo/main.go | 2 +- internal/fakesmtpserver/server.go | 23 +++- internal/osflag/osflag.go | 208 ++++++++++++++++++++++-------- internal/osflag/osflag_test.go | 42 ++++-- 6 files changed, 213 insertions(+), 72 deletions(-) diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index bd54065..e487ef9 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -15,6 +15,7 @@ type config struct { host string port int emailAddr string + emailUser string emailPass string emailHost string emailPort int @@ -33,11 +34,12 @@ func main() { osflag.StringVar(&c.host, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) osflag.IntVar(&c.port, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.HostFlagDesc, false) osflag.StringVar(&c.emailAddr, cmd.EmailAddrEnv, cmd.EmailAddrFlag, cmd.DefaultEmailAddr, cmd.EmailAddrFlagDesc, true) + osflag.StringVar(&c.emailUser, cmd.EmailUserEnv, cmd.EmailUserFlag, cmd.DefaultEmailUser, cmd.EmailUserFlagDesc, true) osflag.StringVar(&c.emailPass, cmd.EmailPassEnv, cmd.EmailPassFlag, cmd.DefaultEmailPass, cmd.EmailPassFlagDesc, true) osflag.StringVar(&c.emailHost, cmd.EmailHostEnv, cmd.EmailHostFlag, cmd.DefaultEmailHost, cmd.EmailHostFlagDesc, true) osflag.IntVar(&c.emailPort, cmd.EmailPortEnv, cmd.EmailPortFlag, cmd.DefaultEmailPort, cmd.EmailPortFlagDesc, false) osflag.StringVar(&c.secret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, true) - if err := osflag.Parse(); err != nil { + if err := osflag.Parse(nil); err != nil { log.Fatalln("ERR: error parsing flags:", err) } if !osflag.Parsed() { @@ -49,7 +51,7 @@ func main() { emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ FromName: "SimpleAuthLink", FromAddress: c.emailAddr, - SMTPUsername: c.emailAddr, + SMTPUsername: c.emailUser, SMTPPassword: c.emailPass, SMTPServer: c.emailHost, SMTPPort: c.emailPort, diff --git a/cmd/consts.go b/cmd/consts.go index 4205497..0a5bb9e 100644 --- a/cmd/consts.go +++ b/cmd/consts.go @@ -4,6 +4,7 @@ const ( DefaultHost = "0.0.0.0" DefaultPort = 8080 DefaultEmailAddr = "" + DefaultEmailUser = "" DefaultEmailPass = "" DefaultEmailHost = "" DefaultEmailPort = 587 @@ -12,6 +13,7 @@ const ( HostFlag = "host" PortFlag = "port" EmailAddrFlag = "email-addr" + EmailUserFlag = "email-user" EmailPassFlag = "email-pass" EmailHostFlag = "email-host" EmailPortFlag = "email-port" @@ -19,6 +21,7 @@ const ( HostFlagDesc = "service host" PortFlagDesc = "service port" EmailAddrFlagDesc = "email account address" + EmailUserFlagDesc = "email account username" EmailPassFlagDesc = "email account password" EmailHostFlagDesc = "email server host" EmailPortFlagDesc = "email server port" @@ -27,6 +30,7 @@ const ( HostEnv = "HOST" PortEnv = "PORT" EmailAddrEnv = "EMAIL_ADDR" + EmailUserEnv = "EMAIL_USER" EmailPassEnv = "EMAIL_PASS" EmailHostEnv = "EMAIL_HOST" EmailPortEnv = "EMAIL_PORT" diff --git a/cmd/demo/main.go b/cmd/demo/main.go index 9fa400c..ea4a4aa 100644 --- a/cmd/demo/main.go +++ b/cmd/demo/main.go @@ -19,7 +19,7 @@ func main() { osflag.StringVar(&demoServer, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) osflag.IntVar(&demoPort, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.PortFlagDesc, false) osflag.StringVar(&demoSecret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, false) - if err := osflag.Parse(); err != nil { + if err := osflag.Parse(nil); err != nil { log.Fatalln("ERR: error parsing flags:", err) } log.Println("INF: starting service with config:", demoServer, demoPort, demoSecret) diff --git a/internal/fakesmtpserver/server.go b/internal/fakesmtpserver/server.go index bbd3ee2..dee96c4 100644 --- a/internal/fakesmtpserver/server.go +++ b/internal/fakesmtpserver/server.go @@ -1,5 +1,13 @@ package fakesmtpserver +// fakesmtpserver package provides a simple SMTP server for testing purposes. +// It allows you to simulate an SMTP server that can receive emails and store +// them in a channel. This is useful for testing email sending functionality +// in applications without needing to set up a real SMTP server. The server +// can be started and stopped, and it handles basic SMTP commands like HELO, +// MAIL FROM, RCPT TO, and DATA. It also provides a way to retrieve the +// received emails from the inbox channel. + import ( "bufio" "context" @@ -40,11 +48,13 @@ func (s *FakeSMTPServer) Start(ctx context.Context) error { for { select { case <-ctx.Done(): - s.Stop() // Use Stop to safely close the listener + // use Stop to safely close the listener + s.Stop() return default: + // copy listener under lock s.mu.Lock() - listener := s.listener // Copy listener under lock + listener := s.listener s.mu.Unlock() if listener == nil { return @@ -62,12 +72,15 @@ func (s *FakeSMTPServer) Start(ctx context.Context) error { // Stop method shuts down the test SMTP server. func (s *FakeSMTPServer) Stop() { + // copy listener under lock s.mu.Lock() - listener := s.listener // Copy listener under lock - s.listener = nil // Set listener to nil under lock + listener := s.listener + // set listener to nil under lock and unlock + s.listener = nil s.mu.Unlock() + // close the listener if it is not nil if listener != nil { - listener.Close() // Close listener outside the lock + listener.Close() } } diff --git a/internal/osflag/osflag.go b/internal/osflag/osflag.go index d71924a..074e298 100644 --- a/internal/osflag/osflag.go +++ b/internal/osflag/osflag.go @@ -1,21 +1,61 @@ package osflag +// osflag package provides a way to manage command line flags and environment +// variables in Go applications. It allows for the creation of command line +// flags that can also be overwritten by environment variables. By default it +// loads a `.env` file, but this can be overridden by passing an `WithEnvFile` +// option to the `Parse` method. It also checks for required flags and ensures +// that they are set before parsing the command line arguments. + import ( + "bufio" "flag" "fmt" "os" - "strconv" + "strings" "time" ) +// Options is a struct that holds options for the Parse method. +type Options struct { + envFile string +} + +// WithEnvFile function creates an Options instance with the specified +// envFile path. If the path is empty, it returns nil. This function is used +// to specify a custom env file path when calling the Parse method. +func WithEnvFile(path string) *Options { + if path == "" { + return nil + } + return &Options{envFile: path} +} + +// osflag is a struct that holds the name, environment variable, and +// required mark of a flag. It is used to manage command line flags +// and their corresponding environment variables. +type osflag struct { + name string + env string + required bool +} + +// OsFlagSet is a struct that embeds flag.FlagSet and adds support for env +// variables. It allows for the creation of command line flags that can also +// be overwritten by environment variables. By default it loads `.env` file, +// but this can be overridden by passing an WithEnvFile option to the Parse +// method. It also checks for required flags and ensures that they are set +// before parsing the command line arguments. type OsFlagSet struct { *flag.FlagSet - required map[string]bool - parsed bool + flags map[string]osflag + parsed bool } +// CommandLine is the default OsFlagSet instance. var CommandLine *OsFlagSet +// init initializes the CommandLine variable with a new OsFlagSet instance. func init() { CommandLine = new(OsFlagSet) if len(os.Args) == 0 { @@ -23,80 +63,77 @@ func init() { } else { CommandLine.FlagSet = flag.NewFlagSet(os.Args[0], flag.ExitOnError) } - CommandLine.required = make(map[string]bool) + CommandLine.flags = make(map[string]osflag) } +// BoolVar method registers a boolean flag with the given name, env variable, +// default value, usage string, and required mark. func (of *OsFlagSet) BoolVar(p *bool, env, name string, value bool, usage string, required bool) { - var newDefault bool = value - if rawBool := os.Getenv(env); rawBool != "" { - if rawBool == "true" || rawBool == "True" || rawBool == "TRUE" || rawBool == "1" { - newDefault = true - } - } - of.required[name] = required - of.FlagSet.BoolVar(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.BoolVar(p, name, value, usage) } +// DurationVar method registers a duration flag with the given name, env +// variable, default value, usage string, and required mark. func (of *OsFlagSet) DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { - var newDefault time.Duration = value - if rawDuration := os.Getenv(env); rawDuration != "" { - if dur, err := time.ParseDuration(rawDuration); err == nil { - newDefault = dur - } - } - of.required[name] = required - of.FlagSet.DurationVar(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.DurationVar(p, name, value, usage) } +// Float64Var method registers a float64 flag with the given name, env variable, +// default value, usage string, and required mark. func (of *OsFlagSet) Float64Var(p *float64, env, name string, value float64, usage string, required bool) { - var newDefault float64 = value - if rawFloat := os.Getenv(env); rawFloat != "" { - if f, err := strconv.ParseFloat(rawFloat, 64); err == nil { - newDefault = f - } - } - of.required[name] = required - of.FlagSet.Float64Var(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.Float64Var(p, name, value, usage) } +// IntVar method registers an int flag with the given name, env variable, +// default value, usage string, and required mark. func (of *OsFlagSet) IntVar(p *int, env, name string, value int, usage string, required bool) { - var newDefault int = value - if rawInt := os.Getenv(env); rawInt != "" { - if integer, err := strconv.Atoi(rawInt); err == nil { - newDefault = integer - } - } - of.required[name] = required - of.FlagSet.IntVar(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.IntVar(p, name, value, usage) } +// StringVar method registers a string flag with the given name, env variable, +// default value, usage string, and required mark. func (of *OsFlagSet) StringVar(p *string, env, name string, value string, usage string, required bool) { - var newDefault string = value - if rawString := os.Getenv(env); rawString != "" { - newDefault = rawString - } - of.required[name] = required - of.FlagSet.StringVar(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.StringVar(p, name, value, usage) } +// UintVar method registers a uint flag with the given name, env variable, +// default value, usage string, and required mark. func (of *OsFlagSet) UintVar(p *uint, env, name string, value uint, usage string, required bool) { - var newDefault uint - if rawUint := os.Getenv(env); rawUint != "" { - if ui, err := strconv.ParseUint(rawUint, 10, 64); err == nil { - newDefault = uint(ui) - } - } - of.required[name] = required - of.FlagSet.UintVar(p, name, newDefault, usage) + of.flags[name] = osflag{name, env, required} + of.FlagSet.UintVar(p, name, value, usage) } -func (of *OsFlagSet) Parse() error { +// Parse method parses the command line arguments and loads the environment +// variables from the specified env file. It checks if all required flags are +// set and it overwrites the command line flags with the values from the env +// variables if they are set. It returns an error if any required flags are +// not set or if there is an error loading the env file. +func (of *OsFlagSet) Parse(opts *Options) error { if err := of.FlagSet.Parse(os.Args[1:]); err != nil { return err } + // load the env file + envFile := ".env" + if opts != nil && opts.envFile != "" { + envFile = opts.envFile + } + if err := loadEnv(envFile); err != nil { + return fmt.Errorf("failed to load env file: %w", err) + } // check if all required flags are set - for name, required := range of.required { - if required { + for name, osf := range of.flags { + if envValue := os.Getenv(osf.env); envValue != "" { + if err := of.FlagSet.Set(name, envValue); err != nil { + return fmt.Errorf("failed to set flag %s from env: %w", name, err) + } + } + // check if the flag is required and not set + if osf.required { f := of.FlagSet.Lookup(name) if f == nil || f.Value.String() == "" { return fmt.Errorf("required flag %s is not set", name) @@ -107,46 +144,107 @@ func (of *OsFlagSet) Parse() error { return nil } +// Parsed method returns true if the command line arguments have been parsed. func (of *OsFlagSet) Parsed() bool { return of.parsed } +// PrintDefaults method prints the default values of all flags. func (of *OsFlagSet) PrintDefaults() { of.FlagSet.PrintDefaults() } +// BoolVar method registers a boolean flag with the given name, env variable, +// default value, usage string, and required mark. func BoolVar(p *bool, env, name string, value bool, usage string, required bool) { CommandLine.BoolVar(p, env, name, value, usage, required) } +// DurationVar method registers a duration flag with the given name, env +// variable, default value, usage string, and required mark. func DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { CommandLine.DurationVar(p, env, name, value, usage, required) } +// Float64Var method registers a float64 flag with the given name, env variable, +// default value, usage string, and required mark. func Float64Var(p *float64, env, name string, value float64, usage string, required bool) { CommandLine.Float64Var(p, env, name, value, usage, required) } +// IntVar method registers an int flag with the given name, env variable, +// default value, usage string, and required mark. func IntVar(p *int, env, name string, value int, usage string, required bool) { CommandLine.IntVar(p, env, name, value, usage, required) } +// StringVar method registers a string flag with the given name, env variable, +// default value, usage string, and required mark. func StringVar(p *string, env, name string, value string, usage string, required bool) { CommandLine.StringVar(p, env, name, value, usage, required) } +// UintVar method registers a uint flag with the given name, env variable, +// default value, usage string, and required mark. func UintVar(p *uint, env, name string, value uint, usage string, required bool) { CommandLine.UintVar(p, env, name, value, usage, required) } -func Parse() error { - return CommandLine.Parse() +// Parse method parses the command line arguments and loads the environment +// variables from the specified env file. It checks if all required flags are +// set and it overwrites the command line flags with the values from the env +// variables if they are set. It returns an error if any required flags are +// not set or if there is an error loading the env file. +func Parse(opts *Options) error { + return CommandLine.Parse(opts) } +// Parsed method returns true if the command line arguments have been parsed. func Parsed() bool { return CommandLine.parsed } +// PrintDefaults method prints the default values of all flags. func PrintDefaults() { CommandLine.PrintDefaults() } + +// loadEnv function loads environment variables from a file. If the file does +// not exist, it returns nil and does not raise an error. It reads the file +// line by line, ignoring empty lines and comments. It sets the environment +// variables in the current process using os.Setenv. It removes any "export " +// prefix and surrounding quotes from the variable assignments. It returns +// an error if there is an issue opening or reading the file (different from +// the file not existing). +func loadEnv(path string) error { + envFile, err := os.Open(path) + if err != nil { + // if the file does not exist, return nil + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to open env file: %w", err) + } + defer envFile.Close() + // create a line scanner + scanner := bufio.NewScanner(envFile) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue // skip empty lines and comments + } + // remove "export " prefix if present + line = strings.TrimPrefix(line, "export ") + // split on the first '=' character + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue // or return an error if preferred + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + value = strings.Trim(value, "\"'") // remove surrounding quotes + // set var in the current env + os.Setenv(key, value) + } + return nil +} diff --git a/internal/osflag/osflag_test.go b/internal/osflag/osflag_test.go index a70fbbd..15a2a46 100644 --- a/internal/osflag/osflag_test.go +++ b/internal/osflag/osflag_test.go @@ -10,7 +10,7 @@ import ( func resetCommandLine() { CommandLine = new(OsFlagSet) CommandLine.FlagSet = flag.NewFlagSet("", flag.ExitOnError) - CommandLine.required = make(map[string]bool) + CommandLine.flags = make(map[string]osflag) // Filter out test framework flags os.Args = os.Args[:1] @@ -23,7 +23,7 @@ func TestBoolVar(t *testing.T) { defer os.Unsetenv("TEST_BOOL") CommandLine.BoolVar(&flagValue, "TEST_BOOL", "boolFlag", false, "A boolean flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -39,7 +39,7 @@ func TestDurationVar(t *testing.T) { defer os.Unsetenv("TEST_DURATION") CommandLine.DurationVar(&flagValue, "TEST_DURATION", "durationFlag", 0, "A duration flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -55,7 +55,7 @@ func TestFloat64Var(t *testing.T) { defer os.Unsetenv("TEST_FLOAT") CommandLine.Float64Var(&flagValue, "TEST_FLOAT", "floatFlag", 0.0, "A float flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -71,7 +71,7 @@ func TestIntVar(t *testing.T) { defer os.Unsetenv("TEST_INT") CommandLine.IntVar(&flagValue, "TEST_INT", "intFlag", 0, "An int flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -87,7 +87,7 @@ func TestStringVar(t *testing.T) { defer os.Unsetenv("TEST_STRING") CommandLine.StringVar(&flagValue, "TEST_STRING", "stringFlag", "default", "A string flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -103,7 +103,7 @@ func TestUintVar(t *testing.T) { defer os.Unsetenv("TEST_UINT") CommandLine.UintVar(&flagValue, "TEST_UINT", "uintFlag", 0, "A uint flag", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -117,7 +117,7 @@ func TestRequiredFlag(t *testing.T) { var flagValue string CommandLine.StringVar(&flagValue, "", "requiredFlag", "", "A required flag", true) - if err := CommandLine.Parse(); err == nil { + if err := CommandLine.Parse(nil); err == nil { t.Errorf("Expected error for missing required flag, got nil") } } @@ -126,7 +126,7 @@ func TestDefaultValues(t *testing.T) { resetCommandLine() var flagValue string CommandLine.StringVar(&flagValue, "", "defaultFlag", "defaultValue", "A flag with a default value", false) - if err := CommandLine.Parse(); err != nil { + if err := CommandLine.Parse(nil); err != nil { t.Fatalf("Failed to parse command line: %v", err) } @@ -134,3 +134,27 @@ func TestDefaultValues(t *testing.T) { t.Errorf("Expected 'defaultValue', got %v", flagValue) } } + +func TestLoadEnv(t *testing.T) { + resetCommandLine() + // try to load a non-existing env file (should not error) + if err := loadEnv("non_existing.env"); err != nil { + t.Fatalf("Expected no error for non-existing env file, got: %v", err) + } + // create .env file + envFileContent := []byte("TEST_ENV=envValue") + envFilePath := ".env" + if err := os.WriteFile(envFilePath, envFileContent, 0o644); err != nil { + t.Fatalf("Failed to create env file: %v", err) + } + defer os.Remove(envFilePath) + // parse flags and check the value + var flagValue string + CommandLine.StringVar(&flagValue, "TEST_ENV", "envFlag", "defaultValue", "A flag with an env variable", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + if flagValue != "envValue" { + t.Errorf("Expected 'envValue', got %v", flagValue) + } +} From 8d0ef8edb460626fcc035e54ae13c2c3f8e43530 Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 16 Apr 2025 18:42:06 +0200 Subject: [PATCH 35/36] example.env updated --- example.env | 1 + 1 file changed, 1 insertion(+) diff --git a/example.env b/example.env index 617594f..b2e07b8 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,7 @@ HOST="localhost" PORT=8080 EMAIL_ADDR="test@test.com" +EMAIL_USER="test@test.com" EMAIL_PASS="smtp_server_password" EMAIL_HOST="smtp.example.com" EMAIL_PORT=587 From 3b44b2ca0fcf4cef406d99980919a0053e01d86e Mon Sep 17 00:00:00 2001 From: Lucas Menendez Date: Wed, 16 Apr 2025 19:19:11 +0200 Subject: [PATCH 36/36] remove ratelimiter, does not work --- api/service.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/service.go b/api/service.go index 7fc5223..acd4b30 100644 --- a/api/service.go +++ b/api/service.go @@ -69,14 +69,13 @@ type Service struct { // service to function properly. func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - rateLimiter := apihandler.RateLimiter(internalCtx, 50, 50, time.Minute*3) // create the service srv := &Service{ ctx: internalCtx, cancel: cancel, cfg: cfg, nq: nq, - handler: apihandler.NewHandler(true, rateLimiter), + handler: apihandler.NewHandler(true, nil), } // demo stuff if cfg.DemoMode {