1 /* Copyright 2016, Ableton AG, Berlin. All rights reserved.
2  *
3  *  This program is free software: you can redistribute it and/or modify
4  *  it under the terms of the GNU General Public License as published by
5  *  the Free Software Foundation, either version 2 of the License, or
6  *  (at your option) any later version.
7  *
8  *  This program is distributed in the hope that it will be useful,
9  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
10  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11  *  GNU General Public License for more details.
12  *
13  *  You should have received a copy of the GNU General Public License
14  *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
15  *
16  *  If you would like to incorporate Link into a proprietary software application,
17  *  please contact <link-devs@ableton.com>.
18  */
19 
20 #pragma once
21 
22 #include <ableton/discovery/InterfaceScanner.hpp>
23 #include <ableton/platforms/asio/AsioWrapper.hpp>
24 #include <map>
25 
26 namespace ableton
27 {
28 namespace discovery
29 {
30 
31 // GatewayFactory must have an operator()(NodeState, IoRef, asio::ip::address)
32 // that constructs a new PeerGateway on a given interface address.
33 template <typename NodeState, typename GatewayFactory, typename IoContext>
34 class PeerGateways
35 {
36 public:
37   using IoType = typename util::Injected<IoContext>::type;
38   using Gateway = typename std::result_of<GatewayFactory(
39     NodeState, util::Injected<IoType&>, asio::ip::address)>::type;
40   using GatewayMap = std::map<asio::ip::address, Gateway>;
41 
PeerGateways(const std::chrono::seconds rescanPeriod,NodeState state,GatewayFactory factory,util::Injected<IoContext> io)42   PeerGateways(const std::chrono::seconds rescanPeriod,
43     NodeState state,
44     GatewayFactory factory,
45     util::Injected<IoContext> io)
46     : mIo(std::move(io))
47   {
48     mpScannerCallback =
49       std::make_shared<Callback>(std::move(state), std::move(factory), *mIo);
50     mpScanner = std::make_shared<Scanner>(
51       rescanPeriod, util::injectShared(mpScannerCallback), util::injectRef(*mIo));
52   }
53 
~PeerGateways()54   ~PeerGateways()
55   {
56     // Release the callback in the io thread so that gateway cleanup
57     // doesn't happen in the client thread
58     mIo->async(Deleter{*this});
59   }
60 
61   PeerGateways(const PeerGateways&) = delete;
62   PeerGateways& operator=(const PeerGateways&) = delete;
63 
64   PeerGateways(PeerGateways&&) = delete;
65   PeerGateways& operator=(PeerGateways&&) = delete;
66 
enable(const bool bEnable)67   void enable(const bool bEnable)
68   {
69     auto pCallback = mpScannerCallback;
70     auto pScanner = mpScanner;
71 
72     if (pCallback && pScanner)
73     {
74       mIo->async([pCallback, pScanner, bEnable] {
75         pCallback->mGateways.clear();
76         pScanner->enable(bEnable);
77       });
78     }
79   }
80 
81   template <typename Handler>
withGatewaysAsync(Handler handler)82   void withGatewaysAsync(Handler handler)
83   {
84     auto pCallback = mpScannerCallback;
85     if (pCallback)
86     {
87       mIo->async([pCallback, handler] {
88         handler(pCallback->mGateways.begin(), pCallback->mGateways.end());
89       });
90     }
91   }
92 
updateNodeState(const NodeState & state)93   void updateNodeState(const NodeState& state)
94   {
95     auto pCallback = mpScannerCallback;
96     if (pCallback)
97     {
98       mIo->async([pCallback, state] {
99         pCallback->mState = state;
100         for (const auto& entry : pCallback->mGateways)
101         {
102           entry.second->updateNodeState(state);
103         }
104       });
105     }
106   }
107 
108   // If a gateway has become non-responsive or is throwing exceptions,
109   // this method can be invoked to either fix it or discard it.
repairGateway(const asio::ip::address & gatewayAddr)110   void repairGateway(const asio::ip::address& gatewayAddr)
111   {
112     auto pCallback = mpScannerCallback;
113     auto pScanner = mpScanner;
114     if (pCallback && pScanner)
115     {
116       mIo->async([pCallback, pScanner, gatewayAddr] {
117         if (pCallback->mGateways.erase(gatewayAddr))
118         {
119           // If we erased a gateway, rescan again immediately so that
120           // we will re-initialize it if it's still present
121           pScanner->scan();
122         }
123       });
124     }
125   }
126 
127 private:
128   struct Callback
129   {
Callbackableton::discovery::PeerGateways::Callback130     Callback(NodeState state, GatewayFactory factory, IoType& io)
131       : mState(std::move(state))
132       , mFactory(std::move(factory))
133       , mIo(io)
134     {
135     }
136 
137     template <typename AddrRange>
operator ()ableton::discovery::PeerGateways::Callback138     void operator()(const AddrRange& range)
139     {
140       using namespace std;
141       // Get the set of current addresses.
142       vector<asio::ip::address> curAddrs;
143       curAddrs.reserve(mGateways.size());
144       transform(std::begin(mGateways), std::end(mGateways), back_inserter(curAddrs),
145         [](const typename GatewayMap::value_type& vt) { return vt.first; });
146 
147       // Now use set_difference to determine the set of addresses that
148       // are new and the set of cur addresses that are no longer there
149       vector<asio::ip::address> newAddrs;
150       set_difference(std::begin(range), std::end(range), std::begin(curAddrs),
151         std::end(curAddrs), back_inserter(newAddrs));
152 
153       vector<asio::ip::address> staleAddrs;
154       set_difference(std::begin(curAddrs), std::end(curAddrs), std::begin(range),
155         std::end(range), back_inserter(staleAddrs));
156 
157       // Remove the stale addresses
158       for (const auto& addr : staleAddrs)
159       {
160         mGateways.erase(addr);
161       }
162 
163       // Add the new addresses
164       for (const auto& addr : newAddrs)
165       {
166         try
167         {
168           // Only handle v4 for now
169           if (addr.is_v4())
170           {
171             info(mIo.log()) << "initializing peer gateway on interface " << addr;
172             mGateways.emplace(addr, mFactory(mState, util::injectRef(mIo), addr.to_v4()));
173           }
174         }
175         catch (const runtime_error& e)
176         {
177           warning(mIo.log()) << "failed to init gateway on interface " << addr
178                              << " reason: " << e.what();
179         }
180       }
181     }
182 
183     NodeState mState;
184     GatewayFactory mFactory;
185     IoType& mIo;
186     GatewayMap mGateways;
187   };
188 
189   using Scanner = InterfaceScanner<std::shared_ptr<Callback>, IoType&>;
190 
191   struct Deleter
192   {
Deleterableton::discovery::PeerGateways::Deleter193     Deleter(PeerGateways& gateways)
194       : mpScannerCallback(std::move(gateways.mpScannerCallback))
195       , mpScanner(std::move(gateways.mpScanner))
196     {
197     }
198 
operator ()ableton::discovery::PeerGateways::Deleter199     void operator()()
200     {
201       mpScanner.reset();
202       mpScannerCallback.reset();
203     }
204 
205     std::shared_ptr<Callback> mpScannerCallback;
206     std::shared_ptr<Scanner> mpScanner;
207   };
208 
209   std::shared_ptr<Callback> mpScannerCallback;
210   std::shared_ptr<Scanner> mpScanner;
211   util::Injected<IoContext> mIo;
212 };
213 
214 // Factory function
215 template <typename NodeState, typename GatewayFactory, typename IoContext>
makePeerGateways(const std::chrono::seconds rescanPeriod,NodeState state,GatewayFactory factory,util::Injected<IoContext> io)216 std::unique_ptr<PeerGateways<NodeState, GatewayFactory, IoContext>> makePeerGateways(
217   const std::chrono::seconds rescanPeriod,
218   NodeState state,
219   GatewayFactory factory,
220   util::Injected<IoContext> io)
221 {
222   using namespace std;
223   using Gateways = PeerGateways<NodeState, GatewayFactory, IoContext>;
224   return unique_ptr<Gateways>{
225     new Gateways{rescanPeriod, move(state), move(factory), move(io)}};
226 }
227 
228 } // namespace discovery
229 } // namespace ableton
230