1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /*                                                                           */
3 /*                  This file is part of the program and library             */
4 /*         SCIP --- Solving Constraint Integer Programs                      */
5 /*                                                                           */
6 /*    Copyright (C) 2002-2021 Konrad-Zuse-Zentrum                            */
7 /*                            fuer Informationstechnik Berlin                */
8 /*                                                                           */
9 /*  SCIP is distributed under the terms of the ZIB Academic License.         */
10 /*                                                                           */
11 /*  You should have received a copy of the ZIB Academic License              */
12 /*  along with SCIP; see the file COPYING. If not visit scipopt.org.         */
13 /*                                                                           */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file   bandit_exp3.c
17  * @ingroup OTHER_CFILES
18  * @brief  methods for Exp.3 bandit selection
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include "scip/bandit.h"
25 #include "scip/bandit_exp3.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32 
33 #define BANDIT_NAME "exp3"
34 #define NUMTOL 1e-6
35 
36 /*
37  * Data structures
38  */
39 
40 /** implementation specific data of Exp.3 bandit algorithm */
41 struct SCIP_BanditData
42 {
43    SCIP_Real*            weights;            /**< exponential weight for each arm */
44    SCIP_Real             weightsum;          /**< the sum of all weights */
45    SCIP_Real             gamma;              /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
46    SCIP_Real             beta;               /**< gain offset between 0 and 1 at every observation */
47 };
48 
49 /*
50  * Local methods
51  */
52 
53 /*
54  * Callback methods of bandit algorithm
55  */
56 
57 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)58 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
59 {  /*lint --e{715}*/
60    SCIP_BANDITDATA* banditdata;
61    int nactions;
62    assert(bandit != NULL);
63 
64    banditdata = SCIPbanditGetData(bandit);
65    assert(banditdata != NULL);
66    nactions = SCIPbanditGetNActions(bandit);
67 
68    BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
69 
70    BMSfreeBlockMemory(blkmem, &banditdata);
71 
72    SCIPbanditSetData(bandit, NULL);
73 
74    return SCIP_OKAY;
75 }
76 
77 /** selection callback for bandit selector */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)78 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
79 {  /*lint --e{715}*/
80    SCIP_BANDITDATA* banditdata;
81    SCIP_RANDNUMGEN* rng;
82    SCIP_Real randnr;
83    SCIP_Real psum;
84    SCIP_Real gammaoverk;
85    SCIP_Real oneminusgamma;
86    SCIP_Real* weights;
87    SCIP_Real weightsum;
88    int i;
89    int nactions;
90 
91    assert(bandit != NULL);
92    assert(selection != NULL);
93 
94    banditdata = SCIPbanditGetData(bandit);
95    assert(banditdata != NULL);
96    rng = SCIPbanditGetRandnumgen(bandit);
97    assert(rng != NULL);
98    nactions = SCIPbanditGetNActions(bandit);
99 
100    /* draw a random number between 0 and 1 */
101    randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
102 
103    /* initialize some local variables to speed up probability computations */
104    oneminusgamma = 1 - banditdata->gamma;
105    gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
106    weightsum = banditdata->weightsum;
107    weights = banditdata->weights;
108    psum = 0.0;
109 
110    /* loop over probability distribution until rand is reached
111     * the loop terminates without looking at the last action,
112     * which is then selected automatically if the target probability
113     * is not reached earlier
114     */
115    for( i = 0; i < nactions - 1; ++i )
116    {
117       SCIP_Real prob;
118 
119       /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
120       prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
121       psum += prob;
122 
123       /* break and select element if target probability is reached */
124       if( randnr <= psum )
125          break;
126    }
127 
128    /* select element i, which is the last action in case that the break statement hasn't been reached */
129    *selection = i;
130 
131    return SCIP_OKAY;
132 }
133 
134 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)135 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
136 {  /*lint --e{715}*/
137    SCIP_BANDITDATA* banditdata;
138    SCIP_Real eta;
139    SCIP_Real gainestim;
140    SCIP_Real beta;
141    SCIP_Real weightsum;
142    SCIP_Real newweightsum;
143    SCIP_Real* weights;
144    SCIP_Real oneminusgamma;
145    SCIP_Real gammaoverk;
146    int nactions;
147 
148    assert(bandit != NULL);
149 
150    banditdata = SCIPbanditGetData(bandit);
151    assert(banditdata != NULL);
152    nactions = SCIPbanditGetNActions(bandit);
153 
154    assert(selection >= 0);
155    assert(selection < nactions);
156 
157    /* the learning rate eta */
158    eta = 1.0 / (SCIP_Real)nactions;
159 
160    beta = banditdata->beta;
161    oneminusgamma = 1.0 - banditdata->gamma;
162    gammaoverk = banditdata->gamma * eta;
163    weights = banditdata->weights;
164    weightsum = banditdata->weightsum;
165    newweightsum = weightsum;
166 
167    /* if beta is zero, only the observation for the current arm needs an update */
168    if( EPSZ(beta, NUMTOL) )
169    {
170       SCIP_Real probai;
171       probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
172 
173       assert(probai > 0.0);
174 
175       gainestim = score / probai;
176       newweightsum -= weights[selection];
177       weights[selection] *= exp(eta * gainestim);
178       newweightsum += weights[selection];
179    }
180    else
181    {
182       int j;
183       newweightsum = 0.0;
184 
185       /* loop over all items and update their weights based on the influence of the beta parameter */
186       for( j = 0; j < nactions; ++j )
187       {
188          SCIP_Real probaj;
189          probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
190 
191          assert(probaj > 0.0);
192 
193          /* consider the score only for the chosen arm i, use constant beta offset otherwise */
194          if( j == selection )
195             gainestim = (score + beta) / probaj;
196          else
197             gainestim = beta / probaj;
198 
199          weights[j] *= exp(eta * gainestim);
200          newweightsum += weights[j];
201       }
202    }
203 
204    banditdata->weightsum = newweightsum;
205 
206    return SCIP_OKAY;
207 }
208 
209 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)210 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
211 {  /*lint --e{715}*/
212    SCIP_BANDITDATA* banditdata;
213    SCIP_Real* weights;
214    int nactions;
215    int i;
216 
217    assert(bandit != NULL);
218 
219    banditdata = SCIPbanditGetData(bandit);
220    assert(banditdata != NULL);
221    nactions = SCIPbanditGetNActions(bandit);
222    weights = banditdata->weights;
223 
224    assert(nactions > 0);
225 
226    banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
227 
228    /* in case of priorities, weights are normalized to sum up to nactions */
229    if( priorities != NULL )
230    {
231       SCIP_Real normalization;
232       SCIP_Real priosum;
233       priosum = 0.0;
234 
235       /* compute sum of priorities */
236       for( i = 0; i < nactions; ++i )
237       {
238          assert(priorities[i] >= 0);
239          priosum += priorities[i];
240       }
241 
242       /* if there are positive priorities, normalize the weights */
243       if( priosum > 0.0 )
244       {
245          normalization = nactions / priosum;
246          for( i = 0; i < nactions; ++i )
247             weights[i] = (priorities[i] * normalization) + NUMTOL;
248       }
249       else
250       {
251          /* use uniform distribution in case of all priorities being 0.0 */
252          for( i = 0; i < nactions; ++i )
253             weights[i] = 1.0 + NUMTOL;
254       }
255    }
256    else
257    {
258       /* use uniform distribution in case of unspecified priorities */
259       for( i = 0; i < nactions; ++i )
260          weights[i] = 1.0 + NUMTOL;
261    }
262 
263    return SCIP_OKAY;
264 }
265 
266 
267 /*
268  * bandit algorithm specific interface methods
269  */
270 
271 /** direct bandit creation method for the core where no SCIP pointer is available */
SCIPbanditCreateExp3(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** exp3,SCIP_Real * priorities,SCIP_Real gammaparam,SCIP_Real beta,int nactions,unsigned int initseed)272 SCIP_RETCODE SCIPbanditCreateExp3(
273    BMS_BLKMEM*           blkmem,             /**< block memory data structure */
274    BMS_BUFMEM*           bufmem,             /**< buffer memory */
275    SCIP_BANDITVTABLE*    vtable,             /**< virtual function table for callback functions of Exp.3 */
276    SCIP_BANDIT**         exp3,               /**< pointer to store bandit algorithm */
277    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
278    SCIP_Real             gammaparam,         /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
279    SCIP_Real             beta,               /**< gain offset between 0 and 1 at every observation */
280    int                   nactions,           /**< the positive number of actions for this bandit algorithm */
281    unsigned int          initseed            /**< initial random seed */
282    )
283 {
284    SCIP_BANDITDATA* banditdata;
285 
286    SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
287    assert(banditdata != NULL);
288 
289    banditdata->gamma = gammaparam;
290    banditdata->beta = beta;
291    assert(gammaparam >= 0 && gammaparam <= 1);
292    assert(beta >= 0 && beta <= 1);
293 
294    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
295 
296    SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
297 
298    return SCIP_OKAY;
299 }
300 
301 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
SCIPcreateBanditExp3(SCIP * scip,SCIP_BANDIT ** exp3,SCIP_Real * priorities,SCIP_Real gammaparam,SCIP_Real beta,int nactions,unsigned int initseed)302 SCIP_RETCODE SCIPcreateBanditExp3(
303    SCIP*                 scip,               /**< SCIP data structure */
304    SCIP_BANDIT**         exp3,               /**< pointer to store bandit algorithm */
305    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
306    SCIP_Real             gammaparam,         /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
307    SCIP_Real             beta,               /**< gain offset between 0 and 1 at every observation */
308    int                   nactions,           /**< the positive number of actions for this bandit algorithm */
309    unsigned int          initseed            /**< initial seed for random number generation */
310    )
311 {
312    SCIP_BANDITVTABLE* vtable;
313 
314    vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
315    if( vtable == NULL )
316    {
317       SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
318       return SCIP_INVALIDDATA;
319    }
320 
321    SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
322          priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
323 
324    return SCIP_OKAY;
325 }
326 
327 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
SCIPsetGammaExp3(SCIP_BANDIT * exp3,SCIP_Real gammaparam)328 void SCIPsetGammaExp3(
329    SCIP_BANDIT*          exp3,               /**< bandit algorithm */
330    SCIP_Real             gammaparam          /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
331    )
332 {
333    SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
334 
335    assert(gammaparam >= 0 && gammaparam <= 1);
336 
337    banditdata->gamma = gammaparam;
338 }
339 
340 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
SCIPsetBetaExp3(SCIP_BANDIT * exp3,SCIP_Real beta)341 void SCIPsetBetaExp3(
342    SCIP_BANDIT*          exp3,               /**< bandit algorithm */
343    SCIP_Real             beta                /**< gain offset between 0 and 1 at every observation */
344    )
345 {
346    SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
347 
348    assert(beta >= 0 && beta <= 1);
349 
350    banditdata->beta = beta;
351 }
352 
353 /** returns probability to play an action */
SCIPgetProbabilityExp3(SCIP_BANDIT * exp3,int action)354 SCIP_Real SCIPgetProbabilityExp3(
355    SCIP_BANDIT*          exp3,               /**< bandit algorithm */
356    int                   action              /**< index of the requested action */
357    )
358 {
359    SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
360 
361    assert(banditdata->weightsum > 0.0);
362    assert(SCIPbanditGetNActions(exp3) > 0);
363 
364    return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
365 }
366 
367 /** include virtual function table for Exp.3 bandit algorithms */
SCIPincludeBanditvtableExp3(SCIP * scip)368 SCIP_RETCODE SCIPincludeBanditvtableExp3(
369    SCIP*                 scip                /**< SCIP data structure */
370    )
371 {
372    SCIP_BANDITVTABLE* vtable;
373 
374    SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
375          SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
376    assert(vtable != NULL);
377 
378    return SCIP_OKAY;
379 }
380