1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2021 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15
16 /**@file bandit_exp3.c
17 * @ingroup OTHER_CFILES
18 * @brief methods for Exp.3 bandit selection
19 * @author Gregor Hendel
20 */
21
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23
24 #include "scip/bandit.h"
25 #include "scip/bandit_exp3.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32
33 #define BANDIT_NAME "exp3"
34 #define NUMTOL 1e-6
35
36 /*
37 * Data structures
38 */
39
40 /** implementation specific data of Exp.3 bandit algorithm */
41 struct SCIP_BanditData
42 {
43 SCIP_Real* weights; /**< exponential weight for each arm */
44 SCIP_Real weightsum; /**< the sum of all weights */
45 SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
46 SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
47 };
48
49 /*
50 * Local methods
51 */
52
53 /*
54 * Callback methods of bandit algorithm
55 */
56
57 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)58 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
59 { /*lint --e{715}*/
60 SCIP_BANDITDATA* banditdata;
61 int nactions;
62 assert(bandit != NULL);
63
64 banditdata = SCIPbanditGetData(bandit);
65 assert(banditdata != NULL);
66 nactions = SCIPbanditGetNActions(bandit);
67
68 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
69
70 BMSfreeBlockMemory(blkmem, &banditdata);
71
72 SCIPbanditSetData(bandit, NULL);
73
74 return SCIP_OKAY;
75 }
76
77 /** selection callback for bandit selector */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)78 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
79 { /*lint --e{715}*/
80 SCIP_BANDITDATA* banditdata;
81 SCIP_RANDNUMGEN* rng;
82 SCIP_Real randnr;
83 SCIP_Real psum;
84 SCIP_Real gammaoverk;
85 SCIP_Real oneminusgamma;
86 SCIP_Real* weights;
87 SCIP_Real weightsum;
88 int i;
89 int nactions;
90
91 assert(bandit != NULL);
92 assert(selection != NULL);
93
94 banditdata = SCIPbanditGetData(bandit);
95 assert(banditdata != NULL);
96 rng = SCIPbanditGetRandnumgen(bandit);
97 assert(rng != NULL);
98 nactions = SCIPbanditGetNActions(bandit);
99
100 /* draw a random number between 0 and 1 */
101 randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
102
103 /* initialize some local variables to speed up probability computations */
104 oneminusgamma = 1 - banditdata->gamma;
105 gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
106 weightsum = banditdata->weightsum;
107 weights = banditdata->weights;
108 psum = 0.0;
109
110 /* loop over probability distribution until rand is reached
111 * the loop terminates without looking at the last action,
112 * which is then selected automatically if the target probability
113 * is not reached earlier
114 */
115 for( i = 0; i < nactions - 1; ++i )
116 {
117 SCIP_Real prob;
118
119 /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
120 prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
121 psum += prob;
122
123 /* break and select element if target probability is reached */
124 if( randnr <= psum )
125 break;
126 }
127
128 /* select element i, which is the last action in case that the break statement hasn't been reached */
129 *selection = i;
130
131 return SCIP_OKAY;
132 }
133
134 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)135 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
136 { /*lint --e{715}*/
137 SCIP_BANDITDATA* banditdata;
138 SCIP_Real eta;
139 SCIP_Real gainestim;
140 SCIP_Real beta;
141 SCIP_Real weightsum;
142 SCIP_Real newweightsum;
143 SCIP_Real* weights;
144 SCIP_Real oneminusgamma;
145 SCIP_Real gammaoverk;
146 int nactions;
147
148 assert(bandit != NULL);
149
150 banditdata = SCIPbanditGetData(bandit);
151 assert(banditdata != NULL);
152 nactions = SCIPbanditGetNActions(bandit);
153
154 assert(selection >= 0);
155 assert(selection < nactions);
156
157 /* the learning rate eta */
158 eta = 1.0 / (SCIP_Real)nactions;
159
160 beta = banditdata->beta;
161 oneminusgamma = 1.0 - banditdata->gamma;
162 gammaoverk = banditdata->gamma * eta;
163 weights = banditdata->weights;
164 weightsum = banditdata->weightsum;
165 newweightsum = weightsum;
166
167 /* if beta is zero, only the observation for the current arm needs an update */
168 if( EPSZ(beta, NUMTOL) )
169 {
170 SCIP_Real probai;
171 probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
172
173 assert(probai > 0.0);
174
175 gainestim = score / probai;
176 newweightsum -= weights[selection];
177 weights[selection] *= exp(eta * gainestim);
178 newweightsum += weights[selection];
179 }
180 else
181 {
182 int j;
183 newweightsum = 0.0;
184
185 /* loop over all items and update their weights based on the influence of the beta parameter */
186 for( j = 0; j < nactions; ++j )
187 {
188 SCIP_Real probaj;
189 probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
190
191 assert(probaj > 0.0);
192
193 /* consider the score only for the chosen arm i, use constant beta offset otherwise */
194 if( j == selection )
195 gainestim = (score + beta) / probaj;
196 else
197 gainestim = beta / probaj;
198
199 weights[j] *= exp(eta * gainestim);
200 newweightsum += weights[j];
201 }
202 }
203
204 banditdata->weightsum = newweightsum;
205
206 return SCIP_OKAY;
207 }
208
209 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)210 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
211 { /*lint --e{715}*/
212 SCIP_BANDITDATA* banditdata;
213 SCIP_Real* weights;
214 int nactions;
215 int i;
216
217 assert(bandit != NULL);
218
219 banditdata = SCIPbanditGetData(bandit);
220 assert(banditdata != NULL);
221 nactions = SCIPbanditGetNActions(bandit);
222 weights = banditdata->weights;
223
224 assert(nactions > 0);
225
226 banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
227
228 /* in case of priorities, weights are normalized to sum up to nactions */
229 if( priorities != NULL )
230 {
231 SCIP_Real normalization;
232 SCIP_Real priosum;
233 priosum = 0.0;
234
235 /* compute sum of priorities */
236 for( i = 0; i < nactions; ++i )
237 {
238 assert(priorities[i] >= 0);
239 priosum += priorities[i];
240 }
241
242 /* if there are positive priorities, normalize the weights */
243 if( priosum > 0.0 )
244 {
245 normalization = nactions / priosum;
246 for( i = 0; i < nactions; ++i )
247 weights[i] = (priorities[i] * normalization) + NUMTOL;
248 }
249 else
250 {
251 /* use uniform distribution in case of all priorities being 0.0 */
252 for( i = 0; i < nactions; ++i )
253 weights[i] = 1.0 + NUMTOL;
254 }
255 }
256 else
257 {
258 /* use uniform distribution in case of unspecified priorities */
259 for( i = 0; i < nactions; ++i )
260 weights[i] = 1.0 + NUMTOL;
261 }
262
263 return SCIP_OKAY;
264 }
265
266
267 /*
268 * bandit algorithm specific interface methods
269 */
270
271 /** direct bandit creation method for the core where no SCIP pointer is available */
SCIPbanditCreateExp3(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** exp3,SCIP_Real * priorities,SCIP_Real gammaparam,SCIP_Real beta,int nactions,unsigned int initseed)272 SCIP_RETCODE SCIPbanditCreateExp3(
273 BMS_BLKMEM* blkmem, /**< block memory data structure */
274 BMS_BUFMEM* bufmem, /**< buffer memory */
275 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
276 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
277 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
278 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
279 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
280 int nactions, /**< the positive number of actions for this bandit algorithm */
281 unsigned int initseed /**< initial random seed */
282 )
283 {
284 SCIP_BANDITDATA* banditdata;
285
286 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
287 assert(banditdata != NULL);
288
289 banditdata->gamma = gammaparam;
290 banditdata->beta = beta;
291 assert(gammaparam >= 0 && gammaparam <= 1);
292 assert(beta >= 0 && beta <= 1);
293
294 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
295
296 SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
297
298 return SCIP_OKAY;
299 }
300
301 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
SCIPcreateBanditExp3(SCIP * scip,SCIP_BANDIT ** exp3,SCIP_Real * priorities,SCIP_Real gammaparam,SCIP_Real beta,int nactions,unsigned int initseed)302 SCIP_RETCODE SCIPcreateBanditExp3(
303 SCIP* scip, /**< SCIP data structure */
304 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
305 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
306 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
307 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
308 int nactions, /**< the positive number of actions for this bandit algorithm */
309 unsigned int initseed /**< initial seed for random number generation */
310 )
311 {
312 SCIP_BANDITVTABLE* vtable;
313
314 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
315 if( vtable == NULL )
316 {
317 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
318 return SCIP_INVALIDDATA;
319 }
320
321 SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
322 priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
323
324 return SCIP_OKAY;
325 }
326
327 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
SCIPsetGammaExp3(SCIP_BANDIT * exp3,SCIP_Real gammaparam)328 void SCIPsetGammaExp3(
329 SCIP_BANDIT* exp3, /**< bandit algorithm */
330 SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
331 )
332 {
333 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
334
335 assert(gammaparam >= 0 && gammaparam <= 1);
336
337 banditdata->gamma = gammaparam;
338 }
339
340 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
SCIPsetBetaExp3(SCIP_BANDIT * exp3,SCIP_Real beta)341 void SCIPsetBetaExp3(
342 SCIP_BANDIT* exp3, /**< bandit algorithm */
343 SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
344 )
345 {
346 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
347
348 assert(beta >= 0 && beta <= 1);
349
350 banditdata->beta = beta;
351 }
352
353 /** returns probability to play an action */
SCIPgetProbabilityExp3(SCIP_BANDIT * exp3,int action)354 SCIP_Real SCIPgetProbabilityExp3(
355 SCIP_BANDIT* exp3, /**< bandit algorithm */
356 int action /**< index of the requested action */
357 )
358 {
359 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
360
361 assert(banditdata->weightsum > 0.0);
362 assert(SCIPbanditGetNActions(exp3) > 0);
363
364 return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
365 }
366
367 /** include virtual function table for Exp.3 bandit algorithms */
SCIPincludeBanditvtableExp3(SCIP * scip)368 SCIP_RETCODE SCIPincludeBanditvtableExp3(
369 SCIP* scip /**< SCIP data structure */
370 )
371 {
372 SCIP_BANDITVTABLE* vtable;
373
374 SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
375 SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
376 assert(vtable != NULL);
377
378 return SCIP_OKAY;
379 }
380