1 #include "numpy/random/distributions.h"
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <stdbool.h>
5 
6 
7 /*
8  *  random_multivariate_hypergeometric_count
9  *
10  *  Draw variates from the multivariate hypergeometric distribution--
11  *  the "count" algorithm.
12  *
13  *  Parameters
14  *  ----------
15  *  bitgen_t *bitgen_state
16  *      Pointer to a `bitgen_t` instance.
17  *  int64_t total
18  *      The sum of the values in the array `colors`.  (This is redundant
19  *      information, but we know the caller has already computed it, so
20  *      we might as well use it.)
21  *  size_t num_colors
22  *      The length of the `colors` array.
23  *  int64_t *colors
24  *      The array of colors (i.e. the number of each type in the collection
25  *      from which the random variate is drawn).
26  *  int64_t nsample
27  *      The number of objects drawn without replacement for each variate.
28  *      `nsample` must not exceed sum(colors).  This condition is not checked;
29  *      it is assumed that the caller has already validated the value.
30  *  size_t num_variates
31  *      The number of variates to be produced and put in the array
32  *      pointed to by `variates`.  One variate is a vector of length
33  *      `num_colors`, so the array pointed to by `variates` must have length
34  *      `num_variates * num_colors`.
35  *  int64_t *variates
36  *      The array that will hold the result.  It must have length
37  *      `num_variates * num_colors`.
38  *      The array is not initialized in the function; it is expected that the
39  *      array has been initialized with zeros when the function is called.
40  *
41  *  Notes
42  *  -----
43  *  The "count" algorithm for drawing one variate is roughly equivalent to the
44  *  following numpy code:
45  *
46  *      choices = np.repeat(np.arange(len(colors)), colors)
47  *      selection = np.random.choice(choices, nsample, replace=False)
48  *      variate = np.bincount(selection, minlength=len(colors))
49  *
50  *  This function uses a temporary array with length sum(colors).
51  *
52  *  Assumptions on the arguments (not checked in the function):
53  *    *  colors[k] >= 0  for k in range(num_colors)
54  *    *  total = sum(colors)
55  *    *  0 <= nsample <= total
56  *    *  the product total * sizeof(size_t) does not exceed SIZE_MAX
57  *    *  the product num_variates * num_colors does not overflow
58  */
59 
random_multivariate_hypergeometric_count(bitgen_t * bitgen_state,int64_t total,size_t num_colors,int64_t * colors,int64_t nsample,size_t num_variates,int64_t * variates)60 int random_multivariate_hypergeometric_count(bitgen_t *bitgen_state,
61                       int64_t total,
62                       size_t num_colors, int64_t *colors,
63                       int64_t nsample,
64                       size_t num_variates, int64_t *variates)
65 {
66     size_t *choices;
67     bool more_than_half;
68 
69     if ((total == 0) || (nsample == 0) || (num_variates == 0)) {
70         // Nothing to do.
71         return 0;
72     }
73 
74     choices = malloc(total * (sizeof *choices));
75     if (choices == NULL) {
76         return -1;
77     }
78 
79     /*
80      *  If colors contains, for example, [3 2 5], then choices
81      *  will contain [0 0 0 1 1 2 2 2 2 2].
82      */
83     for (size_t i = 0, k = 0; i < num_colors; ++i) {
84         for (int64_t j = 0; j < colors[i]; ++j) {
85             choices[k] = i;
86             ++k;
87         }
88     }
89 
90     more_than_half = nsample > (total / 2);
91     if (more_than_half) {
92         nsample = total - nsample;
93     }
94 
95     for (size_t i = 0; i < num_variates * num_colors; i += num_colors) {
96         /*
97          *  Fisher-Yates shuffle, but only loop through the first
98          *  `nsample` entries of `choices`.  After the loop,
99          *  choices[:nsample] contains a random sample from the
100          *  the full array.
101          */
102         for (size_t j = 0; j < (size_t) nsample; ++j) {
103             size_t tmp, k;
104             // Note: nsample is not greater than total, so there is no danger
105             // of integer underflow in `(size_t) total - j - 1`.
106             k = j + (size_t) random_interval(bitgen_state,
107                                              (size_t) total - j - 1);
108             tmp = choices[k];
109             choices[k] = choices[j];
110             choices[j] = tmp;
111         }
112         /*
113          *  Count the number of occurrences of each value in choices[:nsample].
114          *  The result, stored in sample[i:i+num_colors], is the sample from
115          *  the multivariate hypergeometric distribution.
116          */
117         for (size_t j = 0; j < (size_t) nsample; ++j) {
118             variates[i + choices[j]] += 1;
119         }
120 
121         if (more_than_half) {
122             for (size_t k = 0; k < num_colors; ++k) {
123                 variates[i + k] = colors[k] - variates[i + k];
124             }
125         }
126     }
127 
128     free(choices);
129 
130     return 0;
131 }
132