1 /**
2 * @file methods/reinforcement_learning/async_learning_impl.hpp
3 * @author Shangtong Zhang
4 *
5 * This file is the implementation of AsyncLearning class,
6 * which is wrapper for various asynchronous learning algorithms.
7 *
8 * mlpack is free software; you may redistribute it and/or modify it under the
9 * terms of the 3-clause BSD license. You should have received a copy of the
10 * 3-clause BSD license along with mlpack. If not, see
11 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12 */
13 #ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
14 #define MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
15
16 #include <mlpack/prereqs.hpp>
17 #include "queue"
18
19 namespace mlpack {
20 namespace rl {
21
22 template <
23 typename WorkerType,
24 typename EnvironmentType,
25 typename NetworkType,
26 typename UpdaterType,
27 typename PolicyType
28 >
29 AsyncLearning<
30 WorkerType,
31 EnvironmentType,
32 NetworkType,
33 UpdaterType,
34 PolicyType
AsyncLearning(TrainingConfig config,NetworkType network,PolicyType policy,UpdaterType updater,EnvironmentType environment)35 >::AsyncLearning(
36 TrainingConfig config,
37 NetworkType network,
38 PolicyType policy,
39 UpdaterType updater,
40 EnvironmentType environment):
41 config(std::move(config)),
42 learningNetwork(std::move(network)),
43 policy(std::move(policy)),
44 updater(std::move(updater)),
45 environment(std::move(environment))
46 { /* Nothing to do here. */ };
47
48 template <
49 typename WorkerType,
50 typename EnvironmentType,
51 typename NetworkType,
52 typename UpdaterType,
53 typename PolicyType
54 >
55 template <typename Measure>
56 void AsyncLearning<
57 WorkerType,
58 EnvironmentType,
59 NetworkType,
60 UpdaterType,
61 PolicyType
Train(Measure & measure)62 >::Train(Measure& measure)
63 {
64 /**
65 * OpenMP doesn't support shared class member variables.
66 * So we need to copy them to local variables.
67 */
68 NetworkType learningNetwork = std::move(this->learningNetwork);
69 if (learningNetwork.Parameters().is_empty())
70 learningNetwork.ResetParameters();
71 NetworkType targetNetwork = learningNetwork;
72 size_t totalSteps = 0;
73 PolicyType policy = this->policy;
74 bool stop = false;
75
76 // Set up worker pool, worker 0 will be deterministic for evaluation.
77 std::vector<WorkerType> workers;
78 for (size_t i = 0; i <= config.NumWorkers(); ++i)
79 {
80 workers.push_back(WorkerType(updater, environment, config, !i));
81 workers.back().Initialize(learningNetwork);
82 }
83 // Set up task queue corresponding to worker pool.
84 std::queue<size_t> tasks;
85 for (size_t i = 0; i <= config.NumWorkers(); ++i)
86 tasks.push(i);
87
88 /**
89 * Compute the number of threads for the for-loop. In general, we should use
90 * OpenMP task rather than for-loop, here we do so to be compatible with some
91 * compiler. We can switch to OpenMP task once MSVC supports OpenMP 3.0.
92 */
93 size_t numThreads = 0;
94 #pragma omp parallel reduction(+:numThreads)
95 numThreads++;
96 Log::Debug << numThreads << " threads will be used in total." << std::endl;
97
98 #pragma omp parallel for shared(stop, workers, tasks, learningNetwork, \
99 targetNetwork, totalSteps, policy)
100 for (omp_size_t i = 0; i < numThreads; ++i)
101 {
102 #pragma omp critical
103 {
104 #ifdef HAS_OPENMP
105 Log::Debug << "Thread " << omp_get_thread_num() <<
106 " started." << std::endl;
107 #endif
108 }
109 size_t task = std::numeric_limits<size_t>::max();
110 while (!stop)
111 {
112 // Assign task to current thread from queue.
113 #pragma omp critical
114 {
115 if (task != std::numeric_limits<size_t>::max())
116 tasks.push(task);
117
118 if (!tasks.empty())
119 {
120 task = tasks.front();
121 tasks.pop();
122 }
123 };
124
125 // This may happen when threads are more than workers.
126 if (task == std::numeric_limits<size_t>::max())
127 continue;
128
129 // Get corresponding worker.
130 WorkerType& worker = workers[task];
131 double episodeReturn;
132 if (worker.Step(learningNetwork, targetNetwork, totalSteps,
133 policy, episodeReturn) && !task)
134 {
135 stop = measure(episodeReturn);
136 }
137 }
138 }
139
140 // Write back the learning network.
141 this->learningNetwork = std::move(learningNetwork);
142 };
143
144 } // namespace rl
145 } // namespace mlpack
146
147 #endif
148
149