1- from typing import Any , Optional
1+ from typing import Any
2+ from warnings import deprecated
23
34from rcrscore .entities import EntityID
45from rcrscore .entities .area import Area
@@ -39,7 +40,7 @@ def set_change_set(self, change_set: ChangeSet) -> None:
3940 """
4041 self ._change_set = change_set
4142
42- def get_entity (self , entity_id : EntityID ) -> Optional [ Entity ] :
43+ def get_entity (self , entity_id : EntityID ) -> Entity | None :
4344 """
4445 Get the entity
4546
@@ -50,7 +51,7 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]:
5051
5152 Returns
5253 -------
53- Optional[ Entity]
54+ Entity | None
5455 Entity
5556 """
5657 return self ._world_model .get_entity (entity_id )
@@ -179,8 +180,8 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
179180 ValueError
180181 If one or both entities are invalid or the location is invalid
181182 """
182- entity1 : Optional [ Entity ] = self .get_entity (entity_id1 )
183- entity2 : Optional [ Entity ] = self .get_entity (entity_id2 )
183+ entity1 : Entity | None = self .get_entity (entity_id1 )
184+ entity2 : Entity | None = self .get_entity (entity_id2 )
184185 if entity1 is None or entity2 is None :
185186 raise ValueError (
186187 f"One or both entities are invalid: entity_id1={ entity_id1 } , entity_id2={ entity_id2 } , entity1={ entity1 } , entity2={ entity2 } "
@@ -204,6 +205,9 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
204205
205206 return distance
206207
208+ @deprecated (
209+ "get_entity_position is deprecated, use get_entity_position_entity_id or get_entity_position_entity instead."
210+ )
207211 def get_entity_position (self , entity_id : EntityID ) -> EntityID | None :
208212 """
209213 Get the entity position
@@ -215,7 +219,7 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None:
215219
216220 Returns
217221 -------
218- EntityID
222+ EntityID | None
219223 Entity position
220224
221225 Raises
@@ -234,6 +238,50 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None:
234238 return entity .get_position ()
235239 raise ValueError (f"Invalid entity type: entity_id={ entity_id } , entity={ entity } " )
236240
241+ def get_entity_position_entity_id (self , entity_id : EntityID ) -> EntityID | None :
242+ """
243+ Get the entity position EntityID
244+
245+ Parameters
246+ ----------
247+ entity_id : EntityID
248+ Entity ID
249+
250+ Returns
251+ -------
252+ EntityID | None
253+ Entity position EntityID
254+ """
255+ entity = self .get_entity (entity_id )
256+ if entity is None :
257+ return None
258+ if isinstance (entity , Area ):
259+ return entity .get_entity_id ()
260+ if isinstance (entity , Human ):
261+ return entity .get_position ()
262+ if isinstance (entity , Blockade ):
263+ return entity .get_position ()
264+ return None
265+
266+ def get_entity_position_entity (self , entity_id : EntityID ) -> Entity | None :
267+ """
268+ Get the entity position Entity
269+
270+ Parameters
271+ ----------
272+ entity_id : EntityID
273+ Entity ID
274+
275+ Returns
276+ -------
277+ Entity | None
278+ Entity position Entity
279+ """
280+ position_entity_id = self .get_entity_position_entity_id (entity_id )
281+ if position_entity_id is None :
282+ return None
283+ return self .get_entity (position_entity_id )
284+
237285 def get_change_set (self ) -> ChangeSet :
238286 """
239287 Get the change set
0 commit comments