Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `table_prefix` and `table_suffix` parameters for class name customization
- Boolean defaults 0/1 converted to False/True
- Expanded `datetime_now_check` with more SQL datetime keywords
- VARCHAR(n) and CHAR(n) now generate `Field(max_length=n)` for Pydantic validation (issue #48)

**SQLAlchemy 2.0 Support (issue #49)**
- New `sqlalchemy_v2` models type with modern SQLAlchemy 2.0 syntax
Expand Down
39 changes: 33 additions & 6 deletions omymodels/models/pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from omymodels.models.pydantic.types import types_mapping
from omymodels.types import big_integer_types, integer_types, string_types, text_types

# Types that support max_length constraint
MAX_LENGTH_TYPES = string_types


class ModelGenerator:
def __init__(self):
Expand Down Expand Up @@ -74,9 +77,18 @@ def get_not_custom_type(self, type_str: str) -> str:
self.typing_imports.add("List")
return _type

def _should_add_max_length(self, column: Column) -> bool:
"""Check if column should have max_length constraint."""
if not column.size:
return False
# Only add max_length for string types (varchar, char, etc.), not text
original_type = column.type.lower().split("[")[0]
return original_type in MAX_LENGTH_TYPES

def generate_attr(self, column: Column, defaults_off: bool) -> str:
_type = None
original_type = column.type # Keep original for array detection
max_length = column.size if self._should_add_max_length(column) else None

if column.nullable:
self.typing_imports.add("Optional")
Expand All @@ -99,13 +111,20 @@ def generate_attr(self, column: Column, defaults_off: bool) -> str:
arg_name = column.name
field_params = None

# Check if we need Field() for alias or generated column
# Check if we need Field() for alias, generated column, or max_length
generated_as = getattr(column, "generated_as", None)
if not self._is_valid_identifier(column.name) or generated_as is not None:
field_params = self._get_field_params(column, defaults_off)
needs_field = (
not self._is_valid_identifier(column.name)
or generated_as is not None
or max_length is not None
)

if needs_field:
field_params = self._get_field_params(column, defaults_off, max_length)
if field_params:
self.imports.add("Field")
arg_name = self._generate_valid_identifier(column.name)
if not self._is_valid_identifier(column.name):
arg_name = self._generate_valid_identifier(column.name)
else:
if column.default is not None and not defaults_off:
field_params = self._get_default_value_string(column)
Expand All @@ -118,20 +137,28 @@ def generate_attr(self, column: Column, defaults_off: bool) -> str:

return column_str

def _get_field_params(self, column: Column, defaults_off: bool) -> str:
def _get_field_params(
self, column: Column, defaults_off: bool, max_length: int = None
) -> str:
params = []

if not self._is_valid_identifier(column.name):
params.append(f'alias="{column.name}"')

if column.default is not None and not defaults_off:
# For nullable fields with max_length, add default=None
if column.nullable and max_length is not None and not defaults_off:
params.append("default=None")
elif column.default is not None and not defaults_off:
if default_value := self._get_default_value_string(column):
params.append(f"default{default_value.replace(' ', '')}")

generated_as = getattr(column, "generated_as", None)
if generated_as is not None:
params.append("exclude=True")

if max_length is not None:
params.append(f"max_length={max_length}")

if params:
return f" = Field({', '.join(params)})"
return ""
Expand Down
45 changes: 43 additions & 2 deletions omymodels/models/pydantic_v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from omymodels.helpers import create_class_name, datetime_now_check
from omymodels.models.pydantic_v2 import templates as pt
from omymodels.models.pydantic_v2.types import types_mapping
from omymodels.types import datetime_types
from omymodels.types import datetime_types, string_types

# Types that support max_length constraint
MAX_LENGTH_TYPES = string_types


class ModelGenerator:
Expand Down Expand Up @@ -49,8 +52,17 @@ def get_not_custom_type(self, column: Column) -> str:
self.uuid_import = True
return _type

def _should_add_max_length(self, column: Column) -> bool:
"""Check if column should have max_length constraint."""
if not column.size:
return False
# Only add max_length for string types (varchar, char, etc.), not text
original_type = column.type.lower().split("[")[0]
return original_type in MAX_LENGTH_TYPES

def generate_attr(self, column: Column, defaults_off: bool) -> str:
_type = None
max_length = column.size if self._should_add_max_length(column) else None

# Pydantic v2 uses X | None syntax
if column.nullable:
Expand All @@ -65,14 +77,43 @@ def generate_attr(self, column: Column, defaults_off: bool) -> str:

column_str = column_str.format(arg_name=column.name, type=_type)

if column.default is not None and not defaults_off:
# Handle max_length with Field()
if max_length is not None:
self.imports.add("Field")
field_params = []
# Handle defaults
if column.nullable and not defaults_off:
field_params.append("default=None")
elif column.default is not None and not defaults_off:
default_val = self._get_default_value(column)
if default_val:
field_params.append(f"default={default_val}")
field_params.append(f"max_length={max_length}")
column_str += f" = Field({', '.join(field_params)})"
elif column.default is not None and not defaults_off:
column_str = self.add_default_values(column_str, column)
elif column.nullable and not defaults_off:
# Nullable fields without explicit default should default to None
column_str += pt.pydantic_default_attr.format(default="None")

return column_str

def _get_default_value(self, column: Column) -> str:
"""Get formatted default value for Field()."""
if column.default is None or str(column.default).upper() == "NULL":
return ""

# Handle datetime default values
if column.type.upper() in datetime_types:
if datetime_now_check(column.default.lower()):
return "datetime.datetime.now()"

# Add quotes for string defaults if not already quoted
default_val = column.default
if isinstance(default_val, str) and "'" not in default_val and '"' not in default_val:
default_val = f"'{default_val}'"
return default_val

@staticmethod
def add_default_values(column_str: str, column: Column) -> str:
# Handle datetime default values
Expand Down
60 changes: 55 additions & 5 deletions tests/functional/generator/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ def test_pydantic_models_generator():

expected = """from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field


class UserHistory(BaseModel):
runid: Optional[float]
job_id: Optional[float]
id: str
user: str
status: str
id: str = Field(max_length=100)
user: str = Field(max_length=100)
status: str = Field(max_length=10)
event_time: datetime = datetime.now()
comment: str = 'none'
comment: str = Field(default='none', max_length=1000)
"""
assert result == expected

Expand Down Expand Up @@ -320,3 +320,53 @@ class TestDefaults(BaseModel):
col_timestamp: Optional[datetime]
"""
assert expected == result["code"]


def test_pydantic_varchar_max_length():
"""Test that VARCHAR(n) generates Field(max_length=n).

Regression test for issue #48.
"""
ddl = """
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(255),
bio TEXT
);
"""
result = create_models(ddl, models_type="pydantic")
expected = """from typing import Optional
from pydantic import BaseModel, Field


class Users(BaseModel):
id: int
name: str = Field(max_length=100)
email: Optional[str] = Field(default=None, max_length=255)
bio: Optional[str]
"""
assert expected == result["code"]


def test_pydantic_char_max_length():
"""Test that CHAR(n) generates Field(max_length=n).

Regression test for issue #48.
"""
ddl = """
CREATE TABLE codes (
code CHAR(10) NOT NULL,
description VARCHAR(200)
);
"""
result = create_models(ddl, models_type="pydantic")
expected = """from typing import Optional
from pydantic import BaseModel, Field


class Codes(BaseModel):
code: str = Field(max_length=10)
description: Optional[str] = Field(default=None, max_length=200)
"""
assert expected == result["code"]
64 changes: 59 additions & 5 deletions tests/functional/generator/test_pydantic_v2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ def test_pydantic_v2_models_generator():
expected = """from __future__ import annotations

import datetime
from pydantic import BaseModel
from pydantic import BaseModel, Field


class UserHistory(BaseModel):

runid: float | None = None
job_id: float | None = None
id: str
user: str
status: str
id: str = Field(max_length=100)
user: str = Field(max_length=100)
status: str = Field(max_length=10)
event_time: datetime.datetime = datetime.datetime.now()
comment: str = 'none'
comment: str = Field(default='none', max_length=1000)
"""
assert result == expected

Expand Down Expand Up @@ -241,3 +241,57 @@ class OptionalData(BaseModel):
active: bool | None = None
"""
assert expected == result


def test_pydantic_v2_varchar_max_length():
"""Test that VARCHAR(n) generates Field(max_length=n) in Pydantic v2.

Regression test for issue #48.
"""
ddl = """
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(255),
bio TEXT
);
"""
result = create_models(ddl, models_type="pydantic_v2")
expected = """from __future__ import annotations

from pydantic import BaseModel, Field


class Users(BaseModel):

id: int
name: str = Field(max_length=100)
email: str | None = Field(default=None, max_length=255)
bio: str | None = None
"""
assert expected == result["code"]


def test_pydantic_v2_char_max_length():
"""Test that CHAR(n) generates Field(max_length=n) in Pydantic v2.

Regression test for issue #48.
"""
ddl = """
CREATE TABLE codes (
code CHAR(10) NOT NULL,
description VARCHAR(200)
);
"""
result = create_models(ddl, models_type="pydantic_v2")
expected = """from __future__ import annotations

from pydantic import BaseModel, Field


class Codes(BaseModel):

code: str = Field(max_length=10)
description: str | None = Field(default=None, max_length=200)
"""
assert expected == result["code"]
35 changes: 35 additions & 0 deletions tests/integration/pydantic/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import pytest

from omymodels import create_models


Expand Down Expand Up @@ -31,3 +33,36 @@ def test_pydantic_models_are_working_as_expected(load_generated_code) -> None:
assert used_model

os.remove(os.path.abspath(module.__file__))


def test_pydantic_max_length_validation(load_generated_code) -> None:
"""Integration test: verify max_length constraint is enforced (issue #48)."""
from pydantic import ValidationError

ddl = """
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(10) NOT NULL,
email VARCHAR(50)
);
"""
result = create_models(ddl, models_type="pydantic")["code"]

module = load_generated_code(result)

# Valid data within max_length
user = module.Users(id=1, name="John", email="john@example.com")
assert user.name == "John"
assert user.email == "john@example.com"

# Name exceeds max_length of 10
with pytest.raises(ValidationError) as exc_info:
module.Users(id=2, name="A" * 11, email="test@example.com")
assert "name" in str(exc_info.value)

# Email exceeds max_length of 50
with pytest.raises(ValidationError) as exc_info:
module.Users(id=3, name="Jane", email="a" * 51)
assert "email" in str(exc_info.value)

os.remove(os.path.abspath(module.__file__))
Loading