1# Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2# Copyright 2005-2007 Olivia Mackall <olivia@selenic.com>
3#
4# This software may be used and distributed according to the terms of the
5# GNU General Public License version 2 or any later version.
6
7from __future__ import absolute_import
8
9import contextlib
10import struct
11import threading
12
13from .i18n import _
14from . import (
15    encoding,
16    error,
17    pycompat,
18    util,
19    wireprototypes,
20    wireprotov1server,
21    wireprotov2server,
22)
23from .interfaces import util as interfaceutil
24from .utils import (
25    cborutil,
26    compression,
27    stringutil,
28)
29
30stringio = util.stringio
31
32urlerr = util.urlerr
33urlreq = util.urlreq
34
35HTTP_OK = 200
36
37HGTYPE = b'application/mercurial-0.1'
38HGTYPE2 = b'application/mercurial-0.2'
39HGERRTYPE = b'application/hg-error'
40
41SSHV1 = wireprototypes.SSHV1
42SSHV2 = wireprototypes.SSHV2
43
44
45def decodevaluefromheaders(req, headerprefix):
46    """Decode a long value from multiple HTTP request headers.
47
48    Returns the value as a bytes, not a str.
49    """
50    chunks = []
51    i = 1
52    while True:
53        v = req.headers.get(b'%s-%d' % (headerprefix, i))
54        if v is None:
55            break
56        chunks.append(pycompat.bytesurl(v))
57        i += 1
58
59    return b''.join(chunks)
60
61
62@interfaceutil.implementer(wireprototypes.baseprotocolhandler)
63class httpv1protocolhandler(object):
64    def __init__(self, req, ui, checkperm):
65        self._req = req
66        self._ui = ui
67        self._checkperm = checkperm
68        self._protocaps = None
69
70    @property
71    def name(self):
72        return b'http-v1'
73
74    def getargs(self, args):
75        knownargs = self._args()
76        data = {}
77        keys = args.split()
78        for k in keys:
79            if k == b'*':
80                star = {}
81                for key in knownargs.keys():
82                    if key != b'cmd' and key not in keys:
83                        star[key] = knownargs[key][0]
84                data[b'*'] = star
85            else:
86                data[k] = knownargs[k][0]
87        return [data[k] for k in keys]
88
89    def _args(self):
90        args = self._req.qsparams.asdictoflists()
91        postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
92        if postlen:
93            args.update(
94                urlreq.parseqs(
95                    self._req.bodyfh.read(postlen), keep_blank_values=True
96                )
97            )
98            return args
99
100        argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
101        args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
102        return args
103
104    def getprotocaps(self):
105        if self._protocaps is None:
106            value = decodevaluefromheaders(self._req, b'X-HgProto')
107            self._protocaps = set(value.split(b' '))
108        return self._protocaps
109
110    def getpayload(self):
111        # Existing clients *always* send Content-Length.
112        length = int(self._req.headers[b'Content-Length'])
113
114        # If httppostargs is used, we need to read Content-Length
115        # minus the amount that was consumed by args.
116        length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
117        return util.filechunkiter(self._req.bodyfh, limit=length)
118
119    @contextlib.contextmanager
120    def mayberedirectstdio(self):
121        oldout = self._ui.fout
122        olderr = self._ui.ferr
123
124        out = util.stringio()
125
126        try:
127            self._ui.fout = out
128            self._ui.ferr = out
129            yield out
130        finally:
131            self._ui.fout = oldout
132            self._ui.ferr = olderr
133
134    def client(self):
135        return b'remote:%s:%s:%s' % (
136            self._req.urlscheme,
137            urlreq.quote(self._req.remotehost or b''),
138            urlreq.quote(self._req.remoteuser or b''),
139        )
140
141    def addcapabilities(self, repo, caps):
142        caps.append(b'batch')
143
144        caps.append(
145            b'httpheader=%d' % repo.ui.configint(b'server', b'maxhttpheaderlen')
146        )
147        if repo.ui.configbool(b'experimental', b'httppostargs'):
148            caps.append(b'httppostargs')
149
150        # FUTURE advertise 0.2rx once support is implemented
151        # FUTURE advertise minrx and mintx after consulting config option
152        caps.append(b'httpmediatype=0.1rx,0.1tx,0.2tx')
153
154        compengines = wireprototypes.supportedcompengines(
155            repo.ui, compression.SERVERROLE
156        )
157        if compengines:
158            comptypes = b','.join(
159                urlreq.quote(e.wireprotosupport().name) for e in compengines
160            )
161            caps.append(b'compression=%s' % comptypes)
162
163        return caps
164
165    def checkperm(self, perm):
166        return self._checkperm(perm)
167
168
169# This method exists mostly so that extensions like remotefilelog can
170# disable a kludgey legacy method only over http. As of early 2018,
171# there are no other known users, so with any luck we can discard this
172# hook if remotefilelog becomes a first-party extension.
173def iscmd(cmd):
174    return cmd in wireprotov1server.commands
175
176
177def handlewsgirequest(rctx, req, res, checkperm):
178    """Possibly process a wire protocol request.
179
180    If the current request is a wire protocol request, the request is
181    processed by this function.
182
183    ``req`` is a ``parsedrequest`` instance.
184    ``res`` is a ``wsgiresponse`` instance.
185
186    Returns a bool indicating if the request was serviced. If set, the caller
187    should stop processing the request, as a response has already been issued.
188    """
189    # Avoid cycle involving hg module.
190    from .hgweb import common as hgwebcommon
191
192    repo = rctx.repo
193
194    # HTTP version 1 wire protocol requests are denoted by a "cmd" query
195    # string parameter. If it isn't present, this isn't a wire protocol
196    # request.
197    if b'cmd' not in req.qsparams:
198        return False
199
200    cmd = req.qsparams[b'cmd']
201
202    # The "cmd" request parameter is used by both the wire protocol and hgweb.
203    # While not all wire protocol commands are available for all transports,
204    # if we see a "cmd" value that resembles a known wire protocol command, we
205    # route it to a protocol handler. This is better than routing possible
206    # wire protocol requests to hgweb because it prevents hgweb from using
207    # known wire protocol commands and it is less confusing for machine
208    # clients.
209    if not iscmd(cmd):
210        return False
211
212    # The "cmd" query string argument is only valid on the root path of the
213    # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
214    # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
215    # in this case. We send an HTTP 404 for backwards compatibility reasons.
216    if req.dispatchpath:
217        res.status = hgwebcommon.statusmessage(404)
218        res.headers[b'Content-Type'] = HGTYPE
219        # TODO This is not a good response to issue for this request. This
220        # is mostly for BC for now.
221        res.setbodybytes(b'0\n%s\n' % b'Not Found')
222        return True
223
224    proto = httpv1protocolhandler(
225        req, repo.ui, lambda perm: checkperm(rctx, req, perm)
226    )
227
228    # The permissions checker should be the only thing that can raise an
229    # ErrorResponse. It is kind of a layer violation to catch an hgweb
230    # exception here. So consider refactoring into a exception type that
231    # is associated with the wire protocol.
232    try:
233        _callhttp(repo, req, res, proto, cmd)
234    except hgwebcommon.ErrorResponse as e:
235        for k, v in e.headers:
236            res.headers[k] = v
237        res.status = hgwebcommon.statusmessage(
238            e.code, stringutil.forcebytestr(e)
239        )
240        # TODO This response body assumes the failed command was
241        # "unbundle." That assumption is not always valid.
242        res.setbodybytes(b'0\n%s\n' % stringutil.forcebytestr(e))
243
244    return True
245
246
247def _availableapis(repo):
248    apis = set()
249
250    # Registered APIs are made available via config options of the name of
251    # the protocol.
252    for k, v in API_HANDLERS.items():
253        section, option = v[b'config']
254        if repo.ui.configbool(section, option):
255            apis.add(k)
256
257    return apis
258
259
260def handlewsgiapirequest(rctx, req, res, checkperm):
261    """Handle requests to /api/*."""
262    assert req.dispatchparts[0] == b'api'
263
264    repo = rctx.repo
265
266    # This whole URL space is experimental for now. But we want to
267    # reserve the URL space. So, 404 all URLs if the feature isn't enabled.
268    if not repo.ui.configbool(b'experimental', b'web.apiserver'):
269        res.status = b'404 Not Found'
270        res.headers[b'Content-Type'] = b'text/plain'
271        res.setbodybytes(_(b'Experimental API server endpoint not enabled'))
272        return
273
274    # The URL space is /api/<protocol>/*. The structure of URLs under varies
275    # by <protocol>.
276
277    availableapis = _availableapis(repo)
278
279    # Requests to /api/ list available APIs.
280    if req.dispatchparts == [b'api']:
281        res.status = b'200 OK'
282        res.headers[b'Content-Type'] = b'text/plain'
283        lines = [
284            _(
285                b'APIs can be accessed at /api/<name>, where <name> can be '
286                b'one of the following:\n'
287            )
288        ]
289        if availableapis:
290            lines.extend(sorted(availableapis))
291        else:
292            lines.append(_(b'(no available APIs)\n'))
293        res.setbodybytes(b'\n'.join(lines))
294        return
295
296    proto = req.dispatchparts[1]
297
298    if proto not in API_HANDLERS:
299        res.status = b'404 Not Found'
300        res.headers[b'Content-Type'] = b'text/plain'
301        res.setbodybytes(
302            _(b'Unknown API: %s\nKnown APIs: %s')
303            % (proto, b', '.join(sorted(availableapis)))
304        )
305        return
306
307    if proto not in availableapis:
308        res.status = b'404 Not Found'
309        res.headers[b'Content-Type'] = b'text/plain'
310        res.setbodybytes(_(b'API %s not enabled\n') % proto)
311        return
312
313    API_HANDLERS[proto][b'handler'](
314        rctx, req, res, checkperm, req.dispatchparts[2:]
315    )
316
317
318# Maps API name to metadata so custom API can be registered.
319# Keys are:
320#
321# config
322#    Config option that controls whether service is enabled.
323# handler
324#    Callable receiving (rctx, req, res, checkperm, urlparts) that is called
325#    when a request to this API is received.
326# apidescriptor
327#    Callable receiving (req, repo) that is called to obtain an API
328#    descriptor for this service. The response must be serializable to CBOR.
329API_HANDLERS = {
330    wireprotov2server.HTTP_WIREPROTO_V2: {
331        b'config': (b'experimental', b'web.api.http-v2'),
332        b'handler': wireprotov2server.handlehttpv2request,
333        b'apidescriptor': wireprotov2server.httpv2apidescriptor,
334    },
335}
336
337
338def _httpresponsetype(ui, proto, prefer_uncompressed):
339    """Determine the appropriate response type and compression settings.
340
341    Returns a tuple of (mediatype, compengine, engineopts).
342    """
343    # Determine the response media type and compression engine based
344    # on the request parameters.
345
346    if b'0.2' in proto.getprotocaps():
347        # All clients are expected to support uncompressed data.
348        if prefer_uncompressed:
349            return HGTYPE2, compression._noopengine(), {}
350
351        # Now find an agreed upon compression format.
352        compformats = wireprotov1server.clientcompressionsupport(proto)
353        for engine in wireprototypes.supportedcompengines(
354            ui, compression.SERVERROLE
355        ):
356            if engine.wireprotosupport().name in compformats:
357                opts = {}
358                level = ui.configint(b'server', b'%slevel' % engine.name())
359                if level is not None:
360                    opts[b'level'] = level
361
362                return HGTYPE2, engine, opts
363
364        # No mutually supported compression format. Fall back to the
365        # legacy protocol.
366
367    # Don't allow untrusted settings because disabling compression or
368    # setting a very high compression level could lead to flooding
369    # the server's network or CPU.
370    opts = {b'level': ui.configint(b'server', b'zliblevel')}
371    return HGTYPE, util.compengines[b'zlib'], opts
372
373
374def processcapabilitieshandshake(repo, req, res, proto):
375    """Called during a ?cmd=capabilities request.
376
377    If the client is advertising support for a newer protocol, we send
378    a CBOR response with information about available services. If no
379    advertised services are available, we don't handle the request.
380    """
381    # Fall back to old behavior unless the API server is enabled.
382    if not repo.ui.configbool(b'experimental', b'web.apiserver'):
383        return False
384
385    clientapis = decodevaluefromheaders(req, b'X-HgUpgrade')
386    protocaps = decodevaluefromheaders(req, b'X-HgProto')
387    if not clientapis or not protocaps:
388        return False
389
390    # We currently only support CBOR responses.
391    protocaps = set(protocaps.split(b' '))
392    if b'cbor' not in protocaps:
393        return False
394
395    descriptors = {}
396
397    for api in sorted(set(clientapis.split()) & _availableapis(repo)):
398        handler = API_HANDLERS[api]
399
400        descriptorfn = handler.get(b'apidescriptor')
401        if not descriptorfn:
402            continue
403
404        descriptors[api] = descriptorfn(req, repo)
405
406    v1caps = wireprotov1server.dispatch(repo, proto, b'capabilities')
407    assert isinstance(v1caps, wireprototypes.bytesresponse)
408
409    m = {
410        # TODO allow this to be configurable.
411        b'apibase': b'api/',
412        b'apis': descriptors,
413        b'v1capabilities': v1caps.data,
414    }
415
416    res.status = b'200 OK'
417    res.headers[b'Content-Type'] = b'application/mercurial-cbor'
418    res.setbodybytes(b''.join(cborutil.streamencode(m)))
419
420    return True
421
422
423def _callhttp(repo, req, res, proto, cmd):
424    # Avoid cycle involving hg module.
425    from .hgweb import common as hgwebcommon
426
427    def genversion2(gen, engine, engineopts):
428        # application/mercurial-0.2 always sends a payload header
429        # identifying the compression engine.
430        name = engine.wireprotosupport().name
431        assert 0 < len(name) < 256
432        yield struct.pack(b'B', len(name))
433        yield name
434
435        for chunk in gen:
436            yield chunk
437
438    def setresponse(code, contenttype, bodybytes=None, bodygen=None):
439        if code == HTTP_OK:
440            res.status = b'200 Script output follows'
441        else:
442            res.status = hgwebcommon.statusmessage(code)
443
444        res.headers[b'Content-Type'] = contenttype
445
446        if bodybytes is not None:
447            res.setbodybytes(bodybytes)
448        if bodygen is not None:
449            res.setbodygen(bodygen)
450
451    if not wireprotov1server.commands.commandavailable(cmd, proto):
452        setresponse(
453            HTTP_OK,
454            HGERRTYPE,
455            _(
456                b'requested wire protocol command is not available over '
457                b'HTTP'
458            ),
459        )
460        return
461
462    proto.checkperm(wireprotov1server.commands[cmd].permission)
463
464    # Possibly handle a modern client wanting to switch protocols.
465    if cmd == b'capabilities' and processcapabilitieshandshake(
466        repo, req, res, proto
467    ):
468
469        return
470
471    rsp = wireprotov1server.dispatch(repo, proto, cmd)
472
473    if isinstance(rsp, bytes):
474        setresponse(HTTP_OK, HGTYPE, bodybytes=rsp)
475    elif isinstance(rsp, wireprototypes.bytesresponse):
476        setresponse(HTTP_OK, HGTYPE, bodybytes=rsp.data)
477    elif isinstance(rsp, wireprototypes.streamreslegacy):
478        setresponse(HTTP_OK, HGTYPE, bodygen=rsp.gen)
479    elif isinstance(rsp, wireprototypes.streamres):
480        gen = rsp.gen
481
482        # This code for compression should not be streamres specific. It
483        # is here because we only compress streamres at the moment.
484        mediatype, engine, engineopts = _httpresponsetype(
485            repo.ui, proto, rsp.prefer_uncompressed
486        )
487        gen = engine.compressstream(gen, engineopts)
488
489        if mediatype == HGTYPE2:
490            gen = genversion2(gen, engine, engineopts)
491
492        setresponse(HTTP_OK, mediatype, bodygen=gen)
493    elif isinstance(rsp, wireprototypes.pushres):
494        rsp = b'%d\n%s' % (rsp.res, rsp.output)
495        setresponse(HTTP_OK, HGTYPE, bodybytes=rsp)
496    elif isinstance(rsp, wireprototypes.pusherr):
497        rsp = b'0\n%s\n' % rsp.res
498        res.drain = True
499        setresponse(HTTP_OK, HGTYPE, bodybytes=rsp)
500    elif isinstance(rsp, wireprototypes.ooberror):
501        setresponse(HTTP_OK, HGERRTYPE, bodybytes=rsp.message)
502    else:
503        raise error.ProgrammingError(b'hgweb.protocol internal failure', rsp)
504
505
506def _sshv1respondbytes(fout, value):
507    """Send a bytes response for protocol version 1."""
508    fout.write(b'%d\n' % len(value))
509    fout.write(value)
510    fout.flush()
511
512
513def _sshv1respondstream(fout, source):
514    write = fout.write
515    for chunk in source.gen:
516        write(chunk)
517    fout.flush()
518
519
520def _sshv1respondooberror(fout, ferr, rsp):
521    ferr.write(b'%s\n-\n' % rsp)
522    ferr.flush()
523    fout.write(b'\n')
524    fout.flush()
525
526
527@interfaceutil.implementer(wireprototypes.baseprotocolhandler)
528class sshv1protocolhandler(object):
529    """Handler for requests services via version 1 of SSH protocol."""
530
531    def __init__(self, ui, fin, fout):
532        self._ui = ui
533        self._fin = fin
534        self._fout = fout
535        self._protocaps = set()
536
537    @property
538    def name(self):
539        return wireprototypes.SSHV1
540
541    def getargs(self, args):
542        data = {}
543        keys = args.split()
544        for n in pycompat.xrange(len(keys)):
545            argline = self._fin.readline()[:-1]
546            arg, l = argline.split()
547            if arg not in keys:
548                raise error.Abort(_(b"unexpected parameter %r") % arg)
549            if arg == b'*':
550                star = {}
551                for k in pycompat.xrange(int(l)):
552                    argline = self._fin.readline()[:-1]
553                    arg, l = argline.split()
554                    val = self._fin.read(int(l))
555                    star[arg] = val
556                data[b'*'] = star
557            else:
558                val = self._fin.read(int(l))
559                data[arg] = val
560        return [data[k] for k in keys]
561
562    def getprotocaps(self):
563        return self._protocaps
564
565    def getpayload(self):
566        # We initially send an empty response. This tells the client it is
567        # OK to start sending data. If a client sees any other response, it
568        # interprets it as an error.
569        _sshv1respondbytes(self._fout, b'')
570
571        # The file is in the form:
572        #
573        # <chunk size>\n<chunk>
574        # ...
575        # 0\n
576        count = int(self._fin.readline())
577        while count:
578            yield self._fin.read(count)
579            count = int(self._fin.readline())
580
581    @contextlib.contextmanager
582    def mayberedirectstdio(self):
583        yield None
584
585    def client(self):
586        client = encoding.environ.get(b'SSH_CLIENT', b'').split(b' ', 1)[0]
587        return b'remote:ssh:' + client
588
589    def addcapabilities(self, repo, caps):
590        if self.name == wireprototypes.SSHV1:
591            caps.append(b'protocaps')
592        caps.append(b'batch')
593        return caps
594
595    def checkperm(self, perm):
596        pass
597
598
599class sshv2protocolhandler(sshv1protocolhandler):
600    """Protocol handler for version 2 of the SSH protocol."""
601
602    @property
603    def name(self):
604        return wireprototypes.SSHV2
605
606    def addcapabilities(self, repo, caps):
607        return caps
608
609
610def _runsshserver(ui, repo, fin, fout, ev):
611    # This function operates like a state machine of sorts. The following
612    # states are defined:
613    #
614    # protov1-serving
615    #    Server is in protocol version 1 serving mode. Commands arrive on
616    #    new lines. These commands are processed in this state, one command
617    #    after the other.
618    #
619    # protov2-serving
620    #    Server is in protocol version 2 serving mode.
621    #
622    # upgrade-initial
623    #    The server is going to process an upgrade request.
624    #
625    # upgrade-v2-filter-legacy-handshake
626    #    The protocol is being upgraded to version 2. The server is expecting
627    #    the legacy handshake from version 1.
628    #
629    # upgrade-v2-finish
630    #    The upgrade to version 2 of the protocol is imminent.
631    #
632    # shutdown
633    #    The server is shutting down, possibly in reaction to a client event.
634    #
635    # And here are their transitions:
636    #
637    # protov1-serving -> shutdown
638    #    When server receives an empty request or encounters another
639    #    error.
640    #
641    # protov1-serving -> upgrade-initial
642    #    An upgrade request line was seen.
643    #
644    # upgrade-initial -> upgrade-v2-filter-legacy-handshake
645    #    Upgrade to version 2 in progress. Server is expecting to
646    #    process a legacy handshake.
647    #
648    # upgrade-v2-filter-legacy-handshake -> shutdown
649    #    Client did not fulfill upgrade handshake requirements.
650    #
651    # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
652    #    Client fulfilled version 2 upgrade requirements. Finishing that
653    #    upgrade.
654    #
655    # upgrade-v2-finish -> protov2-serving
656    #    Protocol upgrade to version 2 complete. Server can now speak protocol
657    #    version 2.
658    #
659    # protov2-serving -> protov1-serving
660    #    Ths happens by default since protocol version 2 is the same as
661    #    version 1 except for the handshake.
662
663    state = b'protov1-serving'
664    proto = sshv1protocolhandler(ui, fin, fout)
665    protoswitched = False
666
667    while not ev.is_set():
668        if state == b'protov1-serving':
669            # Commands are issued on new lines.
670            request = fin.readline()[:-1]
671
672            # Empty lines signal to terminate the connection.
673            if not request:
674                state = b'shutdown'
675                continue
676
677            # It looks like a protocol upgrade request. Transition state to
678            # handle it.
679            if request.startswith(b'upgrade '):
680                if protoswitched:
681                    _sshv1respondooberror(
682                        fout,
683                        ui.ferr,
684                        b'cannot upgrade protocols multiple times',
685                    )
686                    state = b'shutdown'
687                    continue
688
689                state = b'upgrade-initial'
690                continue
691
692            available = wireprotov1server.commands.commandavailable(
693                request, proto
694            )
695
696            # This command isn't available. Send an empty response and go
697            # back to waiting for a new command.
698            if not available:
699                _sshv1respondbytes(fout, b'')
700                continue
701
702            rsp = wireprotov1server.dispatch(repo, proto, request)
703            repo.ui.fout.flush()
704            repo.ui.ferr.flush()
705
706            if isinstance(rsp, bytes):
707                _sshv1respondbytes(fout, rsp)
708            elif isinstance(rsp, wireprototypes.bytesresponse):
709                _sshv1respondbytes(fout, rsp.data)
710            elif isinstance(rsp, wireprototypes.streamres):
711                _sshv1respondstream(fout, rsp)
712            elif isinstance(rsp, wireprototypes.streamreslegacy):
713                _sshv1respondstream(fout, rsp)
714            elif isinstance(rsp, wireprototypes.pushres):
715                _sshv1respondbytes(fout, b'')
716                _sshv1respondbytes(fout, b'%d' % rsp.res)
717            elif isinstance(rsp, wireprototypes.pusherr):
718                _sshv1respondbytes(fout, rsp.res)
719            elif isinstance(rsp, wireprototypes.ooberror):
720                _sshv1respondooberror(fout, ui.ferr, rsp.message)
721            else:
722                raise error.ProgrammingError(
723                    b'unhandled response type from '
724                    b'wire protocol command: %s' % rsp
725                )
726
727        # For now, protocol version 2 serving just goes back to version 1.
728        elif state == b'protov2-serving':
729            state = b'protov1-serving'
730            continue
731
732        elif state == b'upgrade-initial':
733            # We should never transition into this state if we've switched
734            # protocols.
735            assert not protoswitched
736            assert proto.name == wireprototypes.SSHV1
737
738            # Expected: upgrade <token> <capabilities>
739            # If we get something else, the request is malformed. It could be
740            # from a future client that has altered the upgrade line content.
741            # We treat this as an unknown command.
742            try:
743                token, caps = request.split(b' ')[1:]
744            except ValueError:
745                _sshv1respondbytes(fout, b'')
746                state = b'protov1-serving'
747                continue
748
749            # Send empty response if we don't support upgrading protocols.
750            if not ui.configbool(b'experimental', b'sshserver.support-v2'):
751                _sshv1respondbytes(fout, b'')
752                state = b'protov1-serving'
753                continue
754
755            try:
756                caps = urlreq.parseqs(caps)
757            except ValueError:
758                _sshv1respondbytes(fout, b'')
759                state = b'protov1-serving'
760                continue
761
762            # We don't see an upgrade request to protocol version 2. Ignore
763            # the upgrade request.
764            wantedprotos = caps.get(b'proto', [b''])[0]
765            if SSHV2 not in wantedprotos:
766                _sshv1respondbytes(fout, b'')
767                state = b'protov1-serving'
768                continue
769
770            # It looks like we can honor this upgrade request to protocol 2.
771            # Filter the rest of the handshake protocol request lines.
772            state = b'upgrade-v2-filter-legacy-handshake'
773            continue
774
775        elif state == b'upgrade-v2-filter-legacy-handshake':
776            # Client should have sent legacy handshake after an ``upgrade``
777            # request. Expected lines:
778            #
779            #    hello
780            #    between
781            #    pairs 81
782            #    0000...-0000...
783
784            ok = True
785            for line in (b'hello', b'between', b'pairs 81'):
786                request = fin.readline()[:-1]
787
788                if request != line:
789                    _sshv1respondooberror(
790                        fout,
791                        ui.ferr,
792                        b'malformed handshake protocol: missing %s' % line,
793                    )
794                    ok = False
795                    state = b'shutdown'
796                    break
797
798            if not ok:
799                continue
800
801            request = fin.read(81)
802            if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
803                _sshv1respondooberror(
804                    fout,
805                    ui.ferr,
806                    b'malformed handshake protocol: '
807                    b'missing between argument value',
808                )
809                state = b'shutdown'
810                continue
811
812            state = b'upgrade-v2-finish'
813            continue
814
815        elif state == b'upgrade-v2-finish':
816            # Send the upgrade response.
817            fout.write(b'upgraded %s %s\n' % (token, SSHV2))
818            servercaps = wireprotov1server.capabilities(repo, proto)
819            rsp = b'capabilities: %s' % servercaps.data
820            fout.write(b'%d\n%s\n' % (len(rsp), rsp))
821            fout.flush()
822
823            proto = sshv2protocolhandler(ui, fin, fout)
824            protoswitched = True
825
826            state = b'protov2-serving'
827            continue
828
829        elif state == b'shutdown':
830            break
831
832        else:
833            raise error.ProgrammingError(
834                b'unhandled ssh server state: %s' % state
835            )
836
837
838class sshserver(object):
839    def __init__(self, ui, repo, logfh=None):
840        self._ui = ui
841        self._repo = repo
842        self._fin, self._fout = ui.protectfinout()
843
844        # Log write I/O to stdout and stderr if configured.
845        if logfh:
846            self._fout = util.makeloggingfileobject(
847                logfh, self._fout, b'o', logdata=True
848            )
849            ui.ferr = util.makeloggingfileobject(
850                logfh, ui.ferr, b'e', logdata=True
851            )
852
853    def serve_forever(self):
854        self.serveuntil(threading.Event())
855        self._ui.restorefinout(self._fin, self._fout)
856
857    def serveuntil(self, ev):
858        """Serve until a threading.Event is set."""
859        _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
860