From 34333e43927f77f64e5a68462680abbb1611d04d Mon Sep 17 00:00:00 2001 From: Pranav Ghorpade Date: Sun, 23 Nov 2025 00:11:57 +0530 Subject: [PATCH 1/2] Fix generate_value() in BoundedArray to sample within bounds and add tests --- dm_env/specs.py | 8 ++++++-- dm_env/tests/test_bounded_array.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 dm_env/tests/test_bounded_array.py diff --git a/dm_env/specs.py b/dm_env/specs.py index 0dc989a..c0433fd 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -256,8 +256,12 @@ def validate(self, value): return value def generate_value(self): - return (np.ones(shape=self.shape, dtype=self.dtype) * - self.dtype.type(self.minimum)) + """Generate a random value within [minimum, maximum] that matches this spec.""" + return np.random.uniform( + low=self.minimum, + high=self.maximum, + size=self.shape + ).astype(self.dtype) def __reduce__(self): return BoundedArray, (self._shape, self._dtype, self._minimum, diff --git a/dm_env/tests/test_bounded_array.py b/dm_env/tests/test_bounded_array.py new file mode 100644 index 0000000..9024c7a --- /dev/null +++ b/dm_env/tests/test_bounded_array.py @@ -0,0 +1,16 @@ +import numpy as np +from dm_env import specs + +def test_generate_value_within_bounds(): + spec = specs.BoundedArray( + shape=(2, 2), + dtype=np.float32, + minimum=0.0, + maximum=5.0 + ) + value = spec.generate_value() + + assert value.shape == (2, 2) + assert value.dtype == np.float32 + assert np.all(value >= 0.0) + assert np.all(value <= 5.0) From 3c7da4f253546f3655097832f6729c36f0b56fde Mon Sep 17 00:00:00 2001 From: Pranav Ghorpade Date: Sun, 23 Nov 2025 00:35:51 +0530 Subject: [PATCH 2/2] Add TreeSpec to support dataclass-based nested specs (Issue #13) --- dm_env/__init__.py | 1 + dm_env/specs.py | 43 ++++++++++++++++++++++++++++++++++ dm_env/tests/test_tree_spec.py | 24 +++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 dm_env/tests/test_tree_spec.py diff --git a/dm_env/__init__.py b/dm_env/__init__.py index 37ed80a..c0f19a0 100644 --- a/dm_env/__init__.py +++ b/dm_env/__init__.py @@ -27,3 +27,4 @@ termination = _environment.termination transition = _environment.transition truncation = _environment.truncation +from dm_env import specs \ No newline at end of file diff --git a/dm_env/specs.py b/dm_env/specs.py index c0433fd..9d9bd5a 100644 --- a/dm_env/specs.py +++ b/dm_env/specs.py @@ -406,3 +406,46 @@ def __repr__(self): def __reduce__(self): return type(self), (self.shape, self.string_type, self.name) + +from dataclasses import is_dataclass, fields + +class TreeSpec: + """A container for nested spec-like structures, including dataclasses.""" + + def __init__(self, structure): + self.structure = structure + + def generate_value(self): + return _generate_tree_value(self.structure) + + def __repr__(self): + return f"TreeSpec({self.structure!r})" + + +def _generate_tree_value(structure): + """Recursively generate test values for nested spec structures.""" + + # Case 1 — Array, BoundedArray, DiscreteArray, StringArray + if isinstance(structure, Array): + return structure.generate_value() + + # Case 2 — dataclass + if is_dataclass(structure): + return type(structure)(**{ + f.name: _generate_tree_value(getattr(structure, f.name)) + for f in fields(structure) + }) + + # Case 3 — dict + if isinstance(structure, dict): + return {k: _generate_tree_value(v) for k, v in structure.items()} + + # Case 4 — tuple + if isinstance(structure, tuple): + return tuple(_generate_tree_value(v) for v in structure) + + # Case 5 — list + if isinstance(structure, list): + return [_generate_tree_value(v) for v in structure] + + raise TypeError(f"Unsupported element in TreeSpec: {type(structure)}") diff --git a/dm_env/tests/test_tree_spec.py b/dm_env/tests/test_tree_spec.py new file mode 100644 index 0000000..270321a --- /dev/null +++ b/dm_env/tests/test_tree_spec.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from dm_env import specs + +@dataclass +class MyAction: + a: specs.Array + b: specs.Array + +def test_tree_spec_generate_value(): + spec = specs.TreeSpec( + MyAction( + a=specs.Array(shape=(2,), dtype=float), + b=specs.Array(shape=(3,), dtype=float) + ) + ) + + value = spec.generate_value() + + # type check + assert isinstance(value, MyAction) + + # shape checks + assert value.a.shape == (2,) + assert value.b.shape == (3,)