Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 54 additions & 9 deletions backend/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,42 @@ func NewMysqlBackend(host string, port int, user, password, database string, opt
}

b := &mysqlBackend{
dsn: dsn,
db: db,
workerName: getWorkerName(options),
options: options,
dsn: dsn,
db: db,
workerName: getWorkerName(options),
options: options,
ownsConnection: true,
}

if options.ApplyMigrations {
if err := b.Migrate(); err != nil {
panic(err)
}
}

return b
}

// NewMysqlBackendWithDB creates a new MySQL backend using an existing database connection.
// When using this constructor, the backend will not close the database connection when Close() is called.
// Migrations are disabled by default; to enable them, use WithApplyMigrations(true) along with
// WithMigrationDSN to provide a DSN that supports multi-statement queries.
func NewMysqlBackendWithDB(db *sql.DB, opts ...option) *mysqlBackend {
options := &options{
Options: backend.ApplyOptions(),
ApplyMigrations: false,
}

for _, opt := range opts {
opt(options)
}

b := &mysqlBackend{
dsn: "",
db: db,
workerName: getWorkerName(options),
options: options,
ownsConnection: false,
}

if options.ApplyMigrations {
Expand All @@ -68,23 +100,36 @@ func NewMysqlBackend(host string, port int, user, password, database string, opt
}

type mysqlBackend struct {
dsn string
db *sql.DB
workerName string
options *options
dsn string
db *sql.DB
workerName string
options *options
ownsConnection bool
}

func (mb *mysqlBackend) FeatureSupported(feature backend.Feature) bool {
return true
}

func (mb *mysqlBackend) Close() error {
if !mb.ownsConnection {
return nil
}
return mb.db.Close()
}

// Migrate applies any pending database migrations.
func (mb *mysqlBackend) Migrate() error {
schemaDsn := mb.dsn + "&multiStatements=true"
// Determine which DSN to use for migrations
var schemaDsn string
if mb.options.MigrationDSN != "" {
schemaDsn = mb.options.MigrationDSN
} else if mb.dsn != "" {
schemaDsn = mb.dsn + "&multiStatements=true"
} else {
return errors.New("cannot apply migrations: no DSN available; use WithMigrationDSN option or apply migrations externally")
}

db, err := sql.Open("mysql", schemaDsn)
if err != nil {
return fmt.Errorf("opening schema database: %w", err)
Expand Down
123 changes: 123 additions & 0 deletions backend/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,126 @@ func Test_MysqlBackend_WorkerName(t *testing.T) {
}
})
}

func Test_MysqlBackendWithDB(t *testing.T) {
if testing.Short() {
t.Skip()
}

t.Run("UsesProvidedConnection", func(t *testing.T) {
// Create database for test
adminDB, err := sql.Open("mysql", fmt.Sprintf("%s:%s@/?parseTime=true&interpolateParams=true", testUser, testPassword))
if err != nil {
t.Fatal(err)
}

dbName := "test_withdb_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if _, err := adminDB.Exec("CREATE DATABASE " + dbName); err != nil {
t.Fatal(err)
}
defer func() {
adminDB.Exec("DROP DATABASE IF EXISTS " + dbName)
adminDB.Close()
}()

// Create our own connection to the test database
dsn := fmt.Sprintf("%s:%s@tcp(localhost:3306)/%s?parseTime=true&interpolateParams=true", testUser, testPassword, dbName)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Create backend with existing connection and migration DSN
migrationDSN := dsn + "&multiStatements=true"
backend := NewMysqlBackendWithDB(db,
WithApplyMigrations(true),
WithMigrationDSN(migrationDSN),
)

// Verify the backend uses our connection
if backend.db != db {
t.Error("Backend should use provided db connection")
}
if backend.ownsConnection {
t.Error("Backend should not own the connection")
}

// Close backend - should NOT close our connection
if err := backend.Close(); err != nil {
t.Fatal(err)
}

// Verify our connection is still usable
if err := db.Ping(); err != nil {
t.Errorf("Connection should still be open after backend.Close(): %v", err)
}
})

t.Run("MigrationsDisabledByDefault", func(t *testing.T) {
// Create database for test
adminDB, err := sql.Open("mysql", fmt.Sprintf("%s:%s@/?parseTime=true&interpolateParams=true", testUser, testPassword))
if err != nil {
t.Fatal(err)
}

dbName := "test_withdb2_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if _, err := adminDB.Exec("CREATE DATABASE " + dbName); err != nil {
t.Fatal(err)
}
defer func() {
adminDB.Exec("DROP DATABASE IF EXISTS " + dbName)
adminDB.Close()
}()

dsn := fmt.Sprintf("%s:%s@tcp(localhost:3306)/%s?parseTime=true&interpolateParams=true", testUser, testPassword, dbName)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Create backend without enabling migrations
backend := NewMysqlBackendWithDB(db)
defer backend.Close()

// Tables should not exist since migrations weren't applied
_, err = db.Exec("SELECT 1 FROM instances LIMIT 1")
if err == nil {
t.Error("Expected error because table should not exist")
}
})

t.Run("MigrationFailsWithoutDSN", func(t *testing.T) {
// Create database for test
adminDB, err := sql.Open("mysql", fmt.Sprintf("%s:%s@/?parseTime=true&interpolateParams=true", testUser, testPassword))
if err != nil {
t.Fatal(err)
}

dbName := "test_withdb3_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if _, err := adminDB.Exec("CREATE DATABASE " + dbName); err != nil {
t.Fatal(err)
}
defer func() {
adminDB.Exec("DROP DATABASE IF EXISTS " + dbName)
adminDB.Close()
}()

dsn := fmt.Sprintf("%s:%s@tcp(localhost:3306)/%s?parseTime=true&interpolateParams=true", testUser, testPassword, dbName)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Create backend without migration DSN - should panic when trying to migrate
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic when ApplyMigrations=true without MigrationDSN")
}
}()

NewMysqlBackendWithDB(db, WithApplyMigrations(true))
})
}
14 changes: 14 additions & 0 deletions backend/mysql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ type options struct {

// ApplyMigrations automatically applies database migrations on startup.
ApplyMigrations bool

// MigrationDSN is an optional DSN to use for running migrations. This is useful when
// using NewMysqlBackendWithDB where no DSN is available. The DSN must support
// multi-statement queries (e.g., include &multiStatements=true).
MigrationDSN string
}

type option func(*options)
Expand All @@ -30,6 +35,15 @@ func WithMySQLOptions(f func(db *sql.DB)) option {
}
}

// WithMigrationDSN sets the DSN to use for running migrations. This is required when
// using NewMysqlBackendWithDB with ApplyMigrations enabled. The DSN should support
// multi-statement queries.
func WithMigrationDSN(dsn string) option {
return func(o *options) {
o.MigrationDSN = dsn
}
}

// WithBackendOptions allows to pass generic backend options.
func WithBackendOptions(opts ...backend.BackendOption) option {
return func(o *options) {
Expand Down
75 changes: 61 additions & 14 deletions backend/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,42 @@ func NewPostgresBackend(host string, port int, user, password, database string,
}

b := &postgresBackend{
dsn: dsn,
db: db,
workerName: getWorkerName(options),
options: options,
dsn: dsn,
db: db,
workerName: getWorkerName(options),
options: options,
ownsConnection: true,
}

if options.ApplyMigrations {
if err := b.Migrate(); err != nil {
panic(err)
}
}

return b
}

// NewPostgresBackendWithDB creates a new Postgres backend using an existing database connection.
// When using this constructor, the backend will not close the database connection when Close() is called.
// Migrations can still be applied using WithApplyMigrations(true) as Postgres does not require
// special connection settings for migrations.
func NewPostgresBackendWithDB(db *sql.DB, opts ...option) *postgresBackend {
options := &options{
Options: backend.ApplyOptions(),
ApplyMigrations: false,
}

for _, opt := range opts {
opt(options)
}

b := &postgresBackend{
dsn: "",
db: db,
workerName: getWorkerName(options),
options: options,
ownsConnection: false,
}

if options.ApplyMigrations {
Expand All @@ -67,26 +99,39 @@ func NewPostgresBackend(host string, port int, user, password, database string,
}

type postgresBackend struct {
dsn string
db *sql.DB
workerName string
options *options
dsn string
db *sql.DB
workerName string
options *options
ownsConnection bool
}

func (pb *postgresBackend) FeatureSupported(feature backend.Feature) bool {
return true
}

func (pb *postgresBackend) Close() error {
if !pb.ownsConnection {
return nil
}
return pb.db.Close()
}

// Migrate applies any pending database migrations.
func (pb *postgresBackend) Migrate() error {
schemaDsn := pb.dsn
db, err := sql.Open("pgx", schemaDsn)
if err != nil {
return fmt.Errorf("opening schema database: %w", err)
var db *sql.DB
var needsClose bool

if pb.dsn != "" {
var err error
db, err = sql.Open("pgx", pb.dsn)
if err != nil {
return fmt.Errorf("opening schema database: %w", err)
}
needsClose = true
} else {
db = pb.db
needsClose = false
}

dbi, err := postgres.WithInstance(db, &postgres.Config{})
Expand All @@ -110,8 +155,10 @@ func (pb *postgresBackend) Migrate() error {
}
}

if err := db.Close(); err != nil {
return fmt.Errorf("closing schema database: %w", err)
if needsClose {
if err := db.Close(); err != nil {
return fmt.Errorf("closing schema database: %w", err)
}
}

return nil
Expand Down
Loading
Loading