Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:
run: echo "$APPDATA\Python\Scripts" >> $GITHUB_PATH
- name: Configure poetry
run: poetry config virtualenvs.in-project true
- name: Install system dependencies (for PyQt6)
run: sudo apt-get update && sudo apt-get install -y libegl1
- name: Set up cache
uses: actions/cache@v3
id: cache
Expand All @@ -49,5 +51,7 @@ jobs:
run: timeout 10s poetry run pip --version || rm -rf .venv
- name: Install dependencies
run: poetry install
- name: Run tests
run: poetry run pytest
- name: Run tests with Xvfb (for PyQt support)
run: |
export QT_QPA_PLATFORM=offscreen
xvfb-run -a poetry run pytest
15 changes: 12 additions & 3 deletions niaaml_gui/widgets/pipeline_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def keyPressEvent(self, event):

else:
super().keyPressEvent(event)
self.pipelineStateChanged.emit()


def progress_start(self, maximum: int | None = None) -> None:
Expand Down Expand Up @@ -367,6 +366,9 @@ def _clear_connection_highlights(self):
self._highlighted_circles.clear()

def is_pipeline_ready(self) -> bool:
if not self.block_data:
return False

for block, info in self.block_data.items():
if hasattr(block, "get_value"):
value = block.get_value()
Expand All @@ -378,6 +380,7 @@ def is_pipeline_ready(self) -> bool:




class InteractiveConfigBlock(QGraphicsPathItem):
__niaamlFeatureSelectionAlgorithmsMap = (
FeatureSelectionAlgorithmFactory().get_name_to_classname_mapping()
Expand All @@ -402,9 +405,11 @@ def __init__(
is_number_input: bool = False,
dropdown_options=None,
icon_path: str | None = None,
is_sidebar=True,
readonly=False
):
super().__init__()

self.readonly = readonly
self.shape = shape.lower()
self.icon_path = icon_path
self.label = label
Expand Down Expand Up @@ -494,6 +499,9 @@ def __init__(
self._layout_contents()

def _handle_click_action(self):

if self.readonly:
return
if self.dropdown_options:
return
elif self.label in ["Feature Selection", "Feature Transform", "Classifier"]:
Expand Down Expand Up @@ -613,7 +621,8 @@ def _open_csv_editor(self):
if not getattr(self, "value", ""):
QMessageBox.warning(None, "No CSV", "Please pick a CSV file first.")
return
CSVEditorWindow(self.value).show()
self.csv_editor_window = CSVEditorWindow(self.value)
self.csv_editor_window.show()

def getPath(self):
dlg = QFileDialog
Expand Down
2 changes: 1 addition & 1 deletion niaaml_gui/widgets/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, canvas):
lst.setSpacing(2)
lst.setDragEnabled(True)
lst.mouseMoveEvent = self._wrapped_start_drag
lst.itemClicked.connect(self.handle_click)
#lst.itemClicked.connect(self.handle_click)

for label, icon_file in items.items():
icon_path = os.path.join(icon_dir, icon_file)
Expand Down
1,500 changes: 752 additions & 748 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ repository = "https://github.com/firefly-cpp/NiaAML-GUI"
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
python = ">=3.10,<3.13"
niapy = "^2.5.2"
QtAwesome = "^1.2.3"
niaaml = "^2.1.2"
scikit-learn = "^1.6.1"
pyqt6 = "^6.6.0"
pytest-qt = "^4.5.0"
pyqt-feedback-flow = "^0.3.5"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand Down
13 changes: 5 additions & 8 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_pipeline_run_works_fine(self):
has_header=True,
contains_classes=True,
)

pipeline.optimize(
data_reader.get_x(),
data_reader.get_y(),
Expand All @@ -95,21 +96,17 @@ def test_pipeline_run_works_fine(self):
"ParticleSwarmAlgorithm",
"Accuracy",
)
predicted = pipeline.run(
pandas.DataFrame(
numpy.random.uniform(
low=0.0, high=15.0, size=(30, data_reader.get_x().shape[1])
)
)
)

self.assertEqual(predicted.shape, (30,))
test_data = data_reader.get_x().iloc[:30]
predicted = pipeline.run(test_data)

self.assertEqual(predicted.shape, (30,))
s1 = set(data_reader.get_y())
s2 = set(predicted)
self.assertTrue(s2.issubset(s1))
self.assertTrue(len(s2) > 0 and len(s2) <= 2)


def test_pipeline_export_works_fine(self):
pipeline = Pipeline(
feature_selection_algorithm=SelectKBest(),
Expand Down
27 changes: 27 additions & 0 deletions tests/test_pipeline_canvas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from pytestqt.qtbot import QtBot

from niaaml_gui.widgets.pipeline_canvas import PipelineCanvas


@pytest.fixture
def canvas(qtbot: QtBot):
widget = PipelineCanvas()
qtbot.addWidget(widget)
return widget


def test_add_config_block_emits_signal(canvas, qtbot):
with qtbot.waitSignal(canvas.pipelineStateChanged, timeout=1000):
canvas.add_config_block("Select CSV File")


def test_pipeline_ready_false_when_empty(canvas):
canvas.block_data.clear()
canvas.scene.clear()

print("Block data keys:", canvas.block_data.keys())

assert len(canvas.block_data) == 0
assert not canvas.is_pipeline_ready()

85 changes: 85 additions & 0 deletions tests/test_pipeline_controls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from pytestqt.qtbot import QtBot
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QWidget, QVBoxLayout

from niaaml_gui.widgets.pipeline_controls import PipelineControlsWidget
from niaaml_gui.widgets.pipeline_canvas import PipelineCanvas


@pytest.fixture
def controls(qtbot: QtBot):
widget = PipelineControlsWidget()
qtbot.addWidget(widget)
return widget


def test_run_button_initially_disabled(controls):
assert controls.run_button.isEnabled() is False


def test_enable_disable_run_button(controls):
controls.setRunEnabled(True)
assert controls.run_button.isEnabled() is True

controls.setRunEnabled(False)
assert controls.run_button.isEnabled() is False


def test_run_button_emits_signal(controls, qtbot):
with qtbot.waitSignal(controls.runClicked, timeout=1000):
controls.setRunEnabled(True)
qtbot.mouseClick(controls.run_button, Qt.MouseButton.LeftButton)



@pytest.fixture
def full_widget(qtbot: QtBot):
canvas = PipelineCanvas()
controls = PipelineControlsWidget()

container = QWidget()
layout = QVBoxLayout()
layout.addWidget(canvas)
layout.addWidget(controls)
container.setLayout(layout)

qtbot.addWidget(container)

def update_run_button_state():
ready = canvas.is_pipeline_ready()
controls.run_button.setEnabled(ready)


container.canvas = canvas
container.controls = controls
container.update_fn = update_run_button_state

canvas.pipelineStateChanged.connect(update_run_button_state)

return container


def test_update_run_button_state_false_when_incomplete(full_widget):
canvas = full_widget.canvas
controls = full_widget.controls
update_fn = full_widget.update_fn

canvas.add_config_block("Select CSV File")
update_fn()

assert controls.run_button.isEnabled() is False


def test_update_run_button_state_true_when_ready(full_widget):
canvas = full_widget.canvas
controls = full_widget.controls
update_fn = full_widget.update_fn

canvas.add_config_block("Select CSV File")
block = list(canvas.block_data.items())[0][1]
block["path"] = "tests/tests_files/dataset_no_header_no_classes.csv"

update_fn()

assert controls.run_button.isEnabled() is True
Loading