1 /* Copyright 2016, Ableton AG, Berlin. All rights reserved.
2  *
3  *  This program is free software: you can redistribute it and/or modify
4  *  it under the terms of the GNU General Public License as published by
5  *  the Free Software Foundation, either version 2 of the License, or
6  *  (at your option) any later version.
7  *
8  *  This program is distributed in the hope that it will be useful,
9  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
10  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11  *  GNU General Public License for more details.
12  *
13  *  You should have received a copy of the GNU General Public License
14  *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
15  *
16  *  If you would like to incorporate Link into a proprietary software application,
17  *  please contact <link-devs@ableton.com>.
18  */
19 
20 #pragma once
21 
22 #include <ableton/discovery/UdpMessenger.hpp>
23 #include <ableton/discovery/v1/Messages.hpp>
24 #include <ableton/platforms/asio/AsioService.hpp>
25 #include <ableton/util/SafeAsyncHandler.hpp>
26 #include <memory>
27 
28 namespace ableton
29 {
30 namespace discovery
31 {
32 
33 template <typename Messenger, typename PeerObserver, typename IoContext>
34 class PeerGateway
35 {
36 public:
37   // The peer types are defined by the observer but must match with those
38   // used by the Messenger
39   using ObserverT = typename util::Injected<PeerObserver>::type;
40   using NodeState = typename ObserverT::GatewayObserverNodeState;
41   using NodeId = typename ObserverT::GatewayObserverNodeId;
42   using Timer = typename util::Injected<IoContext>::type::Timer;
43   using TimerError = typename Timer::ErrorCode;
44 
45   PeerGateway(util::Injected<Messenger> messenger,
46     util::Injected<PeerObserver> observer,
47     util::Injected<IoContext> io)
48     : mpImpl(new Impl(std::move(messenger), std::move(observer), std::move(io)))
49   {
50     mpImpl->listen();
51   }
52 
53   PeerGateway(const PeerGateway&) = delete;
54   PeerGateway& operator=(const PeerGateway&) = delete;
55 
56   PeerGateway(PeerGateway&& rhs)
57     : mpImpl(std::move(rhs.mpImpl))
58   {
59   }
60 
61   void updateState(NodeState state)
62   {
63     mpImpl->updateState(std::move(state));
64   }
65 
66 private:
67   using PeerTimeout = std::pair<std::chrono::system_clock::time_point, NodeId>;
68   using PeerTimeouts = std::vector<PeerTimeout>;
69 
70   struct Impl : std::enable_shared_from_this<Impl>
71   {
72     Impl(util::Injected<Messenger> messenger,
73       util::Injected<PeerObserver> observer,
74       util::Injected<IoContext> io)
75       : mMessenger(std::move(messenger))
76       , mObserver(std::move(observer))
77       , mIo(std::move(io))
78       , mPruneTimer(mIo->makeTimer())
79     {
80     }
81 
82     void updateState(NodeState state)
83     {
84       mMessenger->updateState(std::move(state));
85       try
86       {
87         mMessenger->broadcastState();
88       }
89       catch (const std::runtime_error& err)
90       {
91         info(mIo->log()) << "State broadcast failed on gateway: " << err.what();
92       }
93     }
94 
95     void listen()
96     {
97       mMessenger->receive(util::makeAsyncSafe(this->shared_from_this()));
98     }
99 
100     // Operators for handling incoming messages
101     void operator()(const PeerState<NodeState>& msg)
102     {
103       onPeerState(msg.peerState, msg.ttl);
104       listen();
105     }
106 
107     void operator()(const ByeBye<NodeId>& msg)
108     {
109       onByeBye(msg.peerId);
110       listen();
111     }
112 
113     void onPeerState(const NodeState& nodeState, const int ttl)
114     {
115       using namespace std;
116       const auto peerId = nodeState.ident();
117       const auto existing = findPeer(peerId);
118       if (existing != end(mPeerTimeouts))
119       {
120         // If the peer is already present in our timeout list, remove it
121         // as it will be re-inserted below.
122         mPeerTimeouts.erase(existing);
123       }
124 
125       auto newTo = make_pair(mPruneTimer.now() + std::chrono::seconds(ttl), peerId);
126       mPeerTimeouts.insert(
127         upper_bound(begin(mPeerTimeouts), end(mPeerTimeouts), newTo, TimeoutCompare{}),
128         move(newTo));
129 
130       sawPeer(*mObserver, nodeState);
131       scheduleNextPruning();
132     }
133 
134     void onByeBye(const NodeId& peerId)
135     {
136       const auto it = findPeer(peerId);
137       if (it != mPeerTimeouts.end())
138       {
139         peerLeft(*mObserver, it->second);
140         mPeerTimeouts.erase(it);
141       }
142     }
143 
144     void pruneExpiredPeers()
145     {
146       using namespace std;
147 
148       const auto test = make_pair(mPruneTimer.now(), NodeId{});
149       debug(mIo->log()) << "pruning peers @ " << test.first.time_since_epoch().count();
150 
151       const auto endExpired =
152         lower_bound(begin(mPeerTimeouts), end(mPeerTimeouts), test, TimeoutCompare{});
153 
154       for_each(begin(mPeerTimeouts), endExpired, [this](const PeerTimeout& pto) {
155         info(mIo->log()) << "pruning peer " << pto.second;
156         peerTimedOut(*mObserver, pto.second);
157       });
158       mPeerTimeouts.erase(begin(mPeerTimeouts), endExpired);
159       scheduleNextPruning();
160     }
161 
162     void scheduleNextPruning()
163     {
164       // Find the next peer to expire and set the timer based on it
165       if (!mPeerTimeouts.empty())
166       {
167         // Add a second of padding to the timer to avoid over-eager timeouts
168         const auto t = mPeerTimeouts.front().first + std::chrono::seconds(1);
169 
170         debug(mIo->log()) << "scheduling next pruning for "
171                           << t.time_since_epoch().count() << " because of peer "
172                           << mPeerTimeouts.front().second;
173 
174         mPruneTimer.expires_at(t);
175         mPruneTimer.async_wait([this](const TimerError e) {
176           if (!e)
177           {
178             pruneExpiredPeers();
179           }
180         });
181       }
182     }
183 
184     struct TimeoutCompare
185     {
186       bool operator()(const PeerTimeout& lhs, const PeerTimeout& rhs) const
187       {
188         return lhs.first < rhs.first;
189       }
190     };
191 
192     typename PeerTimeouts::iterator findPeer(const NodeId& peerId)
193     {
194       return std::find_if(begin(mPeerTimeouts), end(mPeerTimeouts),
195         [&peerId](const PeerTimeout& pto) { return pto.second == peerId; });
196     }
197 
198     util::Injected<Messenger> mMessenger;
199     util::Injected<PeerObserver> mObserver;
200     util::Injected<IoContext> mIo;
201     Timer mPruneTimer;
202     PeerTimeouts mPeerTimeouts; // Invariant: sorted by time_point
203   };
204 
205   std::shared_ptr<Impl> mpImpl;
206 };
207 
208 template <typename Messenger, typename PeerObserver, typename IoContext>
209 PeerGateway<Messenger, PeerObserver, IoContext> makePeerGateway(
210   util::Injected<Messenger> messenger,
211   util::Injected<PeerObserver> observer,
212   util::Injected<IoContext> io)
213 {
214   return {std::move(messenger), std::move(observer), std::move(io)};
215 }
216 
217 // IpV4 gateway types
218 template <typename StateQuery, typename IoContext>
219 using IpV4Messenger = UdpMessenger<
220   IpV4Interface<typename util::Injected<IoContext>::type&, v1::kMaxMessageSize>,
221   StateQuery,
222   IoContext>;
223 
224 template <typename PeerObserver, typename StateQuery, typename IoContext>
225 using IpV4Gateway =
226   PeerGateway<IpV4Messenger<StateQuery, typename util::Injected<IoContext>::type&>,
227     PeerObserver,
228     IoContext>;
229 
230 // Factory function to bind a PeerGateway to an IpV4Interface with the given address.
231 template <typename PeerObserver, typename NodeState, typename IoContext>
232 IpV4Gateway<PeerObserver, NodeState, IoContext> makeIpV4Gateway(
233   util::Injected<IoContext> io,
234   const asio::ip::address_v4& addr,
235   util::Injected<PeerObserver> observer,
236   NodeState state)
237 {
238   using namespace std;
239   using namespace util;
240 
241   const uint8_t ttl = 5;
242   const uint8_t ttlRatio = 20;
243 
244   auto iface = makeIpV4Interface<v1::kMaxMessageSize>(injectRef(*io), addr);
245 
246   auto messenger =
247     makeUdpMessenger(injectVal(move(iface)), move(state), injectRef(*io), ttl, ttlRatio);
248   return {injectVal(move(messenger)), move(observer), move(io)};
249 }
250 
251 } // namespace discovery
252 } // namespace ableton
253