1import os 2import numpy as np 3 4from gym import utils, error 5from gym.envs.robotics import rotations, hand_env 6from gym.envs.robotics.utils import robot_get_obs 7 8try: 9 import mujoco_py 10except ImportError as e: 11 raise error.DependencyNotInstalled( 12 "{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format( 13 e 14 ) 15 ) 16 17 18def quat_from_angle_and_axis(angle, axis): 19 assert axis.shape == (3,) 20 axis /= np.linalg.norm(axis) 21 quat = np.concatenate([[np.cos(angle / 2.0)], np.sin(angle / 2.0) * axis]) 22 quat /= np.linalg.norm(quat) 23 return quat 24 25 26# Ensure we get the path separator correct on windows 27MANIPULATE_BLOCK_XML = os.path.join("hand", "manipulate_block.xml") 28MANIPULATE_EGG_XML = os.path.join("hand", "manipulate_egg.xml") 29MANIPULATE_PEN_XML = os.path.join("hand", "manipulate_pen.xml") 30 31 32class ManipulateEnv(hand_env.HandEnv): 33 def __init__( 34 self, 35 model_path, 36 target_position, 37 target_rotation, 38 target_position_range, 39 reward_type, 40 initial_qpos=None, 41 randomize_initial_position=True, 42 randomize_initial_rotation=True, 43 distance_threshold=0.01, 44 rotation_threshold=0.1, 45 n_substeps=20, 46 relative_control=False, 47 ignore_z_target_rotation=False, 48 ): 49 """Initializes a new Hand manipulation environment. 50 51 Args: 52 model_path (string): path to the environments XML file 53 target_position (string): the type of target position: 54 - ignore: target position is fully ignored, i.e. the object can be positioned arbitrarily 55 - fixed: target position is set to the initial position of the object 56 - random: target position is fully randomized according to target_position_range 57 target_rotation (string): the type of target rotation: 58 - ignore: target rotation is fully ignored, i.e. the object can be rotated arbitrarily 59 - fixed: target rotation is set to the initial rotation of the object 60 - xyz: fully randomized target rotation around the X, Y and Z axis 61 - z: fully randomized target rotation around the Z axis 62 - parallel: fully randomized target rotation around Z and axis-aligned rotation around X, Y 63 ignore_z_target_rotation (boolean): whether or not the Z axis of the target rotation is ignored 64 target_position_range (np.array of shape (3, 2)): range of the target_position randomization 65 reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense 66 initial_qpos (dict): a dictionary of joint names and values that define the initial configuration 67 randomize_initial_position (boolean): whether or not to randomize the initial position of the object 68 randomize_initial_rotation (boolean): whether or not to randomize the initial rotation of the object 69 distance_threshold (float, in meters): the threshold after which the position of a goal is considered achieved 70 rotation_threshold (float, in radians): the threshold after which the rotation of a goal is considered achieved 71 n_substeps (int): number of substeps the simulation runs on every call to step 72 relative_control (boolean): whether or not the hand is actuated in absolute joint positions or relative to the current state 73 """ 74 self.target_position = target_position 75 self.target_rotation = target_rotation 76 self.target_position_range = target_position_range 77 self.parallel_quats = [ 78 rotations.euler2quat(r) for r in rotations.get_parallel_rotations() 79 ] 80 self.randomize_initial_rotation = randomize_initial_rotation 81 self.randomize_initial_position = randomize_initial_position 82 self.distance_threshold = distance_threshold 83 self.rotation_threshold = rotation_threshold 84 self.reward_type = reward_type 85 self.ignore_z_target_rotation = ignore_z_target_rotation 86 87 assert self.target_position in ["ignore", "fixed", "random"] 88 assert self.target_rotation in ["ignore", "fixed", "xyz", "z", "parallel"] 89 initial_qpos = initial_qpos or {} 90 91 hand_env.HandEnv.__init__( 92 self, 93 model_path, 94 n_substeps=n_substeps, 95 initial_qpos=initial_qpos, 96 relative_control=relative_control, 97 ) 98 99 def _get_achieved_goal(self): 100 # Object position and rotation. 101 object_qpos = self.sim.data.get_joint_qpos("object:joint") 102 assert object_qpos.shape == (7,) 103 return object_qpos 104 105 def _goal_distance(self, goal_a, goal_b): 106 assert goal_a.shape == goal_b.shape 107 assert goal_a.shape[-1] == 7 108 109 d_pos = np.zeros_like(goal_a[..., 0]) 110 d_rot = np.zeros_like(goal_b[..., 0]) 111 if self.target_position != "ignore": 112 delta_pos = goal_a[..., :3] - goal_b[..., :3] 113 d_pos = np.linalg.norm(delta_pos, axis=-1) 114 115 if self.target_rotation != "ignore": 116 quat_a, quat_b = goal_a[..., 3:], goal_b[..., 3:] 117 118 if self.ignore_z_target_rotation: 119 # Special case: We want to ignore the Z component of the rotation. 120 # This code here assumes Euler angles with xyz convention. We first transform 121 # to euler, then set the Z component to be equal between the two, and finally 122 # transform back into quaternions. 123 euler_a = rotations.quat2euler(quat_a) 124 euler_b = rotations.quat2euler(quat_b) 125 euler_a[2] = euler_b[2] 126 quat_a = rotations.euler2quat(euler_a) 127 128 # Subtract quaternions and extract angle between them. 129 quat_diff = rotations.quat_mul(quat_a, rotations.quat_conjugate(quat_b)) 130 angle_diff = 2 * np.arccos(np.clip(quat_diff[..., 0], -1.0, 1.0)) 131 d_rot = angle_diff 132 assert d_pos.shape == d_rot.shape 133 return d_pos, d_rot 134 135 # GoalEnv methods 136 # ---------------------------- 137 138 def compute_reward(self, achieved_goal, goal, info): 139 if self.reward_type == "sparse": 140 success = self._is_success(achieved_goal, goal).astype(np.float32) 141 return success - 1.0 142 else: 143 d_pos, d_rot = self._goal_distance(achieved_goal, goal) 144 # We weigh the difference in position to avoid that `d_pos` (in meters) is completely 145 # dominated by `d_rot` (in radians). 146 return -(10.0 * d_pos + d_rot) 147 148 # RobotEnv methods 149 # ---------------------------- 150 151 def _is_success(self, achieved_goal, desired_goal): 152 d_pos, d_rot = self._goal_distance(achieved_goal, desired_goal) 153 achieved_pos = (d_pos < self.distance_threshold).astype(np.float32) 154 achieved_rot = (d_rot < self.rotation_threshold).astype(np.float32) 155 achieved_both = achieved_pos * achieved_rot 156 return achieved_both 157 158 def _env_setup(self, initial_qpos): 159 for name, value in initial_qpos.items(): 160 self.sim.data.set_joint_qpos(name, value) 161 self.sim.forward() 162 163 def _reset_sim(self): 164 self.sim.set_state(self.initial_state) 165 self.sim.forward() 166 167 initial_qpos = self.sim.data.get_joint_qpos("object:joint").copy() 168 initial_pos, initial_quat = initial_qpos[:3], initial_qpos[3:] 169 assert initial_qpos.shape == (7,) 170 assert initial_pos.shape == (3,) 171 assert initial_quat.shape == (4,) 172 initial_qpos = None 173 174 # Randomization initial rotation. 175 if self.randomize_initial_rotation: 176 if self.target_rotation == "z": 177 angle = self.np_random.uniform(-np.pi, np.pi) 178 axis = np.array([0.0, 0.0, 1.0]) 179 offset_quat = quat_from_angle_and_axis(angle, axis) 180 initial_quat = rotations.quat_mul(initial_quat, offset_quat) 181 elif self.target_rotation == "parallel": 182 angle = self.np_random.uniform(-np.pi, np.pi) 183 axis = np.array([0.0, 0.0, 1.0]) 184 z_quat = quat_from_angle_and_axis(angle, axis) 185 parallel_quat = self.parallel_quats[ 186 self.np_random.randint(len(self.parallel_quats)) 187 ] 188 offset_quat = rotations.quat_mul(z_quat, parallel_quat) 189 initial_quat = rotations.quat_mul(initial_quat, offset_quat) 190 elif self.target_rotation in ["xyz", "ignore"]: 191 angle = self.np_random.uniform(-np.pi, np.pi) 192 axis = self.np_random.uniform(-1.0, 1.0, size=3) 193 offset_quat = quat_from_angle_and_axis(angle, axis) 194 initial_quat = rotations.quat_mul(initial_quat, offset_quat) 195 elif self.target_rotation == "fixed": 196 pass 197 else: 198 raise error.Error( 199 'Unknown target_rotation option "{}".'.format(self.target_rotation) 200 ) 201 202 # Randomize initial position. 203 if self.randomize_initial_position: 204 if self.target_position != "fixed": 205 initial_pos += self.np_random.normal(size=3, scale=0.005) 206 207 initial_quat /= np.linalg.norm(initial_quat) 208 initial_qpos = np.concatenate([initial_pos, initial_quat]) 209 self.sim.data.set_joint_qpos("object:joint", initial_qpos) 210 211 def is_on_palm(): 212 self.sim.forward() 213 cube_middle_idx = self.sim.model.site_name2id("object:center") 214 cube_middle_pos = self.sim.data.site_xpos[cube_middle_idx] 215 is_on_palm = cube_middle_pos[2] > 0.04 216 return is_on_palm 217 218 # Run the simulation for a bunch of timesteps to let everything settle in. 219 for _ in range(10): 220 self._set_action(np.zeros(20)) 221 try: 222 self.sim.step() 223 except mujoco_py.MujocoException: 224 return False 225 return is_on_palm() 226 227 def _sample_goal(self): 228 # Select a goal for the object position. 229 target_pos = None 230 if self.target_position == "random": 231 assert self.target_position_range.shape == (3, 2) 232 offset = self.np_random.uniform( 233 self.target_position_range[:, 0], self.target_position_range[:, 1] 234 ) 235 assert offset.shape == (3,) 236 target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset 237 elif self.target_position in ["ignore", "fixed"]: 238 target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] 239 else: 240 raise error.Error( 241 'Unknown target_position option "{}".'.format(self.target_position) 242 ) 243 assert target_pos is not None 244 assert target_pos.shape == (3,) 245 246 # Select a goal for the object rotation. 247 target_quat = None 248 if self.target_rotation == "z": 249 angle = self.np_random.uniform(-np.pi, np.pi) 250 axis = np.array([0.0, 0.0, 1.0]) 251 target_quat = quat_from_angle_and_axis(angle, axis) 252 elif self.target_rotation == "parallel": 253 angle = self.np_random.uniform(-np.pi, np.pi) 254 axis = np.array([0.0, 0.0, 1.0]) 255 target_quat = quat_from_angle_and_axis(angle, axis) 256 parallel_quat = self.parallel_quats[ 257 self.np_random.randint(len(self.parallel_quats)) 258 ] 259 target_quat = rotations.quat_mul(target_quat, parallel_quat) 260 elif self.target_rotation == "xyz": 261 angle = self.np_random.uniform(-np.pi, np.pi) 262 axis = self.np_random.uniform(-1.0, 1.0, size=3) 263 target_quat = quat_from_angle_and_axis(angle, axis) 264 elif self.target_rotation in ["ignore", "fixed"]: 265 target_quat = self.sim.data.get_joint_qpos("object:joint") 266 else: 267 raise error.Error( 268 'Unknown target_rotation option "{}".'.format(self.target_rotation) 269 ) 270 assert target_quat is not None 271 assert target_quat.shape == (4,) 272 273 target_quat /= np.linalg.norm(target_quat) # normalized quaternion 274 goal = np.concatenate([target_pos, target_quat]) 275 return goal 276 277 def _render_callback(self): 278 # Assign current state to target object but offset a bit so that the actual object 279 # is not obscured. 280 goal = self.goal.copy() 281 assert goal.shape == (7,) 282 if self.target_position == "ignore": 283 # Move the object to the side since we do not care about it's position. 284 goal[0] += 0.15 285 self.sim.data.set_joint_qpos("target:joint", goal) 286 self.sim.data.set_joint_qvel("target:joint", np.zeros(6)) 287 288 if "object_hidden" in self.sim.model.geom_names: 289 hidden_id = self.sim.model.geom_name2id("object_hidden") 290 self.sim.model.geom_rgba[hidden_id, 3] = 1.0 291 self.sim.forward() 292 293 def _get_obs(self): 294 robot_qpos, robot_qvel = robot_get_obs(self.sim) 295 object_qvel = self.sim.data.get_joint_qvel("object:joint") 296 achieved_goal = ( 297 self._get_achieved_goal().ravel() 298 ) # this contains the object position + rotation 299 observation = np.concatenate( 300 [robot_qpos, robot_qvel, object_qvel, achieved_goal] 301 ) 302 return { 303 "observation": observation.copy(), 304 "achieved_goal": achieved_goal.copy(), 305 "desired_goal": self.goal.ravel().copy(), 306 } 307 308 309class HandBlockEnv(ManipulateEnv, utils.EzPickle): 310 def __init__( 311 self, target_position="random", target_rotation="xyz", reward_type="sparse" 312 ): 313 utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) 314 ManipulateEnv.__init__( 315 self, 316 model_path=MANIPULATE_BLOCK_XML, 317 target_position=target_position, 318 target_rotation=target_rotation, 319 target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), 320 reward_type=reward_type, 321 ) 322 323 324class HandEggEnv(ManipulateEnv, utils.EzPickle): 325 def __init__( 326 self, target_position="random", target_rotation="xyz", reward_type="sparse" 327 ): 328 utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) 329 ManipulateEnv.__init__( 330 self, 331 model_path=MANIPULATE_EGG_XML, 332 target_position=target_position, 333 target_rotation=target_rotation, 334 target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), 335 reward_type=reward_type, 336 ) 337 338 339class HandPenEnv(ManipulateEnv, utils.EzPickle): 340 def __init__( 341 self, target_position="random", target_rotation="xyz", reward_type="sparse" 342 ): 343 utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) 344 ManipulateEnv.__init__( 345 self, 346 model_path=MANIPULATE_PEN_XML, 347 target_position=target_position, 348 target_rotation=target_rotation, 349 target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), 350 randomize_initial_rotation=False, 351 reward_type=reward_type, 352 ignore_z_target_rotation=True, 353 distance_threshold=0.05, 354 ) 355