From b221d2952ed4cee84e0d76070f00ef970209cf46 Mon Sep 17 00:00:00 2001 From: Andrei Maslennikov Date: Fri, 30 Jan 2026 11:32:46 +0100 Subject: [PATCH 1/5] More robust bench execution for Dynamo over k8s --- .../systems/kubernetes/kubernetes_system.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/cloudai/systems/kubernetes/kubernetes_system.py b/src/cloudai/systems/kubernetes/kubernetes_system.py index a74bc1fc6..f9ae85113 100644 --- a/src/cloudai/systems/kubernetes/kubernetes_system.py +++ b/src/cloudai/systems/kubernetes/kubernetes_system.py @@ -310,23 +310,18 @@ def _run_genai_perf(self, job: KubernetesJob) -> None: frontend_pod = self._get_dynamo_pod_by_role(role="frontend") - logging.debug(f"Executing genai-perf in pod={frontend_pod} cmd={genai_perf_cmd}") + kubectl_exec_cmd = ["kubectl", "exec", "-n", self.default_namespace, frontend_pod, "--", *genai_perf_cmd] + logging.debug(f"Executing genai-perf in pod={frontend_pod} cmd={kubectl_exec_cmd}") try: - genai_results = lazy.k8s.stream.stream( - self.core_v1.connect_get_namespaced_pod_exec, - name=frontend_pod, - namespace=self.default_namespace, - command=genai_perf_cmd, - stderr=True, - stdin=False, - stdout=True, - tty=False, - _request_timeout=60 * 10, - ) + result = subprocess.run(kubectl_exec_cmd, capture_output=True, text=True, timeout=60 * 10) + logging.debug(f"genai-perf exited with code {result.returncode}") with (job.test_run.output_path / "genai_perf.log").open("w") as f: - f.write(genai_results) - except lazy.k8s.client.ApiException as e: - logging.error(f"Error executing genai-perf command in pod '{frontend_pod}': {e}") + f.write(result.stdout) + if result.stderr: + f.write("\nSTDERR:\n") + f.write(result.stderr) + except Exception as e: + logging.debug(f"Error executing genai-perf command in pod '{frontend_pod}': {e}") cp_logs_cmd = " ".join( [ From 02531264d75dba51d96c0ccbde332fca3675230d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 9 Feb 2026 15:43:11 -0800 Subject: [PATCH 2/5] add report generation for nemotron ultra (that uses megatron workload) --- src/cloudai/registration.py | 2 + .../workloads/megatron_run/__init__.py | 3 +- .../report_generation_strategy.py | 158 +++++++++++++++++- 3 files changed, 160 insertions(+), 3 deletions(-) diff --git a/src/cloudai/registration.py b/src/cloudai/registration.py index f9be227e6..d44a273cc 100644 --- a/src/cloudai/registration.py +++ b/src/cloudai/registration.py @@ -98,6 +98,7 @@ def register_all(): ) from cloudai.workloads.megatron_run import ( CheckpointTimingReportGenerationStrategy, + MegatronRunReportGenerationStrategy, MegatronRunSlurmCommandGenStrategy, MegatronRunTestDefinition, ) @@ -259,6 +260,7 @@ def register_all(): Registry().add_report(GPTTestDefinition, JaxToolboxReportGenerationStrategy) Registry().add_report(GrokTestDefinition, JaxToolboxReportGenerationStrategy) Registry().add_report(MegatronRunTestDefinition, CheckpointTimingReportGenerationStrategy) + Registry().add_report(MegatronRunTestDefinition, MegatronRunReportGenerationStrategy) Registry().add_report(MegatronBridgeTestDefinition, MegatronBridgeReportGenerationStrategy) Registry().add_report(NCCLTestDefinition, NcclTestPerformanceReportGenerationStrategy) Registry().add_report(NeMoLauncherTestDefinition, NeMoLauncherReportGenerationStrategy) diff --git a/src/cloudai/workloads/megatron_run/__init__.py b/src/cloudai/workloads/megatron_run/__init__.py index 960461256..473203447 100644 --- a/src/cloudai/workloads/megatron_run/__init__.py +++ b/src/cloudai/workloads/megatron_run/__init__.py @@ -15,12 +15,13 @@ # limitations under the License. from .megatron_run import MegatronRunCmdArgs, MegatronRunTestDefinition -from .report_generation_strategy import CheckpointTimingReportGenerationStrategy +from .report_generation_strategy import CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy from .slurm_command_gen_strategy import MegatronRunSlurmCommandGenStrategy __all__ = [ "CheckpointTimingReportGenerationStrategy", "MegatronRunCmdArgs", + "MegatronRunReportGenerationStrategy", "MegatronRunSlurmCommandGenStrategy", "MegatronRunTestDefinition", ] diff --git a/src/cloudai/workloads/megatron_run/report_generation_strategy.py b/src/cloudai/workloads/megatron_run/report_generation_strategy.py index 50723a2ca..a7b71399e 100644 --- a/src/cloudai/workloads/megatron_run/report_generation_strategy.py +++ b/src/cloudai/workloads/megatron_run/report_generation_strategy.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import re +from pathlib import Path +from statistics import mean, median, pstdev +from typing import ClassVar -from cloudai.core import ReportGenerationStrategy +from cloudai.core import METRIC_ERROR, ReportGenerationStrategy CHECKPOINT_REGEX = re.compile(r"(save|load)-checkpoint\s.*:\s\((\d+\.\d+),\s(\d+\.\d+)\)") @@ -59,3 +64,152 @@ def generate_report(self) -> None: for checkpoint_type, timings in [("save", save_timings), ("load", load_timings)]: for t in timings: file.write(f"{checkpoint_type},{t[0]},{t[1]}\n") + + +class MegatronRunReportGenerationStrategy(ReportGenerationStrategy): + """Parse Megatron-LM training logs for step time and GPU TFLOP/s per GPU.""" + + metrics: ClassVar[list[str]] = ["default", "step-time", "tflops-per-gpu"] + + ITERATION_LINE_RE = re.compile( + r"elapsed time per iteration \(ms\):\s*([0-9]+(?:\.[0-9]+)?)" + r".*?" + r"throughput per GPU \(TFLOP/s/GPU\):\s*([0-9]+(?:\.[0-9]+)?)", + re.IGNORECASE, + ) + + def get_log_file(self) -> Path | None: + """Find the stdout log file containing Megatron training output.""" + stdout_path = self.test_run.output_path / "stdout.txt" + if stdout_path.is_file(): + return stdout_path + return None + + @property + def results_file(self) -> Path: + return self.get_log_file() or (self.test_run.output_path / "stdout.txt") + + def can_handle_directory(self) -> bool: + """Check if directory contains Megatron training logs with iteration metrics.""" + log_file = self.get_log_file() + if not log_file: + return False + + with log_file.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + if self.ITERATION_LINE_RE.search(line): + return True + return False + + def _extract(self, log_path: Path) -> tuple[list[float], list[float]]: + """Extract step times (in seconds) and GPU TFLOP/s from log file.""" + step_times_s: list[float] = [] + gpu_tflops: list[float] = [] + + with log_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + m = self.ITERATION_LINE_RE.search(line) + if m: + try: + elapsed_ms = float(m.group(1)) + step_times_s.append(elapsed_ms / 1000.0) + gpu_tflops.append(float(m.group(2))) + except (ValueError, TypeError): + logging.debug("Failed to parse iteration metrics line: %s", line.rstrip("\n")) + + if len(step_times_s) > 10: + step_times_s = step_times_s[-10:] + gpu_tflops = gpu_tflops[-10:] + + return step_times_s, gpu_tflops + + def _get_extracted_data(self) -> tuple[Path | None, list[float], list[float]]: + """Get log file and extracted metrics data.""" + log_file = self.get_log_file() + if not log_file: + return None, [], [] + step_times_s, gpu_tflops = self._extract(log_file) + return log_file, step_times_s, gpu_tflops + + def generate_report(self) -> None: + """Generate a summary report with step time and TFLOP/s statistics.""" + log_file, step_times_s, gpu_tflops = self._get_extracted_data() + if not log_file: + logging.error( + "No Megatron training log file found in: %s", + self.test_run.output_path, + ) + return + + summary_file = self.test_run.output_path / "report.txt" + if not step_times_s: + with summary_file.open("w") as f: + f.write("MegatronRun report\n") + f.write("No iteration metrics found in log.\n\n") + f.write("Expected log format:\n") + f.write(" elapsed time per iteration (ms): X.X | throughput per GPU (TFLOP/s/GPU): X.X\n\n") + f.write("Searched file:\n") + f.write(f" - {log_file}\n") + logging.warning("No iteration metrics found under %s (wrote %s)", self.test_run.output_path, summary_file) + return + + step_stats = { + "avg": mean(step_times_s), + "median": median(step_times_s), + "min": min(step_times_s), + "max": max(step_times_s), + "std": pstdev(step_times_s) if len(step_times_s) > 1 else 0.0, + } + + if gpu_tflops: + tflops_stats = { + "avg": mean(gpu_tflops), + "median": median(gpu_tflops), + "min": min(gpu_tflops), + "max": max(gpu_tflops), + "std": pstdev(gpu_tflops) if len(gpu_tflops) > 1 else 0.0, + } + else: + tflops_stats = {"avg": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0} + + with summary_file.open("w") as f: + f.write(f"Source log: {log_file}\n\n") + f.write("Step Time (s)\n") + f.write(f" avg: {step_stats['avg']:.4f}\n") + f.write(f" median: {step_stats['median']:.4f}\n") + f.write(f" min: {step_stats['min']:.4f}\n") + f.write(f" max: {step_stats['max']:.4f}\n") + f.write(f" std: {step_stats['std']:.4f}\n") + f.write("\n") + f.write("TFLOP/s per GPU\n") + f.write(f" avg: {tflops_stats['avg']:.2f}\n") + f.write(f" median: {tflops_stats['median']:.2f}\n") + f.write(f" min: {tflops_stats['min']:.2f}\n") + f.write(f" max: {tflops_stats['max']:.2f}\n") + f.write(f" std: {tflops_stats['std']:.2f}\n") + + logging.info("Generated MegatronRun report: %s", summary_file) + + def get_metric(self, metric: str) -> float: + """Get a specific metric value for DSE/optimization.""" + if metric not in {"default", "step-time", "tflops-per-gpu"}: + return METRIC_ERROR + + log_file, step_times_s, gpu_tflops = self._get_extracted_data() + if not log_file: + logging.error( + "No Megatron training log file found in: %s", + self.test_run.output_path, + ) + return METRIC_ERROR + + if not step_times_s: + return METRIC_ERROR + + if metric in {"default", "step-time"}: + return float(mean(step_times_s)) + + if metric == "tflops-per-gpu": + return float(mean(gpu_tflops)) if gpu_tflops else METRIC_ERROR + + return METRIC_ERROR From b309f0c3093c8d846e97a77f6ba5d6f104969af2 Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 9 Feb 2026 15:48:50 -0800 Subject: [PATCH 3/5] add unit tests --- ...megatron_run_report_generation_strategy.py | 249 ++++++++++++++++++ tests/test_test_scenario.py | 3 +- 2 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py diff --git a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py new file mode 100644 index 000000000..dfecfc1a0 --- /dev/null +++ b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from cloudai import TestRun +from cloudai.core import METRIC_ERROR +from cloudai.systems.slurm.slurm_system import SlurmSystem +from cloudai.workloads.megatron_run import MegatronRunReportGenerationStrategy + + +@pytest.fixture +def megatron_run_tr(tmp_path: Path) -> TestRun: + """Create a TestRun with sample Megatron-LM training logs.""" + tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) + log_content = ( + " [2026-02-06 20:55:02.918469] iteration 292/95367431 | consumed samples: 9344 | " + "elapsed time per iteration (ms): 3075.7 | throughput per GPU (TFLOP/s/GPU): 478.0 | " + "energy per GPU (J/iter/GPU): 1992.5 | power per GPU (W/GPU): 647.8 | learning rate: 9.568256E-08 | " + "global batch size: 32 | lm loss: 2.401035E-02 | loss scale: 1.0 | grad norm: 1.797 |\n" + " [2026-02-06 20:55:05.956222] iteration 293/95367431 | consumed samples: 9376 | " + "elapsed time per iteration (ms): 3037.2 | throughput per GPU (TFLOP/s/GPU): 484.0 | " + "energy per GPU (J/iter/GPU): 1982.6 | power per GPU (W/GPU): 652.8 | learning rate: 9.601024E-08 | " + "global batch size: 32 | lm loss: 2.386082E-02 | loss scale: 1.0 | grad norm: 1.797 |\n" + " [2026-02-06 20:55:08.991445] iteration 294/95367431 | consumed samples: 9408 | " + "elapsed time per iteration (ms): 3035.2 | throughput per GPU (TFLOP/s/GPU): 484.3 | " + "energy per GPU (J/iter/GPU): 1980.1 | power per GPU (W/GPU): 652.5 | learning rate: 9.633792E-08 | " + "global batch size: 32 | lm loss: 2.378540E-02 | loss scale: 1.0 | grad norm: 1.796 |\n" + ) + (tr.output_path / "stdout.txt").write_text(log_content) + return tr + + +@pytest.fixture +def megatron_run_tr_many_iterations(tmp_path: Path) -> TestRun: + """Create a TestRun with many iterations to test last-10 sampling.""" + tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) + lines = [] + for i in range(20): + # Create 20 iterations with varying step times and tflops + elapsed_ms = 3000.0 + i * 10 + tflops = 480.0 + i * 0.5 + lines.append( + f" [2026-02-06 20:55:{i:02d}.000000] iteration {i}/100 | consumed samples: {i*32} | " + f"elapsed time per iteration (ms): {elapsed_ms:.1f} | throughput per GPU (TFLOP/s/GPU): {tflops:.1f} | " + f"learning rate: 9.568256E-08 | global batch size: 32 | lm loss: 2.401035E-02 |\n" + ) + (tr.output_path / "stdout.txt").write_text("".join(lines)) + return tr + + +@pytest.fixture +def megatron_run_tr_no_metrics(tmp_path: Path) -> TestRun: + """Create a TestRun with log file but no iteration metrics.""" + tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) + log_content = "Some log content without iteration metrics\nAnother line\n" + (tr.output_path / "stdout.txt").write_text(log_content) + return tr + + +@pytest.fixture +def megatron_run_tr_empty(tmp_path: Path) -> TestRun: + """Create a TestRun with no log file.""" + tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) + return tr + + +def test_can_handle_directory(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test that strategy can handle directory with valid Megatron logs.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + assert strategy.can_handle_directory() + + +def test_cannot_handle_directory_no_file(slurm_system: SlurmSystem, megatron_run_tr_empty: TestRun) -> None: + """Test that strategy cannot handle directory without log file.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_empty) + assert not strategy.can_handle_directory() + + +def test_cannot_handle_directory_no_metrics(slurm_system: SlurmSystem, megatron_run_tr_no_metrics: TestRun) -> None: + """Test that strategy cannot handle directory without iteration metrics.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_metrics) + assert not strategy.can_handle_directory() + + +def test_generate_report(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test report generation with valid logs.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + strategy.generate_report() + + report_path = megatron_run_tr.output_path / "report.txt" + assert report_path.is_file(), "Report file should be created." + + content = report_path.read_text() + assert "Step Time (s)" in content + assert "TFLOP/s per GPU" in content + assert "avg:" in content + assert "median:" in content + assert "min:" in content + assert "max:" in content + assert "std:" in content + + +def test_generate_report_statistics(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test that report contains correct statistics.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + strategy.generate_report() + + report_path = megatron_run_tr.output_path / "report.txt" + content = report_path.read_text() + + # Expected values based on the log content: + # Step times (ms -> s): 3075.7 -> 3.0757, 3037.2 -> 3.0372, 3035.2 -> 3.0352 + # avg step time: (3.0757 + 3.0372 + 3.0352) / 3 = 3.0494 + # TFLOPs: 478.0, 484.0, 484.3 + # avg tflops: (478.0 + 484.0 + 484.3) / 3 = 482.1 + + # Check step time avg is approximately correct (converting ms to s) + assert "3.04" in content or "3.05" in content, "Average step time should be around 3.04-3.05 seconds" + + # Check tflops avg is approximately correct + assert "482" in content, "Average TFLOP/s should be around 482" + + +def test_generate_report_no_metrics(slurm_system: SlurmSystem, megatron_run_tr_no_metrics: TestRun) -> None: + """Test report generation when no metrics are found.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_metrics) + strategy.generate_report() + + report_path = megatron_run_tr_no_metrics.output_path / "report.txt" + assert report_path.is_file(), "Report file should be created even with no metrics." + + content = report_path.read_text() + assert "No iteration metrics found" in content + + +def test_generate_report_no_file(slurm_system: SlurmSystem, megatron_run_tr_empty: TestRun) -> None: + """Test report generation when log file is missing.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_empty) + strategy.generate_report() + + report_path = megatron_run_tr_empty.output_path / "report.txt" + assert not report_path.exists(), "Report file should not be created when log file is missing." + + +def test_get_metric_step_time(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test getting step-time metric.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + value = strategy.get_metric("step-time") + + # Expected: avg of 3.0757, 3.0372, 3.0352 = 3.0494 seconds + assert value != METRIC_ERROR + assert pytest.approx(value, rel=0.01) == 3.0494 + + +def test_get_metric_default(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test that default metric returns step-time.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + value = strategy.get_metric("default") + + # Default should be the same as step-time + step_time_value = strategy.get_metric("step-time") + assert value == step_time_value + + +def test_get_metric_tflops(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test getting tflops-per-gpu metric.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + value = strategy.get_metric("tflops-per-gpu") + + # Expected: avg of 478.0, 484.0, 484.3 = 482.1 + assert value != METRIC_ERROR + assert pytest.approx(value, rel=0.01) == 482.1 + + +def test_get_metric_invalid(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test that invalid metric returns METRIC_ERROR.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + value = strategy.get_metric("invalid-metric") + assert value == METRIC_ERROR + + +def test_get_metric_no_file(slurm_system: SlurmSystem, megatron_run_tr_empty: TestRun) -> None: + """Test that missing log file returns METRIC_ERROR.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_empty) + value = strategy.get_metric("step-time") + assert value == METRIC_ERROR + + +def test_get_metric_no_metrics(slurm_system: SlurmSystem, megatron_run_tr_no_metrics: TestRun) -> None: + """Test that log without metrics returns METRIC_ERROR.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr_no_metrics) + value = strategy.get_metric("step-time") + assert value == METRIC_ERROR + + +def test_results_file_property(slurm_system: SlurmSystem, megatron_run_tr: TestRun) -> None: + """Test that results_file property returns correct path.""" + strategy = MegatronRunReportGenerationStrategy(slurm_system, megatron_run_tr) + assert strategy.results_file == megatron_run_tr.output_path / "stdout.txt" + + +def test_metrics_class_variable() -> None: + """Test that metrics class variable is correctly defined.""" + assert MegatronRunReportGenerationStrategy.metrics == ["default", "step-time", "tflops-per-gpu"] + + +def test_extract_with_partial_valid_lines(slurm_system: SlurmSystem, tmp_path: Path) -> None: + """Test extraction with some valid and some invalid log lines.""" + tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) + log_content = ( + "Some random log line\n" + " [2026-02-06 20:55:02.918469] iteration 292/100 | " + "elapsed time per iteration (ms): 3000.0 | throughput per GPU (TFLOP/s/GPU): 480.0 |\n" + "Another random line without metrics\n" + " [2026-02-06 20:55:05.956222] iteration 293/100 | " + "elapsed time per iteration (ms): 3100.0 | throughput per GPU (TFLOP/s/GPU): 490.0 |\n" + "Final random line\n" + ) + (tr.output_path / "stdout.txt").write_text(log_content) + + strategy = MegatronRunReportGenerationStrategy(slurm_system, tr) + + # Should extract 2 valid samples + step_time = strategy.get_metric("step-time") + assert step_time != METRIC_ERROR + # avg of 3.0 and 3.1 seconds = 3.05 + assert pytest.approx(step_time, rel=0.01) == 3.05 + + tflops = strategy.get_metric("tflops-per-gpu") + assert tflops != METRIC_ERROR + # avg of 480 and 490 = 485 + assert pytest.approx(tflops, rel=0.01) == 485.0 diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index c2af1373b..a1b2b68f3 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -53,6 +53,7 @@ from cloudai.workloads.megatron_run import ( CheckpointTimingReportGenerationStrategy, MegatronRunCmdArgs, + MegatronRunReportGenerationStrategy, MegatronRunTestDefinition, ) from cloudai.workloads.nccl_test import ( @@ -481,7 +482,7 @@ def test_default_reporters_size(self): (DeepEPTestDefinition, {DeepEPReportGenerationStrategy}), (GPTTestDefinition, {JaxToolboxReportGenerationStrategy}), (GrokTestDefinition, {JaxToolboxReportGenerationStrategy}), - (MegatronRunTestDefinition, {CheckpointTimingReportGenerationStrategy}), + (MegatronRunTestDefinition, {CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy}), (MegatronBridgeTestDefinition, {MegatronBridgeReportGenerationStrategy}), (NCCLTestDefinition, {NcclTestPerformanceReportGenerationStrategy}), (NeMoLauncherTestDefinition, {NeMoLauncherReportGenerationStrategy}), From 624d68fe048ce3b5f5ec92562a471e9e68b9e31d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 9 Feb 2026 15:49:45 -0800 Subject: [PATCH 4/5] ruff --- .../test_megatron_run_report_generation_strategy.py | 2 +- tests/test_test_scenario.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py index dfecfc1a0..ba0333964 100644 --- a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py +++ b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py @@ -57,7 +57,7 @@ def megatron_run_tr_many_iterations(tmp_path: Path) -> TestRun: elapsed_ms = 3000.0 + i * 10 tflops = 480.0 + i * 0.5 lines.append( - f" [2026-02-06 20:55:{i:02d}.000000] iteration {i}/100 | consumed samples: {i*32} | " + f" [2026-02-06 20:55:{i:02d}.000000] iteration {i}/100 | consumed samples: {i * 32} | " f"elapsed time per iteration (ms): {elapsed_ms:.1f} | throughput per GPU (TFLOP/s/GPU): {tflops:.1f} | " f"learning rate: 9.568256E-08 | global batch size: 32 | lm loss: 2.401035E-02 |\n" ) diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index a1b2b68f3..4b80a2a6d 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -482,7 +482,10 @@ def test_default_reporters_size(self): (DeepEPTestDefinition, {DeepEPReportGenerationStrategy}), (GPTTestDefinition, {JaxToolboxReportGenerationStrategy}), (GrokTestDefinition, {JaxToolboxReportGenerationStrategy}), - (MegatronRunTestDefinition, {CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy}), + ( + MegatronRunTestDefinition, + {CheckpointTimingReportGenerationStrategy, MegatronRunReportGenerationStrategy}, + ), (MegatronBridgeTestDefinition, {MegatronBridgeReportGenerationStrategy}), (NCCLTestDefinition, {NcclTestPerformanceReportGenerationStrategy}), (NeMoLauncherTestDefinition, {NeMoLauncherReportGenerationStrategy}), From a754d503b6250f6a09c6c64097b2953fe8f63a6d Mon Sep 17 00:00:00 2001 From: Srivatsan Krishnan Date: Mon, 9 Feb 2026 16:08:21 -0800 Subject: [PATCH 5/5] remove dead fixtures --- ..._megatron_run_report_generation_strategy.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py index ba0333964..40e06774f 100644 --- a/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py +++ b/tests/report_generation_strategy/test_megatron_run_report_generation_strategy.py @@ -47,24 +47,6 @@ def megatron_run_tr(tmp_path: Path) -> TestRun: return tr -@pytest.fixture -def megatron_run_tr_many_iterations(tmp_path: Path) -> TestRun: - """Create a TestRun with many iterations to test last-10 sampling.""" - tr = TestRun(name="megatron_run", test=Mock(), num_nodes=1, nodes=[], output_path=tmp_path) - lines = [] - for i in range(20): - # Create 20 iterations with varying step times and tflops - elapsed_ms = 3000.0 + i * 10 - tflops = 480.0 + i * 0.5 - lines.append( - f" [2026-02-06 20:55:{i:02d}.000000] iteration {i}/100 | consumed samples: {i * 32} | " - f"elapsed time per iteration (ms): {elapsed_ms:.1f} | throughput per GPU (TFLOP/s/GPU): {tflops:.1f} | " - f"learning rate: 9.568256E-08 | global batch size: 32 | lm loss: 2.401035E-02 |\n" - ) - (tr.output_path / "stdout.txt").write_text("".join(lines)) - return tr - - @pytest.fixture def megatron_run_tr_no_metrics(tmp_path: Path) -> TestRun: """Create a TestRun with log file but no iteration metrics."""