1# -*- coding: utf-8 -*-
2# Eclipse SUMO, Simulation of Urban MObility; see https://eclipse.org/sumo
3# Copyright (C) 2008-2019 German Aerospace Center (DLR) and others.
4# This program and the accompanying materials
5# are made available under the terms of the Eclipse Public License v2.0
6# which accompanies this distribution, and is available at
7# http://www.eclipse.org/legal/epl-v20.html
8# SPDX-License-Identifier: EPL-2.0
9
10# @file    connection.py
11# @author  Michael Behrisch
12# @author  Lena Kalleske
13# @author  Mario Krumnow
14# @author  Daniel Krajzewicz
15# @author  Jakob Erdmann
16# @date    2008-10-09
17# @version $Id$
18
19from __future__ import print_function
20from __future__ import absolute_import
21import socket
22import struct
23import sys
24import warnings
25
26try:
27    import traciemb
28    _embedded = True
29except ImportError:
30    _embedded = False
31from . import constants as tc
32from .exceptions import TraCIException, FatalTraCIError
33from .domain import _defaultDomains
34from .storage import Storage
35
36_RESULTS = {0x00: "OK", 0x01: "Not implemented", 0xFF: "Error"}
37
38
39class Connection:
40
41    """Contains the socket, the composed message string
42    together with a list of TraCI commands which are inside.
43    """
44
45    def __init__(self, host, port, process):
46        if not _embedded:
47            if sys.platform.startswith('java'):
48                # working around jython 2.7.0 bug #2273
49                self._socket = socket.socket(
50                    socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
51            else:
52                self._socket = socket.socket()
53            self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
54            self._socket.connect((host, port))
55            self._process = process
56        self._string = bytes()
57        self._queue = []
58        self._subscriptionMapping = {}
59        for domain in _defaultDomains:
60            domain._register(self, self._subscriptionMapping)
61
62    def _packString(self, s, pre=tc.TYPE_STRING):
63        self._string += struct.pack("!Bi", pre, len(s)) + s.encode("latin1")
64
65    def _packStringList(self, l):
66        self._string += struct.pack("!Bi", tc.TYPE_STRINGLIST, len(l))
67        for s in l:
68            self._string += struct.pack("!i", len(s)) + s.encode("latin1")
69
70    def _packDoubleList(self, l):
71        self._string += struct.pack("!Bi", tc.TYPE_DOUBLELIST, len(l))
72        for x in l:
73            self._string += struct.pack("!d", x)
74
75    def _recvExact(self):
76        try:
77            result = bytes()
78            while len(result) < 4:
79                t = self._socket.recv(4 - len(result))
80                if not t:
81                    return None
82                result += t
83            length = struct.unpack("!i", result)[0] - 4
84            result = bytes()
85            while len(result) < length:
86                t = self._socket.recv(length - len(result))
87                if not t:
88                    return None
89                result += t
90            return Storage(result)
91        except socket.error:
92            return None
93
94    def _sendExact(self):
95        if _embedded:
96            result = Storage(traciemb.execute(self._string))
97        else:
98            length = struct.pack("!i", len(self._string) + 4)
99            # print("python_sendExact: '%s'" % ' '.join(map(lambda x : "%X" % ord(x), self._string)))
100            self._socket.send(length + self._string)
101            result = self._recvExact()
102        if not result:
103            self._socket.close()
104            del self._socket
105            raise FatalTraCIError("connection closed by SUMO")
106        for command in self._queue:
107            prefix = result.read("!BBB")
108            err = result.readString()
109            if prefix[2] or err:
110                self._string = bytes()
111                self._queue = []
112                raise TraCIException(err, prefix[1], _RESULTS[prefix[2]])
113            elif prefix[1] != command:
114                raise FatalTraCIError("Received answer %s for command %s." % (prefix[1],
115                                                                              command))
116            elif prefix[1] == tc.CMD_STOP:
117                length = result.read("!B")[0] - 1
118                result.read("!%sx" % length)
119        self._string = bytes()
120        self._queue = []
121        return result
122
123    def _beginMessage(self, cmdID, varID, objID, length=0):
124        self._queue.append(cmdID)
125        length += 1 + 1 + 1 + 4 + len(objID)
126        if length <= 255:
127            self._string += struct.pack("!BB", length, cmdID)
128        else:
129            self._string += struct.pack("!BiB", 0, length + 4, cmdID)
130        self._packString(objID, varID)
131
132    def _sendReadOneStringCmd(self, cmdID, varID, objID):
133        self._beginMessage(cmdID, varID, objID)
134        return self._checkResult(cmdID, varID, objID)
135
136    def _sendIntCmd(self, cmdID, varID, objID, value):
137        self._beginMessage(cmdID, varID, objID, 1 + 4)
138        self._string += struct.pack("!Bi", tc.TYPE_INTEGER, value)
139        self._sendExact()
140
141    def _sendDoubleCmd(self, cmdID, varID, objID, value):
142        self._beginMessage(cmdID, varID, objID, 1 + 8)
143        self._string += struct.pack("!Bd", tc.TYPE_DOUBLE, value)
144        self._sendExact()
145
146    def _sendByteCmd(self, cmdID, varID, objID, value):
147        self._beginMessage(cmdID, varID, objID, 1 + 1)
148        self._string += struct.pack("!BB", tc.TYPE_BYTE, value)
149        self._sendExact()
150
151    def _sendUByteCmd(self, cmdID, varID, objID, value):
152        self._beginMessage(cmdID, varID, objID, 1 + 1)
153        self._string += struct.pack("!BB", tc.TYPE_UBYTE, value)
154        self._sendExact()
155
156    def _sendStringCmd(self, cmdID, varID, objID, value):
157        self._beginMessage(cmdID, varID, objID, 1 + 4 + len(value))
158        self._packString(value)
159        self._sendExact()
160
161    def _checkResult(self, cmdID, varID, objID):
162        result = self._sendExact()
163        result.readLength()
164        response, retVarID = result.read("!BB")
165        objectID = result.readString()
166        if response - cmdID != 16 or retVarID != varID or objectID != objID:
167            raise FatalTraCIError("Received answer %s,%s,%s for command %s,%s,%s."
168                                  % (response, retVarID, objectID, cmdID, varID, objID))
169        result.read("!B")     # Return type of the variable
170        return result
171
172    def _readSubscription(self, result):
173        # to enable this you also need to set _DEBUG to True in storage.py
174        # result.printDebug()
175        result.readLength()
176        response = result.read("!B")[0]
177        isVariableSubscription = (response >= tc.RESPONSE_SUBSCRIBE_INDUCTIONLOOP_VARIABLE and
178                                  response <= tc.RESPONSE_SUBSCRIBE_PERSON_VARIABLE)
179        objectID = result.readString()
180        if not isVariableSubscription:
181            domain = result.read("!B")[0]
182        numVars = result.read("!B")[0]
183        if isVariableSubscription:
184            while numVars > 0:
185                varID = result.read("!B")[0]
186                status, _ = result.read("!BB")
187                if status:
188                    print("Error!", result.readString())
189                elif response in self._subscriptionMapping:
190                    self._subscriptionMapping[response].add(objectID, varID, result)
191                else:
192                    raise FatalTraCIError(
193                        "Cannot handle subscription response %02x for %s." % (response, objectID))
194                numVars -= 1
195        else:
196            objectNo = result.read("!i")[0]
197            for _ in range(objectNo):
198                oid = result.readString()
199                if numVars == 0:
200                    self._subscriptionMapping[response].addContext(
201                        objectID, self._subscriptionMapping[domain], oid)
202                for __ in range(numVars):
203                    varID = result.read("!B")[0]
204                    status, ___ = result.read("!BB")
205                    if status:
206                        print("Error!", result.readString())
207                    elif response in self._subscriptionMapping:
208                        self._subscriptionMapping[response].addContext(
209                            objectID, self._subscriptionMapping[domain], oid, varID, result)
210                    else:
211                        raise FatalTraCIError(
212                            "Cannot handle subscription response %02x for %s." % (response, objectID))
213        return objectID, response
214
215    def _subscribe(self, cmdID, begin, end, objID, varIDs, parameters=None):
216        self._queue.append(cmdID)
217        length = 1 + 1 + 8 + 8 + 4 + len(objID) + 1 + len(varIDs)
218        if parameters:
219            for v in varIDs:
220                if v in parameters:
221                    length += len(parameters[v])
222        if length <= 255:
223            self._string += struct.pack("!B", length)
224        else:
225            self._string += struct.pack("!Bi", 0, length + 4)
226        self._string += struct.pack("!Bddi",
227                                    cmdID, begin, end, len(objID)) + objID.encode("latin1")
228        self._string += struct.pack("!B", len(varIDs))
229        for v in varIDs:
230            self._string += struct.pack("!B", v)
231            if parameters and v in parameters:
232                self._string += parameters[v]
233        result = self._sendExact()
234        if varIDs:
235            objectID, response = self._readSubscription(result)
236            if response - cmdID != 16 or objectID != objID:
237                raise FatalTraCIError("Received answer %02x,%s for subscription command %02x,%s." % (
238                    response, objectID, cmdID, objID))
239
240    def _getSubscriptionResults(self, cmdID):
241        return self._subscriptionMapping[cmdID]
242
243    def _subscribeContext(self, cmdID, begin, end, objID, domain, dist, varIDs):
244        self._queue.append(cmdID)
245        length = 1 + 1 + 8 + 8 + 4 + len(objID) + 1 + 8 + 1 + len(varIDs)
246        if length <= 255:
247            self._string += struct.pack("!B", length)
248        else:
249            self._string += struct.pack("!Bi", 0, length + 4)
250        self._string += struct.pack("!Bddi",
251                                    cmdID, begin, end, len(objID)) + objID.encode("latin1")
252        self._string += struct.pack("!BdB", domain, dist, len(varIDs))
253        for v in varIDs:
254            self._string += struct.pack("!B", v)
255        result = self._sendExact()
256        if varIDs:
257            objectID, response = self._readSubscription(result)
258            if response - cmdID != 16 or objectID != objID:
259                raise FatalTraCIError("Received answer %02x,%s for context subscription command %02x,%s." % (
260                    response, objectID, cmdID, objID))
261
262    def _addSubscriptionFilter(self, filterType, params=None):
263        command = tc.CMD_ADD_SUBSCRIPTION_FILTER
264        self._queue.append(command)
265        if filterType in (tc.FILTER_TYPE_NONE, tc.FILTER_TYPE_NOOPPOSITE,
266                          tc.FILTER_TYPE_TURN, tc.FILTER_TYPE_LEAD_FOLLOW):
267            # filter without parameter
268            assert(params is None)
269            length = 1 + 1 + 1  # length + CMD + FILTER_ID
270            self._string += struct.pack("!BBB", length, command, filterType)
271        elif filterType in (tc.FILTER_TYPE_DOWNSTREAM_DIST, tc.FILTER_TYPE_UPSTREAM_DIST):
272            # filter with float parameter
273            assert(type(params) is float)
274            length = 1 + 1 + 1 + 1 + 8  # length + CMD + FILTER_ID + floattype + float
275            self._string += struct.pack("!BBBBd", length, command, filterType, tc.TYPE_DOUBLE, params)
276        elif filterType in (tc.FILTER_TYPE_VCLASS, tc.FILTER_TYPE_VTYPE):
277            # filter with list(string) parameter
278            length = 1 + 1 + 1 + 1 + 4  # length + CMD + FILTER_ID + TYPE_STRINGLIST + length(stringlist)
279            try:
280                for s in params:
281                    length += 4 + len(s)  # length(s) + s
282            except Exception:
283                raise TraCIException("Filter type %s requires identifier list as parameter." % filterType)
284            if length <= 255:
285                self._string += struct.pack("!BBB", length, command, filterType)
286            else:
287                length += 4  # extended msg length
288                self._string += struct.pack("!BiBB", 0, length, command, filterType)
289            self._packStringList(params)
290        elif filterType == tc.FILTER_TYPE_LANES:
291            # filter with list(byte) parameter
292            # check uniqueness of given lanes in list
293            lanes = set(list(params))
294            if len(lanes) < len(list(params)):
295                warnings.warn("Ignoring duplicate lane specification for subscription filter.")
296            length = 1 + 1 + 1 + 1 + len(lanes)  # length + CMD + FILTER_ID + length(list) as ubyte + lane-indices
297            self._string += struct.pack("!BBBB", length, command, filterType, len(lanes))
298            for i in lanes:
299                if not type(i) is int:
300                    raise TraCIException("Filter type lanes requires numeric index list as parameter.")
301                if i <= -128 or i >= 128:
302                    raise TraCIException("Filter type lanes: maximal lane index is 127.")
303                if i < 0:
304                    i += 256
305                self._string += struct.pack("!B", i)
306
307    def isEmbedded(self):
308        return _embedded
309
310    def load(self, args):
311        """
312        Load a simulation from the given arguments.
313        """
314        self._queue.append(tc.CMD_LOAD)
315        self._string += struct.pack("!BiB", 0, 1 + 4 + 1 + 1 + 4 + sum(map(len, args)) + 4 * len(args), tc.CMD_LOAD)
316        self._packStringList(args)
317        self._sendExact()
318
319    def simulationStep(self, step=0.):
320        """
321        Make a simulation step and simulate up to the given second in sim time.
322        If the given value is 0 or absent, exactly one step is performed.
323        Values smaller than or equal to the current sim time result in no action.
324        """
325        if type(step) is int and step >= 1000:
326            warnings.warn("API change now handles step as floating point seconds", stacklevel=2)
327        self._queue.append(tc.CMD_SIMSTEP)
328        self._string += struct.pack("!BBd", 1 + 1 + 8, tc.CMD_SIMSTEP, step)
329        result = self._sendExact()
330        for subscriptionResults in self._subscriptionMapping.values():
331            subscriptionResults.reset()
332        numSubs = result.readInt()
333        responses = []
334        while numSubs > 0:
335            responses.append(self._readSubscription(result))
336            numSubs -= 1
337        return responses
338
339    def getVersion(self):
340        command = tc.CMD_GETVERSION
341        self._queue.append(command)
342        self._string += struct.pack("!BB", 1 + 1, command)
343        result = self._sendExact()
344        result.readLength()
345        response = result.read("!B")[0]
346        if response != command:
347            raise FatalTraCIError(
348                "Received answer %s for command %s." % (response, command))
349        return result.readInt(), result.readString()
350
351    def setOrder(self, order):
352        self._queue.append(tc.CMD_SETORDER)
353        self._string += struct.pack("!BBi", 1 + 1 + 4, tc.CMD_SETORDER, order)
354        self._sendExact()
355
356    def close(self, wait=True):
357        if not _embedded:
358            if hasattr(self, "_socket"):
359                self._queue.append(tc.CMD_CLOSE)
360                self._string += struct.pack("!BB", 1 + 1, tc.CMD_CLOSE)
361                self._sendExact()
362                self._socket.close()
363                del self._socket
364            if wait and self._process is not None:
365                self._process.wait()
366