Skip to content

Add gradient accumulation flag support#5

Open
manhbi18112005 wants to merge 5 commits intoductho-le:mainfrom
manhbi18112005:feats
Open

Add gradient accumulation flag support#5
manhbi18112005 wants to merge 5 commits intoductho-le:mainfrom
manhbi18112005:feats

Conversation

@manhbi18112005
Copy link

@manhbi18112005 manhbi18112005 commented Mar 2, 2026

There's no built-in CLI flag to enable gradient accumulation. Gradient accumulation allows for simulating a larger batch size by accumulating gradients over multiple small batches before performing an optimizer step. This is a standard technique for training large models on hardware with limited memory. Without this, users can't effectively train large models on consumer GPUs with limited memory, as they are forced to use small batch sizes that might lead to unstable training.

Recommendation

  • Add a gradient_accumulation_steps argument to the parser and update the optimizer step logic to only trigger after the specified number of steps.

Summary by CodeRabbit

  • New Features

    • Introduced gradient accumulation support for memory-efficient training. Users can now configure gradient accumulation steps to increase effective batch size without proportionally increasing GPU memory requirements, selectable via CLI or configuration file.
  • Documentation

    • Updated documentation with gradient accumulation parameter details and usage guidance.

Copilot AI review requested due to automatic review settings March 2, 2026 07:48
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a user-facing way to enable gradient accumulation in the WaveDL training CLI, wiring it through config defaults and documentation so users can simulate larger effective batch sizes on limited-memory GPUs.

Changes:

  • Add --grad_accum_steps CLI flag and pass it into Accelerator(gradient_accumulation_steps=...).
  • Add grad_accum_steps to config validation/defaults and the example configs/config.yaml.
  • Update unit tests and README CLI flag table to document/verify the new option.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/wavedl/train.py Introduces --grad_accum_steps, wires it to Accelerate, and logs effective batch size.
src/wavedl/utils/config.py Adds grad_accum_steps to known keys and default config generation.
unit_tests/test_cli.py Adds coverage for default and custom --grad_accum_steps parsing.
configs/config.yaml Documents the new config option with a default value.
README.md Documents the new CLI flag in the arguments table.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +263 to +268
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--grad_accum_steps is accepted as any int, including 0 or negative values, which will either break Accelerator(gradient_accumulation_steps=...) or lead to undefined accumulation behavior. Add input validation (argparse choices=range(1, ...) or a post-parse check) to enforce grad_accum_steps >= 1 and provide a clear error message.

Copilot uses AI. Check for mistakes.
Comment on lines 263 to 267
"patience",
"weight_decay",
"grad_clip",
"grad_accum_steps",
# Loss
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_accum_steps is added as a recognized config key, but validate_config() doesn’t validate its numeric range/type beyond the generic unknown-key check. Add a numeric check enforcing an integer >= 1 so invalid YAML values (0, negatives, floats/strings) are caught with a helpful warning.

Copilot uses AI. Check for mistakes.
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
unit_tests/test_cli.py (1)

64-83: Add invalid-value tests for --grad_accum_steps.

This covers the happy path well; please also add cases like 0 and -1 to lock in boundary behavior once validation is enforced.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@unit_tests/test_cli.py` around lines 64 - 83, Add unit tests for invalid
values of the --grad_accum_steps CLI flag by extending or adding tests alongside
test_grad_accum_steps_custom_value to assert parse_args() rejects 0 and negative
numbers: call parse_args() with sys.argv containing "--grad_accum_steps", "0"
and separately "--grad_accum_steps", "-1" and assert it either raises the
expected exception (ValueError or SystemExit depending on validation) or that
args.grad_accum_steps is normalized/validated per the implementation; reference
the existing test_grad_accum_steps_custom_value and the parse_args function in
wavedl.train to locate where to add these cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/wavedl/train.py`:
- Around line 263-268: The new argparse entry for "--grad_accum_steps" accepts
non-positive values; add validation so only integers >=1 are allowed. Implement
this by replacing the current parser.add_argument(... type=int ...) with a
custom validator (e.g., positive_int) or by checking args.grad_accum_steps
immediately after parsing in the same module (train.py), and raise
argparse.ArgumentTypeError or SystemExit with a clear message if the value < 1;
reference the existing parser.add_argument call for "--grad_accum_steps" and the
args object used after parse_args to locate where to enforce the check.
- Line 999: The training loop currently steps the OneCycleLR scheduler every
micro-batch which shortens the intended schedule when using
gradient_accumulation_steps=args.grad_accum_steps; modify the loop so
scheduler.step() is only called when gradients are actually synchronized/updated
by wrapping the scheduler.step() call with a conditional check on
accelerator.sync_gradients (or equivalent flag/context used in the loop),
ensuring scheduler.step() runs only on the step where optimizer.step() is
executed (i.e., when accelerator.sync_gradients is True) to preserve the correct
per-optimizer-update LR schedule.

In `@src/wavedl/utils/config.py`:
- Line 266: validate_config currently knows "grad_accum_steps" via
default_known_keys but doesn't validate its type/range; update the
numeric_checks dict used in validate_config() to include "grad_accum_steps": (1,
100000, "grad_accum_steps should be >= 1") and add an explicit type check in
validate_config() so that grad_accum_steps must be an int (reject floats or
strings like "8" or 8.5); reference the numeric_checks dict and the
validate_config function to implement both the range and int-type validation for
grad_accum_steps.

---

Nitpick comments:
In `@unit_tests/test_cli.py`:
- Around line 64-83: Add unit tests for invalid values of the --grad_accum_steps
CLI flag by extending or adding tests alongside
test_grad_accum_steps_custom_value to assert parse_args() rejects 0 and negative
numbers: call parse_args() with sys.argv containing "--grad_accum_steps", "0"
and separately "--grad_accum_steps", "-1" and assert it either raises the
expected exception (ValueError or SystemExit depending on validation) or that
args.grad_accum_steps is normalized/validated per the implementation; reference
the existing test_grad_accum_steps_custom_value and the parse_args function in
wavedl.train to locate where to add these cases.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 811c7c3 and 92d5f0d.

📒 Files selected for processing (5)
  • README.md
  • configs/config.yaml
  • src/wavedl/train.py
  • src/wavedl/utils/config.py
  • unit_tests/test_cli.py

Comment on lines +263 to +268
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate --grad_accum_steps as strictly positive at parse time.

Line 263 introduces a new integer arg, but 0/negative values are currently accepted and can break accumulation semantics later.

🔧 Proposed fix
@@
 def parse_args() -> argparse.Namespace:
@@
+    def _positive_int(value: str) -> int:
+        v = int(value)
+        if v < 1:
+            raise argparse.ArgumentTypeError("grad_accum_steps must be >= 1")
+        return v
+
@@
     parser.add_argument(
         "--grad_accum_steps",
-        type=int,
+        type=_positive_int,
         default=1,
         help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
)
def _positive_int(value: str) -> int:
v = int(value)
if v < 1:
raise argparse.ArgumentTypeError("grad_accum_steps must be >= 1")
return v
parser.add_argument(
"--grad_accum_steps",
type=_positive_int,
default=1,
help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/wavedl/train.py` around lines 263 - 268, The new argparse entry for
"--grad_accum_steps" accepts non-positive values; add validation so only
integers >=1 are allowed. Implement this by replacing the current
parser.add_argument(... type=int ...) with a custom validator (e.g.,
positive_int) or by checking args.grad_accum_steps immediately after parsing in
the same module (train.py), and raise argparse.ArgumentTypeError or SystemExit
with a clear message if the value < 1; reference the existing
parser.add_argument call for "--grad_accum_steps" and the args object used after
parse_args to locate where to enforce the check.

"patience",
"weight_decay",
"grad_clip",
"grad_accum_steps",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the config.py file structure
cd src/wavedl/utils && wc -l config.py

Repository: ductho-le/WaveDL

Length of output: 73


🏁 Script executed:

# Read the config.py file to understand validate_config() and the context around line 266
cd src/wavedl/utils && cat -n config.py | sed -n '250,290p'

Repository: ductho-le/WaveDL

Length of output: 1328


🏁 Script executed:

# Search for validate_config function to understand its full implementation
cd src/wavedl/utils && cat -n config.py | grep -A 50 "def validate_config"

Repository: ductho-le/WaveDL

Length of output: 2200


🏁 Script executed:

# Check if there are any existing validations for numeric fields
cd src/wavedl/utils && cat -n config.py | grep -B 5 -A 10 "numeric_checks"

Repository: ductho-le/WaveDL

Length of output: 1235


🌐 Web query:

Hugging Face Accelerate gradient_accumulation_steps constraints valid values

💡 Result:

In 🤗 Accelerate, gradient_accumulation_steps is:

  • Type: int (it’s cast with int(...)). (huggingface.co)
  • Valid values (practically): positive integers >= 1. The documented default is 1. (huggingface.co)
  • Mutual exclusivity constraint: You must pass either gradient_accumulation_steps or gradient_accumulation_pluginnot both (passing a plugin while setting gradient_accumulation_steps != 1 raises ValueError). (huggingface.co)
  • TPU constraint: Gradient accumulation is not supported on TPU, so it must be 1 (and you must not pass a GradientAccumulationPlugin). (huggingface.co)

Citations:


Add validation for grad_accum_steps in config validation.

The field is recognized in default_known_keys (line 266) but validate_config() has no checks for its range or type. According to Hugging Face Accelerate, gradient_accumulation_steps must be a positive integer (>= 1). Invalid values (e.g., 0, -1, "8", 8.5) currently pass validation and fail only at runtime.

Add "grad_accum_steps": (1, 100000, "grad_accum_steps should be >= 1") to the numeric_checks dict and include a type check to ensure it is an int, not a float.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/wavedl/utils/config.py` at line 266, validate_config currently knows
"grad_accum_steps" via default_known_keys but doesn't validate its type/range;
update the numeric_checks dict used in validate_config() to include
"grad_accum_steps": (1, 100000, "grad_accum_steps should be >= 1") and add an
explicit type check in validate_config() so that grad_accum_steps must be an int
(reject floats or strings like "8" or 8.5); reference the numeric_checks dict
and the validate_config function to implement both the range and int-type
validation for grad_accum_steps.

@ductho-le ductho-le self-requested a review March 2, 2026 23:23
@ductho-le
Copy link
Owner

Thanks for the contribution! The approach is clean -- leveraging Accelerate's built-in gradient_accumulation_steps with the existing accelerator.accumulate(model) context manager is the right call.

However, there are a couple of issues that need to be addressed before merging:

Bug: Missing default_known_keys Entry

grad_accum_steps is not added to the default_known_keys set in config.py (around line 255-321). This set is used to detect typos in YAML config files. Without the entry, users who set grad_accum_steps: 4 in their YAML config will get a spurious warning:

Unknown config key: 'grad_accum_steps' - check for typos or see wavedl-train --help

Fix: Add "grad_accum_steps" to the default_known_keys set near "grad_clip".

Missing Input Validation

There's no guard against grad_accum_steps <= 0, which would cause division-by-zero or undefined behavior in Accelerate. Please add validation, e.g.:

if args.grad_accum_steps < 1:
    raise ValueError(f"--grad_accum_steps must be >= 1, got {args.grad_accum_steps}")

Also consider adding a numeric range check in validate_config's numeric_checks dict:

"grad_accum_steps": (1, 256, "Gradient accumulation steps should be 1-256"),

Minor Suggestions (non-blocking)

  1. Loss reporting: When grad_accum_steps > 1, accelerator.accumulate() divides the loss by grad_accum_steps before .backward(). The reported training loss will be scaled down -- consider documenting this or compensating in the metric accumulation.
  2. Test coverage: Consider adding a test that verifies grad_accum_steps=0 raises an error (once validation is added), and a config YAML integration test.
    Please fix the bug and validation issue, and the rest is at your discretion. Thanks!

Repository owner deleted a comment from coderabbitai bot Mar 2, 2026
@manhbi18112005
Copy link
Author

manhbi18112005 commented Mar 3, 2026

No fix needed for loss reporting concert. Confirmed via Accelerate source that accelerator.backward(loss) uses local rebinding (loss = loss / num_steps), not in-place mutation. The caller's loss tensor remains unscaled, so metric accumulation at line 1409 is already correct.

All other issues were addressed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants