1 #include <csignal>
2 #include <cstdlib>
3
4 #include "spdlog/sinks/stdout_color_sinks.h"
5 #include "spdlog/spdlog.h"
6
7 #include <atomic>
8 #include <fstream>
9 #include <iostream>
10 #include <nlohmann/json.hpp>
11 #include <stdexcept>
12 #include <thread>
13 #include <variant>
14
15 #include <mtx.hpp>
16 #include <mtx/identifiers.hpp>
17
18 #include "mtxclient/crypto/client.hpp"
19 #include "mtxclient/http/client.hpp"
20 #include "mtxclient/http/errors.hpp"
21
22 #include "mtxclient/utils.hpp"
23
24 //
25 // Simple example bot that will accept any invite.
26 //
27
28 using namespace std;
29 using namespace mtx::client;
30 using namespace mtx::crypto;
31 using namespace mtx::http;
32 using namespace mtx::events;
33 using namespace mtx::identifiers;
34
35 using TimelineEvent = mtx::events::collections::TimelineEvents;
36
37 constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
38 constexpr auto STORAGE_KEY = "secret";
39
40 struct OlmCipherContent
41 {
42 std::string body;
43 uint8_t type;
44 };
45
46 inline void
from_json(const nlohmann::json & obj,OlmCipherContent & msg)47 from_json(const nlohmann::json &obj, OlmCipherContent &msg)
48 {
49 msg.body = obj.at("body");
50 msg.type = obj.at("type");
51 }
52
53 struct OlmMessage
54 {
55 std::string sender_key;
56 std::string sender;
57
58 using RecipientKey = std::string;
59 std::map<RecipientKey, OlmCipherContent> ciphertext;
60 };
61
62 inline void
from_json(const nlohmann::json & obj,OlmMessage & msg)63 from_json(const nlohmann::json &obj, OlmMessage &msg)
64 {
65 if (obj.at("type") != "m.room.encrypted") {
66 throw std::invalid_argument("invalid type for olm message");
67 }
68
69 if (obj.at("content").at("algorithm") != OLM_ALGO)
70 throw std::invalid_argument("invalid algorithm for olm message");
71
72 msg.sender = obj.at("sender");
73 msg.sender_key = obj.at("content").at("sender_key");
74 msg.ciphertext =
75 obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
76 }
77
78 template<class Container, class Item>
79 bool
exists(const Container & container,const Item & item)80 exists(const Container &container, const Item &item)
81 {
82 return container.find(item) != container.end();
83 }
84
85 void
86 get_device_keys(const UserId &user);
87
88 void
89 save_device_keys(const mtx::responses::QueryKeys &res);
90
91 void
92 mark_encrypted_room(const RoomId &id);
93
94 void
95 handle_to_device_msgs(const mtx::responses::ToDevice &to_device);
96
97 struct OutboundSessionData
98 {
99 std::string session_id;
100 std::string session_key;
101 uint64_t message_index = 0;
102 };
103
104 inline void
to_json(nlohmann::json & obj,const OutboundSessionData & msg)105 to_json(nlohmann::json &obj, const OutboundSessionData &msg)
106 {
107 obj["session_id"] = msg.session_id;
108 obj["session_key"] = msg.session_key;
109 obj["message_index"] = msg.message_index;
110 }
111
112 inline void
from_json(const nlohmann::json & obj,OutboundSessionData & msg)113 from_json(const nlohmann::json &obj, OutboundSessionData &msg)
114 {
115 msg.session_id = obj.at("session_id");
116 msg.session_key = obj.at("session_key");
117 msg.message_index = obj.at("message_index");
118 }
119
120 struct OutboundSessionDataRef
121 {
122 OlmOutboundGroupSession *session;
123 OutboundSessionData data;
124 };
125
126 struct DevKeys
127 {
128 std::string ed25519;
129 std::string curve25519;
130 };
131
132 inline void
to_json(nlohmann::json & obj,const DevKeys & msg)133 to_json(nlohmann::json &obj, const DevKeys &msg)
134 {
135 obj["ed25519"] = msg.ed25519;
136 obj["curve25519"] = msg.curve25519;
137 }
138
139 inline void
from_json(const nlohmann::json & obj,DevKeys & msg)140 from_json(const nlohmann::json &obj, DevKeys &msg)
141 {
142 msg.ed25519 = obj.at("ed25519");
143 msg.curve25519 = obj.at("curve25519");
144 }
145
146 auto console = spdlog::stdout_color_mt("console");
147
148 std::shared_ptr<Client> client = nullptr;
149 std::shared_ptr<OlmClient> olm_client = nullptr;
150
151 struct Storage
152 {
153 //! Storage for the user_id -> list of devices mapping.
154 std::map<std::string, std::vector<std::string>> devices;
155 //! Storage for the identity key for a device.
156 std::map<std::string, DevKeys> device_keys;
157 //! Flag that indicate if a specific room has encryption enabled.
158 std::map<std::string, bool> encrypted_rooms;
159
160 //! Keep track of members per room.
161 std::map<std::string, std::map<std::string, bool>> members;
162
add_memberStorage163 void add_member(const std::string &room_id, const std::string &user_id)
164 {
165 members[room_id][user_id] = true;
166 }
167
168 //! Mapping from curve25519 to session.
169 std::map<std::string, OlmSessionPtr> olm_inbound_sessions;
170 std::map<std::string, OlmSessionPtr> olm_outbound_sessions;
171
172 // TODO: store message_index / event_id
173 std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions;
174 // TODO: store rotation period
175 std::map<std::string, OutboundSessionData> outbound_group_session_data;
176 std::map<std::string, OutboundGroupSessionPtr> outbound_group_sessions;
177
outbound_group_existsStorage178 bool outbound_group_exists(const std::string &room_id)
179 {
180 return (outbound_group_sessions.find(room_id) != outbound_group_sessions.end()) &&
181 (outbound_group_session_data.find(room_id) != outbound_group_session_data.end());
182 }
183
set_outbound_group_sessionStorage184 void set_outbound_group_session(const std::string &room_id,
185 OutboundGroupSessionPtr session,
186 OutboundSessionData data)
187 {
188 outbound_group_session_data[room_id] = data;
189 outbound_group_sessions[room_id] = std::move(session);
190 }
191
get_outbound_group_sessionStorage192 OutboundSessionDataRef get_outbound_group_session(const std::string &room_id)
193 {
194 return OutboundSessionDataRef{outbound_group_sessions[room_id].get(),
195 outbound_group_session_data[room_id]};
196 }
197
inbound_group_existsStorage198 bool inbound_group_exists(const std::string &room_id,
199 const std::string &session_id,
200 const std::string &sender_key)
201 {
202 const auto key = room_id + session_id + sender_key;
203 return inbound_group_sessions.find(key) != inbound_group_sessions.end();
204 }
205
set_inbound_group_sessionStorage206 void set_inbound_group_session(const std::string &room_id,
207 const std::string &session_id,
208 const std::string &sender_key,
209 InboundGroupSessionPtr session)
210 {
211 const auto key = room_id + session_id + sender_key;
212 inbound_group_sessions[key] = std::move(session);
213 }
214
get_inbound_group_sessionStorage215 OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id,
216 const std::string &session_id,
217 const std::string &sender_key)
218 {
219 const auto key = room_id + session_id + sender_key;
220 return inbound_group_sessions[key].get();
221 }
222
loadStorage223 void load()
224 {
225 console->info("restoring storage");
226
227 ifstream db("db.json");
228 string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
229
230 if (db_data.empty())
231 return;
232
233 json obj = json::parse(db_data);
234
235 devices = obj.at("devices").get<map<string, vector<string>>>();
236 device_keys = obj.at("device_keys").get<map<string, DevKeys>>();
237 encrypted_rooms = obj.at("encrypted_rooms").get<map<string, bool>>();
238 members = obj.at("members").get<map<string, map<string, bool>>>();
239
240 if (obj.count("olm_inbound_sessions") != 0) {
241 auto sessions = obj.at("olm_inbound_sessions").get<map<string, string>>();
242 for (const auto &s : sessions)
243 olm_inbound_sessions[s.first] = unpickle<SessionObject>(s.second, STORAGE_KEY);
244 }
245
246 if (obj.count("olm_outbound_sessions") != 0) {
247 auto sessions = obj.at("olm_outbound_sessions").get<map<string, string>>();
248 for (const auto &s : sessions)
249 olm_outbound_sessions[s.first] = unpickle<SessionObject>(s.second, STORAGE_KEY);
250 }
251
252 if (obj.count("inbound_group_sessions") != 0) {
253 auto sessions = obj.at("inbound_group_sessions").get<map<string, string>>();
254 for (const auto &s : sessions)
255 inbound_group_sessions[s.first] =
256 unpickle<InboundSessionObject>(s.second, STORAGE_KEY);
257 }
258
259 if (obj.count("outbound_group_sessions") != 0) {
260 auto sessions = obj.at("outbound_group_sessions").get<map<string, string>>();
261 for (const auto &s : sessions)
262 outbound_group_sessions[s.first] =
263 unpickle<OutboundSessionObject>(s.second, STORAGE_KEY);
264 }
265
266 if (obj.count("outbound_group_session_data") != 0) {
267 auto sessions =
268 obj.at("outbound_group_session_data").get<map<string, OutboundSessionData>>();
269 for (const auto &s : sessions)
270 outbound_group_session_data[s.first] = s.second;
271 }
272 }
273
saveStorage274 void save()
275 {
276 console->info("saving storage");
277
278 std::ofstream db("db.json");
279 if (!db.is_open()) {
280 console->error("couldn't open file to save keys");
281 return;
282 }
283
284 json data;
285 data["devices"] = devices;
286 data["device_keys"] = device_keys;
287 data["encrypted_rooms"] = encrypted_rooms;
288 data["members"] = members;
289
290 // Save inbound sessions
291 for (const auto &s : olm_inbound_sessions)
292 data["olm_inbound_sessions"][s.first] =
293 mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
294
295 for (const auto &s : olm_outbound_sessions)
296 data["olm_outbound_sessions"][s.first] =
297 mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
298
299 for (const auto &s : inbound_group_sessions)
300 data["inbound_group_sessions"][s.first] =
301 mtx::crypto::pickle<InboundSessionObject>(s.second.get(), STORAGE_KEY);
302
303 for (const auto &s : outbound_group_sessions)
304 data["outbound_group_sessions"][s.first] =
305 mtx::crypto::pickle<OutboundSessionObject>(s.second.get(), STORAGE_KEY);
306
307 for (const auto &s : outbound_group_session_data)
308 data["outbound_group_session_data"][s.first] = s.second;
309
310 // Save to file
311 db << data.dump(2);
312 db.close();
313 }
314 };
315
316 namespace {
317 Storage storage;
318 }
319
320 void
print_errors(RequestErr err)321 print_errors(RequestErr err)
322 {
323 if (err->status_code)
324 console->error("status code: {}", static_cast<uint16_t>(err->status_code));
325 if (!err->matrix_error.error.empty())
326 console->error("matrix error: {}", err->matrix_error.error);
327 if (err->error_code)
328 console->error("error code: {}", err->error_code);
329 }
330
331 template<class T>
332 bool
is_room_encryption(const T & event)333 is_room_encryption(const T &event)
334 {
335 using namespace mtx::events;
336 using namespace mtx::events::state;
337 return std::holds_alternative<StateEvent<Encryption>>(event);
338 }
339
340 void
send_group_message(OlmOutboundGroupSession * session,const std::string & session_id,const std::string & room_id,const std::string & msg)341 send_group_message(OlmOutboundGroupSession *session,
342 const std::string &session_id,
343 const std::string &room_id,
344 const std::string &msg)
345 {
346 // Create event payload
347 json doc{{"type", "m.room.message"},
348 {"content", {{"type", "m.text"}, {"body", msg}}},
349 {"room_id", room_id}};
350
351 auto payload = olm_client->encrypt_group_message(session, doc.dump());
352
353 using namespace mtx::events;
354 using namespace mtx::identifiers;
355
356 msg::Encrypted data;
357 data.ciphertext = std::string((char *)payload.data(), payload.size());
358 data.sender_key = olm_client->identity_keys().curve25519;
359 data.session_id = session_id;
360 data.device_id = client->device_id();
361
362 client->send_room_message<msg::Encrypted>(
363 room_id, data, [](const mtx::responses::EventId &res, RequestErr err) {
364 if (err) {
365 print_errors(err);
366 return;
367 }
368
369 console->info("message sent with event_id: {}", res.event_id.to_string());
370 });
371 }
372
373 void
create_outbound_megolm_session(const std::string & room_id,const std::string & reply_msg)374 create_outbound_megolm_session(const std::string &room_id, const std::string &reply_msg)
375 {
376 // Create an outbound session
377 auto outbound_session = olm_client->init_outbound_group_session();
378
379 const auto session_id = mtx::crypto::session_id(outbound_session.get());
380 const auto session_key = mtx::crypto::session_key(outbound_session.get());
381
382 mtx::events::DeviceEvent<mtx::events::msg::RoomKey> megolm_payload;
383 megolm_payload.content.algorithm = "m.megolm.v1.aes-sha2";
384 megolm_payload.content.room_id = room_id;
385 megolm_payload.content.session_id = session_id;
386 megolm_payload.content.session_key = session_key;
387 megolm_payload.type = mtx::events::EventType::RoomKey;
388
389 if (storage.members.find(room_id) == storage.members.end()) {
390 console->error("no members found for room {}", room_id);
391 return;
392 }
393
394 const auto members = storage.members[room_id];
395
396 for (const auto &member : members) {
397 const auto devices = storage.devices[member.first];
398
399 // TODO: Figure out for which devices we don't have olm sessions.
400 for (const auto &dev : devices) {
401 // TODO: check if we have downloaded the keys
402 const auto device_keys = storage.device_keys[dev];
403
404 auto to_device_cb = [](RequestErr err) {
405 if (err) {
406 print_errors(err);
407 }
408 };
409
410 if (storage.olm_outbound_sessions.find(device_keys.curve25519) !=
411 storage.olm_outbound_sessions.end()) {
412 console->info("found existing olm outbound session with device {}", dev);
413 auto olm_session = storage.olm_outbound_sessions[device_keys.curve25519].get();
414
415 auto device_msg = olm_client->create_olm_encrypted_content(olm_session,
416 megolm_payload,
417 UserId(member.first),
418 device_keys.ed25519,
419 device_keys.curve25519);
420
421 json body{{"messages", {{member, {{dev, device_msg}}}}}};
422
423 client->send_to_device("m.room.encrypted", body, to_device_cb);
424 // TODO: send message to device
425 } else {
426 console->info("claiming one time keys for device {}", dev);
427 auto cb = [member = member.first, dev, megolm_payload, to_device_cb](
428 const mtx::responses::ClaimKeys &res, RequestErr err) {
429 if (err) {
430 print_errors(err);
431 return;
432 }
433
434 console->info("claimed keys for {} - {}", member, dev);
435 console->info("room_key {}", json(megolm_payload).dump(4));
436
437 console->warn("signed one time keys");
438 auto retrieved_devices = res.one_time_keys.at(member);
439 for (const auto &rd : retrieved_devices) {
440 console->info("{} : \n {}", rd.first, rd.second.dump(2));
441
442 // TODO: Verify signatures
443 auto otk = rd.second.begin()->at("key");
444 auto id_key = storage.device_keys[dev].curve25519;
445
446 auto session = olm_client->create_outbound_session(id_key, otk);
447
448 auto device_msg = olm_client->create_olm_encrypted_content(
449 session.get(),
450 megolm_payload,
451 UserId(member),
452 storage.device_keys[dev].ed25519,
453 storage.device_keys[dev].curve25519);
454
455 // TODO: saving should happen when the message is
456 // sent.
457 storage.olm_outbound_sessions[id_key] = std::move(session);
458
459 json body{{"messages", {{member, {{dev, device_msg}}}}}};
460
461 client->send_to_device("m.room.encrypted", body, to_device_cb);
462 }
463 };
464
465 mtx::requests::ClaimKeys claim_keys;
466 claim_keys.one_time_keys[member.first][dev] = SIGNED_CURVE25519;
467
468 // TODO: we should bulk request device keys here
469 client->claim_keys(claim_keys, cb);
470 }
471 }
472 }
473
474 console->info("waiting to send sendToDevice messages");
475 std::this_thread::sleep_for(std::chrono::milliseconds(2000));
476 console->info("sending encrypted group message");
477
478 // TODO: This should be done after all sendToDevice messages have been sent.
479 send_group_message(outbound_session.get(), session_id, room_id, reply_msg);
480
481 // TODO: save message index also.
482 storage.set_outbound_group_session(
483 room_id, std::move(outbound_session), {session_id, session_key});
484 }
485
486 bool
is_encrypted(const TimelineEvent & event)487 is_encrypted(const TimelineEvent &event)
488 {
489 using namespace mtx::events;
490 return std::holds_alternative<EncryptedEvent<msg::Encrypted>>(event);
491 }
492
493 template<class T>
494 bool
is_member_event(const T & event)495 is_member_event(const T &event)
496 {
497 using namespace mtx::events;
498 using namespace mtx::events::state;
499 return std::holds_alternative<StateEvent<Member>>(event);
500 }
501
502 // Check if the given event has a textual representation.
503 bool
is_room_message(const TimelineEvent & e)504 is_room_message(const TimelineEvent &e)
505 {
506 return (std::holds_alternative<mtx::events::RoomEvent<msg::Audio>>(e)) ||
507 (std::holds_alternative<mtx::events::RoomEvent<msg::Emote>>(e)) ||
508 (std::holds_alternative<mtx::events::RoomEvent<msg::File>>(e)) ||
509 (std::holds_alternative<mtx::events::RoomEvent<msg::Image>>(e)) ||
510 (std::holds_alternative<mtx::events::RoomEvent<msg::Notice>>(e)) ||
511 (std::holds_alternative<mtx::events::RoomEvent<msg::Text>>(e)) ||
512 (std::holds_alternative<mtx::events::RoomEvent<msg::Video>>(e));
513 }
514
515 // Retrieves the fallback body value from the event.
516 std::string
get_body(const TimelineEvent & e)517 get_body(const TimelineEvent &e)
518 {
519 if (auto ev = std::get_if<RoomEvent<msg::Audio>>(&e); ev != nullptr)
520 return ev->content.body;
521 else if (auto ev = std::get_if<RoomEvent<msg::Emote>>(&e); ev != nullptr)
522 return ev->content.body;
523 else if (auto ev = std::get_if<RoomEvent<msg::File>>(&e); ev != nullptr)
524 return ev->content.body;
525 else if (auto ev = std::get_if<RoomEvent<msg::Image>>(&e); ev != nullptr)
526 return ev->content.body;
527 else if (auto ev = std::get_if<RoomEvent<msg::Notice>>(&e); ev != nullptr)
528 return ev->content.body;
529 else if (auto ev = std::get_if<RoomEvent<msg::Text>>(&e); ev != nullptr)
530 return ev->content.body;
531 else if (auto ev = std::get_if<RoomEvent<msg::Video>>(&e); ev != nullptr)
532 return ev->content.body;
533
534 return "";
535 }
536
537 // Retrieves the sender of the event.
538 std::string
get_sender(const TimelineEvent & event)539 get_sender(const TimelineEvent &event)
540 {
541 return std::visit([](auto e) { return e.sender; }, event);
542 }
543
544 template<class T>
545 std::string
get_json(const T & event)546 get_json(const T &event)
547 {
548 return std::visit([](auto e) { return json(e).dump(2); }, event);
549 }
550
551 void
keys_uploaded_cb(const mtx::responses::UploadKeys &,RequestErr err)552 keys_uploaded_cb(const mtx::responses::UploadKeys &, RequestErr err)
553 {
554 if (err) {
555 print_errors(err);
556 return;
557 }
558
559 olm_client->mark_keys_as_published();
560 console->info("keys uploaded");
561 }
562
563 void
mark_encrypted_room(const RoomId & id)564 mark_encrypted_room(const RoomId &id)
565 {
566 console->info("encryption is enabled for room: {}", id.get());
567 storage.encrypted_rooms[id.get()] = true;
568 }
569
570 void
send_encrypted_reply(const std::string & room_id,const std::string & reply_msg)571 send_encrypted_reply(const std::string &room_id, const std::string &reply_msg)
572 {
573 console->info("sending reply");
574
575 // Create a megolm session if it doesn't already exist.
576 if (storage.outbound_group_exists(room_id)) {
577 auto session_obj = storage.get_outbound_group_session(room_id);
578
579 send_group_message(session_obj.session, session_obj.data.session_id, room_id, reply_msg);
580
581 } else {
582 console->info("creating new megolm outbound session");
583 create_outbound_megolm_session(room_id, reply_msg);
584 }
585 }
586
587 void
decrypt_olm_message(const OlmMessage & olm_msg)588 decrypt_olm_message(const OlmMessage &olm_msg)
589 {
590 console->info("OLM message");
591 console->info("sender: {}", olm_msg.sender);
592 console->info("sender_key: {}", olm_msg.sender_key);
593
594 const auto my_id_key = olm_client->identity_keys().curve25519;
595 for (const auto &cipher : olm_msg.ciphertext) {
596 if (cipher.first == my_id_key) {
597 const auto msg_body = cipher.second.body;
598 const auto msg_type = cipher.second.type;
599
600 console->info("the message is meant for us");
601 console->info("body: {}", msg_body);
602 console->info("type: {}", msg_type);
603
604 if (msg_type == 0) {
605 console->info("opening session with {}", olm_msg.sender);
606 auto inbound_session = olm_client->create_inbound_session(msg_body);
607
608 auto ok =
609 matches_inbound_session_from(inbound_session.get(), olm_msg.sender_key, msg_body);
610
611 if (!ok) {
612 console->error("session could not be established");
613
614 } else {
615 auto output =
616 olm_client->decrypt_message(inbound_session.get(), msg_type, msg_body);
617
618 auto plaintext = json::parse(std::string((char *)output.data(), output.size()));
619 console->info("decrypted message: \n {}", plaintext.dump(2));
620
621 storage.olm_inbound_sessions.emplace(olm_msg.sender_key,
622 std::move(inbound_session));
623
624 std::string room_id = plaintext.at("content").at("room_id");
625 std::string session_id = plaintext.at("content").at("session_id");
626 std::string session_key = plaintext.at("content").at("session_key");
627
628 if (storage.inbound_group_exists(room_id, session_id, olm_msg.sender_key)) {
629 console->warn("megolm session already exists");
630 } else {
631 auto megolm_session = olm_client->init_inbound_group_session(session_key);
632
633 storage.set_inbound_group_session(
634 room_id, session_id, olm_msg.sender_key, std::move(megolm_session));
635
636 console->info("megolm_session saved");
637 }
638 }
639 }
640 }
641 }
642 }
643
644 void
parse_messages(const mtx::responses::Sync & res)645 parse_messages(const mtx::responses::Sync &res)
646 {
647 for (const auto &room : res.rooms.invite) {
648 auto room_id = room.first;
649
650 console->info("joining room {}", room_id);
651 client->join_room(room_id, [room_id](const mtx::responses::RoomId &, RequestErr e) {
652 if (e) {
653 print_errors(e);
654 console->error("failed to join room {}", room_id);
655 return;
656 }
657 });
658 }
659
660 // Check if we have any new m.room_key messages (i.e starting a new megolm session)
661 handle_to_device_msgs(res.to_device);
662
663 // Check if the uploaded one time keys are enough
664 for (const auto &device : res.device_one_time_keys_count) {
665 if (device.second < 50) {
666 console->info("number of one time keys: {}", device.second);
667 olm_client->generate_one_time_keys(50 - device.second);
668 // TODO: Mark keys as sent
669 client->upload_keys(olm_client->create_upload_keys_request(), &keys_uploaded_cb);
670 }
671 }
672
673 for (const auto &room : res.rooms.join) {
674 const std::string room_id = room.first;
675
676 for (const auto &e : room.second.state.events) {
677 if (is_room_encryption(e)) {
678 mark_encrypted_room(RoomId(room_id));
679 console->debug("{}", get_json(e));
680 } else if (is_member_event(e)) {
681 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
682
683 get_device_keys(UserId(m.state_key));
684 storage.add_member(room_id, m.state_key);
685 }
686 }
687
688 for (const auto &e : room.second.timeline.events) {
689 if (is_room_encryption(e)) {
690 mark_encrypted_room(RoomId(room_id));
691 console->debug("{}", get_json(e));
692 } else if (is_member_event(e)) {
693 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
694
695 get_device_keys(UserId(m.state_key));
696 storage.add_member(room_id, m.state_key);
697 } else if (is_encrypted(e)) {
698 console->info("received an encrypted event: {}", room_id);
699 console->info("{}", get_json(e));
700
701 auto msg = std::get<EncryptedEvent<msg::Encrypted>>(e);
702
703 if (storage.inbound_group_exists(
704 room_id, msg.content.session_id, msg.content.sender_key)) {
705 auto res = olm_client->decrypt_group_message(
706 storage.get_inbound_group_session(
707 room_id, msg.content.session_id, msg.content.sender_key),
708 msg.content.ciphertext);
709
710 auto msg_str = std::string((char *)res.data.data(), res.data.size());
711 const auto body =
712 json::parse(msg_str).at("content").at("body").get<std::string>();
713
714 console->info("decrypted data: {}", body);
715 console->info("decrypted message_index: {}", res.message_index);
716
717 if (msg.sender != client->user_id().to_string()) {
718 // Send a reply back to the sender.
719 std::string reply_txt(msg.sender + ": you said '" + body + "'");
720 send_encrypted_reply(room_id, reply_txt);
721 }
722
723 } else {
724 console->warn("no megolm session found to decrypt the event");
725 }
726 }
727 }
728 }
729 }
730
731 // Callback to executed after a /sync request completes.
732 void
sync_handler(const mtx::responses::Sync & res,RequestErr err)733 sync_handler(const mtx::responses::Sync &res, RequestErr err)
734 {
735 SyncOpts opts;
736
737 if (err) {
738 console->error("error during sync");
739 print_errors(err);
740 opts.since = client->next_batch_token();
741 client->sync(opts, &sync_handler);
742 return;
743 }
744
745 parse_messages(res);
746
747 opts.since = res.next_batch;
748 client->set_next_batch_token(res.next_batch);
749 client->sync(opts, &sync_handler);
750 }
751
752 // Callback to executed after the first (initial) /sync request completes.
753 void
initial_sync_handler(const mtx::responses::Sync & res,RequestErr err)754 initial_sync_handler(const mtx::responses::Sync &res, RequestErr err)
755 {
756 SyncOpts opts;
757
758 if (err) {
759 console->error("error during initial sync");
760 print_errors(err);
761
762 if (err->status_code != 200) {
763 console->error("retrying initial sync ..");
764 opts.timeout = 0;
765 client->sync(opts, &initial_sync_handler);
766 }
767
768 return;
769 }
770
771 parse_messages(res);
772
773 for (const auto &room : res.rooms.join) {
774 const auto room_id = room.first;
775
776 for (const auto &e : room.second.state.events) {
777 if (is_member_event(e)) {
778 auto m = std::get<mtx::events::StateEvent<mtx::events::state::Member>>(e);
779
780 get_device_keys(UserId(m.state_key));
781 storage.add_member(room_id, m.state_key);
782 }
783 }
784 }
785
786 opts.since = res.next_batch;
787 client->set_next_batch_token(res.next_batch);
788 client->sync(opts, &sync_handler);
789 }
790
791 void
save_device_keys(const mtx::responses::QueryKeys & res)792 save_device_keys(const mtx::responses::QueryKeys &res)
793 {
794 for (const auto &entry : res.device_keys) {
795 const auto user_id = entry.first;
796
797 if (!exists(storage.devices, user_id))
798 console->info("keys for {}", user_id);
799
800 std::vector<std::string> device_list;
801 for (const auto &device : entry.second) {
802 const auto key_struct = device.second;
803
804 const std::string device_id = key_struct.device_id;
805 const std::string index = "curve25519:" + device_id;
806
807 if (key_struct.keys.find(index) == key_struct.keys.end())
808 continue;
809
810 const auto key = key_struct.keys.at(index);
811
812 if (!exists(storage.device_keys, device_id)) {
813 console->info("{} => {}", device_id, key);
814 storage.device_keys[device_id] = {key_struct.keys.at("ed25519:" + device_id),
815 key_struct.keys.at("curve25519:" + device_id)};
816 }
817
818 device_list.push_back(device_id);
819 }
820
821 if (!exists(storage.devices, user_id)) {
822 storage.devices[user_id] = device_list;
823 }
824 }
825 }
826
827 void
get_device_keys(const UserId & user)828 get_device_keys(const UserId &user)
829 {
830 // Retrieve all devices keys.
831 mtx::requests::QueryKeys query;
832 query.device_keys[user.get()] = {};
833
834 client->query_keys(query, [](const mtx::responses::QueryKeys &res, RequestErr err) {
835 if (err) {
836 print_errors(err);
837 return;
838 }
839
840 for (const auto &key : res.device_keys) {
841 const auto user_id = key.first;
842 const auto devices = key.second;
843
844 for (const auto &device : devices) {
845 const auto id = device.first;
846 const auto data = device.second;
847
848 try {
849 auto ok = verify_identity_signature(json(data), DeviceId(id), UserId(user_id));
850
851 if (!ok) {
852 console->warn("signature could not be verified");
853 console->warn(json(data).dump(2));
854 }
855 } catch (const olm_exception &e) {
856 console->warn(e.what());
857 }
858 }
859 }
860
861 save_device_keys(std::move(res));
862 });
863 }
864
865 void
handle_to_device_msgs(const mtx::responses::ToDevice & msgs)866 handle_to_device_msgs(const mtx::responses::ToDevice &msgs)
867 {
868 if (!msgs.events.empty())
869 console->info("inspecting {} to_device messages", msgs.events.size());
870
871 for (const auto &msg : msgs.events) {
872 console->info(std::visit([](const auto &e) { return json(e); }, msg).dump(2));
873
874 try {
875 OlmMessage olm_msg = std::visit([](const auto &e) { return json(e); }, msg);
876 decrypt_olm_message(std::move(olm_msg));
877 } catch (const nlohmann::json::exception &e) {
878 console->warn("parsing error for olm message: {}", e.what());
879 } catch (const std::invalid_argument &e) {
880 console->warn("validation error for olm message: {}", e.what());
881 }
882 }
883 }
884
885 void
login_cb(const mtx::responses::Login &,RequestErr err)886 login_cb(const mtx::responses::Login &, RequestErr err)
887 {
888 if (err) {
889 console->error("login error");
890 print_errors(err);
891 return;
892 }
893
894 console->info("User ID: {}", client->user_id().to_string());
895 console->info("Device ID: {}", client->device_id());
896 console->info("ed25519: {}", olm_client->identity_keys().ed25519);
897 console->info("curve25519: {}", olm_client->identity_keys().curve25519);
898
899 // Upload one time keys.
900 olm_client->set_user_id(client->user_id().to_string());
901 olm_client->set_device_id(client->device_id());
902 olm_client->generate_one_time_keys(50);
903
904 client->upload_keys(olm_client->create_upload_keys_request(),
905 [](const mtx::responses::UploadKeys &, RequestErr err) {
906 if (err) {
907 print_errors(err);
908 return;
909 }
910
911 olm_client->mark_keys_as_published();
912 console->info("keys uploaded");
913 console->debug("starting initial sync");
914
915 SyncOpts opts;
916 opts.timeout = 0;
917 client->sync(opts, &initial_sync_handler);
918 });
919 }
920
921 void
join_room_cb(const mtx::responses::RoomId &,RequestErr err)922 join_room_cb(const mtx::responses::RoomId &, RequestErr err)
923 {
924 if (err) {
925 print_errors(err);
926 return;
927 }
928
929 // Fetch device list for all users.
930 }
931
932 void
shutdown_handler(int sig)933 shutdown_handler(int sig)
934 {
935 console->warn("received {} signal", sig);
936 storage.save();
937
938 std::ofstream db("account.json");
939 if (!db.is_open()) {
940 console->error("couldn't open file to save account keys");
941 return;
942 }
943
944 json data;
945 data["account"] = olm_client->save(STORAGE_KEY);
946
947 db << data.dump(2);
948 db.close();
949
950 // The sync calls will stop.
951 client->shutdown();
952 }
953
954 int
main()955 main()
956 {
957 spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
958
959 std::signal(SIGINT, shutdown_handler);
960
961 std::string username("alice");
962 std::string server("localhost");
963 std::string password("secret");
964
965 client = std::make_shared<Client>(server);
966
967 olm_client = make_shared<OlmClient>();
968
969 ifstream db("account.json");
970 string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
971
972 if (db_data.empty())
973 olm_client->create_new_account();
974 else
975 olm_client->load(json::parse(db_data).at("account").get<std::string>(), STORAGE_KEY);
976
977 storage.load();
978
979 client->login(username, password, login_cb);
980 client->close();
981
982 console->info("exit");
983
984 return 0;
985 }
986