1# Lint as: python3
2"""Crowd objects/human controllers module."""
3
4import abc
5import collections
6from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Sequence, Text
7
8from absl import logging
9import dataclasses
10import gin
11import numpy as np
12#import rvo2
13
14from pybullet_envs.minitaur.envs_v2.sensors import base_position_sensor
15from pybullet_envs.minitaur.envs_v2.sensors import sensor as generic_sensor
16from pybullet_envs.minitaur.robots import autonomous_object
17from pybullet_envs.minitaur.robots import object_controller
18
19
20POSITION_SENSOR_POSTFIX = "_pos"
21
22
23@dataclasses.dataclass
24class MovingObjectRecord:
25  position_key: Text
26  agent_id: int
27  radius: float
28  last_position: Optional[np.ndarray] = None
29
30
31@gin.configurable
32def sample_start_target_position(scene,
33                                 start=None,
34                                 start_circles=None,
35                                 target_circles=None,
36                                 num_sampling_retries=1,
37                                 min_wall_distance=0.0,
38                                 min_goal_euclidean_distance=0.0,
39                                 max_goal_euclidean_distance=np.Inf,
40                                 min_path_clearance=None):
41  """Sample valid start and target position reachable from start.
42
43  Args:
44    scene: a SceneBase instance implementing get_random_valid_position function.
45    start: a 2-tuple (x, y) of start position. If specified, no start is
46      sampled.
47    start_circles: a list of circle specification. Each circle is specified as
48      a tuple ((x, y), r) of a center (x, y) and radius r. If specified, start
49      position is sampled from within one of the start_circles.
50    target_circles: same as start_circle. If specified, target positions is
51      sampled from within one of the start_circles.
52    num_sampling_retries: a positive int, number of attempts to sample a
53      start, target pair.
54    min_wall_distance: a float, the minimum distance to a wall.
55    min_goal_euclidean_distance: a positive float, the minimum distance between
56      start and target.
57    max_goal_euclidean_distance: a positive float, the maximum distance between
58      start and target.
59    min_path_clearance: float, clearance of shortest path to walls.
60
61  Returns:
62    A 4 tuple (start, target, shortest_path, is_valid). start and target are
63    start and target positions, shortest_path is a list of 2-tuples specifying
64    the shortest path from start to target, is_valid is bool specifying whether
65    the start, target pair is valid. If min_path_clearance is not specified,
66    then shortest_path is None.
67  """
68  if not hasattr(scene, "get_random_valid_position"):
69    raise ValueError(
70        "Incompatible scene {}. Expected to have `get_random_valid_position` "
71        "method.".format(scene))
72
73  def _print_counters(counters):
74    for name, value in counters.items():
75      logging.info("  %s: %d", name, value)
76
77  sampling_counters = collections.defaultdict(lambda: 0)
78  for _ in range(num_sampling_retries):
79    if start is None:
80      start_pos = scene.get_random_valid_position(
81          min_wall_distance, inclusion_circles=start_circles)
82    else:
83      if start_circles is not None:
84        raise ValueError("At most one of the arguments start and start_circles "
85                         "can be not None.")
86      start_pos = start
87    target_pos = scene.get_random_valid_position(
88        min_wall_distance, inclusion_circles=target_circles)
89    sampling_counters["attempts"] += 1
90
91    euclidean_distance = np.linalg.norm(target_pos - start_pos)
92    if euclidean_distance < min_goal_euclidean_distance:
93      sampling_counters["min_euclidean"] += 1
94      continue
95    if euclidean_distance > max_goal_euclidean_distance:
96      sampling_counters["max_euclidean"] += 1
97      continue
98
99    # Skip the path computation is no path clearance is provided.
100    if min_path_clearance is None:
101      logging.info("Valid goal with no minimum path clearance checking.")
102      _print_counters(sampling_counters)
103      return start_pos, target_pos, None, True
104
105    # Check the goal clearance along the shortest path
106    if not hasattr(scene, "find_shortest_path"):
107      raise ValueError(
108          f"scene %s missing find_shortest_path method {scene}")
109
110    # This is a slow process.
111    shortest_path = scene.find_shortest_path(
112        start_pos[:2], target_pos[:2], min_path_clearance)
113    # No path exists between current robot position and goal satisfying the
114    # clearance.
115    if shortest_path is None:
116      sampling_counters["path_clearance"] += 1
117      continue
118
119    logging.info("Valid start/target with path clearance checking.")
120    _print_counters(sampling_counters)
121    return start_pos, target_pos, shortest_path, True
122
123  logging.info("No valid start/target found.")
124  _print_counters(sampling_counters)
125  return start_pos, target_pos, None, False
126
127
128class CrowdController(metaclass=abc.ABCMeta):
129  """Crowd controller interface."""
130
131  def __init__(self, names: Iterable[Text],
132               position_key_formatter="%s" + POSITION_SENSOR_POSTFIX):
133    """Constructor.
134
135    Args:
136      names: Name of instance (dynamic object or human).
137      position_key_formatter: Formatter to convert name to position sensor name.
138    """
139    self._names = list(names)
140    self._position_key_formatter = position_key_formatter
141    self._num_instance = len(self._names)
142
143    self._current_time = 0
144
145  def _validate_instance_id(self, instance_id):
146    if not 0 <= instance_id < self._num_instance:
147      raise ValueError(
148          f"instance_id must be an integer in [0, {self.num_instance}), "
149          f"got {instance_id}.")
150
151  @property
152  def num_instance(self):
153    """Returns the number of crowd instances."""
154    return self._num_instance
155
156  def instance_name(self, instance_id: int) -> Text:
157    """Returns the name of instance."""
158    self._validate_instance_id(instance_id)
159    return self._names[instance_id]
160
161  def instance_controller(
162      self, instance_id: int) -> object_controller.ControllerBase:
163    """Returns the individual controller of certain instance."""
164    self._validate_instance_id(instance_id)
165    return _IndividualController(self, instance_id)
166
167  def instance_get_action(
168      self, instance_id: int, time_sec: float,
169      observations: Dict[Text, Any]) -> object_controller.ControllerOutput:
170    """Returns action of specific instance given observation.
171
172    This method is for _IndividualController.
173
174    Args:
175      instance_id: Identifier of an object in the crowd.
176      time_sec: Time since simulation reset in seconds. If time < 0, returns
177        initial values and ignores observations.
178      observations: A dict of all observations.
179
180    Returns:
181      Position, orientation and an extra info dict for robot joints, human
182        skeletal pose, etc.
183    """
184    if time_sec < 0:
185      self._recalculate_actions(object_controller.INIT_TIME, {})
186      self._current_time = object_controller.INIT_TIME
187    elif time_sec > self._current_time:
188      self._current_time = time_sec
189      self._recalculate_actions(self._current_time, observations)
190
191    self._validate_instance_id(instance_id)
192
193    return self._get_action_of_instance(instance_id)
194
195  @abc.abstractmethod
196  def _recalculate_actions(
197      self, time_sec: float, observations: Dict[Text, Any]) -> None:
198    """Calculates crowd command for all instances in crowd."""
199    raise NotImplementedError(
200        "_recalculate_actions() should be implemented by subclass.")
201
202  @abc.abstractmethod
203  def _get_action_of_instance(
204      self, instance_id: int) -> object_controller.ControllerOutput:
205    """Returns calculated actions of specific instance."""
206    raise NotImplementedError(
207        "_get_action_of_instance() should be implemented by subclass.")
208
209  def set_scene(self, scene) -> None:
210    """Sets the scene for crowd controller to obtain scene information."""
211    del scene
212
213
214class _IndividualController(object_controller.ControllerBase):
215  """A utility class that wraps crowd controller in ControllerBase interface."""
216
217  def __init__(self, crowd_controller: CrowdController, instance_id: int):
218    """Constructor.
219
220    Args:
221      crowd_controller: The controller of crowd to which this instance belong.
222      instance_id: Identifier of a crowd instance.
223    """
224    self._instance_id = instance_id
225    self._crowd_controller = crowd_controller
226
227  def get_action(
228      self, time_sec: float,
229      observations: Dict[Text, Any]) -> object_controller.ControllerOutput:
230    """Returns position, orientation and pose based on time and observations.
231
232    Args:
233      time_sec: Time since simulation reset in seconds. If time < 0, returns
234        initial values and ignores observations.
235      observations: A dict of all observations.
236
237    Returns:
238      Position, orientation and an extra info dict for robot joints, human
239        skeletal pose, etc.
240    """
241    return self._crowd_controller.instance_get_action(
242        self._instance_id, time_sec, observations)
243
244
245@gin.configurable
246class StationaryController(CrowdController):
247  """A crowd controller that places crowd objects at fixed positions."""
248
249  def __init__(
250      self, positions: Sequence[Sequence[float]],
251      orientations: Optional[Sequence[Sequence[float]]] = None, **kwargs):
252    """Constructor.
253
254    Args:
255      positions: Fixed positions (3D points) of crowd instances.
256      orientations: Fixed orientations in quaternion of crowd instances.
257      **kwargs: Keyword arguments to pass on to base class.
258    """
259    super().__init__(**kwargs)
260
261    if orientations is None:
262      orientations = np.array(((0, 0, 0, 1),) * self.num_instance)
263
264    if not len(positions) == len(orientations) == self.num_instance:
265      raise ValueError(
266          f"positions and orientations should all have the same length "
267          f"{self.num_instance}. Got len(positions) = {len(positions)}, "
268          f"len(orientations) = {len(orientations)}.")
269
270    self._positions = positions
271    self._orientations = orientations
272
273  def _recalculate_actions(
274      self, time_sec: float, observations: Dict[Text, Any]) -> None:
275    """Calculates crowd command for all instances in crowd."""
276    del time_sec
277    del observations
278
279  def _get_action_of_instance(
280      self, instance_id: int) -> object_controller.ControllerOutput:
281    """Returns calculated actions of specific instance."""
282    self._validate_instance_id(instance_id)
283    return self._positions[instance_id], self._orientations[instance_id], {}
284
285
286@gin.configurable
287class OrcaController(CrowdController):
288  """A crowd controller that controls crowd instances using ORCA algorithm.
289
290  Crowd instance will be initialized at a specified start position and move
291  towards specified target position in a linear path while avoid collision with
292  each other.
293  """
294
295  _DEFAULT_NEIGHBOR_DISTANCE_M = 5
296  _DEFAULT_MAX_NEIGHBORS = 10
297  _DEFAULT_RADIUS_M = 0.5
298  _DEFAULT_MAX_SPEED_MPS = 2
299  _DEFAULT_TIME_HORIZON_SEC = 1.0
300  _DEFAULT_OBSTACLE_TIME_HORIZON_SEC = 0.3
301
302  def __init__(
303      self,
304      timestep: float,
305      start_positions: Optional[Sequence[Sequence[float]]] = None,
306      target_positions: Optional[Sequence[Sequence[float]]] = None,
307      use_position_generator: Optional[bool] = False,
308      group_sizes: Sequence[int] = None,
309      radius: float = _DEFAULT_RADIUS_M,
310      max_speed_mps: float = _DEFAULT_MAX_SPEED_MPS,
311      time_horizon_sec: float = _DEFAULT_TIME_HORIZON_SEC,
312      obstacle_time_horizon_sec: float = _DEFAULT_OBSTACLE_TIME_HORIZON_SEC,
313      neighbor_distance_m: float = _DEFAULT_NEIGHBOR_DISTANCE_M,
314      max_neighbors: int = _DEFAULT_MAX_NEIGHBORS,
315      workaround_erp_issue: bool = True,
316      moving_objects_pos_key: Sequence[Text] = (),
317      moving_objects_radius: Union[float, Sequence[float]] = _DEFAULT_RADIUS_M,
318      endless_trajectory: bool = True,
319      **kwargs):
320    """Constructor.
321
322    Args:
323      timestep: Timestep of simulation.
324      start_positions: A list of position (x, y, z) for crowd instances as
325        their starting position.
326      target_positions: A list of position (x, y, z) for crowd instances as
327        their target position.
328      use_position_generator: a boolean, if True than the start and end
329        positions are sampled. start_positions and target_positions must be None
330      group_sizes: If set, then crowd is split in groups randomly, whose sizes
331        are picked in random from this group_size list. In this way, the
332        crowd simulator sumulaters clusters of objects moving around.
333      radius: Radius of crowd instances.
334      max_speed_mps: Maximum crowd instance speed.
335      time_horizon_sec: Time horizon in second.
336      obstacle_time_horizon_sec: Time horizon for static obstacle in second.
337      neighbor_distance_m: Neighbor distance in meters. Instances closer than
338        this distance are considered neighbors.
339      max_neighbors: Max number of neighbors.
340      workaround_erp_issue: There is an issue with pybullet constraint that the
341        constraint is solved only 20% per timestep. Need to amplify position
342        delta by 5x to workaround this issue.
343      moving_objects_pos_key: Position observation key of moving objects not
344        controlled by the ORCA controller.
345      moving_objects_radius: Radius of moving objects. Should be a float, which
346        applies to all moving objects, or a sequence of float, which should be
347        of the same length as moving_objects_pos_key.
348      endless_trajectory: Only valid if use_position_generator is True. Agent
349        returns to starting point after reaching goal to achieve endless motion.
350      **kwargs: Keyword arguments to pass on to base class.
351    """
352    super().__init__(**kwargs)
353
354    assert ((start_positions is not None and target_positions is not None) or
355            use_position_generator)
356    if not use_position_generator:
357      if not len(start_positions) == len(target_positions) == self.num_instance:
358        raise ValueError(
359            f"start_positions and target_positions should both have length "
360            f"equals {self.num_instance}: "
361            f"len(start_positions) = {len(start_positions)}, "
362            f"len(target_positions) = {len(target_positions)}.")
363
364    self._timestep = timestep
365    self._radius = radius
366    self._max_speed_mps = max_speed_mps
367    self._time_horizon_sec = time_horizon_sec
368    self._obstacle_time_horizon_sec = obstacle_time_horizon_sec
369    self._neighbor_distance_m = neighbor_distance_m
370    self._max_neighbors = max_neighbors
371    self._use_position_generator = use_position_generator
372    self._endless_trajectory = endless_trajectory
373    self._scene = None
374    if isinstance(moving_objects_radius, float):
375      moving_objects_radius = [
376          moving_objects_radius] * len(moving_objects_pos_key)
377    if len(moving_objects_radius) != len(moving_objects_pos_key):
378      raise ValueError(
379          "moving_objects_radius should be either a float or a sequence of "
380          "float with the same length as moving_objects_pos_key.")
381    self._moving_objects = [
382        MovingObjectRecord(position_key=key, agent_id=-1, radius=radius)
383        for key, radius in zip(moving_objects_pos_key, moving_objects_radius)]
384
385    self._paths = None
386    self._path_indices = None
387    if self._use_position_generator:
388      self._start_positions = None
389      self._target_positions = None
390    else:
391      self._start_positions = np.array(start_positions, dtype=np.float64)
392      self._target_positions = np.array(target_positions, dtype=np.float64)
393    # A guard against multiple initializations. See recalculate_actions below.
394    self._already_initialized = False
395    self._group_sizes = [1] if group_sizes is None else group_sizes
396
397    # The following variables are initialized in _recalculate_actions()
398    self._current_positions = None
399    self._command_positions = None
400    self._command_orientations = None
401
402    #self._orca = rvo2.PyRVOSimulator(
403    #    self._timestep,  # timestep
404    #    self._neighbor_distance_m,  # neighborDist
405    #    self._max_neighbors,  # maxNeighbors
406    #    self._time_horizon_sec,  # timeHorizon
407    #    self._obstacle_time_horizon_sec,  # timeHorizonObst
408    #    self._radius,  # radius
409    #    self._max_speed_mps  # maxSpeed
410    #)
411    for i in range(self.num_instance):
412      if self._use_position_generator:
413        start_position = (0, 0)
414      else:
415        start_position = self._start_positions[i, :2]
416      agent_id = self._orca.addAgent(
417          tuple(start_position),
418          self._neighbor_distance_m,  # neighborDist
419          self._max_neighbors,  # maxNeighbors
420          self._time_horizon_sec,  # timeHorizon
421          self._obstacle_time_horizon_sec,  # timeHorizonObst
422          self._radius,  # radius
423          self._max_speed_mps,  # maxSpeed
424          (0.0, 0.0))  # velocity
425      assert agent_id == i
426
427    for obj in self._moving_objects:
428      obj.agent_id = self._orca.addAgent(
429          (0.0, 0.0),  # position (will adjust after simulation starts)
430          self._neighbor_distance_m,  # neighborDist
431          self._max_neighbors,  # maxNeighbors
432          self._timestep,  # timeHorizon
433          self._timestep,  # timeHorizonObst
434          obj.radius,  # radius
435          self._max_speed_mps,  # maxSpeed
436          (0.0, 0.0))  # velocity
437
438    self._workaround_erp_issue = workaround_erp_issue
439
440  def _subsample_path(self, path, subsample_step=1.0):
441    subsampled_path = [path[0]]
442    traveled_dist = 0.0
443    for i, (s, t) in enumerate(zip(path[:-1], path[1:])):
444      traveled_dist += np.sqrt(
445          np.square(s[0] - t[0]) + np.square(s[1] - t[1]))
446      if traveled_dist > subsample_step or i >= len(path) - 2:
447        subsampled_path.append(t)
448        traveled_dist = 0.0
449    return subsampled_path
450
451  def _generate_start_target_positions(self):
452    """Generates start and target positions using goal generartors."""
453    assert self._scene is not None
454    self._start_positions = np.zeros((self.num_instance, 3), dtype=np.float64)
455    self._target_positions = np.zeros((self.num_instance, 3), dtype=np.float64)
456
457    self._paths = []
458    self._path_indices = []
459    start_circles, target_circles = None, None
460    group_radius = 1.0
461    current_group_size = np.random.choice(self._group_sizes)
462    index_in_current_group = 0
463    for i in range(self._num_instance):
464      start_pos, target_pos, path, is_valid = sample_start_target_position(
465          self._scene,
466          start_circles=start_circles,
467          target_circles=target_circles)
468      if index_in_current_group == current_group_size - 1:
469        start_circles, target_circles = None, None
470        index_in_current_group = 0
471        current_group_size = np.random.choice(self._group_sizes)
472      else:
473        if start_circles is None:
474          start_circles = [(start_pos[:2], group_radius)]
475          target_circles = [(target_pos[:2], group_radius)]
476        else:
477          start_circles += [(start_pos[:2], group_radius)]
478          target_circles += [(target_pos[:2], group_radius)]
479        index_in_current_group += 1
480      if not is_valid:
481        raise ValueError("No valid start/target positions.")
482      self._start_positions[i, :] = start_pos
483      self._target_positions[i, :] = target_pos
484
485      subsampled_path = self._subsample_path(path)
486      self._paths.append(np.array(subsampled_path, dtype=np.float32))
487      self._path_indices.append(0)
488
489  def _recalculate_actions(
490      self, time_sec: float, observations: Dict[Text, Any]) -> None:
491    """Calculates crowd command for all crowd instances."""
492    if self._use_position_generator:
493      if (time_sec == object_controller.INIT_TIME and
494          self._start_positions is None and
495          not self._already_initialized):
496        self._generate_start_target_positions()
497        # Initialize only once per initial time even if recalculate actions
498        # is called multiple times.
499        self._already_initialized = True
500    if time_sec == object_controller.INIT_TIME:
501      # Resets orca simulator.
502      for i in range(len(self._names)):
503        self._orca.setAgentPosition(i, tuple(self._start_positions[i, :2]))
504
505      self._command_positions = self._start_positions.copy()
506      self._current_positions = self._start_positions.copy()
507      self._command_orientations = np.repeat(
508          ((0.0, 0.0, 0.0, 1.0),), len(self._names), axis=0)
509      self._last_target_recalculation_sec = time_sec
510      return
511    else:
512      # The moment we step beyond initial time, we can initialize again.
513      self._already_initialized = False
514
515    if self._use_position_generator:
516      for i in range(self._num_instance):
517        dist = np.linalg.norm(
518            self._current_positions[i, :] - self._target_positions[i, :])
519        if dist < 2.0:
520          _, target_pos, path, is_valid = sample_start_target_position(
521              self._scene, self._current_positions[i, :])
522          if is_valid:
523            self._target_positions[i, :] = target_pos
524            subsampled_path = self._subsample_path(path)
525            self._paths.append(np.array(subsampled_path, dtype=np.float32))
526            self._path_indices.append(0)
527
528    # Sets agent position and preferred velocity based on target.
529    for i, agent_name in enumerate(self._names):
530      position = observations[self._position_key_formatter % agent_name]
531      self._orca.setAgentPosition(
532          i, tuple(position[:2]))  # ORCA uses 2D position.
533      self._current_positions[i, :2] = position[:2]
534
535      if self._paths is not None:
536        # Find closest point on the path from start to target, which (1) hasn't
537        # been covered already; (2) is at least max_coverage_distance away from
538        # current position.
539        distances = np.sqrt(np.sum(np.square(
540            self._paths[i] - position[:2]), axis=1))
541        max_coverage_distance = 1.0
542        index = self._path_indices[i]
543        while True:
544          if index >= len(self._paths[i]) - 1:
545            if self._endless_trajectory:
546              self._paths[i] = self._paths[i][::-1]
547              distances = distances[::-1]
548              index = 0
549            break
550          elif distances[index] > max_coverage_distance:
551            break
552          else:
553            index += 1
554        self._path_indices[i] = index
555        target_position = self._paths[i][index, :]
556      else:
557        target_position = self._target_positions[i][:2]
558
559      goal_vector = target_position - position[:2]
560      goal_vector_norm = np.linalg.norm(goal_vector) + np.finfo(np.float32).eps
561      goal_unit_vector = goal_vector / goal_vector_norm
562
563      kv = 1
564      velocity = min(kv * goal_vector_norm,
565                     self._DEFAULT_MAX_SPEED_MPS) * goal_unit_vector
566      self._orca.setAgentPrefVelocity(i, tuple(velocity))
567
568    for obj in self._moving_objects:
569      position = observations[obj.position_key]
570      self._orca.setAgentPosition(obj.agent_id, tuple(position[:2]))
571      if obj.last_position is None:
572        self._orca.setAgentPrefVelocity(obj.agent_id, (0.0, 0.0))
573      else:
574        velocity = (position - obj.last_position) / self._timestep
575        self._orca.setAgentPrefVelocity(obj.agent_id, tuple(velocity[:2]))
576      obj.last_position = position.copy()
577
578    # Advances orca simulator.
579    self._orca.doStep()
580
581    # Retrieve agent position and save in buffer.
582    for i in range(len(self._names)):
583      x, y = self._orca.getAgentPosition(i)
584      self._command_positions[i, :2] = (x, y)
585
586      yaw = np.arctan2(y - self._current_positions[i, 1],
587                       x - self._current_positions[i, 0])
588      self._command_orientations[i] = (0, 0, np.sin(yaw / 2), np.cos(yaw / 2))
589
590  def _get_action_of_instance(
591      self, instance_id) -> object_controller.ControllerOutput:
592    """Returns calculated actions of specific instance."""
593
594    if self._command_positions is None:
595      raise RuntimeError(
596          "Attempted to get action of instance before _recalculate_actions().")
597
598    self._validate_instance_id(instance_id)
599
600    if self._workaround_erp_issue:
601      k_erp = 1 / 0.2
602      delta_position = (
603          self._command_positions[instance_id] -
604          self._current_positions[instance_id])
605      command_position = (
606          self._current_positions[instance_id] + k_erp * delta_position)
607    else:
608      command_position = self._command_positions[instance_id].copy()
609    return command_position, self._command_orientations[instance_id], {}
610
611  def set_scene(self, scene) -> None:
612    """Sets the scene for crowd controller to obtain scene information."""
613    try:
614      polygons = scene.vectorized_map
615      for polygon in polygons:
616        self._orca.addObstacle([tuple(point) for point in polygon])
617      self._orca.processObstacles()
618      self._scene = scene
619    except NotImplementedError:
620      logging.exception("Scene does not implement vectorized_map property. "
621                        "Crowd agent cannot avoid static obstacles.")
622
623
624@gin.configurable
625def uniform_object_factory(
626    instance_id: int,
627    object_factory: Callable[..., autonomous_object.AutonomousObject],
628    *args, **kwargs) -> autonomous_object.AutonomousObject:
629  """A wrapper that removes instance_id in default crowd object factory."""
630  del instance_id
631  return object_factory(*args, **kwargs)
632
633
634@gin.configurable
635def random_object_factory(
636    instance_id: int,
637    object_factories: Iterable[
638        Callable[..., autonomous_object.AutonomousObject]],
639    *args, **kwargs) -> autonomous_object.AutonomousObject:
640  """A wrapper that removes instance_id in default crowd object factory."""
641  del instance_id
642  object_factory = np.random.choice(object_factories)
643  return object_factory(*args, **kwargs)
644
645
646@gin.configurable
647def sensor_factory(instance_id: int, sensor: Callable[...,
648                                                      generic_sensor.Sensor],
649                   *args, **kwargs) -> generic_sensor.Sensor:
650  del instance_id
651  return sensor(*args, **kwargs)
652
653
654@gin.configurable
655class CrowdBuilder(object):
656  """A helper class to construct a crowd."""
657
658  def __init__(
659      self,
660      num_instance: int,
661      crowd_controller_factory: Callable[..., CrowdController],
662      object_factory: Callable[..., autonomous_object.AutonomousObject],
663      sensor_factories: Iterable[Callable[..., generic_sensor.Sensor]] = None):
664    """Constructor.
665
666    Args:
667      num_instance: Number of autonomous objects in the crowd.
668      crowd_controller_factory: A callable that returns a crowd controller
669        object.
670      object_factory: Callable that returns an autonomous object.
671      sensor_factories: list of sensor callables.
672    """
673    self._objects = []
674    crowd_id_prefix = "crowd"
675    names = [crowd_id_prefix + "_%d" % i for i in range(num_instance)]
676
677    self._controller = crowd_controller_factory(names=names)
678
679    for i in range(num_instance):
680      position_sensor = base_position_sensor.BasePositionSensor(
681          name=names[i] + POSITION_SENSOR_POSTFIX)
682
683      # Add additional per agent sensors (e.g. camera, occupancy, etc.).
684      add_sensors = []
685      if sensor_factories:
686        for s in sensor_factories:
687          add_sensors.append(
688              sensor_factory(
689                  instance_id=i, sensor=s, name=names[i] + "_" + s.__name__))
690
691      an_object = object_factory(
692          instance_id=i,
693          sensors=(position_sensor,) + tuple(add_sensors),
694          controller=self._controller.instance_controller(i))
695
696      self._objects.append(an_object)
697
698  @property
699  def crowd_objects(self) -> List[autonomous_object.AutonomousObject]:
700    """Returns list of AutonomousObjects in the crowd."""
701    return self._objects
702
703  @property
704  def crowd_controller(self) -> CrowdController:
705    """Returns the crowd controller."""
706    return self._controller
707