1% Copyright (C) 2012-2017,2018 John E. Davis
2%
3% This file is part of the S-Lang Library and may be distributed under the
4% terms of the GNU General Public License.  See the file COPYING for
5% more information.
6%---------------------------------------------------------------------------
7import ("rand");
8
9private define get_generator_args (nargs, num_parms, parmsp, rtp, nump,
10				   usage_str)
11{
12   @rtp = NULL;
13   @nump = NULL;
14   if (nargs == num_parms)
15     {
16	@parmsp = __pop_list (num_parms);
17	return;
18     }
19
20   if (nargs == num_parms + 1)
21     {
22	@nump = ();
23	variable parms = __pop_list (num_parms);
24	if (typeof (parms[0]) == Rand_Type)
25	  {
26	     @rtp = list_pop (parms);
27	     list_append (parms, @nump);
28	     @nump = NULL;
29	  }
30	@parmsp = parms;
31	return;
32     }
33
34   if (nargs == num_parms + 2)
35     {
36	@nump = ();
37	@parmsp = __pop_list (num_parms);
38	variable rt = ();
39	if (typeof (rt) == Rand_Type)
40	  {
41	     @rtp = rt;
42	     return;
43	  }
44     }
45   else _pop_n (nargs);
46
47   usage (usage_str);
48}
49
50private define call_rand_func ()
51{
52   variable num = ();
53   variable args = __pop_list (_NARGS-3);
54   variable rt, func;
55   (func, rt) = ();
56
57   if (rt == NULL)
58     {
59	if (num == NULL)
60	  return (@func) (__push_list(args));
61
62	return (@func) (__push_list(args), num);
63     }
64
65   if (num == NULL)
66     return (@func)(rt, __push_list(args));
67
68   return (@func)(rt, __push_list(args), num);
69}
70
71define rand_flat ()
72{
73   variable parms, rt, num;
74
75   get_generator_args (_NARGS, 2, &parms, &rt, &num,
76		       "r = rand_flat ([Rand_Type,] xmin, xmax [,num])");
77
78   variable r = call_rand_func (&rand_uniform, rt, num);
79
80   return parms[0] + (parms[1] - parms[0])*__tmp(r);
81}
82
83define rand_chisq ()
84{
85   variable parms, rt, num;
86
87   get_generator_args (_NARGS, 1, &parms, &rt, &num,
88		       "r = rand_chisq ([Rand_Type,] nu [,num])");
89   return 2.0 * call_rand_func (&rand_gamma, rt, 0.5*parms[0], 1.0, num);
90}
91
92define rand_fdist ()
93{
94   variable parms, rt, num;
95
96   get_generator_args (_NARGS, 2, &parms, &rt, &num,
97		       "r = rand_fdist ([Rand_Type,] nu1, nu2 [,num])");
98   variable nu1 = parms[0], nu2 = parms[1];
99
100   return (call_rand_func (&rand_gamma, rt, 0.5*nu1, 1.0, num)/nu1)
101     / (call_rand_func(&rand_gamma, rt, 0.5*nu2, 1.0, num)/nu2);
102}
103
104define rand_tdist ()
105{
106   variable parms, rt, num;
107
108   get_generator_args (_NARGS, 1, &parms, &rt, &num,
109		       "r = rand_tdist ([Rand_Type,] nu, [,num])");
110   variable nu = parms[0];
111   return call_rand_func (&rand_gauss, rt, 1.0, num)
112     / sqrt(call_rand_func(&rand_chisq, rt, nu, num)/nu);
113}
114
115define rand_int ()
116{
117   variable parms, rt, num;
118
119   get_generator_args (_NARGS, 2, &parms, &rt, &num,
120		       "r = rand_int ([Rand_Type,] imin, imax [,num])");
121
122   variable
123     imin = typecast (parms[0], Int32_Type),
124     imax = typecast (parms[1], Int32_Type), di;
125
126   if (imin > imax)
127     throw InvalidParmError, "rand_int: imax < imin";
128
129   di = typecast (imax - imin, UInt32_Type);
130
131   variable r = call_rand_func (&rand, rt, num);
132
133   if (di + 1 != 0)		       % UINT32_MAX does not exist
134     r = __tmp(r) mod (di + 1);
135
136   return typecast (imin + r, Int32_Type);
137}
138
139define rand_exp ()
140{
141   variable parms, rt, num;
142
143   get_generator_args (_NARGS, 1, &parms, &rt, &num,
144		       "r = rand_exp ([Rand_Type,] beta [,num])");
145
146   return (-parms[0]) * log (call_rand_func (&rand_uniform_pos, rt, num));
147}
148
149private define make_indices (a, d, i)
150{
151   _for (0, length(array_shape(a))-1, 1)
152     {
153	variable j = ();
154	if (j == d)
155	  i;
156	else
157	  [:];
158     }
159}
160
161define rand_sample ()
162{
163   if (_NARGS < 2)
164     {
165	_pop_n (_NARGS);
166	usage ("(B1 [,B2,...]) = rand_sample ([Rand_Type,] A1 [A2,...], num)");
167     }
168
169   variable num = ();
170   variable arrays = __pop_list (_NARGS-1);
171   variable rt = NULL;
172
173   if (typeof (arrays[0]) == Rand_Type)
174     rt = list_pop (arrays);
175
176   variable n0 = NULL, dim0;
177   variable a, indices;
178
179   foreach a (arrays)
180     {
181	dim0 = array_shape (a)[0];
182	if (n0 == NULL)
183	  {
184	     n0 = dim0;
185	     continue;
186	  }
187	if (n0 != dim0)
188	  throw TypeMismatchError, "The arrays passed to rand_sample must have the same leading dimension";
189     }
190
191   if (num > n0)
192     num = n0;
193
194   if (rt == NULL)
195     indices = rand_permutation (n0);
196   else
197     indices = rand_permutation (rt, n0);
198   if (num < n0)
199     indices = indices[[0:num-1]];
200
201   foreach a (arrays)
202     a[make_indices(a, 0, indices)];
203}
204
205