diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a83ff..dcc27c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Works with both `sqlalchemy` and `sqlalchemy_v2` model types - For `sqlalchemy_v2`: uses `Mapped[List[T]]` for one-to-many and `Mapped[T]` for many-to-one +**Schema-Separated Model Files (issue #40)** +- New `split_by_schema` parameter for `create_models()` to generate separate files per database schema +- Each schema gets its own file with a schema-specific Base class (e.g., `Schema1Base`) +- Tables without explicit schema go to a file with the default `Base` class +- Works with both `sqlalchemy` and `sqlalchemy_v2` model types +- File naming: `{schema_name}_{base_filename}.py` (e.g., `schema1_models.py`) + **SQLModel Improvements** - Fixed array type generation (issue #66) - Arrays now properly generate `List[T]` with correct SQLAlchemy ARRAY type diff --git a/omymodels/from_ddl.py b/omymodels/from_ddl.py index 9353ef6..18ac39e 100644 --- a/omymodels/from_ddl.py +++ b/omymodels/from_ddl.py @@ -45,6 +45,7 @@ def create_models( table_prefix: Optional[str] = "", table_suffix: Optional[str] = "", relationships: Optional[bool] = False, + split_by_schema: Optional[bool] = False, ): """models_type can be: "gino", "dataclass", "pydantic" """ # extract data from ddl file @@ -56,7 +57,28 @@ def create_models( sys.exit(0) else: raise NoTablesError() - # generate code + + # Handle split_by_schema mode + if split_by_schema: + output = generate_models_by_schema( + data, + singular, + naming_exceptions, + models_type, + defaults_off, + table_prefix=table_prefix, + table_suffix=table_suffix, + relationships=relationships, + ) + if dump: + save_models_by_schema(output, dump_path) + else: + for schema_name, code in output.items(): + print(f"# === {schema_name} ===") + print(code) + return {"metadata": data, "code": output} + + # generate code (single file mode) output = generate_models_file( data, singular, @@ -140,6 +162,104 @@ def save_models_to_file(models: str, dump_path: str) -> None: f.write(models) +def save_models_by_schema(models_by_schema: Dict[str, str], dump_path: str) -> None: + """Save models split by schema to separate files.""" + folder = os.path.dirname(dump_path) + base_name = os.path.basename(dump_path) + name_without_ext = os.path.splitext(base_name)[0] + + if folder: + os.makedirs(folder, exist_ok=True) + + for schema_name, code in models_by_schema.items(): + file_name = f"{schema_name}_{name_without_ext}.py" if schema_name else f"{name_without_ext}.py" + file_path = os.path.join(folder, file_name) if folder else file_name + with open(file_path, "w+") as f: + f.write(code) + + +def group_tables_by_schema(tables: List) -> Dict[str, List]: + """Group tables by their schema attribute.""" + grouped = {} + for table in tables: + schema = table.table_schema or "" + grouped.setdefault(schema, []).append(table) + return grouped + + +def _schema_to_base_name(schema: str) -> str: + """Convert schema name to Base class name (e.g., 'my_schema' -> 'MySchemaBase').""" + if not schema: + return "Base" + # Convert snake_case or kebab-case to PascalCase + parts = schema.replace("-", "_").split("_") + pascal = "".join(part.capitalize() for part in parts) + return f"{pascal}Base" + + +def generate_models_by_schema( + data: Dict[str, List], + singular: bool = False, + exceptions: Optional[List] = None, + models_type: str = "gino", + defaults_off: Optional[bool] = False, + table_prefix: Optional[str] = "", + table_suffix: Optional[str] = "", + relationships: Optional[bool] = False, +) -> Dict[str, str]: + """Generate models split by schema, each with its own Base class.""" + from omymodels.generators import get_generator_by_type, render_jinja2_template + + results = {} + tables_by_schema = group_tables_by_schema(data["tables"]) + + # Collect relationships across all tables if enabled + relationships_map = {} + if relationships: + relationships_map = collect_relationships(data["tables"]) + + for schema_name, tables in tables_by_schema.items(): + generator = get_generator_by_type(models_type) + add_custom_types_to_generator(data["types"], generator) + + models_str = "" + header = "" + + # Include types only in the first (or default) schema file + if data["types"] and schema_name == "": + types_generator = enum.ModelGenerator(data["types"]) + models_str += types_generator.create_types() + header += types_generator.create_header() + + for table in tables: + models_str += generator.generate_model( + table, + singular, + exceptions, + schema_global=False, # Always include schema in __table_args__ + defaults_off=defaults_off, + table_prefix=table_prefix, + table_suffix=table_suffix, + relationships=relationships_map.get(table.name, []) if relationships else [], + ) + + header += generator.create_header(tables, schema=False, models_str=models_str) + + # Generate code with schema-specific Base name + base_name = _schema_to_base_name(schema_name) + output = render_jinja2_template( + models_type, models_str, header, base_name=base_name + ) + + # Replace class inheritance from Base to custom base name + if base_name != "Base": + output = output.replace("(Base):", f"({base_name}):") + + results[schema_name] = output + + return results + + def _add_relationship( relationships: Dict, table_name: str, fk_column: str, ref_table: str, ref_column: str ): diff --git a/omymodels/generators.py b/omymodels/generators.py index 937634b..1df9557 100644 --- a/omymodels/generators.py +++ b/omymodels/generators.py @@ -85,13 +85,16 @@ def get_supported_models() -> List[str]: return list(list_generators().keys()) -def render_jinja2_template(models_type: str, models: str, headers: str) -> str: +def render_jinja2_template( + models_type: str, models: str, headers: str, base_name: str = "Base" +) -> str: """Render Jinja2 template for model output. Args: models_type: Generator type name models: Generated model code headers: Generated header/imports code + base_name: Name for the Base class (default: "Base") Returns: Rendered template as string @@ -107,5 +110,5 @@ def render_jinja2_template(models_type: str, models: str, headers: str) -> str: with open(template_file) as t: template = t.read() template = Template(template) - params = {"models": models, "headers": headers} + params = {"models": models, "headers": headers, "base_name": base_name} return template.render(**params) diff --git a/omymodels/models/sqlalchemy/sqlalchemy.jinja2 b/omymodels/models/sqlalchemy/sqlalchemy.jinja2 index 4094d3a..19cf646 100644 --- a/omymodels/models/sqlalchemy/sqlalchemy.jinja2 +++ b/omymodels/models/sqlalchemy/sqlalchemy.jinja2 @@ -2,5 +2,5 @@ import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base {{ headers }} -Base = declarative_base() +{{ base_name }} = declarative_base() {{ models }} \ No newline at end of file diff --git a/omymodels/models/sqlalchemy_v2/sqlalchemy_v2.jinja2 b/omymodels/models/sqlalchemy_v2/sqlalchemy_v2.jinja2 index 000d7b5..bd7deb1 100644 --- a/omymodels/models/sqlalchemy_v2/sqlalchemy_v2.jinja2 +++ b/omymodels/models/sqlalchemy_v2/sqlalchemy_v2.jinja2 @@ -5,6 +5,6 @@ from sqlalchemy import ( from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column {{ headers }} -class Base(DeclarativeBase): +class {{ base_name }}(DeclarativeBase): pass {{ models }} \ No newline at end of file diff --git a/tests/functional/generator/test_sqlalchemy.py b/tests/functional/generator/test_sqlalchemy.py index 2e50ddb..794716c 100644 --- a/tests/functional/generator/test_sqlalchemy.py +++ b/tests/functional/generator/test_sqlalchemy.py @@ -476,3 +476,68 @@ class Comments(Base): """ result = create_models(ddl, models_type="sqlalchemy", relationships=True)["code"] assert result == expected + + +def test_split_by_schema(): + """Test that split_by_schema generates separate files per schema with custom Base.""" + ddl = """ +CREATE SCHEMA schema1; +CREATE SCHEMA schema2; + +CREATE TABLE schema1.users ( + id int PRIMARY KEY, + name varchar NOT NULL +); + +CREATE TABLE schema2.orders ( + id int PRIMARY KEY, + total decimal(10,2) +); +""" + result = create_models(ddl, models_type="sqlalchemy", split_by_schema=True, dump=False) + code = result["code"] + + # Should have two schemas + assert "schema1" in code + assert "schema2" in code + + # Check schema1 output + schema1_code = code["schema1"] + assert "Schema1Base = declarative_base()" in schema1_code + assert "class Users(Schema1Base):" in schema1_code + assert 'dict(schema="schema1")' in schema1_code + + # Check schema2 output + schema2_code = code["schema2"] + assert "Schema2Base = declarative_base()" in schema2_code + assert "class Orders(Schema2Base):" in schema2_code + assert 'dict(schema="schema2")' in schema2_code + + +def test_split_by_schema_with_no_schema_tables(): + """Test split_by_schema handles tables without explicit schema.""" + ddl = """ +CREATE SCHEMA myschema; + +CREATE TABLE myschema.users ( + id int PRIMARY KEY +); + +CREATE TABLE public_table ( + id int PRIMARY KEY +); +""" + result = create_models(ddl, models_type="sqlalchemy", split_by_schema=True, dump=False) + code = result["code"] + + # Should have myschema and empty string for tables without schema + assert "myschema" in code + assert "" in code + + # Check myschema output + assert "MyschemaBase = declarative_base()" in code["myschema"] + assert "class Users(MyschemaBase):" in code["myschema"] + + # Check default schema output (no schema) + assert "Base = declarative_base()" in code[""] + assert "class PublicTable(Base):" in code[""] diff --git a/tests/functional/generator/test_sqlalchemy_v2.py b/tests/functional/generator/test_sqlalchemy_v2.py index eb9cdda..b87f22d 100644 --- a/tests/functional/generator/test_sqlalchemy_v2.py +++ b/tests/functional/generator/test_sqlalchemy_v2.py @@ -250,3 +250,27 @@ def test_relationships_multiple_foreign_keys(): # Check child relationships assert 'user: Mapped["Users"] = relationship("Users", back_populates="comments")' in code assert 'post: Mapped["Posts"] = relationship("Posts", back_populates="comments")' in code + + +def test_split_by_schema(): + """Test split_by_schema with SQLAlchemy 2.0 style models.""" + ddl = """ +CREATE SCHEMA schema1; + +CREATE TABLE schema1.users ( + id int PRIMARY KEY, + name varchar NOT NULL +); +""" + result = create_models(ddl, models_type="sqlalchemy_v2", split_by_schema=True, dump=False) + code = result["code"] + + # Should have schema1 + assert "schema1" in code + + # Check schema1 output has SQLAlchemy 2.0 style with custom Base + schema1_code = code["schema1"] + assert "class Schema1Base(DeclarativeBase):" in schema1_code + assert "class Users(Schema1Base):" in schema1_code + assert "id: Mapped[int]" in schema1_code + assert 'dict(schema="schema1")' in schema1_code