diff --git a/handlers.go b/handlers.go index 285a7a2f..2be05072 100644 --- a/handlers.go +++ b/handlers.go @@ -154,15 +154,16 @@ func (s *server) authalice(next http.Handler) http.Handler { if !found { log.Info().Msg("Looking for user information in DB") // Checks DB from matching user and store user values in context - rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0 FROM users WHERE token=$1 LIMIT 1", token) + rows, err := s.db.Query("SELECT id,name,webhook,jid,events,proxy_url,qrcode,history,hmac_key IS NOT NULL AND length(hmac_key) > 0,CASE WHEN s3_enabled THEN 'true' ELSE 'false' END,COALESCE(media_delivery, 'base64') FROM users WHERE token=$1 LIMIT 1", token) if err != nil { s.Respond(w, r, http.StatusInternalServerError, err) return } defer rows.Close() var history sql.NullInt64 + var s3Enabled, mediaDelivery string for rows.Next() { - err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac) + err = rows.Scan(&txtid, &name, &webhook, &jid, &events, &proxy_url, &qrcode, &history, &hasHmac, &s3Enabled, &mediaDelivery) if err != nil { s.Respond(w, r, http.StatusInternalServerError, err) return @@ -176,16 +177,18 @@ func (s *server) authalice(next http.Handler) http.Handler { log.Debug().Str("userId", txtid).Bool("historyValid", history.Valid).Int64("historyValue", history.Int64).Str("historyStr", historyStr).Msg("User authentication - history debug") v := Values{map[string]string{ - "Id": txtid, - "Name": name, - "Jid": jid, - "Webhook": webhook, - "Token": token, - "Proxy": proxy_url, - "Events": events, - "Qrcode": qrcode, - "History": historyStr, - "HasHmac": strconv.FormatBool(hasHmac), + "Id": txtid, + "Name": name, + "Jid": jid, + "Webhook": webhook, + "Token": token, + "Proxy": proxy_url, + "Events": events, + "Qrcode": qrcode, + "History": historyStr, + "HasHmac": strconv.FormatBool(hasHmac), + "S3Enabled": s3Enabled, + "MediaDelivery": mediaDelivery, }} userinfocache.Set(token, v, cache.NoExpiration) diff --git a/helpers.go b/helpers.go index c1d8fabc..83c4c628 100644 --- a/helpers.go +++ b/helpers.go @@ -538,6 +538,7 @@ func ProcessOutgoingMedia(userID string, contactJID string, messageID string, da // Process S3 upload if enabled if s3Config.Enabled && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(userID) // Process S3 upload (outgoing messages are always in outbox) s3Data, err := GetS3Manager().ProcessMediaForS3( context.Background(), diff --git a/main.go b/main.go index f70e7ff3..b0502240 100755 --- a/main.go +++ b/main.go @@ -377,6 +377,9 @@ func main() { } }() + // Set DB reference in S3Manager for lazy client initialization + GetS3Manager().SetDB(db) + // Initialize the schema if err = initializeSchema(db); err != nil { log.Fatal().Err(err).Msg("Failed to initialize schema") diff --git a/s3manager.go b/s3manager.go index ba5bd502..77c71e9a 100644 --- a/s3manager.go +++ b/s3manager.go @@ -9,6 +9,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/jmoiron/sqlx" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" @@ -32,6 +33,7 @@ type S3Config struct { // S3Manager manages S3 operations type S3Manager struct { mu sync.RWMutex + db *sqlx.DB clients map[string]*s3.Client configs map[string]*S3Config } @@ -47,6 +49,56 @@ func GetS3Manager() *S3Manager { return s3Manager } +// SetDB sets the database reference for lazy S3 client initialization +func (m *S3Manager) SetDB(db *sqlx.DB) { + m.mu.Lock() + defer m.mu.Unlock() + m.db = db +} + +// EnsureClientFromDB loads S3 config from DB and initializes client if enabled. Returns true if client is available. +func (m *S3Manager) EnsureClientFromDB(userID string) bool { + if _, _, ok := m.GetClient(userID); ok { + return true + } + m.mu.RLock() + db := m.db + m.mu.RUnlock() + if db == nil { + return false + } + var s3DbConfig struct { + Enabled bool `db:"s3_enabled"` + Endpoint string `db:"s3_endpoint"` + Region string `db:"s3_region"` + Bucket string `db:"s3_bucket"` + AccessKey string `db:"s3_access_key"` + SecretKey string `db:"s3_secret_key"` + PathStyle bool `db:"s3_path_style"` + PublicURL string `db:"s3_public_url"` + MediaDelivery string `db:"media_delivery"` + RetentionDays int `db:"s3_retention_days"` + } + query := `SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket, s3_access_key, s3_secret_key, s3_path_style, s3_public_url, COALESCE(media_delivery, 'base64') AS media_delivery, COALESCE(s3_retention_days, 30) AS s3_retention_days FROM users WHERE id = $1` + query = db.Rebind(query) + if err := db.Get(&s3DbConfig, query, userID); err != nil || !s3DbConfig.Enabled { + return false + } + config := &S3Config{ + Enabled: s3DbConfig.Enabled, + Endpoint: s3DbConfig.Endpoint, + Region: s3DbConfig.Region, + Bucket: s3DbConfig.Bucket, + AccessKey: s3DbConfig.AccessKey, + SecretKey: s3DbConfig.SecretKey, + PathStyle: s3DbConfig.PathStyle, + PublicURL: s3DbConfig.PublicURL, + MediaDelivery: s3DbConfig.MediaDelivery, + RetentionDays: s3DbConfig.RetentionDays, + } + return m.InitializeS3Client(userID, config) == nil +} + // InitializeS3Client creates or updates S3 client for a user func (m *S3Manager) InitializeS3Client(userID string, config *S3Config) error { if !config.Enabled { @@ -192,7 +244,13 @@ func (m *S3Manager) GenerateS3Key(userID, contactJID, messageID string, mimeType func (m *S3Manager) UploadToS3(ctx context.Context, userID string, key string, data []byte, mimeType string) error { client, config, ok := m.GetClient(userID) if !ok { - return fmt.Errorf("S3 client not initialized for user %s", userID) + // Try lazy init from DB if available (handles reconnect-after-restart) + if m.EnsureClientFromDB(userID) { + client, config, ok = m.GetClient(userID) + } + if !ok { + return fmt.Errorf("S3 client not initialized for user %s", userID) + } } // Set content type and cache headers for preview diff --git a/wmiau.go b/wmiau.go index f34012ac..517ea280 100644 --- a/wmiau.go +++ b/wmiau.go @@ -43,6 +43,11 @@ type MyClient struct { s *server } +// ensureS3ClientForUser loads S3 config from DB and initializes client if not already present (lazy init for reconnect-after-restart) +func ensureS3ClientForUser(userID string) { + GetS3Manager().EnsureClientFromDB(userID) +} + func sendToGlobalWebHook(jsonData []byte, token string, userID string) { jsonDataStr := string(jsonData) @@ -295,49 +300,7 @@ func (s *server) connectOnStartup() { // Initialize S3 client if configured go func(userID string) { - var s3Config struct { - Enabled bool `db:"s3_enabled"` - Endpoint string `db:"s3_endpoint"` - Region string `db:"s3_region"` - Bucket string `db:"s3_bucket"` - AccessKey string `db:"s3_access_key"` - SecretKey string `db:"s3_secret_key"` - PathStyle bool `db:"s3_path_style"` - PublicURL string `db:"s3_public_url"` - RetentionDays int `db:"s3_retention_days"` - } - - err := s.db.Get(&s3Config, ` - SELECT s3_enabled, s3_endpoint, s3_region, s3_bucket, - s3_access_key, s3_secret_key, s3_path_style, - s3_public_url, s3_retention_days - FROM users WHERE id = $1`, userID) - - if err != nil { - log.Error().Err(err).Str("userID", userID).Msg("Failed to get S3 config") - return - } - - if s3Config.Enabled { - config := &S3Config{ - Enabled: s3Config.Enabled, - Endpoint: s3Config.Endpoint, - Region: s3Config.Region, - Bucket: s3Config.Bucket, - AccessKey: s3Config.AccessKey, - SecretKey: s3Config.SecretKey, - PathStyle: s3Config.PathStyle, - PublicURL: s3Config.PublicURL, - RetentionDays: s3Config.RetentionDays, - } - - err = GetS3Manager().InitializeS3Client(userID, config) - if err != nil { - log.Error().Err(err).Str("userID", userID).Msg("Failed to initialize S3 client on startup") - } else { - log.Info().Str("userID", userID).Msg("S3 client initialized on startup") - } - } + GetS3Manager().EnsureClientFromDB(userID) }(txtid) } } @@ -461,6 +424,9 @@ func (s *server) startClient(userID string, textjid string, token string, subscr } clientManager.SetHTTPClient(userID, httpClient) + // Initialize S3 client if configured (needed when user reconnects after container restart - connectOnStartup only runs for connected=1) + GetS3Manager().EnsureClientFromDB(userID) + if client.Store.ID == nil { // No ID stored, new login qrChan, err := client.GetQRChannel(context.Background()) @@ -817,6 +783,11 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { s3Config.MediaDelivery = myuserinfo.(Values).Get("MediaDelivery") } + // Lazy init S3 client if needed (handles reconnect-after-restart when connectOnStartup skipped this user) + if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) + } + postmap["type"] = "Message" dowebhook = 1 metaParts := []string{fmt.Sprintf("pushname: %s", evt.Info.PushName), fmt.Sprintf("timestamp: %s", evt.Info.Timestamp)} @@ -867,6 +838,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { // Process S3 upload if enabled if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) // Get sender JID for inbox/outbox determination isIncoming := evt.Info.IsFromMe == false contactJID := evt.Info.Sender.String() @@ -955,6 +927,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { // Process S3 upload if enabled if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) // Get sender JID for inbox/outbox determination isIncoming := evt.Info.IsFromMe == false contactJID := evt.Info.Sender.String() @@ -1048,6 +1021,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { // Process S3 upload if enabled if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) // Get sender JID for inbox/outbox determination isIncoming := evt.Info.IsFromMe == false contactJID := evt.Info.Sender.String() @@ -1130,6 +1104,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { // Process S3 upload if enabled if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) // Get sender JID for inbox/outbox determination isIncoming := evt.Info.IsFromMe == false contactJID := evt.Info.Sender.String() @@ -1212,6 +1187,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) { // if using S3 (same stream as other media) if s3Config.Enabled == "true" && (s3Config.MediaDelivery == "s3" || s3Config.MediaDelivery == "both") { + ensureS3ClientForUser(txtid) isIncoming := evt.Info.IsFromMe == false contactJID := evt.Info.Sender.String() if evt.Info.IsGroup {