Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ func (s *server) authalice(next http.Handler) http.Handler {
if !found {
log.Info().Msg("Looking for user information in DB")
// Checks DB from matching user and store user values in context
rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0 FROM users WHERE token=$1 LIMIT 1", token)
rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0,CASE WHEN s3_enabled THEN 'true' ELSE 'false' END,COALESCE(media_delivery, 'base64') FROM users WHERE token=$1 LIMIT 1", token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This SQL query has become quite long, which can harm readability and maintainability. It's a good practice to extract complex queries into named constants. This improves readability and separates SQL from Go logic.

For example, you could define a constant at the package level:

const authaliceUserQuery = `
    SELECT
        id, name, webhook, jid, events, proxy_url, qrcode, history,
        hmac_key IS NOT NULL AND length(hmac_key) > 0,
        CASE WHEN s3_enabled THEN 'true' ELSE 'false' END,
        COALESCE(media_delivery, 'base64')
    FROM users
    WHERE token=$1
    LIMIT 1`

Then, you can use this constant in your query call, making the code cleaner.

Suggested change
rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0,CASE WHEN s3_enabled THEN 'true' ELSE 'false' END,COALESCE(media_delivery, 'base64') FROM users WHERE token=$1 LIMIT 1", token)
rows, err := s.db.Query(authaliceUserQuery, token)

if err != nil {
s.Respond(w, r, http.StatusInternalServerError, err)
return
}
defer rows.Close()
var history sql.NullInt64
var s3Enabled, mediaDelivery string
for rows.Next() {
err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac)
err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac, &s3Enabled, &mediaDelivery)
if err != nil {
s.Respond(w, r, http.StatusInternalServerError, err)
return
Expand All @@ -176,16 +177,18 @@ func (s *server) authalice(next http.Handler) http.Handler {
log.Debug().Str("userId", txtid).Bool("historyValid", history.Valid).Int64("historyValue", history.Int64).Str("historyStr", historyStr).Msg("User authentication - history debug")

v := Values{map[string]string{
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"S3Enabled": s3Enabled,
"MediaDelivery": mediaDelivery,
}}
Comment on lines 179 to 192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and long-term maintainability, it's a good practice to keep the keys in map literals sorted alphabetically. This makes it easier for developers to find specific keys when reading or modifying the code.

Suggested change
v := Values{map[string]string{
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"Id": txtid,
"Name": name,
"Jid": jid,
"Webhook": webhook,
"Token": token,
"Proxy": proxy_url,
"Events": events,
"Qrcode": qrcode,
"History": historyStr,
"HasHmac": strconv.FormatBool(hasHmac),
"S3Enabled": s3Enabled,
"MediaDelivery": mediaDelivery,
}}
v := Values{map[string]string{
"Events": events,
"HasHmac": strconv.FormatBool(hasHmac),
"History": historyStr,
"Id": txtid,
"Jid": jid,
"MediaDelivery": mediaDelivery,
"Name": name,
"Proxy": proxy_url,
"Qrcode": qrcode,
"S3Enabled": s3Enabled,
"Token": token,
"Webhook": webhook,
}}


userinfocache.Set(token, v, cache.NoExpiration)
Expand Down
1 change: 1 addition & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ func ProcessOutgoingMedia(userID string, contactJID string, messageID string, da

// Process S3 upload if enabled
if s3Config.Enabled && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(userID)
// Process S3 upload (outgoing messages are always in outbox)
s3Data, err := GetS3Manager().ProcessMediaForS3(
context.Background(),
Expand Down
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ func main() {
}
}()

// Set DB reference in S3Manager for lazy client initialization
GetS3Manager().SetDB(db)

// Initialize the schema
if err = initializeSchema(db); err != nil {
log.Fatal().Err(err).Msg("Failed to initialize schema")
Expand Down
60 changes: 59 additions & 1 deletion s3manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/jmoiron/sqlx"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
Expand All @@ -32,6 +33,7 @@ type S3Config struct {
// S3Manager manages S3 operations
type S3Manager struct {
mu sync.RWMutex
db *sqlx.DB
clients map[string]*s3.Client
configs map[string]*S3Config
}
Expand All @@ -47,6 +49,56 @@ func GetS3Manager() *S3Manager {
return s3Manager
}

// SetDB sets the database reference for lazy S3 client initialization
func (m *S3Manager) SetDB(db *sqlx.DB) {
m.mu.Lock()
defer m.mu.Unlock()
m.db = db
}

// EnsureClientFromDB loads S3 config from DB and initializes client if enabled. Returns true if client is available.
func (m *S3Manager) EnsureClientFromDB(userID string) bool {
if _, _, ok := m.GetClient(userID); ok {
return true
}
m.mu.RLock()
db := m.db
m.mu.RUnlock()
if db == nil {
return false
}
var s3DbConfig struct {
Enabled bool `db:"s3_enabled"`
Endpoint string `db:"s3_endpoint"`
Region string `db:"s3_region"`
Bucket string `db:"s3_bucket"`
AccessKey string `db:"s3_access_key"`
SecretKey string `db:"s3_secret_key"`
PathStyle bool `db:"s3_path_style"`
PublicURL string `db:"s3_public_url"`
MediaDelivery string `db:"media_delivery"`
RetentionDays int `db:"s3_retention_days"`
}
query := `SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket, s3_access_key, s3_secret_key, s3_path_style, s3_public_url, COALESCE(media_delivery, 'base64') AS media_delivery, COALESCE(s3_retention_days, 30) AS s3_retention_days FROM users WHERE id = $1`
query = db.Rebind(query)
if err := db.Get(&s3DbConfig, query, userID); err != nil || !s3DbConfig.Enabled {
return false
}
config := &S3Config{
Enabled: s3DbConfig.Enabled,
Endpoint: s3DbConfig.Endpoint,
Region: s3DbConfig.Region,
Bucket: s3DbConfig.Bucket,
AccessKey: s3DbConfig.AccessKey,
SecretKey: s3DbConfig.SecretKey,
PathStyle: s3DbConfig.PathStyle,
PublicURL: s3DbConfig.PublicURL,
MediaDelivery: s3DbConfig.MediaDelivery,
RetentionDays: s3DbConfig.RetentionDays,
}
return m.InitializeS3Client(userID, config) == nil
}

// InitializeS3Client creates or updates S3 client for a user
func (m *S3Manager) InitializeS3Client(userID string, config *S3Config) error {
if !config.Enabled {
Expand Down Expand Up @@ -192,7 +244,13 @@ func (m *S3Manager) GenerateS3Key(userID, contactJID, messageID string, mimeType
func (m *S3Manager) UploadToS3(ctx context.Context, userID string, key string, data []byte, mimeType string) error {
client, config, ok := m.GetClient(userID)
if !ok {
return fmt.Errorf("S3 client not initialized for user %s", userID)
// Try lazy init from DB if available (handles reconnect-after-restart)
if m.EnsureClientFromDB(userID) {
client, config, ok = m.GetClient(userID)
}
if !ok {
return fmt.Errorf("S3 client not initialized for user %s", userID)
}
}

// Set content type and cache headers for preview
Expand Down
62 changes: 19 additions & 43 deletions wmiau.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type MyClient struct {
s *server
}

// ensureS3ClientForUser loads S3 config from DB and initializes client if not already present (lazy init for reconnect-after-restart)
func ensureS3ClientForUser(userID string) {
GetS3Manager().EnsureClientFromDB(userID)
}

func sendToGlobalWebHook(jsonData []byte, token string, userID string) {
jsonDataStr := string(jsonData)

Expand Down Expand Up @@ -295,49 +300,7 @@ func (s *server) connectOnStartup() {

// Initialize S3 client if configured
go func(userID string) {
var s3Config struct {
Enabled bool `db:"s3_enabled"`
Endpoint string `db:"s3_endpoint"`
Region string `db:"s3_region"`
Bucket string `db:"s3_bucket"`
AccessKey string `db:"s3_access_key"`
SecretKey string `db:"s3_secret_key"`
PathStyle bool `db:"s3_path_style"`
PublicURL string `db:"s3_public_url"`
RetentionDays int `db:"s3_retention_days"`
}

err := s.db.Get(&s3Config, `
SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket,
s3_access_key, s3_secret_key, s3_path_style,
s3_public_url, s3_retention_days
FROM users WHERE id = $1`, userID)

if err != nil {
log.Error().Err(err).Str("userID", userID).Msg("Failed to get S3 config")
return
}

if s3Config.Enabled {
config := &S3Config{
Enabled: s3Config.Enabled,
Endpoint: s3Config.Endpoint,
Region: s3Config.Region,
Bucket: s3Config.Bucket,
AccessKey: s3Config.AccessKey,
SecretKey: s3Config.SecretKey,
PathStyle: s3Config.PathStyle,
PublicURL: s3Config.PublicURL,
RetentionDays: s3Config.RetentionDays,
}

err = GetS3Manager().InitializeS3Client(userID, config)
if err != nil {
log.Error().Err(err).Str("userID", userID).Msg("Failed to initialize S3 client on startup")
} else {
log.Info().Str("userID", userID).Msg("S3 client initialized on startup")
}
}
GetS3Manager().EnsureClientFromDB(userID)
}(txtid)
}
}
Expand Down Expand Up @@ -461,6 +424,9 @@ func (s *server) startClient(userID string, textjid string, token string, subscr
}
clientManager.SetHTTPClient(userID, httpClient)

// Initialize S3 client if configured (needed when user reconnects after container restart - connectOnStartup only runs for connected=1)
GetS3Manager().EnsureClientFromDB(userID)

if client.Store.ID == nil {
// No ID stored, new login
qrChan, err := client.GetQRChannel(context.Background())
Expand Down Expand Up @@ -817,6 +783,11 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {
s3Config.MediaDelivery = myuserinfo.(Values).Get("MediaDelivery")
}

// Lazy init S3 client if needed (handles reconnect-after-restart when connectOnStartup skipped this user)
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
}

postmap["type"] = "Message"
dowebhook = 1
metaParts := []string{fmt.Sprintf("pushname: %s", evt.Info.PushName), fmt.Sprintf("timestamp: %s", evt.Info.Timestamp)}
Expand Down Expand Up @@ -867,6 +838,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -955,6 +927,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1048,6 +1021,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1130,6 +1104,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// Process S3 upload if enabled
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
// Get sender JID for inbox/outbox determination
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
Expand Down Expand Up @@ -1212,6 +1187,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {

// if using S3 (same stream as other media)
if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") {
ensureS3ClientForUser(txtid)
isIncoming := evt.Info.IsFromMe == false
contactJID := evt.Info.Sender.String()
if evt.Info.IsGroup {
Expand Down