1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2021 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15
16 /**@file bandit_epsgreedy.c
17 * @ingroup OTHER_CFILES
18 * @brief implementation of epsilon greedy bandit algorithm
19 * @author Gregor Hendel
20 */
21
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23
24 #include "scip/bandit.h"
25 #include "scip/bandit_epsgreedy.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32
33 #define BANDIT_NAME "eps-greedy"
34 #define EPSGREEDY_SMALL 1e-6
35
36 /*
37 * Data structures
38 */
39
40 /** private data structure of epsilon greedy bandit algorithm */
41 struct SCIP_BanditData
42 {
43 SCIP_Real* weights; /**< weights for every action */
44 SCIP_Real* priorities; /**< saved priorities for tie breaking */
45 int* sels; /**< individual number of selections per action */
46 SCIP_Real eps; /**< epsilon parameter (between 0 and 1) to control epsilon greedy */
47 SCIP_Real decayfactor; /**< the factor to reduce the weight of older observations if exponential decay is enabled */
48 int avglim; /**< nonnegative limit on observation number before the exponential decay starts,
49 * only relevant if exponential decay is enabled
50 */
51 int nselections; /**< counter for the number of selection calls */
52 SCIP_Bool preferrecent; /**< should the weights be updated in an exponentially decaying way? */
53 };
54
55 /*
56 * Callback methods of bandit algorithm virtual function table
57 */
58
59 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy)60 SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy)
61 { /*lint --e{715}*/
62 SCIP_BANDITDATA* banditdata;
63 int nactions;
64
65 assert(bandit != NULL);
66
67 banditdata = SCIPbanditGetData(bandit);
68 assert(banditdata != NULL);
69 assert(banditdata->weights != NULL);
70 nactions = SCIPbanditGetNActions(bandit);
71
72 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
73 BMSfreeBlockMemoryArray(blkmem, &banditdata->priorities, nactions);
74 BMSfreeBlockMemoryArray(blkmem, &banditdata->sels, nactions);
75 BMSfreeBlockMemory(blkmem, &banditdata);
76
77 SCIPbanditSetData(bandit, NULL);
78
79 return SCIP_OKAY;
80 }
81
82 /** selection callback for bandit algorithm */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy)83 SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy)
84 { /*lint --e{715}*/
85 SCIP_BANDITDATA* banditdata;
86 SCIP_Real randnr;
87 SCIP_Real curreps;
88 SCIP_RANDNUMGEN* rng;
89 int nactions;
90 assert(bandit != NULL);
91 assert(selection != NULL);
92
93 banditdata = SCIPbanditGetData(bandit);
94 assert(banditdata != NULL);
95 rng = SCIPbanditGetRandnumgen(bandit);
96 assert(rng != NULL);
97
98 nactions = SCIPbanditGetNActions(bandit);
99
100 /* roll the dice to check if the best element should be picked, or an element at random */
101 randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
102
103 /* make epsilon decrease with an increasing number of selections */
104 banditdata->nselections++;
105 curreps = banditdata->eps * sqrt((SCIP_Real)nactions/(SCIP_Real)banditdata->nselections);
106
107 /* select the best action seen so far */
108 if( randnr >= curreps )
109 {
110 SCIP_Real* weights = banditdata->weights;
111 SCIP_Real* priorities = banditdata->priorities;
112 int j;
113 SCIP_Real maxweight;
114
115 assert(weights != NULL);
116 assert(priorities != NULL);
117
118 /* pick the element with the largest reward */
119 maxweight = weights[0];
120 *selection = 0;
121
122 /* determine reward for every element */
123 for( j = 1; j < nactions; ++j )
124 {
125 SCIP_Real weight = weights[j];
126
127 /* select the action that maximizes the reward, breaking ties by action priorities */
128 if( maxweight < weight
129 || (weight >= maxweight - EPSGREEDY_SMALL && priorities[j] > priorities[*selection] ) )
130 {
131 *selection = j;
132 maxweight = weight;
133 }
134 }
135 }
136 else
137 {
138 /* play one of the actions at random */
139 *selection = SCIPrandomGetInt(rng, 0, nactions - 1);
140 }
141
142 return SCIP_OKAY;
143 }
144
145 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy)146 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy)
147 { /*lint --e{715}*/
148 SCIP_BANDITDATA* banditdata;
149
150 assert(bandit != NULL);
151
152 banditdata = SCIPbanditGetData(bandit);
153 assert(banditdata != NULL);
154
155 /* increase the selection count */
156 ++banditdata->sels[selection];
157
158 /* the very first observation is directly stored as weight for both average or exponential decay */
159 if( banditdata->sels[selection] == 1 )
160 banditdata->weights[selection] = score;
161 else
162 {
163 /* use exponentially decreasing weights for older observations */
164 if( banditdata->preferrecent && banditdata->sels[selection] > banditdata->avglim )
165 {
166 /* decrease old weights by decay factor */
167 banditdata->weights[selection] *= banditdata->decayfactor;
168 banditdata->weights[selection] += (1.0 - banditdata->decayfactor) * score;
169 }
170 else
171 {
172 /* update average score */
173 SCIP_Real diff = score - banditdata->weights[selection];
174 banditdata->weights[selection] += diff / (SCIP_Real)(banditdata->sels[selection]);
175 }
176 }
177
178 return SCIP_OKAY;
179 }
180
181 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy)182 SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy)
183 { /*lint --e{715}*/
184 SCIP_BANDITDATA* banditdata;
185 SCIP_Real* weights;
186 int w;
187 int nactions;
188 SCIP_RANDNUMGEN* rng;
189
190 assert(bandit != NULL);
191
192 banditdata = SCIPbanditGetData(bandit);
193 assert(banditdata != NULL);
194
195 weights = banditdata->weights;
196 nactions = SCIPbanditGetNActions(bandit);
197 assert(weights != NULL);
198 assert(banditdata->priorities != NULL);
199 assert(nactions > 0);
200
201 rng = SCIPbanditGetRandnumgen(bandit);
202 assert(rng != NULL);
203
204 /* alter priorities slightly to make them unique */
205 if( priorities != NULL )
206 {
207 for( w = 1; w < nactions; ++w )
208 {
209 assert(priorities[w] >= 0);
210 banditdata->priorities[w] = priorities[w] + SCIPrandomGetReal(rng, -EPSGREEDY_SMALL, EPSGREEDY_SMALL);
211 }
212 }
213 else
214 {
215 /* use random priorities */
216 for( w = 0; w < nactions; ++w )
217 banditdata->priorities[w] = SCIPrandomGetReal(rng, 0.0, 1.0);
218 }
219
220 /* reset weights and selection counters to 0 */
221 BMSclearMemoryArray(weights, nactions);
222 BMSclearMemoryArray(banditdata->sels, nactions);
223
224 banditdata->nselections = 0;
225
226 return SCIP_OKAY;
227 }
228
229 /*
230 * interface methods of the Epsilon Greedy bandit algorithm
231 */
232
233 /** internal method to create and reset epsilon greedy bandit algorithm */
SCIPbanditCreateEpsgreedy(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** epsgreedy,SCIP_Real * priorities,SCIP_Real eps,SCIP_Bool preferrecent,SCIP_Real decayfactor,int avglim,int nactions,unsigned int initseed)234 SCIP_RETCODE SCIPbanditCreateEpsgreedy(
235 BMS_BLKMEM* blkmem, /**< block memory */
236 BMS_BUFMEM* bufmem, /**< buffer memory */
237 SCIP_BANDITVTABLE* vtable, /**< virtual function table with epsilon greedy callbacks */
238 SCIP_BANDIT** epsgreedy, /**< pointer to store the epsilon greedy bandit algorithm */
239 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
240 SCIP_Real eps, /**< parameter to increase probability for exploration between all actions */
241 SCIP_Bool preferrecent, /**< should the weights be updated in an exponentially decaying way? */
242 SCIP_Real decayfactor, /**< the factor to reduce the weight of older observations if exponential decay is enabled */
243 int avglim, /**< nonnegative limit on observation number before the exponential decay starts,
244 * only relevant if exponential decay is enabled */
245 int nactions, /**< the positive number of possible actions */
246 unsigned int initseed /**< initial random seed */
247 )
248 {
249 SCIP_BANDITDATA* banditdata;
250
251 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
252 assert(banditdata != NULL);
253 assert(eps >= 0.0);
254
255 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
256 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->priorities, nactions) );
257 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->sels, nactions) );
258 banditdata->eps = eps;
259 banditdata->nselections = 0;
260 banditdata->preferrecent = preferrecent;
261 banditdata->decayfactor = decayfactor;
262 banditdata->avglim = avglim;
263
264 SCIP_CALL( SCIPbanditCreate(epsgreedy, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
265
266 return SCIP_OKAY;
267 }
268
269 /** create and resets an epsilon greedy bandit algorithm */
SCIPcreateBanditEpsgreedy(SCIP * scip,SCIP_BANDIT ** epsgreedy,SCIP_Real * priorities,SCIP_Real eps,SCIP_Bool preferrecent,SCIP_Real decayfactor,int avglim,int nactions,unsigned int initseed)270 SCIP_RETCODE SCIPcreateBanditEpsgreedy(
271 SCIP* scip, /**< SCIP data structure */
272 SCIP_BANDIT** epsgreedy, /**< pointer to store the epsilon greedy bandit algorithm */
273 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
274 SCIP_Real eps, /**< parameter to increase probability for exploration between all actions */
275 SCIP_Bool preferrecent, /**< should the weights be updated in an exponentially decaying way? */
276 SCIP_Real decayfactor, /**< the factor to reduce the weight of older observations if exponential decay is enabled */
277 int avglim, /**< nonnegative limit on observation number before the exponential decay starts,
278 * only relevant if exponential decay is enabled */
279 int nactions, /**< the positive number of possible actions */
280 unsigned int initseed /**< initial seed for random number generation */
281 )
282 {
283 SCIP_BANDITVTABLE* vtable;
284 assert(scip != NULL);
285 assert(epsgreedy != NULL);
286
287 vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
288 if( vtable == NULL )
289 {
290 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
291 return SCIP_INVALIDDATA;
292 }
293
294 SCIP_CALL( SCIPbanditCreateEpsgreedy(SCIPblkmem(scip), SCIPbuffer(scip), vtable, epsgreedy,
295 priorities, eps, preferrecent, decayfactor, avglim, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
296
297 return SCIP_OKAY;
298 }
299
300 /** get weights array of epsilon greedy bandit algorithm */
SCIPgetWeightsEpsgreedy(SCIP_BANDIT * epsgreedy)301 SCIP_Real* SCIPgetWeightsEpsgreedy(
302 SCIP_BANDIT* epsgreedy /**< epsilon greedy bandit algorithm */
303 )
304 {
305 SCIP_BANDITDATA* banditdata;
306 assert(epsgreedy != NULL);
307 banditdata = SCIPbanditGetData(epsgreedy);
308 assert(banditdata != NULL);
309
310 return banditdata->weights;
311 }
312
313 /** set epsilon parameter of epsilon greedy bandit algorithm */
SCIPsetEpsilonEpsgreedy(SCIP_BANDIT * epsgreedy,SCIP_Real eps)314 void SCIPsetEpsilonEpsgreedy(
315 SCIP_BANDIT* epsgreedy, /**< epsilon greedy bandit algorithm */
316 SCIP_Real eps /**< parameter to increase probability for exploration between all actions */
317 )
318 {
319 SCIP_BANDITDATA* banditdata;
320 assert(epsgreedy != NULL);
321 assert(eps >= 0);
322
323 banditdata = SCIPbanditGetData(epsgreedy);
324
325 banditdata->eps = eps;
326 }
327
328
329 /** creates the epsilon greedy bandit algorithm includes it in SCIP */
SCIPincludeBanditvtableEpsgreedy(SCIP * scip)330 SCIP_RETCODE SCIPincludeBanditvtableEpsgreedy(
331 SCIP* scip /**< SCIP data structure */
332 )
333 {
334 SCIP_BANDITVTABLE* banditvtable;
335
336 SCIP_CALL( SCIPincludeBanditvtable(scip, &banditvtable, BANDIT_NAME,
337 SCIPbanditFreeEpsgreedy, SCIPbanditSelectEpsgreedy, SCIPbanditUpdateEpsgreedy, SCIPbanditResetEpsgreedy) );
338
339 return SCIP_OKAY;
340 }
341