1# -*- coding: utf-8 -*-
2
3# Copyright (c) 2011 - 2021 Detlev Offenbach <detlev@die-offenbachs.de>
4#
5
6"""
7Module implementing an interface to the Mercurial command server.
8"""
9
10import struct
11import io
12
13from PyQt5.QtCore import (
14    QProcess, QObject, QByteArray, QCoreApplication, QThread
15)
16from PyQt5.QtWidgets import QDialog
17
18from .HgUtilities import prepareProcess, getHgExecutable
19
20
21class HgClient(QObject):
22    """
23    Class implementing the Mercurial command server interface.
24    """
25    InputFormat = ">I"
26    OutputFormat = ">cI"
27    OutputFormatSize = struct.calcsize(OutputFormat)
28    ReturnFormat = ">i"
29
30    Channels = (b"I", b"L", b"o", b"e", b"r", b"d")
31
32    def __init__(self, repoPath, encoding, vcs, parent=None):
33        """
34        Constructor
35
36        @param repoPath root directory of the repository
37        @type str
38        @param encoding encoding to be used by the command server
39        @type str
40        @param vcs reference to the VCS object
41        @type Hg
42        @param parent reference to the parent object
43        @type QObject
44        """
45        super().__init__(parent)
46
47        self.__server = None
48        self.__started = False
49        self.__version = None
50        self.__encoding = vcs.getEncoding()
51        self.__cancel = False
52        self.__commandRunning = False
53        self.__repoPath = repoPath
54
55        # generate command line and environment
56        self.__serverArgs = vcs.initCommand("serve")
57        self.__serverArgs.append("--cmdserver")
58        self.__serverArgs.append("pipe")
59        self.__serverArgs.append("--config")
60        self.__serverArgs.append("ui.interactive=True")
61        if repoPath:
62            self.__serverArgs.append("--repository")
63            self.__serverArgs.append(repoPath)
64
65        if encoding:
66            self.__encoding = encoding
67            if "--encoding" in self.__serverArgs:
68                # use the defined encoding via the environment
69                index = self.__serverArgs.index("--encoding")
70                del self.__serverArgs[index:index + 2]
71
72    def startServer(self):
73        """
74        Public method to start the command server.
75
76        @return tuple of flag indicating a successful start and an error
77            message in case of failure
78        @rtype tuple of (bool, str)
79        """
80        self.__server = QProcess()
81        self.__server.setWorkingDirectory(self.__repoPath)
82
83        # connect signals
84        self.__server.finished.connect(self.__serverFinished)
85
86        prepareProcess(self.__server, self.__encoding)
87
88        exe = getHgExecutable()
89        self.__server.start(exe, self.__serverArgs)
90        serverStarted = self.__server.waitForStarted(15000)
91        if not serverStarted:
92            return False, self.tr(
93                'The process {0} could not be started. '
94                'Ensure, that it is in the search path.'
95            ).format(exe)
96
97        self.__server.setReadChannel(QProcess.ProcessChannel.StandardOutput)
98        ok, error = self.__readHello()
99        self.__started = ok
100        return ok, error
101
102    def stopServer(self):
103        """
104        Public method to stop the command server.
105        """
106        if self.__server is not None:
107            self.__server.closeWriteChannel()
108            res = self.__server.waitForFinished(5000)
109            if not res:
110                self.__server.terminate()
111                res = self.__server.waitForFinished(3000)
112                if not res:
113                    self.__server.kill()
114                    self.__server.waitForFinished(3000)
115
116            self.__started = False
117            self.__server.deleteLater()
118            self.__server = None
119
120    def restartServer(self):
121        """
122        Public method to restart the command server.
123
124        @return tuple of flag indicating a successful start and an error
125            message in case of failure
126        @rtype tuple of (bool, str)
127        """
128        self.stopServer()
129        return self.startServer()
130
131    def __readHello(self):
132        """
133        Private method to read the hello message sent by the command server.
134
135        @return tuple of flag indicating success and an error message in case
136            of failure
137        @rtype tuple of (bool, str)
138        """
139        ch, msg = self.__readChannel()
140        if not ch:
141            return False, self.tr("Did not receive the 'hello' message.")
142        elif ch != "o":
143            return False, self.tr("Received data on unexpected channel.")
144
145        msg = msg.split("\n")
146
147        if not msg[0].startswith("capabilities: "):
148            return False, self.tr(
149                "Bad 'hello' message, expected 'capabilities: '"
150                " but got '{0}'.").format(msg[0])
151        self.__capabilities = msg[0][len('capabilities: '):]
152        if not self.__capabilities:
153            return False, self.tr("'capabilities' message did not contain"
154                                  " any capability.")
155
156        self.__capabilities = set(self.__capabilities.split())
157        if "runcommand" not in self.__capabilities:
158            return False, "'capabilities' did not contain 'runcommand'."
159
160        if not msg[1].startswith("encoding: "):
161            return False, self.tr(
162                "Bad 'hello' message, expected 'encoding: '"
163                " but got '{0}'.").format(msg[1])
164        encoding = msg[1][len('encoding: '):]
165        if not encoding:
166            return False, self.tr("'encoding' message did not contain"
167                                  " any encoding.")
168        self.__encoding = encoding
169
170        return True, ""
171
172    def __serverFinished(self, exitCode, exitStatus):
173        """
174        Private slot connected to the finished signal.
175
176        @param exitCode exit code of the process
177        @type int
178        @param exitStatus exit status of the process
179        @type QProcess.ExitStatus
180        """
181        self.__started = False
182
183    def __readChannel(self):
184        """
185        Private method to read data from the command server.
186
187        @return tuple of channel designator and channel data
188        @rtype tuple of (str, int or str or bytes)
189        """
190        if (
191            self.__server.bytesAvailable() > 0 or
192            self.__server.waitForReadyRead(10000)
193        ):
194            data = bytes(self.__server.peek(HgClient.OutputFormatSize))
195            if not data or len(data) < HgClient.OutputFormatSize:
196                return "", ""
197
198            channel, length = struct.unpack(HgClient.OutputFormat, data)
199            channel = channel.decode(self.__encoding)
200            if channel in "IL":
201                self.__server.read(HgClient.OutputFormatSize)
202                return channel, length
203            else:
204                if (
205                    self.__server.bytesAvailable() <
206                    HgClient.OutputFormatSize + length
207                ):
208                    return "", ""
209                self.__server.read(HgClient.OutputFormatSize)
210                data = self.__server.read(length)
211                if channel == "r":
212                    return (channel, data)
213                else:
214                    return (channel, str(data, self.__encoding, "replace"))
215        else:
216            return "", ""
217
218    def __writeDataBlock(self, data):
219        """
220        Private slot to write some data to the command server.
221
222        @param data data to be sent
223        @type str
224        """
225        if not isinstance(data, bytes):
226            data = data.encode(self.__encoding)
227        self.__server.write(
228            QByteArray(struct.pack(HgClient.InputFormat, len(data))))
229        self.__server.write(QByteArray(data))
230        self.__server.waitForBytesWritten()
231
232    def __runcommand(self, args, inputChannels, outputChannels):
233        """
234        Private method to run a command in the server (low level).
235
236        @param args list of arguments for the command
237        @type list of str
238        @param inputChannels dictionary of input channels. The dictionary must
239            have the keys 'I' and 'L' and each entry must be a function
240            receiving the number of bytes to write.
241        @type dict
242        @param outputChannels dictionary of output channels. The dictionary
243            must have the keys 'o' and 'e' and each entry must be a function
244            receiving the data.
245        @type dict
246        @return result code of the command, -1 if the command server wasn't
247            started or -10, if the command was canceled
248        @rtype int
249        @exception RuntimeError raised to indicate an unexpected command
250            channel
251        """
252        if not self.__started:
253            return -1
254
255        self.__server.write(QByteArray(b'runcommand\n'))
256        self.__writeDataBlock('\0'.join(args))
257
258        while True:
259            QCoreApplication.processEvents()
260
261            if self.__cancel:
262                return -10
263
264            if self.__server is None:
265                return -1
266
267            if self.__server.bytesAvailable() == 0:
268                QThread.msleep(50)
269                continue
270            channel, data = self.__readChannel()
271
272            # input channels
273            if channel in inputChannels:
274                if channel == "L":
275                    inputData, isPassword = inputChannels[channel](data)
276                    # echo the input to the output if it was a prompt
277                    if not isPassword:
278                        outputChannels["o"](inputData)
279                else:
280                    inputData = inputChannels[channel](data)
281                self.__writeDataBlock(inputData)
282
283            # output channels
284            elif channel in outputChannels:
285                outputChannels[channel](data)
286
287            # result channel, command is finished
288            elif channel == "r":
289                return struct.unpack(HgClient.ReturnFormat, data)[0]
290
291            # unexpected but required channel
292            elif channel.isupper():
293                raise RuntimeError(
294                    "Unexpected but required channel '{0}'.".format(channel))
295
296            # optional channels or no channel at all
297            else:
298                pass
299
300    def __prompt(self, size, message):
301        """
302        Private method to prompt the user for some input.
303
304        @param size maximum length of the requested input
305        @type int
306        @param message message sent by the server
307        @type str
308        @return tuple containing data entered by the user and
309            a flag indicating a password input
310        @rtype tuple of (str, bool)
311        """
312        from .HgClientPromptDialog import HgClientPromptDialog
313        inputData = ""
314        isPassword = False
315        dlg = HgClientPromptDialog(size, message)
316        if dlg.exec() == QDialog.DialogCode.Accepted:
317            inputData = dlg.getInput() + '\n'
318            isPassword = dlg.isPassword()
319        return inputData, isPassword
320
321    def runcommand(self, args, prompt=None, inputData=None, output=None,
322                   error=None):
323        """
324        Public method to execute a command via the command server.
325
326        @param args list of arguments for the command
327        @type list of str
328        @param prompt function to reply to prompts by the server. It
329            receives the max number of bytes to return and the contents
330            of the output channel received so far. If an output function is
331            given as well, the prompt data is passed through the output
332            function. The function must return the input data and a flag
333            indicating a password input.
334        @type func(int, str) -> (str, bool)
335        @param inputData function to reply to bulk data requests by the
336            server. It receives the max number of bytes to return.
337        @type func(int) -> bytes
338        @param output function receiving the data from the server. If a
339            prompt function is given, it is assumed, that the prompt output
340            is passed via this function.
341        @type func(str)
342        @param error function receiving error messages from the server
343        @type func(str)
344        @return tuple of output and errors of the command server. In case
345            output and/or error functions were given, the respective return
346            value will be an empty string.
347        @rtype tuple of (str, str)
348        """
349        if not self.__started:
350            # try to start the Mercurial command server
351            ok, startError = self.startServer()
352            if not ok:
353                return "", startError
354
355        self.__commandRunning = True
356        outputChannels = {}
357        outputBuffer = None
358        errorBuffer = None
359
360        if output is None:
361            outputBuffer = io.StringIO()
362            outputChannels["o"] = outputBuffer.write
363        else:
364            outputChannels["o"] = output
365        if error:
366            outputChannels["e"] = error
367        else:
368            errorBuffer = io.StringIO()
369            outputChannels["e"] = errorBuffer.write
370
371        inputChannels = {}
372        if prompt is not None:
373            def func(size):
374                msg = "" if outputBuffer is None else outputBuffer.getvalue()
375                reply, isPassword = prompt(size, msg)
376                return reply, isPassword
377            inputChannels["L"] = func
378        else:
379            def myprompt(size):
380                msg = (self.tr("For message see output dialog.")
381                       if outputBuffer is None else outputBuffer.getvalue())
382                reply, isPassword = self.__prompt(size, msg)
383                return reply, isPassword
384            inputChannels["L"] = myprompt
385        if inputData is not None:
386            inputChannels["I"] = inputData
387
388        self.__cancel = False
389        self.__runcommand(args, inputChannels, outputChannels)
390
391        out = outputBuffer.getvalue() if outputBuffer else ""
392        err = errorBuffer.getvalue() if errorBuffer else ""
393
394        self.__commandRunning = False
395
396        return out, err
397
398    def cancel(self):
399        """
400        Public method to cancel the running command.
401        """
402        self.__cancel = True
403        self.restartServer()
404
405    def wasCanceled(self):
406        """
407        Public method to check, if the last command was canceled.
408
409        @return flag indicating the cancel state
410        @rtype bool
411        """
412        return self.__cancel
413
414    def isExecuting(self):
415        """
416        Public method to check, if the server is executing a command.
417
418        @return flag indicating the execution of a command
419        @rtype bool
420        """
421        return self.__commandRunning
422
423    def getRepository(self):
424        """
425        Public method to get the repository path this client is serving.
426
427        @return repository path
428        @rtype str
429        """
430        return self.__repoPath
431