1import contextlib 2import numbers 3from itertools import chain, product 4from numbers import Integral 5from operator import getitem 6 7import numpy as np 8 9from ..base import tokenize 10from ..highlevelgraph import HighLevelGraph 11from ..utils import _deprecated, derived_from, random_state_data, skip_doctest 12from .core import ( 13 Array, 14 asarray, 15 broadcast_shapes, 16 broadcast_to, 17 normalize_chunks, 18 slices_from_chunks, 19) 20from .creation import arange 21 22 23@_deprecated() 24def doc_wraps(func): 25 """Copy docstring from one function to another""" 26 27 def _(func2): 28 if func.__doc__ is not None: 29 func2.__doc__ = skip_doctest(func.__doc__) 30 return func2 31 32 return _ 33 34 35class RandomState: 36 """ 37 Mersenne Twister pseudo-random number generator 38 39 This object contains state to deterministically generate pseudo-random 40 numbers from a variety of probability distributions. It is identical to 41 ``np.random.RandomState`` except that all functions also take a ``chunks=`` 42 keyword argument. 43 44 Parameters 45 ---------- 46 seed: Number 47 Object to pass to RandomState to serve as deterministic seed 48 RandomState: Callable[seed] -> RandomState 49 A callable that, when provided with a ``seed`` keyword provides an 50 object that operates identically to ``np.random.RandomState`` (the 51 default). This might also be a function that returns a 52 ``randomgen.RandomState``, ``mkl_random``, or 53 ``cupy.random.RandomState`` object. 54 55 Examples 56 -------- 57 >>> import dask.array as da 58 >>> state = da.random.RandomState(1234) # a seed 59 >>> x = state.normal(10, 0.1, size=3, chunks=(2,)) 60 >>> x.compute() 61 array([10.01867852, 10.04812289, 9.89649746]) 62 63 See Also 64 -------- 65 np.random.RandomState 66 """ 67 68 def __init__(self, seed=None, RandomState=None): 69 self._numpy_state = np.random.RandomState(seed) 70 self._RandomState = RandomState 71 72 def seed(self, seed=None): 73 self._numpy_state.seed(seed) 74 75 def _wrap( 76 self, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs 77 ): 78 """Wrap numpy random function to produce dask.array random function 79 80 extra_chunks should be a chunks tuple to append to the end of chunks 81 """ 82 if size is not None and not isinstance(size, (tuple, list)): 83 size = (size,) 84 85 shapes = list( 86 { 87 ar.shape 88 for ar in chain(args, kwargs.values()) 89 if isinstance(ar, (Array, np.ndarray)) 90 } 91 ) 92 if size is not None: 93 shapes.append(size) 94 # broadcast to the final size(shape) 95 size = broadcast_shapes(*shapes) 96 chunks = normalize_chunks( 97 chunks, 98 size, # ideally would use dtype here 99 dtype=kwargs.get("dtype", np.float64), 100 ) 101 slices = slices_from_chunks(chunks) 102 103 def _broadcast_any(ar, shape, chunks): 104 if isinstance(ar, Array): 105 return broadcast_to(ar, shape).rechunk(chunks) 106 if isinstance(ar, np.ndarray): 107 return np.ascontiguousarray(np.broadcast_to(ar, shape)) 108 109 # Broadcast all arguments, get tiny versions as well 110 # Start adding the relevant bits to the graph 111 dsk = {} 112 lookup = {} 113 small_args = [] 114 dependencies = [] 115 for i, ar in enumerate(args): 116 if isinstance(ar, (np.ndarray, Array)): 117 res = _broadcast_any(ar, size, chunks) 118 if isinstance(res, Array): 119 dependencies.append(res) 120 lookup[i] = res.name 121 elif isinstance(res, np.ndarray): 122 name = f"array-{tokenize(res)}" 123 lookup[i] = name 124 dsk[name] = res 125 small_args.append(ar[tuple(0 for _ in ar.shape)]) 126 else: 127 small_args.append(ar) 128 129 small_kwargs = {} 130 for key, ar in kwargs.items(): 131 if isinstance(ar, (np.ndarray, Array)): 132 res = _broadcast_any(ar, size, chunks) 133 if isinstance(res, Array): 134 dependencies.append(res) 135 lookup[key] = res.name 136 elif isinstance(res, np.ndarray): 137 name = f"array-{tokenize(res)}" 138 lookup[key] = name 139 dsk[name] = res 140 small_kwargs[key] = ar[tuple(0 for _ in ar.shape)] 141 else: 142 small_kwargs[key] = ar 143 144 sizes = list(product(*chunks)) 145 seeds = random_state_data(len(sizes), self._numpy_state) 146 token = tokenize(seeds, size, chunks, args, kwargs) 147 name = f"{funcname}-{token}" 148 149 keys = product( 150 [name], *([range(len(bd)) for bd in chunks] + [[0]] * len(extra_chunks)) 151 ) 152 blocks = product(*[range(len(bd)) for bd in chunks]) 153 154 vals = [] 155 for seed, size, slc, block in zip(seeds, sizes, slices, blocks): 156 arg = [] 157 for i, ar in enumerate(args): 158 if i not in lookup: 159 arg.append(ar) 160 else: 161 if isinstance(ar, Array): 162 arg.append((lookup[i],) + block) 163 else: # np.ndarray 164 arg.append((getitem, lookup[i], slc)) 165 kwrg = {} 166 for k, ar in kwargs.items(): 167 if k not in lookup: 168 kwrg[k] = ar 169 else: 170 if isinstance(ar, Array): 171 kwrg[k] = (lookup[k],) + block 172 else: # np.ndarray 173 kwrg[k] = (getitem, lookup[k], slc) 174 vals.append( 175 (_apply_random, self._RandomState, funcname, seed, size, arg, kwrg) 176 ) 177 178 meta = _apply_random( 179 self._RandomState, 180 funcname, 181 seed, 182 (0,) * len(size), 183 small_args, 184 small_kwargs, 185 ) 186 187 dsk.update(dict(zip(keys, vals))) 188 189 graph = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) 190 return Array(graph, name, chunks + extra_chunks, meta=meta) 191 192 @derived_from(np.random.RandomState, skipblocks=1) 193 def beta(self, a, b, size=None, chunks="auto", **kwargs): 194 return self._wrap("beta", a, b, size=size, chunks=chunks, **kwargs) 195 196 @derived_from(np.random.RandomState, skipblocks=1) 197 def binomial(self, n, p, size=None, chunks="auto", **kwargs): 198 return self._wrap("binomial", n, p, size=size, chunks=chunks, **kwargs) 199 200 @derived_from(np.random.RandomState, skipblocks=1) 201 def chisquare(self, df, size=None, chunks="auto", **kwargs): 202 return self._wrap("chisquare", df, size=size, chunks=chunks, **kwargs) 203 204 with contextlib.suppress(AttributeError): 205 206 @derived_from(np.random.RandomState, skipblocks=1) 207 def choice(self, a, size=None, replace=True, p=None, chunks="auto"): 208 dependencies = [] 209 # Normalize and validate `a` 210 if isinstance(a, Integral): 211 # On windows the output dtype differs if p is provided or 212 # absent, see https://github.com/numpy/numpy/issues/9867 213 dummy_p = np.array([1]) if p is not None else p 214 dtype = np.random.choice(1, size=(), p=dummy_p).dtype 215 len_a = a 216 if a < 0: 217 raise ValueError("a must be greater than 0") 218 else: 219 a = asarray(a) 220 a = a.rechunk(a.shape) 221 dtype = a.dtype 222 if a.ndim != 1: 223 raise ValueError("a must be one dimensional") 224 len_a = len(a) 225 dependencies.append(a) 226 a = a.__dask_keys__()[0] 227 228 # Normalize and validate `p` 229 if p is not None: 230 if not isinstance(p, Array): 231 # If p is not a dask array, first check the sum is close 232 # to 1 before converting. 233 p = np.asarray(p) 234 if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0): 235 raise ValueError("probabilities do not sum to 1") 236 p = asarray(p) 237 else: 238 p = p.rechunk(p.shape) 239 240 if p.ndim != 1: 241 raise ValueError("p must be one dimensional") 242 if len(p) != len_a: 243 raise ValueError("a and p must have the same size") 244 245 dependencies.append(p) 246 p = p.__dask_keys__()[0] 247 248 if size is None: 249 size = () 250 elif not isinstance(size, (tuple, list)): 251 size = (size,) 252 253 chunks = normalize_chunks(chunks, size, dtype=np.float64) 254 if not replace and len(chunks[0]) > 1: 255 err_msg = ( 256 "replace=False is not currently supported for " 257 "dask.array.choice with multi-chunk output " 258 "arrays" 259 ) 260 raise NotImplementedError(err_msg) 261 sizes = list(product(*chunks)) 262 state_data = random_state_data(len(sizes), self._numpy_state) 263 264 name = "da.random.choice-%s" % tokenize( 265 state_data, size, chunks, a, replace, p 266 ) 267 keys = product([name], *(range(len(bd)) for bd in chunks)) 268 dsk = { 269 k: (_choice, state, a, size, replace, p) 270 for k, state, size in zip(keys, state_data, sizes) 271 } 272 273 graph = HighLevelGraph.from_collections( 274 name, dsk, dependencies=dependencies 275 ) 276 return Array(graph, name, chunks, dtype=dtype) 277 278 # @derived_from(np.random.RandomState, skipblocks=1) 279 # def dirichlet(self, alpha, size=None, chunks="auto"): 280 281 @derived_from(np.random.RandomState, skipblocks=1) 282 def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs): 283 return self._wrap("exponential", scale, size=size, chunks=chunks, **kwargs) 284 285 @derived_from(np.random.RandomState, skipblocks=1) 286 def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs): 287 return self._wrap("f", dfnum, dfden, size=size, chunks=chunks, **kwargs) 288 289 @derived_from(np.random.RandomState, skipblocks=1) 290 def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs): 291 return self._wrap("gamma", shape, scale, size=size, chunks=chunks, **kwargs) 292 293 @derived_from(np.random.RandomState, skipblocks=1) 294 def geometric(self, p, size=None, chunks="auto", **kwargs): 295 return self._wrap("geometric", p, size=size, chunks=chunks, **kwargs) 296 297 @derived_from(np.random.RandomState, skipblocks=1) 298 def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): 299 return self._wrap("gumbel", loc, scale, size=size, chunks=chunks, **kwargs) 300 301 @derived_from(np.random.RandomState, skipblocks=1) 302 def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs): 303 return self._wrap( 304 "hypergeometric", ngood, nbad, nsample, size=size, chunks=chunks, **kwargs 305 ) 306 307 @derived_from(np.random.RandomState, skipblocks=1) 308 def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): 309 return self._wrap("laplace", loc, scale, size=size, chunks=chunks, **kwargs) 310 311 @derived_from(np.random.RandomState, skipblocks=1) 312 def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): 313 return self._wrap("logistic", loc, scale, size=size, chunks=chunks, **kwargs) 314 315 @derived_from(np.random.RandomState, skipblocks=1) 316 def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs): 317 return self._wrap("lognormal", mean, sigma, size=size, chunks=chunks, **kwargs) 318 319 @derived_from(np.random.RandomState, skipblocks=1) 320 def logseries(self, p, size=None, chunks="auto", **kwargs): 321 return self._wrap("logseries", p, size=size, chunks=chunks, **kwargs) 322 323 @derived_from(np.random.RandomState, skipblocks=1) 324 def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs): 325 return self._wrap( 326 "multinomial", 327 n, 328 pvals, 329 size=size, 330 chunks=chunks, 331 extra_chunks=((len(pvals),),), 332 ) 333 334 @derived_from(np.random.RandomState, skipblocks=1) 335 def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs): 336 return self._wrap("negative_binomial", n, p, size=size, chunks=chunks, **kwargs) 337 338 @derived_from(np.random.RandomState, skipblocks=1) 339 def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs): 340 return self._wrap( 341 "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs 342 ) 343 344 @derived_from(np.random.RandomState, skipblocks=1) 345 def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs): 346 return self._wrap( 347 "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs 348 ) 349 350 @derived_from(np.random.RandomState, skipblocks=1) 351 def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): 352 return self._wrap("normal", loc, scale, size=size, chunks=chunks, **kwargs) 353 354 @derived_from(np.random.RandomState, skipblocks=1) 355 def pareto(self, a, size=None, chunks="auto", **kwargs): 356 return self._wrap("pareto", a, size=size, chunks=chunks, **kwargs) 357 358 @derived_from(np.random.RandomState, skipblocks=1) 359 def permutation(self, x): 360 from .slicing import shuffle_slice 361 362 if isinstance(x, numbers.Number): 363 x = arange(x, chunks="auto") 364 365 index = np.arange(len(x)) 366 self._numpy_state.shuffle(index) 367 return shuffle_slice(x, index) 368 369 @derived_from(np.random.RandomState, skipblocks=1) 370 def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs): 371 return self._wrap("poisson", lam, size=size, chunks=chunks, **kwargs) 372 373 @derived_from(np.random.RandomState, skipblocks=1) 374 def power(self, a, size=None, chunks="auto", **kwargs): 375 return self._wrap("power", a, size=size, chunks=chunks, **kwargs) 376 377 @derived_from(np.random.RandomState, skipblocks=1) 378 def randint(self, low, high=None, size=None, chunks="auto", dtype="l", **kwargs): 379 return self._wrap( 380 "randint", low, high, size=size, chunks=chunks, dtype=dtype, **kwargs 381 ) 382 383 @derived_from(np.random.RandomState, skipblocks=1) 384 def random_integers(self, low, high=None, size=None, chunks="auto", **kwargs): 385 return self._wrap( 386 "random_integers", low, high, size=size, chunks=chunks, **kwargs 387 ) 388 389 @derived_from(np.random.RandomState, skipblocks=1) 390 def random_sample(self, size=None, chunks="auto", **kwargs): 391 return self._wrap("random_sample", size=size, chunks=chunks, **kwargs) 392 393 random = random_sample 394 395 @derived_from(np.random.RandomState, skipblocks=1) 396 def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs): 397 return self._wrap("rayleigh", scale, size=size, chunks=chunks, **kwargs) 398 399 @derived_from(np.random.RandomState, skipblocks=1) 400 def standard_cauchy(self, size=None, chunks="auto", **kwargs): 401 return self._wrap("standard_cauchy", size=size, chunks=chunks, **kwargs) 402 403 @derived_from(np.random.RandomState, skipblocks=1) 404 def standard_exponential(self, size=None, chunks="auto", **kwargs): 405 return self._wrap("standard_exponential", size=size, chunks=chunks, **kwargs) 406 407 @derived_from(np.random.RandomState, skipblocks=1) 408 def standard_gamma(self, shape, size=None, chunks="auto", **kwargs): 409 return self._wrap("standard_gamma", shape, size=size, chunks=chunks, **kwargs) 410 411 @derived_from(np.random.RandomState, skipblocks=1) 412 def standard_normal(self, size=None, chunks="auto", **kwargs): 413 return self._wrap("standard_normal", size=size, chunks=chunks, **kwargs) 414 415 @derived_from(np.random.RandomState, skipblocks=1) 416 def standard_t(self, df, size=None, chunks="auto", **kwargs): 417 return self._wrap("standard_t", df, size=size, chunks=chunks, **kwargs) 418 419 @derived_from(np.random.RandomState, skipblocks=1) 420 def tomaxint(self, size=None, chunks="auto", **kwargs): 421 return self._wrap("tomaxint", size=size, chunks=chunks, **kwargs) 422 423 @derived_from(np.random.RandomState, skipblocks=1) 424 def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs): 425 return self._wrap( 426 "triangular", left, mode, right, size=size, chunks=chunks, **kwargs 427 ) 428 429 @derived_from(np.random.RandomState, skipblocks=1) 430 def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs): 431 return self._wrap("uniform", low, high, size=size, chunks=chunks, **kwargs) 432 433 @derived_from(np.random.RandomState, skipblocks=1) 434 def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs): 435 return self._wrap("vonmises", mu, kappa, size=size, chunks=chunks, **kwargs) 436 437 @derived_from(np.random.RandomState, skipblocks=1) 438 def wald(self, mean, scale, size=None, chunks="auto", **kwargs): 439 return self._wrap("wald", mean, scale, size=size, chunks=chunks, **kwargs) 440 441 @derived_from(np.random.RandomState, skipblocks=1) 442 def weibull(self, a, size=None, chunks="auto", **kwargs): 443 return self._wrap("weibull", a, size=size, chunks=chunks, **kwargs) 444 445 @derived_from(np.random.RandomState, skipblocks=1) 446 def zipf(self, a, size=None, chunks="auto", **kwargs): 447 return self._wrap("zipf", a, size=size, chunks=chunks, **kwargs) 448 449 450def _choice(state_data, a, size, replace, p): 451 state = np.random.RandomState(state_data) 452 return state.choice(a, size=size, replace=replace, p=p) 453 454 455def _apply_random(RandomState, funcname, state_data, size, args, kwargs): 456 """Apply RandomState method with seed""" 457 if RandomState is None: 458 RandomState = np.random.RandomState 459 state = RandomState(state_data) 460 func = getattr(state, funcname) 461 return func(*args, size=size, **kwargs) 462 463 464_state = RandomState() 465 466 467seed = _state.seed 468 469 470beta = _state.beta 471binomial = _state.binomial 472chisquare = _state.chisquare 473if hasattr(_state, "choice"): 474 choice = _state.choice 475exponential = _state.exponential 476f = _state.f 477gamma = _state.gamma 478geometric = _state.geometric 479gumbel = _state.gumbel 480hypergeometric = _state.hypergeometric 481laplace = _state.laplace 482logistic = _state.logistic 483lognormal = _state.lognormal 484logseries = _state.logseries 485multinomial = _state.multinomial 486negative_binomial = _state.negative_binomial 487noncentral_chisquare = _state.noncentral_chisquare 488noncentral_f = _state.noncentral_f 489normal = _state.normal 490pareto = _state.pareto 491permutation = _state.permutation 492poisson = _state.poisson 493power = _state.power 494rayleigh = _state.rayleigh 495random_sample = _state.random_sample 496random = random_sample 497randint = _state.randint 498random_integers = _state.random_integers 499triangular = _state.triangular 500uniform = _state.uniform 501vonmises = _state.vonmises 502wald = _state.wald 503weibull = _state.weibull 504zipf = _state.zipf 505 506""" 507Standard distributions 508""" 509 510standard_cauchy = _state.standard_cauchy 511standard_exponential = _state.standard_exponential 512standard_gamma = _state.standard_gamma 513standard_normal = _state.standard_normal 514standard_t = _state.standard_t 515