diff --git a/internal/mycli/cli_output.go b/internal/mycli/cli_output.go index 9cd3ee70..efb2f853 100644 --- a/internal/mycli/cli_output.go +++ b/internal/mycli/cli_output.go @@ -1,27 +1,21 @@ package mycli import ( - "cmp" _ "embed" "fmt" "io" - "iter" "log/slog" "math" - "slices" "strings" "text/template" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" - "github.com/apstndb/go-runewidthex" "github.com/apstndb/lox" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/apstndb/spanner-mycli/internal/mycli/metrics" "github.com/go-sprout/sprout" "github.com/go-sprout/sprout/group/hermetic" - "github.com/mattn/go-runewidth" - "github.com/ngicks/go-iterator-helper/hiter" - "github.com/ngicks/go-iterator-helper/hiter/stringsiter" "github.com/samber/lo" ) @@ -31,7 +25,7 @@ func renderTableHeader(header TableHeader, verbose bool) []string { return nil } - return header.internalRender(verbose) + return header.Render(verbose) } // extractTableColumnNames extracts pure column names from the table header without type information. @@ -93,27 +87,41 @@ func printTableData(sysVars *systemVariables, screenWidth int, out io.Writer, re displayFormat = enums.DisplayModeTable // Fall back to table format } + // Build FormatConfig from systemVariables + config := sysVars.toFormatConfig() + + // For SQL export, resolve the table name from Result if available + if displayFormat.IsSQLExport() && result.SQLTableNameForExport != "" { + config.SQLTableName = result.SQLTableNameForExport + } + // Create the appropriate formatter based on the display mode - formatter, err := NewFormatter(displayFormat) + formatter, err := format.NewFormatter(displayFormat) if err != nil { return fmt.Errorf("failed to create formatter: %w", err) } + // For table mode, pass verbose headers and column align via WriteTableWithParams + if displayFormat == enums.DisplayModeUnspecified || displayFormat == enums.DisplayModeTable || displayFormat == enums.DisplayModeTableComment || displayFormat == enums.DisplayModeTableDetailComment { + verboseHeaders := renderTableHeader(result.TableHeader, true) + tableMode := displayFormat + if tableMode == enums.DisplayModeUnspecified { + tableMode = enums.DisplayModeTable + } + return format.WriteTableWithParams(out, result.Rows, columnNames, config, screenWidth, tableMode, format.TableParams{ + VerboseHeaders: verboseHeaders, + ColumnAlign: result.ColumnAlign, + }) + } + // Format and write the result - // Individual formatters handle empty columns appropriately for their format - if err := formatter(out, result, columnNames, sysVars, screenWidth); err != nil { + if err := formatter(out, result.Rows, columnNames, config, screenWidth); err != nil { return fmt.Errorf("formatting failed for mode %v: %w", sysVars.CLIFormat, err) } return nil } -func calculateWidth(result *Result, wc *widthCalculator, screenWidth int, rows []Row) []int { - names := extractTableColumnNames(result.TableHeader) - header := renderTableHeader(result.TableHeader, true) - return calculateOptimalWidth(wc, screenWidth, names, slices.Concat(sliceOf(toRow(header...)), rows)) -} - func printResult(sysVars *systemVariables, screenWidth int, out io.Writer, result *Result, interactive bool, input string) error { if sysVars.MarkdownCodeblock { fmt.Fprintln(out, "```sql") @@ -271,215 +279,6 @@ func resultLine(outputTemplate *template.Template, result *Result, verbose bool) return fmt.Sprintf("Query OK%s%s%s\n%s", affectedRowsPart, elapsedTimePart, batchInfo, detail) } -func calculateOptimalWidth(wc *widthCalculator, screenWidth int, header []string, rows []Row) []int { - // table overhead is: - // len(`| |`) + - // len(` | `) * len(columns) - 1 - overheadWidth := 4 + 3*(len(header)-1) - - // don't mutate - termWidthWithoutOverhead := screenWidth - overheadWidth - - slog.Debug("screen width info", "screenWidth", screenWidth, "remainsWidth", termWidthWithoutOverhead) - - formatIntermediate := func(remainsWidth int, adjustedWidths []int) string { - return fmt.Sprintf("remaining %v, adjustedWidths: %v", remainsWidth-lo.Sum(adjustedWidths), adjustedWidths) - } - - adjustedWidths := adjustByHeader(header, termWidthWithoutOverhead) - - slog.Debug("adjustByName", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) - - var transposedRows [][]string - for columnIdx := range len(header) { - transposedRows = append(transposedRows, slices.Collect( - hiter.Map( - func(in Row) string { - return lo.Must(lo.Nth(in, columnIdx)) // columnIdx represents the index of the column in the row - }, - hiter.Concat(hiter.Once(toRow(header...)), slices.Values(rows)), - ))) - } - - widthCounts := wc.calculateWidthCounts(adjustedWidths, transposedRows) - for { - slog.Debug("widthCounts", "counts", widthCounts) - - firstCounts := hiter.Map( - func(wcs []WidthCount) WidthCount { - return lo.FirstOr(wcs, invalidWidthCount) - }, - slices.Values(widthCounts)) - - // find the largest count idx within available width - idx, target := wc.maxIndex(termWidthWithoutOverhead-lo.Sum(adjustedWidths), adjustedWidths, firstCounts) - if idx < 0 || target.Count() < 1 { - break - } - - widthCounts[idx] = widthCounts[idx][1:] - adjustedWidths[idx] = target.Length() - - slog.Debug("adjusting", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) - } - - slog.Debug("semi final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) - - // Add rest to the longest shortage column. - longestWidths := lo.Map(widthCounts, func(item []WidthCount, _ int) int { - return hiter.Max(hiter.Map(WidthCount.Length, slices.Values(item))) - }) - - idx, _ := MaxWithIdx(math.MinInt, hiter.Unify( - func(longestWidth, adjustedWidth int) int { - return longestWidth - adjustedWidth - }, - hiter.Pairs(slices.Values(longestWidths), slices.Values(adjustedWidths)))) - - if idx != -1 { - adjustedWidths[idx] += termWidthWithoutOverhead - lo.Sum(adjustedWidths) - } - - slog.Debug("final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) - - return adjustedWidths -} - -func MaxWithIdx[E cmp.Ordered](fallback E, seq iter.Seq[E]) (int, E) { - return MaxByWithIdx(fallback, lox.Identity, seq) -} - -func MaxByWithIdx[O cmp.Ordered, E any](fallback E, f func(E) O, seq iter.Seq[E]) (int, E) { - val := fallback - idx := -1 - current := -1 - for v := range seq { - current++ - if f(val) < f(v) { - val = v - idx = current - } - } - return idx, val -} - -func (wc *widthCalculator) StringWidth(s string) int { - return wc.Condition.StringWidth(s) -} - -func (wc *widthCalculator) maxWidth(s string) int { - return hiter.Max(hiter.Map( - wc.StringWidth, - stringsiter.SplitFunc(s, 0, stringsiter.CutNewLine))) -} - -func clipToMax[S interface{ ~[]E }, E cmp.Ordered](s S, maxValue E) iter.Seq[E] { - return hiter.Map( - func(in E) E { - return min(in, maxValue) - }, - slices.Values(s), - ) -} - -func asc[T cmp.Ordered](left, right T) int { - switch { - case left < right: - return -1 - case left > right: - return 1 - default: - return 0 - } -} - -func desc[T cmp.Ordered](left, right T) int { - return asc(right, left) -} - -func adjustToSum(limit int, vs []int) ([]int, int) { - sumVs := lo.Sum(vs) - remains := limit - sumVs - if remains >= 0 { - return vs, remains - } - - curVs := vs - for i := 1; ; i++ { - rev := slices.SortedFunc(slices.Values(lo.Uniq(vs)), desc) - v, ok := hiter.Nth(i, slices.Values(rev)) - if !ok { - break - } - curVs = slices.Collect(clipToMax(vs, v)) - if lo.Sum(curVs) <= limit { - break - } - } - return curVs, limit - lo.Sum(curVs) -} - -var invalidWidthCount = WidthCount{ - // impossible to fit any width - width: math.MaxInt, - // least significant - count: math.MinInt, -} - -type widthCalculator struct{ Condition *runewidthex.Condition } - -func (wc *widthCalculator) maxIndex(ignoreMax int, adjustWidths []int, seq iter.Seq[WidthCount]) (int, WidthCount) { - return MaxByWithIdx( - invalidWidthCount, - WidthCount.Count, - hiter.Unify( - func(adjustWidth int, wc WidthCount) WidthCount { - return lo.Ternary(wc.Length()-adjustWidth <= ignoreMax, wc, invalidWidthCount) - }, - hiter.Pairs(slices.Values(adjustWidths), seq))) -} - -func (wc *widthCalculator) countWidth(ss []string) iter.Seq[WidthCount] { - return hiter.Map( - func(e lo.Entry[int, int]) WidthCount { - return WidthCount{ - width: e.Key, - count: e.Value, - } - }, - slices.Values(lox.EntriesSortedByKey(lo.CountValuesBy(ss, wc.maxWidth)))) -} - -func (wc *widthCalculator) calculateWidthCounts(currentWidths []int, rows [][]string) [][]WidthCount { - var result [][]WidthCount - for columnNo := range len(currentWidths) { - currentWidth := currentWidths[columnNo] - columnValues := rows[columnNo] - largerWidthCounts := slices.Collect( - hiter.Filter( - func(v WidthCount) bool { - return v.Length() > currentWidth - }, - wc.countWidth(columnValues), - )) - result = append(result, largerWidthCounts) - } - return result -} - -type WidthCount struct{ width, count int } - -func (wc WidthCount) Length() int { return wc.width } -func (wc WidthCount) Count() int { return wc.count } - -func adjustByHeader(headers []string, availableWidth int) []int { - nameWidths := slices.Collect(hiter.Map(runewidth.StringWidth, slices.Values(headers))) - - adjustWidths, _ := adjustToSum(availableWidth, nameWidths) - - return adjustWidths -} - func formatTypedHeaderColumn(field *sppb.StructType_Field) string { return field.GetName() + "\n" + formatTypeSimple(field.GetType()) } diff --git a/internal/mycli/cli_output_test.go b/internal/mycli/cli_output_test.go index c41cd710..deb08266 100644 --- a/internal/mycli/cli_output_test.go +++ b/internal/mycli/cli_output_test.go @@ -2,13 +2,13 @@ package mycli import ( "bytes" - "io" "regexp" "strings" "testing" "github.com/MakeNowJust/heredoc/v2" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" ) // Helper functions for common test operations @@ -379,13 +379,12 @@ func TestFormatHelpers(t *testing.T) { t.Parallel() // Helper to test empty input handling for different formatters - testEmptyFormatter := func(t *testing.T, name string, formatter func(io.Writer, *Result, []string, *systemVariables, int) error) { + testEmptyFormatter := func(t *testing.T, name string, formatter format.FormatFunc) { t.Helper() t.Run(name+" with empty input", func(t *testing.T) { var buf bytes.Buffer - result := &Result{Rows: []Row{}} - sysVars := &systemVariables{SkipColumnNames: false} - err := formatter(&buf, result, []string{}, sysVars, 0) + config := format.FormatConfig{SkipColumnNames: false} + err := formatter(&buf, []Row{}, []string{}, config, 0) if err != nil { t.Errorf("expected nil for empty columns, got error: %v", err) } @@ -395,9 +394,12 @@ func TestFormatHelpers(t *testing.T) { }) } - testEmptyFormatter(t, "formatHTML", formatHTML) - testEmptyFormatter(t, "formatXML", formatXML) - testEmptyFormatter(t, "formatCSV", formatCSV) + htmlFormatter, _ := format.NewFormatter(enums.DisplayModeHTML) + xmlFormatter, _ := format.NewFormatter(enums.DisplayModeXML) + csvFormatter, _ := format.NewFormatter(enums.DisplayModeCSV) + testEmptyFormatter(t, "formatHTML", htmlFormatter) + testEmptyFormatter(t, "formatXML", xmlFormatter) + testEmptyFormatter(t, "formatCSV", csvFormatter) t.Run("XML with large dataset", func(t *testing.T) { // Test with a larger dataset to ensure performance @@ -412,9 +414,8 @@ func TestFormatHelpers(t *testing.T) { } var buf bytes.Buffer - result := &Result{Rows: rows} - sysVars := &systemVariables{SkipColumnNames: false} - err := formatXML(&buf, result, columns, sysVars, 0) + config := format.FormatConfig{SkipColumnNames: false} + err := xmlFormatter(&buf, rows, columns, config, 0) if err != nil { t.Errorf("unexpected error: %v", err) } diff --git a/internal/mycli/execute_sql.go b/internal/mycli/execute_sql.go index 397ee7c5..46b893a6 100644 --- a/internal/mycli/execute_sql.go +++ b/internal/mycli/execute_sql.go @@ -16,6 +16,7 @@ import ( "github.com/apstndb/gsqlutils" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/apstndb/spanner-mycli/internal/mycli/metrics" "github.com/apstndb/spanvalue" "github.com/ngicks/go-iterator-helper/hiter" @@ -84,7 +85,7 @@ func prepareFormatConfig(sql string, sysVars *systemVariables) (*spanvalue.Forma // Auto-detect table name if not explicitly set if sysVars.SQLTableName == "" { - detectedTableName, detectionErr := extractTableNameFromQuery(sql) + detectedTableName, detectionErr := format.ExtractTableNameFromQuery(sql) if detectedTableName != "" { // Create a copy of sysVars to use the detected table name for this execution only. // This is important for: diff --git a/internal/mycli/format/config.go b/internal/mycli/format/config.go new file mode 100644 index 00000000..c4116c43 --- /dev/null +++ b/internal/mycli/format/config.go @@ -0,0 +1,38 @@ +package format + +import "io" + +// Row is a type alias for a row of string values. +// Using a type alias (not a new type) ensures zero breaking change at call sites. +type Row = []string + +// FormatConfig holds configuration values needed by formatters. +// This replaces the dependency on *systemVariables, exposing only the fields +// that formatters actually use. +type FormatConfig struct { + TabWidth int + Verbose bool + SkipColumnNames bool + SQLTableName string + SQLBatchSize int64 + PreviewRows int64 +} + +// FormatFunc is a function type that formats and writes result data. +// It takes an output writer, rows, column names, config, and screen width. +type FormatFunc func(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error + +// StreamingFormatter defines the interface for format-specific streaming output. +// Each format (CSV, TAB, etc.) implements this interface to handle streaming output. +type StreamingFormatter interface { + // InitFormat is called once with column names and configuration. + // For table formats, previewRows contains the first N rows for width calculation. + // For other formats, previewRows may be empty as they don't need preview. + InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error + + // WriteRow outputs a single row. + WriteRow(row Row) error + + // FinishFormat completes the output (e.g., closing tags, final flush). + FinishFormat() error +} diff --git a/internal/mycli/format/format.go b/internal/mycli/format/format.go new file mode 100644 index 00000000..26662e19 --- /dev/null +++ b/internal/mycli/format/format.go @@ -0,0 +1,260 @@ +package format + +// This file contains output formatters for query results. +// It implements various output formats (TABLE, CSV, HTML, XML, etc.) with proper error handling. +// All formatters follow a consistent pattern where errors are propagated instead of logged and ignored. + +import ( + "cmp" + "fmt" + "io" + "regexp" + "slices" + "strings" + + "github.com/apstndb/go-runewidthex" + "github.com/apstndb/spanner-mycli/enums" + "github.com/ngicks/go-iterator-helper/hiter" + "github.com/olekukonko/tablewriter" + "github.com/olekukonko/tablewriter/renderer" + "github.com/olekukonko/tablewriter/tw" +) + +var ( + topLeftRe = regexp.MustCompile(`^\+`) + bottomRightRe = regexp.MustCompile(`\+$`) +) + +// writeBuffered writes to a temporary buffer first, and only writes to out if no error occurs. +// This is useful for formats that need to build the entire output before writing. +func writeBuffered(out io.Writer, buildFunc func(out io.Writer) error) error { + var buf strings.Builder + err := buildFunc(&buf) + if err != nil { + return err + } + + output := buf.String() + if output != "" { + _, err = fmt.Fprint(out, output) + return err + } + return nil +} + +// formatTable formats output as an ASCII table. +// verboseNames provides the verbose header names (with type info) when Verbose is true. +// columnAlign provides per-column alignment for special statements like EXPLAIN. +func formatTable(mode enums.DisplayMode) FormatFunc { + return func(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return writeBuffered(out, func(out io.Writer) error { + return WriteTable(out, rows, columnNames, config, screenWidth, mode) + }) + } +} + +// TableParams holds additional parameters for table formatting that are not +// part of the standard FormatConfig (used only by table format). +type TableParams struct { + // VerboseHeaders contains header strings rendered with type information. + // These may include newlines (e.g., "Name\nSTRING") and are used for display + // when Verbose mode is enabled. + VerboseHeaders []string + ColumnAlign []tw.Align +} + +// WriteTable writes the table to the provided writer. +// verboseNames and columnAlign are passed separately because they are specific to table formatting +// and not available in the generic FormatConfig. +func WriteTable(w io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int, mode enums.DisplayMode) error { + return WriteTableWithParams(w, rows, columnNames, config, screenWidth, mode, TableParams{}) +} + +// WriteTableWithParams writes the table with additional table-specific parameters. +func WriteTableWithParams(w io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int, mode enums.DisplayMode, params TableParams) error { + rw := runewidthex.NewCondition() + rw.TabWidth = cmp.Or(config.TabWidth, 4) + + // For comment modes, we need to manipulate the output, so use a buffer + var tableBuf strings.Builder + tableWriter := w + if mode == enums.DisplayModeTableComment || mode == enums.DisplayModeTableDetailComment { + tableWriter = &tableBuf + } + + // Create a table that writes to tableWriter + table := tablewriter.NewTable(tableWriter, + tablewriter.WithRenderer( + renderer.NewBlueprint(tw.Rendition{Symbols: tw.NewSymbols(tw.StyleASCII)})), + tablewriter.WithHeaderAlignment(tw.AlignLeft), + tablewriter.WithTrimSpace(tw.Off), + tablewriter.WithHeaderAutoFormat(tw.Off), + ).Configure(func(twConfig *tablewriter.Config) { + if len(params.ColumnAlign) > 0 { + twConfig.Row.Alignment.PerColumn = params.ColumnAlign + } + twConfig.Row.Formatting.AutoWrap = tw.WrapNone + }) + + wc := &widthCalculator{Condition: rw} + + // Use verbose names for width calculation if available + headerForWidth := columnNames + if config.Verbose && len(params.VerboseHeaders) > 0 { + headerForWidth = params.VerboseHeaders + } + adjustedWidths := CalculateWidth(columnNames, headerForWidth, wc, screenWidth, rows) + + // Determine display headers + displayHeaders := columnNames + if config.Verbose && len(params.VerboseHeaders) > 0 { + displayHeaders = params.VerboseHeaders + } + + headers := slices.Collect(hiter.Unify( + rw.Wrap, + hiter.Pairs( + slices.Values(displayHeaders), + slices.Values(adjustedWidths)))) + + if !config.SkipColumnNames { + table.Header(headers) + } + + for _, row := range rows { + wrappedColumns := slices.Collect(hiter.Unify( + rw.Wrap, + hiter.Pairs(slices.Values(row), slices.Values(adjustedWidths)))) + if err := table.Append(wrappedColumns); err != nil { + return fmt.Errorf("failed to append row: %w", err) + } + } + + forceTableRender := config.Verbose && len(headers) > 0 + + if forceTableRender || len(rows) > 0 { + if err := table.Render(); err != nil { + return fmt.Errorf("failed to render table: %w", err) + } + } + + // Handle comment mode transformations + if mode == enums.DisplayModeTableComment || mode == enums.DisplayModeTableDetailComment { + s := strings.TrimSpace(tableBuf.String()) + // Sanitize */ in table content to prevent premature SQL comment closure. + s = strings.ReplaceAll(s, "*/", "* /") + s = strings.ReplaceAll(s, "\n", "\n ") + s = topLeftRe.ReplaceAllLiteralString(s, "/*") + + if mode == enums.DisplayModeTableComment { + s = bottomRightRe.ReplaceAllLiteralString(s, "*/") + } + + if s != "" { + if _, err := fmt.Fprintln(w, s); err != nil { + return err + } + } + } + + return nil +} + +// formatVertical formats output in vertical format where each row is displayed +// with column names on the left and values on the right. +func formatVertical(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return ExecuteWithFormatter(NewVerticalFormatter(out), rows, columnNames, config) +} + +// formatTab formats output as tab-separated values. +func formatTab(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return ExecuteWithFormatter(NewTabFormatter(out, config.SkipColumnNames), rows, columnNames, config) +} + +// formatCSV formats output as comma-separated values following RFC 4180. +func formatCSV(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return ExecuteWithFormatter(NewCSVFormatter(out, config.SkipColumnNames), rows, columnNames, config) +} + +// formatHTML formats output as an HTML table. +func formatHTML(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return ExecuteWithFormatter(NewHTMLFormatter(out, config.SkipColumnNames), rows, columnNames, config) +} + +// formatXML formats output as XML. +func formatXML(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + return ExecuteWithFormatter(NewXMLFormatter(out, config.SkipColumnNames), rows, columnNames, config) +} + +// NewFormatter creates a new formatter function based on the display mode. +func NewFormatter(mode enums.DisplayMode) (FormatFunc, error) { + switch mode { + case enums.DisplayModeUnspecified: + return formatTable(enums.DisplayModeTable), nil + case enums.DisplayModeTable, enums.DisplayModeTableComment, enums.DisplayModeTableDetailComment: + return formatTable(mode), nil + case enums.DisplayModeVertical: + return formatVertical, nil + case enums.DisplayModeTab: + return formatTab, nil + case enums.DisplayModeCSV: + return formatCSV, nil + case enums.DisplayModeHTML: + return formatHTML, nil + case enums.DisplayModeXML: + return formatXML, nil + case enums.DisplayModeSQLInsert, enums.DisplayModeSQLInsertOrIgnore, enums.DisplayModeSQLInsertOrUpdate: + return FormatSQL(mode), nil + default: + return nil, fmt.Errorf("unsupported display mode: %v", mode) + } +} + +// ExecuteWithFormatter executes buffered formatting using a streaming formatter. +// This reduces duplication in formatCSV, formatTab, formatVertical, etc. +func ExecuteWithFormatter(formatter StreamingFormatter, rows []Row, columnNames []string, config FormatConfig) error { + if len(columnNames) == 0 { + return nil + } + + if err := formatter.InitFormat(columnNames, config, nil); err != nil { + return err + } + + for i, row := range rows { + if err := formatter.WriteRow(row); err != nil { + return fmt.Errorf("failed to write row %d: %w", i+1, err) + } + } + + return formatter.FinishFormat() +} + +// NewStreamingFormatter creates a streaming formatter for the given display mode. +// Note: Table formats (Table, TableComment, TableDetailComment) require screenWidth +// and should be created with NewTableStreamingFormatter directly by the caller. +func NewStreamingFormatter(mode enums.DisplayMode, out io.Writer, config FormatConfig) (StreamingFormatter, error) { + switch mode { + case enums.DisplayModeCSV: + return NewCSVFormatter(out, config.SkipColumnNames), nil + case enums.DisplayModeTab: + return NewTabFormatter(out, config.SkipColumnNames), nil + case enums.DisplayModeVertical: + return NewVerticalFormatter(out), nil + case enums.DisplayModeHTML: + return NewHTMLFormatter(out, config.SkipColumnNames), nil + case enums.DisplayModeXML: + return NewXMLFormatter(out, config.SkipColumnNames), nil + case enums.DisplayModeSQLInsert, enums.DisplayModeSQLInsertOrIgnore, enums.DisplayModeSQLInsertOrUpdate: + return NewSQLStreamingFormatter(out, config, mode) + case enums.DisplayModeTable, enums.DisplayModeTableComment, enums.DisplayModeTableDetailComment: + // Table formats need screenWidth, so they must be created by the caller + // Return a dummy formatter for isStreamingSupported check + if out == io.Discard { + return NewTableStreamingFormatter(out, config, 0, 0), nil + } + return nil, fmt.Errorf("table formats require screenWidth - use NewTableStreamingFormatter directly") + default: + return nil, fmt.Errorf("unsupported streaming format: %v", mode) + } +} diff --git a/internal/mycli/formatters_sql.go b/internal/mycli/format/sql.go similarity index 88% rename from internal/mycli/formatters_sql.go rename to internal/mycli/format/sql.go index 26e294e6..c1b58779 100644 --- a/internal/mycli/formatters_sql.go +++ b/internal/mycli/format/sql.go @@ -1,4 +1,4 @@ -// formatters_sql.go implements SQL export formatting for query results. +// Package format implements SQL export formatting for query results. // It generates INSERT, INSERT OR IGNORE, and INSERT OR UPDATE statements // that can be used for database migration, backup/restore, and test data generation. // @@ -13,7 +13,7 @@ // - Would enable format-specific optimizations and better separation of concerns // // The implementation uses memefish's ast.Path for correct identifier handling. -package mycli +package format import ( "fmt" @@ -25,7 +25,7 @@ import ( "github.com/cloudspannerecosystem/memefish/ast" ) -// extractTableNameFromQuery attempts to extract a table name from a simple SELECT query. +// ExtractTableNameFromQuery attempts to extract a table name from a simple SELECT query. // It supports simple SELECT patterns including: // - SELECT * FROM table_name // - SELECT columns FROM table_name @@ -50,7 +50,7 @@ import ( // The error messages are intended for debug logging to help understand why // auto-detection failed, allowing users to adjust their queries or use explicit // CLI_SQL_TABLE_NAME setting. -func extractTableNameFromQuery(sql string) (string, error) { +func ExtractTableNameFromQuery(sql string) (string, error) { // Validation flow: // 1. Parse SQL and verify it's a SELECT statement // 2. Navigate through AST structure (handling Query wrapper for ORDER BY/LIMIT) @@ -230,7 +230,7 @@ func NewSQLFormatter(out io.Writer, mode enums.DisplayMode, tableName string, ba return nil, fmt.Errorf("CLI_SQL_BATCH_SIZE %d exceeds maximum supported value on this platform", batchSize) } - tablePath, err := parseSimpleTablePath(tableName) + tablePath, err := ParseSimpleTablePath(tableName) if err != nil { return nil, err } @@ -245,12 +245,12 @@ func NewSQLFormatter(out io.Writer, mode enums.DisplayMode, tableName string, ba }, nil } -// parseSimpleTablePath converts a simple table path string from CLI input to an ast.Path. +// ParseSimpleTablePath converts a simple table path string from CLI input to an ast.Path. // This function handles user-friendly input where reserved words don't need quoting. // Examples: "Users", "Order" (reserved word OK), "myschema.Users" // The function does NOT parse SQL expressions - it simply splits on dots. // Quoting for reserved words is handled automatically by ast.Ident.SQL() during output. -func parseSimpleTablePath(input string) (*ast.Path, error) { +func ParseSimpleTablePath(input string) (*ast.Path, error) { // Trim spaces and check for empty input input = strings.TrimSpace(input) if input == "" { @@ -386,16 +386,14 @@ func (f *SQLFormatter) flushBatch() error { } // SQLStreamingFormatter implements StreamingFormatter for SQL export. -// Note: While this supports streaming, partitioned queries currently buffer all results -// before formatting, so streaming benefits are not realized for partitioned queries. type SQLStreamingFormatter struct { formatter *SQLFormatter initialized bool } // NewSQLStreamingFormatter creates a new streaming SQL formatter. -func NewSQLStreamingFormatter(out io.Writer, sysVars *systemVariables, mode enums.DisplayMode) (*SQLStreamingFormatter, error) { - if sysVars.SQLTableName == "" { +func NewSQLStreamingFormatter(out io.Writer, config FormatConfig, mode enums.DisplayMode) (*SQLStreamingFormatter, error) { + if config.SQLTableName == "" { return nil, fmt.Errorf("SQL export requires a table name. Auto-detection failed (query may be too complex).\n" + "Options:\n" + " 1. Use DUMP TABLE for full table exports\n" + @@ -403,7 +401,7 @@ func NewSQLStreamingFormatter(out io.Writer, sysVars *systemVariables, mode enum " 3. Ensure your query matches: SELECT * FROM table_name [WHERE/ORDER BY/LIMIT]") } - formatter, err := NewSQLFormatter(out, mode, sysVars.SQLTableName, sysVars.SQLBatchSize) + formatter, err := NewSQLFormatter(out, mode, config.SQLTableName, config.SQLBatchSize) if err != nil { return nil, err } @@ -415,11 +413,9 @@ func NewSQLStreamingFormatter(out io.Writer, sysVars *systemVariables, mode enum } // InitFormat initializes the formatter with column information. -func (s *SQLStreamingFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (s *SQLStreamingFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { s.initialized = true - // WriteHeader will validate column names - columns := extractTableColumnNames(header) - return s.formatter.WriteHeader(columns) + return s.formatter.WriteHeader(columnNames) } // WriteRow outputs a single row. @@ -431,23 +427,14 @@ func (s *SQLStreamingFormatter) WriteRow(row Row) error { } // FinishFormat completes the SQL export. -func (s *SQLStreamingFormatter) FinishFormat(stats QueryStats, rowCount int64) error { +func (s *SQLStreamingFormatter) FinishFormat() error { return s.formatter.Finish() } -// formatSQL is the non-streaming formatter for SQL export. -// PRECONDITION: result.TableHeader must contain valid column information. -// The TableHeader is essential for generating the column names in INSERT statements -// (e.g., INSERT INTO table(col1, col2, ...) VALUES ...). -// Without valid column headers, SQL export cannot generate syntactically correct INSERT statements. -func formatSQL(mode enums.DisplayMode) FormatFunc { - return func(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - // Use the table name from Result if available (for buffered mode with auto-detection) - // Otherwise fall back to sysVars.SQLTableName - tableName := result.SQLTableNameForExport - if tableName == "" { - tableName = sysVars.SQLTableName - } +// FormatSQL is the non-streaming formatter for SQL export. +func FormatSQL(mode enums.DisplayMode) FormatFunc { + return func(out io.Writer, rows []Row, columnNames []string, config FormatConfig, screenWidth int) error { + tableName := config.SQLTableName if tableName == "" { return fmt.Errorf("SQL export requires a table name. Auto-detection failed (query may be too complex).\n" + @@ -457,7 +444,7 @@ func formatSQL(mode enums.DisplayMode) FormatFunc { " 3. Ensure your query matches: SELECT * FROM table_name [WHERE/ORDER BY/LIMIT]") } - formatter, err := NewSQLFormatter(out, mode, tableName, sysVars.SQLBatchSize) + formatter, err := NewSQLFormatter(out, mode, tableName, config.SQLBatchSize) if err != nil { return err } @@ -468,7 +455,7 @@ func formatSQL(mode enums.DisplayMode) FormatFunc { } // Write all rows - for _, row := range result.Rows { + for _, row := range rows { if err := formatter.WriteRow(row); err != nil { return err } diff --git a/internal/mycli/formatters_sql_test.go b/internal/mycli/format/sql_test.go similarity index 97% rename from internal/mycli/formatters_sql_test.go rename to internal/mycli/format/sql_test.go index 8ba9e37e..930ac0ac 100644 --- a/internal/mycli/formatters_sql_test.go +++ b/internal/mycli/format/sql_test.go @@ -1,4 +1,4 @@ -package mycli +package format import ( "testing" @@ -80,7 +80,7 @@ func TestParseSimpleTablePath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := parseSimpleTablePath(tt.input) + got, err := ParseSimpleTablePath(tt.input) if tt.wantError != "" { require.Error(t, err) @@ -303,6 +303,13 @@ func TestExtractTableNameFromQuery(t *testing.T) { }, { name: "SELECT with CTE", + query: "SELECT * FROM active_users", + wantTableName: "active_users", + wantError: "", + description: "simple table reference works", + }, + { + name: "SELECT with CTE full", query: "WITH active_users AS (SELECT * FROM Users WHERE status = 'ACTIVE') SELECT * FROM active_users", wantTableName: "", wantError: "CTE (WITH clause) not supported", @@ -445,7 +452,7 @@ func TestExtractTableNameFromQuery(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := extractTableNameFromQuery(tt.query) + got, err := ExtractTableNameFromQuery(tt.query) if tt.wantError != "" { // Expecting an error @@ -498,7 +505,7 @@ func TestParseSimpleTablePathSQL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - path, err := parseSimpleTablePath(tt.input) + path, err := ParseSimpleTablePath(tt.input) require.NoError(t, err) assert.Equal(t, tt.wantSQL, path.SQL()) }) diff --git a/internal/mycli/formatters_common.go b/internal/mycli/format/streaming_common.go similarity index 77% rename from internal/mycli/formatters_common.go rename to internal/mycli/format/streaming_common.go index f9870ce3..aa8c6b17 100644 --- a/internal/mycli/formatters_common.go +++ b/internal/mycli/format/streaming_common.go @@ -1,4 +1,4 @@ -package mycli +package format import ( "encoding/csv" @@ -25,19 +25,18 @@ func NewCSVFormatter(out io.Writer, skipHeaders bool) *CSVFormatter { } // InitFormat writes CSV headers if needed. -func (f *CSVFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *CSVFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - if len(columns) == 0 { + if len(columnNames) == 0 { return nil } // Write headers unless skipping if !f.skipHeaders { - if err := f.writer.Write(columns); err != nil { + if err := f.writer.Write(columnNames); err != nil { return fmt.Errorf("failed to write CSV header: %w", err) } } @@ -62,7 +61,7 @@ func (f *CSVFormatter) WriteRow(row Row) error { } // FinishFormat completes CSV output. -func (f *CSVFormatter) FinishFormat(stats QueryStats, rowCount int64) error { +func (f *CSVFormatter) FinishFormat() error { f.writer.Flush() if err := f.writer.Error(); err != nil { return fmt.Errorf("CSV writer error: %w", err) @@ -87,21 +86,20 @@ func NewTabFormatter(out io.Writer, skipHeaders bool) *TabFormatter { } // InitFormat writes tab-separated headers if needed. -func (f *TabFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *TabFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - f.columns = columns + f.columns = columnNames - if len(columns) == 0 { + if len(columnNames) == 0 { return nil } // Write headers unless skipping if !f.skipHeaders { - if _, err := fmt.Fprintln(f.out, strings.Join(columns, "\t")); err != nil { + if _, err := fmt.Fprintln(f.out, strings.Join(columnNames, "\t")); err != nil { return fmt.Errorf("failed to write TAB header: %w", err) } } @@ -124,8 +122,7 @@ func (f *TabFormatter) WriteRow(row Row) error { } // FinishFormat completes tab-separated output. -func (f *TabFormatter) FinishFormat(stats QueryStats, rowCount int64) error { - // Nothing to do for tab format +func (f *TabFormatter) FinishFormat() error { return nil } @@ -147,21 +144,20 @@ func NewVerticalFormatter(out io.Writer) *VerticalFormatter { } // InitFormat prepares vertical format output. -func (f *VerticalFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *VerticalFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - f.columns = columns + f.columns = columnNames - if len(columns) == 0 { + if len(columnNames) == 0 { return nil } // Calculate max column name length for alignment f.maxLen = 0 - for _, col := range columns { + for _, col := range columnNames { if len(col) > f.maxLen { f.maxLen = len(col) } @@ -200,7 +196,6 @@ func (f *VerticalFormatter) WriteRow(row Row) error { } // FinishFormat completes vertical format output. -func (f *VerticalFormatter) FinishFormat(stats QueryStats, rowCount int64) error { - // Nothing to do for vertical format +func (f *VerticalFormatter) FinishFormat() error { return nil } diff --git a/internal/mycli/formatters_html_xml_streaming.go b/internal/mycli/format/streaming_html_xml.go similarity index 85% rename from internal/mycli/formatters_html_xml_streaming.go rename to internal/mycli/format/streaming_html_xml.go index 5d7563f0..0ab963e9 100644 --- a/internal/mycli/formatters_html_xml_streaming.go +++ b/internal/mycli/format/streaming_html_xml.go @@ -1,4 +1,4 @@ -package mycli +package format import ( "bytes" @@ -26,15 +26,14 @@ func NewHTMLFormatter(out io.Writer, skipHeaders bool) *HTMLFormatter { } // InitFormat writes the HTML table opening and headers. -func (f *HTMLFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *HTMLFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - f.columns = columns + f.columns = columnNames - if len(columns) == 0 { + if len(columnNames) == 0 { return nil } @@ -48,7 +47,7 @@ func (f *HTMLFormatter) InitFormat(header TableHeader, sysVars *systemVariables, if _, err := fmt.Fprint(f.out, ""); err != nil { return err } - for _, col := range columns { + for _, col := range columnNames { if _, err := fmt.Fprintf(f.out, "%s", html.EscapeString(col)); err != nil { return err } @@ -86,7 +85,7 @@ func (f *HTMLFormatter) WriteRow(row Row) error { } // FinishFormat completes the HTML table. -func (f *HTMLFormatter) FinishFormat(stats QueryStats, rowCount int64) error { +func (f *HTMLFormatter) FinishFormat() error { if _, err := fmt.Fprintln(f.out, ""); err != nil { return err } @@ -112,15 +111,14 @@ func NewXMLFormatter(out io.Writer, skipHeaders bool) *XMLFormatter { } // InitFormat writes the XML declaration and starts the result set. -func (f *XMLFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *XMLFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - f.columns = columns + f.columns = columnNames - if len(columns) == 0 { + if len(columnNames) == 0 { return nil } @@ -139,7 +137,7 @@ func (f *XMLFormatter) InitFormat(header TableHeader, sysVars *systemVariables, if _, err := fmt.Fprint(f.out, "
"); err != nil { return err } - for _, col := range columns { + for _, col := range columnNames { if _, err := fmt.Fprintf(f.out, "%s", xmlEscape(col)); err != nil { return err } @@ -183,7 +181,7 @@ func (f *XMLFormatter) WriteRow(row Row) error { } // FinishFormat completes the XML output. -func (f *XMLFormatter) FinishFormat(stats QueryStats, rowCount int64) error { +func (f *XMLFormatter) FinishFormat() error { // Close resultset if _, err := fmt.Fprintln(f.out, ""); err != nil { return err diff --git a/internal/mycli/formatters_table_streaming.go b/internal/mycli/format/streaming_table.go similarity index 66% rename from internal/mycli/formatters_table_streaming.go rename to internal/mycli/format/streaming_table.go index 95966a62..8367b1d4 100644 --- a/internal/mycli/formatters_table_streaming.go +++ b/internal/mycli/format/streaming_table.go @@ -1,6 +1,7 @@ -package mycli +package format import ( + "cmp" "fmt" "io" "slices" @@ -16,7 +17,7 @@ import ( // It uses a configurable number of preview rows to calculate optimal column widths. type TableStreamingFormatter struct { out io.Writer - sysVars *systemVariables + config FormatConfig screenWidth int table *tablewriter.Table columns []string @@ -29,10 +30,10 @@ type TableStreamingFormatter struct { // NewTableStreamingFormatter creates a new table streaming formatter. // previewSize determines how many rows to use for width calculation (0 = all rows). -func NewTableStreamingFormatter(out io.Writer, sysVars *systemVariables, screenWidth int, previewSize int) *TableStreamingFormatter { +func NewTableStreamingFormatter(out io.Writer, config FormatConfig, screenWidth int, previewSize int) *TableStreamingFormatter { return &TableStreamingFormatter{ out: out, - sysVars: sysVars, + config: config, screenWidth: screenWidth, previewSize: previewSize, previewRows: []Row{}, @@ -40,18 +41,17 @@ func NewTableStreamingFormatter(out io.Writer, sysVars *systemVariables, screenW } // InitFormat initializes the table with preview rows for width calculation. -func (f *TableStreamingFormatter) InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error { +func (f *TableStreamingFormatter) InitFormat(columnNames []string, config FormatConfig, previewRows []Row) error { if f.initialized { return nil } - columns := extractTableColumnNames(header) - f.columns = columns - f.sysVars = sysVars + f.columns = columnNames + f.config = config f.previewRows = previewRows // Calculate optimal widths using preview rows - f.calculateWidths(columns, previewRows) + f.calculateWidths(columnNames, previewRows) // Create table with streaming configuration f.table = tablewriter.NewTable(f.out, @@ -63,12 +63,12 @@ func (f *TableStreamingFormatter) InitFormat(header TableHeader, sysVars *system tablewriter.WithStreaming(tw.StreamConfig{ Enable: true, }), - ).Configure(func(config *tablewriter.Config) { + ).Configure(func(twConfig *tablewriter.Config) { // Note: Column alignment is not set here because: // 1. Regular SQL queries don't specify column alignment (defaults to left) // 2. EXPLAIN/DESCRIBE statements that use custom alignment don't use streaming mode // They return a complete Result with ColumnAlign set and use buffered formatting - config.Row.Formatting.AutoWrap = tw.WrapNone + twConfig.Row.Formatting.AutoWrap = tw.WrapNone }) // Start streaming table @@ -77,18 +77,14 @@ func (f *TableStreamingFormatter) InitFormat(header TableHeader, sysVars *system } // Set headers - if !f.sysVars.SkipColumnNames && len(columns) > 0 { + if !f.config.SkipColumnNames && len(columnNames) > 0 { // Apply calculated widths to headers - headers := f.wrapHeaders(columns) - // Header method doesn't return an error in tablewriter v1.0.9 + headers := f.wrapHeaders(columnNames) f.table.Header(headers) } f.initialized = true - // Don't write preview rows here - they will be written by the TablePreviewProcessor - // after this initialization completes - return nil } @@ -102,8 +98,7 @@ func (f *TableStreamingFormatter) WriteRow(row Row) error { // Check if we have enough preview rows if f.previewSize > 0 && f.rowsBuffered >= f.previewSize { // Initialize with buffered rows - header := simpleTableHeader(f.columns) - return f.InitFormat(header, f.sysVars, f.previewRows) + return f.InitFormat(f.columns, f.config, f.previewRows) } return nil } @@ -122,11 +117,10 @@ func (f *TableStreamingFormatter) writeRowInternal(row Row) error { } // FinishFormat completes the table output. -func (f *TableStreamingFormatter) FinishFormat(stats QueryStats, rowCount int64) error { +func (f *TableStreamingFormatter) FinishFormat() error { // Initialize if not done yet (e.g., fewer rows than preview size) if !f.initialized && len(f.columns) > 0 { - header := simpleTableHeader(f.columns) - if err := f.InitFormat(header, f.sysVars, f.previewRows); err != nil { + if err := f.InitFormat(f.columns, f.config, f.previewRows); err != nil { return err } } @@ -142,32 +136,14 @@ func (f *TableStreamingFormatter) FinishFormat(stats QueryStats, rowCount int64) } // calculateWidths calculates optimal column widths based on preview rows. -// If previewRows is empty, it uses only header names for width calculation. func (f *TableStreamingFormatter) calculateWidths(columns []string, previewRows []Row) { rw := runewidthex.NewCondition() - rw.TabWidth = 4 - if f.sysVars != nil && f.sysVars.TabWidth > 0 { - rw.TabWidth = int(f.sysVars.TabWidth) - } + rw.TabWidth = cmp.Or(f.config.TabWidth, 4) wc := &widthCalculator{Condition: rw} - // If no preview rows, calculate based on headers only - rowsForCalculation := previewRows - if len(previewRows) == 0 { - // Use empty rows for header-only calculation - // calculateWidth will use header widths as the baseline - rowsForCalculation = []Row{} - } - - // Create a mock result for width calculation - mockResult := &Result{ - TableHeader: simpleTableHeader(columns), - Rows: rowsForCalculation, - } - - // Calculate optimal widths - f.widths = calculateWidth(mockResult, wc, f.screenWidth, rowsForCalculation) + // Calculate optimal widths using column names as header + f.widths = CalculateWidth(columns, columns, wc, f.screenWidth, previewRows) } // wrapHeaders wraps headers according to calculated widths. @@ -177,10 +153,7 @@ func (f *TableStreamingFormatter) wrapHeaders(headers []string) []string { } rw := runewidthex.NewCondition() - rw.TabWidth = 4 - if f.sysVars != nil && f.sysVars.TabWidth > 0 { - rw.TabWidth = int(f.sysVars.TabWidth) - } + rw.TabWidth = cmp.Or(f.config.TabWidth, 4) return slices.Collect(hiter.Unify( rw.Wrap, @@ -194,10 +167,7 @@ func (f *TableStreamingFormatter) wrapRow(row Row) []string { } rw := runewidthex.NewCondition() - rw.TabWidth = 4 - if f.sysVars != nil && f.sysVars.TabWidth > 0 { - rw.TabWidth = int(f.sysVars.TabWidth) - } + rw.TabWidth = cmp.Or(f.config.TabWidth, 4) return slices.Collect(hiter.Unify( rw.Wrap, diff --git a/internal/mycli/format/width.go b/internal/mycli/format/width.go new file mode 100644 index 00000000..f89a0585 --- /dev/null +++ b/internal/mycli/format/width.go @@ -0,0 +1,240 @@ +package format + +import ( + "cmp" + "fmt" + "iter" + "log/slog" + "math" + "slices" + + "github.com/apstndb/go-runewidthex" + "github.com/apstndb/lox" + "github.com/mattn/go-runewidth" + "github.com/ngicks/go-iterator-helper/hiter" + "github.com/ngicks/go-iterator-helper/hiter/stringsiter" + "github.com/samber/lo" +) + +// CalculateWidth calculates optimal column widths for table rendering. +// columnNames are the plain column names, verboseHeaders are optionally +// the verbose header strings (with type info, may contain newlines). +// Both are used for width calculation. +func CalculateWidth(columnNames []string, verboseHeaders []string, wc *widthCalculator, screenWidth int, rows []Row) []int { + return calculateOptimalWidth(wc, screenWidth, columnNames, slices.Concat([]Row{verboseHeaders}, rows)) +} + +func calculateOptimalWidth(wc *widthCalculator, screenWidth int, header []string, rows []Row) []int { + // table overhead is: + // len(`| |`) + + // len(` | `) * len(columns) - 1 + overheadWidth := 4 + 3*(len(header)-1) + + // don't mutate + termWidthWithoutOverhead := screenWidth - overheadWidth + + slog.Debug("screen width info", "screenWidth", screenWidth, "remainsWidth", termWidthWithoutOverhead) + + formatIntermediate := func(remainsWidth int, adjustedWidths []int) string { + return fmt.Sprintf("remaining %v, adjustedWidths: %v", remainsWidth-lo.Sum(adjustedWidths), adjustedWidths) + } + + adjustedWidths := adjustByHeader(header, termWidthWithoutOverhead) + + slog.Debug("adjustByName", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) + + var transposedRows [][]string + for columnIdx := range len(header) { + transposedRows = append(transposedRows, slices.Collect( + hiter.Map( + func(in Row) string { + return lo.Must(lo.Nth(in, columnIdx)) // columnIdx represents the index of the column in the row + }, + hiter.Concat(hiter.Once(Row(header)), slices.Values(rows)), + ))) + } + + widthCounts := wc.calculateWidthCounts(adjustedWidths, transposedRows) + for { + slog.Debug("widthCounts", "counts", widthCounts) + + firstCounts := hiter.Map( + func(wcs []WidthCount) WidthCount { + return lo.FirstOr(wcs, invalidWidthCount) + }, + slices.Values(widthCounts)) + + // find the largest count idx within available width + idx, target := wc.maxIndex(termWidthWithoutOverhead-lo.Sum(adjustedWidths), adjustedWidths, firstCounts) + if idx < 0 || target.Count() < 1 { + break + } + + widthCounts[idx] = widthCounts[idx][1:] + adjustedWidths[idx] = target.Length() + + slog.Debug("adjusting", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) + } + + slog.Debug("semi final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) + + // Add rest to the longest shortage column. + longestWidths := lo.Map(widthCounts, func(item []WidthCount, _ int) int { + return hiter.Max(hiter.Map(WidthCount.Length, slices.Values(item))) + }) + + idx, _ := MaxWithIdx(math.MinInt, hiter.Unify( + func(longestWidth, adjustedWidth int) int { + return longestWidth - adjustedWidth + }, + hiter.Pairs(slices.Values(longestWidths), slices.Values(adjustedWidths)))) + + if idx != -1 { + adjustedWidths[idx] += termWidthWithoutOverhead - lo.Sum(adjustedWidths) + } + + slog.Debug("final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths)) + + return adjustedWidths +} + +// MaxWithIdx returns the index and value of the maximum element in seq. +func MaxWithIdx[E cmp.Ordered](fallback E, seq iter.Seq[E]) (int, E) { + return MaxByWithIdx(fallback, lox.Identity, seq) +} + +// MaxByWithIdx returns the index and value of the element with the maximum key. +func MaxByWithIdx[O cmp.Ordered, E any](fallback E, f func(E) O, seq iter.Seq[E]) (int, E) { + val := fallback + idx := -1 + current := -1 + for v := range seq { + current++ + if f(val) < f(v) { + val = v + idx = current + } + } + return idx, val +} + +type widthCalculator struct{ Condition *runewidthex.Condition } + +func (wc *widthCalculator) StringWidth(s string) int { + return wc.Condition.StringWidth(s) +} + +func (wc *widthCalculator) maxWidth(s string) int { + return hiter.Max(hiter.Map( + wc.StringWidth, + stringsiter.SplitFunc(s, 0, stringsiter.CutNewLine))) +} + +func clipToMax[S interface{ ~[]E }, E cmp.Ordered](s S, maxValue E) iter.Seq[E] { + return hiter.Map( + func(in E) E { + return min(in, maxValue) + }, + slices.Values(s), + ) +} + +func asc[T cmp.Ordered](left, right T) int { + switch { + case left < right: + return -1 + case left > right: + return 1 + default: + return 0 + } +} + +func desc[T cmp.Ordered](left, right T) int { + return asc(right, left) +} + +func adjustToSum(limit int, vs []int) ([]int, int) { + sumVs := lo.Sum(vs) + remains := limit - sumVs + if remains >= 0 { + return vs, remains + } + + curVs := vs + for i := 1; ; i++ { + rev := slices.SortedFunc(slices.Values(lo.Uniq(vs)), desc) + v, ok := hiter.Nth(i, slices.Values(rev)) + if !ok { + break + } + curVs = slices.Collect(clipToMax(vs, v)) + if lo.Sum(curVs) <= limit { + break + } + } + return curVs, limit - lo.Sum(curVs) +} + +var invalidWidthCount = WidthCount{ + // impossible to fit any width + width: math.MaxInt, + // least significant + count: math.MinInt, +} + +func (wc *widthCalculator) maxIndex(ignoreMax int, adjustWidths []int, seq iter.Seq[WidthCount]) (int, WidthCount) { + return MaxByWithIdx( + invalidWidthCount, + WidthCount.Count, + hiter.Unify( + func(adjustWidth int, wc WidthCount) WidthCount { + return lo.Ternary(wc.Length()-adjustWidth <= ignoreMax, wc, invalidWidthCount) + }, + hiter.Pairs(slices.Values(adjustWidths), seq))) +} + +func (wc *widthCalculator) countWidth(ss []string) iter.Seq[WidthCount] { + return hiter.Map( + func(e lo.Entry[int, int]) WidthCount { + return WidthCount{ + width: e.Key, + count: e.Value, + } + }, + slices.Values(lox.EntriesSortedByKey(lo.CountValuesBy(ss, wc.maxWidth)))) +} + +func (wc *widthCalculator) calculateWidthCounts(currentWidths []int, rows [][]string) [][]WidthCount { + var result [][]WidthCount + for columnNo := range len(currentWidths) { + currentWidth := currentWidths[columnNo] + columnValues := rows[columnNo] + largerWidthCounts := slices.Collect( + hiter.Filter( + func(v WidthCount) bool { + return v.Length() > currentWidth + }, + wc.countWidth(columnValues), + )) + result = append(result, largerWidthCounts) + } + return result +} + +// WidthCount tracks the width and frequency of column values. +type WidthCount struct{ width, count int } + +// Length returns the width value. +func (wc WidthCount) Length() int { return wc.width } + +// Count returns the frequency count. +func (wc WidthCount) Count() int { return wc.count } + +func adjustByHeader(headers []string, availableWidth int) []int { + nameWidths := slices.Collect(hiter.Map(runewidth.StringWidth, slices.Values(headers))) + + adjustWidths, _ := adjustToSum(availableWidth, nameWidths) + + return adjustWidths +} diff --git a/internal/mycli/formatter_utils.go b/internal/mycli/formatter_utils.go index 22b3d286..6d163238 100644 --- a/internal/mycli/formatter_utils.go +++ b/internal/mycli/formatter_utils.go @@ -1,82 +1,30 @@ package mycli import ( - "fmt" "io" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" ) -// createStreamingFormatter creates a streaming formatter for the given display mode. -// This is the single source of truth for formatter creation logic. -// Note: Table formats (Table, TableComment, TableDetailComment) require screenWidth -// and should be created with NewTableStreamingFormatter directly by the caller. -func createStreamingFormatter(mode enums.DisplayMode, out io.Writer, sysVars *systemVariables) (StreamingFormatter, error) { - switch mode { - case enums.DisplayModeCSV: - return NewCSVFormatter(out, sysVars.SkipColumnNames), nil - case enums.DisplayModeTab: - return NewTabFormatter(out, sysVars.SkipColumnNames), nil - case enums.DisplayModeVertical: - return NewVerticalFormatter(out), nil - case enums.DisplayModeHTML: - return NewHTMLFormatter(out, sysVars.SkipColumnNames), nil - case enums.DisplayModeXML: - return NewXMLFormatter(out, sysVars.SkipColumnNames), nil - case enums.DisplayModeSQLInsert, enums.DisplayModeSQLInsertOrIgnore, enums.DisplayModeSQLInsertOrUpdate: - return NewSQLStreamingFormatter(out, sysVars, mode) - case enums.DisplayModeTable, enums.DisplayModeTableComment, enums.DisplayModeTableDetailComment: - // Table formats need screenWidth, so they must be created by the caller - // Return a dummy formatter for isStreamingSupported check - if out == io.Discard { - // This is just for checking support - return NewTableStreamingFormatter(out, sysVars, 0, 0), nil - } - return nil, fmt.Errorf("table formats require screenWidth - use NewTableStreamingFormatter directly") - default: - return nil, fmt.Errorf("unsupported streaming format: %v", mode) - } -} - -// executeWithFormatter executes buffered formatting using a streaming formatter. -// This reduces duplication in formatCSV, formatTab, formatVertical, etc. -func executeWithFormatter(formatter StreamingFormatter, result *Result, columnNames []string, sysVars *systemVariables) error { - if len(columnNames) == 0 { - return nil - } - - // Pass TableHeader directly - formatters can extract what they need - if err := formatter.InitFormat(result.TableHeader, sysVars, nil); err != nil { - return err - } - - // Write all rows - for i, row := range result.Rows { - if err := formatter.WriteRow(row); err != nil { - return fmt.Errorf("failed to write row %d: %w", i+1, err) - } - } - - // Finish formatting - return formatter.FinishFormat(QueryStats{}, int64(len(result.Rows))) -} - // createStreamingProcessorForMode creates a streaming processor for the given display mode. // This is the single source of truth for streaming processor creation logic, // used by both execute_sql.go and streaming.go to avoid duplication. func createStreamingProcessorForMode(mode enums.DisplayMode, out io.Writer, sysVars *systemVariables, screenWidth int) (RowProcessor, error) { + config := sysVars.toFormatConfig() + // Special handling for table formats with preview (need screenWidth) if mode == enums.DisplayModeTable || mode == enums.DisplayModeTableComment || mode == enums.DisplayModeTableDetailComment { previewSize := int(sysVars.TablePreviewRows) if previewSize < 0 { previewSize = 0 // 0 means headers-only preview (stream all rows) } - tableFormatter := NewTableStreamingFormatter(out, sysVars, screenWidth, previewSize) + tableFormatter := format.NewTableStreamingFormatter(out, config, screenWidth, previewSize) return NewTablePreviewProcessor(tableFormatter, previewSize), nil } // For non-table formats, use unified creation - formatter, err := createStreamingFormatter(mode, out, sysVars) + formatter, err := format.NewStreamingFormatter(mode, out, config) if err != nil { return nil, err } diff --git a/internal/mycli/formatters.go b/internal/mycli/formatters.go deleted file mode 100644 index fbbf1065..00000000 --- a/internal/mycli/formatters.go +++ /dev/null @@ -1,303 +0,0 @@ -package mycli - -// This file contains output formatters for query results. -// It implements various output formats (TABLE, CSV, HTML, XML, etc.) with proper error handling. -// All formatters follow a consistent pattern where errors are propagated instead of logged and ignored. - -import ( - "cmp" - "encoding/xml" - "fmt" - "html" - "io" - "regexp" - "slices" - "strings" - - "github.com/apstndb/go-runewidthex" - "github.com/apstndb/spanner-mycli/enums" - "github.com/ngicks/go-iterator-helper/hiter" - "github.com/olekukonko/tablewriter" - "github.com/olekukonko/tablewriter/renderer" - "github.com/olekukonko/tablewriter/tw" -) - -var ( - topLeftRe = regexp.MustCompile(`^\+`) - bottomRightRe = regexp.MustCompile(`\+$`) -) - -// writeBuffered writes to a temporary buffer first, and only writes to out if no error occurs. -// This is useful for formats that need to build the entire output before writing. -func writeBuffered(out io.Writer, buildFunc func(out io.Writer) error) error { - var buf strings.Builder - err := buildFunc(&buf) - if err != nil { - return err - } - - output := buf.String() - if output != "" { - _, err = fmt.Fprint(out, output) - return err - } - return nil -} - -// FormatFunc is a function type that formats and writes result data. -// It takes an output writer and result data, and returns an error if any write operation fails. -type FormatFunc func(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error - -// formatTable formats output as an ASCII table. -func formatTable(mode enums.DisplayMode) FormatFunc { - return func(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - return writeBuffered(out, func(out io.Writer) error { - return writeTable(out, result, columnNames, sysVars, screenWidth, mode) - }) - } -} - -// writeTable writes the table to the provided writer. -func writeTable(w io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int, mode enums.DisplayMode) error { - rw := runewidthex.NewCondition() - rw.TabWidth = cmp.Or(int(sysVars.TabWidth), 4) - - rows := result.Rows - - // For comment modes, we need to manipulate the output, so use a buffer - var tableBuf strings.Builder - tableWriter := w - if mode == enums.DisplayModeTableComment || mode == enums.DisplayModeTableDetailComment { - tableWriter = &tableBuf - } - - // Create a table that writes to tableWriter - table := tablewriter.NewTable(tableWriter, - tablewriter.WithRenderer( - renderer.NewBlueprint(tw.Rendition{Symbols: tw.NewSymbols(tw.StyleASCII)})), - tablewriter.WithHeaderAlignment(tw.AlignLeft), - tablewriter.WithTrimSpace(tw.Off), - tablewriter.WithHeaderAutoFormat(tw.Off), - ).Configure(func(config *tablewriter.Config) { - // Use the new Row.Alignment.PerColumn field instead of deprecated Row.ColumnAligns - if len(result.ColumnAlign) > 0 { - config.Row.Alignment.PerColumn = result.ColumnAlign - } - config.Row.Formatting.AutoWrap = tw.WrapNone - }) - - wc := &widthCalculator{Condition: rw} - adjustedWidths := calculateWidth(result, wc, screenWidth, rows) - - headers := slices.Collect(hiter.Unify( - rw.Wrap, - hiter.Pairs( - slices.Values(renderTableHeader(result.TableHeader, sysVars.Verbose)), - slices.Values(adjustedWidths)))) - - if !sysVars.SkipColumnNames { - table.Header(headers) - } - - for _, row := range rows { - wrappedColumns := slices.Collect(hiter.Unify( - rw.Wrap, - hiter.Pairs(slices.Values(row), slices.Values(adjustedWidths)))) - if err := table.Append(wrappedColumns); err != nil { - return fmt.Errorf("failed to append row: %w", err) - } - } - - forceTableRender := sysVars.Verbose && len(headers) > 0 - - if forceTableRender || len(rows) > 0 { - if err := table.Render(); err != nil { - return fmt.Errorf("failed to render table: %w", err) - } - } - - // Handle comment mode transformations - if mode == enums.DisplayModeTableComment || mode == enums.DisplayModeTableDetailComment { - s := strings.TrimSpace(tableBuf.String()) - s = strings.ReplaceAll(s, "\n", "\n ") - s = topLeftRe.ReplaceAllLiteralString(s, "/*") - - if mode == enums.DisplayModeTableComment { - s = bottomRightRe.ReplaceAllLiteralString(s, "*/") - } - - if s != "" { - if _, err := fmt.Fprintln(w, s); err != nil { - return err - } - } - } - - return nil -} - -// formatVertical formats output in vertical format where each row is displayed -// with column names on the left and values on the right. -func formatVertical(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - return executeWithFormatter(NewVerticalFormatter(out), result, columnNames, sysVars) -} - -// formatTab formats output as tab-separated values. -func formatTab(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - return executeWithFormatter(NewTabFormatter(out, sysVars.SkipColumnNames), result, columnNames, sysVars) -} - -// formatCSV formats output as comma-separated values following RFC 4180. -func formatCSV(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - return executeWithFormatter(NewCSVFormatter(out, sysVars.SkipColumnNames), result, columnNames, sysVars) -} - -// formatHTML formats output as an HTML table. -// This is a streaming format that outputs row-by-row without buffering. -func formatHTML(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - // Skip formatting if there are no columns (consistent with formatTab, formatVertical, and formatCSV) - if len(columnNames) == 0 { - return nil - } - - if _, err := fmt.Fprint(out, ""); err != nil { - return err - } - - // Add header row unless skipping column names - if !sysVars.SkipColumnNames { - if _, err := fmt.Fprint(out, ""); err != nil { - return err - } - for _, col := range columnNames { - if _, err := fmt.Fprintf(out, "", html.EscapeString(col)); err != nil { - return err - } - } - if _, err := fmt.Fprint(out, ""); err != nil { - return err - } - } - - // Add data rows - for _, row := range result.Rows { - if _, err := fmt.Fprint(out, ""); err != nil { - return err - } - for _, col := range row { - if _, err := fmt.Fprintf(out, "", html.EscapeString(col)); err != nil { - return err - } - } - if _, err := fmt.Fprint(out, ""); err != nil { - return err - } - } - - if _, err := fmt.Fprintln(out, "
%s
%s
"); err != nil { - return err - } - return nil -} - -// xmlField represents a field element in XML output. -type xmlField struct { - XMLName xml.Name `xml:"field"` - Value string `xml:",chardata"` -} - -// xmlRow represents a row element containing multiple fields. -type xmlRow struct { - XMLName xml.Name `xml:"row"` - Fields []xmlField `xml:"field"` -} - -// xmlHeader represents the optional header element containing column names. -type xmlHeader struct { - XMLName xml.Name `xml:"header"` - Fields []xmlField `xml:"field"` -} - -// xmlResultSet represents the root element of the XML output. -type xmlResultSet struct { - XMLName xml.Name `xml:"resultset"` - XMLNS string `xml:"xmlns:xsi,attr"` - Header *xmlHeader `xml:"header,omitempty"` - Rows []xmlRow `xml:"row"` -} - -// formatXML formats output as XML. -func formatXML(out io.Writer, result *Result, columnNames []string, sysVars *systemVariables, screenWidth int) error { - return writeBuffered(out, func(out io.Writer) error { - // Skip formatting if there are no columns (consistent with other formatters) - if len(columnNames) == 0 { - return nil - } - - // Build the result set structure - resultSet := xmlResultSet{ - XMLNS: "http://www.w3.org/2001/XMLSchema-instance", - Rows: make([]xmlRow, 0, len(result.Rows)), - } - - // Add header fields only if not skipping column names - if !sysVars.SkipColumnNames { - header := &xmlHeader{Fields: make([]xmlField, 0, len(columnNames))} - for _, col := range columnNames { - header.Fields = append(header.Fields, xmlField{Value: col}) - } - resultSet.Header = header - } - - // Add rows - for _, row := range result.Rows { - xmlRow := xmlRow{Fields: make([]xmlField, 0, len(row))} - for _, col := range row { - xmlRow.Fields = append(xmlRow.Fields, xmlField{Value: col}) - } - resultSet.Rows = append(resultSet.Rows, xmlRow) - } - - // Write XML declaration - if _, err := fmt.Fprintln(out, ""); err != nil { - return err - } - - // Marshal the result set - encoder := xml.NewEncoder(out) - encoder.Indent("", "\t") - if err := encoder.Encode(resultSet); err != nil { - return fmt.Errorf("xml encode failed: %w", err) - } - if _, err := fmt.Fprintln(out); err != nil { - return err - } // Add final newline - - return nil - }) -} - -// NewFormatter creates a new formatter function based on the display mode. -func NewFormatter(mode enums.DisplayMode) (FormatFunc, error) { - switch mode { - case enums.DisplayModeUnspecified: - // Should not happen as it's handled in main.go, but provide a sensible default - return formatTable(enums.DisplayModeTable), nil - case enums.DisplayModeTable, enums.DisplayModeTableComment, enums.DisplayModeTableDetailComment: - return formatTable(mode), nil - case enums.DisplayModeVertical: - return formatVertical, nil - case enums.DisplayModeTab: - return formatTab, nil - case enums.DisplayModeCSV: - return formatCSV, nil - case enums.DisplayModeHTML: - return formatHTML, nil - case enums.DisplayModeXML: - return formatXML, nil - case enums.DisplayModeSQLInsert, enums.DisplayModeSQLInsertOrIgnore, enums.DisplayModeSQLInsertOrUpdate: - return formatSQL(mode), nil - default: - return nil, fmt.Errorf("unsupported display mode: %v", mode) - } -} diff --git a/internal/mycli/row_processor.go b/internal/mycli/row_processor.go index 1ecfa1c8..03a36c80 100644 --- a/internal/mycli/row_processor.go +++ b/internal/mycli/row_processor.go @@ -4,6 +4,7 @@ import ( "io" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/apstndb/spanner-mycli/internal/mycli/format" ) // RowProcessor handles rows either in buffered or streaming mode. @@ -29,14 +30,14 @@ type BufferedProcessor struct { rows []Row metadata *sppb.ResultSetMetadata sysVars *systemVariables - formatter FormatFunc + formatter format.FormatFunc out io.Writer screenWidth int result *Result // Accumulates the complete result } // NewBufferedProcessor creates a processor that collects all rows before formatting. -func NewBufferedProcessor(formatter FormatFunc, out io.Writer, screenWidth int) *BufferedProcessor { +func NewBufferedProcessor(formatter format.FormatFunc, out io.Writer, screenWidth int) *BufferedProcessor { return &BufferedProcessor{ formatter: formatter, out: out, @@ -66,11 +67,11 @@ func (p *BufferedProcessor) Finish(stats QueryStats, rowCount int64) error { p.result.Stats = stats p.result.AffectedRows = len(p.rows) - // Extract column names for formatting + // Extract column names and construct FormatConfig for formatting columnNames := extractTableColumnNames(p.result.TableHeader) + config := p.sysVars.toFormatConfig() - // Use the existing formatter - return p.formatter(p.out, p.result, columnNames, p.sysVars, p.screenWidth) + return p.formatter(p.out, p.result.Rows, columnNames, config, p.screenWidth) } // StreamingProcessor processes rows immediately as they arrive. @@ -78,35 +79,19 @@ func (p *BufferedProcessor) Finish(stats QueryStats, rowCount int64) error { type StreamingProcessor struct { metadata *sppb.ResultSetMetadata sysVars *systemVariables - formatter StreamingFormatter + formatter format.StreamingFormatter out io.Writer screenWidth int rowCount int64 initialized bool } -// StreamingFormatter defines the interface for format-specific streaming output. -// Each format (CSV, TAB, etc.) implements this interface to handle streaming output. -type StreamingFormatter interface { - // InitFormat is called once with table header information to output headers. - // TableHeader provides both column names and type information (if available). - // For table formats, previewRows contains the first N rows for width calculation. - // For other formats, previewRows may be empty as they don't need preview. - InitFormat(header TableHeader, sysVars *systemVariables, previewRows []Row) error - - // WriteRow outputs a single row. - WriteRow(row Row) error - - // FinishFormat completes the output (e.g., closing tags, final flush). - FinishFormat(stats QueryStats, rowCount int64) error -} - // TablePreviewProcessor collects a configurable number of rows for table width calculation. // This allows table formats to determine optimal column widths before starting output. type TablePreviewProcessor struct { previewSize int // Number of rows to preview (0 = all rows for non-streaming) previewRows []Row // Collected preview rows - formatter StreamingFormatter + formatter format.StreamingFormatter metadata *sppb.ResultSetMetadata sysVars *systemVariables initialized bool @@ -115,7 +100,7 @@ type TablePreviewProcessor struct { // NewTablePreviewProcessor creates a processor that previews rows for width calculation. // previewSize of 0 means collect all rows (non-streaming mode). -func NewTablePreviewProcessor(formatter StreamingFormatter, previewSize int) *TablePreviewProcessor { +func NewTablePreviewProcessor(formatter format.StreamingFormatter, previewSize int) *TablePreviewProcessor { return &TablePreviewProcessor{ formatter: formatter, previewSize: previewSize, @@ -172,7 +157,7 @@ func (p *TablePreviewProcessor) Finish(stats QueryStats, rowCount int64) error { } } - return p.formatter.FinishFormat(stats, rowCount) + return p.formatter.FinishFormat() } // initializeFormatter initializes the formatter with preview rows. @@ -182,9 +167,11 @@ func (p *TablePreviewProcessor) initializeFormatter() error { } header := toTableHeader(p.metadata.GetRowType().GetFields()) + columnNames := extractTableColumnNames(header) + config := p.sysVars.toFormatConfig() // Initialize formatter with preview rows for width calculation - if err := p.formatter.InitFormat(header, p.sysVars, p.previewRows); err != nil { + if err := p.formatter.InitFormat(columnNames, config, p.previewRows); err != nil { return err } @@ -201,7 +188,7 @@ func (p *TablePreviewProcessor) initializeFormatter() error { } // NewStreamingProcessor creates a processor that outputs rows immediately. -func NewStreamingProcessor(formatter StreamingFormatter, out io.Writer, screenWidth int) *StreamingProcessor { +func NewStreamingProcessor(formatter format.StreamingFormatter, out io.Writer, screenWidth int) *StreamingProcessor { return &StreamingProcessor{ formatter: formatter, out: out, @@ -216,10 +203,11 @@ func (p *StreamingProcessor) Init(metadata *sppb.ResultSetMetadata, sysVars *sys // Get header from metadata header := toTableHeader(metadata.GetRowType().GetFields()) + columnNames := extractTableColumnNames(header) + config := sysVars.toFormatConfig() // Initialize the format (e.g., write CSV headers) - // For streaming formats, we don't need preview rows - if err := p.formatter.InitFormat(header, sysVars, nil); err != nil { + if err := p.formatter.InitFormat(columnNames, config, nil); err != nil { return err } @@ -238,5 +226,5 @@ func (p *StreamingProcessor) ProcessRow(row Row) error { // Finish completes the streaming output. func (p *StreamingProcessor) Finish(stats QueryStats, rowCount int64) error { - return p.formatter.FinishFormat(stats, rowCount) + return p.formatter.FinishFormat() } diff --git a/internal/mycli/statement_processing.go b/internal/mycli/statement_processing.go index 4cb0f533..1b1e3fdf 100644 --- a/internal/mycli/statement_processing.go +++ b/internal/mycli/statement_processing.go @@ -28,6 +28,7 @@ import ( sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/apstndb/gsqlutils/stmtkind" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/apstndb/spanner-mycli/internal/mycli/metrics" "github.com/cloudspannerecosystem/memefish" "github.com/cloudspannerecosystem/memefish/ast" @@ -69,8 +70,9 @@ type BatchInfo struct { } type TableHeader interface { - // internalRender shouldn't be called directly. Use renderTableHeader(). - internalRender(verbose bool) []string + // Render returns the header strings. When verbose is true, type information is included. + // Use renderTableHeader() as a nil-safe wrapper. + Render(verbose bool) []string // structFields returns the complete field information with types if available. // Returns (fields, true) if type information is available. // Returns (nil, false) if only column names are available (e.g., simpleTableHeader). @@ -79,7 +81,7 @@ type TableHeader interface { type simpleTableHeader []string -func (th simpleTableHeader) internalRender(verbose bool) []string { +func (th simpleTableHeader) Render(verbose bool) []string { return th } @@ -147,7 +149,7 @@ func toTableHeader[T interface { type typesTableHeader []*sppb.StructType_Field -func (th typesTableHeader) internalRender(verbose bool) []string { +func (th typesTableHeader) Render(verbose bool) []string { var result []string for _, f := range th { if verbose { @@ -223,7 +225,8 @@ type Result struct { Metrics *metrics.ExecutionMetrics // Performance metrics for query execution } -type Row []string +// Row is a row of string values. It is a type alias for format.Row (= []string). +type Row = format.Row // QueryStats contains query statistics. // Some fields may not have a valid value depending on the environment. diff --git a/internal/mycli/statements_dump.go b/internal/mycli/statements_dump.go index 16eabff3..ba04d2b0 100644 --- a/internal/mycli/statements_dump.go +++ b/internal/mycli/statements_dump.go @@ -10,6 +10,7 @@ import ( "cloud.google.com/go/spanner" dbadminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" ) // DumpDatabaseStatement represents DUMP DATABASE statement @@ -206,9 +207,9 @@ func executeDumpBuffered(ctx context.Context, session *Session, mode dumpMode, s if len(dataResult.Rows) > 0 { var buf bytes.Buffer - tempVars := *session.systemVariables - tempVars.SQLTableName, tempVars.CLIFormat = table, enums.DisplayModeSQLInsert - if err := formatSQL(enums.DisplayModeSQLInsert)(&buf, dataResult, extractTableColumnNames(dataResult.TableHeader), &tempVars, 0); err != nil { + config := session.systemVariables.toFormatConfig() + config.SQLTableName = table + if err := format.FormatSQL(enums.DisplayModeSQLInsert)(&buf, dataResult.Rows, extractTableColumnNames(dataResult.TableHeader), config, 0); err != nil { return fmt.Errorf("failed to format SQL for table %s: %w", table, err) } if buf.Len() > 0 { diff --git a/internal/mycli/streaming.go b/internal/mycli/streaming.go index 3330d57e..7809e09c 100644 --- a/internal/mycli/streaming.go +++ b/internal/mycli/streaming.go @@ -10,6 +10,7 @@ import ( "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/apstndb/spanner-mycli/internal/mycli/metrics" ) @@ -182,7 +183,7 @@ func shouldUseStreaming(sysVars *systemVariables) bool { // isStreamingSupported checks if a specific display mode supports streaming. func isStreamingSupported(mode enums.DisplayMode) bool { - _, err := createStreamingFormatter(mode, io.Discard, &systemVariables{}) + _, err := format.NewStreamingFormatter(mode, io.Discard, format.FormatConfig{}) return err == nil } diff --git a/internal/mycli/streaming_test.go b/internal/mycli/streaming_test.go index 36593dd9..7b6ea68a 100644 --- a/internal/mycli/streaming_test.go +++ b/internal/mycli/streaming_test.go @@ -6,6 +6,7 @@ import ( sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/stretchr/testify/assert" ) @@ -179,18 +180,16 @@ func TestBufferedVsStreaming(t *testing.T) { // Buffered output var bufBuffered bytes.Buffer - result := &Result{ - TableHeader: simpleTableHeader(columnNames), - Rows: rows, - } - err := formatCSV(&bufBuffered, result, columnNames, sysVars, 0) + config := sysVars.toFormatConfig() + csvFormatter, err := format.NewFormatter(enums.DisplayModeCSV) + assert.NoError(t, err) + err = csvFormatter(&bufBuffered, rows, columnNames, config, 0) assert.NoError(t, err) // Streaming output var bufStreaming bytes.Buffer - formatter := NewCSVFormatter(&bufStreaming, sysVars.SkipColumnNames) - header := simpleTableHeader(columnNames) - err = formatter.InitFormat(header, sysVars, nil) + formatter := format.NewCSVFormatter(&bufStreaming, sysVars.SkipColumnNames) + err = formatter.InitFormat(columnNames, config, nil) assert.NoError(t, err) for _, row := range rows { @@ -198,7 +197,7 @@ func TestBufferedVsStreaming(t *testing.T) { assert.NoError(t, err) } - err = formatter.FinishFormat(QueryStats{}, int64(len(rows))) + err = formatter.FinishFormat() assert.NoError(t, err) // Compare outputs diff --git a/internal/mycli/system_variables.go b/internal/mycli/system_variables.go index 70f9a01e..453447fc 100644 --- a/internal/mycli/system_variables.go +++ b/internal/mycli/system_variables.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "github.com/apstndb/spanner-mycli/enums" + "github.com/apstndb/spanner-mycli/internal/mycli/format" "github.com/bufbuild/protocompile" "github.com/cloudspannerecosystem/memefish/ast" "google.golang.org/protobuf/reflect/protodesc" @@ -173,6 +174,18 @@ type systemVariables struct { RetryAbortsInternally bool // RETRY_ABORTS_INTERNALLY (unimplemented) } +// toFormatConfig converts the formatter-relevant fields of systemVariables into a format.FormatConfig. +func (sv *systemVariables) toFormatConfig() format.FormatConfig { + return format.FormatConfig{ + TabWidth: int(sv.TabWidth), + Verbose: sv.Verbose, + SkipColumnNames: sv.SkipColumnNames, + SQLTableName: sv.SQLTableName, + SQLBatchSize: sv.SQLBatchSize, + PreviewRows: sv.TablePreviewRows, + } +} + // parseEndpoint parses an endpoint string into host and port components. // It returns an error if the endpoint is invalid. func parseEndpoint(endpoint string) (host string, port int, err error) {