1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /*                                                                           */
3 /*                  This file is part of the program and library             */
4 /*         SCIP --- Solving Constraint Integer Programs                      */
5 /*                                                                           */
6 /*    Copyright (C) 2002-2021 Konrad-Zuse-Zentrum                            */
7 /*                            fuer Informationstechnik Berlin                */
8 /*                                                                           */
9 /*  SCIP is distributed under the terms of the ZIB Academic License.         */
10 /*                                                                           */
11 /*  You should have received a copy of the ZIB Academic License              */
12 /*  along with SCIP; see the file COPYING. If not visit scipopt.org.         */
13 /*                                                                           */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file   bandit_ucb.c
17  * @ingroup OTHER_CFILES
18  * @brief  methods for UCB bandit selection
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include "scip/bandit.h"
25 #include "scip/bandit_ucb.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/pub_misc_sort.h"
30 #include "scip/scip_bandit.h"
31 #include "scip/scip_mem.h"
32 #include "scip/scip_randnumgen.h"
33 
34 
35 #define BANDIT_NAME "ucb"
36 #define NUMEPS 1e-6
37 
38 /*
39  * Data structures
40  */
41 
42 /** implementation specific data of UCB bandit algorithm */
43 struct SCIP_BanditData
44 {
45    int                   nselections;        /**< counter for the number of selections */
46    int*                  counter;            /**< array of counters how often every action has been chosen */
47    int*                  startperm;          /**< indices for starting permutation */
48    SCIP_Real*            meanscores;         /**< array of average scores for the actions */
49    SCIP_Real             alpha;              /**< parameter to increase confidence width */
50 };
51 
52 
53 /*
54  * Local methods
55  */
56 
57 /** data reset method */
58 static
dataReset(BMS_BUFMEM * bufmem,SCIP_BANDIT * ucb,SCIP_BANDITDATA * banditdata,SCIP_Real * priorities,int nactions)59 SCIP_RETCODE dataReset(
60    BMS_BUFMEM*           bufmem,             /**< buffer memory */
61    SCIP_BANDIT*          ucb,                /**< ucb bandit algorithm */
62    SCIP_BANDITDATA*      banditdata,         /**< UCB bandit data structure */
63    SCIP_Real*            priorities,         /**< priorities for start permutation, or NULL */
64    int                   nactions            /**< number of actions */
65    )
66 {
67    int i;
68    SCIP_RANDNUMGEN* rng;
69 
70    assert(bufmem != NULL);
71    assert(ucb != NULL);
72    assert(nactions > 0);
73 
74    /* clear counters and scores */
75    BMSclearMemoryArray(banditdata->counter, nactions);
76    BMSclearMemoryArray(banditdata->meanscores, nactions);
77    banditdata->nselections = 0;
78 
79    rng = SCIPbanditGetRandnumgen(ucb);
80    assert(rng != NULL);
81 
82    /* initialize start permutation as identity */
83    for( i = 0; i < nactions; ++i )
84       banditdata->startperm[i] = i;
85 
86    /* prepare the start permutation in decreasing order of priority */
87    if( priorities != NULL )
88    {
89       SCIP_Real* prioritycopy;
90 
91       SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
92 
93       /* randomly wiggle priorities a little bit to make them unique */
94       for( i = 0; i < nactions; ++i )
95          prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
96 
97       SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
98 
99       BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
100    }
101    else
102    {
103       /* use a random start permutation */
104       SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
105    }
106 
107    return SCIP_OKAY;
108 }
109 
110 
111 /*
112  * Callback methods of bandit algorithm
113  */
114 
115 /** callback to free bandit specific data structures */
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)116 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
117 {  /*lint --e{715}*/
118    SCIP_BANDITDATA* banditdata;
119    int nactions;
120    assert(bandit != NULL);
121 
122    banditdata = SCIPbanditGetData(bandit);
123    assert(banditdata != NULL);
124    nactions = SCIPbanditGetNActions(bandit);
125 
126    BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
127    BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
128    BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
129    BMSfreeBlockMemory(blkmem, &banditdata);
130 
131    SCIPbanditSetData(bandit, NULL);
132 
133    return SCIP_OKAY;
134 }
135 
136 /** selection callback for bandit selector */
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)137 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
138 {  /*lint --e{715}*/
139    SCIP_BANDITDATA* banditdata;
140    int nactions;
141    int* counter;
142 
143    assert(bandit != NULL);
144    assert(selection != NULL);
145 
146    banditdata = SCIPbanditGetData(bandit);
147    assert(banditdata != NULL);
148    nactions = SCIPbanditGetNActions(bandit);
149 
150    counter = banditdata->counter;
151    /* select the next uninitialized action from the start permutation */
152    if( banditdata->nselections < nactions )
153    {
154       *selection = banditdata->startperm[banditdata->nselections];
155       assert(counter[*selection] == 0);
156    }
157    else
158    {
159       /* select the action with the highest upper confidence bound */
160       SCIP_Real* meanscores;
161       SCIP_Real widthfactor;
162       SCIP_Real maxucb;
163       int i;
164       SCIP_RANDNUMGEN* rng = SCIPbanditGetRandnumgen(bandit);
165       meanscores = banditdata->meanscores;
166 
167       assert(rng != NULL);
168       assert(meanscores != NULL);
169 
170       /* compute the confidence width factor that is common for all actions */
171       /* cppcheck-suppress unpreciseMathCall */
172       widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
173       widthfactor = sqrt(widthfactor);
174       maxucb = -1.0;
175 
176       /* loop over the actions and determine the maximum upper confidence bound.
177        * The upper confidence bound of an action is the sum of its mean score
178        * plus a confidence term that decreases with increasing number of observations of
179        * this action.
180        */
181       for( i = 0; i < nactions; ++i )
182       {
183          SCIP_Real uppercb;
184          SCIP_Real rootcount;
185          assert(counter[i] > 0);
186 
187          /* compute the upper confidence bound for action i */
188          uppercb = meanscores[i];
189          rootcount = sqrt((SCIP_Real)counter[i]);
190          uppercb += widthfactor / rootcount;
191          assert(uppercb > 0);
192 
193          /* update maximum, breaking ties uniformly at random */
194          if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
195          {
196             maxucb = uppercb;
197             *selection = i;
198          }
199       }
200    }
201 
202    assert(*selection >= 0);
203    assert(*selection < nactions);
204 
205    return SCIP_OKAY;
206 }
207 
208 /** update callback for bandit algorithm */
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)209 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
210 {  /*lint --e{715}*/
211    SCIP_BANDITDATA* banditdata;
212    SCIP_Real delta;
213 
214    assert(bandit != NULL);
215 
216    banditdata = SCIPbanditGetData(bandit);
217    assert(banditdata != NULL);
218    assert(selection >= 0);
219    assert(selection < SCIPbanditGetNActions(bandit));
220 
221    /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
222    delta = score - banditdata->meanscores[selection];
223    ++banditdata->counter[selection];
224    banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
225 
226    banditdata->nselections++;
227 
228    return SCIP_OKAY;
229 }
230 
231 /** reset callback for bandit algorithm */
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)232 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
233 {  /*lint --e{715}*/
234    SCIP_BANDITDATA* banditdata;
235    int nactions;
236 
237    assert(bufmem != NULL);
238    assert(bandit != NULL);
239 
240    banditdata = SCIPbanditGetData(bandit);
241    assert(banditdata != NULL);
242    nactions = SCIPbanditGetNActions(bandit);
243 
244    /* call the data reset for the given priorities */
245    SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
246 
247    return SCIP_OKAY;
248 }
249 
250 /*
251  * bandit algorithm specific interface methods
252  */
253 
254 /** returns the upper confidence bound of a selected action */
SCIPgetConfidenceBoundUcb(SCIP_BANDIT * ucb,int action)255 SCIP_Real SCIPgetConfidenceBoundUcb(
256    SCIP_BANDIT*          ucb,                /**< UCB bandit algorithm */
257    int                   action              /**< index of the queried action */
258    )
259 {
260    SCIP_Real uppercb;
261    SCIP_BANDITDATA* banditdata;
262    int nactions;
263 
264    assert(ucb != NULL);
265    banditdata = SCIPbanditGetData(ucb);
266    nactions = SCIPbanditGetNActions(ucb);
267    assert(action < nactions);
268 
269    /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
270    if( banditdata->nselections < nactions )
271       return 1.0;
272 
273    /* the bandit algorithm must have picked every action once */
274    assert(banditdata->counter[action] > 0);
275    uppercb = banditdata->meanscores[action];
276 
277    /* cppcheck-suppress unpreciseMathCall */
278    uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
279 
280    return uppercb;
281 }
282 
283 /** return start permutation of the UCB bandit algorithm */
SCIPgetStartPermutationUcb(SCIP_BANDIT * ucb)284 int* SCIPgetStartPermutationUcb(
285    SCIP_BANDIT*          ucb                 /**< UCB bandit algorithm */
286    )
287 {
288    SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
289 
290    assert(banditdata != NULL);
291 
292    return banditdata->startperm;
293 }
294 
295 /** internal method to create and reset UCB bandit algorithm */
SCIPbanditCreateUcb(BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_BANDITVTABLE * vtable,SCIP_BANDIT ** ucb,SCIP_Real * priorities,SCIP_Real alpha,int nactions,unsigned int initseed)296 SCIP_RETCODE SCIPbanditCreateUcb(
297    BMS_BLKMEM*           blkmem,             /**< block memory */
298    BMS_BUFMEM*           bufmem,             /**< buffer memory */
299    SCIP_BANDITVTABLE*    vtable,             /**< virtual function table for UCB bandit algorithm */
300    SCIP_BANDIT**         ucb,                /**< pointer to store bandit algorithm */
301    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
302    SCIP_Real             alpha,              /**< parameter to increase confidence width */
303    int                   nactions,           /**< the positive number of actions for this bandit algorithm */
304    unsigned int          initseed            /**< initial random seed */
305    )
306 {
307    SCIP_BANDITDATA* banditdata;
308 
309    if( alpha < 0.0 )
310    {
311       SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
312       return SCIP_INVALIDDATA;
313    }
314 
315    SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
316    assert(banditdata != NULL);
317 
318    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
319    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
320    SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
321 
322    banditdata->alpha = alpha;
323 
324    SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
325 
326    return SCIP_OKAY;
327 }
328 
329 /** create and reset UCB bandit algorithm */
SCIPcreateBanditUcb(SCIP * scip,SCIP_BANDIT ** ucb,SCIP_Real * priorities,SCIP_Real alpha,int nactions,unsigned int initseed)330 SCIP_RETCODE SCIPcreateBanditUcb(
331    SCIP*                 scip,               /**< SCIP data structure */
332    SCIP_BANDIT**         ucb,                /**< pointer to store bandit algorithm */
333    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
334    SCIP_Real             alpha,              /**< parameter to increase confidence width */
335    int                   nactions,           /**< the positive number of actions for this bandit algorithm */
336    unsigned int          initseed            /**< initial random number seed */
337    )
338 {
339    SCIP_BANDITVTABLE* vtable;
340 
341    vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
342    if( vtable == NULL )
343    {
344       SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
345       return SCIP_INVALIDDATA;
346    }
347 
348    SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
349          priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
350 
351    return SCIP_OKAY;
352 }
353 
354 /** include virtual function table for UCB bandit algorithms */
SCIPincludeBanditvtableUcb(SCIP * scip)355 SCIP_RETCODE SCIPincludeBanditvtableUcb(
356    SCIP*                 scip                /**< SCIP data structure */
357    )
358 {
359    SCIP_BANDITVTABLE* vtable;
360 
361    SCIP_CALL( SCIPincludeBanditvtable(scip, &vtable, BANDIT_NAME,
362          SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
363    assert(vtable != NULL);
364 
365    return SCIP_OKAY;
366 }
367