1__author__ = 'Frank Sehnke, sehnke@in.tum.de'
2
3from . import sensors
4import threading
5from pybrain.utilities import threaded
6from pybrain.tools.networking.udpconnection import UDPServer
7from pybrain.rl.environments.environment import Environment
8from scipy import ones, zeros, array, clip, arange, sqrt
9from time import sleep
10
11class FlexCubeEnvironment(Environment):
12    def __init__(self, render=True, realtime=True, ip="127.0.0.1", port="21560"):
13        # initialize base class
14        self.render = render
15        if self.render:
16            self.updateDone = True
17            self.updateLock = threading.Lock()
18            self.server = UDPServer(ip, port)
19        self.actLen = 12
20        self.mySensors = sensors.Sensors(["EdgesReal"])
21        self.dists = array([20.0, sqrt(2.0) * 20, sqrt(3.0) * 20])
22        self.gravVect = array([0.0, -100.0, 0.0])
23        self.centerOfGrav = zeros((1, 3), float)
24        self.pos = ones((8, 3), float)
25        self.vel = zeros((8, 3), float)
26        self.SpringM = ones((8, 8), float)
27        self.d = 60.0
28        self.dt = 0.02
29        self.startHight = 10.0
30        self.dumping = 0.4
31        self.fraktMin = 0.7
32        self.fraktMax = 1.3
33        self.minAkt = self.dists[0] * self.fraktMin
34        self.maxAkt = self.dists[0] * self.fraktMax
35        self.reset()
36        self.count = 0
37        self.setEdges()
38        self.act(array([20.0] * 12))
39        self.euler()
40        self.realtime = realtime
41        self.step = 0
42
43    def closeSocket(self):
44        self.server.UDPInSock.close()
45        sleep(10)
46
47    def setEdges(self):
48        self.edges = zeros((12, 2), int)
49        count = 0
50        c1 = 0
51        for i in range(2):
52            for j in range(2):
53                for k in range(2):
54                    c2 = 0
55                    for i2 in range(2):
56                        for j2 in range(2):
57                            for k2 in range(2):
58                                sum = abs(i - i2) + abs(j - j2) + abs(k - k2)
59                                if sum == 1 and i <= i2 and j <= j2 and k <= k2:
60                                    self.edges[count] = [c1, c2]
61                                    count += 1
62                                c2 += 1
63                    c1 += 1
64
65    def reset(self):
66        self.action = ones((1, 12), float) * self.dists[0]
67
68        for i in range(2):
69            for j in range(2):
70                for k in range(2):
71                    self.pos[i * 4 + j * 2 + k] = [i * self.dists[0] - self.dists[0] / 2.0, j * self.dists[0] - self.dists[0] / 2.0 + self.startHight, k * self.dists[0] - self.dists[0] / 2.0]
72        self.vel = zeros((8, 3), float)
73
74        idx0 = arange(8).repeat(8)
75        idx1 = array(list(range(8)) * 8)
76        self.difM = self.pos[idx0, :] - self.pos[idx1, :] #vectors from all points to all other points
77        self.springM = sqrt((self.difM ** 2).sum(axis=1)).reshape(64, 1)
78        self.distM = self.springM.copy() #distance matrix
79        self.step = 0
80        self.mySensors.updateSensor(self.pos, self.vel, self.distM, self.centerOfGrav, self.step, self.action)
81        if self.render:
82            if self.server.clients > 0:
83                # If there are clients send them reset signal
84                self.server.send(["r", "r"])
85
86    def performAction(self, action):
87        action = self.normAct(action)
88        self.action = action.copy()
89        self.act(action)
90        self.euler()
91        self.step += 1
92
93        if self.render:
94            if self.updateDone:
95                self.updateRenderer()
96                if self.server.clients > 0 and self.realtime:
97                    sleep(0.02)
98
99    def getSensors(self):
100        self.mySensors.updateSensor(self.pos, self.vel, self.distM, self.centerOfGrav, self.step, self.action)
101        return self.mySensors.getSensor()[:]
102
103    def normAct(self, s):
104        return clip(s, self.minAkt, self.maxAkt)
105
106    def act(self, a):
107        count = 0
108        for i in self.edges:
109            self.springM[i[0] * 8 + i[1]] = a[count]
110            self.springM[i[1] * 8 + i[0]] = a[count]
111            count += 1
112
113    def euler(self):
114        self.count += 1
115        #Inner Forces
116        distM = self.distM.copy()
117        disM = self.springM - distM #difference between wanted spring lengths and current ones
118        disM = disM.reshape(64, 1)
119
120        distM = distM + 0.0000000001 #hack to prevent divs by 0
121
122        #Forces to Velos
123        #spring vectors normalized to 1 times the actual force from deformation
124        vel = self.difM / distM
125        vel *= disM * self.d * self.dt
126        idx2 = arange(8)
127
128        #TODO: arggggg!!!!!
129        for i in range(8):
130            self.vel[i] += vel[idx2 + i * 8, :].sum(axis=0)
131
132        #Gravity
133        self.vel += self.gravVect * self.dt
134
135        #Dumping
136        self.vel -= self.vel * self.dumping * self.dt
137
138        #velos to positions
139        self.pos += self.dt * self.vel
140
141        #Collisions and friction
142        for i in range(8):
143            if self.pos[i][1] < 0.0:
144                self.pos[i][1] = 0.0
145                self.vel[i] = self.vel[i] * [0.0, -1.0, 0.0]
146        self.centerOfGrav = self.pos.sum(axis=0) / 8.0
147
148        #Distances of new state
149        idx0 = arange(8).repeat(8)
150        idx1 = array(list(range(8)) * 8)
151        self.difM = self.pos[idx0, :] - self.pos[idx1, :] #vectors from all points to all other points
152        self.distM = sqrt((self.difM ** 2).sum(axis=1)).reshape(64, 1) #distance matrix
153
154    @threaded()
155    def updateRenderer(self):
156        self.updateDone = False
157        if not self.updateLock.acquire(False): return
158
159        # Listen for clients
160        self.server.listen()
161        if self.server.clients > 0:
162            # If there are clients send them the new data
163            self.server.send(repr([self.pos, self.centerOfGrav]))
164        sleep(0.02)
165        self.updateLock.release()
166        self.updateDone = True
167
168