Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Uses `X | None` union syntax for nullable columns
- Supports all column types, foreign keys, indexes, and constraints

**SQLAlchemy Relationships (issue #47)**
- New `relationships` parameter for `create_models()` to generate `relationship()` with `back_populates`
- Automatically generates bidirectional relationships for foreign keys:
- Parent side (one-to-many): collection attribute pointing to children
- Child side (many-to-one): attribute pointing to parent
- Works with both `sqlalchemy` and `sqlalchemy_v2` model types
- For `sqlalchemy_v2`: uses `Mapped[List[T]]` for one-to-many and `Mapped[T]` for many-to-one

**SQLModel Improvements**
- Fixed array type generation (issue #66)
- Arrays now properly generate `List[T]` with correct SQLAlchemy ARRAY type
Expand Down
59 changes: 59 additions & 0 deletions omymodels/from_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def create_models(
no_auto_snake_case: Optional[bool] = False,
table_prefix: Optional[str] = "",
table_suffix: Optional[str] = "",
relationships: Optional[bool] = False,
):
"""models_type can be: "gino", "dataclass", "pydantic" """
# extract data from ddl file
Expand All @@ -65,6 +66,7 @@ def create_models(
defaults_off,
table_prefix=table_prefix,
table_suffix=table_suffix,
relationships=relationships,
)
if dump:
save_models_to_file(output, dump_path)
Expand Down Expand Up @@ -138,6 +140,56 @@ def save_models_to_file(models: str, dump_path: str) -> None:
f.write(models)


def _add_relationship(
relationships: Dict, table_name: str, fk_column: str, ref_table: str, ref_column: str
):
"""Helper to add both sides of a relationship."""
relationships.setdefault(table_name, []).append({
"type": "many_to_one",
"fk_column": fk_column,
"ref_table": ref_table,
"ref_column": ref_column,
"child_table_name": table_name,
})
relationships.setdefault(ref_table, []).append({
"type": "one_to_many",
"child_table": table_name,
"fk_column": fk_column,
})


def _get_alter_columns(table) -> List:
"""Get ALTER TABLE columns if they exist."""
if hasattr(table, 'alter') and table.alter:
return table.alter.get("columns", [])
return []


def collect_relationships(tables: List) -> Dict:
"""Collect foreign key relationships between tables."""
relationships = {}

for table in tables:
for column in table.columns:
if column.references and column.references.get("table"):
_add_relationship(
relationships, table.name, column.name,
column.references["table"],
column.references.get("column") or column.name
)

for alter_col in _get_alter_columns(table):
ref_info = alter_col.get("references")
if ref_info and ref_info.get("table"):
_add_relationship(
relationships, table.name, alter_col["name"],
ref_info["table"],
ref_info.get("column") or alter_col["name"]
)

return relationships


def generate_models_file(
data: Dict[str, List],
singular: bool = False,
Expand All @@ -147,6 +199,7 @@ def generate_models_file(
defaults_off: Optional[bool] = False,
table_prefix: Optional[str] = "",
table_suffix: Optional[str] = "",
relationships: Optional[bool] = False,
) -> str:
"""method to prepare full file with all Models &"""
models_str = ""
Expand All @@ -159,6 +212,11 @@ def generate_models_file(
if data["tables"]:
add_custom_types_to_generator(data["types"], generator)

# Collect relationships if enabled
relationships_map = {}
if relationships:
relationships_map = collect_relationships(data["tables"])

for table in data["tables"]:
models_str += generator.generate_model(
table,
Expand All @@ -168,6 +226,7 @@ def generate_models_file(
defaults_off=defaults_off,
table_prefix=table_prefix,
table_suffix=table_suffix,
relationships=relationships_map.get(table.name, []) if relationships else [],
)
header += generator.create_header(
data["tables"], schema=schema_global, models_str=models_str
Expand Down
60 changes: 59 additions & 1 deletion omymodels/models/sqlalchemy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self):
self.postgresql_dialect_cols = set()
self.constraint = False
self.im_index = False
self.relationship_import = False
self.types_mapping = types_mapping
self.templates = st
self.prefix = "sa."
Expand Down Expand Up @@ -46,14 +47,16 @@ def generate_model(
singular: bool = True,
exceptions: Optional[List] = None,
schema_global: Optional[bool] = True,
relationships: Optional[List] = None,
*args,
**kwargs,
) -> str:
"""method to prepare one Model defention - name & tablename & columns"""
model = ""
model_name = create_class_name(table.name, singular, exceptions)

model = st.model_template.format(
model_name=create_class_name(table.name, singular, exceptions),
model_name=model_name,
table_name=table.name,
)
for column in table.columns:
Expand All @@ -62,8 +65,61 @@ def generate_model(
)
if table.indexes or table.alter or table.checks or not schema_global:
model = logic.add_table_args(self, model, table, schema_global)

# Generate relationships if enabled
if relationships:
model += self._generate_relationships(
relationships, singular, exceptions
)
return model

def _generate_relationships(
self,
relationships: List[Dict],
singular: bool,
exceptions: Optional[List] = None,
) -> str:
"""Generate relationship() lines for the model."""
result = "\n"
self.relationship_import = True

for rel in relationships:
if rel["type"] == "many_to_one":
# Child side: reference to parent
# e.g., posts.user = relationship("Users", back_populates="posts")
ref_table = rel["ref_table"]
fk_column = rel["fk_column"]
child_table_name = rel["child_table_name"]
related_class = create_class_name(ref_table, singular, exceptions)
# Attribute name derived from FK column (user_id -> user)
attr_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else ref_table.lower()
# back_populates points to the collection on the parent (uses child table name)
back_pop_name = child_table_name.lower().replace("-", "_")
back_populates = st.back_populates_template.format(attr_name=back_pop_name)
result += st.relationship_template.format(
attr_name=attr_name,
related_class=related_class,
back_populates=back_populates,
)
elif rel["type"] == "one_to_many":
# Parent side: collection of children
# e.g., users.posts = relationship("Posts", back_populates="user")
child_table = rel["child_table"]
fk_column = rel["fk_column"]
related_class = create_class_name(child_table, singular, exceptions)
# Attribute name is the child table name (as-is, since table names are typically plural)
attr_name = child_table.lower().replace("-", "_")
# back_populates points to the single parent reference on the child
# Derived from FK column (user_id -> user)
back_pop_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else child_table.lower()
back_populates = st.back_populates_template.format(attr_name=back_pop_name)
result += st.relationship_template.format(
attr_name=attr_name,
related_class=related_class,
back_populates=back_populates,
)
return result

def create_header(
self, tables: List[Dict], schema: bool = False, *args, **kwargs
) -> str:
Expand All @@ -82,4 +138,6 @@ def create_header(
header += st.unique_cons_import + "\n"
if self.im_index:
header += st.index_import + "\n"
if self.relationship_import:
header += st.relationship_import + "\n"
return header
5 changes: 5 additions & 0 deletions omymodels/models/sqlalchemy/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ class {model_name}(Base):\n

on_delete = ', ondelete="{mode}"'
on_update = ', onupdate="{mode}"'

# relationship templates
relationship_import = "from sqlalchemy.orm import relationship"
relationship_template = ' {attr_name} = relationship("{related_class}"{back_populates})\n'
back_populates_template = ', back_populates="{attr_name}"'
65 changes: 65 additions & 0 deletions omymodels/models/sqlalchemy_v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self):
self.time_import = False
self.uuid_import = False
self.fk_import = False
self.relationship_import = False
self.types_mapping = types_mapping
self.templates = st
self.prefix = ""
Expand Down Expand Up @@ -225,6 +226,7 @@ def generate_model(
singular: bool = True,
exceptions: Optional[List] = None,
schema_global: Optional[bool] = True,
relationships: Optional[List] = None,
*args,
**kwargs,
) -> str:
Expand All @@ -242,8 +244,68 @@ def generate_model(
if table.indexes or table.alter or table.checks or not schema_global:
model = self._add_table_args(model, table, schema_global)

# Generate relationships if enabled
if relationships:
model += self._generate_relationships(
relationships, singular, exceptions
)

return model

def _generate_relationships(
self,
relationships: List[Dict],
singular: bool,
exceptions: Optional[List] = None,
) -> str:
"""Generate relationship() lines for the model."""
result = "\n"
self.relationship_import = True
self.typing_imports.add("List")

for rel in relationships:
if rel["type"] == "many_to_one":
# Child side: reference to parent
# e.g., author: Mapped["Authors"] = relationship("Authors", back_populates="books")
ref_table = rel["ref_table"]
fk_column = rel["fk_column"]
child_table_name = rel["child_table_name"]
related_class = create_class_name(ref_table, singular, exceptions)
# Attribute name derived from FK column (author_id -> author)
attr_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else ref_table.lower()
# back_populates points to the collection on the parent (uses child table name)
back_pop_name = child_table_name.lower().replace("-", "_")
back_populates = st.back_populates_template.format(attr_name=back_pop_name)
# Type hint for many-to-one is the related class (quoted for forward ref)
type_hint = f'"{related_class}"'
result += st.relationship_template.format(
attr_name=attr_name,
type_hint=type_hint,
related_class=related_class,
back_populates=back_populates,
)
elif rel["type"] == "one_to_many":
# Parent side: collection of children
# e.g., books: Mapped[List["Books"]] = relationship("Books", back_populates="author")
child_table = rel["child_table"]
fk_column = rel["fk_column"]
related_class = create_class_name(child_table, singular, exceptions)
# Attribute name is the child table name (as-is, since table names are typically plural)
attr_name = child_table.lower().replace("-", "_")
# back_populates points to the single parent reference on the child
# Derived from FK column (author_id -> author)
back_pop_name = fk_column.replace("_id", "") if fk_column.endswith("_id") else child_table.lower()
back_populates = st.back_populates_template.format(attr_name=back_pop_name)
# Type hint for one-to-many is List of related class (quoted for forward ref)
type_hint = f'List["{related_class}"]'
result += st.relationship_template.format(
attr_name=attr_name,
type_hint=type_hint,
related_class=related_class,
back_populates=back_populates,
)
return result

def _add_table_args(
self, model: str, table: Dict, schema_global: bool = True
) -> str:
Expand Down Expand Up @@ -320,4 +382,7 @@ def create_header(
if self.im_index:
parts.append(st.index_import + "\n")

if self.relationship_import:
parts.append(st.relationship_import + "\n")

return "".join(parts)
5 changes: 5 additions & 0 deletions omymodels/models/sqlalchemy_v2/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,8 @@ class {model_name}(Base):

on_delete = ', ondelete="{mode}"'
on_update = ', onupdate="{mode}"'

# relationship templates
relationship_import = "from sqlalchemy.orm import relationship"
relationship_template = ' {attr_name}: Mapped[{type_hint}] = relationship("{related_class}"{back_populates})\n'
back_populates_template = ', back_populates="{attr_name}"'
Loading