1 /* Copyright 2016, Ableton AG, Berlin. All rights reserved.
2 *
3 * This program is free software: you can redistribute it and/or modify
4 * it under the terms of the GNU General Public License as published by
5 * the Free Software Foundation, either version 2 of the License, or
6 * (at your option) any later version.
7 *
8 * This program is distributed in the hope that it will be useful,
9 * but WITHOUT ANY WARRANTY; without even the implied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 * GNU General Public License for more details.
12 *
13 * You should have received a copy of the GNU General Public License
14 * along with this program. If not, see <http://www.gnu.org/licenses/>.
15 *
16 * If you would like to incorporate Link into a proprietary software application,
17 * please contact <link-devs@ableton.com>.
18 */
19
20 #pragma once
21
22 #include <ableton/discovery/InterfaceScanner.hpp>
23 #include <ableton/platforms/asio/AsioWrapper.hpp>
24 #include <map>
25
26 namespace ableton
27 {
28 namespace discovery
29 {
30
31 // GatewayFactory must have an operator()(NodeState, IoRef, asio::ip::address)
32 // that constructs a new PeerGateway on a given interface address.
33 template <typename NodeState, typename GatewayFactory, typename IoContext>
34 class PeerGateways
35 {
36 public:
37 using IoType = typename util::Injected<IoContext>::type;
38 using Gateway = typename std::result_of<GatewayFactory(
39 NodeState, util::Injected<IoType&>, asio::ip::address)>::type;
40 using GatewayMap = std::map<asio::ip::address, Gateway>;
41
PeerGateways(const std::chrono::seconds rescanPeriod,NodeState state,GatewayFactory factory,util::Injected<IoContext> io)42 PeerGateways(const std::chrono::seconds rescanPeriod,
43 NodeState state,
44 GatewayFactory factory,
45 util::Injected<IoContext> io)
46 : mIo(std::move(io))
47 {
48 mpScannerCallback =
49 std::make_shared<Callback>(std::move(state), std::move(factory), *mIo);
50 mpScanner = std::make_shared<Scanner>(
51 rescanPeriod, util::injectShared(mpScannerCallback), util::injectRef(*mIo));
52 }
53
~PeerGateways()54 ~PeerGateways()
55 {
56 // Release the callback in the io thread so that gateway cleanup
57 // doesn't happen in the client thread
58 mIo->async(Deleter{*this});
59 }
60
61 PeerGateways(const PeerGateways&) = delete;
62 PeerGateways& operator=(const PeerGateways&) = delete;
63
64 PeerGateways(PeerGateways&&) = delete;
65 PeerGateways& operator=(PeerGateways&&) = delete;
66
enable(const bool bEnable)67 void enable(const bool bEnable)
68 {
69 auto pCallback = mpScannerCallback;
70 auto pScanner = mpScanner;
71
72 if (pCallback && pScanner)
73 {
74 mIo->async([pCallback, pScanner, bEnable] {
75 pCallback->mGateways.clear();
76 pScanner->enable(bEnable);
77 });
78 }
79 }
80
81 template <typename Handler>
withGatewaysAsync(Handler handler)82 void withGatewaysAsync(Handler handler)
83 {
84 auto pCallback = mpScannerCallback;
85 if (pCallback)
86 {
87 mIo->async([pCallback, handler] {
88 handler(pCallback->mGateways.begin(), pCallback->mGateways.end());
89 });
90 }
91 }
92
updateNodeState(const NodeState & state)93 void updateNodeState(const NodeState& state)
94 {
95 auto pCallback = mpScannerCallback;
96 if (pCallback)
97 {
98 mIo->async([pCallback, state] {
99 pCallback->mState = state;
100 for (const auto& entry : pCallback->mGateways)
101 {
102 entry.second->updateNodeState(state);
103 }
104 });
105 }
106 }
107
108 // If a gateway has become non-responsive or is throwing exceptions,
109 // this method can be invoked to either fix it or discard it.
repairGateway(const asio::ip::address & gatewayAddr)110 void repairGateway(const asio::ip::address& gatewayAddr)
111 {
112 auto pCallback = mpScannerCallback;
113 auto pScanner = mpScanner;
114 if (pCallback && pScanner)
115 {
116 mIo->async([pCallback, pScanner, gatewayAddr] {
117 if (pCallback->mGateways.erase(gatewayAddr))
118 {
119 // If we erased a gateway, rescan again immediately so that
120 // we will re-initialize it if it's still present
121 pScanner->scan();
122 }
123 });
124 }
125 }
126
127 private:
128 struct Callback
129 {
Callbackableton::discovery::PeerGateways::Callback130 Callback(NodeState state, GatewayFactory factory, IoType& io)
131 : mState(std::move(state))
132 , mFactory(std::move(factory))
133 , mIo(io)
134 {
135 }
136
137 template <typename AddrRange>
operator ()ableton::discovery::PeerGateways::Callback138 void operator()(const AddrRange& range)
139 {
140 using namespace std;
141 // Get the set of current addresses.
142 vector<asio::ip::address> curAddrs;
143 curAddrs.reserve(mGateways.size());
144 transform(std::begin(mGateways), std::end(mGateways), back_inserter(curAddrs),
145 [](const typename GatewayMap::value_type& vt) { return vt.first; });
146
147 // Now use set_difference to determine the set of addresses that
148 // are new and the set of cur addresses that are no longer there
149 vector<asio::ip::address> newAddrs;
150 set_difference(std::begin(range), std::end(range), std::begin(curAddrs),
151 std::end(curAddrs), back_inserter(newAddrs));
152
153 vector<asio::ip::address> staleAddrs;
154 set_difference(std::begin(curAddrs), std::end(curAddrs), std::begin(range),
155 std::end(range), back_inserter(staleAddrs));
156
157 // Remove the stale addresses
158 for (const auto& addr : staleAddrs)
159 {
160 mGateways.erase(addr);
161 }
162
163 // Add the new addresses
164 for (const auto& addr : newAddrs)
165 {
166 try
167 {
168 // Only handle v4 for now
169 if (addr.is_v4())
170 {
171 info(mIo.log()) << "initializing peer gateway on interface " << addr;
172 mGateways.emplace(addr, mFactory(mState, util::injectRef(mIo), addr.to_v4()));
173 }
174 }
175 catch (const runtime_error& e)
176 {
177 warning(mIo.log()) << "failed to init gateway on interface " << addr
178 << " reason: " << e.what();
179 }
180 }
181 }
182
183 NodeState mState;
184 GatewayFactory mFactory;
185 IoType& mIo;
186 GatewayMap mGateways;
187 };
188
189 using Scanner = InterfaceScanner<std::shared_ptr<Callback>, IoType&>;
190
191 struct Deleter
192 {
Deleterableton::discovery::PeerGateways::Deleter193 Deleter(PeerGateways& gateways)
194 : mpScannerCallback(std::move(gateways.mpScannerCallback))
195 , mpScanner(std::move(gateways.mpScanner))
196 {
197 }
198
operator ()ableton::discovery::PeerGateways::Deleter199 void operator()()
200 {
201 mpScanner.reset();
202 mpScannerCallback.reset();
203 }
204
205 std::shared_ptr<Callback> mpScannerCallback;
206 std::shared_ptr<Scanner> mpScanner;
207 };
208
209 std::shared_ptr<Callback> mpScannerCallback;
210 std::shared_ptr<Scanner> mpScanner;
211 util::Injected<IoContext> mIo;
212 };
213
214 // Factory function
215 template <typename NodeState, typename GatewayFactory, typename IoContext>
makePeerGateways(const std::chrono::seconds rescanPeriod,NodeState state,GatewayFactory factory,util::Injected<IoContext> io)216 std::unique_ptr<PeerGateways<NodeState, GatewayFactory, IoContext>> makePeerGateways(
217 const std::chrono::seconds rescanPeriod,
218 NodeState state,
219 GatewayFactory factory,
220 util::Injected<IoContext> io)
221 {
222 using namespace std;
223 using Gateways = PeerGateways<NodeState, GatewayFactory, IoContext>;
224 return unique_ptr<Gateways>{
225 new Gateways{rescanPeriod, move(state), move(factory), move(io)}};
226 }
227
228 } // namespace discovery
229 } // namespace ableton
230