1from __future__ import print_function
2
3import hashlib
4import os
5import shutil
6import sys
7import tempfile
8import zipfile
9
10import six
11from humanize import naturalsize
12from tqdm import tqdm
13from twisted.internet import reactor
14from twisted.internet.defer import inlineCallbacks, returnValue
15from twisted.python import log
16from wormhole import __version__, create, input_with_completion
17
18from ..errors import TransferError
19from ..transit import TransitReceiver
20from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
21                    estimate_free_space)
22from .welcome import handle_welcome
23
24APPID = u"lothar.com/wormhole/text-or-file-xfer"
25
26KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0))
27VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
28
29
30class RespondError(Exception):
31    def __init__(self, response):
32        self.response = response
33
34
35class TransferRejectedError(RespondError):
36    def __init__(self):
37        RespondError.__init__(self, "transfer rejected")
38
39
40def receive(args, reactor=reactor, _debug_stash_wormhole=None):
41    """I implement 'wormhole receive'. I return a Deferred that fires with
42    None (for success), or signals one of the following errors:
43    * WrongPasswordError: the two sides didn't use matching passwords
44    * Timeout: something didn't happen fast enough for our tastes
45    * TransferError: the sender rejected the transfer: verifier mismatch
46    * any other error: something unexpected happened
47    """
48    r = Receiver(args, reactor)
49    d = r.go()
50    if _debug_stash_wormhole is not None:
51        _debug_stash_wormhole.append(r._w)
52    return d
53
54
55class Receiver:
56    def __init__(self, args, reactor=reactor):
57        assert isinstance(args.relay_url, type(u""))
58        self.args = args
59        self._reactor = reactor
60        self._tor = None
61        self._transit_receiver = None
62
63    def _msg(self, *args, **kwargs):
64        print(*args, file=self.args.stderr, **kwargs)
65
66    @inlineCallbacks
67    def go(self):
68        if self.args.tor:
69            with self.args.timing.add("import", which="tor_manager"):
70                from ..tor_manager import get_tor
71            # For now, block everything until Tor has started. Soon: launch
72            # tor in parallel with everything else, make sure the Tor object
73            # can lazy-provide an endpoint, and overlap the startup process
74            # with the user handing off the wormhole code
75            self._tor = yield get_tor(
76                self._reactor,
77                self.args.launch_tor,
78                self.args.tor_control_port,
79                timing=self.args.timing)
80
81        w = create(
82            self.args.appid or APPID,
83            self.args.relay_url,
84            self._reactor,
85            tor=self._tor,
86            timing=self.args.timing)
87        self._w = w  # so tests can wait on events too
88
89        # I wanted to do this instead:
90        #
91        #    try:
92        #        yield self._go(w, tor)
93        #    finally:
94        #        yield w.close()
95        #
96        # but when _go had a UsageError, the stacktrace was always displayed
97        # as coming from the "yield self._go" line, which wasn't very useful
98        # for tracking it down.
99        d = self._go(w)
100
101        # if we succeed, we should close and return the w.close results
102        # (which might be an error)
103        @inlineCallbacks
104        def _good(res):
105            yield w.close()  # wait for ack
106            returnValue(res)
107
108        # if we raise an error, we should close and then return the original
109        # error (the close might give us an error, but it isn't as important
110        # as the original one)
111        @inlineCallbacks
112        def _bad(f):
113            try:
114                yield w.close()  # might be an error too
115            except Exception:
116                pass
117            returnValue(f)
118
119        d.addCallbacks(_good, _bad)
120        yield d
121
122    @inlineCallbacks
123    def _go(self, w):
124        welcome = yield w.get_welcome()
125        handle_welcome(welcome, self.args.relay_url, __version__,
126                       self.args.stderr)
127
128        yield self._handle_code(w)
129
130        def on_slow_key():
131            print(u"Waiting for sender...", file=self.args.stderr)
132
133        notify = self._reactor.callLater(KEY_TIMER, on_slow_key)
134        try:
135            # We wait here until we connect to the server and see the senders
136            # PAKE message. If we used set_code() in the "human-selected
137            # offline codes" mode, then the sender might not have even
138            # started yet, so we might be sitting here for a while. Because
139            # of that possibility, it's probably not appropriate to give up
140            # automatically after some timeout. The user can express their
141            # impatience by quitting the program with control-C.
142            yield w.get_unverified_key()
143        finally:
144            if not notify.called:
145                notify.cancel()
146
147        def on_slow_verification():
148            print(
149                u"Key established, waiting for confirmation...",
150                file=self.args.stderr)
151
152        notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification)
153        try:
154            # We wait here until we've seen their VERSION message (which they
155            # send after seeing our PAKE message, and has the side-effect of
156            # verifying that we both share the same key). There is a
157            # round-trip between these two events, and we could experience a
158            # significant delay here if:
159            # * the relay server is being restarted
160            # * the network is very slow
161            # * the sender is very slow
162            # * the sender has quit (in which case we may wait forever)
163
164            # It would be reasonable to give up after waiting here for too
165            # long.
166            verifier_bytes = yield w.get_verifier()
167        finally:
168            if not notify.called:
169                notify.cancel()
170        self._show_verifier(verifier_bytes)
171
172        want_offer = True
173
174        while True:
175            them_d = yield self._get_data(w)
176            # print("GOT", them_d)
177            recognized = False
178            if u"transit" in them_d:
179                recognized = True
180                yield self._parse_transit(them_d[u"transit"], w)
181            if u"offer" in them_d:
182                recognized = True
183                if not want_offer:
184                    raise TransferError("duplicate offer")
185                want_offer = False
186                try:
187                    yield self._parse_offer(them_d[u"offer"], w)
188                except RespondError as r:
189                    self._send_data({"error": r.response}, w)
190                    raise TransferError(r.response)
191                returnValue(None)
192            if not recognized:
193                log.msg("unrecognized message %r" % (them_d, ))
194
195    def _send_data(self, data, w):
196        data_bytes = dict_to_bytes(data)
197        w.send_message(data_bytes)
198
199    @inlineCallbacks
200    def _get_data(self, w):
201        # this may raise WrongPasswordError
202        them_bytes = yield w.get_message()
203        them_d = bytes_to_dict(them_bytes)
204        if "error" in them_d:
205            raise TransferError(them_d["error"])
206        returnValue(them_d)
207
208    @inlineCallbacks
209    def _handle_code(self, w):
210        code = self.args.code
211        if self.args.zeromode:
212            assert not code
213            code = u"0-"
214        if code:
215            w.set_code(code)
216        else:
217            prompt = "Enter receive wormhole code: "
218            used_completion = yield input_with_completion(
219                prompt, w.input_code(), self._reactor)
220            if not used_completion:
221                print(
222                    " (note: you can use <Tab> to complete words)",
223                    file=self.args.stderr)
224        yield w.get_code()
225
226    def _show_verifier(self, verifier_bytes):
227        verifier_hex = bytes_to_hexstr(verifier_bytes)
228        if self.args.verify:
229            self._msg(u"Verifier %s." % verifier_hex)
230
231    @inlineCallbacks
232    def _parse_transit(self, sender_transit, w):
233        if self._transit_receiver:
234            # TODO: accept multiple messages, add the additional hints to the
235            # existing TransitReceiver
236            return
237        yield self._build_transit(w, sender_transit)
238
239    @inlineCallbacks
240    def _build_transit(self, w, sender_transit):
241        tr = TransitReceiver(
242            self.args.transit_helper,
243            no_listen=(not self.args.listen),
244            tor=self._tor,
245            reactor=self._reactor,
246            timing=self.args.timing)
247        self._transit_receiver = tr
248        # When I made it possible to override APPID with a CLI argument
249        # (issue #113), I forgot to also change this w.derive_key() (issue
250        # #339). We're stuck with it now. Use a local constant to make this
251        # clear.
252        BUG339_APPID = u"lothar.com/wormhole/text-or-file-xfer"
253        transit_key = w.derive_key(BUG339_APPID + u"/transit-key",
254                                   tr.TRANSIT_KEY_LENGTH)
255        tr.set_transit_key(transit_key)
256
257        tr.add_connection_hints(sender_transit.get("hints-v1", []))
258        receiver_abilities = tr.get_connection_abilities()
259        receiver_hints = yield tr.get_connection_hints()
260        receiver_transit = {
261            "abilities-v1": receiver_abilities,
262            "hints-v1": receiver_hints,
263        }
264        self._send_data({u"transit": receiver_transit}, w)
265        # TODO: send more hints as the TransitReceiver produces them
266
267    @inlineCallbacks
268    def _parse_offer(self, them_d, w):
269        if "message" in them_d:
270            self._handle_text(them_d, w)
271            returnValue(None)
272        # transit will be created by this point, but not connected
273        if "file" in them_d:
274            f = self._handle_file(them_d)
275            self._send_permission(w)
276            rp = yield self._establish_transit()
277            datahash = yield self._transfer_data(rp, f)
278            self._write_file(f)
279            yield self._close_transit(rp, datahash)
280        elif "directory" in them_d:
281            f = self._handle_directory(them_d)
282            self._send_permission(w)
283            rp = yield self._establish_transit()
284            datahash = yield self._transfer_data(rp, f)
285            self._write_directory(f)
286            yield self._close_transit(rp, datahash)
287        else:
288            self._msg(u"I don't know what they're offering\n")
289            self._msg(u"Offer details: %r" % (them_d, ))
290            raise RespondError("unknown offer type")
291
292    def _handle_text(self, them_d, w):
293        # we're receiving a text message
294        self.args.timing.add("print")
295        print(them_d["message"], file=self.args.stdout)
296        self._send_data({"answer": {"message_ack": "ok"}}, w)
297
298    def _handle_file(self, them_d):
299        file_data = them_d["file"]
300        self.abs_destname = self._decide_destname("file",
301                                                  file_data["filename"])
302        self.xfersize = file_data["filesize"]
303        free = estimate_free_space(self.abs_destname)
304        if free is not None and free < self.xfersize:
305            self._msg(u"Error: insufficient free space (%sB) for file (%sB)" %
306                      (free, self.xfersize))
307            raise TransferRejectedError()
308
309        self._msg(u"Receiving file (%s) into: %s" %
310                  (naturalsize(self.xfersize),
311                   os.path.basename(self.abs_destname)))
312        self._ask_permission()
313        tmp_destname = self.abs_destname + ".tmp"
314        return open(tmp_destname, "wb")
315
316    def _handle_directory(self, them_d):
317        file_data = them_d["directory"]
318        zipmode = file_data["mode"]
319        if zipmode != "zipfile/deflated":
320            self._msg(u"Error: unknown directory-transfer mode '%s'" %
321                      (zipmode, ))
322            raise RespondError("unknown mode")
323        self.abs_destname = self._decide_destname("directory",
324                                                  file_data["dirname"])
325        self.xfersize = file_data["zipsize"]
326        free = estimate_free_space(self.abs_destname)
327        if free is not None and free < file_data["numbytes"]:
328            self._msg(
329                u"Error: insufficient free space (%sB) for directory (%sB)" %
330                (free, file_data["numbytes"]))
331            raise TransferRejectedError()
332
333        self._msg(u"Receiving directory (%s) into: %s/" %
334                  (naturalsize(self.xfersize),
335                   os.path.basename(self.abs_destname)))
336        self._msg(u"%d files, %s (uncompressed)" %
337                  (file_data["numfiles"], naturalsize(file_data["numbytes"])))
338        self._ask_permission()
339        f = tempfile.SpooledTemporaryFile()
340        # workaround for https://bugs.python.org/issue26175 (STF doesn't
341        # fully implement IOBase abstract class), which breaks the new
342        # zipfile in py3.7.0 that expects .seekable
343        if not hasattr(f, "seekable"):
344            # AFAICT all the filetypes that STF wraps can seek
345            f.seekable = lambda: True
346        return f
347
348    def _decide_destname(self, mode, destname):
349        # the basename() is intended to protect us against
350        # "~/.ssh/authorized_keys" and other attacks
351        destname = os.path.basename(destname)
352        if self.args.output_file:
353            destname = self.args.output_file  # override
354        abs_destname = os.path.abspath(os.path.join(self.args.cwd, destname))
355
356        # get confirmation from the user before writing to the local directory
357        if os.path.exists(abs_destname):
358            if self.args.output_file:  # overwrite is intentional
359                self._msg(u"Overwriting '%s'" % destname)
360                if self.args.accept_file:
361                    self._remove_existing(abs_destname)
362            else:
363                self._msg(
364                    u"Error: refusing to overwrite existing '%s'" % destname)
365                raise TransferRejectedError()
366        return abs_destname
367
368    def _remove_existing(self, path):
369        if os.path.isfile(path):
370            os.remove(path)
371        if os.path.isdir(path):
372            shutil.rmtree(path)
373
374    def _ask_permission(self):
375        with self.args.timing.add("permission", waiting="user") as t:
376            while True and not self.args.accept_file:
377                ok = six.moves.input("ok? (Y/n): ")
378                if ok.lower().startswith("y") or len(ok) == 0:
379                    if os.path.exists(self.abs_destname):
380                        self._remove_existing(self.abs_destname)
381                    break
382                print(u"transfer rejected", file=sys.stderr)
383                t.detail(answer="no")
384                raise TransferRejectedError()
385            t.detail(answer="yes")
386
387    def _send_permission(self, w):
388        self._send_data({"answer": {"file_ack": "ok"}}, w)
389
390    @inlineCallbacks
391    def _establish_transit(self):
392        record_pipe = yield self._transit_receiver.connect()
393        self.args.timing.add("transit connected")
394        returnValue(record_pipe)
395
396    @inlineCallbacks
397    def _transfer_data(self, record_pipe, f):
398        # now receive the rest of the owl
399        self._msg(u"Receiving (%s).." % record_pipe.describe())
400
401        with self.args.timing.add("rx file"):
402            progress = tqdm(
403                file=self.args.stderr,
404                disable=self.args.hide_progress,
405                unit="B",
406                unit_scale=True,
407                total=self.xfersize)
408            hasher = hashlib.sha256()
409            with progress:
410                received = yield record_pipe.writeToFile(
411                    f, self.xfersize, progress.update, hasher.update)
412            datahash = hasher.digest()
413
414        # except TransitError
415        if received < self.xfersize:
416            self._msg()
417            self._msg(u"Connection dropped before full file received")
418            self._msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
419            raise TransferError("Connection dropped before full file received")
420        assert received == self.xfersize
421        returnValue(datahash)
422
423    def _write_file(self, f):
424        tmp_name = f.name
425        f.close()
426        os.rename(tmp_name, self.abs_destname)
427        self._msg(u"Received file written to %s" % os.path.basename(
428            self.abs_destname))
429
430    def _extract_file(self, zf, info, extract_dir):
431        """
432        the zipfile module does not restore file permissions
433        so we'll do it manually
434        """
435        out_path = os.path.join(extract_dir, info.filename)
436        out_path = os.path.abspath(out_path)
437        if not out_path.startswith(extract_dir):
438            raise ValueError(
439                "malicious zipfile, %s outside of extract_dir %s" %
440                (info.filename, extract_dir))
441
442        zf.extract(info.filename, path=extract_dir)
443
444        # not sure why zipfiles store the perms 16 bits away but they do
445        perm = info.external_attr >> 16
446        os.chmod(out_path, perm)
447
448    def _write_directory(self, f):
449
450        self._msg(u"Unpacking zipfile..")
451        with self.args.timing.add("unpack zip"):
452            with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
453                for info in zf.infolist():
454                    self._extract_file(zf, info, self.abs_destname)
455
456            self._msg(u"Received files written to %s/" % os.path.basename(
457                self.abs_destname))
458            f.close()
459
460    @inlineCallbacks
461    def _close_transit(self, record_pipe, datahash):
462        datahash_hex = bytes_to_hexstr(datahash)
463        ack = {u"ack": u"ok", u"sha256": datahash_hex}
464        ack_bytes = dict_to_bytes(ack)
465        with self.args.timing.add("send ack"):
466            yield record_pipe.send_record(ack_bytes)
467            yield record_pipe.close()
468