1# commandserver.py - communicate with Mercurial's API over a pipe
2#
3#  Copyright Olivia Mackall <olivia@selenic.com>
4#
5# This software may be used and distributed according to the terms of the
6# GNU General Public License version 2 or any later version.
7
8from __future__ import absolute_import
9
10import errno
11import gc
12import os
13import random
14import signal
15import socket
16import struct
17import traceback
18
19try:
20    import selectors
21
22    selectors.BaseSelector
23except ImportError:
24    from .thirdparty import selectors2 as selectors
25
26from .i18n import _
27from .pycompat import getattr
28from . import (
29    encoding,
30    error,
31    loggingutil,
32    pycompat,
33    repocache,
34    util,
35    vfs as vfsmod,
36)
37from .utils import (
38    cborutil,
39    procutil,
40)
41
42
43class channeledoutput(object):
44    """
45    Write data to out in the following format:
46
47    data length (unsigned int),
48    data
49    """
50
51    def __init__(self, out, channel):
52        self.out = out
53        self.channel = channel
54
55    @property
56    def name(self):
57        return b'<%c-channel>' % self.channel
58
59    def write(self, data):
60        if not data:
61            return
62        # single write() to guarantee the same atomicity as the underlying file
63        self.out.write(struct.pack(b'>cI', self.channel, len(data)) + data)
64        self.out.flush()
65
66    def __getattr__(self, attr):
67        if attr in ('isatty', 'fileno', 'tell', 'seek'):
68            raise AttributeError(attr)
69        return getattr(self.out, attr)
70
71
72class channeledmessage(object):
73    """
74    Write encoded message and metadata to out in the following format:
75
76    data length (unsigned int),
77    encoded message and metadata, as a flat key-value dict.
78
79    Each message should have 'type' attribute. Messages of unknown type
80    should be ignored.
81    """
82
83    # teach ui that write() can take **opts
84    structured = True
85
86    def __init__(self, out, channel, encodename, encodefn):
87        self._cout = channeledoutput(out, channel)
88        self.encoding = encodename
89        self._encodefn = encodefn
90
91    def write(self, data, **opts):
92        opts = pycompat.byteskwargs(opts)
93        if data is not None:
94            opts[b'data'] = data
95        self._cout.write(self._encodefn(opts))
96
97    def __getattr__(self, attr):
98        return getattr(self._cout, attr)
99
100
101class channeledinput(object):
102    """
103    Read data from in_.
104
105    Requests for input are written to out in the following format:
106    channel identifier - 'I' for plain input, 'L' line based (1 byte)
107    how many bytes to send at most (unsigned int),
108
109    The client replies with:
110    data length (unsigned int), 0 meaning EOF
111    data
112    """
113
114    maxchunksize = 4 * 1024
115
116    def __init__(self, in_, out, channel):
117        self.in_ = in_
118        self.out = out
119        self.channel = channel
120
121    @property
122    def name(self):
123        return b'<%c-channel>' % self.channel
124
125    def read(self, size=-1):
126        if size < 0:
127            # if we need to consume all the clients input, ask for 4k chunks
128            # so the pipe doesn't fill up risking a deadlock
129            size = self.maxchunksize
130            s = self._read(size, self.channel)
131            buf = s
132            while s:
133                s = self._read(size, self.channel)
134                buf += s
135
136            return buf
137        else:
138            return self._read(size, self.channel)
139
140    def _read(self, size, channel):
141        if not size:
142            return b''
143        assert size > 0
144
145        # tell the client we need at most size bytes
146        self.out.write(struct.pack(b'>cI', channel, size))
147        self.out.flush()
148
149        length = self.in_.read(4)
150        length = struct.unpack(b'>I', length)[0]
151        if not length:
152            return b''
153        else:
154            return self.in_.read(length)
155
156    def readline(self, size=-1):
157        if size < 0:
158            size = self.maxchunksize
159            s = self._read(size, b'L')
160            buf = s
161            # keep asking for more until there's either no more or
162            # we got a full line
163            while s and not s.endswith(b'\n'):
164                s = self._read(size, b'L')
165                buf += s
166
167            return buf
168        else:
169            return self._read(size, b'L')
170
171    def __iter__(self):
172        return self
173
174    def next(self):
175        l = self.readline()
176        if not l:
177            raise StopIteration
178        return l
179
180    __next__ = next
181
182    def __getattr__(self, attr):
183        if attr in ('isatty', 'fileno', 'tell', 'seek'):
184            raise AttributeError(attr)
185        return getattr(self.in_, attr)
186
187
188_messageencoders = {
189    b'cbor': lambda v: b''.join(cborutil.streamencode(v)),
190}
191
192
193def _selectmessageencoder(ui):
194    encnames = ui.configlist(b'cmdserver', b'message-encodings')
195    for n in encnames:
196        f = _messageencoders.get(n)
197        if f:
198            return n, f
199    raise error.Abort(
200        b'no supported message encodings: %s' % b' '.join(encnames)
201    )
202
203
204class server(object):
205    """
206    Listens for commands on fin, runs them and writes the output on a channel
207    based stream to fout.
208    """
209
210    def __init__(self, ui, repo, fin, fout, prereposetups=None):
211        self.cwd = encoding.getcwd()
212
213        if repo:
214            # the ui here is really the repo ui so take its baseui so we don't
215            # end up with its local configuration
216            self.ui = repo.baseui
217            self.repo = repo
218            self.repoui = repo.ui
219        else:
220            self.ui = ui
221            self.repo = self.repoui = None
222        self._prereposetups = prereposetups
223
224        self.cdebug = channeledoutput(fout, b'd')
225        self.cerr = channeledoutput(fout, b'e')
226        self.cout = channeledoutput(fout, b'o')
227        self.cin = channeledinput(fin, fout, b'I')
228        self.cresult = channeledoutput(fout, b'r')
229
230        if self.ui.config(b'cmdserver', b'log') == b'-':
231            # switch log stream of server's ui to the 'd' (debug) channel
232            # (don't touch repo.ui as its lifetime is longer than the server)
233            self.ui = self.ui.copy()
234            setuplogging(self.ui, repo=None, fp=self.cdebug)
235
236        self.cmsg = None
237        if ui.config(b'ui', b'message-output') == b'channel':
238            encname, encfn = _selectmessageencoder(ui)
239            self.cmsg = channeledmessage(fout, b'm', encname, encfn)
240
241        self.client = fin
242
243        # If shutdown-on-interrupt is off, the default SIGINT handler is
244        # removed so that client-server communication wouldn't be interrupted.
245        # For example, 'runcommand' handler will issue three short read()s.
246        # If one of the first two read()s were interrupted, the communication
247        # channel would be left at dirty state and the subsequent request
248        # wouldn't be parsed. So catching KeyboardInterrupt isn't enough.
249        self._shutdown_on_interrupt = ui.configbool(
250            b'cmdserver', b'shutdown-on-interrupt'
251        )
252        self._old_inthandler = None
253        if not self._shutdown_on_interrupt:
254            self._old_inthandler = signal.signal(signal.SIGINT, signal.SIG_IGN)
255
256    def cleanup(self):
257        """release and restore resources taken during server session"""
258        if not self._shutdown_on_interrupt:
259            signal.signal(signal.SIGINT, self._old_inthandler)
260
261    def _read(self, size):
262        if not size:
263            return b''
264
265        data = self.client.read(size)
266
267        # is the other end closed?
268        if not data:
269            raise EOFError
270
271        return data
272
273    def _readstr(self):
274        """read a string from the channel
275
276        format:
277        data length (uint32), data
278        """
279        length = struct.unpack(b'>I', self._read(4))[0]
280        if not length:
281            return b''
282        return self._read(length)
283
284    def _readlist(self):
285        """read a list of NULL separated strings from the channel"""
286        s = self._readstr()
287        if s:
288            return s.split(b'\0')
289        else:
290            return []
291
292    def _dispatchcommand(self, req):
293        from . import dispatch  # avoid cycle
294
295        if self._shutdown_on_interrupt:
296            # no need to restore SIGINT handler as it is unmodified.
297            return dispatch.dispatch(req)
298
299        try:
300            signal.signal(signal.SIGINT, self._old_inthandler)
301            return dispatch.dispatch(req)
302        except error.SignalInterrupt:
303            # propagate SIGBREAK, SIGHUP, or SIGTERM.
304            raise
305        except KeyboardInterrupt:
306            # SIGINT may be received out of the try-except block of dispatch(),
307            # so catch it as last ditch. Another KeyboardInterrupt may be
308            # raised while handling exceptions here, but there's no way to
309            # avoid that except for doing everything in C.
310            pass
311        finally:
312            signal.signal(signal.SIGINT, signal.SIG_IGN)
313        # On KeyboardInterrupt, print error message and exit *after* SIGINT
314        # handler removed.
315        req.ui.error(_(b'interrupted!\n'))
316        return -1
317
318    def runcommand(self):
319        """reads a list of \0 terminated arguments, executes
320        and writes the return code to the result channel"""
321        from . import dispatch  # avoid cycle
322
323        args = self._readlist()
324
325        # copy the uis so changes (e.g. --config or --verbose) don't
326        # persist between requests
327        copiedui = self.ui.copy()
328        uis = [copiedui]
329        if self.repo:
330            self.repo.baseui = copiedui
331            # clone ui without using ui.copy because this is protected
332            repoui = self.repoui.__class__(self.repoui)
333            repoui.copy = copiedui.copy  # redo copy protection
334            uis.append(repoui)
335            self.repo.ui = self.repo.dirstate._ui = repoui
336            self.repo.invalidateall()
337
338        for ui in uis:
339            ui.resetstate()
340            # any kind of interaction must use server channels, but chg may
341            # replace channels by fully functional tty files. so nontty is
342            # enforced only if cin is a channel.
343            if not util.safehasattr(self.cin, b'fileno'):
344                ui.setconfig(b'ui', b'nontty', b'true', b'commandserver')
345
346        req = dispatch.request(
347            args[:],
348            copiedui,
349            self.repo,
350            self.cin,
351            self.cout,
352            self.cerr,
353            self.cmsg,
354            prereposetups=self._prereposetups,
355        )
356
357        try:
358            ret = self._dispatchcommand(req) & 255
359            # If shutdown-on-interrupt is off, it's important to write the
360            # result code *after* SIGINT handler removed. If the result code
361            # were lost, the client wouldn't be able to continue processing.
362            self.cresult.write(struct.pack(b'>i', int(ret)))
363        finally:
364            # restore old cwd
365            if b'--cwd' in args:
366                os.chdir(self.cwd)
367
368    def getencoding(self):
369        """writes the current encoding to the result channel"""
370        self.cresult.write(encoding.encoding)
371
372    def serveone(self):
373        cmd = self.client.readline()[:-1]
374        if cmd:
375            handler = self.capabilities.get(cmd)
376            if handler:
377                handler(self)
378            else:
379                # clients are expected to check what commands are supported by
380                # looking at the servers capabilities
381                raise error.Abort(_(b'unknown command %s') % cmd)
382
383        return cmd != b''
384
385    capabilities = {b'runcommand': runcommand, b'getencoding': getencoding}
386
387    def serve(self):
388        hellomsg = b'capabilities: ' + b' '.join(sorted(self.capabilities))
389        hellomsg += b'\n'
390        hellomsg += b'encoding: ' + encoding.encoding
391        hellomsg += b'\n'
392        if self.cmsg:
393            hellomsg += b'message-encoding: %s\n' % self.cmsg.encoding
394        hellomsg += b'pid: %d' % procutil.getpid()
395        if util.safehasattr(os, b'getpgid'):
396            hellomsg += b'\n'
397            hellomsg += b'pgid: %d' % os.getpgid(0)
398
399        # write the hello msg in -one- chunk
400        self.cout.write(hellomsg)
401
402        try:
403            while self.serveone():
404                pass
405        except EOFError:
406            # we'll get here if the client disconnected while we were reading
407            # its request
408            return 1
409
410        return 0
411
412
413def setuplogging(ui, repo=None, fp=None):
414    """Set up server logging facility
415
416    If cmdserver.log is '-', log messages will be sent to the given fp.
417    It should be the 'd' channel while a client is connected, and otherwise
418    is the stderr of the server process.
419    """
420    # developer config: cmdserver.log
421    logpath = ui.config(b'cmdserver', b'log')
422    if not logpath:
423        return
424    # developer config: cmdserver.track-log
425    tracked = set(ui.configlist(b'cmdserver', b'track-log'))
426
427    if logpath == b'-' and fp:
428        logger = loggingutil.fileobjectlogger(fp, tracked)
429    elif logpath == b'-':
430        logger = loggingutil.fileobjectlogger(ui.ferr, tracked)
431    else:
432        logpath = util.abspath(util.expandpath(logpath))
433        # developer config: cmdserver.max-log-files
434        maxfiles = ui.configint(b'cmdserver', b'max-log-files')
435        # developer config: cmdserver.max-log-size
436        maxsize = ui.configbytes(b'cmdserver', b'max-log-size')
437        vfs = vfsmod.vfs(os.path.dirname(logpath))
438        logger = loggingutil.filelogger(
439            vfs,
440            os.path.basename(logpath),
441            tracked,
442            maxfiles=maxfiles,
443            maxsize=maxsize,
444        )
445
446    targetuis = {ui}
447    if repo:
448        targetuis.add(repo.baseui)
449        targetuis.add(repo.ui)
450    for u in targetuis:
451        u.setlogger(b'cmdserver', logger)
452
453
454class pipeservice(object):
455    def __init__(self, ui, repo, opts):
456        self.ui = ui
457        self.repo = repo
458
459    def init(self):
460        pass
461
462    def run(self):
463        ui = self.ui
464        # redirect stdio to null device so that broken extensions or in-process
465        # hooks will never cause corruption of channel protocol.
466        with ui.protectedfinout() as (fin, fout):
467            sv = server(ui, self.repo, fin, fout)
468            try:
469                return sv.serve()
470            finally:
471                sv.cleanup()
472
473
474def _initworkerprocess():
475    # use a different process group from the master process, in order to:
476    # 1. make the current process group no longer "orphaned" (because the
477    #    parent of this process is in a different process group while
478    #    remains in a same session)
479    #    according to POSIX 2.2.2.52, orphaned process group will ignore
480    #    terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
481    #    cause trouble for things like ncurses.
482    # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
483    #    SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
484    #    processes like ssh will be killed properly, without affecting
485    #    unrelated processes.
486    os.setpgid(0, 0)
487    # change random state otherwise forked request handlers would have a
488    # same state inherited from parent.
489    random.seed()
490
491
492def _serverequest(ui, repo, conn, createcmdserver, prereposetups):
493    fin = conn.makefile('rb')
494    fout = conn.makefile('wb')
495    sv = None
496    try:
497        sv = createcmdserver(repo, conn, fin, fout, prereposetups)
498        try:
499            sv.serve()
500        # handle exceptions that may be raised by command server. most of
501        # known exceptions are caught by dispatch.
502        except error.Abort as inst:
503            ui.error(_(b'abort: %s\n') % inst.message)
504        except IOError as inst:
505            if inst.errno != errno.EPIPE:
506                raise
507        except KeyboardInterrupt:
508            pass
509        finally:
510            sv.cleanup()
511    except:  # re-raises
512        # also write traceback to error channel. otherwise client cannot
513        # see it because it is written to server's stderr by default.
514        if sv:
515            cerr = sv.cerr
516        else:
517            cerr = channeledoutput(fout, b'e')
518        cerr.write(encoding.strtolocal(traceback.format_exc()))
519        raise
520    finally:
521        fin.close()
522        try:
523            fout.close()  # implicit flush() may cause another EPIPE
524        except IOError as inst:
525            if inst.errno != errno.EPIPE:
526                raise
527
528
529class unixservicehandler(object):
530    """Set of pluggable operations for unix-mode services
531
532    Almost all methods except for createcmdserver() are called in the main
533    process. You can't pass mutable resource back from createcmdserver().
534    """
535
536    pollinterval = None
537
538    def __init__(self, ui):
539        self.ui = ui
540
541    def bindsocket(self, sock, address):
542        util.bindunixsocket(sock, address)
543        sock.listen(socket.SOMAXCONN)
544        self.ui.status(_(b'listening at %s\n') % address)
545        self.ui.flush()  # avoid buffering of status message
546
547    def unlinksocket(self, address):
548        os.unlink(address)
549
550    def shouldexit(self):
551        """True if server should shut down; checked per pollinterval"""
552        return False
553
554    def newconnection(self):
555        """Called when main process notices new connection"""
556
557    def createcmdserver(self, repo, conn, fin, fout, prereposetups):
558        """Create new command server instance; called in the process that
559        serves for the current connection"""
560        return server(self.ui, repo, fin, fout, prereposetups)
561
562
563class unixforkingservice(object):
564    """
565    Listens on unix domain socket and forks server per connection
566    """
567
568    def __init__(self, ui, repo, opts, handler=None):
569        self.ui = ui
570        self.repo = repo
571        self.address = opts[b'address']
572        if not util.safehasattr(socket, b'AF_UNIX'):
573            raise error.Abort(_(b'unsupported platform'))
574        if not self.address:
575            raise error.Abort(_(b'no socket path specified with --address'))
576        self._servicehandler = handler or unixservicehandler(ui)
577        self._sock = None
578        self._mainipc = None
579        self._workeripc = None
580        self._oldsigchldhandler = None
581        self._workerpids = set()  # updated by signal handler; do not iterate
582        self._socketunlinked = None
583        # experimental config: cmdserver.max-repo-cache
584        maxlen = ui.configint(b'cmdserver', b'max-repo-cache')
585        if maxlen < 0:
586            raise error.Abort(_(b'negative max-repo-cache size not allowed'))
587        self._repoloader = repocache.repoloader(ui, maxlen)
588        # attempt to avoid crash in CoreFoundation when using chg after fix in
589        # a89381e04c58
590        if pycompat.isdarwin:
591            procutil.gui()
592
593    def init(self):
594        self._sock = socket.socket(socket.AF_UNIX)
595        # IPC channel from many workers to one main process; this is actually
596        # a uni-directional pipe, but is backed by a DGRAM socket so each
597        # message can be easily separated.
598        o = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
599        self._mainipc, self._workeripc = o
600        self._servicehandler.bindsocket(self._sock, self.address)
601        if util.safehasattr(procutil, b'unblocksignal'):
602            procutil.unblocksignal(signal.SIGCHLD)
603        o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
604        self._oldsigchldhandler = o
605        self._socketunlinked = False
606        self._repoloader.start()
607
608    def _unlinksocket(self):
609        if not self._socketunlinked:
610            self._servicehandler.unlinksocket(self.address)
611            self._socketunlinked = True
612
613    def _cleanup(self):
614        signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
615        self._sock.close()
616        self._mainipc.close()
617        self._workeripc.close()
618        self._unlinksocket()
619        self._repoloader.stop()
620        # don't kill child processes as they have active clients, just wait
621        self._reapworkers(0)
622
623    def run(self):
624        try:
625            self._mainloop()
626        finally:
627            self._cleanup()
628
629    def _mainloop(self):
630        exiting = False
631        h = self._servicehandler
632        selector = selectors.DefaultSelector()
633        selector.register(
634            self._sock, selectors.EVENT_READ, self._acceptnewconnection
635        )
636        selector.register(
637            self._mainipc, selectors.EVENT_READ, self._handlemainipc
638        )
639        while True:
640            if not exiting and h.shouldexit():
641                # clients can no longer connect() to the domain socket, so
642                # we stop queuing new requests.
643                # for requests that are queued (connect()-ed, but haven't been
644                # accept()-ed), handle them before exit. otherwise, clients
645                # waiting for recv() will receive ECONNRESET.
646                self._unlinksocket()
647                exiting = True
648            try:
649                events = selector.select(timeout=h.pollinterval)
650            except OSError as inst:
651                # selectors2 raises ETIMEDOUT if timeout exceeded while
652                # handling signal interrupt. That's probably wrong, but
653                # we can easily get around it.
654                if inst.errno != errno.ETIMEDOUT:
655                    raise
656                events = []
657            if not events:
658                # only exit if we completed all queued requests
659                if exiting:
660                    break
661                continue
662            for key, _mask in events:
663                key.data(key.fileobj, selector)
664        selector.close()
665
666    def _acceptnewconnection(self, sock, selector):
667        h = self._servicehandler
668        try:
669            conn, _addr = sock.accept()
670        except socket.error as inst:
671            if inst.args[0] == errno.EINTR:
672                return
673            raise
674
675        # Future improvement: On Python 3.7, maybe gc.freeze() can be used
676        # to prevent COW memory from being touched by GC.
677        # https://instagram-engineering.com/
678        #   copy-on-write-friendly-python-garbage-collection-ad6ed5233ddf
679        pid = os.fork()
680        if pid:
681            try:
682                self.ui.log(
683                    b'cmdserver', b'forked worker process (pid=%d)\n', pid
684                )
685                self._workerpids.add(pid)
686                h.newconnection()
687            finally:
688                conn.close()  # release handle in parent process
689        else:
690            try:
691                selector.close()
692                sock.close()
693                self._mainipc.close()
694                self._runworker(conn)
695                conn.close()
696                self._workeripc.close()
697                os._exit(0)
698            except:  # never return, hence no re-raises
699                try:
700                    self.ui.traceback(force=True)
701                finally:
702                    os._exit(255)
703
704    def _handlemainipc(self, sock, selector):
705        """Process messages sent from a worker"""
706        try:
707            path = sock.recv(32768)  # large enough to receive path
708        except socket.error as inst:
709            if inst.args[0] == errno.EINTR:
710                return
711            raise
712        self._repoloader.load(path)
713
714    def _sigchldhandler(self, signal, frame):
715        self._reapworkers(os.WNOHANG)
716
717    def _reapworkers(self, options):
718        while self._workerpids:
719            try:
720                pid, _status = os.waitpid(-1, options)
721            except OSError as inst:
722                if inst.errno == errno.EINTR:
723                    continue
724                if inst.errno != errno.ECHILD:
725                    raise
726                # no child processes at all (reaped by other waitpid()?)
727                self._workerpids.clear()
728                return
729            if pid == 0:
730                # no waitable child processes
731                return
732            self.ui.log(b'cmdserver', b'worker process exited (pid=%d)\n', pid)
733            self._workerpids.discard(pid)
734
735    def _runworker(self, conn):
736        signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
737        _initworkerprocess()
738        h = self._servicehandler
739        try:
740            _serverequest(
741                self.ui,
742                self.repo,
743                conn,
744                h.createcmdserver,
745                prereposetups=[self._reposetup],
746            )
747        finally:
748            gc.collect()  # trigger __del__ since worker process uses os._exit
749
750    def _reposetup(self, ui, repo):
751        if not repo.local():
752            return
753
754        class unixcmdserverrepo(repo.__class__):
755            def close(self):
756                super(unixcmdserverrepo, self).close()
757                try:
758                    self._cmdserveripc.send(self.root)
759                except socket.error:
760                    self.ui.log(
761                        b'cmdserver', b'failed to send repo root to master\n'
762                    )
763
764        repo.__class__ = unixcmdserverrepo
765        repo._cmdserveripc = self._workeripc
766
767        cachedrepo = self._repoloader.get(repo.root)
768        if cachedrepo is None:
769            return
770        repo.ui.log(b'repocache', b'repo from cache: %s\n', repo.root)
771        repocache.copycache(cachedrepo, repo)
772