1#!/usr/bin/env python3
2# Copyright (c) 2017-2020 The Bitcoin Core developers
3# Distributed under the MIT software license, see the accompanying
4# file COPYING or http://www.opensource.org/licenses/mit-license.php.
5"""Test HD Wallet keypool restore function.
6
7Two nodes. Node1 is under test. Node0 is providing transactions and generating blocks.
8
9- Start node1, shutdown and backup wallet.
10- Generate 110 keys (enough to drain the keypool). Store key 90 (in the initial keypool) and key 110 (beyond the initial keypool). Send funds to key 90 and key 110.
11- Stop node1, clear the datadir, move wallet file back into the datadir and restart node1.
12- connect node1 to node0. Verify that they sync and node1 receives its funds."""
13import os
14import shutil
15
16from test_framework.blocktools import COINBASE_MATURITY
17from test_framework.test_framework import BitcoinTestFramework
18from test_framework.util import (
19    assert_equal,
20)
21
22
23class KeypoolRestoreTest(BitcoinTestFramework):
24    def set_test_params(self):
25        self.setup_clean_chain = True
26        self.num_nodes = 4
27        self.extra_args = [[], ['-keypool=100'], ['-keypool=100'], ['-keypool=100']]
28
29    def skip_test_if_missing_module(self):
30        self.skip_if_no_wallet()
31
32    def run_test(self):
33        wallet_path = os.path.join(self.nodes[1].datadir, self.chain, "wallets", self.default_wallet_name, self.wallet_data_filename)
34        wallet_backup_path = os.path.join(self.nodes[1].datadir, "wallet.bak")
35        self.nodes[0].generate(COINBASE_MATURITY + 1)
36
37        self.log.info("Make backup of wallet")
38        self.stop_node(1)
39        shutil.copyfile(wallet_path, wallet_backup_path)
40        self.start_node(1, self.extra_args[1])
41        self.connect_nodes(0, 1)
42        self.connect_nodes(0, 2)
43        self.connect_nodes(0, 3)
44
45        for i, output_type in enumerate(["legacy", "p2sh-segwit", "bech32"]):
46
47            self.log.info("Generate keys for wallet with address type: {}".format(output_type))
48            idx = i+1
49            for _ in range(90):
50                addr_oldpool = self.nodes[idx].getnewaddress(address_type=output_type)
51            for _ in range(20):
52                addr_extpool = self.nodes[idx].getnewaddress(address_type=output_type)
53
54            # Make sure we're creating the outputs we expect
55            address_details = self.nodes[idx].validateaddress(addr_extpool)
56            if i == 0:
57                assert not address_details["isscript"] and not address_details["iswitness"]
58            elif i == 1:
59                assert address_details["isscript"] and not address_details["iswitness"]
60            else:
61                assert not address_details["isscript"] and address_details["iswitness"]
62
63
64            self.log.info("Send funds to wallet")
65            self.nodes[0].sendtoaddress(addr_oldpool, 10)
66            self.nodes[0].generate(1)
67            self.nodes[0].sendtoaddress(addr_extpool, 5)
68            self.nodes[0].generate(1)
69            self.sync_blocks()
70
71            self.log.info("Restart node with wallet backup")
72            self.stop_node(idx)
73            shutil.copyfile(wallet_backup_path, wallet_path)
74            self.start_node(idx, self.extra_args[idx])
75            self.connect_nodes(0, idx)
76            self.sync_all()
77
78            self.log.info("Verify keypool is restored and balance is correct")
79            assert_equal(self.nodes[idx].getbalance(), 15)
80            assert_equal(self.nodes[idx].listtransactions()[0]['category'], "receive")
81            # Check that we have marked all keys up to the used keypool key as used
82            if self.options.descriptors:
83                if output_type == 'legacy':
84                    assert_equal(self.nodes[idx].getaddressinfo(self.nodes[idx].getnewaddress(address_type=output_type))['hdkeypath'], "m/44'/1'/0'/0/110")
85                elif output_type == 'p2sh-segwit':
86                    assert_equal(self.nodes[idx].getaddressinfo(self.nodes[idx].getnewaddress(address_type=output_type))['hdkeypath'], "m/49'/1'/0'/0/110")
87                elif output_type == 'bech32':
88                    assert_equal(self.nodes[idx].getaddressinfo(self.nodes[idx].getnewaddress(address_type=output_type))['hdkeypath'], "m/84'/1'/0'/0/110")
89            else:
90                assert_equal(self.nodes[idx].getaddressinfo(self.nodes[idx].getnewaddress(address_type=output_type))['hdkeypath'], "m/0'/0'/110'")
91
92
93if __name__ == '__main__':
94    KeypoolRestoreTest().main()
95