1#!/usr/bin/env python3
2# Copyright (c) 2014-2019 The Bitcoin Core developers
3# Distributed under the MIT software license, see the accompanying
4# file COPYING or http://www.opensource.org/licenses/mit-license.php.
5"""Test wallet import RPCs.
6
7Test rescan behavior of importaddress, importpubkey, importprivkey, and
8importmulti RPCs with different types of keys and rescan options.
9
10In the first part of the test, node 0 creates an address for each type of
11import RPC call and sends BTC to it. Then other nodes import the addresses,
12and the test makes listtransactions and getbalance calls to confirm that the
13importing node either did or did not execute rescans picking up the send
14transactions.
15
16In the second part of the test, node 0 sends more BTC to each address, and the
17test makes more listtransactions and getbalance calls to confirm that the
18importing nodes pick up the new transactions regardless of whether rescans
19happened previously.
20"""
21
22from test_framework.test_framework import BitcoinTestFramework
23from test_framework.address import AddressType
24from test_framework.util import (
25    connect_nodes,
26    assert_equal,
27    set_node_times,
28)
29
30import collections
31from decimal import Decimal
32import enum
33import itertools
34import random
35
36Call = enum.Enum("Call", "single multiaddress multiscript")
37Data = enum.Enum("Data", "address pub priv")
38Rescan = enum.Enum("Rescan", "no yes late_timestamp")
39
40
41class Variant(collections.namedtuple("Variant", "call data address_type rescan prune")):
42    """Helper for importing one key and verifying scanned transactions."""
43
44    def do_import(self, timestamp):
45        """Call one key import RPC."""
46        rescan = self.rescan == Rescan.yes
47
48        assert_equal(self.address["solvable"], True)
49        assert_equal(self.address["isscript"], self.address_type == AddressType.p2sh_segwit)
50        assert_equal(self.address["iswitness"], self.address_type == AddressType.bech32)
51        if self.address["isscript"]:
52            assert_equal(self.address["embedded"]["isscript"], False)
53            assert_equal(self.address["embedded"]["iswitness"], True)
54
55        if self.call == Call.single:
56            if self.data == Data.address:
57                response = self.node.importaddress(address=self.address["address"], label=self.label, rescan=rescan)
58            elif self.data == Data.pub:
59                response = self.node.importpubkey(pubkey=self.address["pubkey"], label=self.label, rescan=rescan)
60            elif self.data == Data.priv:
61                response = self.node.importprivkey(privkey=self.key, label=self.label, rescan=rescan)
62            assert_equal(response, None)
63
64        elif self.call in (Call.multiaddress, Call.multiscript):
65            request = {
66                "scriptPubKey": {
67                    "address": self.address["address"]
68                } if self.call == Call.multiaddress else self.address["scriptPubKey"],
69                "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0),
70                "pubkeys": [self.address["pubkey"]] if self.data == Data.pub else [],
71                "keys": [self.key] if self.data == Data.priv else [],
72                "label": self.label,
73                "watchonly": self.data != Data.priv
74            }
75            if self.address_type == AddressType.p2sh_segwit and self.data != Data.address:
76                # We need solving data when providing a pubkey or privkey as data
77                request.update({"redeemscript": self.address['embedded']['scriptPubKey']})
78            response = self.node.importmulti(
79                requests=[request],
80                options={"rescan": self.rescan in (Rescan.yes, Rescan.late_timestamp)},
81            )
82            assert_equal(response, [{"success": True}])
83
84    def check(self, txid=None, amount=None, confirmation_height=None):
85        """Verify that listtransactions/listreceivedbyaddress return expected values."""
86
87        txs = self.node.listtransactions(label=self.label, count=10000, include_watchonly=True)
88        current_height = self.node.getblockcount()
89        assert_equal(len(txs), self.expected_txs)
90
91        addresses = self.node.listreceivedbyaddress(minconf=0, include_watchonly=True, address_filter=self.address['address'])
92        if self.expected_txs:
93            assert_equal(len(addresses[0]["txids"]), self.expected_txs)
94
95        if txid is not None:
96            tx, = [tx for tx in txs if tx["txid"] == txid]
97            assert_equal(tx["label"], self.label)
98            assert_equal(tx["address"], self.address["address"])
99            assert_equal(tx["amount"], amount)
100            assert_equal(tx["category"], "receive")
101            assert_equal(tx["label"], self.label)
102            assert_equal(tx["txid"], txid)
103            assert_equal(tx["confirmations"], 1 + current_height - confirmation_height)
104            assert_equal("trusted" not in tx, True)
105
106            address, = [ad for ad in addresses if txid in ad["txids"]]
107            assert_equal(address["address"], self.address["address"])
108            assert_equal(address["amount"], self.expected_balance)
109            assert_equal(address["confirmations"], 1 + current_height - confirmation_height)
110            # Verify the transaction is correctly marked watchonly depending on
111            # whether the transaction pays to an imported public key or
112            # imported private key. The test setup ensures that transaction
113            # inputs will not be from watchonly keys (important because
114            # involvesWatchonly will be true if either the transaction output
115            # or inputs are watchonly).
116            if self.data != Data.priv:
117                assert_equal(address["involvesWatchonly"], True)
118            else:
119                assert_equal("involvesWatchonly" not in address, True)
120
121
122# List of Variants for each way a key or address could be imported.
123IMPORT_VARIANTS = [Variant(*variants) for variants in itertools.product(Call, Data, AddressType, Rescan, (False, True))]
124
125# List of nodes to import keys to. Half the nodes will have pruning disabled,
126# half will have it enabled. Different nodes will be used for imports that are
127# expected to cause rescans, and imports that are not expected to cause
128# rescans, in order to prevent rescans during later imports picking up
129# transactions associated with earlier imports. This makes it easier to keep
130# track of expected balances and transactions.
131ImportNode = collections.namedtuple("ImportNode", "prune rescan")
132IMPORT_NODES = [ImportNode(*fields) for fields in itertools.product((False, True), repeat=2)]
133
134# Rescans start at the earliest block up to 2 hours before the key timestamp.
135TIMESTAMP_WINDOW = 2 * 60 * 60
136
137AMOUNT_DUST = 0.00054600
138
139
140def get_rand_amount():
141    r = random.uniform(AMOUNT_DUST, 1)
142    return Decimal(str(round(r, 8)))
143
144
145class ImportRescanTest(BitcoinTestFramework):
146    def set_test_params(self):
147        self.num_nodes = 2 + len(IMPORT_NODES)
148        self.supports_cli = False
149
150    def skip_test_if_missing_module(self):
151        self.skip_if_no_wallet()
152
153    def setup_network(self):
154        self.extra_args = [[] for _ in range(self.num_nodes)]
155        for i, import_node in enumerate(IMPORT_NODES, 2):
156            if import_node.prune:
157                self.extra_args[i] += ["-prune=1"]
158
159        self.add_nodes(self.num_nodes, extra_args=self.extra_args)
160
161        # Import keys with pruning disabled
162        self.start_nodes(extra_args=[[]] * self.num_nodes)
163        for n in self.nodes:
164            n.importprivkey(privkey=n.get_deterministic_priv_key().key, label='coinbase')
165        self.stop_nodes()
166
167        self.start_nodes()
168        for i in range(1, self.num_nodes):
169            connect_nodes(self.nodes[i], 0)
170
171    def run_test(self):
172        # Create one transaction on node 0 with a unique amount for
173        # each possible type of wallet import RPC.
174        for i, variant in enumerate(IMPORT_VARIANTS):
175            variant.label = "label {} {}".format(i, variant)
176            variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress(
177                label=variant.label,
178                address_type=variant.address_type.value,
179            ))
180            variant.key = self.nodes[1].dumpprivkey(variant.address["address"])
181            variant.initial_amount = get_rand_amount()
182            variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount)
183            self.nodes[0].generate(1)  # Generate one block for each send
184            variant.confirmation_height = self.nodes[0].getblockcount()
185            variant.timestamp = self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"]
186
187        # Generate a block further in the future (past the rescan window).
188        assert_equal(self.nodes[0].getrawmempool(), [])
189        set_node_times(
190            self.nodes,
191            self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] + TIMESTAMP_WINDOW + 1,
192        )
193        self.nodes[0].generate(1)
194        self.sync_all()
195
196        # For each variation of wallet key import, invoke the import RPC and
197        # check the results from getbalance and listtransactions.
198        for variant in IMPORT_VARIANTS:
199            self.log.info('Run import for variant {}'.format(variant))
200            expect_rescan = variant.rescan == Rescan.yes
201            variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))]
202            variant.do_import(variant.timestamp)
203            if expect_rescan:
204                variant.expected_balance = variant.initial_amount
205                variant.expected_txs = 1
206                variant.check(variant.initial_txid, variant.initial_amount, variant.confirmation_height)
207            else:
208                variant.expected_balance = 0
209                variant.expected_txs = 0
210                variant.check()
211
212        # Create new transactions sending to each address.
213        for i, variant in enumerate(IMPORT_VARIANTS):
214            variant.sent_amount = get_rand_amount()
215            variant.sent_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.sent_amount)
216            self.nodes[0].generate(1)  # Generate one block for each send
217            variant.confirmation_height = self.nodes[0].getblockcount()
218
219        assert_equal(self.nodes[0].getrawmempool(), [])
220        self.sync_all()
221
222        # Check the latest results from getbalance and listtransactions.
223        for variant in IMPORT_VARIANTS:
224            self.log.info('Run check for variant {}'.format(variant))
225            variant.expected_balance += variant.sent_amount
226            variant.expected_txs += 1
227            variant.check(variant.sent_txid, variant.sent_amount, variant.confirmation_height)
228
229if __name__ == "__main__":
230    ImportRescanTest().main()
231