1#!/usr/bin/env python3
2# Copyright (c) 2014-2018 The Bitcoin Core developers
3# Distributed under the MIT software license, see the accompanying
4# file COPYING or http://www.opensource.org/licenses/mit-license.php.
5"""Helpful routines for regression testing."""
6
7from base64 import b64encode
8from binascii import hexlify, unhexlify
9from decimal import Decimal, ROUND_DOWN
10import hashlib
11import inspect
12import json
13import logging
14import os
15import random
16import re
17from subprocess import CalledProcessError
18import time
19
20from . import coverage
21from .authproxy import AuthServiceProxy, JSONRPCException
22
23logger = logging.getLogger("TestFramework.utils")
24
25# Assert functions
26##################
27
28def assert_fee_amount(fee, tx_size, fee_per_kB):
29    """Assert the fee was in range"""
30    target_fee = round(tx_size * fee_per_kB / 1000, 8)
31    if fee < target_fee:
32        raise AssertionError("Fee of %s LTC too low! (Should be %s LTC)" % (str(fee), str(target_fee)))
33    # allow the wallet's estimation to be at most 2 bytes off
34    if fee > (tx_size + 2) * fee_per_kB / 1000:
35        raise AssertionError("Fee of %s LTC too high! (Should be %s LTC)" % (str(fee), str(target_fee)))
36
37def assert_equal(thing1, thing2, *args):
38    if thing1 != thing2 or any(thing1 != arg for arg in args):
39        raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
40
41def assert_greater_than(thing1, thing2):
42    if thing1 <= thing2:
43        raise AssertionError("%s <= %s" % (str(thing1), str(thing2)))
44
45def assert_greater_than_or_equal(thing1, thing2):
46    if thing1 < thing2:
47        raise AssertionError("%s < %s" % (str(thing1), str(thing2)))
48
49def assert_raises(exc, fun, *args, **kwds):
50    assert_raises_message(exc, None, fun, *args, **kwds)
51
52def assert_raises_message(exc, message, fun, *args, **kwds):
53    try:
54        fun(*args, **kwds)
55    except JSONRPCException:
56        raise AssertionError("Use assert_raises_rpc_error() to test RPC failures")
57    except exc as e:
58        if message is not None and message not in e.error['message']:
59            raise AssertionError("Expected substring not found:" + e.error['message'])
60    except Exception as e:
61        raise AssertionError("Unexpected exception raised: " + type(e).__name__)
62    else:
63        raise AssertionError("No exception raised")
64
65def assert_raises_process_error(returncode, output, fun, *args, **kwds):
66    """Execute a process and asserts the process return code and output.
67
68    Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
69    and verifies that the return code and output are as expected. Throws AssertionError if
70    no CalledProcessError was raised or if the return code and output are not as expected.
71
72    Args:
73        returncode (int): the process return code.
74        output (string): [a substring of] the process output.
75        fun (function): the function to call. This should execute a process.
76        args*: positional arguments for the function.
77        kwds**: named arguments for the function.
78    """
79    try:
80        fun(*args, **kwds)
81    except CalledProcessError as e:
82        if returncode != e.returncode:
83            raise AssertionError("Unexpected returncode %i" % e.returncode)
84        if output not in e.output:
85            raise AssertionError("Expected substring not found:" + e.output)
86    else:
87        raise AssertionError("No exception raised")
88
89def assert_raises_rpc_error(code, message, fun, *args, **kwds):
90    """Run an RPC and verify that a specific JSONRPC exception code and message is raised.
91
92    Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
93    and verifies that the error code and message are as expected. Throws AssertionError if
94    no JSONRPCException was raised or if the error code/message are not as expected.
95
96    Args:
97        code (int), optional: the error code returned by the RPC call (defined
98            in src/rpc/protocol.h). Set to None if checking the error code is not required.
99        message (string), optional: [a substring of] the error string returned by the
100            RPC call. Set to None if checking the error string is not required.
101        fun (function): the function to call. This should be the name of an RPC.
102        args*: positional arguments for the function.
103        kwds**: named arguments for the function.
104    """
105    assert try_rpc(code, message, fun, *args, **kwds), "No exception raised"
106
107def try_rpc(code, message, fun, *args, **kwds):
108    """Tries to run an rpc command.
109
110    Test against error code and message if the rpc fails.
111    Returns whether a JSONRPCException was raised."""
112    try:
113        fun(*args, **kwds)
114    except JSONRPCException as e:
115        # JSONRPCException was thrown as expected. Check the code and message values are correct.
116        if (code is not None) and (code != e.error["code"]):
117            raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"])
118        if (message is not None) and (message not in e.error['message']):
119            raise AssertionError("Expected substring not found:" + e.error['message'])
120        return True
121    except Exception as e:
122        raise AssertionError("Unexpected exception raised: " + type(e).__name__)
123    else:
124        return False
125
126def assert_is_hex_string(string):
127    try:
128        int(string, 16)
129    except Exception as e:
130        raise AssertionError(
131            "Couldn't interpret %r as hexadecimal; raised: %s" % (string, e))
132
133def assert_is_hash_string(string, length=64):
134    if not isinstance(string, str):
135        raise AssertionError("Expected a string, got type %r" % type(string))
136    elif length and len(string) != length:
137        raise AssertionError(
138            "String of length %d expected; got %d" % (length, len(string)))
139    elif not re.match('[abcdef0-9]+$', string):
140        raise AssertionError(
141            "String %r contains invalid characters for a hash." % string)
142
143def assert_array_result(object_array, to_match, expected, should_not_find=False):
144    """
145        Pass in array of JSON objects, a dictionary with key/value pairs
146        to match against, and another dictionary with expected key/value
147        pairs.
148        If the should_not_find flag is true, to_match should not be found
149        in object_array
150        """
151    if should_not_find:
152        assert_equal(expected, {})
153    num_matched = 0
154    for item in object_array:
155        all_match = True
156        for key, value in to_match.items():
157            if item[key] != value:
158                all_match = False
159        if not all_match:
160            continue
161        elif should_not_find:
162            num_matched = num_matched + 1
163        for key, value in expected.items():
164            if item[key] != value:
165                raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value)))
166            num_matched = num_matched + 1
167    if num_matched == 0 and not should_not_find:
168        raise AssertionError("No objects matched %s" % (str(to_match)))
169    if num_matched > 0 and should_not_find:
170        raise AssertionError("Objects were found %s" % (str(to_match)))
171
172# Utility functions
173###################
174
175def check_json_precision():
176    """Make sure json library being used does not lose precision converting BTC values"""
177    n = Decimal("20000000.00000003")
178    satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8)
179    if satoshis != 2000000000000003:
180        raise RuntimeError("JSON encode/decode loses precision")
181
182def count_bytes(hex_string):
183    return len(bytearray.fromhex(hex_string))
184
185def bytes_to_hex_str(byte_str):
186    return hexlify(byte_str).decode('ascii')
187
188def hash256(byte_str):
189    sha256 = hashlib.sha256()
190    sha256.update(byte_str)
191    sha256d = hashlib.sha256()
192    sha256d.update(sha256.digest())
193    return sha256d.digest()[::-1]
194
195def hex_str_to_bytes(hex_str):
196    return unhexlify(hex_str.encode('ascii'))
197
198def str_to_b64str(string):
199    return b64encode(string.encode('utf-8')).decode('ascii')
200
201def satoshi_round(amount):
202    return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
203
204def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None):
205    if attempts == float('inf') and timeout == float('inf'):
206        timeout = 60
207    attempt = 0
208    time_end = time.time() + timeout
209
210    while attempt < attempts and time.time() < time_end:
211        if lock:
212            with lock:
213                if predicate():
214                    return
215        else:
216            if predicate():
217                return
218        attempt += 1
219        time.sleep(0.05)
220
221    # Print the cause of the timeout
222    predicate_source = "''''\n" + inspect.getsource(predicate) + "'''"
223    logger.error("wait_until() failed. Predicate: {}".format(predicate_source))
224    if attempt >= attempts:
225        raise AssertionError("Predicate {} not true after {} attempts".format(predicate_source, attempts))
226    elif time.time() >= time_end:
227        raise AssertionError("Predicate {} not true after {} seconds".format(predicate_source, timeout))
228    raise RuntimeError('Unreachable')
229
230# RPC/P2P connection constants and functions
231############################################
232
233# The maximum number of nodes a single test can spawn
234MAX_NODES = 8
235# Don't assign rpc or p2p ports lower than this
236PORT_MIN = 11000
237# The number of ports to "reserve" for p2p and rpc, each
238PORT_RANGE = 5000
239
240class PortSeed:
241    # Must be initialized with a unique integer for each process
242    n = None
243
244def get_rpc_proxy(url, node_number, timeout=None, coveragedir=None):
245    """
246    Args:
247        url (str): URL of the RPC server to call
248        node_number (int): the node number (or id) that this calls to
249
250    Kwargs:
251        timeout (int): HTTP timeout in seconds
252
253    Returns:
254        AuthServiceProxy. convenience object for making RPC calls.
255
256    """
257    proxy_kwargs = {}
258    if timeout is not None:
259        proxy_kwargs['timeout'] = timeout
260
261    proxy = AuthServiceProxy(url, **proxy_kwargs)
262    proxy.url = url  # store URL on proxy for info
263
264    coverage_logfile = coverage.get_filename(
265        coveragedir, node_number) if coveragedir else None
266
267    return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile)
268
269def p2p_port(n):
270    assert(n <= MAX_NODES)
271    return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
272
273def rpc_port(n):
274    return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
275
276def rpc_url(datadir, i, rpchost=None):
277    rpc_u, rpc_p = get_auth_cookie(datadir)
278    host = '127.0.0.1'
279    port = rpc_port(i)
280    if rpchost:
281        parts = rpchost.split(':')
282        if len(parts) == 2:
283            host, port = parts
284        else:
285            host = rpchost
286    return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port))
287
288# Node functions
289################
290
291def initialize_datadir(dirname, n):
292    datadir = get_datadir_path(dirname, n)
293    if not os.path.isdir(datadir):
294        os.makedirs(datadir)
295    with open(os.path.join(datadir, "litecoin.conf"), 'w', encoding='utf8') as f:
296        f.write("regtest=1\n")
297        f.write("[regtest]\n")
298        f.write("port=" + str(p2p_port(n)) + "\n")
299        f.write("rpcport=" + str(rpc_port(n)) + "\n")
300        f.write("server=1\n")
301        f.write("keypool=1\n")
302        f.write("discover=0\n")
303        f.write("listenonion=0\n")
304        f.write("printtoconsole=0\n")
305        f.write("upnp=0\n")
306        os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True)
307        os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True)
308    return datadir
309
310def get_datadir_path(dirname, n):
311    return os.path.join(dirname, "node" + str(n))
312
313def append_config(datadir, options):
314    with open(os.path.join(datadir, "litecoin.conf"), 'a', encoding='utf8') as f:
315        for option in options:
316            f.write(option + "\n")
317
318def get_auth_cookie(datadir):
319    user = None
320    password = None
321    if os.path.isfile(os.path.join(datadir, "litecoin.conf")):
322        with open(os.path.join(datadir, "litecoin.conf"), 'r', encoding='utf8') as f:
323            for line in f:
324                if line.startswith("rpcuser="):
325                    assert user is None  # Ensure that there is only one rpcuser line
326                    user = line.split("=")[1].strip("\n")
327                if line.startswith("rpcpassword="):
328                    assert password is None  # Ensure that there is only one rpcpassword line
329                    password = line.split("=")[1].strip("\n")
330    if os.path.isfile(os.path.join(datadir, "regtest", ".cookie")) and os.access(os.path.join(datadir, "regtest", ".cookie"), os.R_OK):
331        with open(os.path.join(datadir, "regtest", ".cookie"), 'r', encoding="ascii") as f:
332            userpass = f.read()
333            split_userpass = userpass.split(':')
334            user = split_userpass[0]
335            password = split_userpass[1]
336    if user is None or password is None:
337        raise ValueError("No RPC credentials")
338    return user, password
339
340# If a cookie file exists in the given datadir, delete it.
341def delete_cookie_file(datadir):
342    if os.path.isfile(os.path.join(datadir, "regtest", ".cookie")):
343        logger.debug("Deleting leftover cookie file")
344        os.remove(os.path.join(datadir, "regtest", ".cookie"))
345
346def get_bip9_status(node, key):
347    info = node.getblockchaininfo()
348    return info['bip9_softforks'][key]
349
350def set_node_times(nodes, t):
351    for node in nodes:
352        node.setmocktime(t)
353
354def disconnect_nodes(from_connection, node_num):
355    for peer_id in [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']]:
356        try:
357            from_connection.disconnectnode(nodeid=peer_id)
358        except JSONRPCException as e:
359            # If this node is disconnected between calculating the peer id
360            # and issuing the disconnect, don't worry about it.
361            # This avoids a race condition if we're mass-disconnecting peers.
362            if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED
363                raise
364
365    # wait to disconnect
366    wait_until(lambda: [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']] == [], timeout=5)
367
368def connect_nodes(from_connection, node_num):
369    ip_port = "127.0.0.1:" + str(p2p_port(node_num))
370    from_connection.addnode(ip_port, "onetry")
371    # poll until version handshake complete to avoid race conditions
372    # with transaction relaying
373    wait_until(lambda:  all(peer['version'] != 0 for peer in from_connection.getpeerinfo()))
374
375def connect_nodes_bi(nodes, a, b):
376    connect_nodes(nodes[a], b)
377    connect_nodes(nodes[b], a)
378
379def sync_blocks(rpc_connections, *, wait=1, timeout=60):
380    """
381    Wait until everybody has the same tip.
382
383    sync_blocks needs to be called with an rpc_connections set that has least
384    one node already synced to the latest, stable tip, otherwise there's a
385    chance it might return before all nodes are stably synced.
386    """
387    stop_time = time.time() + timeout
388    while time.time() <= stop_time:
389        best_hash = [x.getbestblockhash() for x in rpc_connections]
390        if best_hash.count(best_hash[0]) == len(rpc_connections):
391            return
392        time.sleep(wait)
393    raise AssertionError("Block sync timed out:{}".format("".join("\n  {!r}".format(b) for b in best_hash)))
394
395def sync_mempools(rpc_connections, *, wait=1, timeout=60, flush_scheduler=True):
396    """
397    Wait until everybody has the same transactions in their memory
398    pools
399    """
400    stop_time = time.time() + timeout
401    while time.time() <= stop_time:
402        pool = [set(r.getrawmempool()) for r in rpc_connections]
403        if pool.count(pool[0]) == len(rpc_connections):
404            if flush_scheduler:
405                for r in rpc_connections:
406                    r.syncwithvalidationinterfacequeue()
407            return
408        time.sleep(wait)
409    raise AssertionError("Mempool sync timed out:{}".format("".join("\n  {!r}".format(m) for m in pool)))
410
411# Transaction/Block functions
412#############################
413
414def find_output(node, txid, amount, *, blockhash=None):
415    """
416    Return index to output of txid with value amount
417    Raises exception if there is none.
418    """
419    txdata = node.getrawtransaction(txid, 1, blockhash)
420    for i in range(len(txdata["vout"])):
421        if txdata["vout"][i]["value"] == amount:
422            return i
423    raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount)))
424
425def gather_inputs(from_node, amount_needed, confirmations_required=1):
426    """
427    Return a random set of unspent txouts that are enough to pay amount_needed
428    """
429    assert(confirmations_required >= 0)
430    utxo = from_node.listunspent(confirmations_required)
431    random.shuffle(utxo)
432    inputs = []
433    total_in = Decimal("0.00000000")
434    while total_in < amount_needed and len(utxo) > 0:
435        t = utxo.pop()
436        total_in += t["amount"]
437        inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]})
438    if total_in < amount_needed:
439        raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in))
440    return (total_in, inputs)
441
442def make_change(from_node, amount_in, amount_out, fee):
443    """
444    Create change output(s), return them
445    """
446    outputs = {}
447    amount = amount_out + fee
448    change = amount_in - amount
449    if change > amount * 2:
450        # Create an extra change output to break up big inputs
451        change_address = from_node.getnewaddress()
452        # Split change in two, being careful of rounding:
453        outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
454        change = amount_in - amount - outputs[change_address]
455    if change > 0:
456        outputs[from_node.getnewaddress()] = change
457    return outputs
458
459def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants):
460    """
461    Create a random transaction.
462    Returns (txid, hex-encoded-transaction-data, fee)
463    """
464    from_node = random.choice(nodes)
465    to_node = random.choice(nodes)
466    fee = min_fee + fee_increment * random.randint(0, fee_variants)
467
468    (total_in, inputs) = gather_inputs(from_node, amount + fee)
469    outputs = make_change(from_node, total_in, amount, fee)
470    outputs[to_node.getnewaddress()] = float(amount)
471
472    rawtx = from_node.createrawtransaction(inputs, outputs)
473    signresult = from_node.signrawtransactionwithwallet(rawtx)
474    txid = from_node.sendrawtransaction(signresult["hex"], True)
475
476    return (txid, signresult["hex"], fee)
477
478# Helper to create at least "count" utxos
479# Pass in a fee that is sufficient for relay and mining new transactions.
480def create_confirmed_utxos(fee, node, count):
481    to_generate = int(0.5 * count) + 101
482    while to_generate > 0:
483        node.generate(min(25, to_generate))
484        to_generate -= 25
485    utxos = node.listunspent()
486    iterations = count - len(utxos)
487    addr1 = node.getnewaddress()
488    addr2 = node.getnewaddress()
489    if iterations <= 0:
490        return utxos
491    for i in range(iterations):
492        t = utxos.pop()
493        inputs = []
494        inputs.append({"txid": t["txid"], "vout": t["vout"]})
495        outputs = {}
496        send_value = t['amount'] - fee
497        outputs[addr1] = satoshi_round(send_value / 2)
498        outputs[addr2] = satoshi_round(send_value / 2)
499        raw_tx = node.createrawtransaction(inputs, outputs)
500        signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"]
501        node.sendrawtransaction(signed_tx)
502
503    while (node.getmempoolinfo()['size'] > 0):
504        node.generate(1)
505
506    utxos = node.listunspent()
507    assert(len(utxos) >= count)
508    return utxos
509
510# Create large OP_RETURN txouts that can be appended to a transaction
511# to make it large (helper for constructing large transactions).
512def gen_return_txouts():
513    # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create
514    # So we have big transactions (and therefore can't fit very many into each block)
515    # create one script_pubkey
516    script_pubkey = "6a4d0200"  # OP_RETURN OP_PUSH2 512 bytes
517    for i in range(512):
518        script_pubkey = script_pubkey + "01"
519    # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change
520    txouts = "81"
521    for k in range(128):
522        # add txout value
523        txouts = txouts + "0000000000000000"
524        # add length of script_pubkey
525        txouts = txouts + "fd0402"
526        # add script_pubkey
527        txouts = txouts + script_pubkey
528    return txouts
529
530# Create a spend of each passed-in utxo, splicing in "txouts" to each raw
531# transaction to make it large.  See gen_return_txouts() above.
532def create_lots_of_big_transactions(node, txouts, utxos, num, fee):
533    addr = node.getnewaddress()
534    txids = []
535    for _ in range(num):
536        t = utxos.pop()
537        inputs = [{"txid": t["txid"], "vout": t["vout"]}]
538        outputs = {}
539        change = t['amount'] - fee
540        outputs[addr] = satoshi_round(change)
541        rawtx = node.createrawtransaction(inputs, outputs)
542        newtx = rawtx[0:92]
543        newtx = newtx + txouts
544        newtx = newtx + rawtx[94:]
545        signresult = node.signrawtransactionwithwallet(newtx, None, "NONE")
546        txid = node.sendrawtransaction(signresult["hex"], True)
547        txids.append(txid)
548    return txids
549
550def mine_large_block(node, utxos=None):
551    # generate a 66k transaction,
552    # and 14 of them is close to the 1MB block limit
553    num = 14
554    txouts = gen_return_txouts()
555    utxos = utxos if utxos is not None else []
556    if len(utxos) < num:
557        utxos.clear()
558        utxos.extend(node.listunspent())
559    fee = 100 * node.getnetworkinfo()["relayfee"]
560    create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee)
561    node.generate(1)
562
563def find_vout_for_address(node, txid, addr):
564    """
565    Locate the vout index of the given transaction sending to the
566    given address. Raises runtime error exception if not found.
567    """
568    tx = node.getrawtransaction(txid, True)
569    for i in range(len(tx["vout"])):
570        if any([addr == a for a in tx["vout"][i]["scriptPubKey"]["addresses"]]):
571            return i
572    raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr))
573