1#!/usr/bin/env python3 2# Copyright (c) 2014-2020 The Bitcoin Core developers 3# Distributed under the MIT software license, see the accompanying 4# file COPYING or http://www.opensource.org/licenses/mit-license.php. 5"""Test wallet import RPCs. 6 7Test rescan behavior of importaddress, importpubkey, importprivkey, and 8importmulti RPCs with different types of keys and rescan options. 9 10In the first part of the test, node 0 creates an address for each type of 11import RPC call and sends BTC to it. Then other nodes import the addresses, 12and the test makes listtransactions and getbalance calls to confirm that the 13importing node either did or did not execute rescans picking up the send 14transactions. 15 16In the second part of the test, node 0 sends more BTC to each address, and the 17test makes more listtransactions and getbalance calls to confirm that the 18importing nodes pick up the new transactions regardless of whether rescans 19happened previously. 20""" 21 22from test_framework.test_framework import BitcoinTestFramework 23from test_framework.address import AddressType 24from test_framework.util import ( 25 assert_equal, 26 set_node_times, 27) 28 29import collections 30from decimal import Decimal 31import enum 32import itertools 33import random 34 35Call = enum.Enum("Call", "single multiaddress multiscript") 36Data = enum.Enum("Data", "address pub priv") 37Rescan = enum.Enum("Rescan", "no yes late_timestamp") 38 39 40class Variant(collections.namedtuple("Variant", "call data address_type rescan prune")): 41 """Helper for importing one key and verifying scanned transactions.""" 42 def do_import(self, timestamp): 43 """Call one key import RPC.""" 44 rescan = self.rescan == Rescan.yes 45 46 assert_equal(self.address["solvable"], True) 47 assert_equal(self.address["isscript"], self.address_type == AddressType.p2sh_segwit) 48 assert_equal(self.address["iswitness"], self.address_type == AddressType.bech32) 49 if self.address["isscript"]: 50 assert_equal(self.address["embedded"]["isscript"], False) 51 assert_equal(self.address["embedded"]["iswitness"], True) 52 53 if self.call == Call.single: 54 if self.data == Data.address: 55 response = self.node.importaddress(address=self.address["address"], label=self.label, rescan=rescan) 56 elif self.data == Data.pub: 57 response = self.node.importpubkey(pubkey=self.address["pubkey"], label=self.label, rescan=rescan) 58 elif self.data == Data.priv: 59 response = self.node.importprivkey(privkey=self.key, label=self.label, rescan=rescan) 60 assert_equal(response, None) 61 62 elif self.call in (Call.multiaddress, Call.multiscript): 63 request = { 64 "scriptPubKey": { 65 "address": self.address["address"] 66 } if self.call == Call.multiaddress else self.address["scriptPubKey"], 67 "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0), 68 "pubkeys": [self.address["pubkey"]] if self.data == Data.pub else [], 69 "keys": [self.key] if self.data == Data.priv else [], 70 "label": self.label, 71 "watchonly": self.data != Data.priv 72 } 73 if self.address_type == AddressType.p2sh_segwit and self.data != Data.address: 74 # We need solving data when providing a pubkey or privkey as data 75 request.update({"redeemscript": self.address['embedded']['scriptPubKey']}) 76 response = self.node.importmulti( 77 requests=[request], 78 options={"rescan": self.rescan in (Rescan.yes, Rescan.late_timestamp)}, 79 ) 80 assert_equal(response, [{"success": True}]) 81 82 def check(self, txid=None, amount=None, confirmation_height=None): 83 """Verify that listtransactions/listreceivedbyaddress return expected values.""" 84 85 txs = self.node.listtransactions(label=self.label, count=10000, include_watchonly=True) 86 current_height = self.node.getblockcount() 87 assert_equal(len(txs), self.expected_txs) 88 89 addresses = self.node.listreceivedbyaddress(minconf=0, include_watchonly=True, address_filter=self.address['address']) 90 if self.expected_txs: 91 assert_equal(len(addresses[0]["txids"]), self.expected_txs) 92 93 if txid is not None: 94 tx, = [tx for tx in txs if tx["txid"] == txid] 95 assert_equal(tx["label"], self.label) 96 assert_equal(tx["address"], self.address["address"]) 97 assert_equal(tx["amount"], amount) 98 assert_equal(tx["category"], "receive") 99 assert_equal(tx["label"], self.label) 100 assert_equal(tx["txid"], txid) 101 assert_equal(tx["confirmations"], 1 + current_height - confirmation_height) 102 assert_equal("trusted" not in tx, True) 103 104 address, = [ad for ad in addresses if txid in ad["txids"]] 105 assert_equal(address["address"], self.address["address"]) 106 assert_equal(address["amount"], self.expected_balance) 107 assert_equal(address["confirmations"], 1 + current_height - confirmation_height) 108 # Verify the transaction is correctly marked watchonly depending on 109 # whether the transaction pays to an imported public key or 110 # imported private key. The test setup ensures that transaction 111 # inputs will not be from watchonly keys (important because 112 # involvesWatchonly will be true if either the transaction output 113 # or inputs are watchonly). 114 if self.data != Data.priv: 115 assert_equal(address["involvesWatchonly"], True) 116 else: 117 assert_equal("involvesWatchonly" not in address, True) 118 119 120# List of Variants for each way a key or address could be imported. 121IMPORT_VARIANTS = [Variant(*variants) for variants in itertools.product(Call, Data, AddressType, Rescan, (False, True))] 122 123# List of nodes to import keys to. Half the nodes will have pruning disabled, 124# half will have it enabled. Different nodes will be used for imports that are 125# expected to cause rescans, and imports that are not expected to cause 126# rescans, in order to prevent rescans during later imports picking up 127# transactions associated with earlier imports. This makes it easier to keep 128# track of expected balances and transactions. 129ImportNode = collections.namedtuple("ImportNode", "prune rescan") 130IMPORT_NODES = [ImportNode(*fields) for fields in itertools.product((False, True), repeat=2)] 131 132# Rescans start at the earliest block up to 2 hours before the key timestamp. 133TIMESTAMP_WINDOW = 2 * 60 * 60 134 135AMOUNT_DUST = 0.00000546 136 137 138def get_rand_amount(): 139 r = random.uniform(AMOUNT_DUST, 1) 140 return Decimal(str(round(r, 8))) 141 142 143class ImportRescanTest(BitcoinTestFramework): 144 def set_test_params(self): 145 self.num_nodes = 2 + len(IMPORT_NODES) 146 self.supports_cli = False 147 self.rpc_timeout = 120 148 149 def skip_test_if_missing_module(self): 150 self.skip_if_no_wallet() 151 152 def setup_network(self): 153 self.extra_args = [[] for _ in range(self.num_nodes)] 154 for i, import_node in enumerate(IMPORT_NODES, 2): 155 if import_node.prune: 156 self.extra_args[i] += ["-prune=1"] 157 158 self.add_nodes(self.num_nodes, extra_args=self.extra_args) 159 160 # Import keys with pruning disabled 161 self.start_nodes(extra_args=[[]] * self.num_nodes) 162 self.import_deterministic_coinbase_privkeys() 163 self.stop_nodes() 164 165 self.start_nodes() 166 for i in range(1, self.num_nodes): 167 self.connect_nodes(i, 0) 168 169 def run_test(self): 170 # Create one transaction on node 0 with a unique amount for 171 # each possible type of wallet import RPC. 172 for i, variant in enumerate(IMPORT_VARIANTS): 173 variant.label = "label {} {}".format(i, variant) 174 variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress( 175 label=variant.label, 176 address_type=variant.address_type.value, 177 )) 178 variant.key = self.nodes[1].dumpprivkey(variant.address["address"]) 179 variant.initial_amount = get_rand_amount() 180 variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount) 181 self.nodes[0].generate(1) # Generate one block for each send 182 variant.confirmation_height = self.nodes[0].getblockcount() 183 variant.timestamp = self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] 184 self.sync_all() # Conclude sync before calling setmocktime to avoid timeouts 185 186 # Generate a block further in the future (past the rescan window). 187 assert_equal(self.nodes[0].getrawmempool(), []) 188 set_node_times( 189 self.nodes, 190 self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] + TIMESTAMP_WINDOW + 1, 191 ) 192 self.nodes[0].generate(1) 193 self.sync_all() 194 195 # For each variation of wallet key import, invoke the import RPC and 196 # check the results from getbalance and listtransactions. 197 for variant in IMPORT_VARIANTS: 198 self.log.info('Run import for variant {}'.format(variant)) 199 expect_rescan = variant.rescan == Rescan.yes 200 variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))] 201 variant.do_import(variant.timestamp) 202 if expect_rescan: 203 variant.expected_balance = variant.initial_amount 204 variant.expected_txs = 1 205 variant.check(variant.initial_txid, variant.initial_amount, variant.confirmation_height) 206 else: 207 variant.expected_balance = 0 208 variant.expected_txs = 0 209 variant.check() 210 211 # Create new transactions sending to each address. 212 for i, variant in enumerate(IMPORT_VARIANTS): 213 variant.sent_amount = get_rand_amount() 214 variant.sent_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.sent_amount) 215 self.nodes[0].generate(1) # Generate one block for each send 216 variant.confirmation_height = self.nodes[0].getblockcount() 217 218 assert_equal(self.nodes[0].getrawmempool(), []) 219 self.sync_all() 220 221 # Check the latest results from getbalance and listtransactions. 222 for variant in IMPORT_VARIANTS: 223 self.log.info('Run check for variant {}'.format(variant)) 224 variant.expected_balance += variant.sent_amount 225 variant.expected_txs += 1 226 variant.check(variant.sent_txid, variant.sent_amount, variant.confirmation_height) 227 228 229if __name__ == "__main__": 230 ImportRescanTest().main() 231