Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dimos/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,23 @@
load_dotenv()


def _has_ros() -> bool:
try:
import rclpy # noqa: F401

return True
except ImportError:
return False


def pytest_configure(config):
config.addinivalue_line("markers", "tool: dev tooling")
config.addinivalue_line("markers", "slow: tests that are too slow for the fast loop")
config.addinivalue_line("markers", "mujoco: tests which open mujoco")
config.addinivalue_line("markers", "skipif_in_ci: skip when CI env var is set")
config.addinivalue_line("markers", "skipif_no_openai: skip when OPENAI_API_KEY is not set")
config.addinivalue_line("markers", "skipif_no_alibaba: skip when ALIBABA_API_KEY is not set")
config.addinivalue_line("markers", "skipif_no_ros: skip when ROS dependencies are not present")


@pytest.hookimpl()
Expand All @@ -39,6 +49,7 @@ def pytest_collection_modifyitems(config, items):
"skipif_in_ci": (bool(os.getenv("CI")), "Skipped in CI"),
"skipif_no_openai": (not os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY not set"),
"skipif_no_alibaba": (not os.getenv("ALIBABA_API_KEY"), "ALIBABA_API_KEY not set"),
"skipif_no_ros": (not _has_ros(), "ROS dependencies are not present"),
}
for marker_name, (condition, reason) in _skipif_markers.items():
if condition:
Expand Down
6 changes: 6 additions & 0 deletions dimos/protocol/pubsub/impl/test_rospubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def subscriber() -> Generator[DimosROS, None, None]:
yield from ros_node()


@pytest.mark.skipif_no_ros
def test_basic_conversion(publisher, subscriber):
"""Test Vector3 publish/subscribe through ROS.

Expand All @@ -75,6 +76,7 @@ def callback(msg, t):
assert msg.z == 3.0


@pytest.mark.skipif_no_ros
@pytest.mark.slow
def test_pointcloud2_pubsub(publisher, subscriber):
"""Test PointCloud2 publish/subscribe through ROS.
Expand Down Expand Up @@ -132,6 +134,7 @@ def callback(msg, t):
assert abs(original.ts - converted.ts) < 0.001


@pytest.mark.skipif_no_ros
def test_pointcloud2_empty_pubsub(publisher, subscriber):
"""Test empty PointCloud2 publish/subscribe.

Expand Down Expand Up @@ -160,6 +163,7 @@ def callback(msg, t):
assert len(received[0]) == 0


@pytest.mark.skipif_no_ros
def test_posestamped_pubsub(publisher, subscriber):
"""Test PoseStamped publish/subscribe through ROS.

Expand Down Expand Up @@ -200,6 +204,7 @@ def callback(msg, t):
np.testing.assert_allclose(converted.orientation.w, original.orientation.w, rtol=1e-5)


@pytest.mark.skipif_no_ros
def test_pointstamped_pubsub(publisher, subscriber):
"""Test PointStamped publish/subscribe through ROS.

Expand Down Expand Up @@ -242,6 +247,7 @@ def callback(msg, t):
assert converted.point.z == original.point.z


@pytest.mark.skipif_no_ros
def test_twist_pubsub(publisher, subscriber):
"""Test Twist publish/subscribe through ROS.

Expand Down
Loading