diff --git a/.flake8 b/.flake8 index e872590..9ea07d0 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 88 -exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache \ No newline at end of file +exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,alembic/versions \ No newline at end of file diff --git a/app/crud/base.py b/app/crud/base.py index 86903fa..f15c916 100644 --- a/app/crud/base.py +++ b/app/crud/base.py @@ -3,7 +3,6 @@ from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy.orm import Session -import logging from app.database.base import Base diff --git a/app/crud/paper.py b/app/crud/paper.py index 544de1e..a3e90a6 100644 --- a/app/crud/paper.py +++ b/app/crud/paper.py @@ -4,8 +4,7 @@ class CRUDPaper(CRUDBase[Paper, PaperCreate, PaperUpdate]): - pass - + pass paper = CRUDPaper(Paper) diff --git a/app/crud/paper_with_code.py b/app/crud/paper_with_code.py index a51556b..476c22a 100644 --- a/app/crud/paper_with_code.py +++ b/app/crud/paper_with_code.py @@ -41,7 +41,10 @@ def get_multi_model_metrics_by_identifier( 'model_identifier': row.model_identifier, 'model_name': row.model_name, 'model_hardware_burden': row.model_hardware_burden, - 'model_operation_per_network_pass': row.model_gflops if row.model_gflops else row.model_multiply_adds, + 'model_operation_per_network_pass': ( + row.model_gflops + if row.model_gflops else row.model_multiply_adds + ), 'paper_identifier': row.paper_identifier, }) @@ -83,7 +86,10 @@ def get_model_metrics_by_identifier( 'model_identifier': response[0].model_identifier, 'model_name': response[0].model_name, 'model_hardware_burden': response[0].model_hardware_burden, - 'model_operation_per_network_pass': response[0].model_gflops if response[0].model_gflops else response[0].model_multiply_adds, + 'model_operation_per_network_pass': ( + response[0].model_gflops + if response[0].model_gflops else response[0].model_multiply_adds + ), 'paper_identifier': response[0].paper_identifier, } diff --git a/app/crud/task.py b/app/crud/task.py index 1354011..117b623 100644 --- a/app/crud/task.py +++ b/app/crud/task.py @@ -357,7 +357,10 @@ def get_models( 'gflops': row.model_gflops, 'number_of_parameters': row.model_number_of_parameters, 'multiply_adds': row.model_multiply_adds, - 'operation_per_network_pass': row.model_gflops if row.model_gflops else row.model_multiply_adds, + 'operation_per_network_pass': ( + row.model_gflops + if row.model_gflops else row.model_multiply_adds + ), 'hardware_burden': row.model_hardware_burden, 'paper_title': row.paper_title, 'paper_code_link': row.paper_code_link, @@ -443,7 +446,10 @@ def get_models_csv( row.accuracy_type: row.accuracy_value, 'model_gflops': row.model_gflops, 'model_multiply_adds': row.model_multiply_adds, - 'model_operation_per_network_pass': row.model_gflops if row.model_gflops else row.model_multiply_adds, + 'model_operation_per_network_pass': ( + row.model_gflops + if row.model_gflops else row.model_multiply_adds + ), 'model_extra_training_time': row.model_extra_training_time, 'model_number_of_cpus': row.model_number_of_cpus, 'model_cpu': row.model_cpu, diff --git a/app/database.py b/app/database.py index 8c5312d..75f723a 100644 --- a/app/database.py +++ b/app/database.py @@ -10,4 +10,4 @@ ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() \ No newline at end of file +Base = declarative_base() diff --git a/app/models/cpu.py b/app/models/cpu.py index cf61aaf..4738195 100644 --- a/app/models/cpu.py +++ b/app/models/cpu.py @@ -14,4 +14,4 @@ class Cpu(Base): tdp = Column(Float(precision=3)) gflops = Column(Float(precision=3)) die_size = Column(Integer) - year = Column(Integer) \ No newline at end of file + year = Column(Integer) diff --git a/app/models/gpu.py b/app/models/gpu.py index 9844649..e4a9cd3 100644 --- a/app/models/gpu.py +++ b/app/models/gpu.py @@ -11,4 +11,4 @@ class Gpu(Base): tdp = Column(Float(precision=3)) gflops = Column(Float(precision=3)) die_size = Column(Integer) - year = Column(Integer) \ No newline at end of file + year = Column(Integer) diff --git a/app/models/task_dataset.py b/app/models/task_dataset.py index 310851f..e3b6e57 100644 --- a/app/models/task_dataset.py +++ b/app/models/task_dataset.py @@ -1,8 +1,6 @@ from sqlalchemy import event -import logging -from sqlalchemy.sql.expression import bindparam, select, text -from app.models import Dataset, Task -from sqlalchemy.sql.functions import func + +from sqlalchemy.sql.expression import text from app.database.base import Base from sqlalchemy import Column, Integer, ForeignKey, String from sqlalchemy.orm import relationship @@ -25,8 +23,11 @@ class TaskDataset(Base): def my_before_insert_listener(mapper, connection, target): target.identifier = connection.execute( - text("select concat(task.identifier,'-on-', dataset.identifier) from task, dataset where task.id = %d and dataset.id = %d" % - (target.task_id, target.dataset_id)) + text( + "select concat(task.identifier,'-on-', dataset.identifier) from task, " + "dataset where task.id = %d and dataset.id = %d" % + (target.task_id, target.dataset_id) + ) ).scalar() diff --git a/app/routes/model.py b/app/routes/model.py index 6405ba2..da536c6 100644 --- a/app/routes/model.py +++ b/app/routes/model.py @@ -74,7 +74,9 @@ def get_models_csv( media_type="text/csv" ) - response.headers["Content-Disposition"] = f"attachment; filename={task_id}-{dataset_id}.csv" + response.headers[ + "Content-Disposition" + ] = f"attachment; filename={task_id}-{dataset_id}.csv" return response diff --git a/app/schemas/msg.py b/app/schemas/msg.py index 860e9f3..945e0c6 100644 --- a/app/schemas/msg.py +++ b/app/schemas/msg.py @@ -2,4 +2,4 @@ class Msg(BaseModel): - msg: str \ No newline at end of file + msg: str diff --git a/app/schemas/task.py b/app/schemas/task.py index ba3576d..0ae2b3c 100644 --- a/app/schemas/task.py +++ b/app/schemas/task.py @@ -1,8 +1,6 @@ -from app import models from pydantic.main import BaseModel from typing import List, Optional -from .dataset import Dataset from .model import Model # Shared properties