1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /*                                                                           */
3 /*                  This file is part of the program and library             */
4 /*         SCIP --- Solving Constraint Integer Programs                      */
5 /*                                                                           */
6 /*    Copyright (C) 2002-2021 Konrad-Zuse-Zentrum                            */
7 /*                            fuer Informationstechnik Berlin                */
8 /*                                                                           */
9 /*  SCIP is distributed under the terms of the ZIB Academic License.         */
10 /*                                                                           */
11 /*  You should have received a copy of the ZIB Academic License              */
12 /*  along with SCIP; see the file COPYING. If not visit scipopt.org.         */
13 /*                                                                           */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file   bandit_epsgreedy.c
17  * @ingroup OTHER_CFILES
18  * @brief  implementation of epsilon greedy bandit algorithm
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include "scip/bandit.h"
25 #include "scip/bandit_epsgreedy.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32 
33 #define BANDIT_NAME           "eps-greedy"
34 #define EPSGREEDY_SMALL       1e-6
35 
36 /*
37  * Data structures
38  */
39 
40 /** private data structure of epsilon greedy bandit algorithm */
41 struct SCIP_BanditData
42 {
43    SCIP_Real*            weights;            /**< weights for every action */
44    SCIP_Real*            priorities;         /**< saved priorities for tie breaking */
45    int*                  sels;               /**< individual number of selections per action */
46    SCIP_Real             eps;                /**< epsilon parameter (between 0 and 1) to control epsilon greedy */
47    SCIP_Real             decayfactor;        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
48    int                   avglim;             /**< nonnegative limit on observation number before the exponential decay starts,
49                                                *  only relevant if exponential decay is enabled
50                                                */
51    int                   nselections;        /**< counter for the number of selection calls */
52    SCIP_Bool             preferrecent;       /**< should the weights be updated in an exponentially decaying way? */
53 };
54 
55 /*
56  * Callback methods of bandit algorithm virtual function table
57  */
58 
59 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy)60 SCIP_DECL_BANDITFREE(SCIPbanditFreeEpsgreedy)
61 {  /*lint --e{715}*/
62    SCIP_BANDITDATA* banditdata;
63    int nactions;
64 
65    assert(bandit != NULL);
66 
67    banditdata = SCIPbanditGetData(bandit);
68    assert(banditdata != NULL);
69    assert(banditdata->weights != NULL);
70    nactions = SCIPbanditGetNActions(bandit);
71 
72    BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
73    BMSfreeBlockMemoryArray(blkmem, &banditdata->priorities, nactions);
74    BMSfreeBlockMemoryArray(blkmem, &banditdata->sels, nactions);
75    BMSfreeBlockMemory(blkmem, &banditdata);
76 
77    SCIPbanditSetData(bandit, NULL);
78 
79    return SCIP_OKAY;
80 }
81 
82 /** selection callback for bandit algorithm */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy)83 SCIP_DECL_BANDITSELECT(SCIPbanditSelectEpsgreedy)
84 {  /*lint --e{715}*/
85    SCIP_BANDITDATA* banditdata;
86    SCIP_Real randnr;
87    SCIP_Real curreps;
88    SCIP_RANDNUMGEN* rng;
89    int nactions;
90    assert(bandit != NULL);
91    assert(selection != NULL);
92 
93    banditdata = SCIPbanditGetData(bandit);
94    assert(banditdata != NULL);
95    rng = SCIPbanditGetRandnumgen(bandit);
96    assert(rng != NULL);
97 
98    nactions = SCIPbanditGetNActions(bandit);
99 
100    /* roll the dice to check if the best element should be picked, or an element at random */
101    randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
102 
103    /* make epsilon decrease with an increasing number of selections */
104    banditdata->nselections++;
105    curreps = banditdata->eps * sqrt((SCIP_Real)nactions/(SCIP_Real)banditdata->nselections);
106 
107    /* select the best action seen so far */
108    if( randnr >= curreps )
109    {
110       SCIP_Real* weights = banditdata->weights;
111       SCIP_Real* priorities = banditdata->priorities;
112       int j;
113       SCIP_Real maxweight;
114 
115       assert(weights != NULL);
116       assert(priorities != NULL);
117 
118       /* pick the element with the largest reward */
119       maxweight = weights[0];
120       *selection = 0;
121 
122       /* determine reward for every element */
123       for( j = 1; j < nactions; ++j )
124       {
125          SCIP_Real weight = weights[j];
126 
127          /* select the action that maximizes the reward, breaking ties by action priorities */
128          if( maxweight < weight
129                || (weight >= maxweight - EPSGREEDY_SMALL && priorities[j] > priorities[*selection] ) )
130          {
131             *selection = j;
132             maxweight = weight;
133          }
134       }
135    }
136    else
137    {
138       /* play one of the actions at random */
139       *selection = SCIPrandomGetInt(rng, 0, nactions - 1);
140    }
141 
142    return SCIP_OKAY;
143 }
144 
145 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy)146 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateEpsgreedy)
147 {  /*lint --e{715}*/
148    SCIP_BANDITDATA* banditdata;
149 
150    assert(bandit != NULL);
151 
152    banditdata = SCIPbanditGetData(bandit);
153    assert(banditdata != NULL);
154 
155    /* increase the selection count */
156    ++banditdata->sels[selection];
157 
158    /* the very first observation is directly stored as weight for both average or exponential decay */
159    if( banditdata->sels[selection] == 1 )
160       banditdata->weights[selection] = score;
161    else
162    {
163       /* use exponentially decreasing weights for older observations */
164       if( banditdata->preferrecent && banditdata->sels[selection] > banditdata->avglim )
165       {
166          /* decrease old weights by decay factor */
167          banditdata->weights[selection] *= banditdata->decayfactor;
168          banditdata->weights[selection] += (1.0 - banditdata->decayfactor) * score;
169       }
170       else
171       {
172          /* update average score */
173          SCIP_Real diff = score - banditdata->weights[selection];
174          banditdata->weights[selection] += diff / (SCIP_Real)(banditdata->sels[selection]);
175       }
176    }
177 
178    return SCIP_OKAY;
179 }
180 
181 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy)182 SCIP_DECL_BANDITRESET(SCIPbanditResetEpsgreedy)
183 {  /*lint --e{715}*/
184    SCIP_BANDITDATA* banditdata;
185    SCIP_Real* weights;
186    int w;
187    int nactions;
188    SCIP_RANDNUMGEN* rng;
189 
190    assert(bandit != NULL);
191 
192    banditdata = SCIPbanditGetData(bandit);
193    assert(banditdata != NULL);
194 
195    weights = banditdata->weights;
196    nactions = SCIPbanditGetNActions(bandit);
197    assert(weights != NULL);
198    assert(banditdata->priorities != NULL);
199    assert(nactions > 0);
200 
201    rng = SCIPbanditGetRandnumgen(bandit);
202    assert(rng != NULL);
203 
204    /* alter priorities slightly to make them unique */
205    if( priorities != NULL )
206    {
207       for( w = 1; w < nactions; ++w )
208       {
209          assert(priorities[w] >= 0);
210          banditdata->priorities[w] = priorities[w] + SCIPrandomGetReal(rng, -EPSGREEDY_SMALL, EPSGREEDY_SMALL);
211       }
212    }
213    else
214    {
215       /* use random priorities */
216       for( w = 0; w < nactions; ++w )
217          banditdata->priorities[w] = SCIPrandomGetReal(rng, 0.0, 1.0);
218    }
219 
220    /* reset weights and selection counters to 0 */
221    BMSclearMemoryArray(weights, nactions);
222    BMSclearMemoryArray(banditdata->sels, nactions);
223 
224    banditdata->nselections = 0;
225 
226    return SCIP_OKAY;
227 }
228 
229 /*
230  * interface methods of the Epsilon Greedy bandit algorithm
231  */
232 
233 /** internal method to create and reset epsilon greedy bandit algorithm */
SCIPbanditCreateEpsgreedy(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** epsgreedy,SCIP_Real * priorities,SCIP_Real eps,SCIP_Bool preferrecent,SCIP_Real decayfactor,int avglim,int nactions,unsigned int initseed)234 SCIP_RETCODE SCIPbanditCreateEpsgreedy(
235    BMS_BLKMEM*           blkmem,             /**< block memory */
236    BMS_BUFMEM*           bufmem,             /**< buffer memory */
237    SCIP_BANDITVTABLE*    vtable,             /**< virtual function table with epsilon greedy callbacks */
238    SCIP_BANDIT**         epsgreedy,          /**< pointer to store the epsilon greedy bandit algorithm */
239    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
240    SCIP_Real             eps,                /**< parameter to increase probability for exploration between all actions */
241    SCIP_Bool             preferrecent,       /**< should the weights be updated in an exponentially decaying way? */
242    SCIP_Real             decayfactor,        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
243    int                   avglim,             /**< nonnegative limit on observation number before the exponential decay starts,
244                                               *   only relevant if exponential decay is enabled */
245    int                   nactions,           /**< the positive number of possible actions */
246    unsigned int          initseed            /**< initial random seed */
247    )
248 {
249    SCIP_BANDITDATA* banditdata;
250 
251    SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
252    assert(banditdata != NULL);
253    assert(eps >= 0.0);
254 
255    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
256    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->priorities, nactions) );
257    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->sels, nactions) );
258    banditdata->eps = eps;
259    banditdata->nselections = 0;
260    banditdata->preferrecent = preferrecent;
261    banditdata->decayfactor = decayfactor;
262    banditdata->avglim = avglim;
263 
264    SCIP_CALL( SCIPbanditCreate(epsgreedy, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
265 
266    return SCIP_OKAY;
267 }
268 
269 /** create and resets an epsilon greedy bandit algorithm */
SCIPcreateBanditEpsgreedy(SCIP * scip,SCIP_BANDIT ** epsgreedy,SCIP_Real * priorities,SCIP_Real eps,SCIP_Bool preferrecent,SCIP_Real decayfactor,int avglim,int nactions,unsigned int initseed)270 SCIP_RETCODE SCIPcreateBanditEpsgreedy(
271    SCIP*                 scip,               /**< SCIP data structure */
272    SCIP_BANDIT**         epsgreedy,          /**< pointer to store the epsilon greedy bandit algorithm */
273    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
274    SCIP_Real             eps,                /**< parameter to increase probability for exploration between all actions */
275    SCIP_Bool             preferrecent,       /**< should the weights be updated in an exponentially decaying way? */
276    SCIP_Real             decayfactor,        /**< the factor to reduce the weight of older observations if exponential decay is enabled */
277    int                   avglim,             /**< nonnegative limit on observation number before the exponential decay starts,
278                                               *   only relevant if exponential decay is enabled */
279    int                   nactions,           /**< the positive number of possible actions */
280    unsigned int          initseed            /**< initial seed for random number generation */
281    )
282 {
283    SCIP_BANDITVTABLE* vtable;
284    assert(scip != NULL);
285    assert(epsgreedy != NULL);
286 
287    vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
288    if( vtable == NULL )
289    {
290       SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
291       return SCIP_INVALIDDATA;
292    }
293 
294    SCIP_CALL( SCIPbanditCreateEpsgreedy(SCIPblkmem(scip), SCIPbuffer(scip), vtable, epsgreedy,
295          priorities, eps, preferrecent, decayfactor, avglim, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
296 
297    return SCIP_OKAY;
298 }
299 
300 /** get weights array of epsilon greedy bandit algorithm */
SCIPgetWeightsEpsgreedy(SCIP_BANDIT * epsgreedy)301 SCIP_Real* SCIPgetWeightsEpsgreedy(
302    SCIP_BANDIT*          epsgreedy           /**< epsilon greedy bandit algorithm */
303    )
304 {
305    SCIP_BANDITDATA* banditdata;
306    assert(epsgreedy != NULL);
307    banditdata = SCIPbanditGetData(epsgreedy);
308    assert(banditdata != NULL);
309 
310    return banditdata->weights;
311 }
312 
313 /** set epsilon parameter of epsilon greedy bandit algorithm */
SCIPsetEpsilonEpsgreedy(SCIP_BANDIT * epsgreedy,SCIP_Real eps)314 void SCIPsetEpsilonEpsgreedy(
315    SCIP_BANDIT*          epsgreedy,          /**< epsilon greedy bandit algorithm */
316    SCIP_Real             eps                 /**< parameter to increase probability for exploration between all actions */
317    )
318 {
319    SCIP_BANDITDATA* banditdata;
320    assert(epsgreedy != NULL);
321    assert(eps >= 0);
322 
323    banditdata = SCIPbanditGetData(epsgreedy);
324 
325    banditdata->eps = eps;
326 }
327 
328 
329 /** creates the epsilon greedy bandit algorithm includes it in SCIP */
SCIPincludeBanditvtableEpsgreedy(SCIP * scip)330 SCIP_RETCODE SCIPincludeBanditvtableEpsgreedy(
331    SCIP*                 scip                /**< SCIP data structure */
332    )
333 {
334    SCIP_BANDITVTABLE* banditvtable;
335 
336    SCIP_CALL( SCIPincludeBanditvtable(scip, &banditvtable, BANDIT_NAME,
337          SCIPbanditFreeEpsgreedy, SCIPbanditSelectEpsgreedy, SCIPbanditUpdateEpsgreedy, SCIPbanditResetEpsgreedy) );
338 
339    return SCIP_OKAY;
340 }
341