1import random
2import math
3from typing import List, Tuple, Optional, Sequence, Dict, TYPE_CHECKING
4from collections import defaultdict
5
6from .util import profiler
7from .lnutil import NoPathFound
8
9PART_PENALTY = 1.0  # 1.0 results in avoiding splits
10MIN_PART_MSAT = 10_000_000  # we don't want to split indefinitely
11EXHAUST_DECAY_FRACTION = 10  # fraction of the local balance that should be reserved if possible
12
13# these parameters determine the granularity of the newly suggested configurations
14REDISTRIBUTION_FRACTION = 50
15SPLIT_FRACTION = 50
16
17# these parameters affect the computational work in the probabilistic algorithm
18STARTING_CONFIGS = 50
19CANDIDATES_PER_LEVEL = 10
20REDISTRIBUTE = 20
21
22# maximum number of parts for splitting
23MAX_PARTS = 5
24
25
26def unique_hierarchy(hierarchy: Dict[int, List[Dict[Tuple[bytes, bytes], int]]]) -> Dict[int, List[Dict[Tuple[bytes, bytes], int]]]:
27    new_hierarchy = defaultdict(list)
28    for number_parts, configs in hierarchy.items():
29        unique_configs = set()
30        for config in configs:
31            # config dict can be out of order, so sort, otherwise not unique
32            unique_configs.add(tuple((c, config[c]) for c in sorted(config.keys())))
33        for unique_config in sorted(unique_configs):
34            new_hierarchy[number_parts].append(
35                {t[0]: t[1] for t in unique_config})
36    return new_hierarchy
37
38
39def single_node_hierarchy(hierarchy: Dict[int, List[Dict[Tuple[bytes, bytes], int]]]) -> Dict[int, List[Dict[Tuple[bytes, bytes], int]]]:
40    new_hierarchy = defaultdict(list)
41    for number_parts, configs in hierarchy.items():
42        for config in configs:
43            # determine number of nodes in configuration
44            if number_nonzero_nodes(config) > 1:
45                continue
46            new_hierarchy[number_parts].append(config)
47    return new_hierarchy
48
49
50def number_nonzero_parts(configuration: Dict[Tuple[bytes, bytes], int]) -> int:
51    return len([v for v in configuration.values() if v])
52
53
54def number_nonzero_nodes(configuration: Dict[Tuple[bytes, bytes], int]) -> int:
55    return len({nodeid for (_, nodeid), amount in configuration.items() if amount > 0})
56
57
58def create_starting_split_hierarchy(amount_msat: int, channels_with_funds: Dict[Tuple[bytes, bytes], int]):
59    """Distributes the amount to send to a single or more channels in several
60    ways (randomly)."""
61    # TODO: find all possible starting configurations deterministically
62    # could try all permutations
63
64    split_hierarchy = defaultdict(list)
65    channels_order = list(channels_with_funds.keys())
66
67    for _ in range(STARTING_CONFIGS):
68        # shuffle to have different starting points
69        random.shuffle(channels_order)
70
71        configuration = {}
72        amount_added = 0
73        for c in channels_order:
74            s = channels_with_funds[c]
75            if amount_added == amount_msat:
76                configuration[c] = 0
77            else:
78                amount_to_add = amount_msat - amount_added
79                amt = min(s, amount_to_add)
80                configuration[c] = amt
81                amount_added += amt
82        if amount_added != amount_msat:
83            raise NoPathFound("Channels don't have enough sending capacity.")
84        split_hierarchy[number_nonzero_parts(configuration)].append(configuration)
85
86    return unique_hierarchy(split_hierarchy)
87
88
89def balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
90    check = (
91            proposed_balance_to < MIN_PART_MSAT or
92            proposed_balance_to > channels_with_funds[channel_to] or
93            proposed_balance_from < MIN_PART_MSAT or
94            proposed_balance_from > channels_with_funds[channel_from]
95    )
96    return check
97
98
99def propose_new_configuration(channels_with_funds: Dict[Tuple[bytes, bytes], int], configuration: Dict[Tuple[bytes, bytes], int],
100                              amount_msat: int, preserve_number_parts=True) -> Dict[Tuple[bytes, bytes], int]:
101    """Randomly alters a split configuration. If preserve_number_parts, the
102    configuration stays within the same class of number of splits."""
103
104    # there are three basic operations to reach different split configurations:
105    # redistribute, split, swap
106
107    def redistribute(config: dict):
108        # we redistribute the amount from a nonzero channel to a nonzero channel
109        redistribution_amount = amount_msat // REDISTRIBUTION_FRACTION
110        nonzero = [ck for ck, cv in config.items() if
111                   cv >= redistribution_amount]
112        if len(nonzero) == 1:  # we only have a single channel, so we can't redistribute
113            return config
114
115        channel_from = random.choice(nonzero)
116        channel_to = random.choice(nonzero)
117        if channel_from == channel_to:
118            return config
119        proposed_balance_from = config[channel_from] - redistribution_amount
120        proposed_balance_to = config[channel_to] + redistribution_amount
121        if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
122            return config
123        else:
124            config[channel_from] = proposed_balance_from
125            config[channel_to] = proposed_balance_to
126        assert sum([cv for cv in config.values()]) == amount_msat
127        return config
128
129    def split(config: dict):
130        # we split off a certain amount from a nonzero channel and put it into a
131        # zero channel
132        nonzero = [ck for ck, cv in config.items() if cv != 0]
133        zero = [ck for ck, cv in config.items() if cv == 0]
134        try:
135            channel_from = random.choice(nonzero)
136            channel_to = random.choice(zero)
137        except IndexError:
138            return config
139        delta = config[channel_from] // SPLIT_FRACTION
140        proposed_balance_from = config[channel_from] - delta
141        proposed_balance_to = config[channel_to] + delta
142        if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
143            return config
144        else:
145            config[channel_from] = proposed_balance_from
146            config[channel_to] = proposed_balance_to
147            assert sum([cv for cv in config.values()]) == amount_msat
148        return config
149
150    def swap(config: dict):
151        # we swap the amounts from a single channel with another channel
152        nonzero = [ck for ck, cv in config.items() if cv != 0]
153        all = list(config.keys())
154
155        channel_from = random.choice(nonzero)
156        channel_to = random.choice(all)
157
158        proposed_balance_to = config[channel_from]
159        proposed_balance_from = config[channel_to]
160        if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
161            return config
162        else:
163            config[channel_to] = proposed_balance_to
164            config[channel_from] = proposed_balance_from
165        return config
166
167    initial_number_parts = number_nonzero_parts(configuration)
168
169    for _ in range(REDISTRIBUTE):
170        configuration = redistribute(configuration)
171    if not preserve_number_parts and number_nonzero_parts(
172            configuration) == initial_number_parts:
173        configuration = split(configuration)
174    configuration = swap(configuration)
175
176    return configuration
177
178
179@profiler
180def suggest_splits(amount_msat: int, channels_with_funds: Dict[Tuple[bytes, bytes], int],
181                   exclude_single_parts=True, single_node=False) \
182        -> Sequence[Tuple[Dict[Tuple[bytes, bytes], int], float]]:
183    """Creates split configurations for a payment over channels. Single channel
184    payments are excluded by default. channels_with_funds is keyed by
185    (channelid, nodeid)."""
186
187    def rate_configuration(config: dict) -> float:
188        """Defines an objective function to rate a split configuration.
189
190        We calculate the normalized L2 norm for a split configuration and
191        add a part penalty for each nonzero amount. The consequence is that
192        amounts that are equally distributed and have less parts are rated
193        lowest."""
194        F = 0
195        total_amount = sum([v for v in config.values()])
196
197        for channel, amount in config.items():
198            funds = channels_with_funds[channel]
199            if amount:
200                F += amount * amount / (total_amount * total_amount)  # a penalty to favor distribution of amounts
201                F += PART_PENALTY * PART_PENALTY  # a penalty for each part
202                decay = funds / EXHAUST_DECAY_FRACTION
203                F += math.exp((amount - funds) / decay)  # a penalty for channel saturation
204
205        return F
206
207    def rated_sorted_configurations(hierarchy: dict) -> Sequence[Tuple[Dict[Tuple[bytes, bytes], int], float]]:
208        """Cleans up duplicate splittings, rates and sorts them according to
209        the rating. A lower rating is a better configuration."""
210        hierarchy = unique_hierarchy(hierarchy)
211        rated_configs = []
212        for level, configs in hierarchy.items():
213            for config in configs:
214                rated_configs.append((config, rate_configuration(config)))
215        sorted_rated_configs = sorted(rated_configs, key=lambda c: c[1], reverse=False)
216        return sorted_rated_configs
217
218    # create initial guesses
219    split_hierarchy = create_starting_split_hierarchy(amount_msat, channels_with_funds)
220
221    # randomize initial guesses and generate splittings of different split
222    # levels up to number of channels
223    for level in range(2, min(MAX_PARTS, len(channels_with_funds) + 1)):
224        # generate a set of random configurations for each level
225        for _ in range(CANDIDATES_PER_LEVEL):
226            configurations = unique_hierarchy(split_hierarchy).get(level, None)
227            if configurations:  # we have a splitting of the desired number of parts
228                configuration = random.choice(configurations)
229                # generate new splittings preserving the number of parts
230                configuration = propose_new_configuration(
231                    channels_with_funds, configuration, amount_msat,
232                    preserve_number_parts=True)
233            else:
234                # go one level lower and look for valid splittings,
235                # try to go one level higher by splitting a single outgoing amount
236                configurations = unique_hierarchy(split_hierarchy).get(level - 1, None)
237                if not configurations:
238                    continue
239                configuration = random.choice(configurations)
240                # generate new splittings going one level higher in the number of parts
241                configuration = propose_new_configuration(
242                    channels_with_funds, configuration, amount_msat,
243                    preserve_number_parts=False)
244
245            # add the newly found configuration (doesn't matter if nothing changed)
246            split_hierarchy[number_nonzero_parts(configuration)].append(configuration)
247
248    if exclude_single_parts:
249        # we only want to return configurations that have at least two parts
250        try:
251            del split_hierarchy[1]
252        except:
253            pass
254
255    if single_node:
256        # we only take configurations that send to a single node
257        split_hierarchy = single_node_hierarchy(split_hierarchy)
258
259    return rated_sorted_configurations(split_hierarchy)
260