diff --git a/dm_env/specs.py b/dm_env/specs.py index 0dc989a..c0433fd 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -256,8 +256,12 @@ def validate(self, value): return value def generate_value(self): - return (np.ones(shape=self.shape, dtype=self.dtype) * - self.dtype.type(self.minimum)) + """Generate a random value within [minimum, maximum] that matches this spec.""" + return np.random.uniform( + low=self.minimum, + high=self.maximum, + size=self.shape + ).astype(self.dtype) def __reduce__(self): return BoundedArray, (self._shape, self._dtype, self._minimum, diff --git a/dm_env/tests/test_bounded_array.py b/dm_env/tests/test_bounded_array.py new file mode 100644 index 0000000..9024c7a --- /dev/null +++ b/dm_env/tests/test_bounded_array.py @@ -0,0 +1,16 @@ +import numpy as np +from dm_env import specs + +def test_generate_value_within_bounds(): + spec = specs.BoundedArray( + shape=(2, 2), + dtype=np.float32, + minimum=0.0, + maximum=5.0 + ) + value = spec.generate_value() + + assert value.shape == (2, 2) + assert value.dtype == np.float32 + assert np.all(value >= 0.0) + assert np.all(value <= 5.0)