Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 25 additions & 226 deletions internal/mycli/cli_output.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
package mycli

import (
"cmp"
_ "embed"
"fmt"
"io"
"iter"
"log/slog"
"math"
"slices"
"strings"
"text/template"

sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/apstndb/go-runewidthex"
"github.com/apstndb/lox"
"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/format"
"github.com/apstndb/spanner-mycli/internal/mycli/metrics"
"github.com/go-sprout/sprout"
"github.com/go-sprout/sprout/group/hermetic"
"github.com/mattn/go-runewidth"
"github.com/ngicks/go-iterator-helper/hiter"
"github.com/ngicks/go-iterator-helper/hiter/stringsiter"
"github.com/samber/lo"
)

Expand All @@ -31,7 +25,7 @@ func renderTableHeader(header TableHeader, verbose bool) []string {
return nil
}

return header.internalRender(verbose)
return header.Render(verbose)
}

// extractTableColumnNames extracts pure column names from the table header without type information.
Expand Down Expand Up @@ -93,27 +87,41 @@ func printTableData(sysVars *systemVariables, screenWidth int, out io.Writer, re
displayFormat = enums.DisplayModeTable // Fall back to table format
}

// Build FormatConfig from systemVariables
config := sysVars.toFormatConfig()

// For SQL export, resolve the table name from Result if available
if displayFormat.IsSQLExport() && result.SQLTableNameForExport != "" {
config.SQLTableName = result.SQLTableNameForExport
}

// Create the appropriate formatter based on the display mode
formatter, err := NewFormatter(displayFormat)
formatter, err := format.NewFormatter(displayFormat)
if err != nil {
return fmt.Errorf("failed to create formatter: %w", err)
}

// For table mode, pass verbose headers and column align via WriteTableWithParams
if displayFormat == enums.DisplayModeUnspecified || displayFormat == enums.DisplayModeTable || displayFormat == enums.DisplayModeTableComment || displayFormat == enums.DisplayModeTableDetailComment {
verboseHeaders := renderTableHeader(result.TableHeader, true)
tableMode := displayFormat
if tableMode == enums.DisplayModeUnspecified {
tableMode = enums.DisplayModeTable
}
return format.WriteTableWithParams(out, result.Rows, columnNames, config, screenWidth, tableMode, format.TableParams{
VerboseHeaders: verboseHeaders,
ColumnAlign: result.ColumnAlign,
})
}

// Format and write the result
// Individual formatters handle empty columns appropriately for their format
if err := formatter(out, result, columnNames, sysVars, screenWidth); err != nil {
if err := formatter(out, result.Rows, columnNames, config, screenWidth); err != nil {
return fmt.Errorf("formatting failed for mode %v: %w", sysVars.CLIFormat, err)
}

return nil
}

func calculateWidth(result *Result, wc *widthCalculator, screenWidth int, rows []Row) []int {
names := extractTableColumnNames(result.TableHeader)
header := renderTableHeader(result.TableHeader, true)
return calculateOptimalWidth(wc, screenWidth, names, slices.Concat(sliceOf(toRow(header...)), rows))
}

func printResult(sysVars *systemVariables, screenWidth int, out io.Writer, result *Result, interactive bool, input string) error {
if sysVars.MarkdownCodeblock {
fmt.Fprintln(out, "```sql")
Expand Down Expand Up @@ -271,215 +279,6 @@ func resultLine(outputTemplate *template.Template, result *Result, verbose bool)
return fmt.Sprintf("Query OK%s%s%s\n%s", affectedRowsPart, elapsedTimePart, batchInfo, detail)
}

func calculateOptimalWidth(wc *widthCalculator, screenWidth int, header []string, rows []Row) []int {
// table overhead is:
// len(`| |`) +
// len(` | `) * len(columns) - 1
overheadWidth := 4 + 3*(len(header)-1)

// don't mutate
termWidthWithoutOverhead := screenWidth - overheadWidth

slog.Debug("screen width info", "screenWidth", screenWidth, "remainsWidth", termWidthWithoutOverhead)

formatIntermediate := func(remainsWidth int, adjustedWidths []int) string {
return fmt.Sprintf("remaining %v, adjustedWidths: %v", remainsWidth-lo.Sum(adjustedWidths), adjustedWidths)
}

adjustedWidths := adjustByHeader(header, termWidthWithoutOverhead)

slog.Debug("adjustByName", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths))

var transposedRows [][]string
for columnIdx := range len(header) {
transposedRows = append(transposedRows, slices.Collect(
hiter.Map(
func(in Row) string {
return lo.Must(lo.Nth(in, columnIdx)) // columnIdx represents the index of the column in the row
},
hiter.Concat(hiter.Once(toRow(header...)), slices.Values(rows)),
)))
}

widthCounts := wc.calculateWidthCounts(adjustedWidths, transposedRows)
for {
slog.Debug("widthCounts", "counts", widthCounts)

firstCounts := hiter.Map(
func(wcs []WidthCount) WidthCount {
return lo.FirstOr(wcs, invalidWidthCount)
},
slices.Values(widthCounts))

// find the largest count idx within available width
idx, target := wc.maxIndex(termWidthWithoutOverhead-lo.Sum(adjustedWidths), adjustedWidths, firstCounts)
if idx < 0 || target.Count() < 1 {
break
}

widthCounts[idx] = widthCounts[idx][1:]
adjustedWidths[idx] = target.Length()

slog.Debug("adjusting", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths))
}

slog.Debug("semi final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths))

// Add rest to the longest shortage column.
longestWidths := lo.Map(widthCounts, func(item []WidthCount, _ int) int {
return hiter.Max(hiter.Map(WidthCount.Length, slices.Values(item)))
})

idx, _ := MaxWithIdx(math.MinInt, hiter.Unify(
func(longestWidth, adjustedWidth int) int {
return longestWidth - adjustedWidth
},
hiter.Pairs(slices.Values(longestWidths), slices.Values(adjustedWidths))))

if idx != -1 {
adjustedWidths[idx] += termWidthWithoutOverhead - lo.Sum(adjustedWidths)
}

slog.Debug("final", "info", formatIntermediate(termWidthWithoutOverhead, adjustedWidths))

return adjustedWidths
}

func MaxWithIdx[E cmp.Ordered](fallback E, seq iter.Seq[E]) (int, E) {
return MaxByWithIdx(fallback, lox.Identity, seq)
}

func MaxByWithIdx[O cmp.Ordered, E any](fallback E, f func(E) O, seq iter.Seq[E]) (int, E) {
val := fallback
idx := -1
current := -1
for v := range seq {
current++
if f(val) < f(v) {
val = v
idx = current
}
}
return idx, val
}

func (wc *widthCalculator) StringWidth(s string) int {
return wc.Condition.StringWidth(s)
}

func (wc *widthCalculator) maxWidth(s string) int {
return hiter.Max(hiter.Map(
wc.StringWidth,
stringsiter.SplitFunc(s, 0, stringsiter.CutNewLine)))
}

func clipToMax[S interface{ ~[]E }, E cmp.Ordered](s S, maxValue E) iter.Seq[E] {
return hiter.Map(
func(in E) E {
return min(in, maxValue)
},
slices.Values(s),
)
}

func asc[T cmp.Ordered](left, right T) int {
switch {
case left < right:
return -1
case left > right:
return 1
default:
return 0
}
}

func desc[T cmp.Ordered](left, right T) int {
return asc(right, left)
}

func adjustToSum(limit int, vs []int) ([]int, int) {
sumVs := lo.Sum(vs)
remains := limit - sumVs
if remains >= 0 {
return vs, remains
}

curVs := vs
for i := 1; ; i++ {
rev := slices.SortedFunc(slices.Values(lo.Uniq(vs)), desc)
v, ok := hiter.Nth(i, slices.Values(rev))
if !ok {
break
}
curVs = slices.Collect(clipToMax(vs, v))
if lo.Sum(curVs) <= limit {
break
}
}
return curVs, limit - lo.Sum(curVs)
}

var invalidWidthCount = WidthCount{
// impossible to fit any width
width: math.MaxInt,
// least significant
count: math.MinInt,
}

type widthCalculator struct{ Condition *runewidthex.Condition }

func (wc *widthCalculator) maxIndex(ignoreMax int, adjustWidths []int, seq iter.Seq[WidthCount]) (int, WidthCount) {
return MaxByWithIdx(
invalidWidthCount,
WidthCount.Count,
hiter.Unify(
func(adjustWidth int, wc WidthCount) WidthCount {
return lo.Ternary(wc.Length()-adjustWidth <= ignoreMax, wc, invalidWidthCount)
},
hiter.Pairs(slices.Values(adjustWidths), seq)))
}

func (wc *widthCalculator) countWidth(ss []string) iter.Seq[WidthCount] {
return hiter.Map(
func(e lo.Entry[int, int]) WidthCount {
return WidthCount{
width: e.Key,
count: e.Value,
}
},
slices.Values(lox.EntriesSortedByKey(lo.CountValuesBy(ss, wc.maxWidth))))
}

func (wc *widthCalculator) calculateWidthCounts(currentWidths []int, rows [][]string) [][]WidthCount {
var result [][]WidthCount
for columnNo := range len(currentWidths) {
currentWidth := currentWidths[columnNo]
columnValues := rows[columnNo]
largerWidthCounts := slices.Collect(
hiter.Filter(
func(v WidthCount) bool {
return v.Length() > currentWidth
},
wc.countWidth(columnValues),
))
result = append(result, largerWidthCounts)
}
return result
}

type WidthCount struct{ width, count int }

func (wc WidthCount) Length() int { return wc.width }
func (wc WidthCount) Count() int { return wc.count }

func adjustByHeader(headers []string, availableWidth int) []int {
nameWidths := slices.Collect(hiter.Map(runewidth.StringWidth, slices.Values(headers)))

adjustWidths, _ := adjustToSum(availableWidth, nameWidths)

return adjustWidths
}

func formatTypedHeaderColumn(field *sppb.StructType_Field) string {
return field.GetName() + "\n" + formatTypeSimple(field.GetType())
}
23 changes: 12 additions & 11 deletions internal/mycli/cli_output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package mycli

import (
"bytes"
"io"
"regexp"
"strings"
"testing"

"github.com/MakeNowJust/heredoc/v2"
"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/format"
)

// Helper functions for common test operations
Expand Down Expand Up @@ -379,13 +379,12 @@ func TestFormatHelpers(t *testing.T) {
t.Parallel()

// Helper to test empty input handling for different formatters
testEmptyFormatter := func(t *testing.T, name string, formatter func(io.Writer, *Result, []string, *systemVariables, int) error) {
testEmptyFormatter := func(t *testing.T, name string, formatter format.FormatFunc) {
t.Helper()
t.Run(name+" with empty input", func(t *testing.T) {
var buf bytes.Buffer
result := &Result{Rows: []Row{}}
sysVars := &systemVariables{SkipColumnNames: false}
err := formatter(&buf, result, []string{}, sysVars, 0)
config := format.FormatConfig{SkipColumnNames: false}
err := formatter(&buf, []Row{}, []string{}, config, 0)
if err != nil {
t.Errorf("expected nil for empty columns, got error: %v", err)
}
Expand All @@ -395,9 +394,12 @@ func TestFormatHelpers(t *testing.T) {
})
}

testEmptyFormatter(t, "formatHTML", formatHTML)
testEmptyFormatter(t, "formatXML", formatXML)
testEmptyFormatter(t, "formatCSV", formatCSV)
htmlFormatter, _ := format.NewFormatter(enums.DisplayModeHTML)
xmlFormatter, _ := format.NewFormatter(enums.DisplayModeXML)
csvFormatter, _ := format.NewFormatter(enums.DisplayModeCSV)
testEmptyFormatter(t, "formatHTML", htmlFormatter)
testEmptyFormatter(t, "formatXML", xmlFormatter)
testEmptyFormatter(t, "formatCSV", csvFormatter)

t.Run("XML with large dataset", func(t *testing.T) {
// Test with a larger dataset to ensure performance
Expand All @@ -412,9 +414,8 @@ func TestFormatHelpers(t *testing.T) {
}

var buf bytes.Buffer
result := &Result{Rows: rows}
sysVars := &systemVariables{SkipColumnNames: false}
err := formatXML(&buf, result, columns, sysVars, 0)
config := format.FormatConfig{SkipColumnNames: false}
err := xmlFormatter(&buf, rows, columns, config, 0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/mycli/execute_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/apstndb/gsqlutils"
"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/format"
"github.com/apstndb/spanner-mycli/internal/mycli/metrics"
"github.com/apstndb/spanvalue"
"github.com/ngicks/go-iterator-helper/hiter"
Expand Down Expand Up @@ -84,7 +85,7 @@ func prepareFormatConfig(sql string, sysVars *systemVariables) (*spanvalue.Forma

// Auto-detect table name if not explicitly set
if sysVars.SQLTableName == "" {
detectedTableName, detectionErr := extractTableNameFromQuery(sql)
detectedTableName, detectionErr := format.ExtractTableNameFromQuery(sql)
if detectedTableName != "" {
// Create a copy of sysVars to use the detected table name for this execution only.
// This is important for:
Expand Down
Loading