From 34333e43927f77f64e5a68462680abbb1611d04d Mon Sep 17 00:00:00 2001 From: Pranav Ghorpade Date: Sun, 23 Nov 2025 00:11:57 +0530 Subject: [PATCH] Fix generate_value() in BoundedArray to sample within bounds and add tests --- dm_env/specs.py | 8 ++++++-- dm_env/tests/test_bounded_array.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 dm_env/tests/test_bounded_array.py diff --git a/dm_env/specs.py b/dm_env/specs.py index 0dc989a..c0433fd 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -256,8 +256,12 @@ def validate(self, value): return value def generate_value(self): - return (np.ones(shape=self.shape, dtype=self.dtype) * - self.dtype.type(self.minimum)) + """Generate a random value within [minimum, maximum] that matches this spec.""" + return np.random.uniform( + low=self.minimum, + high=self.maximum, + size=self.shape + ).astype(self.dtype) def __reduce__(self): return BoundedArray, (self._shape, self._dtype, self._minimum, diff --git a/dm_env/tests/test_bounded_array.py b/dm_env/tests/test_bounded_array.py new file mode 100644 index 0000000..9024c7a --- /dev/null +++ b/dm_env/tests/test_bounded_array.py @@ -0,0 +1,16 @@ +import numpy as np +from dm_env import specs + +def test_generate_value_within_bounds(): + spec = specs.BoundedArray( + shape=(2, 2), + dtype=np.float32, + minimum=0.0, + maximum=5.0 + ) + value = spec.generate_value() + + assert value.shape == (2, 2) + assert value.dtype == np.float32 + assert np.all(value >= 0.0) + assert np.all(value <= 5.0)