diff --git a/engine/agnostic/engine.go b/engine/agnostic/engine.go index 48d0661..4df5835 100644 --- a/engine/agnostic/engine.go +++ b/engine/agnostic/engine.go @@ -22,6 +22,20 @@ func NewEngine() *Engine { e.schemas = make(map[string]*Schema) e.schemas[DefaultSchema] = NewSchema(DefaultSchema) + // create information_schema with a 'tables' relation used by clients (e.g. GORM) + info := NewSchema("information_schema") + // minimal columns used by GORM queries: table_schema, table_name, table_type + attrs := []Attribute{ + NewAttribute("table_schema", "varchar"), + NewAttribute("table_name", "varchar"), + NewAttribute("table_type", "varchar"), + } + // create relation (no primary key) + if r, err := NewRelation("information_schema", "tables", attrs, nil); err == nil { + info.Add("tables", r) + } + e.schemas["information_schema"] = info + return e } diff --git a/engine/agnostic/predicate.go b/engine/agnostic/predicate.go index 726e3d1..7f26e73 100644 --- a/engine/agnostic/predicate.go +++ b/engine/agnostic/predicate.go @@ -1093,10 +1093,21 @@ func (p *EqPredicate) Right() (Predicate, bool) { } func (p *EqPredicate) Relation() string { - if p.left.Relation() != "" { + // Handle nil cases first + if p.left == nil && p.right == nil { + return "" + } + if p.left == nil { + return p.right.Relation() + } + if p.right == nil { return p.left.Relation() } + // Check left first, then fall back to right (if empty) + if p.left.Relation() != "" { + return p.left.Relation() + } return p.right.Relation() } diff --git a/engine/agnostic/transaction.go b/engine/agnostic/transaction.go index 997e9a4..cf46104 100644 --- a/engine/agnostic/transaction.go +++ b/engine/agnostic/transaction.go @@ -150,6 +150,27 @@ func (t *Transaction) CreateRelation(schemaName, relName string, attributes []At log.Debug("CreateRelation(%s,%s,%s,%s)", schemaName, relName, attributes, pk) t.lock(r) + + // maintain information_schema.tables so external tools (eg. GORM) can query table existence + // insert a row into information_schema.tables: table_schema, table_name, table_type + // Use default schema if empty + sch := schemaName + if sch == "" { + sch = DefaultSchema + } + // best-effort: if information_schema exists, insert a metadata row via Transaction.Insert + if t.CheckSchema("information_schema") { + vals := map[string]any{ + "table_schema": sch, + "table_name": relName, + "table_type": "BASE TABLE", + } + _, err := t.Insert("information_schema", "tables", vals) + if err != nil { + // do not fail relation creation because of metadata insertion; just log + log.Warn("could not update information_schema.tables: %s", err) + } + } return nil } @@ -170,6 +191,23 @@ func (t *Transaction) DropRelation(schemaName, relName string) error { } t.changes.PushBack(c) + // remove metadata from information_schema.tables if present + sch := schemaName + if sch == "" { + sch = DefaultSchema + } + if t.CheckSchema("information_schema") { + // build predicate: table_schema = sch AND table_name = relName + left := NewEqPredicate(NewAttributeValueFunctor("tables", "table_schema"), NewConstValueFunctor(sch)) + right := NewEqPredicate(NewAttributeValueFunctor("tables", "table_name"), NewConstValueFunctor(relName)) + p := NewAndPredicate(left, right) + // selectors can be nil for Delete + _, _, err := t.Delete("information_schema", "tables", nil, p) + if err != nil { + log.Warn("could not remove information_schema.tables entry for %s.%s: %s", sch, relName, err) + } + } + return nil } diff --git a/engine/executor/engine.go b/engine/executor/engine.go index 284ccf0..0c675b8 100644 --- a/engine/executor/engine.go +++ b/engine/executor/engine.go @@ -514,12 +514,21 @@ func selectExecutor(t *Tx, selectDecl *parser.Decl, args []NamedValue) (int64, i } sorters = append(sorters, s) case parser.LimitToken: - limit, err := strconv.ParseInt(selectDecl.Decl[i].Decl[0].Lexeme, 10, 64) - if err != nil { - return 0, 0, nil, nil, fmt.Errorf("wrong limit value: %s", err) + if len(selectDecl.Decl[i].Decl) == 0 { + return 0, 0, nil, nil, fmt.Errorf("LIMIT clause requires a value") } - s := agnostic.NewLimitSorter(limit) - sorters = append(sorters, s) + decl := selectDecl.Decl[i].Decl[0] + if decl.Token == parser.NumberToken || decl.Token == parser.StringToken { + limit, err := strconv.ParseInt(decl.Lexeme, 10, 64) + if err != nil { + return 0, 0, nil, nil, fmt.Errorf("wrong LIMIT value: %s", err) + } + // Always add limit sorter last to ensure it's applied after other sorters + sorters = append(sorters, agnostic.NewLimitSorter(limit)) + } else { + return 0, 0, nil, nil, fmt.Errorf("LIMIT clause requires a number value") + } + continue } } diff --git a/engine/executor/tx.go b/engine/executor/tx.go index 2902224..a74ee1d 100644 --- a/engine/executor/tx.go +++ b/engine/executor/tx.go @@ -284,7 +284,7 @@ func (t *Tx) getPredicates(decl []*parser.Decl, schema, fromTableName string, ar switch leftS.Token { case parser.CurrentSchemaToken: - left = agnostic.NewConstValueFunctor(schema) + right = agnostic.NewConstValueFunctor(schema) case parser.NamedArgToken: for _, arg := range args { if leftS.Lexeme == arg.Name { diff --git a/engine/parser/create.go b/engine/parser/create.go index da3ffd5..6949150 100644 --- a/engine/parser/create.go +++ b/engine/parser/create.go @@ -175,6 +175,78 @@ func (p *parser) parseIndex(tokens []Token) (*Decl, error) { return indexDecl, nil } +func (p *parser) parseForeignKeyConstraint() (*Decl, error) { + constraintDecl, err := p.consumeToken(ConstraintToken) + if err != nil { + return nil, err + } + + // Constraint name + nameDecl, err := p.parseQuotedToken() + if err != nil { + return nil, err + } + constraintDecl.Add(nameDecl) + + // FOREIGN KEY + foreignDecl, err := p.consumeToken(ForeignToken) + if err != nil { + return nil, err + } + keyDecl, err := p.consumeToken(KeyToken) + if err != nil { + return nil, err + } + constraintDecl.Add(foreignDecl) + foreignDecl.Add(keyDecl) + + // (column_name) + _, err = p.consumeToken(BracketOpeningToken) + if err != nil { + return nil, err + } + columnDecl, err := p.parseQuotedToken() + if err != nil { + return nil, err + } + keyDecl.Add(columnDecl) + _, err = p.consumeToken(BracketClosingToken) + if err != nil { + return nil, err + } + + // REFERENCES table_name(column_name) + referencesDecl, err := p.consumeToken(ReferencesToken) + if err != nil { + return nil, err + } + keyDecl.Add(referencesDecl) + + // Referenced table + tableDecl, err := p.parseTableName() + if err != nil { + return nil, err + } + referencesDecl.Add(tableDecl) + + // Referenced column + _, err = p.consumeToken(BracketOpeningToken) + if err != nil { + return nil, err + } + refColumnDecl, err := p.parseQuotedToken() + if err != nil { + return nil, err + } + tableDecl.Add(refColumnDecl) + _, err = p.consumeToken(BracketClosingToken) + if err != nil { + return nil, err + } + + return constraintDecl, nil +} + func (p *parser) parseTable(tokens []Token) (*Decl, error) { var err error tableDecl := NewDecl(tokens[p.index]) @@ -220,15 +292,28 @@ func (p *parser) parseTable(tokens []Token) (*Decl, error) { for p.index < len(tokens) { - switch p.cur().Token { - case PrimaryToken: - pkDecl, err := p.parsePrimaryKey() - if err != nil { - return nil, err + // Handle primary key and constraints at table level + if p.cur().Token == PrimaryToken || p.cur().Token == ConstraintToken { + switch p.cur().Token { + case PrimaryToken: + pkDecl, err := p.parsePrimaryKey() + if err != nil { + return nil, err + } + tableDecl.Add(pkDecl) + case ConstraintToken: + constraintDecl, err := p.parseForeignKeyConstraint() + if err != nil { + return nil, err + } + tableDecl.Add(constraintDecl) + } + + // After constraint/key, expect either comma or closing bracket + if p.cur().Token == CommaToken { + p.index++ } - tableDecl.Add(pkDecl) continue - default: } // Closing bracket ? diff --git a/engine/parser/lexer.go b/engine/parser/lexer.go index 8bdbc07..e123903 100644 --- a/engine/parser/lexer.go +++ b/engine/parser/lexer.go @@ -22,6 +22,11 @@ const ( GreaterOrEqualToken BacktickToken + // Constraint tokens + ConstraintToken + ForeignToken + ReferencesToken + // QuoteToken DoubleQuoteToken @@ -128,6 +133,7 @@ func (l *lexer) lex(instruction []byte) ([]Token, error) { securityPos := 0 var matchers []Matcher + matchers = append(matchers, l.MatchNumberToken) // Match numbers first matchers = append(matchers, l.MatchArgTokenODBC) matchers = append(matchers, l.MatchNamedArgToken) matchers = append(matchers, l.MatchArgToken) @@ -191,9 +197,10 @@ func (l *lexer) lex(instruction []byte) ([]Token, error) { matchers = append(matchers, l.genericStringMatcher("or", OrToken)) matchers = append(matchers, l.genericStringMatcher("asc", AscToken)) matchers = append(matchers, l.genericStringMatcher("desc", DescToken)) - matchers = append(matchers, l.genericStringMatcher("limit", LimitToken)) - matchers = append(matchers, l.genericStringMatcher("is", IsToken)) - matchers = append(matchers, l.genericStringMatcher("for", ForToken)) + matchers = append(matchers, l.genericStringMatcher("limit", LimitToken)) + matchers = append(matchers, l.genericStringMatcher("is", IsToken)) + matchers = append(matchers, l.genericStringMatcher("for", ForToken)) + matchers = append(matchers, l.MatchNumberToken) // Ensure numbers are properly matched matchers = append(matchers, l.genericStringMatcher("default", DefaultToken)) matchers = append(matchers, l.genericStringMatcher("localtimestamp", LocalTimestampToken)) matchers = append(matchers, l.genericStringMatcher("false", FalseToken)) @@ -204,6 +211,10 @@ func (l *lexer) lex(instruction []byte) ([]Token, error) { matchers = append(matchers, l.genericStringMatcher("on", OnToken)) matchers = append(matchers, l.genericStringMatcher("collate", CollateToken)) matchers = append(matchers, l.genericStringMatcher("nocase", NocaseToken)) + // Constraint matchers + matchers = append(matchers, l.genericStringMatcher("constraint", ConstraintToken)) + matchers = append(matchers, l.genericStringMatcher("foreign", ForeignToken)) + matchers = append(matchers, l.genericStringMatcher("references", ReferencesToken)) // Type Matcher matchers = append(matchers, l.genericStringMatcher("decimal", DecimalToken)) matchers = append(matchers, l.genericStringMatcher("primary", PrimaryToken))