1import contextlib
2import numbers
3from itertools import chain, product
4from numbers import Integral
5from operator import getitem
6
7import numpy as np
8
9from ..base import tokenize
10from ..highlevelgraph import HighLevelGraph
11from ..utils import _deprecated, derived_from, random_state_data, skip_doctest
12from .core import (
13    Array,
14    asarray,
15    broadcast_shapes,
16    broadcast_to,
17    normalize_chunks,
18    slices_from_chunks,
19)
20from .creation import arange
21
22
23@_deprecated()
24def doc_wraps(func):
25    """Copy docstring from one function to another"""
26
27    def _(func2):
28        if func.__doc__ is not None:
29            func2.__doc__ = skip_doctest(func.__doc__)
30        return func2
31
32    return _
33
34
35class RandomState:
36    """
37    Mersenne Twister pseudo-random number generator
38
39    This object contains state to deterministically generate pseudo-random
40    numbers from a variety of probability distributions.  It is identical to
41    ``np.random.RandomState`` except that all functions also take a ``chunks=``
42    keyword argument.
43
44    Parameters
45    ----------
46    seed: Number
47        Object to pass to RandomState to serve as deterministic seed
48    RandomState: Callable[seed] -> RandomState
49        A callable that, when provided with a ``seed`` keyword provides an
50        object that operates identically to ``np.random.RandomState`` (the
51        default).  This might also be a function that returns a
52        ``randomgen.RandomState``, ``mkl_random``, or
53        ``cupy.random.RandomState`` object.
54
55    Examples
56    --------
57    >>> import dask.array as da
58    >>> state = da.random.RandomState(1234)  # a seed
59    >>> x = state.normal(10, 0.1, size=3, chunks=(2,))
60    >>> x.compute()
61    array([10.01867852, 10.04812289,  9.89649746])
62
63    See Also
64    --------
65    np.random.RandomState
66    """
67
68    def __init__(self, seed=None, RandomState=None):
69        self._numpy_state = np.random.RandomState(seed)
70        self._RandomState = RandomState
71
72    def seed(self, seed=None):
73        self._numpy_state.seed(seed)
74
75    def _wrap(
76        self, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs
77    ):
78        """Wrap numpy random function to produce dask.array random function
79
80        extra_chunks should be a chunks tuple to append to the end of chunks
81        """
82        if size is not None and not isinstance(size, (tuple, list)):
83            size = (size,)
84
85        shapes = list(
86            {
87                ar.shape
88                for ar in chain(args, kwargs.values())
89                if isinstance(ar, (Array, np.ndarray))
90            }
91        )
92        if size is not None:
93            shapes.append(size)
94        # broadcast to the final size(shape)
95        size = broadcast_shapes(*shapes)
96        chunks = normalize_chunks(
97            chunks,
98            size,  # ideally would use dtype here
99            dtype=kwargs.get("dtype", np.float64),
100        )
101        slices = slices_from_chunks(chunks)
102
103        def _broadcast_any(ar, shape, chunks):
104            if isinstance(ar, Array):
105                return broadcast_to(ar, shape).rechunk(chunks)
106            if isinstance(ar, np.ndarray):
107                return np.ascontiguousarray(np.broadcast_to(ar, shape))
108
109        # Broadcast all arguments, get tiny versions as well
110        # Start adding the relevant bits to the graph
111        dsk = {}
112        lookup = {}
113        small_args = []
114        dependencies = []
115        for i, ar in enumerate(args):
116            if isinstance(ar, (np.ndarray, Array)):
117                res = _broadcast_any(ar, size, chunks)
118                if isinstance(res, Array):
119                    dependencies.append(res)
120                    lookup[i] = res.name
121                elif isinstance(res, np.ndarray):
122                    name = f"array-{tokenize(res)}"
123                    lookup[i] = name
124                    dsk[name] = res
125                small_args.append(ar[tuple(0 for _ in ar.shape)])
126            else:
127                small_args.append(ar)
128
129        small_kwargs = {}
130        for key, ar in kwargs.items():
131            if isinstance(ar, (np.ndarray, Array)):
132                res = _broadcast_any(ar, size, chunks)
133                if isinstance(res, Array):
134                    dependencies.append(res)
135                    lookup[key] = res.name
136                elif isinstance(res, np.ndarray):
137                    name = f"array-{tokenize(res)}"
138                    lookup[key] = name
139                    dsk[name] = res
140                small_kwargs[key] = ar[tuple(0 for _ in ar.shape)]
141            else:
142                small_kwargs[key] = ar
143
144        sizes = list(product(*chunks))
145        seeds = random_state_data(len(sizes), self._numpy_state)
146        token = tokenize(seeds, size, chunks, args, kwargs)
147        name = f"{funcname}-{token}"
148
149        keys = product(
150            [name], *([range(len(bd)) for bd in chunks] + [[0]] * len(extra_chunks))
151        )
152        blocks = product(*[range(len(bd)) for bd in chunks])
153
154        vals = []
155        for seed, size, slc, block in zip(seeds, sizes, slices, blocks):
156            arg = []
157            for i, ar in enumerate(args):
158                if i not in lookup:
159                    arg.append(ar)
160                else:
161                    if isinstance(ar, Array):
162                        arg.append((lookup[i],) + block)
163                    else:  # np.ndarray
164                        arg.append((getitem, lookup[i], slc))
165            kwrg = {}
166            for k, ar in kwargs.items():
167                if k not in lookup:
168                    kwrg[k] = ar
169                else:
170                    if isinstance(ar, Array):
171                        kwrg[k] = (lookup[k],) + block
172                    else:  # np.ndarray
173                        kwrg[k] = (getitem, lookup[k], slc)
174            vals.append(
175                (_apply_random, self._RandomState, funcname, seed, size, arg, kwrg)
176            )
177
178        meta = _apply_random(
179            self._RandomState,
180            funcname,
181            seed,
182            (0,) * len(size),
183            small_args,
184            small_kwargs,
185        )
186
187        dsk.update(dict(zip(keys, vals)))
188
189        graph = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
190        return Array(graph, name, chunks + extra_chunks, meta=meta)
191
192    @derived_from(np.random.RandomState, skipblocks=1)
193    def beta(self, a, b, size=None, chunks="auto", **kwargs):
194        return self._wrap("beta", a, b, size=size, chunks=chunks, **kwargs)
195
196    @derived_from(np.random.RandomState, skipblocks=1)
197    def binomial(self, n, p, size=None, chunks="auto", **kwargs):
198        return self._wrap("binomial", n, p, size=size, chunks=chunks, **kwargs)
199
200    @derived_from(np.random.RandomState, skipblocks=1)
201    def chisquare(self, df, size=None, chunks="auto", **kwargs):
202        return self._wrap("chisquare", df, size=size, chunks=chunks, **kwargs)
203
204    with contextlib.suppress(AttributeError):
205
206        @derived_from(np.random.RandomState, skipblocks=1)
207        def choice(self, a, size=None, replace=True, p=None, chunks="auto"):
208            dependencies = []
209            # Normalize and validate `a`
210            if isinstance(a, Integral):
211                # On windows the output dtype differs if p is provided or
212                # absent, see https://github.com/numpy/numpy/issues/9867
213                dummy_p = np.array([1]) if p is not None else p
214                dtype = np.random.choice(1, size=(), p=dummy_p).dtype
215                len_a = a
216                if a < 0:
217                    raise ValueError("a must be greater than 0")
218            else:
219                a = asarray(a)
220                a = a.rechunk(a.shape)
221                dtype = a.dtype
222                if a.ndim != 1:
223                    raise ValueError("a must be one dimensional")
224                len_a = len(a)
225                dependencies.append(a)
226                a = a.__dask_keys__()[0]
227
228            # Normalize and validate `p`
229            if p is not None:
230                if not isinstance(p, Array):
231                    # If p is not a dask array, first check the sum is close
232                    # to 1 before converting.
233                    p = np.asarray(p)
234                    if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0):
235                        raise ValueError("probabilities do not sum to 1")
236                    p = asarray(p)
237                else:
238                    p = p.rechunk(p.shape)
239
240                if p.ndim != 1:
241                    raise ValueError("p must be one dimensional")
242                if len(p) != len_a:
243                    raise ValueError("a and p must have the same size")
244
245                dependencies.append(p)
246                p = p.__dask_keys__()[0]
247
248            if size is None:
249                size = ()
250            elif not isinstance(size, (tuple, list)):
251                size = (size,)
252
253            chunks = normalize_chunks(chunks, size, dtype=np.float64)
254            if not replace and len(chunks[0]) > 1:
255                err_msg = (
256                    "replace=False is not currently supported for "
257                    "dask.array.choice with multi-chunk output "
258                    "arrays"
259                )
260                raise NotImplementedError(err_msg)
261            sizes = list(product(*chunks))
262            state_data = random_state_data(len(sizes), self._numpy_state)
263
264            name = "da.random.choice-%s" % tokenize(
265                state_data, size, chunks, a, replace, p
266            )
267            keys = product([name], *(range(len(bd)) for bd in chunks))
268            dsk = {
269                k: (_choice, state, a, size, replace, p)
270                for k, state, size in zip(keys, state_data, sizes)
271            }
272
273            graph = HighLevelGraph.from_collections(
274                name, dsk, dependencies=dependencies
275            )
276            return Array(graph, name, chunks, dtype=dtype)
277
278    # @derived_from(np.random.RandomState, skipblocks=1)
279    # def dirichlet(self, alpha, size=None, chunks="auto"):
280
281    @derived_from(np.random.RandomState, skipblocks=1)
282    def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs):
283        return self._wrap("exponential", scale, size=size, chunks=chunks, **kwargs)
284
285    @derived_from(np.random.RandomState, skipblocks=1)
286    def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs):
287        return self._wrap("f", dfnum, dfden, size=size, chunks=chunks, **kwargs)
288
289    @derived_from(np.random.RandomState, skipblocks=1)
290    def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs):
291        return self._wrap("gamma", shape, scale, size=size, chunks=chunks, **kwargs)
292
293    @derived_from(np.random.RandomState, skipblocks=1)
294    def geometric(self, p, size=None, chunks="auto", **kwargs):
295        return self._wrap("geometric", p, size=size, chunks=chunks, **kwargs)
296
297    @derived_from(np.random.RandomState, skipblocks=1)
298    def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
299        return self._wrap("gumbel", loc, scale, size=size, chunks=chunks, **kwargs)
300
301    @derived_from(np.random.RandomState, skipblocks=1)
302    def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs):
303        return self._wrap(
304            "hypergeometric", ngood, nbad, nsample, size=size, chunks=chunks, **kwargs
305        )
306
307    @derived_from(np.random.RandomState, skipblocks=1)
308    def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
309        return self._wrap("laplace", loc, scale, size=size, chunks=chunks, **kwargs)
310
311    @derived_from(np.random.RandomState, skipblocks=1)
312    def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
313        return self._wrap("logistic", loc, scale, size=size, chunks=chunks, **kwargs)
314
315    @derived_from(np.random.RandomState, skipblocks=1)
316    def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs):
317        return self._wrap("lognormal", mean, sigma, size=size, chunks=chunks, **kwargs)
318
319    @derived_from(np.random.RandomState, skipblocks=1)
320    def logseries(self, p, size=None, chunks="auto", **kwargs):
321        return self._wrap("logseries", p, size=size, chunks=chunks, **kwargs)
322
323    @derived_from(np.random.RandomState, skipblocks=1)
324    def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs):
325        return self._wrap(
326            "multinomial",
327            n,
328            pvals,
329            size=size,
330            chunks=chunks,
331            extra_chunks=((len(pvals),),),
332        )
333
334    @derived_from(np.random.RandomState, skipblocks=1)
335    def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs):
336        return self._wrap("negative_binomial", n, p, size=size, chunks=chunks, **kwargs)
337
338    @derived_from(np.random.RandomState, skipblocks=1)
339    def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs):
340        return self._wrap(
341            "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs
342        )
343
344    @derived_from(np.random.RandomState, skipblocks=1)
345    def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs):
346        return self._wrap(
347            "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs
348        )
349
350    @derived_from(np.random.RandomState, skipblocks=1)
351    def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
352        return self._wrap("normal", loc, scale, size=size, chunks=chunks, **kwargs)
353
354    @derived_from(np.random.RandomState, skipblocks=1)
355    def pareto(self, a, size=None, chunks="auto", **kwargs):
356        return self._wrap("pareto", a, size=size, chunks=chunks, **kwargs)
357
358    @derived_from(np.random.RandomState, skipblocks=1)
359    def permutation(self, x):
360        from .slicing import shuffle_slice
361
362        if isinstance(x, numbers.Number):
363            x = arange(x, chunks="auto")
364
365        index = np.arange(len(x))
366        self._numpy_state.shuffle(index)
367        return shuffle_slice(x, index)
368
369    @derived_from(np.random.RandomState, skipblocks=1)
370    def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs):
371        return self._wrap("poisson", lam, size=size, chunks=chunks, **kwargs)
372
373    @derived_from(np.random.RandomState, skipblocks=1)
374    def power(self, a, size=None, chunks="auto", **kwargs):
375        return self._wrap("power", a, size=size, chunks=chunks, **kwargs)
376
377    @derived_from(np.random.RandomState, skipblocks=1)
378    def randint(self, low, high=None, size=None, chunks="auto", dtype="l", **kwargs):
379        return self._wrap(
380            "randint", low, high, size=size, chunks=chunks, dtype=dtype, **kwargs
381        )
382
383    @derived_from(np.random.RandomState, skipblocks=1)
384    def random_integers(self, low, high=None, size=None, chunks="auto", **kwargs):
385        return self._wrap(
386            "random_integers", low, high, size=size, chunks=chunks, **kwargs
387        )
388
389    @derived_from(np.random.RandomState, skipblocks=1)
390    def random_sample(self, size=None, chunks="auto", **kwargs):
391        return self._wrap("random_sample", size=size, chunks=chunks, **kwargs)
392
393    random = random_sample
394
395    @derived_from(np.random.RandomState, skipblocks=1)
396    def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs):
397        return self._wrap("rayleigh", scale, size=size, chunks=chunks, **kwargs)
398
399    @derived_from(np.random.RandomState, skipblocks=1)
400    def standard_cauchy(self, size=None, chunks="auto", **kwargs):
401        return self._wrap("standard_cauchy", size=size, chunks=chunks, **kwargs)
402
403    @derived_from(np.random.RandomState, skipblocks=1)
404    def standard_exponential(self, size=None, chunks="auto", **kwargs):
405        return self._wrap("standard_exponential", size=size, chunks=chunks, **kwargs)
406
407    @derived_from(np.random.RandomState, skipblocks=1)
408    def standard_gamma(self, shape, size=None, chunks="auto", **kwargs):
409        return self._wrap("standard_gamma", shape, size=size, chunks=chunks, **kwargs)
410
411    @derived_from(np.random.RandomState, skipblocks=1)
412    def standard_normal(self, size=None, chunks="auto", **kwargs):
413        return self._wrap("standard_normal", size=size, chunks=chunks, **kwargs)
414
415    @derived_from(np.random.RandomState, skipblocks=1)
416    def standard_t(self, df, size=None, chunks="auto", **kwargs):
417        return self._wrap("standard_t", df, size=size, chunks=chunks, **kwargs)
418
419    @derived_from(np.random.RandomState, skipblocks=1)
420    def tomaxint(self, size=None, chunks="auto", **kwargs):
421        return self._wrap("tomaxint", size=size, chunks=chunks, **kwargs)
422
423    @derived_from(np.random.RandomState, skipblocks=1)
424    def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs):
425        return self._wrap(
426            "triangular", left, mode, right, size=size, chunks=chunks, **kwargs
427        )
428
429    @derived_from(np.random.RandomState, skipblocks=1)
430    def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs):
431        return self._wrap("uniform", low, high, size=size, chunks=chunks, **kwargs)
432
433    @derived_from(np.random.RandomState, skipblocks=1)
434    def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs):
435        return self._wrap("vonmises", mu, kappa, size=size, chunks=chunks, **kwargs)
436
437    @derived_from(np.random.RandomState, skipblocks=1)
438    def wald(self, mean, scale, size=None, chunks="auto", **kwargs):
439        return self._wrap("wald", mean, scale, size=size, chunks=chunks, **kwargs)
440
441    @derived_from(np.random.RandomState, skipblocks=1)
442    def weibull(self, a, size=None, chunks="auto", **kwargs):
443        return self._wrap("weibull", a, size=size, chunks=chunks, **kwargs)
444
445    @derived_from(np.random.RandomState, skipblocks=1)
446    def zipf(self, a, size=None, chunks="auto", **kwargs):
447        return self._wrap("zipf", a, size=size, chunks=chunks, **kwargs)
448
449
450def _choice(state_data, a, size, replace, p):
451    state = np.random.RandomState(state_data)
452    return state.choice(a, size=size, replace=replace, p=p)
453
454
455def _apply_random(RandomState, funcname, state_data, size, args, kwargs):
456    """Apply RandomState method with seed"""
457    if RandomState is None:
458        RandomState = np.random.RandomState
459    state = RandomState(state_data)
460    func = getattr(state, funcname)
461    return func(*args, size=size, **kwargs)
462
463
464_state = RandomState()
465
466
467seed = _state.seed
468
469
470beta = _state.beta
471binomial = _state.binomial
472chisquare = _state.chisquare
473if hasattr(_state, "choice"):
474    choice = _state.choice
475exponential = _state.exponential
476f = _state.f
477gamma = _state.gamma
478geometric = _state.geometric
479gumbel = _state.gumbel
480hypergeometric = _state.hypergeometric
481laplace = _state.laplace
482logistic = _state.logistic
483lognormal = _state.lognormal
484logseries = _state.logseries
485multinomial = _state.multinomial
486negative_binomial = _state.negative_binomial
487noncentral_chisquare = _state.noncentral_chisquare
488noncentral_f = _state.noncentral_f
489normal = _state.normal
490pareto = _state.pareto
491permutation = _state.permutation
492poisson = _state.poisson
493power = _state.power
494rayleigh = _state.rayleigh
495random_sample = _state.random_sample
496random = random_sample
497randint = _state.randint
498random_integers = _state.random_integers
499triangular = _state.triangular
500uniform = _state.uniform
501vonmises = _state.vonmises
502wald = _state.wald
503weibull = _state.weibull
504zipf = _state.zipf
505
506"""
507Standard distributions
508"""
509
510standard_cauchy = _state.standard_cauchy
511standard_exponential = _state.standard_exponential
512standard_gamma = _state.standard_gamma
513standard_normal = _state.standard_normal
514standard_t = _state.standard_t
515