1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2021 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15
16 /**@file bandit.c
17 * @ingroup OTHER_CFILES
18 * @brief internal API of bandit algorithms and bandit virtual function tables
19 * @author Gregor Hendel
20 */
21
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23
24 #include <assert.h>
25 #include <string.h>
26 #include "scip/bandit.h"
27 #include "scip/pub_bandit.h"
28 #include "scip/struct_bandit.h"
29 #include "scip/struct_set.h"
30 #include "scip/set.h"
31
32 /** creates and resets bandit algorithm */
SCIPbanditCreate(SCIP_BANDIT ** bandit,SCIP_BANDITVTABLE * banditvtable,BMS_BLKMEM * blkmem,BMS_BUFMEM * bufmem,SCIP_Real * priorities,int nactions,unsigned int initseed,SCIP_BANDITDATA * banditdata)33 SCIP_RETCODE SCIPbanditCreate(
34 SCIP_BANDIT** bandit, /**< pointer to bandit algorithm data structure */
35 SCIP_BANDITVTABLE* banditvtable, /**< virtual table for this bandit algorithm */
36 BMS_BLKMEM* blkmem, /**< block memory for parameter settings */
37 BMS_BUFMEM* bufmem, /**< buffer memory */
38 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
39 int nactions, /**< the positive number of actions for this bandit */
40 unsigned int initseed, /**< initial seed for random number generation */
41 SCIP_BANDITDATA* banditdata /**< algorithm specific bandit data */
42 )
43 {
44 SCIP_BANDIT* banditptr;
45 assert(bandit != NULL);
46 assert(banditvtable != NULL);
47
48 /* the number of actions must be positive */
49 if( nactions <= 0 )
50 {
51 SCIPerrorMessage("Cannot create bandit selector with %d <= 0 actions\n", nactions);
52
53 return SCIP_INVALIDDATA;
54 }
55
56 SCIP_ALLOC( BMSallocBlockMemory(blkmem, bandit) );
57 assert(*bandit != NULL);
58 banditptr = *bandit;
59 banditptr->vtable = banditvtable;
60 banditptr->data = banditdata;
61 banditptr->nactions = nactions;
62
63 SCIP_CALL( SCIPrandomCreate(&banditptr->rng, blkmem, initseed) );
64
65 SCIP_CALL( SCIPbanditReset(bufmem, banditptr, priorities, initseed) );
66
67 return SCIP_OKAY;
68 }
69
70 /** calls destructor and frees memory of bandit algorithm */
SCIPbanditFree(BMS_BLKMEM * blkmem,SCIP_BANDIT ** bandit)71 SCIP_RETCODE SCIPbanditFree(
72 BMS_BLKMEM* blkmem, /**< block memory */
73 SCIP_BANDIT** bandit /**< pointer to bandit algorithm data structure */
74 )
75 {
76 SCIP_BANDIT* banditptr;
77 SCIP_BANDITVTABLE* vtable;
78 assert(bandit != NULL);
79 assert(*bandit != NULL);
80
81 banditptr = *bandit;
82 vtable = banditptr->vtable;
83 assert(vtable != NULL);
84
85 /* call bandit specific data destructor */
86 if( vtable->banditfree != NULL )
87 {
88 SCIP_CALL( vtable->banditfree(blkmem, banditptr) );
89 }
90
91 /* free random number generator */
92 SCIPrandomFree(&banditptr->rng, blkmem);
93
94 BMSfreeBlockMemory(blkmem, bandit);
95
96 return SCIP_OKAY;
97 }
98
99 /** reset the bandit algorithm */
SCIPbanditReset(BMS_BUFMEM * bufmem,SCIP_BANDIT * bandit,SCIP_Real * priorities,unsigned int seed)100 SCIP_RETCODE SCIPbanditReset(
101 BMS_BUFMEM* bufmem, /**< buffer memory */
102 SCIP_BANDIT* bandit, /**< pointer to bandit algorithm data structure */
103 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
104 unsigned int seed /**< initial seed for random number generation */
105 )
106 {
107 SCIP_BANDITVTABLE* vtable;
108
109 assert(bandit != NULL);
110 assert(bufmem != NULL);
111
112 vtable = bandit->vtable;
113 assert(vtable != NULL);
114 assert(vtable->banditreset != NULL);
115
116 /* test if the priorities are nonnegative */
117 if( priorities != NULL )
118 {
119 int i;
120
121 assert(SCIPbanditGetNActions(bandit) > 0);
122
123 for( i = 0; i < SCIPbanditGetNActions(bandit); ++i )
124 {
125 if( priorities[i] < 0 )
126 {
127 SCIPerrorMessage("Negative priority for action %d\n", i);
128
129 return SCIP_INVALIDDATA;
130 }
131 }
132 }
133
134 /* reset the random seed of the bandit algorithm */
135 SCIPrandomSetSeed(bandit->rng, seed);
136
137 /* call the reset callback of the bandit algorithm */
138 SCIP_CALL( vtable->banditreset(bufmem, bandit, priorities) );
139
140 return SCIP_OKAY;
141 }
142
143 /** select the next action */
SCIPbanditSelect(SCIP_BANDIT * bandit,int * action)144 SCIP_RETCODE SCIPbanditSelect(
145 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */
146 int* action /**< pointer to store the selected action */
147 )
148 {
149 assert(bandit != NULL);
150 assert(action != NULL);
151
152 *action = -1;
153
154 assert(bandit->vtable->banditselect != NULL);
155
156 SCIP_CALL( bandit->vtable->banditselect(bandit, action) );
157
158 assert(*action >= 0);
159 assert(*action < SCIPbanditGetNActions(bandit));
160
161 return SCIP_OKAY;
162 }
163
164 /** update the score of the selected action */
SCIPbanditUpdate(SCIP_BANDIT * bandit,int action,SCIP_Real score)165 SCIP_RETCODE SCIPbanditUpdate(
166 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */
167 int action, /**< index of action for which the score should be updated */
168 SCIP_Real score /**< observed gain of the i'th action */
169 )
170 {
171 assert(bandit != NULL);
172 assert(0 <= action && action < SCIPbanditGetNActions(bandit));
173 assert(bandit->vtable->banditupdate != NULL);
174
175 SCIP_CALL( bandit->vtable->banditupdate(bandit, action, score) );
176
177 return SCIP_OKAY;
178 }
179
180 /** get data of this bandit algorithm */
SCIPbanditGetData(SCIP_BANDIT * bandit)181 SCIP_BANDITDATA* SCIPbanditGetData(
182 SCIP_BANDIT* bandit /**< bandit algorithm data structure */
183 )
184 {
185 assert(bandit != NULL);
186
187 return bandit->data;
188 }
189
190 /** set the data of this bandit algorithm */
SCIPbanditSetData(SCIP_BANDIT * bandit,SCIP_BANDITDATA * banditdata)191 void SCIPbanditSetData(
192 SCIP_BANDIT* bandit, /**< bandit algorithm data structure */
193 SCIP_BANDITDATA* banditdata /**< bandit algorihm specific data, or NULL */
194 )
195 {
196 assert(bandit != NULL);
197
198 bandit->data = banditdata;
199 }
200
201 /** internal method to create a bandit VTable */
202 static
doBanditvtableCreate(SCIP_BANDITVTABLE ** banditvtable,const char * name,SCIP_DECL_BANDITFREE ((* banditfree)),SCIP_DECL_BANDITSELECT ((* banditselect)),SCIP_DECL_BANDITUPDATE ((* banditupdate)),SCIP_DECL_BANDITRESET ((* banditreset)))203 SCIP_RETCODE doBanditvtableCreate(
204 SCIP_BANDITVTABLE** banditvtable, /**< pointer to virtual table for bandit algorithm */
205 const char* name, /**< a name for the algorithm represented by this vtable */
206 SCIP_DECL_BANDITFREE ((*banditfree)), /**< callback to free bandit specific data structures */
207 SCIP_DECL_BANDITSELECT((*banditselect)), /**< selection callback for bandit selector */
208 SCIP_DECL_BANDITUPDATE((*banditupdate)), /**< update callback for bandit algorithms */
209 SCIP_DECL_BANDITRESET ((*banditreset)) /**< update callback for bandit algorithms */
210 )
211 {
212 SCIP_BANDITVTABLE* banditvtableptr;
213
214 assert(banditvtable != NULL);
215 assert(name != NULL);
216 assert(banditfree != NULL);
217 assert(banditselect != NULL);
218 assert(banditupdate != NULL);
219 assert(banditreset != NULL);
220
221 /* allocate memory for this virtual function table */
222 SCIP_ALLOC( BMSallocMemory(banditvtable) );
223 BMSclearMemory(*banditvtable);
224
225 SCIP_ALLOC( BMSduplicateMemoryArray(&(*banditvtable)->name, name, strlen(name)+1) );
226 banditvtableptr = *banditvtable;
227 banditvtableptr->banditfree = banditfree;
228 banditvtableptr->banditselect = banditselect;
229 banditvtableptr->banditupdate = banditupdate;
230 banditvtableptr->banditreset = banditreset;
231
232 return SCIP_OKAY;
233 }
234
235 /** create a bandit VTable for bandit algorithm callback functions */
SCIPbanditvtableCreate(SCIP_BANDITVTABLE ** banditvtable,const char * name,SCIP_DECL_BANDITFREE ((* banditfree)),SCIP_DECL_BANDITSELECT ((* banditselect)),SCIP_DECL_BANDITUPDATE ((* banditupdate)),SCIP_DECL_BANDITRESET ((* banditreset)))236 SCIP_RETCODE SCIPbanditvtableCreate(
237 SCIP_BANDITVTABLE** banditvtable, /**< pointer to virtual table for bandit algorithm */
238 const char* name, /**< a name for the algorithm represented by this vtable */
239 SCIP_DECL_BANDITFREE ((*banditfree)), /**< callback to free bandit specific data structures */
240 SCIP_DECL_BANDITSELECT((*banditselect)), /**< selection callback for bandit selector */
241 SCIP_DECL_BANDITUPDATE((*banditupdate)), /**< update callback for bandit algorithms */
242 SCIP_DECL_BANDITRESET ((*banditreset)) /**< update callback for bandit algorithms */
243 )
244 {
245 assert(banditvtable != NULL);
246 assert(name != NULL);
247 assert(banditfree != NULL);
248 assert(banditselect != NULL);
249 assert(banditupdate != NULL);
250 assert(banditreset != NULL);
251
252 SCIP_CALL_FINALLY( doBanditvtableCreate(banditvtable, name, banditfree, banditselect, banditupdate, banditreset),
253 SCIPbanditvtableFree(banditvtable) );
254
255 return SCIP_OKAY;
256 }
257
258
259 /** free a bandit virtual table for bandit algorithm callback functions */
SCIPbanditvtableFree(SCIP_BANDITVTABLE ** banditvtable)260 void SCIPbanditvtableFree(
261 SCIP_BANDITVTABLE** banditvtable /**< pointer to virtual table for bandit algorithm */
262 )
263 {
264 assert(banditvtable != NULL);
265 if( *banditvtable == NULL )
266 return;
267
268 BMSfreeMemoryArrayNull(&(*banditvtable)->name);
269 BMSfreeMemory(banditvtable);
270 }
271
272 /** return the name of this bandit virtual function table */
SCIPbanditvtableGetName(SCIP_BANDITVTABLE * banditvtable)273 const char* SCIPbanditvtableGetName(
274 SCIP_BANDITVTABLE* banditvtable /**< virtual table for bandit algorithm */
275 )
276 {
277 assert(banditvtable != NULL);
278
279 return banditvtable->name;
280 }
281
282
283 /** return the random number generator of a bandit algorithm */
SCIPbanditGetRandnumgen(SCIP_BANDIT * bandit)284 SCIP_RANDNUMGEN* SCIPbanditGetRandnumgen(
285 SCIP_BANDIT* bandit /**< bandit algorithm data structure */
286 )
287 {
288 assert(bandit != NULL);
289
290 return bandit->rng;
291 }
292
293 /** return number of actions of this bandit algorithm */
SCIPbanditGetNActions(SCIP_BANDIT * bandit)294 int SCIPbanditGetNActions(
295 SCIP_BANDIT* bandit /**< bandit algorithm data structure */
296 )
297 {
298 assert(bandit != NULL);
299
300 return bandit->nactions;
301 }
302