1 using System; 2 using System.Runtime.InteropServices; 3 using Microsoft.VisualStudio.TestTools.UnitTesting; 4 using MultiWorldTesting; 5 using System.Collections.Generic; 6 using System.Linq; 7 8 namespace ExploreTests 9 { 10 [TestClass] 11 public class MWTExploreTests 12 { 13 /* 14 ** C# Tests do not need to be as extensive as those for C++. These tests should ensure 15 ** the interactions between managed and native code are as expected. 16 */ 17 [TestMethod] EpsilonGreedy()18 public void EpsilonGreedy() 19 { 20 uint numActions = 10; 21 float epsilon = 0f; 22 string uniqueKey = "ManagedTestId"; 23 24 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 25 TestPolicy policy = new TestPolicy(); 26 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 27 TestContext testContext = new TestContext(); 28 testContext.Id = 100; 29 30 var explorer = new EpsilonGreedyExplorer<TestContext>(policy, epsilon, numActions); 31 32 uint expectedAction = policy.ChooseAction(testContext); 33 34 uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 35 Assert.AreEqual(expectedAction, chosenAction); 36 37 chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 38 Assert.AreEqual(expectedAction, chosenAction); 39 40 var interactions = recorder.GetAllInteractions(); 41 Assert.AreEqual(2, interactions.Count); 42 43 Assert.AreEqual(testContext.Id, interactions[0].Context.Id); 44 } 45 46 [TestMethod] TauFirst()47 public void TauFirst() 48 { 49 uint numActions = 10; 50 uint tau = 0; 51 string uniqueKey = "ManagedTestId"; 52 53 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 54 TestPolicy policy = new TestPolicy(); 55 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 56 TestContext testContext = new TestContext() { Id = 100 }; 57 58 var explorer = new TauFirstExplorer<TestContext>(policy, tau, numActions); 59 60 uint expectedAction = policy.ChooseAction(testContext); 61 62 uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 63 Assert.AreEqual(expectedAction, chosenAction); 64 65 var interactions = recorder.GetAllInteractions(); 66 Assert.AreEqual(0, interactions.Count); 67 } 68 69 [TestMethod] Bootstrap()70 public void Bootstrap() 71 { 72 uint numActions = 10; 73 uint numbags = 2; 74 string uniqueKey = "ManagedTestId"; 75 76 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 77 TestPolicy[] policies = new TestPolicy[numbags]; 78 for (int i = 0; i < numbags; i++) 79 { 80 policies[i] = new TestPolicy(i * 2); 81 } 82 TestContext testContext1 = new TestContext() { Id = 99 }; 83 TestContext testContext2 = new TestContext() { Id = 100 }; 84 85 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 86 var explorer = new BootstrapExplorer<TestContext>(policies, numActions); 87 88 uint expectedAction = policies[0].ChooseAction(testContext1); 89 90 uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1); 91 Assert.AreEqual(expectedAction, chosenAction); 92 93 chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext2); 94 Assert.AreEqual(expectedAction, chosenAction); 95 96 var interactions = recorder.GetAllInteractions(); 97 Assert.AreEqual(2, interactions.Count); 98 99 Assert.AreEqual(testContext1.Id, interactions[0].Context.Id); 100 Assert.AreEqual(testContext2.Id, interactions[1].Context.Id); 101 } 102 103 [TestMethod] Softmax()104 public void Softmax() 105 { 106 uint numActions = 10; 107 float lambda = 0.5f; 108 uint numActionsCover = 100; 109 float C = 5; 110 111 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 112 TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions); 113 114 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 115 var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions); 116 117 uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions); 118 uint[] actions = new uint[numActions]; 119 120 Random rand = new Random(); 121 for (uint i = 0; i < numDecisions; i++) 122 { 123 uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = (int)i }); 124 actions[chosenAction - 1]++; // action id is one-based 125 } 126 127 for (uint i = 0; i < numActions; i++) 128 { 129 Assert.IsTrue(actions[i] > 0); 130 } 131 132 var interactions = recorder.GetAllInteractions(); 133 Assert.AreEqual(numDecisions, (uint)interactions.Count); 134 135 for (int i = 0; i < numDecisions; i++) 136 { 137 Assert.AreEqual(i, interactions[i].Context.Id); 138 } 139 } 140 141 [TestMethod] SoftmaxScores()142 public void SoftmaxScores() 143 { 144 uint numActions = 10; 145 float lambda = 0.5f; 146 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 147 TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions, uniform: false); 148 149 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 150 var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions); 151 152 Random rand = new Random(); 153 mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 100 }); 154 mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 101 }); 155 mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 102 }); 156 157 var interactions = recorder.GetAllInteractions(); 158 159 Assert.AreEqual(3, interactions.Count); 160 161 for (int i = 0; i < interactions.Count; i++) 162 { 163 // Scores are not equal therefore probabilities should not be uniform 164 Assert.AreNotEqual(interactions[i].Probability, 1.0f / numActions); 165 Assert.AreEqual(100 + i, interactions[i].Context.Id); 166 } 167 } 168 169 [TestMethod] Generic()170 public void Generic() 171 { 172 uint numActions = 10; 173 string uniqueKey = "ManagedTestId"; 174 TestRecorder<TestContext> recorder = new TestRecorder<TestContext>(); 175 TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions); 176 177 MwtExplorer<TestContext> mwtt = new MwtExplorer<TestContext>("mwt", recorder); 178 var explorer = new GenericExplorer<TestContext>(scorer, numActions); 179 180 TestContext testContext = new TestContext() { Id = 100 }; 181 uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 182 183 var interactions = recorder.GetAllInteractions(); 184 Assert.AreEqual(1, interactions.Count); 185 Assert.AreEqual(testContext.Id, interactions[0].Context.Id); 186 } 187 188 [TestInitialize] TestInitialize()189 public void TestInitialize() 190 { 191 } 192 193 [TestCleanup] TestCleanup()194 public void TestCleanup() 195 { 196 } 197 } 198 199 struct TestInteraction<Ctx> 200 { 201 public Ctx Context; 202 public UInt32 Action; 203 public float Probability; 204 public string UniqueKey; 205 } 206 207 class TestContext 208 { 209 private int id; 210 211 public int Id 212 { 213 get { return id; } 214 set { id = value; } 215 } 216 } 217 218 class TestRecorder<Ctx> : IRecorder<Ctx> 219 { Record(Ctx context, UInt32 action, float probability, string uniqueKey)220 public void Record(Ctx context, UInt32 action, float probability, string uniqueKey) 221 { 222 interactions.Add(new TestInteraction<Ctx>() 223 { 224 Context = context, 225 Action = action, 226 Probability = probability, 227 UniqueKey = uniqueKey 228 }); 229 } 230 GetAllInteractions()231 public List<TestInteraction<Ctx>> GetAllInteractions() 232 { 233 return interactions; 234 } 235 236 private List<TestInteraction<Ctx>> interactions = new List<TestInteraction<Ctx>>(); 237 } 238 239 class TestPolicy : IPolicy<TestContext> 240 { TestPolicy()241 public TestPolicy() : this(-1) { } 242 TestPolicy(int index)243 public TestPolicy(int index) 244 { 245 this.index = index; 246 } 247 ChooseAction(TestContext context)248 public uint ChooseAction(TestContext context) 249 { 250 return 5; 251 } 252 253 private int index; 254 } 255 256 class TestSimplePolicy : IPolicy<SimpleContext> 257 { ChooseAction(SimpleContext context)258 public uint ChooseAction(SimpleContext context) 259 { 260 return 1; 261 } 262 } 263 264 class StringPolicy : IPolicy<SimpleContext> 265 { ChooseAction(SimpleContext context)266 public uint ChooseAction(SimpleContext context) 267 { 268 return 1; 269 } 270 } 271 272 class TestScorer<Ctx> : IScorer<Ctx> 273 { TestScorer(uint numActions, bool uniform = true)274 public TestScorer(uint numActions, bool uniform = true) 275 { 276 this.uniform = uniform; 277 this.numActions = numActions; 278 } ScoreActions(Ctx context)279 public List<float> ScoreActions(Ctx context) 280 { 281 if (uniform) 282 { 283 return Enumerable.Repeat<float>(1.0f / numActions, (int)numActions).ToList(); 284 } 285 else 286 { 287 return Array.ConvertAll<int, float>(Enumerable.Range(1, (int)numActions).ToArray(), Convert.ToSingle).ToList(); 288 } 289 } 290 private uint numActions; 291 private bool uniform; 292 } 293 } 294