1from __future__ import print_function
3import hashlib
4import os
5import shutil
6import sys
7import tempfile
8import zipfile
10import six
11from humanize import naturalsize
12from tqdm import tqdm
13from twisted.internet import reactor
14from twisted.internet.defer import inlineCallbacks, returnValue
15from twisted.python import log
16from wormhole import __version__, create, input_with_completion
18from ..errors import TransferError
19from ..transit import TransitReceiver
20from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
21                    estimate_free_space)
22from .welcome import handle_welcome
24APPID = u"lothar.com/wormhole/text-or-file-xfer"
26KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0))
27VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
30class RespondError(Exception):
31    def __init__(self, response):
32        self.response = response
35class TransferRejectedError(RespondError):
36    def __init__(self):
37        RespondError.__init__(self, "transfer rejected")
40def receive(args, reactor=reactor, _debug_stash_wormhole=None):
41    """I implement 'wormhole receive'. I return a Deferred that fires with
42    None (for success), or signals one of the following errors:
43    * WrongPasswordError: the two sides didn't use matching passwords
44    * Timeout: something didn't happen fast enough for our tastes
45    * TransferError: the sender rejected the transfer: verifier mismatch
46    * any other error: something unexpected happened
47    """
48    r = Receiver(args, reactor)
49    d = r.go()
50    if _debug_stash_wormhole is not None:
51        _debug_stash_wormhole.append(r._w)
52    return d
55class Receiver:
56    def __init__(self, args, reactor=reactor):
57        assert isinstance(args.relay_url, type(u""))
58        self.args = args
59        self._reactor = reactor
60        self._tor = None
61        self._transit_receiver = None
63    def _msg(self, *args, **kwargs):
64        print(*args, file=self.args.stderr, **kwargs)
66    @inlineCallbacks
67    def go(self):
68        if self.args.tor:
69            with self.args.timing.add("import", which="tor_manager"):
70                from ..tor_manager import get_tor
71            # For now, block everything until Tor has started. Soon: launch
72            # tor in parallel with everything else, make sure the Tor object
73            # can lazy-provide an endpoint, and overlap the startup process
74            # with the user handing off the wormhole code
75            self._tor = yield get_tor(
76                self._reactor,
77                self.args.launch_tor,
78                self.args.tor_control_port,
79                timing=self.args.timing)
81        w = create(
82            self.args.appid or APPID,
83            self.args.relay_url,
84            self._reactor,
85            tor=self._tor,
86            timing=self.args.timing)
87        self._w = w  # so tests can wait on events too
89        # I wanted to do this instead:
90        #
91        #    try:
92        #        yield self._go(w, tor)
93        #    finally:
94        #        yield w.close()
95        #
96        # but when _go had a UsageError, the stacktrace was always displayed
97        # as coming from the "yield self._go" line, which wasn't very useful
98        # for tracking it down.
99        d = self._go(w)
101        # if we succeed, we should close and return the w.close results
102        # (which might be an error)
103        @inlineCallbacks
104        def _good(res):
105            yield w.close()  # wait for ack
106            returnValue(res)
108        # if we raise an error, we should close and then return the original
109        # error (the close might give us an error, but it isn't as important
110        # as the original one)
111        @inlineCallbacks
112        def _bad(f):
113            try:
114                yield w.close()  # might be an error too
115            except Exception:
116                pass
117            returnValue(f)
119        d.addCallbacks(_good, _bad)
120        yield d
122    @inlineCallbacks
123    def _go(self, w):
124        welcome = yield w.get_welcome()
125        handle_welcome(welcome, self.args.relay_url, __version__,
126                       self.args.stderr)
128        yield self._handle_code(w)
130        def on_slow_key():
131            print(u"Waiting for sender...", file=self.args.stderr)
133        notify = self._reactor.callLater(KEY_TIMER, on_slow_key)
134        try:
135            # We wait here until we connect to the server and see the senders
136            # PAKE message. If we used set_code() in the "human-selected
137            # offline codes" mode, then the sender might not have even
138            # started yet, so we might be sitting here for a while. Because
139            # of that possibility, it's probably not appropriate to give up
140            # automatically after some timeout. The user can express their
141            # impatience by quitting the program with control-C.
142            yield w.get_unverified_key()
143        finally:
144            if not notify.called:
145                notify.cancel()
147        def on_slow_verification():
148            print(
149                u"Key established, waiting for confirmation...",
150                file=self.args.stderr)
152        notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification)
153        try:
154            # We wait here until we've seen their VERSION message (which they
155            # send after seeing our PAKE message, and has the side-effect of
156            # verifying that we both share the same key). There is a
157            # round-trip between these two events, and we could experience a
158            # significant delay here if:
159            # * the relay server is being restarted
160            # * the network is very slow
161            # * the sender is very slow
162            # * the sender has quit (in which case we may wait forever)
164            # It would be reasonable to give up after waiting here for too
165            # long.
166            verifier_bytes = yield w.get_verifier()
167        finally:
168            if not notify.called:
169                notify.cancel()
170        self._show_verifier(verifier_bytes)
172        want_offer = True
174        while True:
175            them_d = yield self._get_data(w)
176            # print("GOT", them_d)
177            recognized = False
178            if u"transit" in them_d:
179                recognized = True
180                yield self._parse_transit(them_d[u"transit"], w)
181            if u"offer" in them_d:
182                recognized = True
183                if not want_offer:
184                    raise TransferError("duplicate offer")
185                want_offer = False
186                try:
187                    yield self._parse_offer(them_d[u"offer"], w)
188                except RespondError as r:
189                    self._send_data({"error": r.response}, w)
190                    raise TransferError(r.response)
191                returnValue(None)
192            if not recognized:
193                log.msg("unrecognized message %r" % (them_d, ))
195    def _send_data(self, data, w):
196        data_bytes = dict_to_bytes(data)
197        w.send_message(data_bytes)
199    @inlineCallbacks
200    def _get_data(self, w):
201        # this may raise WrongPasswordError
202        them_bytes = yield w.get_message()
203        them_d = bytes_to_dict(them_bytes)
204        if "error" in them_d:
205            raise TransferError(them_d["error"])
206        returnValue(them_d)
208    @inlineCallbacks
209    def _handle_code(self, w):
210        code = self.args.code
211        if self.args.zeromode:
212            assert not code
213            code = u"0-"
214        if code:
215            w.set_code(code)
216        else:
217            prompt = "Enter receive wormhole code: "
218            used_completion = yield input_with_completion(
219                prompt, w.input_code(), self._reactor)
220            if not used_completion:
221                print(
222                    " (note: you can use <Tab> to complete words)",
223                    file=self.args.stderr)
224        yield w.get_code()
226    def _show_verifier(self, verifier_bytes):
227        verifier_hex = bytes_to_hexstr(verifier_bytes)
228        if self.args.verify:
229            self._msg(u"Verifier %s." % verifier_hex)
231    @inlineCallbacks
232    def _parse_transit(self, sender_transit, w):
233        if self._transit_receiver:
234            # TODO: accept multiple messages, add the additional hints to the
235            # existing TransitReceiver
236            return
237        yield self._build_transit(w, sender_transit)
239    @inlineCallbacks
240    def _build_transit(self, w, sender_transit):
241        tr = TransitReceiver(
242            self.args.transit_helper,
243            no_listen=(not self.args.listen),
244            tor=self._tor,
245            reactor=self._reactor,
246            timing=self.args.timing)
247        self._transit_receiver = tr
248        # When I made it possible to override APPID with a CLI argument
249        # (issue #113), I forgot to also change this w.derive_key() (issue
250        # #339). We're stuck with it now. Use a local constant to make this
251        # clear.
252        BUG339_APPID = u"lothar.com/wormhole/text-or-file-xfer"
253        transit_key = w.derive_key(BUG339_APPID + u"/transit-key",
254                                   tr.TRANSIT_KEY_LENGTH)
255        tr.set_transit_key(transit_key)
257        tr.add_connection_hints(sender_transit.get("hints-v1", []))
258        receiver_abilities = tr.get_connection_abilities()
259        receiver_hints = yield tr.get_connection_hints()
260        receiver_transit = {
261            "abilities-v1": receiver_abilities,
262            "hints-v1": receiver_hints,
263        }
264        self._send_data({u"transit": receiver_transit}, w)
265        # TODO: send more hints as the TransitReceiver produces them
267    @inlineCallbacks
268    def _parse_offer(self, them_d, w):
269        if "message" in them_d:
270            self._handle_text(them_d, w)
271            returnValue(None)
272        # transit will be created by this point, but not connected
273        if "file" in them_d:
274            f = self._handle_file(them_d)
275            self._send_permission(w)
276            rp = yield self._establish_transit()
277            datahash = yield self._transfer_data(rp, f)
278            self._write_file(f)
279            yield self._close_transit(rp, datahash)
280        elif "directory" in them_d:
281            f = self._handle_directory(them_d)
282            self._send_permission(w)
283            rp = yield self._establish_transit()
284            datahash = yield self._transfer_data(rp, f)
285            self._write_directory(f)
286            yield self._close_transit(rp, datahash)
287        else:
288            self._msg(u"I don't know what they're offering\n")
289            self._msg(u"Offer details: %r" % (them_d, ))
290            raise RespondError("unknown offer type")
292    def _handle_text(self, them_d, w):
293        # we're receiving a text message
294        self.args.timing.add("print")
295        print(them_d["message"], file=self.args.stdout)
296        self._send_data({"answer": {"message_ack": "ok"}}, w)
298    def _handle_file(self, them_d):
299        file_data = them_d["file"]
300        self.abs_destname = self._decide_destname("file",
301                                                  file_data["filename"])
302        self.xfersize = file_data["filesize"]
303        free = estimate_free_space(self.abs_destname)
304        if free is not None and free < self.xfersize:
305            self._msg(u"Error: insufficient free space (%sB) for file (%sB)" %
306                      (free, self.xfersize))
307            raise TransferRejectedError()
309        self._msg(u"Receiving file (%s) into: %s" %
310                  (naturalsize(self.xfersize),
311                   os.path.basename(self.abs_destname)))
312        self._ask_permission()
313        tmp_destname = self.abs_destname + ".tmp"
314        return open(tmp_destname, "wb")
316    def _handle_directory(self, them_d):
317        file_data = them_d["directory"]
318        zipmode = file_data["mode"]
319        if zipmode != "zipfile/deflated":
320            self._msg(u"Error: unknown directory-transfer mode '%s'" %
321                      (zipmode, ))
322            raise RespondError("unknown mode")
323        self.abs_destname = self._decide_destname("directory",
324                                                  file_data["dirname"])
325        self.xfersize = file_data["zipsize"]
326        free = estimate_free_space(self.abs_destname)
327        if free is not None and free < file_data["numbytes"]:
328            self._msg(
329                u"Error: insufficient free space (%sB) for directory (%sB)" %
330                (free, file_data["numbytes"]))
331            raise TransferRejectedError()
333        self._msg(u"Receiving directory (%s) into: %s/" %
334                  (naturalsize(self.xfersize),
335                   os.path.basename(self.abs_destname)))
336        self._msg(u"%d files, %s (uncompressed)" %
337                  (file_data["numfiles"], naturalsize(file_data["numbytes"])))
338        self._ask_permission()
339        f = tempfile.SpooledTemporaryFile()
340        # workaround for https://bugs.python.org/issue26175 (STF doesn't
341        # fully implement IOBase abstract class), which breaks the new
342        # zipfile in py3.7.0 that expects .seekable
343        if not hasattr(f, "seekable"):
344            # AFAICT all the filetypes that STF wraps can seek
345            f.seekable = lambda: True
346        return f
348    def _decide_destname(self, mode, destname):
349        # the basename() is intended to protect us against
350        # "~/.ssh/authorized_keys" and other attacks
351        destname = os.path.basename(destname)
352        if self.args.output_file:
353            destname = self.args.output_file  # override
354        abs_destname = os.path.abspath(os.path.join(self.args.cwd, destname))
356        # get confirmation from the user before writing to the local directory
357        if os.path.exists(abs_destname):
358            if self.args.output_file:  # overwrite is intentional
359                self._msg(u"Overwriting '%s'" % destname)
360                if self.args.accept_file:
361                    self._remove_existing(abs_destname)
362            else:
363                self._msg(
364                    u"Error: refusing to overwrite existing '%s'" % destname)
365                raise TransferRejectedError()
366        return abs_destname
368    def _remove_existing(self, path):
369        if os.path.isfile(path):
370            os.remove(path)
371        if os.path.isdir(path):
372            shutil.rmtree(path)
374    def _ask_permission(self):
375        with self.args.timing.add("permission", waiting="user") as t:
376            while True and not self.args.accept_file:
377                ok = six.moves.input("ok? (Y/n): ")
378                if ok.lower().startswith("y") or len(ok) == 0:
379                    if os.path.exists(self.abs_destname):
380                        self._remove_existing(self.abs_destname)
381                    break
382                print(u"transfer rejected", file=sys.stderr)
383                t.detail(answer="no")
384                raise TransferRejectedError()
385            t.detail(answer="yes")
387    def _send_permission(self, w):
388        self._send_data({"answer": {"file_ack": "ok"}}, w)
390    @inlineCallbacks
391    def _establish_transit(self):
392        record_pipe = yield self._transit_receiver.connect()
393        self.args.timing.add("transit connected")
394        returnValue(record_pipe)
396    @inlineCallbacks
397    def _transfer_data(self, record_pipe, f):
398        # now receive the rest of the owl
399        self._msg(u"Receiving (%s).." % record_pipe.describe())
401        with self.args.timing.add("rx file"):
402            progress = tqdm(
403                file=self.args.stderr,
404                disable=self.args.hide_progress,
405                unit="B",
406                unit_scale=True,
407                total=self.xfersize)
408            hasher = hashlib.sha256()
409            with progress:
410                received = yield record_pipe.writeToFile(
411                    f, self.xfersize, progress.update, hasher.update)
412            datahash = hasher.digest()
414        # except TransitError
415        if received < self.xfersize:
416            self._msg()
417            self._msg(u"Connection dropped before full file received")
418            self._msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
419            raise TransferError("Connection dropped before full file received")
420        assert received == self.xfersize
421        returnValue(datahash)
423    def _write_file(self, f):
424        tmp_name = f.name
425        f.close()
426        os.rename(tmp_name, self.abs_destname)
427        self._msg(u"Received file written to %s" % os.path.basename(
428            self.abs_destname))
430    def _extract_file(self, zf, info, extract_dir):
431        """
432        the zipfile module does not restore file permissions
433        so we'll do it manually
434        """
435        out_path = os.path.join(extract_dir, info.filename)
436        out_path = os.path.abspath(out_path)
437        if not out_path.startswith(extract_dir):
438            raise ValueError(
439                "malicious zipfile, %s outside of extract_dir %s" %
440                (info.filename, extract_dir))
442        zf.extract(info.filename, path=extract_dir)
444        # not sure why zipfiles store the perms 16 bits away but they do
445        perm = info.external_attr >> 16
446        os.chmod(out_path, perm)
448    def _write_directory(self, f):
450        self._msg(u"Unpacking zipfile..")
451        with self.args.timing.add("unpack zip"):
452            with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
453                for info in zf.infolist():
454                    self._extract_file(zf, info, self.abs_destname)
456            self._msg(u"Received files written to %s/" % os.path.basename(
457                self.abs_destname))
458            f.close()
460    @inlineCallbacks
461    def _close_transit(self, record_pipe, datahash):
462        datahash_hex = bytes_to_hexstr(datahash)
463        ack = {u"ack": u"ok", u"sha256": datahash_hex}
464        ack_bytes = dict_to_bytes(ack)
465        with self.args.timing.add("send ack"):
466            yield record_pipe.send_record(ack_bytes)
467            yield record_pipe.close()