1#!/usr/bin/env python3
2# Copyright (c) 2014-2020 The Bitcoin Core developers
3# Distributed under the MIT software license, see the accompanying
4# file COPYING or http://www.opensource.org/licenses/mit-license.php.
5"""Test wallet import RPCs.
6
7Test rescan behavior of importaddress, importpubkey, importprivkey, and
8importmulti RPCs with different types of keys and rescan options.
9
10In the first part of the test, node 0 creates an address for each type of
11import RPC call and sends BTC to it. Then other nodes import the addresses,
12and the test makes listtransactions and getbalance calls to confirm that the
13importing node either did or did not execute rescans picking up the send
14transactions.
15
16In the second part of the test, node 0 sends more BTC to each address, and the
17test makes more listtransactions and getbalance calls to confirm that the
18importing nodes pick up the new transactions regardless of whether rescans
19happened previously.
20"""
21
22from test_framework.test_framework import BitcoinTestFramework
23from test_framework.address import AddressType
24from test_framework.util import (
25    assert_equal,
26    set_node_times,
27)
28
29import collections
30from decimal import Decimal
31import enum
32import itertools
33import random
34
35Call = enum.Enum("Call", "single multiaddress multiscript")
36Data = enum.Enum("Data", "address pub priv")
37Rescan = enum.Enum("Rescan", "no yes late_timestamp")
38
39
40class Variant(collections.namedtuple("Variant", "call data address_type rescan prune")):
41    """Helper for importing one key and verifying scanned transactions."""
42    def do_import(self, timestamp):
43        """Call one key import RPC."""
44        rescan = self.rescan == Rescan.yes
45
46        assert_equal(self.address["solvable"], True)
47        assert_equal(self.address["isscript"], self.address_type == AddressType.p2sh_segwit)
48        assert_equal(self.address["iswitness"], self.address_type == AddressType.bech32)
49        if self.address["isscript"]:
50            assert_equal(self.address["embedded"]["isscript"], False)
51            assert_equal(self.address["embedded"]["iswitness"], True)
52
53        if self.call == Call.single:
54            if self.data == Data.address:
55                response = self.node.importaddress(address=self.address["address"], label=self.label, rescan=rescan)
56            elif self.data == Data.pub:
57                response = self.node.importpubkey(pubkey=self.address["pubkey"], label=self.label, rescan=rescan)
58            elif self.data == Data.priv:
59                response = self.node.importprivkey(privkey=self.key, label=self.label, rescan=rescan)
60            assert_equal(response, None)
61
62        elif self.call in (Call.multiaddress, Call.multiscript):
63            request = {
64                "scriptPubKey": {
65                    "address": self.address["address"]
66                } if self.call == Call.multiaddress else self.address["scriptPubKey"],
67                "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0),
68                "pubkeys": [self.address["pubkey"]] if self.data == Data.pub else [],
69                "keys": [self.key] if self.data == Data.priv else [],
70                "label": self.label,
71                "watchonly": self.data != Data.priv
72            }
73            if self.address_type == AddressType.p2sh_segwit and self.data != Data.address:
74                # We need solving data when providing a pubkey or privkey as data
75                request.update({"redeemscript": self.address['embedded']['scriptPubKey']})
76            response = self.node.importmulti(
77                requests=[request],
78                options={"rescan": self.rescan in (Rescan.yes, Rescan.late_timestamp)},
79            )
80            assert_equal(response, [{"success": True}])
81
82    def check(self, txid=None, amount=None, confirmation_height=None):
83        """Verify that listtransactions/listreceivedbyaddress return expected values."""
84
85        txs = self.node.listtransactions(label=self.label, count=10000, include_watchonly=True)
86        current_height = self.node.getblockcount()
87        assert_equal(len(txs), self.expected_txs)
88
89        addresses = self.node.listreceivedbyaddress(minconf=0, include_watchonly=True, address_filter=self.address['address'])
90        if self.expected_txs:
91            assert_equal(len(addresses[0]["txids"]), self.expected_txs)
92
93        if txid is not None:
94            tx, = [tx for tx in txs if tx["txid"] == txid]
95            assert_equal(tx["label"], self.label)
96            assert_equal(tx["address"], self.address["address"])
97            assert_equal(tx["amount"], amount)
98            assert_equal(tx["category"], "receive")
99            assert_equal(tx["label"], self.label)
100            assert_equal(tx["txid"], txid)
101            assert_equal(tx["confirmations"], 1 + current_height - confirmation_height)
102            assert_equal("trusted" not in tx, True)
103
104            address, = [ad for ad in addresses if txid in ad["txids"]]
105            assert_equal(address["address"], self.address["address"])
106            assert_equal(address["amount"], self.expected_balance)
107            assert_equal(address["confirmations"], 1 + current_height - confirmation_height)
108            # Verify the transaction is correctly marked watchonly depending on
109            # whether the transaction pays to an imported public key or
110            # imported private key. The test setup ensures that transaction
111            # inputs will not be from watchonly keys (important because
112            # involvesWatchonly will be true if either the transaction output
113            # or inputs are watchonly).
114            if self.data != Data.priv:
115                assert_equal(address["involvesWatchonly"], True)
116            else:
117                assert_equal("involvesWatchonly" not in address, True)
118
119
120# List of Variants for each way a key or address could be imported.
121IMPORT_VARIANTS = [Variant(*variants) for variants in itertools.product(Call, Data, AddressType, Rescan, (False, True))]
122
123# List of nodes to import keys to. Half the nodes will have pruning disabled,
124# half will have it enabled. Different nodes will be used for imports that are
125# expected to cause rescans, and imports that are not expected to cause
126# rescans, in order to prevent rescans during later imports picking up
127# transactions associated with earlier imports. This makes it easier to keep
128# track of expected balances and transactions.
129ImportNode = collections.namedtuple("ImportNode", "prune rescan")
130IMPORT_NODES = [ImportNode(*fields) for fields in itertools.product((False, True), repeat=2)]
131
132# Rescans start at the earliest block up to 2 hours before the key timestamp.
133TIMESTAMP_WINDOW = 2 * 60 * 60
134
135AMOUNT_DUST = 0.00000546
136
137
138def get_rand_amount():
139    r = random.uniform(AMOUNT_DUST, 1)
140    return Decimal(str(round(r, 8)))
141
142
143class ImportRescanTest(BitcoinTestFramework):
144    def set_test_params(self):
145        self.num_nodes = 2 + len(IMPORT_NODES)
146        self.supports_cli = False
147        self.rpc_timeout = 120
148
149    def skip_test_if_missing_module(self):
150        self.skip_if_no_wallet()
151
152    def setup_network(self):
153        self.extra_args = [[] for _ in range(self.num_nodes)]
154        for i, import_node in enumerate(IMPORT_NODES, 2):
155            if import_node.prune:
156                self.extra_args[i] += ["-prune=1"]
157
158        self.add_nodes(self.num_nodes, extra_args=self.extra_args)
159
160        # Import keys with pruning disabled
161        self.start_nodes(extra_args=[[]] * self.num_nodes)
162        self.import_deterministic_coinbase_privkeys()
163        self.stop_nodes()
164
165        self.start_nodes()
166        for i in range(1, self.num_nodes):
167            self.connect_nodes(i, 0)
168
169    def run_test(self):
170        # Create one transaction on node 0 with a unique amount for
171        # each possible type of wallet import RPC.
172        for i, variant in enumerate(IMPORT_VARIANTS):
173            variant.label = "label {} {}".format(i, variant)
174            variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress(
175                label=variant.label,
176                address_type=variant.address_type.value,
177            ))
178            variant.key = self.nodes[1].dumpprivkey(variant.address["address"])
179            variant.initial_amount = get_rand_amount()
180            variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount)
181            self.nodes[0].generate(1)  # Generate one block for each send
182            variant.confirmation_height = self.nodes[0].getblockcount()
183            variant.timestamp = self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"]
184        self.sync_all() # Conclude sync before calling setmocktime to avoid timeouts
185
186        # Generate a block further in the future (past the rescan window).
187        assert_equal(self.nodes[0].getrawmempool(), [])
188        set_node_times(
189            self.nodes,
190            self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] + TIMESTAMP_WINDOW + 1,
191        )
192        self.nodes[0].generate(1)
193        self.sync_all()
194
195        # For each variation of wallet key import, invoke the import RPC and
196        # check the results from getbalance and listtransactions.
197        for variant in IMPORT_VARIANTS:
198            self.log.info('Run import for variant {}'.format(variant))
199            expect_rescan = variant.rescan == Rescan.yes
200            variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))]
201            variant.do_import(variant.timestamp)
202            if expect_rescan:
203                variant.expected_balance = variant.initial_amount
204                variant.expected_txs = 1
205                variant.check(variant.initial_txid, variant.initial_amount, variant.confirmation_height)
206            else:
207                variant.expected_balance = 0
208                variant.expected_txs = 0
209                variant.check()
210
211        # Create new transactions sending to each address.
212        for i, variant in enumerate(IMPORT_VARIANTS):
213            variant.sent_amount = get_rand_amount()
214            variant.sent_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.sent_amount)
215            self.nodes[0].generate(1)  # Generate one block for each send
216            variant.confirmation_height = self.nodes[0].getblockcount()
217
218        assert_equal(self.nodes[0].getrawmempool(), [])
219        self.sync_all()
220
221        # Check the latest results from getbalance and listtransactions.
222        for variant in IMPORT_VARIANTS:
223            self.log.info('Run check for variant {}'.format(variant))
224            variant.expected_balance += variant.sent_amount
225            variant.expected_txs += 1
226            variant.check(variant.sent_txid, variant.sent_amount, variant.confirmation_height)
227
228
229if __name__ == "__main__":
230    ImportRescanTest().main()
231