-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.go
More file actions
419 lines (347 loc) · 11 KB
/
main.go
File metadata and controls
419 lines (347 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
package main
import (
"bytes"
"context"
_ "embed"
"flag"
"fmt"
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"gopkg.in/yaml.v3"
"github.com/ioj/sqlty/compiler"
"github.com/ioj/sqlty/config"
"github.com/ioj/sqlty/db"
"github.com/ioj/sqlty/generator"
"github.com/ioj/sqlty/stmt"
"github.com/ioj/sqlty/watcher"
)
// Version information (can be set via ldflags)
var (
version = "dev"
commit = "unknown"
)
//go:embed db/types.yaml
var defaultTypes []byte
var (
showVersion = flag.Bool("version", false, "print version and exit")
showHelp = flag.Bool("help", false, "show help")
verbose = flag.Bool("verbose", false, "enable verbose output")
configFile = flag.String("config", "sqlty.yaml", "config file path")
timeout = flag.Duration("timeout", 30*time.Second, "database connection timeout")
watchMode = flag.Bool("watch", false, "watch for file changes and recompile automatically")
)
// compilationContext holds the resources needed for compilation.
type compilationContext struct {
cfg *config.Config
resolver *db.Resolver
gen *generator.Generator
ctx context.Context
verbose bool
}
// newCompilationContext creates a new compilation context with connected resources.
func newCompilationContext(ctx context.Context, cfg *config.Config, verbose bool) (*compilationContext, error) {
types, err := loadTypes(cfg)
if err != nil {
return nil, err
}
if verbose {
fmt.Println("Connecting to database...")
}
resolver, err := db.NewResolver(ctx, cfg.DBURL, types)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
gen, err := generator.New(cfg.Paths.Templates, cfg.Paths.Cache)
if err != nil {
resolver.Close()
return nil, fmt.Errorf("failed to initialize generator: %w", err)
}
return &compilationContext{
cfg: cfg,
resolver: resolver,
gen: gen,
ctx: ctx,
verbose: verbose,
}, nil
}
// Close releases the compilation context resources.
func (cc *compilationContext) Close() {
cc.gen.Close()
cc.resolver.Close()
}
// loadTypes loads the type mappings from default or custom file.
func loadTypes(cfg *config.Config) ([]db.PGTypeTranslation, error) {
defaulttypes := &db.PGTypeTranslationsFile{}
if cfg.DefaultTypes == "" {
if err := yaml.NewDecoder(bytes.NewReader(defaultTypes)).Decode(defaulttypes); err != nil {
return nil, fmt.Errorf("failed to decode default types: %w", err)
}
} else {
f, err := os.Open(cfg.DefaultTypes)
if err != nil {
return nil, fmt.Errorf("failed to open custom types file %s: %w", cfg.DefaultTypes, err)
}
defer f.Close()
if err := yaml.NewDecoder(f).Decode(defaulttypes); err != nil {
return nil, fmt.Errorf("failed to decode custom types file %s: %w", cfg.DefaultTypes, err)
}
}
return append(defaulttypes.Types, cfg.Types...), nil
}
// getOutputFilename converts a source SQL file path to its generated Go file path.
func getOutputFilename(sourcePath, outputDir string) string {
base := filepath.Base(sourcePath)
withoutExt := strings.TrimSuffix(base, filepath.Ext(base))
return filepath.Join(outputDir, withoutExt+".gen.go")
}
// compileSingleFile compiles a single SQL file and generates its output.
func compileSingleFile(cc *compilationContext, fname string) error {
q, err := compiler.CompileFile(fname)
if err != nil {
if err == compiler.ErrEmptyFile {
fmt.Printf("warn: %v is empty, ignoring\n", fname)
return nil
}
return fmt.Errorf("%s: %w", fname, err)
}
params, returns, err := cc.resolver.ResolveTypes(cc.ctx, q.PreparedQuery(), q.NotNullArray())
if err != nil {
return fmt.Errorf("%s: failed to resolve types: %w", fname, err)
}
stmtq, err := q.StmtQuery(cc.cfg.PackageName, params, returns)
if err != nil {
return fmt.Errorf("%s: %w", fname, err)
}
if (returns == nil || len(returns.Params) == 0) && stmtq.ExecMode != stmt.ExecModeExec {
fmt.Printf("warn: %s has no return values, changing exec mode to @exec\n", stmtq.Name)
stmtq.ExecMode = stmt.ExecModeExec
}
fnameOut := getOutputFilename(fname, cc.cfg.Paths.Output)
if err := cc.gen.Query(q.Template(), fnameOut, stmtq); err != nil {
return fmt.Errorf("%s: failed to generate query: %w", fname, err)
}
return nil
}
// regenerateSharedFiles generates the shared files (enums, composite types, db utilities).
func regenerateSharedFiles(cc *compilationContext) error {
enums := &stmt.Enums{PackageName: cc.cfg.PackageName, Enums: cc.resolver.Enums()}
if err := cc.gen.Enums(cc.cfg.Paths.Output, enums); err != nil {
return fmt.Errorf("failed to generate enums: %w", err)
}
if err := cc.gen.DB(cc.cfg.Paths.Output, &stmt.DB{PackageName: cc.cfg.PackageName}); err != nil {
return fmt.Errorf("failed to generate db utilities: %w", err)
}
ct, err := cc.resolver.CompositeTypes(cc.ctx)
if err != nil {
return fmt.Errorf("failed to resolve composite types: %w", err)
}
compositeTypes := &stmt.CompositeTypes{PackageName: cc.cfg.PackageName, Types: ct}
if err := cc.gen.CompositeTypes(cc.cfg.Paths.Output, compositeTypes); err != nil {
return fmt.Errorf("failed to generate composite types: %w", err)
}
return nil
}
// runGoimports runs goimports on the output directory.
func runGoimports(outputDir string) {
if _, err := exec.LookPath("goimports"); err != nil {
fmt.Println("warn: goimports not found, generated code may need manual formatting")
return
}
goimports := exec.Command("goimports", "-w", ".")
goimports.Dir = outputDir
if output, err := goimports.CombinedOutput(); err != nil {
fmt.Printf("warn: goimports failed: %v\n%s\n", err, string(output))
fmt.Println("Generated code may need manual formatting")
}
}
// findSQLFiles returns all SQL files in the source directory.
func findSQLFiles(sourceDir string) ([]string, error) {
sqlFiles, err := filepath.Glob(filepath.Join(sourceDir, "*.sql"))
if err != nil {
return nil, fmt.Errorf("failed to find SQL files: %w", err)
}
return sqlFiles, nil
}
func compiledir(cfg *config.Config, verbose bool, timeout time.Duration) error {
sqlFiles, err := findSQLFiles(cfg.Paths.Source)
if err != nil {
return err
}
if len(sqlFiles) == 0 {
return fmt.Errorf("no *.sql files to compile in %v", cfg.Paths.Source)
}
if verbose {
fmt.Printf("Found %d SQL files to process\n", len(sqlFiles))
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cc, err := newCompilationContext(ctx, cfg, verbose)
if err != nil {
return err
}
defer cc.Close()
if err := regenerateSharedFiles(cc); err != nil {
return err
}
var compileTime time.Duration
for i, fname := range sqlFiles {
if verbose {
fmt.Printf("[%d/%d] Processing %s\n", i+1, len(sqlFiles), filepath.Base(fname))
}
t1 := time.Now()
if err := compileSingleFile(cc, fname); err != nil {
return err
}
compileTime += time.Since(t1)
}
runGoimports(cfg.Paths.Output)
if verbose {
fmt.Printf("\nTiming: total=%v\n", compileTime.Round(time.Millisecond))
fmt.Printf("Generated %d query files\n", len(sqlFiles))
}
return nil
}
func watch(cfg *config.Config, verbose bool, timeout time.Duration) error {
fmt.Println("Starting watch mode...")
// Initial full compilation
if err := compiledir(cfg, verbose, timeout); err != nil {
fmt.Printf("Initial compilation error: %v\n", err)
fmt.Println("Continuing to watch for changes...")
} else {
fmt.Println("Initial compilation complete.")
}
// Create watcher with 100ms debounce
w, err := watcher.New(cfg.Paths.Source, 100*time.Millisecond)
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
defer w.Close()
// Setup context with signal handling
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Handle interrupt signals
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
fmt.Println("\nShutting down...")
cancel()
}()
// Create persistent compilation context for watch mode
cc, err := newCompilationContext(ctx, cfg, verbose)
if err != nil {
return fmt.Errorf("failed to setup compilation context: %w", err)
}
defer cc.Close()
events, errors := w.Start(ctx)
fmt.Printf("Watching for changes in %s...\n", cfg.Paths.Source)
for {
select {
case <-ctx.Done():
return nil
case event, ok := <-events:
if !ok {
return nil
}
handleFileChanges(cc, event.Files)
case err, ok := <-errors:
if !ok {
return nil
}
fmt.Printf("Watch error: %v\n", err)
}
}
}
// handleFileChanges processes a batch of changed files.
func handleFileChanges(cc *compilationContext, files []string) {
fmt.Printf("Recompiling %d file(s)...", len(files))
hasError := false
for _, fname := range files {
// Check if file still exists (might have been deleted)
if _, err := os.Stat(fname); os.IsNotExist(err) {
// File was deleted, remove the generated file
fnameOut := getOutputFilename(fname, cc.cfg.Paths.Output)
if err := os.Remove(fnameOut); err != nil && !os.IsNotExist(err) {
fmt.Printf("Error removing %s: %v\n", fnameOut, err)
} else if cc.verbose {
fmt.Printf("Removed %s\n", fnameOut)
}
continue
}
if cc.verbose {
fmt.Printf("Processing %s\n", filepath.Base(fname))
}
if err := compileSingleFile(cc, fname); err != nil {
fmt.Printf("Error: %v\n", err)
hasError = true
}
}
// Always regenerate shared files
if err := regenerateSharedFiles(cc); err != nil {
fmt.Printf("Error generating shared files: %v\n", err)
hasError = true
}
runGoimports(cc.cfg.Paths.Output)
if !hasError {
fmt.Println("Recompilation complete.")
} else {
fmt.Println("Recompilation completed with errors.")
}
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "SQLty - SQL code generator for Go with PostgreSQL\n\n")
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nConfiguration:\n")
fmt.Fprintf(os.Stderr, " Create a sqlty.yaml file with:\n")
fmt.Fprintf(os.Stderr, " dburl: postgres://user:pass@localhost:5432/mydb\n")
fmt.Fprintf(os.Stderr, " paths:\n")
fmt.Fprintf(os.Stderr, " source: ./queries\n")
fmt.Fprintf(os.Stderr, " output: ./db\n")
fmt.Fprintf(os.Stderr, " package: db\n")
}
flag.Parse()
if *showHelp {
flag.Usage()
os.Exit(0)
}
if *showVersion {
fmt.Printf("sqlty version %s (commit %s)\n", version, commit)
os.Exit(0)
}
// Check if config file exists
if _, err := os.Stat(*configFile); os.IsNotExist(err) {
if *configFile == "sqlty.yaml" {
// Default config file doesn't exist - show usage
fmt.Fprintf(os.Stderr, "No configuration file found.\n\n")
flag.Usage()
os.Exit(0)
}
// User specified a config file that doesn't exist
log.Fatalf("Configuration file not found: %s", *configFile)
}
cfg, err := config.LoadFrom(*configFile)
if err != nil {
log.Fatalf("Configuration error: %v", err)
}
if *watchMode {
if err := watch(cfg, *verbose, *timeout); err != nil {
log.Fatalf("Error: %v", err)
}
} else {
if err := compiledir(cfg, *verbose, *timeout); err != nil {
log.Fatalf("Error: %v", err)
}
if *verbose {
fmt.Println("Done!")
}
}
}