1 using System;
2 using System.Runtime.InteropServices;
3 using Microsoft.VisualStudio.TestTools.UnitTesting;
4 using MultiWorldTesting;
5 using System.Collections.Generic;
6 using System.Linq;
7 
8 namespace ExploreTests
9 {
10     [TestClass]
11     public class MWTExploreTests
12     {
13         /*
14         ** C# Tests do not need to be as extensive as those for C++. These tests should ensure
15         ** the interactions between managed and native code are as expected.
16         */
17         [TestMethod]
EpsilonGreedy()18         public void EpsilonGreedy()
19         {
20             uint numActions = 10;
21             float epsilon = 0f;
22             string uniqueKey = "ManagedTestId";
23 
24             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
25             TestPolicy policy = new TestPolicy();
26             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
27             TestContext testContext = new TestContext();
28             testContext.Id = 100;
29 
30             var explorer = new EpsilonGreedyExplorer<TestContext>(policy, epsilon, numActions);
31 
32             uint expectedAction = policy.ChooseAction(testContext);
33 
34             uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
35             Assert.AreEqual(expectedAction, chosenAction);
36 
37             chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
38             Assert.AreEqual(expectedAction, chosenAction);
39 
40             var interactions = recorder.GetAllInteractions();
41             Assert.AreEqual(2, interactions.Count);
42 
43             Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
44         }
45 
46         [TestMethod]
TauFirst()47         public void TauFirst()
48         {
49             uint numActions = 10;
50             uint tau = 0;
51             string uniqueKey = "ManagedTestId";
52 
53             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
54             TestPolicy policy = new TestPolicy();
55             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
56             TestContext testContext = new TestContext() { Id = 100 };
57 
58             var explorer = new TauFirstExplorer<TestContext>(policy, tau, numActions);
59 
60             uint expectedAction = policy.ChooseAction(testContext);
61 
62             uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
63             Assert.AreEqual(expectedAction, chosenAction);
64 
65             var interactions = recorder.GetAllInteractions();
66             Assert.AreEqual(0, interactions.Count);
67         }
68 
69         [TestMethod]
Bootstrap()70         public void Bootstrap()
71         {
72             uint numActions = 10;
73             uint numbags = 2;
74             string uniqueKey = "ManagedTestId";
75 
76             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
77             TestPolicy[] policies = new TestPolicy[numbags];
78             for (int i = 0; i < numbags; i++)
79             {
80                 policies[i] = new TestPolicy(i * 2);
81             }
82             TestContext testContext1 = new TestContext() { Id = 99 };
83             TestContext testContext2 = new TestContext() { Id = 100 };
84 
85             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
86             var explorer = new BootstrapExplorer<TestContext>(policies, numActions);
87 
88             uint expectedAction = policies[0].ChooseAction(testContext1);
89 
90             uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1);
91             Assert.AreEqual(expectedAction, chosenAction);
92 
93             chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext2);
94             Assert.AreEqual(expectedAction, chosenAction);
95 
96             var interactions = recorder.GetAllInteractions();
97             Assert.AreEqual(2, interactions.Count);
98 
99             Assert.AreEqual(testContext1.Id, interactions[0].Context.Id);
100             Assert.AreEqual(testContext2.Id, interactions[1].Context.Id);
101         }
102 
103         [TestMethod]
Softmax()104         public void Softmax()
105         {
106             uint numActions = 10;
107             float lambda = 0.5f;
108             uint numActionsCover = 100;
109             float C = 5;
110 
111             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
112             TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions);
113 
114             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
115             var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions);
116 
117             uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions);
118             uint[] actions = new uint[numActions];
119 
120             Random rand = new Random();
121             for (uint i = 0; i < numDecisions; i++)
122             {
123                 uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = (int)i });
124                 actions[chosenAction - 1]++; // action id is one-based
125             }
126 
127             for (uint i = 0; i < numActions; i++)
128             {
129                 Assert.IsTrue(actions[i] > 0);
130             }
131 
132             var interactions = recorder.GetAllInteractions();
133             Assert.AreEqual(numDecisions, (uint)interactions.Count);
134 
135             for (int i = 0; i < numDecisions; i++)
136             {
137                 Assert.AreEqual(i, interactions[i].Context.Id);
138             }
139         }
140 
141         [TestMethod]
SoftmaxScores()142         public void SoftmaxScores()
143         {
144             uint numActions = 10;
145             float lambda = 0.5f;
146             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
147             TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions, uniform: false);
148 
149             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
150             var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions);
151 
152             Random rand = new Random();
153             mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 100 });
154             mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 101 });
155             mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 102 });
156 
157             var interactions = recorder.GetAllInteractions();
158 
159             Assert.AreEqual(3, interactions.Count);
160 
161             for (int i = 0; i < interactions.Count; i++)
162             {
163                 // Scores are not equal therefore probabilities should not be uniform
164                 Assert.AreNotEqual(interactions[i].Probability, 1.0f / numActions);
165                 Assert.AreEqual(100 + i, interactions[i].Context.Id);
166             }
167         }
168 
169         [TestMethod]
Generic()170         public void Generic()
171         {
172             uint numActions = 10;
173             string uniqueKey = "ManagedTestId";
174             TestRecorder<TestContext> recorder = new TestRecorder<TestContext>();
175             TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions);
176 
177             MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder);
178             var explorer = new GenericExplorer<TestContext>(scorer, numActions);
179 
180             TestContext testContext = new TestContext() { Id = 100 };
181             uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
182 
183             var interactions = recorder.GetAllInteractions();
184             Assert.AreEqual(1, interactions.Count);
185             Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
186         }
187 
188         [TestInitialize]
TestInitialize()189         public void TestInitialize()
190         {
191         }
192 
193         [TestCleanup]
TestCleanup()194         public void TestCleanup()
195         {
196         }
197     }
198 
199     struct TestInteraction<Ctx>
200     {
201         public Ctx Context;
202         public UInt32 Action;
203         public float Probability;
204         public string UniqueKey;
205     }
206 
207     class TestContext
208     {
209         private int id;
210 
211         public int Id
212         {
213             get { return id; }
214             set { id = value; }
215         }
216     }
217 
218     class TestRecorder<Ctx> : IRecorder<Ctx>
219     {
Record(Ctx context, UInt32 action, float probability, string uniqueKey)220         public void Record(Ctx context, UInt32 action, float probability, string uniqueKey)
221         {
222             interactions.Add(new TestInteraction<Ctx>()
223             {
224                 Context = context,
225                 Action = action,
226                 Probability = probability,
227                 UniqueKey = uniqueKey
228             });
229         }
230 
GetAllInteractions()231         public List<TestInteraction<Ctx>> GetAllInteractions()
232         {
233             return interactions;
234         }
235 
236         private List<TestInteraction<Ctx>> interactions = new List<TestInteraction<Ctx>>();
237     }
238 
239     class TestPolicy : IPolicy<TestContext>
240     {
TestPolicy()241         public TestPolicy() : this(-1) { }
242 
TestPolicy(int index)243         public TestPolicy(int index)
244         {
245             this.index = index;
246         }
247 
ChooseAction(TestContext context)248         public uint ChooseAction(TestContext context)
249         {
250             return 5;
251         }
252 
253         private int index;
254     }
255 
256     class TestSimplePolicy : IPolicy<SimpleContext>
257     {
ChooseAction(SimpleContext context)258         public uint ChooseAction(SimpleContext context)
259         {
260             return 1;
261         }
262     }
263 
264     class StringPolicy : IPolicy<SimpleContext>
265     {
ChooseAction(SimpleContext context)266         public uint ChooseAction(SimpleContext context)
267         {
268             return 1;
269         }
270     }
271 
272     class TestScorer<Ctx> : IScorer<Ctx>
273     {
TestScorer(uint numActions, bool uniform = true)274         public TestScorer(uint numActions, bool uniform = true)
275         {
276             this.uniform = uniform;
277             this.numActions = numActions;
278         }
ScoreActions(Ctx context)279         public List<float> ScoreActions(Ctx context)
280         {
281             if (uniform)
282             {
283                 return Enumerable.Repeat<float>(1.0f / numActions, (int)numActions).ToList();
284             }
285             else
286             {
287                 return Array.ConvertAll<int, float>(Enumerable.Range(1, (int)numActions).ToArray(), Convert.ToSingle).ToList();
288             }
289         }
290         private uint numActions;
291         private bool uniform;
292     }
293 }
294