Add gradient accumulation flag support#5
Add gradient accumulation flag support#5manhbi18112005 wants to merge 5 commits intoductho-le:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a user-facing way to enable gradient accumulation in the WaveDL training CLI, wiring it through config defaults and documentation so users can simulate larger effective batch sizes on limited-memory GPUs.
Changes:
- Add
--grad_accum_stepsCLI flag and pass it intoAccelerator(gradient_accumulation_steps=...). - Add
grad_accum_stepsto config validation/defaults and the exampleconfigs/config.yaml. - Update unit tests and README CLI flag table to document/verify the new option.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/wavedl/train.py |
Introduces --grad_accum_steps, wires it to Accelerate, and logs effective batch size. |
src/wavedl/utils/config.py |
Adds grad_accum_steps to known keys and default config generation. |
unit_tests/test_cli.py |
Adds coverage for default and custom --grad_accum_steps parsing. |
configs/config.yaml |
Documents the new config option with a default value. |
README.md |
Documents the new CLI flag in the arguments table. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| parser.add_argument( | ||
| "--grad_accum_steps", | ||
| type=int, | ||
| default=1, | ||
| help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus", | ||
| ) |
There was a problem hiding this comment.
--grad_accum_steps is accepted as any int, including 0 or negative values, which will either break Accelerator(gradient_accumulation_steps=...) or lead to undefined accumulation behavior. Add input validation (argparse choices=range(1, ...) or a post-parse check) to enforce grad_accum_steps >= 1 and provide a clear error message.
| "patience", | ||
| "weight_decay", | ||
| "grad_clip", | ||
| "grad_accum_steps", | ||
| # Loss |
There was a problem hiding this comment.
grad_accum_steps is added as a recognized config key, but validate_config() doesn’t validate its numeric range/type beyond the generic unknown-key check. Add a numeric check enforcing an integer >= 1 so invalid YAML values (0, negatives, floats/strings) are caught with a helpful warning.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
unit_tests/test_cli.py (1)
64-83: Add invalid-value tests for--grad_accum_steps.This covers the happy path well; please also add cases like
0and-1to lock in boundary behavior once validation is enforced.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@unit_tests/test_cli.py` around lines 64 - 83, Add unit tests for invalid values of the --grad_accum_steps CLI flag by extending or adding tests alongside test_grad_accum_steps_custom_value to assert parse_args() rejects 0 and negative numbers: call parse_args() with sys.argv containing "--grad_accum_steps", "0" and separately "--grad_accum_steps", "-1" and assert it either raises the expected exception (ValueError or SystemExit depending on validation) or that args.grad_accum_steps is normalized/validated per the implementation; reference the existing test_grad_accum_steps_custom_value and the parse_args function in wavedl.train to locate where to add these cases.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/wavedl/train.py`:
- Around line 263-268: The new argparse entry for "--grad_accum_steps" accepts
non-positive values; add validation so only integers >=1 are allowed. Implement
this by replacing the current parser.add_argument(... type=int ...) with a
custom validator (e.g., positive_int) or by checking args.grad_accum_steps
immediately after parsing in the same module (train.py), and raise
argparse.ArgumentTypeError or SystemExit with a clear message if the value < 1;
reference the existing parser.add_argument call for "--grad_accum_steps" and the
args object used after parse_args to locate where to enforce the check.
- Line 999: The training loop currently steps the OneCycleLR scheduler every
micro-batch which shortens the intended schedule when using
gradient_accumulation_steps=args.grad_accum_steps; modify the loop so
scheduler.step() is only called when gradients are actually synchronized/updated
by wrapping the scheduler.step() call with a conditional check on
accelerator.sync_gradients (or equivalent flag/context used in the loop),
ensuring scheduler.step() runs only on the step where optimizer.step() is
executed (i.e., when accelerator.sync_gradients is True) to preserve the correct
per-optimizer-update LR schedule.
In `@src/wavedl/utils/config.py`:
- Line 266: validate_config currently knows "grad_accum_steps" via
default_known_keys but doesn't validate its type/range; update the
numeric_checks dict used in validate_config() to include "grad_accum_steps": (1,
100000, "grad_accum_steps should be >= 1") and add an explicit type check in
validate_config() so that grad_accum_steps must be an int (reject floats or
strings like "8" or 8.5); reference the numeric_checks dict and the
validate_config function to implement both the range and int-type validation for
grad_accum_steps.
---
Nitpick comments:
In `@unit_tests/test_cli.py`:
- Around line 64-83: Add unit tests for invalid values of the --grad_accum_steps
CLI flag by extending or adding tests alongside
test_grad_accum_steps_custom_value to assert parse_args() rejects 0 and negative
numbers: call parse_args() with sys.argv containing "--grad_accum_steps", "0"
and separately "--grad_accum_steps", "-1" and assert it either raises the
expected exception (ValueError or SystemExit depending on validation) or that
args.grad_accum_steps is normalized/validated per the implementation; reference
the existing test_grad_accum_steps_custom_value and the parse_args function in
wavedl.train to locate where to add these cases.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
README.mdconfigs/config.yamlsrc/wavedl/train.pysrc/wavedl/utils/config.pyunit_tests/test_cli.py
| parser.add_argument( | ||
| "--grad_accum_steps", | ||
| type=int, | ||
| default=1, | ||
| help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus", | ||
| ) |
There was a problem hiding this comment.
Validate --grad_accum_steps as strictly positive at parse time.
Line 263 introduces a new integer arg, but 0/negative values are currently accepted and can break accumulation semantics later.
🔧 Proposed fix
@@
def parse_args() -> argparse.Namespace:
@@
+ def _positive_int(value: str) -> int:
+ v = int(value)
+ if v < 1:
+ raise argparse.ArgumentTypeError("grad_accum_steps must be >= 1")
+ return v
+
@@
parser.add_argument(
"--grad_accum_steps",
- type=int,
+ type=_positive_int,
default=1,
help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus",
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument( | |
| "--grad_accum_steps", | |
| type=int, | |
| default=1, | |
| help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus", | |
| ) | |
| def _positive_int(value: str) -> int: | |
| v = int(value) | |
| if v < 1: | |
| raise argparse.ArgumentTypeError("grad_accum_steps must be >= 1") | |
| return v | |
| parser.add_argument( | |
| "--grad_accum_steps", | |
| type=_positive_int, | |
| default=1, | |
| help="Gradient accumulation steps. Effective batch = batch_size x grad_accum_steps x num_gpus", | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/wavedl/train.py` around lines 263 - 268, The new argparse entry for
"--grad_accum_steps" accepts non-positive values; add validation so only
integers >=1 are allowed. Implement this by replacing the current
parser.add_argument(... type=int ...) with a custom validator (e.g.,
positive_int) or by checking args.grad_accum_steps immediately after parsing in
the same module (train.py), and raise argparse.ArgumentTypeError or SystemExit
with a clear message if the value < 1; reference the existing
parser.add_argument call for "--grad_accum_steps" and the args object used after
parse_args to locate where to enforce the check.
| "patience", | ||
| "weight_decay", | ||
| "grad_clip", | ||
| "grad_accum_steps", |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the config.py file structure
cd src/wavedl/utils && wc -l config.pyRepository: ductho-le/WaveDL
Length of output: 73
🏁 Script executed:
# Read the config.py file to understand validate_config() and the context around line 266
cd src/wavedl/utils && cat -n config.py | sed -n '250,290p'Repository: ductho-le/WaveDL
Length of output: 1328
🏁 Script executed:
# Search for validate_config function to understand its full implementation
cd src/wavedl/utils && cat -n config.py | grep -A 50 "def validate_config"Repository: ductho-le/WaveDL
Length of output: 2200
🏁 Script executed:
# Check if there are any existing validations for numeric fields
cd src/wavedl/utils && cat -n config.py | grep -B 5 -A 10 "numeric_checks"Repository: ductho-le/WaveDL
Length of output: 1235
🌐 Web query:
Hugging Face Accelerate gradient_accumulation_steps constraints valid values
💡 Result:
In 🤗 Accelerate, gradient_accumulation_steps is:
- Type:
int(it’s cast withint(...)). (huggingface.co) - Valid values (practically): positive integers
>= 1. The documented default is1. (huggingface.co) - Mutual exclusivity constraint: You must pass either
gradient_accumulation_stepsorgradient_accumulation_plugin—not both (passing a plugin while settinggradient_accumulation_steps != 1raisesValueError). (huggingface.co) - TPU constraint: Gradient accumulation is not supported on TPU, so it must be
1(and you must not pass aGradientAccumulationPlugin). (huggingface.co)
Citations:
- 1: https://huggingface.co/datasets/yyyzzzzyyy/envss/blob/2f63a9a85556113e08a02df73c9fc7a6ce075b17/Emu/lib/python3.10/site-packages/accelerate/accelerator.py?utm_source=openai
- 2: https://huggingface.co/docs/accelerate/en/package_reference/accelerator?utm_source=openai
- 3: https://huggingface.co/datasets/yyyzzzzyyy/envss/blob/2f63a9a85556113e08a02df73c9fc7a6ce075b17/Emu/lib/python3.10/site-packages/accelerate/accelerator.py?utm_source=openai
- 4: https://huggingface.co/datasets/yyyzzzzyyy/envss/blob/2f63a9a85556113e08a02df73c9fc7a6ce075b17/Emu/lib/python3.10/site-packages/accelerate/accelerator.py?utm_source=openai
Add validation for grad_accum_steps in config validation.
The field is recognized in default_known_keys (line 266) but validate_config() has no checks for its range or type. According to Hugging Face Accelerate, gradient_accumulation_steps must be a positive integer (>= 1). Invalid values (e.g., 0, -1, "8", 8.5) currently pass validation and fail only at runtime.
Add "grad_accum_steps": (1, 100000, "grad_accum_steps should be >= 1") to the numeric_checks dict and include a type check to ensure it is an int, not a float.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/wavedl/utils/config.py` at line 266, validate_config currently knows
"grad_accum_steps" via default_known_keys but doesn't validate its type/range;
update the numeric_checks dict used in validate_config() to include
"grad_accum_steps": (1, 100000, "grad_accum_steps should be >= 1") and add an
explicit type check in validate_config() so that grad_accum_steps must be an int
(reject floats or strings like "8" or 8.5); reference the numeric_checks dict
and the validate_config function to implement both the range and int-type
validation for grad_accum_steps.
|
Thanks for the contribution! The approach is clean -- leveraging Accelerate's built-in However, there are a couple of issues that need to be addressed before merging: Bug: Missing
|
|
No fix needed for loss reporting concert. Confirmed via Accelerate source that accelerator.backward(loss) uses local rebinding (loss = loss / num_steps), not in-place mutation. The caller's loss tensor remains unscaled, so metric accumulation at line 1409 is already correct. All other issues were addressed! |
There's no built-in CLI flag to enable gradient accumulation. Gradient accumulation allows for simulating a larger batch size by accumulating gradients over multiple small batches before performing an optimizer step. This is a standard technique for training large models on hardware with limited memory. Without this, users can't effectively train large models on consumer GPUs with limited memory, as they are forced to use small batch sizes that might lead to unstable training.
Recommendation
Summary by CodeRabbit
New Features
Documentation