Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def whoami(
Returns the current authenticated user
"""
user = await User.get_by_username(session, current_user.username)
return UserOutput.from_orm(user)
return UserOutput.model_validate(user, from_attributes=True)


@router.get("/token/")
Expand Down
7 changes: 5 additions & 2 deletions datajunction-server/datajunction_server/api/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ async def list_attributes(
List all available attribute types.
"""
attributes = await AttributeType.get_all(session)
return [AttributeTypeBase.from_orm(attr) for attr in attributes]
return [
AttributeTypeBase.model_validate(attr, from_attributes=True)
for attr in attributes
]


@router.post(
Expand All @@ -62,7 +65,7 @@ async def add_attribute_type(
message=f"Attribute type `{data.name}` already exists!",
)
attribute_type = await AttributeType.create(session, data)
return AttributeTypeBase.from_orm(attribute_type)
return AttributeTypeBase.model_validate(attribute_type, from_attributes=True)


async def default_attribute_types(session: AsyncSession = Depends(get_session)):
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/api/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def list_catalogs(
"""
statement = select(Catalog).options(joinedload(Catalog.engines))
return [
CatalogInfo.from_orm(catalog)
CatalogInfo.model_validate(catalog, from_attributes=True)
for catalog in (await session.execute(statement)).unique().scalars()
]

Expand Down Expand Up @@ -111,7 +111,7 @@ async def add_catalog(
await session.commit()
await session.refresh(catalog, ["engines"])

return CatalogInfo.from_orm(catalog)
return CatalogInfo.model_validate(catalog, from_attributes=True)


@router.post(
Expand All @@ -136,7 +136,7 @@ async def add_engines_to_catalog(
session.add(catalog)
await session.commit()
await session.refresh(catalog)
return CatalogInfo.from_orm(catalog)
return CatalogInfo.model_validate(catalog, from_attributes=True)


async def list_new_engines(
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def create_a_collection(
await session.commit()
await session.refresh(collection)

return CollectionInfo.from_orm(collection)
return CollectionInfo.model_validate(collection, from_attributes=True)


@router.delete("/collections/{name}", status_code=HTTPStatus.NO_CONTENT)
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def cube_materialization_info(
Granularity.YEAR: "0 0 1 1 *", # Runs at midnight on January 1st every year
}
upsert = UpsertCubeMaterialization(
job=MaterializationJobTypeEnum.DRUID_CUBE,
job=MaterializationJobTypeEnum.DRUID_CUBE.value.name,
strategy=(
MaterializationStrategy.INCREMENTAL_TIME
if temporal_partition
Expand All @@ -186,7 +186,7 @@ async def cube_materialization_info(
metrics=cube_config.metrics,
strategy=upsert.strategy,
schedule=upsert.schedule,
job=upsert.job.name,
job=upsert.job.name, # type: ignore
measures_materializations=cube_config.measures_materializations,
combiners=cube_config.combiners,
)
Expand Down Expand Up @@ -299,7 +299,7 @@ async def get_cube_dimension_values(
value=row[0 : count_column[0]] if count_column else row,
count=row[count_column[0]] if count_column else None,
)
for row in result.results.__root__[0].rows
for row in result.results.root[0].rows
]
return DimensionValues( # pragma: no cover
dimensions=[
Expand Down
28 changes: 18 additions & 10 deletions datajunction-server/datajunction_server/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,15 @@ async def add_availability_state(
table=data.table,
valid_through_ts=data.valid_through_ts,
url=data.url,
min_temporal_partition=data.min_temporal_partition,
max_temporal_partition=data.max_temporal_partition,
min_temporal_partition=[
str(part) for part in data.min_temporal_partition or []
],
max_temporal_partition=[
str(part) for part in data.max_temporal_partition or []
],
partitions=[
partition.dict() if not isinstance(partition, Dict) else partition
for partition in data.partitions # type: ignore
partition.model_dump() if not isinstance(partition, Dict) else partition
for partition in (data.partitions or [])
],
categorical_partitions=data.categorical_partitions,
temporal_partitions=data.temporal_partitions,
Expand All @@ -159,10 +163,14 @@ async def add_availability_state(
entity_type=EntityType.AVAILABILITY,
node=node.name, # type: ignore
activity_type=ActivityType.CREATE,
pre=AvailabilityStateBase.from_orm(old_availability).dict()
pre=AvailabilityStateBase.model_validate(
old_availability,
).model_dump()
if old_availability
else {},
post=AvailabilityStateBase.from_orm(node_revision.availability).dict(),
post=AvailabilityStateBase.model_validate(
node_revision.availability,
).model_dump(),
user=current_user.username,
),
session=session,
Expand Down Expand Up @@ -333,8 +341,8 @@ async def get_data(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = generated_sql.columns # type: ignore
if result.results.root: # pragma: no cover
result.results.root[0].columns = generated_sql.columns # type: ignore
return result


Expand Down Expand Up @@ -518,8 +526,8 @@ async def get_data_for_metrics(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = translated_sql.columns or []
if result.results.root: # pragma: no cover
result.results.root[0].columns = translated_sql.columns or []
return result


Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/api/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def submit(self, spec: DeploymentSpec, context: DeploymentContext) -> str:
deployment = Deployment(
uuid=deployment_uuid,
namespace=spec.namespace,
spec=spec.dict(),
spec=spec.model_dump(),
status=DeploymentStatus.PENDING,
created_by_id=context.current_user.id,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ async def update_status(
deployment = await session.get(Deployment, deployment_uuid)
deployment.status = status
if results is not None:
deployment.results = [r.dict() for r in results]
deployment.results = [r.model_dump() for r in results]
await session.commit()

async def _run_deployment(
Expand Down
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/api/djsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ async def get_data_for_djsql(
)

# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = translated_sql.columns or []
if result.results.root: # pragma: no cover
result.results.root[0].columns = translated_sql.columns or []
return result


Expand Down
8 changes: 5 additions & 3 deletions datajunction-server/datajunction_server/api/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def list_engines(
List all available engines
"""
return [
EngineInfo.from_orm(engine)
EngineInfo.model_validate(engine)
for engine in (await session.execute(select(Engine))).scalars()
]

Expand All @@ -58,7 +58,9 @@ async def get_an_engine(
"""
Return an engine by name and version
"""
return EngineInfo.from_orm(await get_engine(session, name, version))
return EngineInfo.model_validate(
await get_engine(session, name, version),
)


@router.post(
Expand Down Expand Up @@ -95,4 +97,4 @@ async def add_engine(
await session.commit()
await session.refresh(engine)

return EngineInfo.from_orm(engine)
return EngineInfo.model_validate(engine)
10 changes: 5 additions & 5 deletions datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,8 @@ async def query_event_stream(
"query end state detected (%s), sending final event to the client",
query_next.state,
)
if query_next.results.__root__: # pragma: no cover
query_next.results.__root__[0].columns = columns or []
if query_next.results.root: # pragma: no cover
query_next.results.root[0].columns = columns or []
yield {
"event": "message",
"id": uuid.uuid4(),
Expand Down Expand Up @@ -879,13 +879,13 @@ def get_node_revision_materialization(
)
if materialization.strategy != MaterializationStrategy.INCREMENTAL_TIME:
info.urls = [info.urls[0]]
materialization_config_output = MaterializationConfigOutput.from_orm(
materialization_config_output = MaterializationConfigOutput.model_validate(
materialization,
)
materializations.append(
MaterializationConfigInfoUnified(
**materialization_config_output.dict(),
**info.dict(),
**materialization_config_output.model_dump(),
**info.model_dump(),
),
)
return materializations
4 changes: 2 additions & 2 deletions datajunction-server/datajunction_server/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def list_history(
offset=offset,
limit=limit,
)
return [HistoryOutput.from_orm(entry) for entry in hist]
return [HistoryOutput.model_validate(entry) for entry in hist]


@router.get("/history/", response_model=List[HistoryOutput])
Expand Down Expand Up @@ -86,4 +86,4 @@ async def list_history_by_node_context(
)
result = await session.execute(statement)
hist = result.scalars().all()
return [HistoryOutput.from_orm(entry) for entry in hist]
return [HistoryOutput.model_validate(entry) for entry in hist]
29 changes: 19 additions & 10 deletions datajunction-server/datajunction_server/api/materializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import logging
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Callable, List
from typing import Annotated, Callable, List

from fastapi import Depends, Request
from pydantic import Discriminator
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -68,7 +69,9 @@ def materialization_jobs_info() -> JSONResponse:
return JSONResponse(
status_code=200,
content={
"job_types": [value.value.dict() for value in MaterializationJobTypeEnum],
"job_types": [
value.value.model_dump() for value in MaterializationJobTypeEnum
],
"strategies": [
{"name": value, "label": labelize(value)}
for value in MaterializationStrategy
Expand All @@ -84,7 +87,10 @@ def materialization_jobs_info() -> JSONResponse:
)
async def upsert_materialization(
node_name: str,
data: UpsertMaterialization | UpsertCubeMaterialization,
materialization: Annotated[
UpsertCubeMaterialization | UpsertMaterialization,
Discriminator("job"),
],
*,
session: AsyncSession = Depends(get_session),
request: Request,
Expand Down Expand Up @@ -117,20 +123,20 @@ async def upsert_materialization(
current_revision = node.current # type: ignore
old_materializations = {mat.name: mat for mat in current_revision.materializations}

if data.strategy == MaterializationStrategy.INCREMENTAL_TIME:
if materialization.strategy == MaterializationStrategy.INCREMENTAL_TIME: # type: ignore
if not node.current.temporal_partition_columns(): # type: ignore
raise DJInvalidInputException(
http_status_code=HTTPStatus.BAD_REQUEST,
message="Cannot create materialization with strategy "
f"`{data.strategy}` without specifying a time partition column!",
f"`{materialization.strategy}` without specifying a time partition column!", # type: ignore
)

# Create a new materialization
new_materialization = await create_new_materialization(
session,
current_revision,
data,
validate_access,
materialization,
validate_access, # type: ignore
current_user=current_user,
)

Expand Down Expand Up @@ -194,7 +200,7 @@ async def upsert_materialization(
f"already exists for node `{node_name}` but was deactivated. It has now been "
f"restored."
),
"info": existing_materialization_info.dict(),
"info": existing_materialization_info.model_dump(),
},
)
# If changes are detected, update the existing or save the new materialization
Expand Down Expand Up @@ -498,7 +504,10 @@ async def run_materialization_backfill(
)
backfill = Backfill(
materialization=materialization,
spec=[backfill_partition.dict() for backfill_partition in backfill_partitions],
spec=[
backfill_partition.model_dump()
for backfill_partition in backfill_partitions
],
urls=materialization_output.urls,
)
materialization.backfills.append(backfill)
Expand All @@ -511,7 +520,7 @@ async def run_materialization_backfill(
details={
"materialization": materialization_name,
"partition": [
backfill_partition.dict()
backfill_partition.model_dump()
for backfill_partition in backfill_partitions
],
},
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/api/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def hard_delete_node_namespace(
status_code=HTTPStatus.OK,
content={
"message": f"The namespace `{namespace}` has been completely removed.",
"impact": impacts.dict(),
"impact": impacts.model_dump(),
},
)

Expand Down
Loading
Loading