1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /*                                                                           */
3 /*                  This file is part of the program and library             */
4 /*         SCIP --- Solving Constraint Integer Programs                      */
5 /*                                                                           */
6 /*    Copyright (C) 2002-2021 Konrad-Zuse-Zentrum                            */
7 /*                            fuer Informationstechnik Berlin                */
8 /*                                                                           */
9 /*  SCIP is distributed under the terms of the ZIB Academic License.         */
10 /*                                                                           */
11 /*  You should have received a copy of the ZIB Academic License              */
12 /*  along with SCIP; see the file COPYING. If not visit scipopt.org.         */
13 /*                                                                           */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file   bandit.c
17  * @ingroup OTHER_CFILES
18  * @brief  internal API of bandit algorithms and bandit virtual function tables
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include <assert.h>
25 #include <string.h>
26 #include "scip/bandit.h"
27 #include "scip/pub_bandit.h"
28 #include "scip/struct_bandit.h"
29 #include "scip/struct_set.h"
30 #include "scip/set.h"
31 
32 /** creates and resets bandit algorithm */
SCIPbanditCreate(SCIP_BANDIT ** bandit,SCIP_BANDITVTABLE * banditvtable,BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_Real * priorities,int nactions,unsigned int initseed,SCIP_BANDITDATA * banditdata)33 SCIP_RETCODE SCIPbanditCreate(
34    SCIP_BANDIT**         bandit,             /**< pointer to bandit algorithm data structure */
35    SCIP_BANDITVTABLE*    banditvtable,       /**< virtual table for this bandit algorithm */
36    BMS_BLKMEM*           blkmem,             /**< block memory for parameter settings */
37    BMS_BUFMEM*           bufmem,             /**< buffer memory */
38    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
39    int                   nactions,           /**< the positive number of actions for this bandit */
40    unsigned int          initseed,           /**< initial seed for random number generation */
41    SCIP_BANDITDATA*      banditdata          /**< algorithm specific bandit data */
42    )
43 {
44    SCIP_BANDIT* banditptr;
45    assert(bandit != NULL);
46    assert(banditvtable != NULL);
47 
48    /* the number of actions must be positive */
49    if( nactions <= 0 )
50    {
51       SCIPerrorMessage("Cannot create bandit selector with %d <= 0 actions\n", nactions);
52 
53       return SCIP_INVALIDDATA;
54    }
55 
56    SCIP_ALLOC( BMSallocBlockMemory(blkmem, bandit) );
57    assert(*bandit != NULL);
58    banditptr = *bandit;
59    banditptr->vtable = banditvtable;
60    banditptr->data = banditdata;
61    banditptr->nactions = nactions;
62 
63    SCIP_CALL( SCIPrandomCreate(&banditptr->rng, blkmem, initseed) );
64 
65    SCIP_CALL( SCIPbanditReset(bufmem, banditptr, priorities, initseed) );
66 
67    return SCIP_OKAY;
68 }
69 
70 /** calls destructor and frees memory of bandit algorithm */
SCIPbanditFree(BMS_BLKMEM * blkmem,SCIP_BANDIT ** bandit)71 SCIP_RETCODE SCIPbanditFree(
72    BMS_BLKMEM*           blkmem,             /**< block memory */
73    SCIP_BANDIT**         bandit              /**< pointer to bandit algorithm data structure */
74    )
75 {
76    SCIP_BANDIT* banditptr;
77    SCIP_BANDITVTABLE* vtable;
78    assert(bandit != NULL);
79    assert(*bandit != NULL);
80 
81    banditptr = *bandit;
82    vtable = banditptr->vtable;
83    assert(vtable != NULL);
84 
85    /* call bandit specific data destructor */
86    if( vtable->banditfree != NULL )
87    {
88       SCIP_CALL( vtable->banditfree(blkmem, banditptr) );
89    }
90 
91    /* free random number generator */
92    SCIPrandomFree(&banditptr->rng, blkmem);
93 
94    BMSfreeBlockMemory(blkmem, bandit);
95 
96    return SCIP_OKAY;
97 }
98 
99 /** reset the bandit algorithm */
SCIPbanditReset(BMS_BUFMEM * bufmem,SCIP_BANDIT * bandit,SCIP_Real * priorities,unsigned int seed)100 SCIP_RETCODE SCIPbanditReset(
101    BMS_BUFMEM*           bufmem,             /**< buffer memory */
102    SCIP_BANDIT*          bandit,             /**< pointer to bandit algorithm data structure */
103    SCIP_Real*            priorities,         /**< nonnegative priorities for each action, or NULL if not needed */
104    unsigned int          seed                /**< initial seed for random number generation */
105    )
106 {
107    SCIP_BANDITVTABLE* vtable;
108 
109    assert(bandit != NULL);
110    assert(bufmem != NULL);
111 
112    vtable = bandit->vtable;
113    assert(vtable != NULL);
114    assert(vtable->banditreset != NULL);
115 
116    /* test if the priorities are nonnegative */
117    if( priorities != NULL )
118    {
119       int i;
120 
121       assert(SCIPbanditGetNActions(bandit) > 0);
122 
123       for( i = 0; i < SCIPbanditGetNActions(bandit); ++i )
124       {
125          if( priorities[i] < 0 )
126          {
127             SCIPerrorMessage("Negative priority for action %d\n", i);
128 
129             return SCIP_INVALIDDATA;
130          }
131       }
132    }
133 
134    /* reset the random seed of the bandit algorithm */
135    SCIPrandomSetSeed(bandit->rng, seed);
136 
137    /* call the reset callback of the bandit algorithm */
138    SCIP_CALL( vtable->banditreset(bufmem, bandit, priorities) );
139 
140    return SCIP_OKAY;
141 }
142 
143 /** select the next action */
SCIPbanditSelect(SCIP_BANDIT * bandit,int * action)144 SCIP_RETCODE SCIPbanditSelect(
145    SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
146    int*                  action              /**< pointer to store the selected action */
147    )
148 {
149    assert(bandit != NULL);
150    assert(action != NULL);
151 
152    *action = -1;
153 
154    assert(bandit->vtable->banditselect != NULL);
155 
156    SCIP_CALL( bandit->vtable->banditselect(bandit, action) );
157 
158    assert(*action >= 0);
159    assert(*action < SCIPbanditGetNActions(bandit));
160 
161    return SCIP_OKAY;
162 }
163 
164 /** update the score of the selected action */
SCIPbanditUpdate(SCIP_BANDIT * bandit,int action,SCIP_Real score)165 SCIP_RETCODE SCIPbanditUpdate(
166    SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
167    int                   action,             /**< index of action for which the score should be updated */
168    SCIP_Real             score               /**< observed gain of the i'th action */
169    )
170 {
171    assert(bandit != NULL);
172    assert(0 <= action && action < SCIPbanditGetNActions(bandit));
173    assert(bandit->vtable->banditupdate != NULL);
174 
175    SCIP_CALL( bandit->vtable->banditupdate(bandit, action, score) );
176 
177    return SCIP_OKAY;
178 }
179 
180 /** get data of this bandit algorithm */
SCIPbanditGetData(SCIP_BANDIT * bandit)181 SCIP_BANDITDATA* SCIPbanditGetData(
182    SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
183    )
184 {
185    assert(bandit != NULL);
186 
187    return bandit->data;
188 }
189 
190 /** set the data of this bandit algorithm */
SCIPbanditSetData(SCIP_BANDIT * bandit,SCIP_BANDITDATA * banditdata)191 void SCIPbanditSetData(
192    SCIP_BANDIT*          bandit,             /**< bandit algorithm data structure */
193    SCIP_BANDITDATA*      banditdata          /**< bandit algorihm specific data, or NULL */
194    )
195 {
196    assert(bandit != NULL);
197 
198    bandit->data = banditdata;
199 }
200 
201 /** internal method to create a bandit VTable */
202 static
doBanditvtableCreate(SCIP_BANDITVTABLE ** banditvtable,const char * name,SCIP_DECL_BANDITFREE ((* banditfree)),SCIP_DECL_BANDITSELECT ((* banditselect)),SCIP_DECL_BANDITUPDATE ((* banditupdate)),SCIP_DECL_BANDITRESET ((* banditreset)))203 SCIP_RETCODE doBanditvtableCreate(
204    SCIP_BANDITVTABLE**   banditvtable,       /**< pointer to virtual table for bandit algorithm */
205    const char*           name,               /**< a name for the algorithm represented by this vtable */
206    SCIP_DECL_BANDITFREE  ((*banditfree)),    /**< callback to free bandit specific data structures */
207    SCIP_DECL_BANDITSELECT((*banditselect)),  /**< selection callback for bandit selector */
208    SCIP_DECL_BANDITUPDATE((*banditupdate)),  /**< update callback for bandit algorithms */
209    SCIP_DECL_BANDITRESET ((*banditreset))    /**< update callback for bandit algorithms */
210    )
211 {
212    SCIP_BANDITVTABLE* banditvtableptr;
213 
214    assert(banditvtable != NULL);
215    assert(name != NULL);
216    assert(banditfree != NULL);
217    assert(banditselect != NULL);
218    assert(banditupdate != NULL);
219    assert(banditreset != NULL);
220 
221    /* allocate memory for this virtual function table */
222    SCIP_ALLOC( BMSallocMemory(banditvtable) );
223    BMSclearMemory(*banditvtable);
224 
225    SCIP_ALLOC( BMSduplicateMemoryArray(&(*banditvtable)->name, name, strlen(name)+1) );
226    banditvtableptr = *banditvtable;
227    banditvtableptr->banditfree = banditfree;
228    banditvtableptr->banditselect = banditselect;
229    banditvtableptr->banditupdate = banditupdate;
230    banditvtableptr->banditreset = banditreset;
231 
232    return SCIP_OKAY;
233 }
234 
235 /** create a bandit VTable for bandit algorithm callback functions */
SCIPbanditvtableCreate(SCIP_BANDITVTABLE ** banditvtable,const char * name,SCIP_DECL_BANDITFREE ((* banditfree)),SCIP_DECL_BANDITSELECT ((* banditselect)),SCIP_DECL_BANDITUPDATE ((* banditupdate)),SCIP_DECL_BANDITRESET ((* banditreset)))236 SCIP_RETCODE SCIPbanditvtableCreate(
237    SCIP_BANDITVTABLE**   banditvtable,       /**< pointer to virtual table for bandit algorithm */
238    const char*           name,               /**< a name for the algorithm represented by this vtable */
239    SCIP_DECL_BANDITFREE  ((*banditfree)),    /**< callback to free bandit specific data structures */
240    SCIP_DECL_BANDITSELECT((*banditselect)),  /**< selection callback for bandit selector */
241    SCIP_DECL_BANDITUPDATE((*banditupdate)),  /**< update callback for bandit algorithms */
242    SCIP_DECL_BANDITRESET ((*banditreset))    /**< update callback for bandit algorithms */
243    )
244 {
245    assert(banditvtable != NULL);
246    assert(name != NULL);
247    assert(banditfree != NULL);
248    assert(banditselect != NULL);
249    assert(banditupdate != NULL);
250    assert(banditreset != NULL);
251 
252    SCIP_CALL_FINALLY( doBanditvtableCreate(banditvtable, name, banditfree, banditselect, banditupdate, banditreset),
253       SCIPbanditvtableFree(banditvtable) );
254 
255    return SCIP_OKAY;
256 }
257 
258 
259 /** free a bandit virtual table for bandit algorithm callback functions */
SCIPbanditvtableFree(SCIP_BANDITVTABLE ** banditvtable)260 void SCIPbanditvtableFree(
261    SCIP_BANDITVTABLE**   banditvtable        /**< pointer to virtual table for bandit algorithm */
262    )
263 {
264    assert(banditvtable != NULL);
265    if( *banditvtable == NULL )
266       return;
267 
268    BMSfreeMemoryArrayNull(&(*banditvtable)->name);
269    BMSfreeMemory(banditvtable);
270 }
271 
272 /** return the name of this bandit virtual function table */
SCIPbanditvtableGetName(SCIP_BANDITVTABLE * banditvtable)273 const char* SCIPbanditvtableGetName(
274    SCIP_BANDITVTABLE*    banditvtable        /**< virtual table for bandit algorithm */
275    )
276 {
277    assert(banditvtable != NULL);
278 
279    return banditvtable->name;
280 }
281 
282 
283 /** return the random number generator of a bandit algorithm */
SCIPbanditGetRandnumgen(SCIP_BANDIT * bandit)284 SCIP_RANDNUMGEN* SCIPbanditGetRandnumgen(
285    SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
286    )
287 {
288    assert(bandit != NULL);
289 
290    return bandit->rng;
291 }
292 
293 /** return number of actions of this bandit algorithm */
SCIPbanditGetNActions(SCIP_BANDIT * bandit)294 int SCIPbanditGetNActions(
295    SCIP_BANDIT*          bandit              /**< bandit algorithm data structure */
296    )
297 {
298    assert(bandit != NULL);
299 
300    return bandit->nactions;
301 }
302