From ea6b093c44f7ccf60fcdea811f0d56843cb87e0b Mon Sep 17 00:00:00 2001 From: Adam Brightwell Date: Mon, 23 Feb 2026 18:28:04 -0500 Subject: [PATCH] Add `tsvector` type support Implement PostgreSQL `tsvector` type with support for: - Lexemes with positions and weights (A, B, C, D) - Binary and text format encoding/decoding - Quote and backslash escape handling - Array type support - CopyFrom operations Note: Some escape sequences (doubled quotes, backslash escapes) are PostgreSQL-specific and not supported by CockroachDB. Resolves #2483 --- copy_from_test.go | 101 ++++++++ pgconn/pgconn_test.go | 1 + pgtype/pgtype.go | 2 + pgtype/pgtype_default.go | 3 + pgtype/tsvector.go | 503 +++++++++++++++++++++++++++++++++++++++ pgtype/tsvector_test.go | 465 ++++++++++++++++++++++++++++++++++++ 6 files changed, 1075 insertions(+) create mode 100644 pgtype/tsvector.go create mode 100644 pgtype/tsvector_test.go diff --git a/copy_from_test.go b/copy_from_test.go index 44ca4508a..d3264c122 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxtest" "github.com/stretchr/testify/require" ) @@ -452,6 +453,106 @@ func TestConnCopyFromJSON(t *testing.T) { ensureConnValid(t, conn) } +func TestConnCopyFromTSVector(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + pgxtest.SkipCockroachDB(t, conn, "CockroachDB handles tsvector escaping differently") + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, `create temporary table tmp_tsv (id int, t tsvector)`) + require.NoError(t, err) + + inputRows := [][]any{ + // Text format: core functionality. + {1, `'a':1A 'cat':5 'fat':2B,4C`}, // Multiple lexemes with positions and weights. + {2, `'bare'`}, // Single lexeme with no positions. + {3, `'multi':1,2,3,4,5`}, // Multiple positions (default weight D). + {4, `'test':1A,2B,3C,4D`}, // All four weights on one lexeme. + {5, `'word':1D`}, // Explicit weight D (normalizes to no suffix). + {6, `'high':16383A`}, // High position number (near 14-bit max). + + // Text format: escaping. + {7, `'don''t'`}, // Quote escaping (doubled single quote). + {8, `'don\'t'`}, // Quote escaping (backslash). + {9, `'ab\\c'`}, // Backslash in lexeme. + {10, `'\ foo'`}, // Escaped space. + + // Text format: special characters. + {11, `'café' 'naïve'`}, // Unicode lexemes. + {12, `'a:b' 'c,d'`}, // Delimiter-like characters (colon, comma). + + // Struct format: tests binary encoding path. + {13, pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {Word: "alpha", Positions: []pgtype.TSVectorPosition{{Position: 1, Weight: pgtype.TSVectorWeightA}}}, + {Word: "beta", Positions: []pgtype.TSVectorPosition{{Position: 2, Weight: pgtype.TSVectorWeightB}}}, + {Word: "gamma", Positions: nil}, + }, + Valid: true, + }}, + {14, pgtype.TSVector{Valid: true}}, // Empty valid tsvector (no lexemes). + + // NULL handling. + {15, pgtype.TSVector{Valid: false}}, // Invalid (NULL) TSVector struct. + {16, nil}, // Nil value. + } + + copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"tmp_tsv"}, []string{"id", "t"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, err := conn.Query(ctx, "select id, t::text from tmp_tsv order by id nulls last") + require.NoError(t, err) + + var outputRows [][]any + for rows.Next() { + row, err := rows.Values() + require.NoError(t, err) + outputRows = append(outputRows, row) + } + require.NoError(t, rows.Err()) + + expectedOutputRows := [][]any{ + // Text format: core functionality. + {int32(1), `'a':1A 'cat':5 'fat':2B,4C`}, + {int32(2), `'bare'`}, + {int32(3), `'multi':1,2,3,4,5`}, + {int32(4), `'test':1A,2B,3C,4`}, + {int32(5), `'word':1`}, + {int32(6), `'high':16383A`}, + + // Text format: escaping. + {int32(7), `'don''t'`}, + {int32(8), `'don''t'`}, + {int32(9), `'ab\\c'`}, + {int32(10), `' foo'`}, + + // Text format: special characters. + {int32(11), `'café' 'naïve'`}, + {int32(12), `'a:b' 'c,d'`}, + + // Struct format. + {int32(13), `'alpha':1A 'beta':2B 'gamma'`}, + {int32(14), ``}, + + // NULL handling. + {int32(15), nil}, + {int32(16), nil}, + } + require.Equal(t, expectedOutputRows, outputRows) + + ensureConnValid(t, conn) +} + type clientFailSource struct { count int err error diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 33ba05d90..a0debce25 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -170,6 +170,7 @@ func TestConnectOAuthError(t *testing.T) { _, err = pgconn.ConnectConfig(context.Background(), config) require.Error(t, err, "connect should return error for invalid token") } + func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) { t.Parallel() diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 991a91125..536bc3c8f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -96,6 +96,8 @@ const ( RecordArrayOID = 2287 UUIDOID = 2950 UUIDArrayOID = 2951 + TSVectorOID = 3614 + TSVectorArrayOID = 3643 JSONBOID = 3802 JSONBArrayOID = 3807 DaterangeOID = 3912 diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 5648d89bf..42b39d827 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -81,6 +81,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "tsvector", OID: TSVectorOID, Codec: TSVectorCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) @@ -164,6 +165,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsvector", OID: TSVectorArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TSVectorOID]}}) defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) @@ -242,6 +244,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[TSVector](defaultMap, "tsvector") registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") defaultMap.buildReflectTypeToType() diff --git a/pgtype/tsvector.go b/pgtype/tsvector.go new file mode 100644 index 000000000..8a91b1e75 --- /dev/null +++ b/pgtype/tsvector.go @@ -0,0 +1,503 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type TSVectorScanner interface { + ScanTSVector(TSVector) error +} + +type TSVectorValuer interface { + TSVectorValue() (TSVector, error) +} + +// TSVector represents a PostgreSQL tsvector value. +type TSVector struct { + Lexemes []TSVectorLexeme + Valid bool +} + +// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions. +type TSVectorLexeme struct { + Word string + Positions []TSVectorPosition +} + +// ScanTSVector implements the [TSVectorScanner] interface. +func (t *TSVector) ScanTSVector(v TSVector) error { + *t = v + return nil +} + +// TSVectorValue implements the [TSVectorValuer] interface. +func (t TSVector) TSVectorValue() (TSVector, error) { + return t, nil +} + +func (t TSVector) String() string { + buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil) + return string(buf) +} + +// Scan implements the [database/sql.Scanner] interface. +func (t *TSVector) Scan(src any) error { + if src == nil { + *t = TSVector{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (t TSVector) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + + return string(buf), nil +} + +// TSVectorWeight represents the weight label of a lexeme position in a tsvector. +type TSVectorWeight byte + +const ( + TSVectorWeightA = TSVectorWeight('A') + TSVectorWeightB = TSVectorWeight('B') + TSVectorWeightC = TSVectorWeight('C') + TSVectorWeightD = TSVectorWeight('D') +) + +// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL. +func tsvectorWeightToBinary(w TSVectorWeight) uint16 { + switch w { + case TSVectorWeightA: + return 3 + case TSVectorWeightB: + return 2 + case TSVectorWeightC: + return 1 + default: + return 0 // D or unset + } +} + +// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight. +func tsvectorWeightFromBinary(b uint16) TSVectorWeight { + switch b { + case 3: + return TSVectorWeightA + case 2: + return TSVectorWeightB + case 1: + return TSVectorWeightC + default: + return TSVectorWeightD + } +} + +// TSVectorPosition represents a lexeme position and its optional weight within a tsvector. +type TSVectorPosition struct { + Position uint16 + Weight TSVectorWeight +} + +func (p TSVectorPosition) String() string { + s := strconv.FormatUint(uint64(p.Position), 10) + if p.Weight != 0 && p.Weight != TSVectorWeightD { + s += string(p.Weight) + } + return s +} + +type TSVectorCodec struct{} + +func (TSVectorCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TSVectorCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(TSVectorValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTSVectorCodecBinary{} + case TextFormatCode: + return encodePlanTSVectorCodecText{} + } + + return nil +} + +type encodePlanTSVectorCodecBinary struct{} + +func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes))) + + for _, entry := range tsv.Lexemes { + buf = append(buf, entry.Word...) + buf = append(buf, 0x00) + buf = pgio.AppendUint16(buf, uint16(len(entry.Positions))) + + // Each position is a uint16: weight (2 bits) | position (14 bits) + for _, pos := range entry.Positions { + packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF + buf = pgio.AppendUint16(buf, packed) + } + } + + return buf, nil +} + +type scanPlanBinaryTSVectorToTSVectorScanner struct{} + +func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + rp := 0 + + const ( + uint16Len = 2 + uint32Len = 4 + ) + + if len(src[rp:]) < uint32Len { + return fmt.Errorf("tsvector incomplete %v", src) + } + entryCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += uint32Len + + var tsv TSVector + if entryCount > 0 { + tsv.Lexemes = make([]TSVectorLexeme, entryCount) + } + + for i := range entryCount { + nullIndex := bytes.IndexByte(src[rp:], 0x00) + if nullIndex == -1 { + return fmt.Errorf("invalid tsvector binary format: missing null terminator") + } + + lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])} + rp += nullIndex + 1 // skip past null terminator + + // Read position count. + if len(src[rp:]) < uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete position count") + } + + numPositions := int(binary.BigEndian.Uint16(src[rp:])) + rp += uint16Len + + // Read each packed position: weight (2 bits) | position (14 bits) + if len(src[rp:]) < numPositions*uint16Len { + return fmt.Errorf("invalid tsvector binary format: incomplete positions") + } + + if numPositions > 0 { + lexeme.Positions = make([]TSVectorPosition, numPositions) + for pos := range numPositions { + packed := binary.BigEndian.Uint16(src[rp:]) + rp += uint16Len + lexeme.Positions[pos] = TSVectorPosition{ + Position: packed & 0x3FFF, + Weight: tsvectorWeightFromBinary(packed >> 14), + } + } + } + + tsv.Lexemes[i] = lexeme + } + tsv.Valid = true + + return scanner.ScanTSVector(tsv) +} + +var tsvectorLexemeReplacer = strings.NewReplacer( + `\`, `\\`, + `'`, `\'`, +) + +type encodePlanTSVectorCodecText struct{} + +func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) { + tsv, err := value.(TSVectorValuer).TSVectorValue() + if err != nil { + return nil, err + } + + if !tsv.Valid { + return nil, nil + } + + for i, lex := range tsv.Lexemes { + if i > 0 { + buf = append(buf, ' ') + } + + buf = append(buf, '\'') + buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...) + buf = append(buf, '\'') + + sep := byte(':') + for _, p := range lex.Positions { + buf = append(buf, sep) + buf = append(buf, p.String()...) + sep = ',' + } + } + + return buf, nil +} + +func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanBinaryTSVectorToTSVectorScanner{} + } + case TextFormatCode: + switch target.(type) { + case TSVectorScanner: + return scanPlanTextAnyToTSVectorScanner{} + } + } + + return nil +} + +type scanPlanTextAnyToTSVectorScanner struct{} + +func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(TSVectorScanner) + + if src == nil { + return scanner.ScanTSVector(TSVector{}) + } + + return s.scanString(string(src), scanner) +} + +func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error { + tsv, err := parseTSVector(src) + if err != nil { + return err + } + return scanner.ScanTSVector(tsv) +} + +func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var tsv TSVector + err := codecScan(c, m, oid, format, src, &tsv) + if err != nil { + return nil, err + } + return tsv, nil +} + +type tsvectorParser struct { + str string + pos int +} + +func (p *tsvectorParser) atEnd() bool { + return p.pos >= len(p.str) +} + +func (p *tsvectorParser) peek() byte { + return p.str[p.pos] +} + +func (p *tsvectorParser) consume() (byte, bool) { + if p.pos >= len(p.str) { + return 0, true + } + b := p.str[p.pos] + p.pos++ + return b, false +} + +func (p *tsvectorParser) consumeSpaces() { + for !p.atEnd() && p.peek() == ' ' { + p.consume() + } +} + +// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes. +func (p *tsvectorParser) consumeLexeme() (string, error) { + ch, end := p.consume() + if end || ch != '\'' { + return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote") + } + + var buf strings.Builder + for { + ch, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme") + } + + switch ch { + case '\'': + // Escaped quote ('') — write a literal single quote + if !p.atEnd() && p.peek() == '\'' { + p.consume() + buf.WriteByte('\'') + } else { + // Closing quote — lexeme is complete + return buf.String(), nil + } + case '\\': + next, end := p.consume() + if end { + return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash") + } + buf.WriteByte(next) + default: + buf.WriteByte(ch) + } + } +} + +// consumePositions consumes a comma-separated list of position[weight] values. +func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) { + var positions []TSVectorPosition + + for { + pos, err := p.consumePosition() + if err != nil { + return nil, err + } + positions = append(positions, pos) + + if p.atEnd() || p.peek() != ',' { + break + } + + p.consume() // skip ',' + } + + return positions, nil +} + +// consumePosition consumes a single position number with optional weight letter. +func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) { + start := p.pos + + for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' { + p.consume() + } + + if p.pos == start { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number") + } + + num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16) + if err != nil { + return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos]) + } + + pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD} + + // Check for optional weight letter + if !p.atEnd() { + switch p.peek() { + case 'A', 'a': + pos.Weight = TSVectorWeightA + case 'B', 'b': + pos.Weight = TSVectorWeightB + case 'C', 'c': + pos.Weight = TSVectorWeightC + case 'D', 'd': + pos.Weight = TSVectorWeightD + default: + return pos, nil + } + p.consume() + } + + return pos, nil +} + +// parseTSVector parses a PostgreSQL tsvector text representation. +func parseTSVector(s string) (TSVector, error) { + result := TSVector{} + p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0} + + for !p.atEnd() { + p.consumeSpaces() + if p.atEnd() { + break + } + + word, err := p.consumeLexeme() + if err != nil { + return TSVector{}, err + } + + entry := TSVectorLexeme{Word: word} + + // Check for optional positions after ':' + if !p.atEnd() && p.peek() == ':' { + p.consume() // skip ':' + + positions, err := p.consumePositions() + if err != nil { + return TSVector{}, err + } + entry.Positions = positions + } + + result.Lexemes = append(result.Lexemes, entry) + } + + result.Valid = true + + return result, nil +} diff --git a/pgtype/tsvector_test.go b/pgtype/tsvector_test.go new file mode 100644 index 000000000..e0d48ec9f --- /dev/null +++ b/pgtype/tsvector_test.go @@ -0,0 +1,465 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func isExpectedEqTSVector(a any) func(any) bool { + return func(v any) bool { + at := a.(pgtype.TSVector) + vt := v.(pgtype.TSVector) + + if len(at.Lexemes) != len(vt.Lexemes) { + return false + } + + if at.Valid != vt.Valid { + return false + } + + for i := range at.Lexemes { + atLexeme := at.Lexemes[i] + vtLexeme := vt.Lexemes[i] + + if atLexeme.Word != vtLexeme.Word { + return false + } + + if len(atLexeme.Positions) != len(vtLexeme.Positions) { + return false + } + + for j := range atLexeme.Positions { + if atLexeme.Positions[j] != vtLexeme.Positions[j] { + return false + } + } + } + + return true + } +} + +func tsvectorConnTestRunner(t *testing.T) pgxtest.ConnTestRunner { + ctr := defaultConnTestRunner + ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var tsvectorOID uint32 + err := conn.QueryRow(context.Background(), `select oid from pg_type where typname = 'tsvector'`).Scan(&tsvectorOID) + if err != nil { + t.Skipf("Skipping; cannot find tsvector OID") + } + + conn.TypeMap().RegisterType(&pgtype.Type{Name: "tsvector", OID: tsvectorOID, Codec: pgtype.TSVectorCodec{}}) + } + return ctr +} + +func TestTSVectorCodecBinary(t *testing.T) { + t.Run("Core", func(t *testing.T) { + tests := []pgxtest.ValueRoundTripTest{ + // NULL. + { + Param: pgtype.TSVector{}, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{}), + }, + // Empty but valid tsvector (no lexemes). + { + Param: pgtype.TSVector{Valid: true}, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{Valid: true}), + }, + // Single lexeme with no positions. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{{"fat", nil}}, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{{"fat", nil}}, + Valid: true, + }), + }, + // Multiple lexemes with positions and weights. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"cat", []pgtype.TSVectorPosition{{1, pgtype.TSVectorWeightA}}}, + {"dog", []pgtype.TSVectorPosition{{2, pgtype.TSVectorWeightB}}}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"cat", []pgtype.TSVectorPosition{{1, pgtype.TSVectorWeightA}}}, + {"dog", []pgtype.TSVectorPosition{{2, pgtype.TSVectorWeightB}}}, + }, + Valid: true, + }), + }, + // All four weight types (A, B, C, D) on a single lexeme. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"word", []pgtype.TSVectorPosition{ + {1, pgtype.TSVectorWeightA}, + {2, pgtype.TSVectorWeightB}, + {3, pgtype.TSVectorWeightC}, + {4, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"word", []pgtype.TSVectorPosition{ + {1, pgtype.TSVectorWeightA}, + {2, pgtype.TSVectorWeightB}, + {3, pgtype.TSVectorWeightC}, + {4, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }), + }, + // Multiple positions per lexeme. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"spaceship", []pgtype.TSVectorPosition{ + {2, pgtype.TSVectorWeightD}, + {33, pgtype.TSVectorWeightA}, + {34, pgtype.TSVectorWeightB}, + {35, pgtype.TSVectorWeightC}, + {36, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"spaceship", []pgtype.TSVectorPosition{ + {2, pgtype.TSVectorWeightD}, + {33, pgtype.TSVectorWeightA}, + {34, pgtype.TSVectorWeightB}, + {35, pgtype.TSVectorWeightC}, + {36, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }), + }, + // Lexeme word containing a space. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"1 2", nil}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"1 2", nil}, + }, + Valid: true, + }), + }, + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, tsvectorConnTestRunner(t), pgxtest.KnownOIDQueryExecModes, "tsvector", tests) + }) + + t.Run("SpecialCharacters", func(t *testing.T) { + tests := []pgxtest.ValueRoundTripTest{ + // Lexeme words containing a single quote. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"D'Artagnan", []pgtype.TSVectorPosition{}}, + {"cats'", []pgtype.TSVectorPosition{}}, + {"don't", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"D'Artagnan", []pgtype.TSVectorPosition{}}, + {"cats'", []pgtype.TSVectorPosition{}}, + {"don't", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Unicode lexemes. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"café", []pgtype.TSVectorPosition{}}, + {"naïve", []pgtype.TSVectorPosition{}}, + {"日本語", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"café", []pgtype.TSVectorPosition{}}, + {"naïve", []pgtype.TSVectorPosition{}}, + {"日本語", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Lexeme words containing backslashes. + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {`ab\c`, []pgtype.TSVectorPosition{}}, + {`back\slash`, []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {`ab\c`, []pgtype.TSVectorPosition{}}, + {`back\slash`, []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Lexeme words containing delimiter characters (colon, comma). + { + Param: pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"a:b", []pgtype.TSVectorPosition{}}, + {"c,d", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"a:b", []pgtype.TSVectorPosition{}}, + {"c,d", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, tsvectorConnTestRunner(t), pgxtest.KnownOIDQueryExecModes, "tsvector", tests) + }) +} + +func TestTSVectorCodecText(t *testing.T) { + t.Run("Core", func(t *testing.T) { + tests := []pgxtest.ValueRoundTripTest{ + // NULL. + { + Param: pgtype.TSVector{}, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{}), + }, + // Empty but valid tsvector (no lexemes). + { + Param: pgtype.TSVector{Valid: true}, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{Valid: true}), + }, + // Single lexeme with no positions. + { + Param: "'fat'", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{{"fat", nil}}, + Valid: true, + }), + }, + // Multiple lexemes with positions and weights. + { + Param: "'cat':1A 'dog':2B", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"cat", []pgtype.TSVectorPosition{{1, pgtype.TSVectorWeightA}}}, + {"dog", []pgtype.TSVectorPosition{{2, pgtype.TSVectorWeightB}}}, + }, + Valid: true, + }), + }, + // All four weight types (A, B, C, D) on a single lexeme. + { + Param: "'word':1A,2B,3C,4D", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"word", []pgtype.TSVectorPosition{ + {1, pgtype.TSVectorWeightA}, + {2, pgtype.TSVectorWeightB}, + {3, pgtype.TSVectorWeightC}, + {4, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }), + }, + // Multiple positions per lexeme. + { + Param: "'spaceship':2,33A,34B,35C,36D", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"spaceship", []pgtype.TSVectorPosition{ + {2, pgtype.TSVectorWeightD}, + {33, pgtype.TSVectorWeightA}, + {34, pgtype.TSVectorWeightB}, + {35, pgtype.TSVectorWeightC}, + {36, pgtype.TSVectorWeightD}, + }}, + }, + Valid: true, + }), + }, + // Lowercase weight letters are accepted and normalized to uppercase. + { + Param: "'cat':2b", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"cat", []pgtype.TSVectorPosition{{2, pgtype.TSVectorWeightB}}}, + }, + Valid: true, + }), + }, + // Leading and trailing whitespace is trimmed. + { + Param: " 'fat' ", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"fat", nil}, + }, + Valid: true, + }), + }, + // Lexeme word containing a space. + { + Param: "'1 2'", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"1 2", nil}, + }, + Valid: true, + }), + }, + // Backslash quote escape (\'). + { + Param: `'D\'Artagnan' 'cats\'' 'don\'t'`, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"D'Artagnan", []pgtype.TSVectorPosition{}}, + {"cats'", []pgtype.TSVectorPosition{}}, + {"don't", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Lexeme words containing delimiter characters (colon, comma). + { + Param: `'a:b' 'c,d'`, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"a:b", []pgtype.TSVectorPosition{}}, + {"c,d", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.AllQueryExecModes, "tsvector", tests) + }) + + t.Run("SpecialCharacters", func(t *testing.T) { + tests := []pgxtest.ValueRoundTripTest{ + // Unicode lexemes. + { + Param: "'café' 'naïve' '日本語'", + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"café", []pgtype.TSVectorPosition{}}, + {"naïve", []pgtype.TSVectorPosition{}}, + {"日本語", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Escaped space in lexeme word. + { + Param: `'\ '`, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {" ", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.AllQueryExecModes, "tsvector", tests) + }) + + t.Run("PostgreSQL", func(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support these escape sequences in tsvector") + + tests := []pgxtest.ValueRoundTripTest{ + // Doubled quote escape (''). + { + Param: `'D''Artagnan' 'cats''' 'don''t'`, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"D'Artagnan", []pgtype.TSVectorPosition{}}, + {"cats'", []pgtype.TSVectorPosition{}}, + {"don't", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + // Escaped backslashes in lexeme words. + { + Param: `'AB\\\c' '\\as' 'ab\\\\c' 'ab\\c' 'abc'`, + Result: new(pgtype.TSVector), + Test: isExpectedEqTSVector(pgtype.TSVector{ + Lexemes: []pgtype.TSVectorLexeme{ + {"AB\\c", []pgtype.TSVectorPosition{}}, + {"\\as", []pgtype.TSVectorPosition{}}, + {"ab\\\\c", []pgtype.TSVectorPosition{}}, + {"ab\\c", []pgtype.TSVectorPosition{}}, + {"abc", []pgtype.TSVectorPosition{}}, + }, + Valid: true, + }), + }, + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.AllQueryExecModes, "tsvector", tests) + }) +}