1 /**
2  * @file methods/reinforcement_learning/async_learning_impl.hpp
3  * @author Shangtong Zhang
4  *
5  * This file is the implementation of AsyncLearning class,
6  * which is wrapper for various asynchronous learning algorithms.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
14 #define MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "queue"
18 
19 namespace mlpack {
20 namespace rl {
21 
22 template <
23   typename WorkerType,
24   typename EnvironmentType,
25   typename NetworkType,
26   typename UpdaterType,
27   typename PolicyType
28 >
29 AsyncLearning<
30   WorkerType,
31   EnvironmentType,
32   NetworkType,
33   UpdaterType,
34   PolicyType
AsyncLearning(TrainingConfig config,NetworkType network,PolicyType policy,UpdaterType updater,EnvironmentType environment)35 >::AsyncLearning(
36     TrainingConfig config,
37     NetworkType network,
38     PolicyType policy,
39     UpdaterType updater,
40     EnvironmentType environment):
41     config(std::move(config)),
42     learningNetwork(std::move(network)),
43     policy(std::move(policy)),
44     updater(std::move(updater)),
45     environment(std::move(environment))
46 { /* Nothing to do here. */ };
47 
48 template <
49   typename WorkerType,
50   typename EnvironmentType,
51   typename NetworkType,
52   typename UpdaterType,
53   typename PolicyType
54 >
55 template <typename Measure>
56 void AsyncLearning<
57   WorkerType,
58   EnvironmentType,
59   NetworkType,
60   UpdaterType,
61   PolicyType
Train(Measure & measure)62 >::Train(Measure& measure)
63 {
64   /**
65    * OpenMP doesn't support shared class member variables.
66    * So we need to copy them to local variables.
67    */
68   NetworkType learningNetwork = std::move(this->learningNetwork);
69   if (learningNetwork.Parameters().is_empty())
70     learningNetwork.ResetParameters();
71   NetworkType targetNetwork = learningNetwork;
72   size_t totalSteps = 0;
73   PolicyType policy = this->policy;
74   bool stop = false;
75 
76   // Set up worker pool, worker 0 will be deterministic for evaluation.
77   std::vector<WorkerType> workers;
78   for (size_t i = 0; i <= config.NumWorkers(); ++i)
79   {
80     workers.push_back(WorkerType(updater, environment, config, !i));
81     workers.back().Initialize(learningNetwork);
82   }
83   // Set up task queue corresponding to worker pool.
84   std::queue<size_t> tasks;
85   for (size_t i = 0; i <= config.NumWorkers(); ++i)
86     tasks.push(i);
87 
88   /**
89    * Compute the number of threads for the for-loop. In general, we should use
90    * OpenMP task rather than for-loop, here we do so to be compatible with some
91    * compiler. We can switch to OpenMP task once MSVC supports OpenMP 3.0.
92    */
93   size_t numThreads = 0;
94   #pragma omp parallel reduction(+:numThreads)
95   numThreads++;
96   Log::Debug << numThreads << " threads will be used in total." << std::endl;
97 
98   #pragma omp parallel for shared(stop, workers, tasks, learningNetwork, \
99       targetNetwork, totalSteps, policy)
100   for (omp_size_t i = 0; i < numThreads; ++i)
101   {
102     #pragma omp critical
103     {
104       #ifdef HAS_OPENMP
105         Log::Debug << "Thread " << omp_get_thread_num() <<
106             " started." << std::endl;
107       #endif
108     }
109     size_t task = std::numeric_limits<size_t>::max();
110     while (!stop)
111     {
112       // Assign task to current thread from queue.
113       #pragma omp critical
114       {
115         if (task != std::numeric_limits<size_t>::max())
116           tasks.push(task);
117 
118         if (!tasks.empty())
119         {
120           task = tasks.front();
121           tasks.pop();
122         }
123       };
124 
125       // This may happen when threads are more than workers.
126       if (task == std::numeric_limits<size_t>::max())
127         continue;
128 
129       // Get corresponding worker.
130       WorkerType& worker = workers[task];
131       double episodeReturn;
132       if (worker.Step(learningNetwork, targetNetwork, totalSteps,
133           policy, episodeReturn) && !task)
134       {
135         stop = measure(episodeReturn);
136       }
137     }
138   }
139 
140   // Write back the learning network.
141   this->learningNetwork = std::move(learningNetwork);
142 };
143 
144 } // namespace rl
145 } // namespace mlpack
146 
147 #endif
148 
149