1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2021 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15
16 /**@file bandit_ucb.c
17 * @ingroup OTHER_CFILES
18 * @brief methods for UCB bandit selection
19 * @author Gregor Hendel
20 */
21
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23
24 #include "scip/bandit.h"
25 #include "scip/bandit_ucb.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/pub_misc_sort.h"
30 #include "scip/scip_bandit.h"
31 #include "scip/scip_mem.h"
32 #include "scip/scip_randnumgen.h"
33
34
35 #define BANDIT_NAME "ucb"
36 #define NUMEPS 1e-6
37
38 /*
39 * Data structures
40 */
41
42 /** implementation specific data of UCB bandit algorithm */
43 struct SCIP_BanditData
44 {
45 int nselections; /**< counter for the number of selections */
46 int* counter; /**< array of counters how often every action has been chosen */
47 int* startperm; /**< indices for starting permutation */
48 SCIP_Real* meanscores; /**< array of average scores for the actions */
49 SCIP_Real alpha; /**< parameter to increase confidence width */
50 };
51
52
53 /*
54 * Local methods
55 */
56
57 /** data reset method */
58 static
dataReset(BMS_BUFMEM * bufmem,SCIP_BANDIT * ucb,SCIP_BANDITDATA * banditdata,SCIP_Real * priorities,int nactions)59 SCIP_RETCODE dataReset(
60 BMS_BUFMEM* bufmem, /**< buffer memory */
61 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
62 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
63 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
64 int nactions /**< number of actions */
65 )
66 {
67 int i;
68 SCIP_RANDNUMGEN* rng;
69
70 assert(bufmem != NULL);
71 assert(ucb != NULL);
72 assert(nactions > 0);
73
74 /* clear counters and scores */
75 BMSclearMemoryArray(banditdata->counter, nactions);
76 BMSclearMemoryArray(banditdata->meanscores, nactions);
77 banditdata->nselections = 0;
78
79 rng = SCIPbanditGetRandnumgen(ucb);
80 assert(rng != NULL);
81
82 /* initialize start permutation as identity */
83 for( i = 0; i < nactions; ++i )
84 banditdata->startperm[i] = i;
85
86 /* prepare the start permutation in decreasing order of priority */
87 if( priorities != NULL )
88 {
89 SCIP_Real* prioritycopy;
90
91 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
92
93 /* randomly wiggle priorities a little bit to make them unique */
94 for( i = 0; i < nactions; ++i )
95 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
96
97 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
98
99 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
100 }
101 else
102 {
103 /* use a random start permutation */
104 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
105 }
106
107 return SCIP_OKAY;
108 }
109
110
111 /*
112 * Callback methods of bandit algorithm
113 */
114
115 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)116 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
117 { /*lint --e{715}*/
118 SCIP_BANDITDATA* banditdata;
119 int nactions;
120 assert(bandit != NULL);
121
122 banditdata = SCIPbanditGetData(bandit);
123 assert(banditdata != NULL);
124 nactions = SCIPbanditGetNActions(bandit);
125
126 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
127 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
128 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
129 BMSfreeBlockMemory(blkmem, &banditdata);
130
131 SCIPbanditSetData(bandit, NULL);
132
133 return SCIP_OKAY;
134 }
135
136 /** selection callback for bandit selector */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)137 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
138 { /*lint --e{715}*/
139 SCIP_BANDITDATA* banditdata;
140 int nactions;
141 int* counter;
142
143 assert(bandit != NULL);
144 assert(selection != NULL);
145
146 banditdata = SCIPbanditGetData(bandit);
147 assert(banditdata != NULL);
148 nactions = SCIPbanditGetNActions(bandit);
149
150 counter = banditdata->counter;
151 /* select the next uninitialized action from the start permutation */
152 if( banditdata->nselections < nactions )
153 {
154 *selection = banditdata->startperm[banditdata->nselections];
155 assert(counter[*selection] == 0);
156 }
157 else
158 {
159 /* select the action with the highest upper confidence bound */
160 SCIP_Real* meanscores;
161 SCIP_Real widthfactor;
162 SCIP_Real maxucb;
163 int i;
164 SCIP_RANDNUMGEN* rng = SCIPbanditGetRandnumgen(bandit);
165 meanscores = banditdata->meanscores;
166
167 assert(rng != NULL);
168 assert(meanscores != NULL);
169
170 /* compute the confidence width factor that is common for all actions */
171 /* cppcheck-suppress unpreciseMathCall */
172 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
173 widthfactor = sqrt(widthfactor);
174 maxucb = -1.0;
175
176 /* loop over the actions and determine the maximum upper confidence bound.
177 * The upper confidence bound of an action is the sum of its mean score
178 * plus a confidence term that decreases with increasing number of observations of
179 * this action.
180 */
181 for( i = 0; i < nactions; ++i )
182 {
183 SCIP_Real uppercb;
184 SCIP_Real rootcount;
185 assert(counter[i] > 0);
186
187 /* compute the upper confidence bound for action i */
188 uppercb = meanscores[i];
189 rootcount = sqrt((SCIP_Real)counter[i]);
190 uppercb += widthfactor / rootcount;
191 assert(uppercb > 0);
192
193 /* update maximum, breaking ties uniformly at random */
194 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
195 {
196 maxucb = uppercb;
197 *selection = i;
198 }
199 }
200 }
201
202 assert(*selection >= 0);
203 assert(*selection < nactions);
204
205 return SCIP_OKAY;
206 }
207
208 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)209 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
210 { /*lint --e{715}*/
211 SCIP_BANDITDATA* banditdata;
212 SCIP_Real delta;
213
214 assert(bandit != NULL);
215
216 banditdata = SCIPbanditGetData(bandit);
217 assert(banditdata != NULL);
218 assert(selection >= 0);
219 assert(selection < SCIPbanditGetNActions(bandit));
220
221 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
222 delta = score - banditdata->meanscores[selection];
223 ++banditdata->counter[selection];
224 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
225
226 banditdata->nselections++;
227
228 return SCIP_OKAY;
229 }
230
231 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)232 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
233 { /*lint --e{715}*/
234 SCIP_BANDITDATA* banditdata;
235 int nactions;
236
237 assert(bufmem != NULL);
238 assert(bandit != NULL);
239
240 banditdata = SCIPbanditGetData(bandit);
241 assert(banditdata != NULL);
242 nactions = SCIPbanditGetNActions(bandit);
243
244 /* call the data reset for the given priorities */
245 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
246
247 return SCIP_OKAY;
248 }
249
250 /*
251 * bandit algorithm specific interface methods
252 */
253
254 /** returns the upper confidence bound of a selected action */
SCIPgetConfidenceBoundUcb(SCIP_BANDIT * ucb,int action)255 SCIP_Real SCIPgetConfidenceBoundUcb(
256 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
257 int action /**< index of the queried action */
258 )
259 {
260 SCIP_Real uppercb;
261 SCIP_BANDITDATA* banditdata;
262 int nactions;
263
264 assert(ucb != NULL);
265 banditdata = SCIPbanditGetData(ucb);
266 nactions = SCIPbanditGetNActions(ucb);
267 assert(action < nactions);
268
269 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
270 if( banditdata->nselections < nactions )
271 return 1.0;
272
273 /* the bandit algorithm must have picked every action once */
274 assert(banditdata->counter[action] > 0);
275 uppercb = banditdata->meanscores[action];
276
277 /* cppcheck-suppress unpreciseMathCall */
278 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
279
280 return uppercb;
281 }
282
283 /** return start permutation of the UCB bandit algorithm */
SCIPgetStartPermutationUcb(SCIP_BANDIT * ucb)284 int* SCIPgetStartPermutationUcb(
285 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
286 )
287 {
288 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
289
290 assert(banditdata != NULL);
291
292 return banditdata->startperm;
293 }
294
295 /** internal method to create and reset UCB bandit algorithm */
SCIPbanditCreateUcb(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** ucb,SCIP_Real * priorities,SCIP_Real alpha,int nactions,unsigned int initseed)296 SCIP_RETCODE SCIPbanditCreateUcb(
297 BMS_BLKMEM* blkmem, /**< block memory */
298 BMS_BUFMEM* bufmem, /**< buffer memory */
299 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
300 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
301 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
302 SCIP_Real alpha, /**< parameter to increase confidence width */
303 int nactions, /**< the positive number of actions for this bandit algorithm */
304 unsigned int initseed /**< initial random seed */
305 )
306 {
307 SCIP_BANDITDATA* banditdata;
308
309 if( alpha < 0.0 )
310 {
311 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
312 return SCIP_INVALIDDATA;
313 }
314
315 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
316 assert(banditdata != NULL);
317
318 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
319 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
320 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
321
322 banditdata->alpha = alpha;
323
324 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
325
326 return SCIP_OKAY;
327 }
328
329 /** create and reset UCB bandit algorithm */
SCIPcreateBanditUcb(SCIP * scip,SCIP_BANDIT ** ucb,SCIP_Real * priorities,SCIP_Real alpha,int nactions,unsigned int initseed)330 SCIP_RETCODE SCIPcreateBanditUcb(
331 SCIP* scip, /**< SCIP data structure */
332 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
333 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
334 SCIP_Real alpha, /**< parameter to increase confidence width */
335 int nactions, /**< the positive number of actions for this bandit algorithm */
336 unsigned int initseed /**< initial random number seed */
337 )
338 {
339 SCIP_BANDITVTABLE* vtable;
340
341 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
342 if( vtable == NULL )
343 {
344 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
345 return SCIP_INVALIDDATA;
346 }
347
348 SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
349 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
350
351 return SCIP_OKAY;
352 }
353
354 /** include virtual function table for UCB bandit algorithms */
SCIPincludeBanditvtableUcb(SCIP * scip)355 SCIP_RETCODE SCIPincludeBanditvtableUcb(
356 SCIP* scip /**< SCIP data structure */
357 )
358 {
359 SCIP_BANDITVTABLE* vtable;
360
361 SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
362 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
363 assert(vtable != NULL);
364
365 return SCIP_OKAY;
366 }
367