1 #include <csignal>
2 #include <cstdlib>
3 
4 #include "spdlog/sinks/stdout_color_sinks.h"
5 #include "spdlog/spdlog.h"
6 
7 #include <atomic>
8 #include <fstream>
9 #include <iostream>
10 #include <nlohmann/json.hpp>
11 #include <stdexcept>
12 #include <thread>
13 #include <variant>
14 
15 #include <mtx.hpp>
16 #include <mtx/identifiers.hpp>
17 
18 #include "mtxclient/crypto/client.hpp"
19 #include "mtxclient/http/client.hpp"
20 #include "mtxclient/http/errors.hpp"
21 
22 #include "mtxclient/utils.hpp"
23 
24 //
25 // Simple example bot that will accept any invite.
26 //
27 
28 using namespace std;
29 using namespace mtx::client;
30 using namespace mtx::crypto;
31 using namespace mtx::http;
32 using namespace mtx::events;
33 using namespace mtx::identifiers;
34 
35 using TimelineEvent = mtx::events::collections::TimelineEvents;
36 
37 constexpr auto OLM_ALGO    = "m.olm.v1.curve25519-aes-sha2";
38 constexpr auto STORAGE_KEY = "secret";
39 
40 struct OlmCipherContent
41 {
42     std::string body;
43     uint8_t type;
44 };
45 
46 inline void
from_json(const nlohmann::json & obj,OlmCipherContent & msg)47 from_json(const nlohmann::json &obj, OlmCipherContent &msg)
48 {
49     msg.body = obj.at("body");
50     msg.type = obj.at("type");
51 }
52 
53 struct OlmMessage
54 {
55     std::string sender_key;
56     std::string sender;
57 
58     using RecipientKey = std::string;
59     std::map<RecipientKey, OlmCipherContent> ciphertext;
60 };
61 
62 inline void
from_json(const nlohmann::json & obj,OlmMessage & msg)63 from_json(const nlohmann::json &obj, OlmMessage &msg)
64 {
65     if (obj.at("type") != "m.room.encrypted") {
66         throw std::invalid_argument("invalid type for olm message");
67     }
68 
69     if (obj.at("content").at("algorithm") != OLM_ALGO)
70         throw std::invalid_argument("invalid algorithm for olm message");
71 
72     msg.sender     = obj.at("sender");
73     msg.sender_key = obj.at("content").at("sender_key");
74     msg.ciphertext =
75       obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
76 }
77 
78 template<class Container, class Item>
79 bool
exists(const Container & container,const Item & item)80 exists(const Container &container, const Item &item)
81 {
82     return container.find(item) != container.end();
83 }
84 
85 void
86 get_device_keys(const UserId &user);
87 
88 void
89 save_device_keys(const mtx::responses::QueryKeys &res);
90 
91 void
92 mark_encrypted_room(const RoomId &id);
93 
94 void
95 handle_to_device_msgs(const mtx::responses::ToDevice &to_device);
96 
97 struct OutboundSessionData
98 {
99     std::string session_id;
100     std::string session_key;
101     uint64_t message_index = 0;
102 };
103 
104 inline void
to_json(nlohmann::json & obj,const OutboundSessionData & msg)105 to_json(nlohmann::json &obj, const OutboundSessionData &msg)
106 {
107     obj["session_id"]    = msg.session_id;
108     obj["session_key"]   = msg.session_key;
109     obj["message_index"] = msg.message_index;
110 }
111 
112 inline void
from_json(const nlohmann::json & obj,OutboundSessionData & msg)113 from_json(const nlohmann::json &obj, OutboundSessionData &msg)
114 {
115     msg.session_id    = obj.at("session_id");
116     msg.session_key   = obj.at("session_key");
117     msg.message_index = obj.at("message_index");
118 }
119 
120 struct OutboundSessionDataRef
121 {
122     OlmOutboundGroupSession *session;
123     OutboundSessionData data;
124 };
125 
126 struct DevKeys
127 {
128     std::string ed25519;
129     std::string curve25519;
130 };
131 
132 inline void
to_json(nlohmann::json & obj,const DevKeys & msg)133 to_json(nlohmann::json &obj, const DevKeys &msg)
134 {
135     obj["ed25519"]    = msg.ed25519;
136     obj["curve25519"] = msg.curve25519;
137 }
138 
139 inline void
from_json(const nlohmann::json & obj,DevKeys & msg)140 from_json(const nlohmann::json &obj, DevKeys &msg)
141 {
142     msg.ed25519    = obj.at("ed25519");
143     msg.curve25519 = obj.at("curve25519");
144 }
145 
146 auto console = spdlog::stdout_color_mt("console");
147 
148 std::shared_ptr<Client> client        = nullptr;
149 std::shared_ptr<OlmClient> olm_client = nullptr;
150 
151 struct Storage
152 {
153     //! Storage for the user_id -> list of devices mapping.
154     std::map<std::string, std::vector<std::string>> devices;
155     //! Storage for the identity key for a device.
156     std::map<std::string, DevKeys> device_keys;
157     //! Flag that indicate if a specific room has encryption enabled.
158     std::map<std::string, bool> encrypted_rooms;
159 
160     //! Keep track of members per room.
161     std::map<std::string, std::map<std::string, bool>> members;
162 
add_memberStorage163     void add_member(const std::string &room_id, const std::string &user_id)
164     {
165         members[room_id][user_id] = true;
166     }
167 
168     //! Mapping from curve25519 to session.
169     std::map<std::string, OlmSessionPtr> olm_inbound_sessions;
170     std::map<std::string, OlmSessionPtr> olm_outbound_sessions;
171 
172     // TODO: store message_index / event_id
173     std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions;
174     // TODO: store rotation period
175     std::map<std::string, OutboundSessionData> outbound_group_session_data;
176     std::map<std::string, OutboundGroupSessionPtr> outbound_group_sessions;
177 
outbound_group_existsStorage178     bool outbound_group_exists(const std::string &room_id)
179     {
180         return (outbound_group_sessions.find(room_id) != outbound_group_sessions.end()) &&
181                (outbound_group_session_data.find(room_id) != outbound_group_session_data.end());
182     }
183 
set_outbound_group_sessionStorage184     void set_outbound_group_session(const std::string &room_id,
185                                     OutboundGroupSessionPtr session,
186                                     OutboundSessionData data)
187     {
188         outbound_group_session_data[room_id] = data;
189         outbound_group_sessions[room_id]     = std::move(session);
190     }
191 
get_outbound_group_sessionStorage192     OutboundSessionDataRef get_outbound_group_session(const std::string &room_id)
193     {
194         return OutboundSessionDataRef{outbound_group_sessions[room_id].get(),
195                                       outbound_group_session_data[room_id]};
196     }
197 
inbound_group_existsStorage198     bool inbound_group_exists(const std::string &room_id,
199                               const std::string &session_id,
200                               const std::string &sender_key)
201     {
202         const auto key = room_id + session_id + sender_key;
203         return inbound_group_sessions.find(key) != inbound_group_sessions.end();
204     }
205 
set_inbound_group_sessionStorage206     void set_inbound_group_session(const std::string &room_id,
207                                    const std::string &session_id,
208                                    const std::string &sender_key,
209                                    InboundGroupSessionPtr session)
210     {
211         const auto key              = room_id + session_id + sender_key;
212         inbound_group_sessions[key] = std::move(session);
213     }
214 
get_inbound_group_sessionStorage215     OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id,
216                                                       const std::string &session_id,
217                                                       const std::string &sender_key)
218     {
219         const auto key = room_id + session_id + sender_key;
220         return inbound_group_sessions[key].get();
221     }
222 
loadStorage223     void load()
224     {
225         console->info("restoring storage");
226 
227         ifstream db("db.json");
228         string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
229 
230         if (db_data.empty())
231             return;
232 
233         json obj = json::parse(db_data);
234 
235         devices         = obj.at("devices").get<map<string, vector<string>>>();
236         device_keys     = obj.at("device_keys").get<map<string, DevKeys>>();
237         encrypted_rooms = obj.at("encrypted_rooms").get<map<string, bool>>();
238         members         = obj.at("members").get<map<string, map<string, bool>>>();
239 
240         if (obj.count("olm_inbound_sessions") != 0) {
241             auto sessions = obj.at("olm_inbound_sessions").get<map<string, string>>();
242             for (const auto &s : sessions)
243                 olm_inbound_sessions[s.first] = unpickle<SessionObject>(s.second, STORAGE_KEY);
244         }
245 
246         if (obj.count("olm_outbound_sessions") != 0) {
247             auto sessions = obj.at("olm_outbound_sessions").get<map<string, string>>();
248             for (const auto &s : sessions)
249                 olm_outbound_sessions[s.first] = unpickle<SessionObject>(s.second, STORAGE_KEY);
250         }
251 
252         if (obj.count("inbound_group_sessions") != 0) {
253             auto sessions = obj.at("inbound_group_sessions").get<map<string, string>>();
254             for (const auto &s : sessions)
255                 inbound_group_sessions[s.first] =
256                   unpickle<InboundSessionObject>(s.second, STORAGE_KEY);
257         }
258 
259         if (obj.count("outbound_group_sessions") != 0) {
260             auto sessions = obj.at("outbound_group_sessions").get<map<string, string>>();
261             for (const auto &s : sessions)
262                 outbound_group_sessions[s.first] =
263                   unpickle<OutboundSessionObject>(s.second, STORAGE_KEY);
264         }
265 
266         if (obj.count("outbound_group_session_data") != 0) {
267             auto sessions =
268               obj.at("outbound_group_session_data").get<map<string, OutboundSessionData>>();
269             for (const auto &s : sessions)
270                 outbound_group_session_data[s.first] = s.second;
271         }
272     }
273 
saveStorage274     void save()
275     {
276         console->info("saving storage");
277 
278         std::ofstream db("db.json");
279         if (!db.is_open()) {
280             console->error("couldn't open file to save keys");
281             return;
282         }
283 
284         json data;
285         data["devices"]         = devices;
286         data["device_keys"]     = device_keys;
287         data["encrypted_rooms"] = encrypted_rooms;
288         data["members"]         = members;
289 
290         // Save inbound sessions
291         for (const auto &s : olm_inbound_sessions)
292             data["olm_inbound_sessions"][s.first] =
293               mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
294 
295         for (const auto &s : olm_outbound_sessions)
296             data["olm_outbound_sessions"][s.first] =
297               mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
298 
299         for (const auto &s : inbound_group_sessions)
300             data["inbound_group_sessions"][s.first] =
301               mtx::crypto::pickle<InboundSessionObject>(s.second.get(), STORAGE_KEY);
302 
303         for (const auto &s : outbound_group_sessions)
304             data["outbound_group_sessions"][s.first] =
305               mtx::crypto::pickle<OutboundSessionObject>(s.second.get(), STORAGE_KEY);
306 
307         for (const auto &s : outbound_group_session_data)
308             data["outbound_group_session_data"][s.first] = s.second;
309 
310         // Save to file
311         db << data.dump(2);
312         db.close();
313     }
314 };
315 
316 namespace {
317 Storage storage;
318 }
319 
320 void
print_errors(RequestErr err)321 print_errors(RequestErr err)
322 {
323     if (err->status_code)
324         console->error("status code: {}", static_cast<uint16_t>(err->status_code));
325     if (!err->matrix_error.error.empty())
326         console->error("matrix error: {}", err->matrix_error.error);
327     if (err->error_code)
328         console->error("error code: {}", err->error_code);
329 }
330 
331 template<class T>
332 bool
is_room_encryption(const T & event)333 is_room_encryption(const T &event)
334 {
335     using namespace mtx::events;
336     using namespace mtx::events::state;
337     return std::holds_alternative<StateEvent<Encryption>>(event);
338 }
339 
340 void
send_group_message(OlmOutboundGroupSession * session,const std::string & session_id,const std::string & room_id,const std::string & msg)341 send_group_message(OlmOutboundGroupSession *session,
342                    const std::string &session_id,
343                    const std::string &room_id,
344                    const std::string &msg)
345 {
346     // Create event payload
347     json doc{{"type", "m.room.message"},
348              {"content", {{"type", "m.text"}, {"body", msg}}},
349              {"room_id", room_id}};
350 
351     auto payload = olm_client->encrypt_group_message(session, doc.dump());
352 
353     using namespace mtx::events;
354     using namespace mtx::identifiers;
355 
356     msg::Encrypted data;
357     data.ciphertext = std::string((char *)payload.data(), payload.size());
358     data.sender_key = olm_client->identity_keys().curve25519;
359     data.session_id = session_id;
360     data.device_id  = client->device_id();
361 
362     client->send_room_message<msg::Encrypted>(
363       room_id, data, [](const mtx::responses::EventId &res, RequestErr err) {
364           if (err) {
365               print_errors(err);
366               return;
367           }
368 
369           console->info("message sent with event_id: {}", res.event_id.to_string());
370       });
371 }
372 
373 void
create_outbound_megolm_session(const std::string & room_id,const std::string & reply_msg)374 create_outbound_megolm_session(const std::string &room_id, const std::string &reply_msg)
375 {
376     // Create an outbound session
377     auto outbound_session = olm_client->init_outbound_group_session();
378 
379     const auto session_id  = mtx::crypto::session_id(outbound_session.get());
380     const auto session_key = mtx::crypto::session_key(outbound_session.get());
381 
382     mtx::events::DeviceEvent<mtx::events::msg::RoomKey> megolm_payload;
383     megolm_payload.content.algorithm   = "m.megolm.v1.aes-sha2";
384     megolm_payload.content.room_id     = room_id;
385     megolm_payload.content.session_id  = session_id;
386     megolm_payload.content.session_key = session_key;
387     megolm_payload.type                = mtx::events::EventType::RoomKey;
388 
389     if (storage.members.find(room_id) == storage.members.end()) {
390         console->error("no members found for room {}", room_id);
391         return;
392     }
393 
394     const auto members = storage.members[room_id];
395 
396     for (const auto &member : members) {
397         const auto devices = storage.devices[member.first];
398 
399         // TODO: Figure out for which devices we don't have olm sessions.
400         for (const auto &dev : devices) {
401             // TODO: check if we have downloaded the keys
402             const auto device_keys = storage.device_keys[dev];
403 
404             auto to_device_cb = [](RequestErr err) {
405                 if (err) {
406                     print_errors(err);
407                 }
408             };
409 
410             if (storage.olm_outbound_sessions.find(device_keys.curve25519) !=
411                 storage.olm_outbound_sessions.end()) {
412                 console->info("found existing olm outbound session with device {}", dev);
413                 auto olm_session = storage.olm_outbound_sessions[device_keys.curve25519].get();
414 
415                 auto device_msg = olm_client->create_olm_encrypted_content(olm_session,
416                                                                            megolm_payload,
417                                                                            UserId(member.first),
418                                                                            device_keys.ed25519,
419                                                                            device_keys.curve25519);
420 
421                 json body{{"messages", {{member, {{dev, device_msg}}}}}};
422 
423                 client->send_to_device("m.room.encrypted", body, to_device_cb);
424                 // TODO: send message to device
425             } else {
426                 console->info("claiming one time keys for device {}", dev);
427                 auto cb = [member = member.first, dev, megolm_payload, to_device_cb](
428                             const mtx::responses::ClaimKeys &res, RequestErr err) {
429                     if (err) {
430                         print_errors(err);
431                         return;
432                     }
433 
434                     console->info("claimed keys for {} - {}", member, dev);
435                     console->info("room_key {}", json(megolm_payload).dump(4));
436 
437                     console->warn("signed one time keys");
438                     auto retrieved_devices = res.one_time_keys.at(member);
439                     for (const auto &rd : retrieved_devices) {
440                         console->info("{} : \n {}", rd.first, rd.second.dump(2));
441 
442                         // TODO: Verify signatures
443                         auto otk    = rd.second.begin()->at("key");
444                         auto id_key = storage.device_keys[dev].curve25519;
445 
446                         auto session = olm_client->create_outbound_session(id_key, otk);
447 
448                         auto device_msg = olm_client->create_olm_encrypted_content(
449                           session.get(),
450                           megolm_payload,
451                           UserId(member),
452                           storage.device_keys[dev].ed25519,
453                           storage.device_keys[dev].curve25519);
454 
455                         // TODO: saving should happen when the message is
456                         // sent.
457                         storage.olm_outbound_sessions[id_key] = std::move(session);
458 
459                         json body{{"messages", {{member, {{dev, device_msg}}}}}};
460 
461                         client->send_to_device("m.room.encrypted", body, to_device_cb);
462                     }
463                 };
464 
465                 mtx::requests::ClaimKeys claim_keys;
466                 claim_keys.one_time_keys[member.first][dev] = SIGNED_CURVE25519;
467 
468                 // TODO: we should bulk request device keys here
469                 client->claim_keys(claim_keys, cb);
470             }
471         }
472     }
473 
474     console->info("waiting to send sendToDevice messages");
475     std::this_thread::sleep_for(std::chrono::milliseconds(2000));
476     console->info("sending encrypted group message");
477 
478     // TODO: This should be done after all sendToDevice messages have been sent.
479     send_group_message(outbound_session.get(), session_id, room_id, reply_msg);
480 
481     // TODO: save message index also.
482     storage.set_outbound_group_session(
483       room_id, std::move(outbound_session), {session_id, session_key});
484 }
485 
486 bool
is_encrypted(const TimelineEvent & event)487 is_encrypted(const TimelineEvent &event)
488 {
489     using namespace mtx::events;
490     return std::holds_alternative<EncryptedEvent<msg::Encrypted>>(event);
491 }
492 
493 template<class T>
494 bool
is_member_event(const T & event)495 is_member_event(const T &event)
496 {
497     using namespace mtx::events;
498     using namespace mtx::events::state;
499     return std::holds_alternative<StateEvent<Member>>(event);
500 }
501 
502 // Check if the given event has a textual representation.
503 bool
is_room_message(const TimelineEvent & e)504 is_room_message(const TimelineEvent &e)
505 {
506     return (std::holds_alternative<mtx::events::RoomEvent<msg::Audio>>(e)) ||
507            (std::holds_alternative<mtx::events::RoomEvent<msg::Emote>>(e)) ||
508            (std::holds_alternative<mtx::events::RoomEvent<msg::File>>(e)) ||
509            (std::holds_alternative<mtx::events::RoomEvent<msg::Image>>(e)) ||
510            (std::holds_alternative<mtx::events::RoomEvent<msg::Notice>>(e)) ||
511            (std::holds_alternative<mtx::events::RoomEvent<msg::Text>>(e)) ||
512            (std::holds_alternative<mtx::events::RoomEvent<msg::Video>>(e));
513 }
514 
515 // Retrieves the fallback body value from the event.
516 std::string
get_body(const TimelineEvent & e)517 get_body(const TimelineEvent &e)
518 {
519     if (auto ev = std::get_if<RoomEvent<msg::Audio>>(&e); ev != nullptr)
520         return ev->content.body;
521     else if (auto ev = std::get_if<RoomEvent<msg::Emote>>(&e); ev != nullptr)
522         return ev->content.body;
523     else if (auto ev = std::get_if<RoomEvent<msg::File>>(&e); ev != nullptr)
524         return ev->content.body;
525     else if (auto ev = std::get_if<RoomEvent<msg::Image>>(&e); ev != nullptr)
526         return ev->content.body;
527     else if (auto ev = std::get_if<RoomEvent<msg::Notice>>(&e); ev != nullptr)
528         return ev->content.body;
529     else if (auto ev = std::get_if<RoomEvent<msg::Text>>(&e); ev != nullptr)
530         return ev->content.body;
531     else if (auto ev = std::get_if<RoomEvent<msg::Video>>(&e); ev != nullptr)
532         return ev->content.body;
533 
534     return "";
535 }
536 
537 // Retrieves the sender of the event.
538 std::string
get_sender(const TimelineEvent & event)539 get_sender(const TimelineEvent &event)
540 {
541     return std::visit([](auto e) { return e.sender; }, event);
542 }
543 
544 template<class T>
545 std::string
get_json(const T & event)546 get_json(const T &event)
547 {
548     return std::visit([](auto e) { return json(e).dump(2); }, event);
549 }
550 
551 void
keys_uploaded_cb(const mtx::responses::UploadKeys &,RequestErr err)552 keys_uploaded_cb(const mtx::responses::UploadKeys &, RequestErr err)
553 {
554     if (err) {
555         print_errors(err);
556         return;
557     }
558 
559     olm_client->mark_keys_as_published();
560     console->info("keys uploaded");
561 }
562 
563 void
mark_encrypted_room(const RoomId & id)564 mark_encrypted_room(const RoomId &id)
565 {
566     console->info("encryption is enabled for room: {}", id.get());
567     storage.encrypted_rooms[id.get()] = true;
568 }
569 
570 void
send_encrypted_reply(const std::string & room_id,const std::string & reply_msg)571 send_encrypted_reply(const std::string &room_id, const std::string &reply_msg)
572 {
573     console->info("sending reply");
574 
575     // Create a megolm session if it doesn't already exist.
576     if (storage.outbound_group_exists(room_id)) {
577         auto session_obj = storage.get_outbound_group_session(room_id);
578 
579         send_group_message(session_obj.session, session_obj.data.session_id, room_id, reply_msg);
580 
581     } else {
582         console->info("creating new megolm outbound session");
583         create_outbound_megolm_session(room_id, reply_msg);
584     }
585 }
586 
587 void
decrypt_olm_message(const OlmMessage & olm_msg)588 decrypt_olm_message(const OlmMessage &olm_msg)
589 {
590     console->info("OLM message");
591     console->info("sender: {}", olm_msg.sender);
592     console->info("sender_key: {}", olm_msg.sender_key);
593 
594     const auto my_id_key = olm_client->identity_keys().curve25519;
595     for (const auto &cipher : olm_msg.ciphertext) {
596         if (cipher.first == my_id_key) {
597             const auto msg_body = cipher.second.body;
598             const auto msg_type = cipher.second.type;
599 
600             console->info("the message is meant for us");
601             console->info("body: {}", msg_body);
602             console->info("type: {}", msg_type);
603 
604             if (msg_type == 0) {
605                 console->info("opening session with {}", olm_msg.sender);
606                 auto inbound_session = olm_client->create_inbound_session(msg_body);
607 
608                 auto ok =
609                   matches_inbound_session_from(inbound_session.get(), olm_msg.sender_key, msg_body);
610 
611                 if (!ok) {
612                     console->error("session could not be established");
613 
614                 } else {
615                     auto output =
616                       olm_client->decrypt_message(inbound_session.get(), msg_type, msg_body);
617 
618                     auto plaintext = json::parse(std::string((char *)output.data(), output.size()));
619                     console->info("decrypted message: \n {}", plaintext.dump(2));
620 
621                     storage.olm_inbound_sessions.emplace(olm_msg.sender_key,
622                                                          std::move(inbound_session));
623 
624                     std::string room_id     = plaintext.at("content").at("room_id");
625                     std::string session_id  = plaintext.at("content").at("session_id");
626                     std::string session_key = plaintext.at("content").at("session_key");
627 
628                     if (storage.inbound_group_exists(room_id, session_id, olm_msg.sender_key)) {
629                         console->warn("megolm session already exists");
630                     } else {
631                         auto megolm_session = olm_client->init_inbound_group_session(session_key);
632 
633                         storage.set_inbound_group_session(
634                           room_id, session_id, olm_msg.sender_key, std::move(megolm_session));
635 
636                         console->info("megolm_session saved");
637                     }
638                 }
639             }
640         }
641     }
642 }
643 
644 void
parse_messages(const mtx::responses::Sync & res)645 parse_messages(const mtx::responses::Sync &res)
646 {
647     for (const auto &room : res.rooms.invite) {
648         auto room_id = room.first;
649 
650         console->info("joining room {}", room_id);
651         client->join_room(room_id, [room_id](const mtx::responses::RoomId &, RequestErr e) {
652             if (e) {
653                 print_errors(e);
654                 console->error("failed to join room {}", room_id);
655                 return;
656             }
657         });
658     }
659 
660     // Check if we have any new m.room_key messages (i.e starting a new megolm session)
661     handle_to_device_msgs(res.to_device);
662 
663     // Check if the uploaded one time keys are enough
664     for (const auto &device : res.device_one_time_keys_count) {
665         if (device.second < 50) {
666             console->info("number of one time keys: {}", device.second);
667             olm_client->generate_one_time_keys(50 - device.second);
668             // TODO: Mark keys as sent
669             client->upload_keys(olm_client->create_upload_keys_request(), &keys_uploaded_cb);
670         }
671     }
672 
673     for (const auto &room : res.rooms.join) {
674         const std::string room_id = room.first;
675 
676         for (const auto &e : room.second.state.events) {
677             if (is_room_encryption(e)) {
678                 mark_encrypted_room(RoomId(room_id));
679                 console->debug("{}", get_json(e));
680             } else if (is_member_event(e)) {
681                 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
682 
683                 get_device_keys(UserId(m.state_key));
684                 storage.add_member(room_id, m.state_key);
685             }
686         }
687 
688         for (const auto &e : room.second.timeline.events) {
689             if (is_room_encryption(e)) {
690                 mark_encrypted_room(RoomId(room_id));
691                 console->debug("{}", get_json(e));
692             } else if (is_member_event(e)) {
693                 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
694 
695                 get_device_keys(UserId(m.state_key));
696                 storage.add_member(room_id, m.state_key);
697             } else if (is_encrypted(e)) {
698                 console->info("received an encrypted event: {}", room_id);
699                 console->info("{}", get_json(e));
700 
701                 auto msg = std::get<EncryptedEvent<msg::Encrypted>>(e);
702 
703                 if (storage.inbound_group_exists(
704                       room_id, msg.content.session_id, msg.content.sender_key)) {
705                     auto res = olm_client->decrypt_group_message(
706                       storage.get_inbound_group_session(
707                         room_id, msg.content.session_id, msg.content.sender_key),
708                       msg.content.ciphertext);
709 
710                     auto msg_str = std::string((char *)res.data.data(), res.data.size());
711                     const auto body =
712                       json::parse(msg_str).at("content").at("body").get<std::string>();
713 
714                     console->info("decrypted data: {}", body);
715                     console->info("decrypted message_index: {}", res.message_index);
716 
717                     if (msg.sender != client->user_id().to_string()) {
718                         // Send a reply back to the sender.
719                         std::string reply_txt(msg.sender + ": you said '" + body + "'");
720                         send_encrypted_reply(room_id, reply_txt);
721                     }
722 
723                 } else {
724                     console->warn("no megolm session found to decrypt the event");
725                 }
726             }
727         }
728     }
729 }
730 
731 // Callback to executed after a /sync request completes.
732 void
sync_handler(const mtx::responses::Sync & res,RequestErr err)733 sync_handler(const mtx::responses::Sync &res, RequestErr err)
734 {
735     SyncOpts opts;
736 
737     if (err) {
738         console->error("error during sync");
739         print_errors(err);
740         opts.since = client->next_batch_token();
741         client->sync(opts, &sync_handler);
742         return;
743     }
744 
745     parse_messages(res);
746 
747     opts.since = res.next_batch;
748     client->set_next_batch_token(res.next_batch);
749     client->sync(opts, &sync_handler);
750 }
751 
752 // Callback to executed after the first (initial) /sync request completes.
753 void
initial_sync_handler(const mtx::responses::Sync & res,RequestErr err)754 initial_sync_handler(const mtx::responses::Sync &res, RequestErr err)
755 {
756     SyncOpts opts;
757 
758     if (err) {
759         console->error("error during initial sync");
760         print_errors(err);
761 
762         if (err->status_code != 200) {
763             console->error("retrying initial sync ..");
764             opts.timeout = 0;
765             client->sync(opts, &initial_sync_handler);
766         }
767 
768         return;
769     }
770 
771     parse_messages(res);
772 
773     for (const auto &room : res.rooms.join) {
774         const auto room_id = room.first;
775 
776         for (const auto &e : room.second.state.events) {
777             if (is_member_event(e)) {
778                 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
779 
780                 get_device_keys(UserId(m.state_key));
781                 storage.add_member(room_id, m.state_key);
782             }
783         }
784     }
785 
786     opts.since = res.next_batch;
787     client->set_next_batch_token(res.next_batch);
788     client->sync(opts, &sync_handler);
789 }
790 
791 void
save_device_keys(const mtx::responses::QueryKeys & res)792 save_device_keys(const mtx::responses::QueryKeys &res)
793 {
794     for (const auto &entry : res.device_keys) {
795         const auto user_id = entry.first;
796 
797         if (!exists(storage.devices, user_id))
798             console->info("keys for {}", user_id);
799 
800         std::vector<std::string> device_list;
801         for (const auto &device : entry.second) {
802             const auto key_struct = device.second;
803 
804             const std::string device_id = key_struct.device_id;
805             const std::string index     = "curve25519:" + device_id;
806 
807             if (key_struct.keys.find(index) == key_struct.keys.end())
808                 continue;
809 
810             const auto key = key_struct.keys.at(index);
811 
812             if (!exists(storage.device_keys, device_id)) {
813                 console->info("{} => {}", device_id, key);
814                 storage.device_keys[device_id] = {key_struct.keys.at("ed25519:" + device_id),
815                                                   key_struct.keys.at("curve25519:" + device_id)};
816             }
817 
818             device_list.push_back(device_id);
819         }
820 
821         if (!exists(storage.devices, user_id)) {
822             storage.devices[user_id] = device_list;
823         }
824     }
825 }
826 
827 void
get_device_keys(const UserId & user)828 get_device_keys(const UserId &user)
829 {
830     // Retrieve all devices keys.
831     mtx::requests::QueryKeys query;
832     query.device_keys[user.get()] = {};
833 
834     client->query_keys(query, [](const mtx::responses::QueryKeys &res, RequestErr err) {
835         if (err) {
836             print_errors(err);
837             return;
838         }
839 
840         for (const auto &key : res.device_keys) {
841             const auto user_id = key.first;
842             const auto devices = key.second;
843 
844             for (const auto &device : devices) {
845                 const auto id   = device.first;
846                 const auto data = device.second;
847 
848                 try {
849                     auto ok = verify_identity_signature(json(data), DeviceId(id), UserId(user_id));
850 
851                     if (!ok) {
852                         console->warn("signature could not be verified");
853                         console->warn(json(data).dump(2));
854                     }
855                 } catch (const olm_exception &e) {
856                     console->warn(e.what());
857                 }
858             }
859         }
860 
861         save_device_keys(std::move(res));
862     });
863 }
864 
865 void
handle_to_device_msgs(const mtx::responses::ToDevice & msgs)866 handle_to_device_msgs(const mtx::responses::ToDevice &msgs)
867 {
868     if (!msgs.events.empty())
869         console->info("inspecting {} to_device messages", msgs.events.size());
870 
871     for (const auto &msg : msgs.events) {
872         console->info(std::visit([](const auto &e) { return json(e); }, msg).dump(2));
873 
874         try {
875             OlmMessage olm_msg = std::visit([](const auto &e) { return json(e); }, msg);
876             decrypt_olm_message(std::move(olm_msg));
877         } catch (const nlohmann::json::exception &e) {
878             console->warn("parsing error for olm message: {}", e.what());
879         } catch (const std::invalid_argument &e) {
880             console->warn("validation error for olm message: {}", e.what());
881         }
882     }
883 }
884 
885 void
login_cb(const mtx::responses::Login &,RequestErr err)886 login_cb(const mtx::responses::Login &, RequestErr err)
887 {
888     if (err) {
889         console->error("login error");
890         print_errors(err);
891         return;
892     }
893 
894     console->info("User ID: {}", client->user_id().to_string());
895     console->info("Device ID: {}", client->device_id());
896     console->info("ed25519: {}", olm_client->identity_keys().ed25519);
897     console->info("curve25519: {}", olm_client->identity_keys().curve25519);
898 
899     // Upload one time keys.
900     olm_client->set_user_id(client->user_id().to_string());
901     olm_client->set_device_id(client->device_id());
902     olm_client->generate_one_time_keys(50);
903 
904     client->upload_keys(olm_client->create_upload_keys_request(),
905                         [](const mtx::responses::UploadKeys &, RequestErr err) {
906                             if (err) {
907                                 print_errors(err);
908                                 return;
909                             }
910 
911                             olm_client->mark_keys_as_published();
912                             console->info("keys uploaded");
913                             console->debug("starting initial sync");
914 
915                             SyncOpts opts;
916                             opts.timeout = 0;
917                             client->sync(opts, &initial_sync_handler);
918                         });
919 }
920 
921 void
join_room_cb(const mtx::responses::RoomId &,RequestErr err)922 join_room_cb(const mtx::responses::RoomId &, RequestErr err)
923 {
924     if (err) {
925         print_errors(err);
926         return;
927     }
928 
929     // Fetch device list for all users.
930 }
931 
932 void
shutdown_handler(int sig)933 shutdown_handler(int sig)
934 {
935     console->warn("received {} signal", sig);
936     storage.save();
937 
938     std::ofstream db("account.json");
939     if (!db.is_open()) {
940         console->error("couldn't open file to save account keys");
941         return;
942     }
943 
944     json data;
945     data["account"] = olm_client->save(STORAGE_KEY);
946 
947     db << data.dump(2);
948     db.close();
949 
950     // The sync calls will stop.
951     client->shutdown();
952 }
953 
954 int
main()955 main()
956 {
957     spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
958 
959     std::signal(SIGINT, shutdown_handler);
960 
961     std::string username("alice");
962     std::string server("localhost");
963     std::string password("secret");
964 
965     client = std::make_shared<Client>(server);
966 
967     olm_client = make_shared<OlmClient>();
968 
969     ifstream db("account.json");
970     string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
971 
972     if (db_data.empty())
973         olm_client->create_new_account();
974     else
975         olm_client->load(json::parse(db_data).at("account").get<std::string>(), STORAGE_KEY);
976 
977     storage.load();
978 
979     client->login(username, password, login_cb);
980     client->close();
981 
982     console->info("exit");
983 
984     return 0;
985 }
986