1import os
2import numpy as np
3
4from gym import utils, error
5from gym.envs.robotics import rotations, hand_env
6from gym.envs.robotics.utils import robot_get_obs
7
8try:
9    import mujoco_py
10except ImportError as e:
11    raise error.DependencyNotInstalled(
12        "{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(
13            e
14        )
15    )
16
17
18def quat_from_angle_and_axis(angle, axis):
19    assert axis.shape == (3,)
20    axis /= np.linalg.norm(axis)
21    quat = np.concatenate([[np.cos(angle / 2.0)], np.sin(angle / 2.0) * axis])
22    quat /= np.linalg.norm(quat)
23    return quat
24
25
26# Ensure we get the path separator correct on windows
27MANIPULATE_BLOCK_XML = os.path.join("hand", "manipulate_block.xml")
28MANIPULATE_EGG_XML = os.path.join("hand", "manipulate_egg.xml")
29MANIPULATE_PEN_XML = os.path.join("hand", "manipulate_pen.xml")
30
31
32class ManipulateEnv(hand_env.HandEnv):
33    def __init__(
34        self,
35        model_path,
36        target_position,
37        target_rotation,
38        target_position_range,
39        reward_type,
40        initial_qpos=None,
41        randomize_initial_position=True,
42        randomize_initial_rotation=True,
43        distance_threshold=0.01,
44        rotation_threshold=0.1,
45        n_substeps=20,
46        relative_control=False,
47        ignore_z_target_rotation=False,
48    ):
49        """Initializes a new Hand manipulation environment.
50
51        Args:
52            model_path (string): path to the environments XML file
53            target_position (string): the type of target position:
54                - ignore: target position is fully ignored, i.e. the object can be positioned arbitrarily
55                - fixed: target position is set to the initial position of the object
56                - random: target position is fully randomized according to target_position_range
57            target_rotation (string): the type of target rotation:
58                - ignore: target rotation is fully ignored, i.e. the object can be rotated arbitrarily
59                - fixed: target rotation is set to the initial rotation of the object
60                - xyz: fully randomized target rotation around the X, Y and Z axis
61                - z: fully randomized target rotation around the Z axis
62                - parallel: fully randomized target rotation around Z and axis-aligned rotation around X, Y
63            ignore_z_target_rotation (boolean): whether or not the Z axis of the target rotation is ignored
64            target_position_range (np.array of shape (3, 2)): range of the target_position randomization
65            reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense
66            initial_qpos (dict): a dictionary of joint names and values that define the initial configuration
67            randomize_initial_position (boolean): whether or not to randomize the initial position of the object
68            randomize_initial_rotation (boolean): whether or not to randomize the initial rotation of the object
69            distance_threshold (float, in meters): the threshold after which the position of a goal is considered achieved
70            rotation_threshold (float, in radians): the threshold after which the rotation of a goal is considered achieved
71            n_substeps (int): number of substeps the simulation runs on every call to step
72            relative_control (boolean): whether or not the hand is actuated in absolute joint positions or relative to the current state
73        """
74        self.target_position = target_position
75        self.target_rotation = target_rotation
76        self.target_position_range = target_position_range
77        self.parallel_quats = [
78            rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
79        ]
80        self.randomize_initial_rotation = randomize_initial_rotation
81        self.randomize_initial_position = randomize_initial_position
82        self.distance_threshold = distance_threshold
83        self.rotation_threshold = rotation_threshold
84        self.reward_type = reward_type
85        self.ignore_z_target_rotation = ignore_z_target_rotation
86
87        assert self.target_position in ["ignore", "fixed", "random"]
88        assert self.target_rotation in ["ignore", "fixed", "xyz", "z", "parallel"]
89        initial_qpos = initial_qpos or {}
90
91        hand_env.HandEnv.__init__(
92            self,
93            model_path,
94            n_substeps=n_substeps,
95            initial_qpos=initial_qpos,
96            relative_control=relative_control,
97        )
98
99    def _get_achieved_goal(self):
100        # Object position and rotation.
101        object_qpos = self.sim.data.get_joint_qpos("object:joint")
102        assert object_qpos.shape == (7,)
103        return object_qpos
104
105    def _goal_distance(self, goal_a, goal_b):
106        assert goal_a.shape == goal_b.shape
107        assert goal_a.shape[-1] == 7
108
109        d_pos = np.zeros_like(goal_a[..., 0])
110        d_rot = np.zeros_like(goal_b[..., 0])
111        if self.target_position != "ignore":
112            delta_pos = goal_a[..., :3] - goal_b[..., :3]
113            d_pos = np.linalg.norm(delta_pos, axis=-1)
114
115        if self.target_rotation != "ignore":
116            quat_a, quat_b = goal_a[..., 3:], goal_b[..., 3:]
117
118            if self.ignore_z_target_rotation:
119                # Special case: We want to ignore the Z component of the rotation.
120                # This code here assumes Euler angles with xyz convention. We first transform
121                # to euler, then set the Z component to be equal between the two, and finally
122                # transform back into quaternions.
123                euler_a = rotations.quat2euler(quat_a)
124                euler_b = rotations.quat2euler(quat_b)
125                euler_a[2] = euler_b[2]
126                quat_a = rotations.euler2quat(euler_a)
127
128            # Subtract quaternions and extract angle between them.
129            quat_diff = rotations.quat_mul(quat_a, rotations.quat_conjugate(quat_b))
130            angle_diff = 2 * np.arccos(np.clip(quat_diff[..., 0], -1.0, 1.0))
131            d_rot = angle_diff
132        assert d_pos.shape == d_rot.shape
133        return d_pos, d_rot
134
135    # GoalEnv methods
136    # ----------------------------
137
138    def compute_reward(self, achieved_goal, goal, info):
139        if self.reward_type == "sparse":
140            success = self._is_success(achieved_goal, goal).astype(np.float32)
141            return success - 1.0
142        else:
143            d_pos, d_rot = self._goal_distance(achieved_goal, goal)
144            # We weigh the difference in position to avoid that `d_pos` (in meters) is completely
145            # dominated by `d_rot` (in radians).
146            return -(10.0 * d_pos + d_rot)
147
148    # RobotEnv methods
149    # ----------------------------
150
151    def _is_success(self, achieved_goal, desired_goal):
152        d_pos, d_rot = self._goal_distance(achieved_goal, desired_goal)
153        achieved_pos = (d_pos < self.distance_threshold).astype(np.float32)
154        achieved_rot = (d_rot < self.rotation_threshold).astype(np.float32)
155        achieved_both = achieved_pos * achieved_rot
156        return achieved_both
157
158    def _env_setup(self, initial_qpos):
159        for name, value in initial_qpos.items():
160            self.sim.data.set_joint_qpos(name, value)
161        self.sim.forward()
162
163    def _reset_sim(self):
164        self.sim.set_state(self.initial_state)
165        self.sim.forward()
166
167        initial_qpos = self.sim.data.get_joint_qpos("object:joint").copy()
168        initial_pos, initial_quat = initial_qpos[:3], initial_qpos[3:]
169        assert initial_qpos.shape == (7,)
170        assert initial_pos.shape == (3,)
171        assert initial_quat.shape == (4,)
172        initial_qpos = None
173
174        # Randomization initial rotation.
175        if self.randomize_initial_rotation:
176            if self.target_rotation == "z":
177                angle = self.np_random.uniform(-np.pi, np.pi)
178                axis = np.array([0.0, 0.0, 1.0])
179                offset_quat = quat_from_angle_and_axis(angle, axis)
180                initial_quat = rotations.quat_mul(initial_quat, offset_quat)
181            elif self.target_rotation == "parallel":
182                angle = self.np_random.uniform(-np.pi, np.pi)
183                axis = np.array([0.0, 0.0, 1.0])
184                z_quat = quat_from_angle_and_axis(angle, axis)
185                parallel_quat = self.parallel_quats[
186                    self.np_random.randint(len(self.parallel_quats))
187                ]
188                offset_quat = rotations.quat_mul(z_quat, parallel_quat)
189                initial_quat = rotations.quat_mul(initial_quat, offset_quat)
190            elif self.target_rotation in ["xyz", "ignore"]:
191                angle = self.np_random.uniform(-np.pi, np.pi)
192                axis = self.np_random.uniform(-1.0, 1.0, size=3)
193                offset_quat = quat_from_angle_and_axis(angle, axis)
194                initial_quat = rotations.quat_mul(initial_quat, offset_quat)
195            elif self.target_rotation == "fixed":
196                pass
197            else:
198                raise error.Error(
199                    'Unknown target_rotation option "{}".'.format(self.target_rotation)
200                )
201
202        # Randomize initial position.
203        if self.randomize_initial_position:
204            if self.target_position != "fixed":
205                initial_pos += self.np_random.normal(size=3, scale=0.005)
206
207        initial_quat /= np.linalg.norm(initial_quat)
208        initial_qpos = np.concatenate([initial_pos, initial_quat])
209        self.sim.data.set_joint_qpos("object:joint", initial_qpos)
210
211        def is_on_palm():
212            self.sim.forward()
213            cube_middle_idx = self.sim.model.site_name2id("object:center")
214            cube_middle_pos = self.sim.data.site_xpos[cube_middle_idx]
215            is_on_palm = cube_middle_pos[2] > 0.04
216            return is_on_palm
217
218        # Run the simulation for a bunch of timesteps to let everything settle in.
219        for _ in range(10):
220            self._set_action(np.zeros(20))
221            try:
222                self.sim.step()
223            except mujoco_py.MujocoException:
224                return False
225        return is_on_palm()
226
227    def _sample_goal(self):
228        # Select a goal for the object position.
229        target_pos = None
230        if self.target_position == "random":
231            assert self.target_position_range.shape == (3, 2)
232            offset = self.np_random.uniform(
233                self.target_position_range[:, 0], self.target_position_range[:, 1]
234            )
235            assert offset.shape == (3,)
236            target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
237        elif self.target_position in ["ignore", "fixed"]:
238            target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
239        else:
240            raise error.Error(
241                'Unknown target_position option "{}".'.format(self.target_position)
242            )
243        assert target_pos is not None
244        assert target_pos.shape == (3,)
245
246        # Select a goal for the object rotation.
247        target_quat = None
248        if self.target_rotation == "z":
249            angle = self.np_random.uniform(-np.pi, np.pi)
250            axis = np.array([0.0, 0.0, 1.0])
251            target_quat = quat_from_angle_and_axis(angle, axis)
252        elif self.target_rotation == "parallel":
253            angle = self.np_random.uniform(-np.pi, np.pi)
254            axis = np.array([0.0, 0.0, 1.0])
255            target_quat = quat_from_angle_and_axis(angle, axis)
256            parallel_quat = self.parallel_quats[
257                self.np_random.randint(len(self.parallel_quats))
258            ]
259            target_quat = rotations.quat_mul(target_quat, parallel_quat)
260        elif self.target_rotation == "xyz":
261            angle = self.np_random.uniform(-np.pi, np.pi)
262            axis = self.np_random.uniform(-1.0, 1.0, size=3)
263            target_quat = quat_from_angle_and_axis(angle, axis)
264        elif self.target_rotation in ["ignore", "fixed"]:
265            target_quat = self.sim.data.get_joint_qpos("object:joint")
266        else:
267            raise error.Error(
268                'Unknown target_rotation option "{}".'.format(self.target_rotation)
269            )
270        assert target_quat is not None
271        assert target_quat.shape == (4,)
272
273        target_quat /= np.linalg.norm(target_quat)  # normalized quaternion
274        goal = np.concatenate([target_pos, target_quat])
275        return goal
276
277    def _render_callback(self):
278        # Assign current state to target object but offset a bit so that the actual object
279        # is not obscured.
280        goal = self.goal.copy()
281        assert goal.shape == (7,)
282        if self.target_position == "ignore":
283            # Move the object to the side since we do not care about it's position.
284            goal[0] += 0.15
285        self.sim.data.set_joint_qpos("target:joint", goal)
286        self.sim.data.set_joint_qvel("target:joint", np.zeros(6))
287
288        if "object_hidden" in self.sim.model.geom_names:
289            hidden_id = self.sim.model.geom_name2id("object_hidden")
290            self.sim.model.geom_rgba[hidden_id, 3] = 1.0
291        self.sim.forward()
292
293    def _get_obs(self):
294        robot_qpos, robot_qvel = robot_get_obs(self.sim)
295        object_qvel = self.sim.data.get_joint_qvel("object:joint")
296        achieved_goal = (
297            self._get_achieved_goal().ravel()
298        )  # this contains the object position + rotation
299        observation = np.concatenate(
300            [robot_qpos, robot_qvel, object_qvel, achieved_goal]
301        )
302        return {
303            "observation": observation.copy(),
304            "achieved_goal": achieved_goal.copy(),
305            "desired_goal": self.goal.ravel().copy(),
306        }
307
308
309class HandBlockEnv(ManipulateEnv, utils.EzPickle):
310    def __init__(
311        self, target_position="random", target_rotation="xyz", reward_type="sparse"
312    ):
313        utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
314        ManipulateEnv.__init__(
315            self,
316            model_path=MANIPULATE_BLOCK_XML,
317            target_position=target_position,
318            target_rotation=target_rotation,
319            target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
320            reward_type=reward_type,
321        )
322
323
324class HandEggEnv(ManipulateEnv, utils.EzPickle):
325    def __init__(
326        self, target_position="random", target_rotation="xyz", reward_type="sparse"
327    ):
328        utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
329        ManipulateEnv.__init__(
330            self,
331            model_path=MANIPULATE_EGG_XML,
332            target_position=target_position,
333            target_rotation=target_rotation,
334            target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
335            reward_type=reward_type,
336        )
337
338
339class HandPenEnv(ManipulateEnv, utils.EzPickle):
340    def __init__(
341        self, target_position="random", target_rotation="xyz", reward_type="sparse"
342    ):
343        utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
344        ManipulateEnv.__init__(
345            self,
346            model_path=MANIPULATE_PEN_XML,
347            target_position=target_position,
348            target_rotation=target_rotation,
349            target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
350            randomize_initial_rotation=False,
351            reward_type=reward_type,
352            ignore_z_target_rotation=True,
353            distance_threshold=0.05,
354        )
355