diff --git a/crates/lib-dialects/src/postgres.rs b/crates/lib-dialects/src/postgres.rs index 4531958e8..304cdbec3 100644 --- a/crates/lib-dialects/src/postgres.rs +++ b/crates/lib-dialects/src/postgres.rs @@ -6710,3 +6710,172 @@ pub fn statement_segment() -> Matchable { false, ) } + +/// Apply pgvector extension support to the postgres dialect. +/// This adds VECTOR, HALFVEC, and SPARSEVEC types with optional dimension parameters. +pub fn apply_pgvector_extension(postgres: &mut Dialect) { + use sqruff_lib_core::parser::types::DialectElementType; + + // Define the pgvector type keywords + const PGVECTOR_KEYWORDS: &[&str] = &["VECTOR", "HALFVEC", "SPARSEVEC"]; + + // Add keywords to the unreserved_keywords set + for &kw in PGVECTOR_KEYWORDS { + postgres.add_keyword_to_set("unreserved_keywords", kw); + } + + // Also add the keywords directly to the library as StringParsers + // (since expand() has already been called and won't process newly added set members) + for &kw in PGVECTOR_KEYWORDS { + let parser = StringParser::new(kw, SyntaxKind::Keyword); + postgres.add([(kw.into(), DialectElementType::Matchable(parser.to_matchable()))]); + } + + // Replace DatatypeSegment grammar to include pgvector types + postgres.replace_grammar( + "DatatypeSegment", + Sequence::new(vec_of_erased![ + Sequence::new(vec_of_erased![ + Ref::new("SingleIdentifierGrammar"), + Ref::new("DotSegment") + ]) + .config(|this| { + this.allow_gaps = false; + this.optional(); + }), + one_of(vec_of_erased![ + Ref::new("WellKnownTextGeometrySegment"), + Ref::new("DateTimeTypeIdentifier"), + Sequence::new(vec_of_erased![one_of(vec_of_erased![ + Ref::keyword("SMALLINT"), + Ref::keyword("INTEGER"), + Ref::keyword("INT"), + Ref::keyword("INT2"), + Ref::keyword("INT4"), + Ref::keyword("INT8"), + Ref::keyword("BIGINT"), + Ref::keyword("FLOAT4"), + Ref::keyword("FLOAT8"), + Ref::keyword("REAL"), + Sequence::new(vec_of_erased![ + Ref::keyword("DOUBLE"), + Ref::keyword("PRECISION") + ]), + Ref::keyword("SMALLSERIAL"), + Ref::keyword("SERIAL"), + Ref::keyword("SERIAL2"), + Ref::keyword("SERIAL4"), + Ref::keyword("SERIAL8"), + Ref::keyword("BIGSERIAL"), + // Numeric types [(precision)] + Sequence::new(vec_of_erased![ + one_of(vec_of_erased![Ref::keyword("FLOAT")]), + Ref::new("BracketedArguments").optional() + ]), + // Numeric types [precision ["," scale])] + Sequence::new(vec_of_erased![ + one_of(vec_of_erased![ + Ref::keyword("DECIMAL"), + Ref::keyword("NUMERIC") + ]), + Ref::new("BracketedArguments").optional() + ]), + // Monetary type + Ref::keyword("MONEY"), + // Character types + one_of(vec_of_erased![ + Sequence::new(vec_of_erased![ + one_of(vec_of_erased![ + Ref::keyword("BPCHAR"), + Ref::keyword("CHAR"), + Sequence::new(vec_of_erased![ + Ref::keyword("CHAR"), + Ref::keyword("VARYING") + ]), + Ref::keyword("CHARACTER"), + Sequence::new(vec_of_erased![ + Ref::keyword("CHARACTER"), + Ref::keyword("VARYING") + ]), + Ref::keyword("VARCHAR") + ]), + Ref::new("BracketedArguments").optional() + ]), + Ref::keyword("TEXT") + ]), + // Binary type + Ref::keyword("BYTEA"), + // Boolean types + one_of(vec_of_erased![ + Ref::keyword("BOOLEAN"), + Ref::keyword("BOOL") + ]), + // Geometric types + one_of(vec_of_erased![ + Ref::keyword("POINT"), + Ref::keyword("LINE"), + Ref::keyword("LSEG"), + Ref::keyword("BOX"), + Ref::keyword("PATH"), + Ref::keyword("POLYGON"), + Ref::keyword("CIRCLE") + ]), + // Network address types + one_of(vec_of_erased![ + Ref::keyword("CIDR"), + Ref::keyword("INET"), + Ref::keyword("MACADDR"), + Ref::keyword("MACADDR8") + ]), + // Text search types + one_of(vec_of_erased![ + Ref::keyword("TSVECTOR"), + Ref::keyword("TSQUERY") + ]), + // pgvector types (dimension is optional) + Sequence::new(vec_of_erased![ + one_of(vec_of_erased![ + Ref::keyword("VECTOR"), + Ref::keyword("HALFVEC"), + Ref::keyword("SPARSEVEC") + ]), + Ref::new("BracketedArguments").optional() + ]), + // Bit string types + Sequence::new(vec_of_erased![ + Ref::keyword("BIT"), + one_of(vec_of_erased![Ref::keyword("VARYING")]) + .config(|this| this.optional()), + Ref::new("BracketedArguments").optional() + ]), + // UUID type + Ref::keyword("UUID"), + // XML type + Ref::keyword("XML"), + // JSON types + one_of(vec_of_erased![Ref::keyword("JSON"), Ref::keyword("JSONB")]), + // Range types + Ref::keyword("INT4RANGE"), + Ref::keyword("INT8RANGE"), + Ref::keyword("NUMRANGE"), + Ref::keyword("TSRANGE"), + Ref::keyword("TSTZRANGE"), + Ref::keyword("DATERANGE"), + // pg_lsn type + Ref::keyword("PG_LSN") + ])]), + Ref::new("DatatypeIdentifierSegment") + ]), + one_of(vec_of_erased![ + AnyNumberOf::new(vec_of_erased![ + Bracketed::new(vec_of_erased![Ref::new("ExpressionSegment").optional()]) + .config(|this| this.bracket_type("square")) + ]), + Ref::new("ArrayTypeSegment"), + Ref::new("SizedArrayTypeSegment"), + ]) + .config(|this| this.optional()), + ]) + .to_matchable(), + ); +} diff --git a/crates/lib/Cargo.toml b/crates/lib/Cargo.toml index 094ee986c..7406e0967 100644 --- a/crates/lib/Cargo.toml +++ b/crates/lib/Cargo.toml @@ -38,8 +38,10 @@ name = "depth_map" harness = false [features] +default = ["postgres"] parser = ["sqruff-lib-core/stringify"] python = ["pyo3", "sqruff-lib-core/serde"] +postgres = ["sqruff-lib-dialects/postgres"] [dependencies] sqruff-lib-core.workspace = true diff --git a/crates/lib/src/core/config.rs b/crates/lib/src/core/config.rs index 64e82ccbf..be6ee20e6 100644 --- a/crates/lib/src/core/config.rs +++ b/crates/lib/src/core/config.rs @@ -10,6 +10,8 @@ use sqruff_lib_core::dialects::init::{DialectKind, dialect_readout}; use sqruff_lib_core::errors::SQLFluffUserError; use sqruff_lib_core::parser::Parser; use sqruff_lib_dialects::kind_to_dialect; +#[cfg(feature = "postgres")] +use sqruff_lib_dialects::postgres::apply_pgvector_extension; use crate::utils::reflow::config::ReflowConfig; @@ -121,7 +123,30 @@ impl FluffConfig { _value => DialectKind::default(), }; - let dialect = kind_to_dialect(&dialect); + let mut dialect = kind_to_dialect(&dialect); + + // Apply dialect extensions based on configuration + #[cfg(feature = "postgres")] + if let Some(ref mut dialect) = dialect { + if dialect.name() == DialectKind::Postgres { + // Check for postgres extensions configuration + if let Some(postgres_config) = configs.get("postgres") { + if let Some(extensions) = postgres_config + .as_map() + .and_then(|m| m.get("extensions")) + .and_then(|v| v.as_string()) + { + // Parse comma-separated extension list + for ext in extensions.split(',').map(|s| s.trim()) { + if ext.eq_ignore_ascii_case("pgvector") { + apply_pgvector_extension(dialect); + } + } + } + } + } + } + for (in_key, out_key) in [ // Deal with potential ignore & warning parameters ("ignore", "ignore"),