1#!/usr/bin/env python3 2# Copyright (c) 2014-2019 The Bitcoin Core developers 3# Distributed under the MIT software license, see the accompanying 4# file COPYING or http://www.opensource.org/licenses/mit-license.php. 5"""Test wallet import RPCs. 6 7Test rescan behavior of importaddress, importpubkey, importprivkey, and 8importmulti RPCs with different types of keys and rescan options. 9 10In the first part of the test, node 0 creates an address for each type of 11import RPC call and sends BTC to it. Then other nodes import the addresses, 12and the test makes listtransactions and getbalance calls to confirm that the 13importing node either did or did not execute rescans picking up the send 14transactions. 15 16In the second part of the test, node 0 sends more BTC to each address, and the 17test makes more listtransactions and getbalance calls to confirm that the 18importing nodes pick up the new transactions regardless of whether rescans 19happened previously. 20""" 21 22from test_framework.test_framework import BitcoinTestFramework 23from test_framework.address import AddressType 24from test_framework.util import ( 25 connect_nodes, 26 assert_equal, 27 set_node_times, 28) 29 30import collections 31from decimal import Decimal 32import enum 33import itertools 34import random 35 36Call = enum.Enum("Call", "single multiaddress multiscript") 37Data = enum.Enum("Data", "address pub priv") 38Rescan = enum.Enum("Rescan", "no yes late_timestamp") 39 40 41class Variant(collections.namedtuple("Variant", "call data address_type rescan prune")): 42 """Helper for importing one key and verifying scanned transactions.""" 43 44 def do_import(self, timestamp): 45 """Call one key import RPC.""" 46 rescan = self.rescan == Rescan.yes 47 48 assert_equal(self.address["solvable"], True) 49 assert_equal(self.address["isscript"], self.address_type == AddressType.p2sh_segwit) 50 assert_equal(self.address["iswitness"], self.address_type == AddressType.bech32) 51 if self.address["isscript"]: 52 assert_equal(self.address["embedded"]["isscript"], False) 53 assert_equal(self.address["embedded"]["iswitness"], True) 54 55 if self.call == Call.single: 56 if self.data == Data.address: 57 response = self.node.importaddress(address=self.address["address"], label=self.label, rescan=rescan) 58 elif self.data == Data.pub: 59 response = self.node.importpubkey(pubkey=self.address["pubkey"], label=self.label, rescan=rescan) 60 elif self.data == Data.priv: 61 response = self.node.importprivkey(privkey=self.key, label=self.label, rescan=rescan) 62 assert_equal(response, None) 63 64 elif self.call in (Call.multiaddress, Call.multiscript): 65 request = { 66 "scriptPubKey": { 67 "address": self.address["address"] 68 } if self.call == Call.multiaddress else self.address["scriptPubKey"], 69 "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0), 70 "pubkeys": [self.address["pubkey"]] if self.data == Data.pub else [], 71 "keys": [self.key] if self.data == Data.priv else [], 72 "label": self.label, 73 "watchonly": self.data != Data.priv 74 } 75 if self.address_type == AddressType.p2sh_segwit and self.data != Data.address: 76 # We need solving data when providing a pubkey or privkey as data 77 request.update({"redeemscript": self.address['embedded']['scriptPubKey']}) 78 response = self.node.importmulti( 79 requests=[request], 80 options={"rescan": self.rescan in (Rescan.yes, Rescan.late_timestamp)}, 81 ) 82 assert_equal(response, [{"success": True}]) 83 84 def check(self, txid=None, amount=None, confirmation_height=None): 85 """Verify that listtransactions/listreceivedbyaddress return expected values.""" 86 87 txs = self.node.listtransactions(label=self.label, count=10000, include_watchonly=True) 88 current_height = self.node.getblockcount() 89 assert_equal(len(txs), self.expected_txs) 90 91 addresses = self.node.listreceivedbyaddress(minconf=0, include_watchonly=True, address_filter=self.address['address']) 92 if self.expected_txs: 93 assert_equal(len(addresses[0]["txids"]), self.expected_txs) 94 95 if txid is not None: 96 tx, = [tx for tx in txs if tx["txid"] == txid] 97 assert_equal(tx["label"], self.label) 98 assert_equal(tx["address"], self.address["address"]) 99 assert_equal(tx["amount"], amount) 100 assert_equal(tx["category"], "receive") 101 assert_equal(tx["label"], self.label) 102 assert_equal(tx["txid"], txid) 103 assert_equal(tx["confirmations"], 1 + current_height - confirmation_height) 104 assert_equal("trusted" not in tx, True) 105 106 address, = [ad for ad in addresses if txid in ad["txids"]] 107 assert_equal(address["address"], self.address["address"]) 108 assert_equal(address["amount"], self.expected_balance) 109 assert_equal(address["confirmations"], 1 + current_height - confirmation_height) 110 # Verify the transaction is correctly marked watchonly depending on 111 # whether the transaction pays to an imported public key or 112 # imported private key. The test setup ensures that transaction 113 # inputs will not be from watchonly keys (important because 114 # involvesWatchonly will be true if either the transaction output 115 # or inputs are watchonly). 116 if self.data != Data.priv: 117 assert_equal(address["involvesWatchonly"], True) 118 else: 119 assert_equal("involvesWatchonly" not in address, True) 120 121 122# List of Variants for each way a key or address could be imported. 123IMPORT_VARIANTS = [Variant(*variants) for variants in itertools.product(Call, Data, AddressType, Rescan, (False, True))] 124 125# List of nodes to import keys to. Half the nodes will have pruning disabled, 126# half will have it enabled. Different nodes will be used for imports that are 127# expected to cause rescans, and imports that are not expected to cause 128# rescans, in order to prevent rescans during later imports picking up 129# transactions associated with earlier imports. This makes it easier to keep 130# track of expected balances and transactions. 131ImportNode = collections.namedtuple("ImportNode", "prune rescan") 132IMPORT_NODES = [ImportNode(*fields) for fields in itertools.product((False, True), repeat=2)] 133 134# Rescans start at the earliest block up to 2 hours before the key timestamp. 135TIMESTAMP_WINDOW = 2 * 60 * 60 136 137AMOUNT_DUST = 0.00054600 138 139 140def get_rand_amount(): 141 r = random.uniform(AMOUNT_DUST, 1) 142 return Decimal(str(round(r, 8))) 143 144 145class ImportRescanTest(BitcoinTestFramework): 146 def set_test_params(self): 147 self.num_nodes = 2 + len(IMPORT_NODES) 148 self.supports_cli = False 149 150 def skip_test_if_missing_module(self): 151 self.skip_if_no_wallet() 152 153 def setup_network(self): 154 self.extra_args = [[] for _ in range(self.num_nodes)] 155 for i, import_node in enumerate(IMPORT_NODES, 2): 156 if import_node.prune: 157 self.extra_args[i] += ["-prune=1"] 158 159 self.add_nodes(self.num_nodes, extra_args=self.extra_args) 160 161 # Import keys with pruning disabled 162 self.start_nodes(extra_args=[[]] * self.num_nodes) 163 for n in self.nodes: 164 n.importprivkey(privkey=n.get_deterministic_priv_key().key, label='coinbase') 165 self.stop_nodes() 166 167 self.start_nodes() 168 for i in range(1, self.num_nodes): 169 connect_nodes(self.nodes[i], 0) 170 171 def run_test(self): 172 # Create one transaction on node 0 with a unique amount for 173 # each possible type of wallet import RPC. 174 for i, variant in enumerate(IMPORT_VARIANTS): 175 variant.label = "label {} {}".format(i, variant) 176 variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress( 177 label=variant.label, 178 address_type=variant.address_type.value, 179 )) 180 variant.key = self.nodes[1].dumpprivkey(variant.address["address"]) 181 variant.initial_amount = get_rand_amount() 182 variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount) 183 self.nodes[0].generate(1) # Generate one block for each send 184 variant.confirmation_height = self.nodes[0].getblockcount() 185 variant.timestamp = self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] 186 187 # Generate a block further in the future (past the rescan window). 188 assert_equal(self.nodes[0].getrawmempool(), []) 189 set_node_times( 190 self.nodes, 191 self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] + TIMESTAMP_WINDOW + 1, 192 ) 193 self.nodes[0].generate(1) 194 self.sync_all() 195 196 # For each variation of wallet key import, invoke the import RPC and 197 # check the results from getbalance and listtransactions. 198 for variant in IMPORT_VARIANTS: 199 self.log.info('Run import for variant {}'.format(variant)) 200 expect_rescan = variant.rescan == Rescan.yes 201 variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))] 202 variant.do_import(variant.timestamp) 203 if expect_rescan: 204 variant.expected_balance = variant.initial_amount 205 variant.expected_txs = 1 206 variant.check(variant.initial_txid, variant.initial_amount, variant.confirmation_height) 207 else: 208 variant.expected_balance = 0 209 variant.expected_txs = 0 210 variant.check() 211 212 # Create new transactions sending to each address. 213 for i, variant in enumerate(IMPORT_VARIANTS): 214 variant.sent_amount = get_rand_amount() 215 variant.sent_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.sent_amount) 216 self.nodes[0].generate(1) # Generate one block for each send 217 variant.confirmation_height = self.nodes[0].getblockcount() 218 219 assert_equal(self.nodes[0].getrawmempool(), []) 220 self.sync_all() 221 222 # Check the latest results from getbalance and listtransactions. 223 for variant in IMPORT_VARIANTS: 224 self.log.info('Run check for variant {}'.format(variant)) 225 variant.expected_balance += variant.sent_amount 226 variant.expected_txs += 1 227 variant.check(variant.sent_txid, variant.sent_amount, variant.confirmation_height) 228 229if __name__ == "__main__": 230 ImportRescanTest().main() 231