Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/mycli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
tcspanner "github.com/testcontainers/testcontainers-go/modules/gcloud/spanner"

"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/streamio"

"github.com/cloudspannerecosystem/memefish"

Expand Down Expand Up @@ -573,7 +574,7 @@ func run(ctx context.Context, opts *spannerOptions) error {
var originalOut io.Writer = os.Stdout

// Create StreamManager for managing all input/output streams
streamManager := NewStreamManager(os.Stdin, originalOut, errStream)
streamManager := streamio.NewStreamManager(os.Stdin, originalOut, errStream)
// StreamManager will automatically detect if os.Stdout is a TTY
if term.IsTerminal(int(os.Stdout.Fd())) {
streamManager.SetTtyStream(os.Stdout)
Expand Down
6 changes: 4 additions & 2 deletions internal/mycli/cli_current_width_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os"
"strconv"
"testing"

"github.com/apstndb/spanner-mycli/internal/mycli/streamio"
)

func TestCliCurrentWidthWithTee(t *testing.T) {
Expand All @@ -15,7 +17,7 @@ func TestCliCurrentWidthWithTee(t *testing.T) {

t.Run("with TtyStream in StreamManager", func(t *testing.T) {
// Setup StreamManager with a buffer for tee output
teeManager := NewStreamManager(os.Stdin, os.Stdout, os.Stderr)
teeManager := streamio.NewStreamManager(os.Stdin, os.Stdout, os.Stderr)
teeManager.SetTtyStream(os.Stdout)

sysVars := &systemVariables{
Expand All @@ -42,7 +44,7 @@ func TestCliCurrentWidthWithTee(t *testing.T) {
t.Run("without TtyStream and non-file stream", func(t *testing.T) {
// Setup StreamManager with non-TTY output
consoleBuf := &bytes.Buffer{}
teeManager := NewStreamManager(io.NopCloser(bytes.NewReader(nil)), consoleBuf, consoleBuf)
teeManager := streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), consoleBuf, consoleBuf)
// Do not set TTY stream

sysVars := &systemVariables{
Expand Down
3 changes: 2 additions & 1 deletion internal/mycli/cli_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/apstndb/lox"
"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/decoder"
"github.com/apstndb/spanner-mycli/internal/mycli/format"
"github.com/apstndb/spanner-mycli/internal/mycli/metrics"
"github.com/go-sprout/sprout"
Expand Down Expand Up @@ -280,5 +281,5 @@ func resultLine(outputTemplate *template.Template, result *Result, verbose bool)
}

func formatTypedHeaderColumn(field *sppb.StructType_Field) string {
return field.GetName() + "\n" + formatTypeSimple(field.GetType())
return field.GetName() + "\n" + decoder.FormatTypeSimple(field.GetType())
}
15 changes: 8 additions & 7 deletions internal/mycli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

"github.com/apstndb/spanner-mycli/internal/mycli/streamio"
"github.com/apstndb/spanner-mycli/internal/protostruct"
)

Expand Down Expand Up @@ -1122,7 +1123,7 @@ func TestCli_handleExit(t *testing.T) {
t.Parallel()
outBuf := &bytes.Buffer{}
sysVars := &systemVariables{
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
}
cli := &Cli{
SessionHandler: NewSessionHandler(&Session{}), // Dummy session, Close() is now safe with nil client
Expand Down Expand Up @@ -1163,7 +1164,7 @@ func TestCli_ExitOnError(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
errBuf := &bytes.Buffer{}
sysVars := &systemVariables{
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), errBuf, errBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), errBuf, errBuf),
}
cli := &Cli{
SessionHandler: NewSessionHandler(&Session{}), // Dummy session, Close() is now safe with nil client
Expand Down Expand Up @@ -1266,7 +1267,7 @@ func TestCli_handleSpecialStatements(t *testing.T) {
}

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(
sysVars.StreamManager = streamio.NewStreamManager(
io.NopCloser(strings.NewReader(tt.confirmInput)), // InStream for confirm
outBuf,
errBuf,
Expand Down Expand Up @@ -1323,7 +1324,7 @@ func TestCli_PrintResult(t *testing.T) {
sysVars := &systemVariables{
UsePager: tt.usePager,
CLIFormat: enums.DisplayModeTab, // Use TAB format for predictable output
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
}
cli := &Cli{
SystemVariables: sysVars,
Expand Down Expand Up @@ -1367,7 +1368,7 @@ func TestCli_PrintBatchError(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
errBuf := &bytes.Buffer{}
sysVars := &systemVariables{
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), errBuf, errBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), errBuf, errBuf),
}
cli := &Cli{
SystemVariables: sysVars,
Expand Down Expand Up @@ -1553,7 +1554,7 @@ func TestCli_executeSourceFile(t *testing.T) {
sysVars := &systemVariables{
BuildStatementMode: enums.ParseModeFallback,
CLIFormat: enums.DisplayModeTab,
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, outBuf),
}
session := &Session{systemVariables: sysVars}

Expand Down Expand Up @@ -1639,7 +1640,7 @@ func TestCli_executeSourceFile_FileTooLarge(t *testing.T) {
outBuf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
sysVars := &systemVariables{
StreamManager: NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, errBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(bytes.NewReader(nil)), outBuf, errBuf),
}
cli := &Cli{
SessionHandler: NewSessionHandler(&Session{}),
Expand Down
16 changes: 10 additions & 6 deletions internal/mycli/decoder.go → internal/mycli/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// limitations under the License.
//

package mycli
package decoder

import (
"cmp"
Expand All @@ -40,7 +40,11 @@ func DecodeRow(row *spanner.Row) ([]string, error) {
return spanvalue.FormatRowSpannerCLICompatible(row)
}

func formatConfigWithProto(fds *descriptorpb.FileDescriptorSet, multiline bool) (*spanvalue.FormatConfig, error) {
// FormatConfigWithProto creates a format configuration for decoding Spanner values,
// with special handling for PROTO and ENUM types based on the provided protobuf file descriptor set.
// If fds is nil, it returns a config without custom proto/enum support.
// The multiline parameter controls the formatting of protobuf messages.
func FormatConfigWithProto(fds *descriptorpb.FileDescriptorSet, multiline bool) (*spanvalue.FormatConfig, error) {
types, err := dynamicTypesByFDS(fds)
if err != nil {
return nil, err
Expand Down Expand Up @@ -134,8 +138,8 @@ func formatEnum(types protoEnumResolver) func(formatter spanvalue.Formatter, val
}
}

// formatTypeSimple is format type for headers.
func formatTypeSimple(typ *sppb.Type) string {
// FormatTypeSimple is format type for headers.
func FormatTypeSimple(typ *sppb.Type) string {
return spantype.FormatType(typ, spantype.FormatOption{
Struct: spantype.StructModeBase,
Proto: spantype.ProtoEnumModeLeafWithKind,
Expand All @@ -144,7 +148,7 @@ func formatTypeSimple(typ *sppb.Type) string {
})
}

// formatTypeVerbose is format type for DESCRIBE.
func formatTypeVerbose(typ *sppb.Type) string {
// FormatTypeVerbose is format type for DESCRIBE.
func FormatTypeVerbose(typ *sppb.Type) string {
return spantype.FormatTypeMoreVerbose(typ)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// limitations under the License.
//

package mycli
package decoder

import (
"math/big"
Expand Down Expand Up @@ -399,7 +399,7 @@ func TestDecodeColumn(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
got, err := lo.Must(formatConfigWithProto(test.fds, test.multiline)).FormatToplevelColumn(createColumnValue(t, test.value))
got, err := lo.Must(FormatConfigWithProto(test.fds, test.multiline)).FormatToplevelColumn(createColumnValue(t, test.value))
if err != nil {
t.Error(err)
}
Expand All @@ -410,7 +410,7 @@ func TestDecodeColumn(t *testing.T) {
t.Error(err)
}
if diff := cmp.Diff(nm.Interface(), test.wantMessage, protocmp.Transform()); diff != "" {
t.Errorf("formatConfigWithProto(%v) mismatch (-got +want):\n%s", test.value, diff)
t.Errorf("FormatConfigWithProto(%v) mismatch (-got +want):\n%s", test.value, diff)
}
} else {
if got != test.want {
Expand Down Expand Up @@ -442,15 +442,15 @@ func TestDecodeColumnRoundTripEnum(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
got, err := lo.Must(formatConfigWithProto(test.fds, false)).FormatToplevelColumn(createColumnValue(t, test.value))
got, err := lo.Must(FormatConfigWithProto(test.fds, false)).FormatToplevelColumn(createColumnValue(t, test.value))
if err != nil {
t.Error(err)
}

gotEnumValue := test.want.Type().Descriptor().Values().ByName(protoreflect.Name(got))
gotNumber := gotEnumValue.Number()
if gotNumber != test.want.Number() {
t.Errorf("formatConfigWithProto(%v): %v(%v), want: %v(%v)", test.value, gotEnumValue.Name(), gotNumber, test.want, test.want.Number())
t.Errorf("FormatConfigWithProto(%v): %v(%v), want: %v(%v)", test.value, gotEnumValue.Name(), gotNumber, test.want, test.want.Number())
}
})
}
Expand Down Expand Up @@ -482,7 +482,7 @@ func TestDecodeColumnRoundTripProto(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
got, err := lo.Must(formatConfigWithProto(test.fds, test.multiline)).FormatToplevelColumn(createColumnValue(t, test.value))
got, err := lo.Must(FormatConfigWithProto(test.fds, test.multiline)).FormatToplevelColumn(createColumnValue(t, test.value))
if err != nil {
t.Error(err)
}
Expand All @@ -492,7 +492,7 @@ func TestDecodeColumnRoundTripProto(t *testing.T) {
t.Error(err)
}
if diff := cmp.Diff(nm.Interface(), test.want, protocmp.Transform()); diff != "" {
t.Errorf("formatConfigWithProto(%v) mismatch (-got +want):\n%s", test.value, diff)
t.Errorf("FormatConfigWithProto(%v) mismatch (-got +want):\n%s", test.value, diff)
}
})
}
Expand Down Expand Up @@ -549,9 +549,9 @@ func TestFormatTypeVerbose(t *testing.T) {
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
got := formatTypeVerbose(test.sppbType)
got := FormatTypeVerbose(test.sppbType)
if diff := cmp.Diff(got, test.want); diff != "" {
t.Errorf("formatTypeVerbose(%v) mismatch (-got +want):\n%s", test.sppbType, diff)
t.Errorf("FormatTypeVerbose(%v) mismatch (-got +want):\n%s", test.sppbType, diff)
}
})
}
Expand Down
5 changes: 3 additions & 2 deletions internal/mycli/execute_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/apstndb/gsqlutils"
"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/decoder"
"github.com/apstndb/spanner-mycli/internal/mycli/format"
"github.com/apstndb/spanner-mycli/internal/mycli/metrics"
"github.com/apstndb/spanvalue"
Expand Down Expand Up @@ -104,7 +105,7 @@ func prepareFormatConfig(sql string, sysVars *systemVariables) (*spanvalue.Forma
default:
// Use regular display formatting for other modes
// formatConfigWithProto handles custom proto descriptors if set
fc, err = formatConfigWithProto(sysVars.ProtoDescriptor, sysVars.MultilineProtoText)
fc, err = decoder.FormatConfigWithProto(sysVars.ProtoDescriptor, sysVars.MultilineProtoText)
usingSQLLiterals = false
}

Expand Down Expand Up @@ -726,7 +727,7 @@ func spannerRowToRow(fc *spanvalue.FormatConfig) func(row *spanner.Row) (Row, er
}

func runPartitionedQuery(ctx context.Context, session *Session, sql string) (*Result, error) {
fc, err := formatConfigWithProto(session.systemVariables.ProtoDescriptor, session.systemVariables.MultilineProtoText)
fc, err := decoder.FormatConfigWithProto(session.systemVariables.ProtoDescriptor, session.systemVariables.MultilineProtoText)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion internal/mycli/integration_dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/apstndb/spanner-mycli/enums"
"github.com/apstndb/spanner-mycli/internal/mycli/streamio"
"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -267,7 +268,7 @@ func TestDumpWithStreaming(t *testing.T) {
// Replace the session's output stream with our buffer
// This simulates streaming mode with captured output
originalStream := session.systemVariables.StreamManager
session.systemVariables.StreamManager = NewStreamManager(
session.systemVariables.StreamManager = streamio.NewStreamManager(
originalStream.GetInStream(),
&buf, // Use our buffer as output
originalStream.GetErrStream(),
Expand Down
7 changes: 4 additions & 3 deletions internal/mycli/integration_mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/apstndb/spanemuboost"
"github.com/apstndb/spanner-mycli/internal/mycli/streamio"
"github.com/cloudspannerecosystem/memefish/ast"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/samber/lo"
Expand All @@ -26,7 +27,7 @@ func setupMCPClientServer(t *testing.T, ctx context.Context, session *Session) (
session.systemVariables.Verbose = true // Set Verbose to true to ensure result line is printed

// Update the session's StreamManager to use the output buffer
session.systemVariables.StreamManager = NewStreamManager(io.NopCloser(strings.NewReader("")), &outputBuf, &outputBuf)
session.systemVariables.StreamManager = streamio.NewStreamManager(io.NopCloser(strings.NewReader("")), &outputBuf, &outputBuf)

cli := &Cli{
SessionHandler: NewSessionHandler(session),
Expand Down Expand Up @@ -252,7 +253,7 @@ func testRunMCPWithNonExistentDatabase(t *testing.T) {
defer func() { _ = pipeWriter.Close() }()

// Create StreamManager with the pipe for input
sysVarsNonExistent.StreamManager = NewStreamManager(pipeReader, &outputBuf, &outputBuf)
sysVarsNonExistent.StreamManager = streamio.NewStreamManager(pipeReader, &outputBuf, &outputBuf)

cli, err := NewCli(ctx, nil, &sysVarsNonExistent)
if err != nil {
Expand Down Expand Up @@ -444,7 +445,7 @@ func TestRunMCP(t *testing.T) {
StatementTimeout: lo.ToPtr(1 * time.Hour), // Long timeout for integration tests
AutoWrap: true, // Set a different value
EnableHighlight: true, // Set a different value
StreamManager: NewStreamManager(io.NopCloser(strings.NewReader("")), &outputBuf, &outputBuf),
StreamManager: streamio.NewStreamManager(io.NopCloser(strings.NewReader("")), &outputBuf, &outputBuf),
}
cli := &Cli{
SessionHandler: NewSessionHandler(session),
Expand Down
14 changes: 8 additions & 6 deletions internal/mycli/integration_meta_command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"path/filepath"
"strings"
"testing"

"github.com/apstndb/spanner-mycli/internal/mycli/streamio"
)

// createTestCli creates a test CLI with the given input, output, and error streams
func createTestCli(t *testing.T, ctx context.Context, input io.ReadCloser, output, errOutput io.Writer, sysVars *systemVariables) *Cli {
// Create StreamManager with the provided streams
sysVars.StreamManager = NewStreamManager(input, output, errOutput)
sysVars.StreamManager = streamio.NewStreamManager(input, output, errOutput)

cli, err := NewCli(ctx, nil, sysVars)
if err != nil {
Expand Down Expand Up @@ -178,7 +180,7 @@ SELECT "foo" AS s;`
output := &bytes.Buffer{}

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(io.NopCloser(input), output, output)
sysVars.StreamManager = streamio.NewStreamManager(io.NopCloser(input), output, output)

cli := &Cli{
SessionHandler: sessionHandler,
Expand Down Expand Up @@ -224,7 +226,7 @@ SELECT "foo" AS s;`
output := &bytes.Buffer{}

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(io.NopCloser(input), output, output)
sysVars.StreamManager = streamio.NewStreamManager(io.NopCloser(input), output, output)

cli := &Cli{
SessionHandler: sessionHandler,
Expand Down Expand Up @@ -291,7 +293,7 @@ SELECT "foo" AS s;`
output := &bytes.Buffer{}

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(io.NopCloser(input), output, output)
sysVars.StreamManager = streamio.NewStreamManager(io.NopCloser(input), output, output)

cli := &Cli{
SessionHandler: sessionHandler,
Expand Down Expand Up @@ -342,7 +344,7 @@ SELECT "foo" AS s;`
output := &bytes.Buffer{}

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(io.NopCloser(input), output, output)
sysVars.StreamManager = streamio.NewStreamManager(io.NopCloser(input), output, output)

cli := &Cli{
SessionHandler: sessionHandler,
Expand Down Expand Up @@ -493,7 +495,7 @@ SELECT "foo" AS s;`
input := strings.NewReader(commands + "\nexit;\n")

// Create StreamManager with the test streams
sysVars.StreamManager = NewStreamManager(io.NopCloser(input), consoleBuf, consoleBuf)
sysVars.StreamManager = streamio.NewStreamManager(io.NopCloser(input), consoleBuf, consoleBuf)

cli := &Cli{
SessionHandler: sessionHandler,
Expand Down
Loading