Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Works with both `sqlalchemy` and `sqlalchemy_v2` model types
- For `sqlalchemy_v2`: uses `Mapped[List[T]]` for one-to-many and `Mapped[T]` for many-to-one

**Schema-Separated Model Files (issue #40)**
- New `split_by_schema` parameter for `create_models()` to generate separate files per database schema
- Each schema gets its own file with a schema-specific Base class (e.g., `Schema1Base`)
- Tables without explicit schema go to a file with the default `Base` class
- Works with both `sqlalchemy` and `sqlalchemy_v2` model types
- File naming: `{schema_name}_{base_filename}.py` (e.g., `schema1_models.py`)

**SQLModel Improvements**
- Fixed array type generation (issue #66)
- Arrays now properly generate `List[T]` with correct SQLAlchemy ARRAY type
Expand Down
122 changes: 121 additions & 1 deletion omymodels/from_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def create_models(
table_prefix: Optional[str] = "",
table_suffix: Optional[str] = "",
relationships: Optional[bool] = False,
split_by_schema: Optional[bool] = False,
):
"""models_type can be: "gino", "dataclass", "pydantic" """
# extract data from ddl file
Expand All @@ -56,7 +57,28 @@ def create_models(
sys.exit(0)
else:
raise NoTablesError()
# generate code

# Handle split_by_schema mode
if split_by_schema:
output = generate_models_by_schema(
data,
singular,
naming_exceptions,
models_type,
defaults_off,
table_prefix=table_prefix,
table_suffix=table_suffix,
relationships=relationships,
)
if dump:
save_models_by_schema(output, dump_path)
else:
for schema_name, code in output.items():
print(f"# === {schema_name} ===")
print(code)
return {"metadata": data, "code": output}

# generate code (single file mode)
output = generate_models_file(
data,
singular,
Expand Down Expand Up @@ -140,6 +162,104 @@ def save_models_to_file(models: str, dump_path: str) -> None:
f.write(models)


def save_models_by_schema(models_by_schema: Dict[str, str], dump_path: str) -> None:
"""Save models split by schema to separate files."""
folder = os.path.dirname(dump_path)
base_name = os.path.basename(dump_path)
name_without_ext = os.path.splitext(base_name)[0]

if folder:
os.makedirs(folder, exist_ok=True)

for schema_name, code in models_by_schema.items():
file_name = f"{schema_name}_{name_without_ext}.py" if schema_name else f"{name_without_ext}.py"
file_path = os.path.join(folder, file_name) if folder else file_name
with open(file_path, "w+") as f:
f.write(code)


def group_tables_by_schema(tables: List) -> Dict[str, List]:
"""Group tables by their schema attribute."""
grouped = {}
for table in tables:
schema = table.table_schema or ""
grouped.setdefault(schema, []).append(table)
return grouped


def _schema_to_base_name(schema: str) -> str:
"""Convert schema name to Base class name (e.g., 'my_schema' -> 'MySchemaBase')."""
if not schema:
return "Base"
# Convert snake_case or kebab-case to PascalCase
parts = schema.replace("-", "_").split("_")
pascal = "".join(part.capitalize() for part in parts)
return f"{pascal}Base"


def generate_models_by_schema(
data: Dict[str, List],
singular: bool = False,
exceptions: Optional[List] = None,
models_type: str = "gino",
defaults_off: Optional[bool] = False,
table_prefix: Optional[str] = "",
table_suffix: Optional[str] = "",
relationships: Optional[bool] = False,
) -> Dict[str, str]:
"""Generate models split by schema, each with its own Base class."""
from omymodels.generators import get_generator_by_type, render_jinja2_template

results = {}
tables_by_schema = group_tables_by_schema(data["tables"])

# Collect relationships across all tables if enabled
relationships_map = {}
if relationships:
relationships_map = collect_relationships(data["tables"])

for schema_name, tables in tables_by_schema.items():
generator = get_generator_by_type(models_type)
add_custom_types_to_generator(data["types"], generator)

models_str = ""
header = ""

# Include types only in the first (or default) schema file
if data["types"] and schema_name == "":
types_generator = enum.ModelGenerator(data["types"])
models_str += types_generator.create_types()
header += types_generator.create_header()

for table in tables:
models_str += generator.generate_model(
table,
singular,
exceptions,
schema_global=False, # Always include schema in __table_args__
defaults_off=defaults_off,
table_prefix=table_prefix,
table_suffix=table_suffix,
relationships=relationships_map.get(table.name, []) if relationships else [],
)

header += generator.create_header(tables, schema=False, models_str=models_str)

# Generate code with schema-specific Base name
base_name = _schema_to_base_name(schema_name)
output = render_jinja2_template(
models_type, models_str, header, base_name=base_name
)

# Replace class inheritance from Base to custom base name
if base_name != "Base":
output = output.replace("(Base):", f"({base_name}):")

results[schema_name] = output

return results


def _add_relationship(
relationships: Dict, table_name: str, fk_column: str, ref_table: str, ref_column: str
):
Expand Down
7 changes: 5 additions & 2 deletions omymodels/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,16 @@ def get_supported_models() -> List[str]:
return list(list_generators().keys())


def render_jinja2_template(models_type: str, models: str, headers: str) -> str:
def render_jinja2_template(
models_type: str, models: str, headers: str, base_name: str = "Base"
) -> str:
"""Render Jinja2 template for model output.

Args:
models_type: Generator type name
models: Generated model code
headers: Generated header/imports code
base_name: Name for the Base class (default: "Base")

Returns:
Rendered template as string
Expand All @@ -107,5 +110,5 @@ def render_jinja2_template(models_type: str, models: str, headers: str) -> str:
with open(template_file) as t:
template = t.read()
template = Template(template)
params = {"models": models, "headers": headers}
params = {"models": models, "headers": headers, "base_name": base_name}
return template.render(**params)
2 changes: 1 addition & 1 deletion omymodels/models/sqlalchemy/sqlalchemy.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
{{ headers }}

Base = declarative_base()
{{ base_name }} = declarative_base()
{{ models }}
2 changes: 1 addition & 1 deletion omymodels/models/sqlalchemy_v2/sqlalchemy_v2.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ from sqlalchemy import (
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
{{ headers }}

class Base(DeclarativeBase):
class {{ base_name }}(DeclarativeBase):
pass
{{ models }}
65 changes: 65 additions & 0 deletions tests/functional/generator/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,68 @@ class Comments(Base):
"""
result = create_models(ddl, models_type="sqlalchemy", relationships=True)["code"]
assert result == expected


def test_split_by_schema():
"""Test that split_by_schema generates separate files per schema with custom Base."""
ddl = """
CREATE SCHEMA schema1;
CREATE SCHEMA schema2;

CREATE TABLE schema1.users (
id int PRIMARY KEY,
name varchar NOT NULL
);

CREATE TABLE schema2.orders (
id int PRIMARY KEY,
total decimal(10,2)
);
"""
result = create_models(ddl, models_type="sqlalchemy", split_by_schema=True, dump=False)
code = result["code"]

# Should have two schemas
assert "schema1" in code
assert "schema2" in code

# Check schema1 output
schema1_code = code["schema1"]
assert "Schema1Base = declarative_base()" in schema1_code
assert "class Users(Schema1Base):" in schema1_code
assert 'dict(schema="schema1")' in schema1_code

# Check schema2 output
schema2_code = code["schema2"]
assert "Schema2Base = declarative_base()" in schema2_code
assert "class Orders(Schema2Base):" in schema2_code
assert 'dict(schema="schema2")' in schema2_code


def test_split_by_schema_with_no_schema_tables():
"""Test split_by_schema handles tables without explicit schema."""
ddl = """
CREATE SCHEMA myschema;

CREATE TABLE myschema.users (
id int PRIMARY KEY
);

CREATE TABLE public_table (
id int PRIMARY KEY
);
"""
result = create_models(ddl, models_type="sqlalchemy", split_by_schema=True, dump=False)
code = result["code"]

# Should have myschema and empty string for tables without schema
assert "myschema" in code
assert "" in code

# Check myschema output
assert "MyschemaBase = declarative_base()" in code["myschema"]
assert "class Users(MyschemaBase):" in code["myschema"]

# Check default schema output (no schema)
assert "Base = declarative_base()" in code[""]
assert "class PublicTable(Base):" in code[""]
24 changes: 24 additions & 0 deletions tests/functional/generator/test_sqlalchemy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,27 @@ def test_relationships_multiple_foreign_keys():
# Check child relationships
assert 'user: Mapped["Users"] = relationship("Users", back_populates="comments")' in code
assert 'post: Mapped["Posts"] = relationship("Posts", back_populates="comments")' in code


def test_split_by_schema():
"""Test split_by_schema with SQLAlchemy 2.0 style models."""
ddl = """
CREATE SCHEMA schema1;

CREATE TABLE schema1.users (
id int PRIMARY KEY,
name varchar NOT NULL
);
"""
result = create_models(ddl, models_type="sqlalchemy_v2", split_by_schema=True, dump=False)
code = result["code"]

# Should have schema1
assert "schema1" in code

# Check schema1 output has SQLAlchemy 2.0 style with custom Base
schema1_code = code["schema1"]
assert "class Schema1Base(DeclarativeBase):" in schema1_code
assert "class Users(Schema1Base):" in schema1_code
assert "id: Mapped[int]" in schema1_code
assert 'dict(schema="schema1")' in schema1_code