1__author__ = 'Frank Sehnke, sehnke@in.tum.de' 2 3from . import sensors 4import threading 5from pybrain.utilities import threaded 6from pybrain.tools.networking.udpconnection import UDPServer 7from pybrain.rl.environments.environment import Environment 8from scipy import ones, zeros, array, clip, arange, sqrt 9from time import sleep 10 11class FlexCubeEnvironment(Environment): 12 def __init__(self, render=True, realtime=True, ip="127.0.0.1", port="21560"): 13 # initialize base class 14 self.render = render 15 if self.render: 16 self.updateDone = True 17 self.updateLock = threading.Lock() 18 self.server = UDPServer(ip, port) 19 self.actLen = 12 20 self.mySensors = sensors.Sensors(["EdgesReal"]) 21 self.dists = array([20.0, sqrt(2.0) * 20, sqrt(3.0) * 20]) 22 self.gravVect = array([0.0, -100.0, 0.0]) 23 self.centerOfGrav = zeros((1, 3), float) 24 self.pos = ones((8, 3), float) 25 self.vel = zeros((8, 3), float) 26 self.SpringM = ones((8, 8), float) 27 self.d = 60.0 28 self.dt = 0.02 29 self.startHight = 10.0 30 self.dumping = 0.4 31 self.fraktMin = 0.7 32 self.fraktMax = 1.3 33 self.minAkt = self.dists[0] * self.fraktMin 34 self.maxAkt = self.dists[0] * self.fraktMax 35 self.reset() 36 self.count = 0 37 self.setEdges() 38 self.act(array([20.0] * 12)) 39 self.euler() 40 self.realtime = realtime 41 self.step = 0 42 43 def closeSocket(self): 44 self.server.UDPInSock.close() 45 sleep(10) 46 47 def setEdges(self): 48 self.edges = zeros((12, 2), int) 49 count = 0 50 c1 = 0 51 for i in range(2): 52 for j in range(2): 53 for k in range(2): 54 c2 = 0 55 for i2 in range(2): 56 for j2 in range(2): 57 for k2 in range(2): 58 sum = abs(i - i2) + abs(j - j2) + abs(k - k2) 59 if sum == 1 and i <= i2 and j <= j2 and k <= k2: 60 self.edges[count] = [c1, c2] 61 count += 1 62 c2 += 1 63 c1 += 1 64 65 def reset(self): 66 self.action = ones((1, 12), float) * self.dists[0] 67 68 for i in range(2): 69 for j in range(2): 70 for k in range(2): 71 self.pos[i * 4 + j * 2 + k] = [i * self.dists[0] - self.dists[0] / 2.0, j * self.dists[0] - self.dists[0] / 2.0 + self.startHight, k * self.dists[0] - self.dists[0] / 2.0] 72 self.vel = zeros((8, 3), float) 73 74 idx0 = arange(8).repeat(8) 75 idx1 = array(list(range(8)) * 8) 76 self.difM = self.pos[idx0, :] - self.pos[idx1, :] #vectors from all points to all other points 77 self.springM = sqrt((self.difM ** 2).sum(axis=1)).reshape(64, 1) 78 self.distM = self.springM.copy() #distance matrix 79 self.step = 0 80 self.mySensors.updateSensor(self.pos, self.vel, self.distM, self.centerOfGrav, self.step, self.action) 81 if self.render: 82 if self.server.clients > 0: 83 # If there are clients send them reset signal 84 self.server.send(["r", "r"]) 85 86 def performAction(self, action): 87 action = self.normAct(action) 88 self.action = action.copy() 89 self.act(action) 90 self.euler() 91 self.step += 1 92 93 if self.render: 94 if self.updateDone: 95 self.updateRenderer() 96 if self.server.clients > 0 and self.realtime: 97 sleep(0.02) 98 99 def getSensors(self): 100 self.mySensors.updateSensor(self.pos, self.vel, self.distM, self.centerOfGrav, self.step, self.action) 101 return self.mySensors.getSensor()[:] 102 103 def normAct(self, s): 104 return clip(s, self.minAkt, self.maxAkt) 105 106 def act(self, a): 107 count = 0 108 for i in self.edges: 109 self.springM[i[0] * 8 + i[1]] = a[count] 110 self.springM[i[1] * 8 + i[0]] = a[count] 111 count += 1 112 113 def euler(self): 114 self.count += 1 115 #Inner Forces 116 distM = self.distM.copy() 117 disM = self.springM - distM #difference between wanted spring lengths and current ones 118 disM = disM.reshape(64, 1) 119 120 distM = distM + 0.0000000001 #hack to prevent divs by 0 121 122 #Forces to Velos 123 #spring vectors normalized to 1 times the actual force from deformation 124 vel = self.difM / distM 125 vel *= disM * self.d * self.dt 126 idx2 = arange(8) 127 128 #TODO: arggggg!!!!! 129 for i in range(8): 130 self.vel[i] += vel[idx2 + i * 8, :].sum(axis=0) 131 132 #Gravity 133 self.vel += self.gravVect * self.dt 134 135 #Dumping 136 self.vel -= self.vel * self.dumping * self.dt 137 138 #velos to positions 139 self.pos += self.dt * self.vel 140 141 #Collisions and friction 142 for i in range(8): 143 if self.pos[i][1] < 0.0: 144 self.pos[i][1] = 0.0 145 self.vel[i] = self.vel[i] * [0.0, -1.0, 0.0] 146 self.centerOfGrav = self.pos.sum(axis=0) / 8.0 147 148 #Distances of new state 149 idx0 = arange(8).repeat(8) 150 idx1 = array(list(range(8)) * 8) 151 self.difM = self.pos[idx0, :] - self.pos[idx1, :] #vectors from all points to all other points 152 self.distM = sqrt((self.difM ** 2).sum(axis=1)).reshape(64, 1) #distance matrix 153 154 @threaded() 155 def updateRenderer(self): 156 self.updateDone = False 157 if not self.updateLock.acquire(False): return 158 159 # Listen for clients 160 self.server.listen() 161 if self.server.clients > 0: 162 # If there are clients send them the new data 163 self.server.send(repr([self.pos, self.centerOfGrav])) 164 sleep(0.02) 165 self.updateLock.release() 166 self.updateDone = True 167 168