diff --git a/CHANGELOG.md b/CHANGELOG.md index b590b6a..a32f594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Uses `X | None` union syntax for nullable columns - Supports all column types, foreign keys, indexes, and constraints +**SQLAlchemy Relationships (issue #47)** +- New `relationships` parameter for `create_models()` to generate `relationship()` with `back_populates` +- Automatically generates bidirectional relationships for foreign keys: + - Parent side (one-to-many): collection attribute pointing to children + - Child side (many-to-one): attribute pointing to parent +- Works with both `sqlalchemy` and `sqlalchemy_v2` model types +- For `sqlalchemy_v2`: uses `Mapped[List[T]]` for one-to-many and `Mapped[T]` for many-to-one + **SQLModel Improvements** - Fixed array type generation (issue #66) - Arrays now properly generate `List[T]` with correct SQLAlchemy ARRAY type diff --git a/omymodels/from_ddl.py b/omymodels/from_ddl.py index 5f27d73..9353ef6 100644 --- a/omymodels/from_ddl.py +++ b/omymodels/from_ddl.py @@ -44,6 +44,7 @@ def create_models( no_auto_snake_case: Optional[bool] = False, table_prefix: Optional[str] = "", table_suffix: Optional[str] = "", + relationships: Optional[bool] = False, ): """models_type can be: "gino", "dataclass", "pydantic" """ # extract data from ddl file @@ -65,6 +66,7 @@ def create_models( defaults_off, table_prefix=table_prefix, table_suffix=table_suffix, + relationships=relationships, ) if dump: save_models_to_file(output, dump_path) @@ -138,6 +140,56 @@ def save_models_to_file(models: str, dump_path: str) -> None: f.write(models) +def _add_relationship( + relationships: Dict, table_name: str, fk_column: str, ref_table: str, ref_column: str +): + """Helper to add both sides of a relationship.""" + relationships.setdefault(table_name, []).append({ + "type": "many_to_one", + "fk_column": fk_column, + "ref_table": ref_table, + "ref_column": ref_column, + "child_table_name": table_name, + }) + relationships.setdefault(ref_table, []).append({ + "type": "one_to_many", + "child_table": table_name, + "fk_column": fk_column, + }) + + +def _get_alter_columns(table) -> List: + """Get ALTER TABLE columns if they exist.""" + if hasattr(table, 'alter') and table.alter: + return table.alter.get("columns", []) + return [] + + +def collect_relationships(tables: List) -> Dict: + """Collect foreign key relationships between tables.""" + relationships = {} + + for table in tables: + for column in table.columns: + if column.references and column.references.get("table"): + _add_relationship( + relationships, table.name, column.name, + column.references["table"], + column.references.get("column") or column.name + ) + + for alter_col in _get_alter_columns(table): + ref_info = alter_col.get("references") + if ref_info and ref_info.get("table"): + _add_relationship( + relationships, table.name, alter_col["name"], + ref_info["table"], + ref_info.get("column") or alter_col["name"] + ) + + return relationships + + def generate_models_file( data: Dict[str, List], singular: bool = False, @@ -147,6 +199,7 @@ def generate_models_file( defaults_off: Optional[bool] = False, table_prefix: Optional[str] = "", table_suffix: Optional[str] = "", + relationships: Optional[bool] = False, ) -> str: """method to prepare full file with all Models &""" models_str = "" @@ -159,6 +212,11 @@ def generate_models_file( if data["tables"]: add_custom_types_to_generator(data["types"], generator) + # Collect relationships if enabled + relationships_map = {} + if relationships: + relationships_map = collect_relationships(data["tables"]) + for table in data["tables"]: models_str += generator.generate_model( table, @@ -168,6 +226,7 @@ def generate_models_file( defaults_off=defaults_off, table_prefix=table_prefix, table_suffix=table_suffix, + relationships=relationships_map.get(table.name, []) if relationships else [], ) header += generator.create_header( data["tables"], schema=schema_global, models_str=models_str diff --git a/omymodels/models/sqlalchemy/core.py b/omymodels/models/sqlalchemy/core.py index 1c01f15..5285517 100644 --- a/omymodels/models/sqlalchemy/core.py +++ b/omymodels/models/sqlalchemy/core.py @@ -18,6 +18,7 @@ def __init__(self): self.postgresql_dialect_cols = set() self.constraint = False self.im_index = False + self.relationship_import = False self.types_mapping = types_mapping self.templates = st self.prefix = "sa." @@ -46,14 +47,16 @@ def generate_model( singular: bool = True, exceptions: Optional[List] = None, schema_global: Optional[bool] = True, + relationships: Optional[List] = None, *args, **kwargs, ) -> str: """method to prepare one Model defention - name & tablename & columns""" model = "" + model_name = create_class_name(table.name, singular, exceptions) model = st.model_template.format( - model_name=create_class_name(table.name, singular, exceptions), + model_name=model_name, table_name=table.name, ) for column in table.columns: @@ -62,8 +65,61 @@ def generate_model( ) if table.indexes or table.alter or table.checks or not schema_global: model = logic.add_table_args(self, model, table, schema_global) + + # Generate relationships if enabled + if relationships: + model += self._generate_relationships( + relationships, singular, exceptions + ) return model + def _generate_relationships( + self, + relationships: List[Dict], + singular: bool, + exceptions: Optional[List] = None, + ) -> str: + """Generate relationship() lines for the model.""" + result = "\n" + self.relationship_import = True + + for rel in relationships: + if rel["type"] == "many_to_one": + # Child side: reference to parent + # e.g., posts.user = relationship("Users", back_populates="posts") + ref_table = rel["ref_table"] + fk_column = rel["fk_column"] + child_table_name = rel["child_table_name"] + related_class = create_class_name(ref_table, singular, exceptions) + # Attribute name derived from FK column (user_id -> user) + attr_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else ref_table.lower() + # back_populates points to the collection on the parent (uses child table name) + back_pop_name = child_table_name.lower().replace("-", "_") + back_populates = st.back_populates_template.format(attr_name=back_pop_name) + result += st.relationship_template.format( + attr_name=attr_name, + related_class=related_class, + back_populates=back_populates, + ) + elif rel["type"] == "one_to_many": + # Parent side: collection of children + # e.g., users.posts = relationship("Posts", back_populates="user") + child_table = rel["child_table"] + fk_column = rel["fk_column"] + related_class = create_class_name(child_table, singular, exceptions) + # Attribute name is the child table name (as-is, since table names are typically plural) + attr_name = child_table.lower().replace("-", "_") + # back_populates points to the single parent reference on the child + # Derived from FK column (user_id -> user) + back_pop_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else child_table.lower() + back_populates = st.back_populates_template.format(attr_name=back_pop_name) + result += st.relationship_template.format( + attr_name=attr_name, + related_class=related_class, + back_populates=back_populates, + ) + return result + def create_header( self, tables: List[Dict], schema: bool = False, *args, **kwargs ) -> str: @@ -82,4 +138,6 @@ def create_header( header += st.unique_cons_import + "\n" if self.im_index: header += st.index_import + "\n" + if self.relationship_import: + header += st.relationship_import + "\n" return header diff --git a/omymodels/models/sqlalchemy/templates.py b/omymodels/models/sqlalchemy/templates.py index f2f2eff..09196c8 100644 --- a/omymodels/models/sqlalchemy/templates.py +++ b/omymodels/models/sqlalchemy/templates.py @@ -51,3 +51,8 @@ class {model_name}(Base):\n on_delete = ', ondelete="{mode}"' on_update = ', onupdate="{mode}"' + +# relationship templates +relationship_import = "from sqlalchemy.orm import relationship" +relationship_template = ' {attr_name} = relationship("{related_class}"{back_populates})\n' +back_populates_template = ', back_populates="{attr_name}"' diff --git a/omymodels/models/sqlalchemy_v2/core.py b/omymodels/models/sqlalchemy_v2/core.py index 265d3bf..41fcae7 100644 --- a/omymodels/models/sqlalchemy_v2/core.py +++ b/omymodels/models/sqlalchemy_v2/core.py @@ -24,6 +24,7 @@ def __init__(self): self.time_import = False self.uuid_import = False self.fk_import = False + self.relationship_import = False self.types_mapping = types_mapping self.templates = st self.prefix = "" @@ -225,6 +226,7 @@ def generate_model( singular: bool = True, exceptions: Optional[List] = None, schema_global: Optional[bool] = True, + relationships: Optional[List] = None, *args, **kwargs, ) -> str: @@ -242,8 +244,68 @@ def generate_model( if table.indexes or table.alter or table.checks or not schema_global: model = self._add_table_args(model, table, schema_global) + # Generate relationships if enabled + if relationships: + model += self._generate_relationships( + relationships, singular, exceptions + ) + return model + def _generate_relationships( + self, + relationships: List[Dict], + singular: bool, + exceptions: Optional[List] = None, + ) -> str: + """Generate relationship() lines for the model.""" + result = "\n" + self.relationship_import = True + self.typing_imports.add("List") + + for rel in relationships: + if rel["type"] == "many_to_one": + # Child side: reference to parent + # e.g., author: Mapped["Authors"] = relationship("Authors", back_populates="books") + ref_table = rel["ref_table"] + fk_column = rel["fk_column"] + child_table_name = rel["child_table_name"] + related_class = create_class_name(ref_table, singular, exceptions) + # Attribute name derived from FK column (author_id -> author) + attr_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else ref_table.lower() + # back_populates points to the collection on the parent (uses child table name) + back_pop_name = child_table_name.lower().replace("-", "_") + back_populates = st.back_populates_template.format(attr_name=back_pop_name) + # Type hint for many-to-one is the related class (quoted for forward ref) + type_hint = f'"{related_class}"' + result += st.relationship_template.format( + attr_name=attr_name, + type_hint=type_hint, + related_class=related_class, + back_populates=back_populates, + ) + elif rel["type"] == "one_to_many": + # Parent side: collection of children + # e.g., books: Mapped[List["Books"]] = relationship("Books", back_populates="author") + child_table = rel["child_table"] + fk_column = rel["fk_column"] + related_class = create_class_name(child_table, singular, exceptions) + # Attribute name is the child table name (as-is, since table names are typically plural) + attr_name = child_table.lower().replace("-", "_") + # back_populates points to the single parent reference on the child + # Derived from FK column (author_id -> author) + back_pop_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else child_table.lower() + back_populates = st.back_populates_template.format(attr_name=back_pop_name) + # Type hint for one-to-many is List of related class (quoted for forward ref) + type_hint = f'List["{related_class}"]' + result += st.relationship_template.format( + attr_name=attr_name, + type_hint=type_hint, + related_class=related_class, + back_populates=back_populates, + ) + return result + def _add_table_args( self, model: str, table: Dict, schema_global: bool = True ) -> str: @@ -320,4 +382,7 @@ def create_header( if self.im_index: parts.append(st.index_import + "\n") + if self.relationship_import: + parts.append(st.relationship_import + "\n") + return "".join(parts) diff --git a/omymodels/models/sqlalchemy_v2/templates.py b/omymodels/models/sqlalchemy_v2/templates.py index 0a5bfd7..8d5b08a 100644 --- a/omymodels/models/sqlalchemy_v2/templates.py +++ b/omymodels/models/sqlalchemy_v2/templates.py @@ -60,3 +60,8 @@ class {model_name}(Base): on_delete = ', ondelete="{mode}"' on_update = ', onupdate="{mode}"' + +# relationship templates +relationship_import = "from sqlalchemy.orm import relationship" +relationship_template = ' {attr_name}: Mapped[{type_hint}] = relationship("{related_class}"{back_populates})\n' +back_populates_template = ', back_populates="{attr_name}"' diff --git a/tests/functional/generator/test_sqlalchemy.py b/tests/functional/generator/test_sqlalchemy.py index 523ecfe..2e50ddb 100644 --- a/tests/functional/generator/test_sqlalchemy.py +++ b/tests/functional/generator/test_sqlalchemy.py @@ -361,3 +361,118 @@ class Table2(Base): """ result = create_models(ddl, schema_global=False, models_type="sqlalchemy")["code"] assert result == expected + + +def test_relationships_with_back_populates(): + """Test that relationships=True generates relationship() with back_populates.""" + expected = """import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + + +Base = declarative_base() + + +class Users(Base): + + __tablename__ = 'users' + + id = sa.Column(sa.Integer(), primary_key=True) + name = sa.Column(sa.String(), nullable=False) + + posts = relationship("Posts", back_populates="user") + + +class Posts(Base): + + __tablename__ = 'posts' + + id = sa.Column(sa.Integer(), primary_key=True) + title = sa.Column(sa.String(), nullable=False) + user_id = sa.Column(sa.Integer(), sa.ForeignKey('users.id')) + + user = relationship("Users", back_populates="posts") +""" + ddl = """ +CREATE TABLE "users" ( + "id" int PRIMARY KEY, + "name" varchar NOT NULL +); + +CREATE TABLE "posts" ( + "id" int PRIMARY KEY, + "title" varchar NOT NULL, + "user_id" int +); + +ALTER TABLE "posts" ADD FOREIGN KEY ("user_id") REFERENCES "users" ("id"); +""" + result = create_models(ddl, models_type="sqlalchemy", relationships=True)["code"] + assert result == expected + + +def test_relationships_multiple_foreign_keys(): + """Test relationships with multiple foreign keys in the same table.""" + expected = """import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + + +Base = declarative_base() + + +class Users(Base): + + __tablename__ = 'users' + + id = sa.Column(sa.Integer(), primary_key=True) + name = sa.Column(sa.String()) + + comments = relationship("Comments", back_populates="user") + + +class Posts(Base): + + __tablename__ = 'posts' + + id = sa.Column(sa.Integer(), primary_key=True) + title = sa.Column(sa.String()) + + comments = relationship("Comments", back_populates="post") + + +class Comments(Base): + + __tablename__ = 'comments' + + id = sa.Column(sa.Integer(), primary_key=True) + text = sa.Column(sa.String()) + user_id = sa.Column(sa.Integer(), sa.ForeignKey('users.id')) + post_id = sa.Column(sa.Integer(), sa.ForeignKey('posts.id')) + + user = relationship("Users", back_populates="comments") + post = relationship("Posts", back_populates="comments") +""" + ddl = """ +CREATE TABLE "users" ( + "id" int PRIMARY KEY, + "name" varchar +); + +CREATE TABLE "posts" ( + "id" int PRIMARY KEY, + "title" varchar +); + +CREATE TABLE "comments" ( + "id" int PRIMARY KEY, + "text" varchar, + "user_id" int, + "post_id" int +); + +ALTER TABLE "comments" ADD FOREIGN KEY ("user_id") REFERENCES "users" ("id"); +ALTER TABLE "comments" ADD FOREIGN KEY ("post_id") REFERENCES "posts" ("id"); +""" + result = create_models(ddl, models_type="sqlalchemy", relationships=True)["code"] + assert result == expected diff --git a/tests/functional/generator/test_sqlalchemy_v2.py b/tests/functional/generator/test_sqlalchemy_v2.py index a5ff5e0..eb9cdda 100644 --- a/tests/functional/generator/test_sqlalchemy_v2.py +++ b/tests/functional/generator/test_sqlalchemy_v2.py @@ -186,3 +186,67 @@ def test_array_types(): assert "Mapped[List[str] | None]" in code assert "Mapped[List[int] | None]" in code assert "ARRAY(" in code + + +def test_relationships_with_back_populates(): + """Test that relationships=True generates relationship() with back_populates.""" + ddl = """ +CREATE TABLE users ( + id int PRIMARY KEY, + name varchar NOT NULL +); + +CREATE TABLE posts ( + id int PRIMARY KEY, + title varchar NOT NULL, + user_id int +); + +ALTER TABLE posts ADD FOREIGN KEY (user_id) REFERENCES users (id); +""" + result = create_models(ddl, models_type="sqlalchemy_v2", relationships=True) + code = result["code"] + + # Check relationship import + assert "from sqlalchemy.orm import relationship" in code + + # Check parent (Users) has posts relationship + assert 'posts: Mapped[List["Posts"]] = relationship("Posts", back_populates="user")' in code + + # Check child (Posts) has user relationship + assert 'user: Mapped["Users"] = relationship("Users", back_populates="posts")' in code + + +def test_relationships_multiple_foreign_keys(): + """Test relationships with multiple foreign keys in the same table.""" + ddl = """ +CREATE TABLE users ( + id int PRIMARY KEY, + name varchar +); + +CREATE TABLE posts ( + id int PRIMARY KEY, + title varchar +); + +CREATE TABLE comments ( + id int PRIMARY KEY, + text varchar, + user_id int, + post_id int +); + +ALTER TABLE comments ADD FOREIGN KEY (user_id) REFERENCES users (id); +ALTER TABLE comments ADD FOREIGN KEY (post_id) REFERENCES posts (id); +""" + result = create_models(ddl, models_type="sqlalchemy_v2", relationships=True) + code = result["code"] + + # Check parent relationships + assert 'comments: Mapped[List["Comments"]] = relationship("Comments", back_populates="user")' in code + assert 'comments: Mapped[List["Comments"]] = relationship("Comments", back_populates="post")' in code + + # Check child relationships + assert 'user: Mapped["Users"] = relationship("Users", back_populates="comments")' in code + assert 'post: Mapped["Posts"] = relationship("Posts", back_populates="comments")' in code diff --git a/tests/integration/sqlalchemy/test_sqlalchemy.py b/tests/integration/sqlalchemy/test_sqlalchemy.py index f0f28b6..eefd527 100644 --- a/tests/integration/sqlalchemy/test_sqlalchemy.py +++ b/tests/integration/sqlalchemy/test_sqlalchemy.py @@ -118,3 +118,43 @@ def test_sqlalchemy_multiple_tables(load_generated_code) -> None: assert issubclass(module.Posts, module.Base) os.remove(os.path.abspath(module.__file__)) + + +def test_sqlalchemy_relationships_with_back_populates(load_generated_code) -> None: + """Integration test: verify relationship() with back_populates is generated correctly.""" + ddl = """ + CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL + ); + + CREATE TABLE books ( + id SERIAL PRIMARY KEY, + title VARCHAR(200) NOT NULL, + author_id INT + ); + + ALTER TABLE books ADD FOREIGN KEY (author_id) REFERENCES authors (id); + """ + result = create_models(ddl, models_type="sqlalchemy", relationships=True)["code"] + + module = load_generated_code(result) + + # Verify models exist + assert hasattr(module, "Authors") + assert hasattr(module, "Books") + + # Verify relationships are defined as class attributes + # The relationship should be in __mapper__.relationships + from sqlalchemy.orm import configure_mappers + configure_mappers() + + # Check Authors has 'books' relationship + author_relationships = module.Authors.__mapper__.relationships + assert "books" in author_relationships.keys() + + # Check Books has 'author' relationship + book_relationships = module.Books.__mapper__.relationships + assert "author" in book_relationships.keys() + + os.remove(os.path.abspath(module.__file__)) diff --git a/tests/integration/sqlalchemy_v2/test_sqlalchemy_v2.py b/tests/integration/sqlalchemy_v2/test_sqlalchemy_v2.py index 5e8558c..6c22d07 100644 --- a/tests/integration/sqlalchemy_v2/test_sqlalchemy_v2.py +++ b/tests/integration/sqlalchemy_v2/test_sqlalchemy_v2.py @@ -210,3 +210,42 @@ def test_sqlalchemy_v2_complete_example(load_generated_code) -> None: assert len(user_id_col.foreign_keys) == 1 os.remove(os.path.abspath(module.__file__)) + + +def test_sqlalchemy_v2_relationships_with_back_populates(load_generated_code) -> None: + """Integration test: verify relationship() with back_populates is generated correctly.""" + ddl = """ + CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL + ); + + CREATE TABLE books ( + id SERIAL PRIMARY KEY, + title VARCHAR(200) NOT NULL, + author_id INT + ); + + ALTER TABLE books ADD FOREIGN KEY (author_id) REFERENCES authors (id); + """ + result = create_models(ddl, models_type="sqlalchemy_v2", relationships=True)["code"] + + module = load_generated_code(result) + + # Verify models exist + assert hasattr(module, "Authors") + assert hasattr(module, "Books") + + # Verify relationships are defined + from sqlalchemy.orm import configure_mappers + configure_mappers() + + # Check Authors has 'books' relationship + author_relationships = module.Authors.__mapper__.relationships + assert "books" in author_relationships.keys() + + # Check Books has 'author' relationship + book_relationships = module.Books.__mapper__.relationships + assert "author" in book_relationships.keys() + + os.remove(os.path.abspath(module.__file__))