1#!/usr/bin/env python3
2# Copyright (c) 2014-2020 The Bitcoin Core developers
3# Distributed under the MIT software license, see the accompanying
4# file COPYING or http://www.opensource.org/licenses/mit-license.php.
5"""Test the wallet backup features.
6
7Test case is:
84 nodes. 1 2 and 3 send transactions between each other,
9fourth node is a miner.
101 2 3 each mine a block to start, then
11Miner creates 100 blocks so 1 2 3 each have 50 mature
12coins to spend.
13Then 5 iterations of 1/2/3 sending coins amongst
14themselves to get transactions in the wallets,
15and the miner mining one block.
16
17Wallets are backed up using dumpwallet/backupwallet.
18Then 5 more iterations of transactions and mining a block.
19
20Miner then generates 101 more blocks, so any
21transaction fees paid mature.
22
23Sanity check:
24  Sum(1,2,3,4 balances) == 114*50
25
261/2/3 are shutdown, and their wallets erased.
27Then restore using wallet.dat backup. And
28confirm 1/2/3/4 balances are same as before.
29
30Shutdown again, restore using importwallet,
31and confirm again balances are correct.
32"""
33from decimal import Decimal
34import os
35from random import randint
36import shutil
37
38from test_framework.blocktools import COINBASE_MATURITY
39from test_framework.test_framework import BitcoinTestFramework
40from test_framework.util import (
41    assert_equal,
42    assert_raises_rpc_error,
43)
44
45
46class WalletBackupTest(BitcoinTestFramework):
47    def set_test_params(self):
48        self.num_nodes = 4
49        self.setup_clean_chain = True
50        # nodes 1, 2,3 are spenders, let's give them a keypool=100
51        # whitelist all peers to speed up tx relay / mempool sync
52        self.extra_args = [
53            ["-whitelist=noban@127.0.0.1", "-keypool=100"],
54            ["-whitelist=noban@127.0.0.1", "-keypool=100"],
55            ["-whitelist=noban@127.0.0.1", "-keypool=100"],
56            ["-whitelist=noban@127.0.0.1"],
57        ]
58        self.rpc_timeout = 120
59
60    def skip_test_if_missing_module(self):
61        self.skip_if_no_wallet()
62
63    def setup_network(self):
64        self.setup_nodes()
65        self.connect_nodes(0, 3)
66        self.connect_nodes(1, 3)
67        self.connect_nodes(2, 3)
68        self.connect_nodes(2, 0)
69        self.sync_all()
70
71    def one_send(self, from_node, to_address):
72        if (randint(1,2) == 1):
73            amount = Decimal(randint(1,10)) / Decimal(10)
74            self.nodes[from_node].sendtoaddress(to_address, amount)
75
76    def do_one_round(self):
77        a0 = self.nodes[0].getnewaddress()
78        a1 = self.nodes[1].getnewaddress()
79        a2 = self.nodes[2].getnewaddress()
80
81        self.one_send(0, a1)
82        self.one_send(0, a2)
83        self.one_send(1, a0)
84        self.one_send(1, a2)
85        self.one_send(2, a0)
86        self.one_send(2, a1)
87
88        # Have the miner (node3) mine a block.
89        # Must sync mempools before mining.
90        self.sync_mempools()
91        self.nodes[3].generate(1)
92        self.sync_blocks()
93
94    # As above, this mirrors the original bash test.
95    def start_three(self, args=()):
96        self.start_node(0, self.extra_args[0] + list(args))
97        self.start_node(1, self.extra_args[1] + list(args))
98        self.start_node(2, self.extra_args[2] + list(args))
99        self.connect_nodes(0, 3)
100        self.connect_nodes(1, 3)
101        self.connect_nodes(2, 3)
102        self.connect_nodes(2, 0)
103
104    def stop_three(self):
105        self.stop_node(0)
106        self.stop_node(1)
107        self.stop_node(2)
108
109    def erase_three(self):
110        os.remove(os.path.join(self.nodes[0].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
111        os.remove(os.path.join(self.nodes[1].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
112        os.remove(os.path.join(self.nodes[2].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
113
114    def init_three(self):
115        self.init_wallet(0)
116        self.init_wallet(1)
117        self.init_wallet(2)
118
119    def run_test(self):
120        self.log.info("Generating initial blockchain")
121        self.nodes[0].generate(1)
122        self.sync_blocks()
123        self.nodes[1].generate(1)
124        self.sync_blocks()
125        self.nodes[2].generate(1)
126        self.sync_blocks()
127        self.nodes[3].generate(COINBASE_MATURITY)
128        self.sync_blocks()
129
130        assert_equal(self.nodes[0].getbalance(), 50)
131        assert_equal(self.nodes[1].getbalance(), 50)
132        assert_equal(self.nodes[2].getbalance(), 50)
133        assert_equal(self.nodes[3].getbalance(), 0)
134
135        self.log.info("Creating transactions")
136        # Five rounds of sending each other transactions.
137        for _ in range(5):
138            self.do_one_round()
139
140        self.log.info("Backing up")
141
142        self.nodes[0].backupwallet(os.path.join(self.nodes[0].datadir, 'wallet.bak'))
143        self.nodes[1].backupwallet(os.path.join(self.nodes[1].datadir, 'wallet.bak'))
144        self.nodes[2].backupwallet(os.path.join(self.nodes[2].datadir, 'wallet.bak'))
145
146        if not self.options.descriptors:
147            self.nodes[0].dumpwallet(os.path.join(self.nodes[0].datadir, 'wallet.dump'))
148            self.nodes[1].dumpwallet(os.path.join(self.nodes[1].datadir, 'wallet.dump'))
149            self.nodes[2].dumpwallet(os.path.join(self.nodes[2].datadir, 'wallet.dump'))
150
151        self.log.info("More transactions")
152        for _ in range(5):
153            self.do_one_round()
154
155        # Generate 101 more blocks, so any fees paid mature
156        self.nodes[3].generate(COINBASE_MATURITY + 1)
157        self.sync_all()
158
159        balance0 = self.nodes[0].getbalance()
160        balance1 = self.nodes[1].getbalance()
161        balance2 = self.nodes[2].getbalance()
162        balance3 = self.nodes[3].getbalance()
163        total = balance0 + balance1 + balance2 + balance3
164
165        # At this point, there are 214 blocks (103 for setup, then 10 rounds, then 101.)
166        # 114 are mature, so the sum of all wallets should be 114 * 50 = 5700.
167        assert_equal(total, 5700)
168
169        ##
170        # Test restoring spender wallets from backups
171        ##
172        self.log.info("Restoring using wallet.dat")
173        self.stop_three()
174        self.erase_three()
175
176        # Start node2 with no chain
177        shutil.rmtree(os.path.join(self.nodes[2].datadir, self.chain, 'blocks'))
178        shutil.rmtree(os.path.join(self.nodes[2].datadir, self.chain, 'chainstate'))
179
180        # Restore wallets from backup
181        shutil.copyfile(os.path.join(self.nodes[0].datadir, 'wallet.bak'), os.path.join(self.nodes[0].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
182        shutil.copyfile(os.path.join(self.nodes[1].datadir, 'wallet.bak'), os.path.join(self.nodes[1].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
183        shutil.copyfile(os.path.join(self.nodes[2].datadir, 'wallet.bak'), os.path.join(self.nodes[2].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
184
185        self.log.info("Re-starting nodes")
186        self.start_three()
187        self.sync_blocks()
188
189        assert_equal(self.nodes[0].getbalance(), balance0)
190        assert_equal(self.nodes[1].getbalance(), balance1)
191        assert_equal(self.nodes[2].getbalance(), balance2)
192
193        if not self.options.descriptors:
194            self.log.info("Restoring using dumped wallet")
195            self.stop_three()
196            self.erase_three()
197
198            #start node2 with no chain
199            shutil.rmtree(os.path.join(self.nodes[2].datadir, self.chain, 'blocks'))
200            shutil.rmtree(os.path.join(self.nodes[2].datadir, self.chain, 'chainstate'))
201
202            self.start_three(["-nowallet"])
203            self.init_three()
204
205            assert_equal(self.nodes[0].getbalance(), 0)
206            assert_equal(self.nodes[1].getbalance(), 0)
207            assert_equal(self.nodes[2].getbalance(), 0)
208
209            self.nodes[0].importwallet(os.path.join(self.nodes[0].datadir, 'wallet.dump'))
210            self.nodes[1].importwallet(os.path.join(self.nodes[1].datadir, 'wallet.dump'))
211            self.nodes[2].importwallet(os.path.join(self.nodes[2].datadir, 'wallet.dump'))
212
213            self.sync_blocks()
214
215            assert_equal(self.nodes[0].getbalance(), balance0)
216            assert_equal(self.nodes[1].getbalance(), balance1)
217            assert_equal(self.nodes[2].getbalance(), balance2)
218
219        # Backup to source wallet file must fail
220        sourcePaths = [
221            os.path.join(self.nodes[0].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename),
222            os.path.join(self.nodes[0].datadir, self.chain, '.', 'wallets', self.default_wallet_name, self.wallet_data_filename),
223            os.path.join(self.nodes[0].datadir, self.chain, 'wallets', self.default_wallet_name),
224            os.path.join(self.nodes[0].datadir, self.chain, 'wallets')]
225
226        for sourcePath in sourcePaths:
227            assert_raises_rpc_error(-4, "backup failed", self.nodes[0].backupwallet, sourcePath)
228
229
230if __name__ == '__main__':
231    WalletBackupTest().main()
232