1 // Copyright (C) 2009 Timothy Brownawell <tbrownaw@prjek.net>
2 //
3 // This program is made available under the GNU GPL version 2.0 or
4 // greater. See the accompanying file COPYING for details.
5 //
6 // This program is distributed WITHOUT ANY WARRANTY; without even the
7 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 // PURPOSE.
9 
10 #include "../base.hh"
11 #include "../app_state.hh"
12 #include "../key_store.hh"
13 #include "../database.hh"
14 #include "../keys.hh"
15 #include "../lazy_rng.hh"
16 #include "../lua_hooks.hh"
17 #include "automate_session.hh"
18 #include "netsync_session.hh"
19 #include "../options.hh"
20 #include "../project.hh"
21 #include "../vocab_cast.hh"
22 #include "session.hh"
23 #include "automate_session.hh"
24 #include "netsync_session.hh"
25 
26 using std::string;
27 
28 using boost::lexical_cast;
29 using boost::shared_ptr;
30 
31 
32 static const var_domain known_servers_domain = var_domain("known-servers");
33 
34 size_t session::session_num = 0;
35 
session(app_state & app,project_t & project,key_store & keys,protocol_voice voice,std::string const & peer,shared_ptr<Netxx::StreamBase> sock)36 session::session(app_state & app, project_t & project,
37                  key_store & keys,
38                  protocol_voice voice,
39                  std::string const & peer,
40                  shared_ptr<Netxx::StreamBase> sock) :
41   session_base(voice, peer, sock),
42   version(app.opts.max_netsync_version),
43   max_version(app.opts.max_netsync_version),
44   min_version(app.opts.min_netsync_version),
45   use_transport_auth(!app.opts.no_transport_auth),
46   signing_key(keys.signing_key),
47   cmd_in(0),
48   armed(false),
49   received_remote_key(false),
50   session_key(constants::netsync_key_initializer),
51   read_hmac(netsync_session_key(constants::netsync_key_initializer),
52             use_transport_auth),
53   write_hmac(netsync_session_key(constants::netsync_key_initializer),
54              use_transport_auth),
55   authenticated(false),
56   completed_hello(false),
57   error_code(0),
58   session_id(++session_num),
59   app(app),
60   project(project),
61   keys(keys),
62   peer(peer),
63   unnoted_bytes_in(0),
64   unnoted_bytes_out(0)
65 {
66   if (!app.opts.max_netsync_version_given)
67     {
68       max_version = constants::netcmd_current_protocol_version;
69       version = max_version;
70     }
71   if (!app.opts.min_netsync_version_given)
72     min_version = constants::netcmd_minimum_protocol_version;
73 }
74 
~session()75 session::~session()
76 {
77   if (wrapped)
78     wrapped->on_end(session_id);
79 }
80 
set_inner(shared_ptr<wrapped_session> wrapped)81 void session::set_inner(shared_ptr<wrapped_session> wrapped)
82 {
83   this->wrapped = wrapped;
84 }
85 
86 id
mk_nonce()87 session::mk_nonce()
88 {
89   I(this->saved_nonce().empty());
90   char buf[constants::merkle_hash_length_in_bytes];
91 
92 #if BOTAN_VERSION_CODE >= BOTAN_VERSION_CODE_FOR(1,7,7)
93   lazy_rng::get().randomize(reinterpret_cast<Botan::byte *>(buf),
94                             constants::merkle_hash_length_in_bytes);
95 #else
96   Botan::Global_RNG::randomize(reinterpret_cast<Botan::byte *>(buf),
97                                constants::merkle_hash_length_in_bytes);
98 #endif
99   this->saved_nonce = id(string(buf, buf + constants::merkle_hash_length_in_bytes),
100                          origin::internal);
101   I(this->saved_nonce().size() == constants::merkle_hash_length_in_bytes);
102   return this->saved_nonce;
103 }
104 
105 void
set_session_key(string const & key)106 session::set_session_key(string const & key)
107 {
108   session_key = netsync_session_key(key, origin::internal);
109   read_hmac.set_key(session_key);
110   write_hmac.set_key(session_key);
111 }
112 
113 void
set_session_key(rsa_oaep_sha_data const & hmac_key_encrypted)114 session::set_session_key(rsa_oaep_sha_data const & hmac_key_encrypted)
115 {
116   MM(use_transport_auth);
117   if (use_transport_auth)
118     {
119       MM(signing_key);
120       string hmac_key;
121       keys.decrypt_rsa(signing_key, hmac_key_encrypted, hmac_key);
122       set_session_key(hmac_key);
123     }
124 }
125 
arm()126 bool session::arm()
127 {
128   if (!armed)
129     {
130       // Don't pack the buffer unnecessarily.
131       if (output_overfull())
132         return false;
133 
134       if (cmd_in.read(min_version, max_version, inbuf, read_hmac))
135         {
136           L(FL("armed with netcmd having code '%d'") % cmd_in.get_cmd_code());
137           armed = true;
138         }
139     }
140   return armed;
141 }
142 
begin_service()143 void session::begin_service()
144 {
145   netcmd cmd(0);
146   cmd.write_usher_cmd(utf8("", origin::internal));
147   write_netcmd(cmd);
148 }
149 
do_work(transaction_guard & guard)150 bool session::do_work(transaction_guard & guard)
151 {
152   try
153     {
154       arm();
155       bool is_goodbye = armed && cmd_in.get_cmd_code() == bye_cmd;
156       bool is_error = armed && cmd_in.get_cmd_code() == error_cmd;
157       if (completed_hello && !is_goodbye && !is_error)
158         {
159           if (encountered_error)
160             return true;
161           else
162             {
163               if (armed)
164                 L(FL("doing work for peer '%s' with '%d' netcmd")
165                   % get_peer() % cmd_in.get_cmd_code());
166               else
167                 L(FL("doing work for peer '%s' with no netcmd")
168                   % get_peer());
169               bool ok = wrapped->do_work(guard, armed ? &cmd_in : 0);
170               armed = false;
171               if (ok)
172                 {
173                   if (voice == client_voice
174                       && protocol_state == working_state
175                       && wrapped->finished_working())
176                     {
177                       protocol_state = shutdown_state;
178                       guard.do_checkpoint();
179                       queue_bye_cmd(0);
180                     }
181                 }
182               return ok;
183             }
184         }
185       else
186         {
187           if (!armed)
188             return true;
189           armed = false;
190           switch (cmd_in.get_cmd_code())
191             {
192             case usher_cmd:
193               {
194                 utf8 msg;
195                 cmd_in.read_usher_cmd(msg);
196                 if (msg().size())
197                   {
198                     if (msg()[0] == '!')
199                       P(F("received warning from usher: %s") % msg().substr(1));
200                     else
201                       L(FL("received greeting from usher: %s") % msg().substr(1));
202                   }
203                 netcmd cmdout(version);
204                 cmdout.write_usher_reply_cmd(utf8(peer_id, origin::internal),
205                                              wrapped->usher_reply_data());
206                 write_netcmd(cmdout);
207                 L(FL("sent reply."));
208                 return true;
209               }
210             case usher_reply_cmd:
211               {
212                 u8 client_version;
213                 utf8 server;
214                 string pattern;
215                 cmd_in.read_usher_reply_cmd(client_version, server, pattern);
216 
217                 // netcmd::read() has already checked that the client isn't too old
218                 if (client_version < max_version)
219                   {
220                     version = client_version;
221                   }
222                 L(FL("client has maximum version %d, using %d")
223                   % widen<u32>(client_version) % widen<u32>(version));
224                 netcmd cmd(version);
225 
226                 key_name name;
227                 keypair kp;
228                 if (use_transport_auth)
229                   {
230                     keys.get_key_pair(signing_key, name, kp);
231                     cmd.write_hello_cmd(name, kp.pub, mk_nonce());
232                   }
233                 else
234                   {
235                     cmd.write_hello_cmd(name, kp.pub, mk_nonce());
236                   }
237                 write_netcmd(cmd);
238                 return true;
239               }
240             case hello_cmd:
241               { // need to ask wrapped what to reply with (we're a client)
242                 u8 server_version;
243                 key_name their_keyname;
244                 rsa_pub_key their_key;
245                 id nonce;
246                 cmd_in.read_hello_cmd(server_version, their_keyname,
247                                       their_key, nonce);
248                 hello_nonce = nonce;
249 
250                 I(!received_remote_key);
251                 I(saved_nonce().empty());
252 
253                 // version sanity has already been checked by netcmd::read()
254                 L(FL("received hello command; setting version from %d to %d")
255                   % widen<u32>(get_version())
256                   % widen<u32>(server_version));
257                 version = server_version;
258 
259                 if (use_transport_auth)
260                   {
261                     key_hash_code(their_keyname, their_key, remote_peer_key_id);
262 
263                     var_value printable_key_hash;
264                     {
265                       hexenc<id> encoded_key_hash;
266                       encode_hexenc(remote_peer_key_id.inner(), encoded_key_hash);
267                       printable_key_hash = typecast_vocab<var_value>(encoded_key_hash);
268                     }
269                     L(FL("server key has name %s, hash %s")
270                       % their_keyname % printable_key_hash);
271                     var_key their_key_key(known_servers_domain,
272                                           var_name(get_peer(), origin::internal));
273                     if (project.db.var_exists(their_key_key))
274                       {
275                         var_value expected_key_hash;
276                         project.db.get_var(their_key_key, expected_key_hash);
277                         if (expected_key_hash != printable_key_hash)
278                           {
279                             P(F("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
280                                 "@ WARNING: SERVER IDENTIFICATION HAS CHANGED              @\n"
281                                 "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
282                                 "IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY\n"
283                                 "It is also possible that the server key has just been changed.\n"
284                                 "Remote host sent key %s,\n"
285                                 "I expected %s\n"
286                                 "'%s unset %s %s' overrides this check")
287                               % printable_key_hash
288                               % expected_key_hash
289                               % prog_name % their_key_key.first % their_key_key.second);
290                             E(false, origin::network, F("server key changed"));
291                           }
292                       }
293                     else
294                       {
295                         P(F("first time connecting to server %s.\n"
296                             "I'll assume it's really them, but you might want to double-check\n"
297                             "their key's fingerprint: %s")
298                           % get_peer()
299                           % printable_key_hash);
300                         project.db.set_var(their_key_key, printable_key_hash);
301                       }
302 
303                     if (!project.db.public_key_exists(remote_peer_key_id))
304                       {
305                         // this should now always return true since we just checked
306                         // for the existence of this particular key
307                         I(project.db.put_key(their_keyname, their_key));
308                         W(F("saving public key for %s to database") % their_keyname);
309                       }
310                     {
311                       hexenc<id> hnonce;
312                       encode_hexenc(nonce, hnonce);
313                       L(FL("received 'hello' netcmd from server '%s' with nonce '%s'")
314                         % printable_key_hash % hnonce);
315                     }
316 
317                     I(project.db.public_key_exists(remote_peer_key_id));
318 
319                     // save their identity
320                     received_remote_key = true;
321                   }
322 
323                 wrapped->request_service();
324 
325               }
326               return true;
327 
328             case anonymous_cmd:
329             case auth_cmd:
330             case automate_cmd:
331               return handle_service_request();
332 
333             case confirm_cmd:
334               {
335                 authenticated = true; // maybe?
336                 completed_hello = true;
337                 wrapped->accept_service();
338               }
339               return true;
340 
341             case bye_cmd:
342               {
343                 u8 phase;
344                 cmd_in.read_bye_cmd(phase);
345                 return process_bye_cmd(phase, guard);
346               }
347             case error_cmd:
348               {
349                 string errmsg;
350                 cmd_in.read_error_cmd(errmsg);
351 
352                 // "xxx string" with xxx being digits means there's an error code
353                 if (errmsg.size() > 4 && errmsg.substr(3,1) == " ")
354                   {
355                     try
356                       {
357                         int err = boost::lexical_cast<int>(errmsg.substr(0,3));
358                         if (err >= 100)
359                           {
360                             error_code = err;
361                             throw bad_decode(F("received network error: %s")
362                                              % errmsg.substr(4));
363                           }
364                       }
365                     catch (boost::bad_lexical_cast)
366                       { // ok, so it wasn't a number
367                       }
368                   }
369                 throw bad_decode(F("received network error: %s") % errmsg);
370               }
371             default:
372               // ERROR
373               return false;
374             }
375         }
376     }
377   catch (netsync_error & err)
378     {
379       W(F("error: %s") % err.msg);
380       string const errmsg(lexical_cast<string>(error_code) + " " + err.msg);
381       L(FL("queueing 'error' command"));
382       netcmd cmd(get_version());
383       cmd.write_error_cmd(errmsg);
384       write_netcmd(cmd);
385       encountered_error = true;
386       return true; // Don't terminate until we've send the error_cmd.
387     }
388 }
389 
390 void
request_netsync(protocol_role role,globish const & our_include_pattern,globish const & our_exclude_pattern)391 session::request_netsync(protocol_role role,
392                          globish const & our_include_pattern,
393                          globish const & our_exclude_pattern)
394 {
395   MM(use_transport_auth);
396   id nonce2(mk_nonce());
397   netcmd request(version);
398   rsa_oaep_sha_data hmac_key_encrypted;
399   if (use_transport_auth)
400     project.db.encrypt_rsa(remote_peer_key_id, nonce2(), hmac_key_encrypted);
401 
402   if (use_transport_auth && signing_key.inner()() != "")
403     {
404       // get our key pair
405       load_key_pair(keys, signing_key);
406 
407       // make a signature with it;
408       // this also ensures our public key is in the database
409       rsa_sha1_signature sig;
410       keys.make_signature(project.db, signing_key, hello_nonce(), sig);
411 
412       request.write_auth_cmd(role, our_include_pattern, our_exclude_pattern,
413                              signing_key, hello_nonce,
414                              hmac_key_encrypted, sig);
415     }
416   else
417     {
418       request.write_anonymous_cmd(role, our_include_pattern, our_exclude_pattern,
419                                   hmac_key_encrypted);
420     }
421   write_netcmd(request);
422   set_session_key(nonce2());
423 
424   key_identity_info remote_key;
425   remote_key.id = remote_peer_key_id;
426   if (!remote_key.id.inner()().empty())
427     project.complete_key_identity_from_id(keys, app.lua, remote_key);
428 
429   wrapped->on_begin(session_id, remote_key);
430 }
431 
432 void
request_automate()433 session::request_automate()
434 {
435   MM(use_transport_auth);
436   id nonce2(mk_nonce());
437   rsa_oaep_sha_data hmac_key_encrypted;
438   rsa_sha1_signature sig;
439   if (use_transport_auth)
440     {
441       project.db.encrypt_rsa(remote_peer_key_id, nonce2(), hmac_key_encrypted);
442 
443       if (signing_key.inner()() != "")
444         {
445           keys.make_signature(project.db, signing_key, hello_nonce(), sig);
446         }
447     }
448 
449   netcmd request(version);
450   request.write_automate_cmd(signing_key, hello_nonce,
451                              hmac_key_encrypted, sig);
452   write_netcmd(request);
453   set_session_key(nonce2());
454 
455   key_identity_info remote_key;
456   remote_key.id = remote_peer_key_id;
457   if (!remote_key.id.inner()().empty())
458     project.complete_key_identity_from_id(keys, app.lua, remote_key);
459 
460   wrapped->on_begin(session_id, remote_key);
461 }
462 
463 void
queue_bye_cmd(u8 phase)464 session::queue_bye_cmd(u8 phase)
465 {
466   L(FL("queueing 'bye' command, phase %d")
467     % static_cast<size_t>(phase));
468   netcmd cmd(get_version());
469   cmd.write_bye_cmd(phase);
470   write_netcmd(cmd);
471 }
472 
473 bool
process_bye_cmd(u8 phase,transaction_guard & guard)474 session::process_bye_cmd(u8 phase,
475                          transaction_guard & guard)
476 {
477 
478 // Ideal shutdown
479 // ~~~~~~~~~~~~~~~
480 //
481 //             I/O events                 state transitions
482 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~   ~~~~~~~~~~~~~~~~~~~
483 //                                        client: C_WORKING
484 //                                        server: S_WORKING
485 // 0. [refinement, data, deltas, etc.]
486 //                                        client: C_SHUTDOWN
487 //                                        (client checkpoints here)
488 // 1. client -> "bye 0"
489 // 2.           "bye 0"  -> server
490 //                                        server: S_SHUTDOWN
491 //                                        (server checkpoints here)
492 // 3.           "bye 1"  <- server
493 // 4. client <- "bye 1"
494 //                                        client: C_CONFIRMED
495 // 5. client -> "bye 2"
496 // 6.           "bye 2"  -> server
497 //                                        server: S_CONFIRMED
498 // 7. [server drops connection]
499 //
500 //
501 // Affects of I/O errors or disconnections
502 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
503 //   C_WORKING: report error and fault
504 //   S_WORKING: report error and recover
505 //  C_SHUTDOWN: report error and fault
506 //  S_SHUTDOWN: report success and recover
507 //              (and warn that client might falsely see error)
508 // C_CONFIRMED: report success
509 // S_CONFIRMED: report success
510 
511   switch (phase)
512     {
513     case 0:
514       if (voice == server_voice &&
515           protocol_state == working_state)
516         {
517           protocol_state = shutdown_state;
518           guard.do_checkpoint();
519           queue_bye_cmd(1);
520         }
521       else
522         error(error_codes::bad_command,
523               "unexpected bye phase 0 received");
524       break;
525 
526     case 1:
527       if (voice == client_voice &&
528           protocol_state == shutdown_state)
529         {
530           protocol_state = confirmed_state;
531           queue_bye_cmd(2);
532         }
533       else
534         error(error_codes::bad_command, "unexpected bye phase 1 received");
535       break;
536 
537     case 2:
538       if (voice == server_voice &&
539           protocol_state == shutdown_state)
540         {
541           protocol_state = confirmed_state;
542           return false;
543         }
544       else
545         error(error_codes::bad_command, "unexpected bye phase 2 received");
546       break;
547 
548     default:
549       error(error_codes::bad_command,
550             (F("unknown bye phase %d received") % phase).str());
551     }
552 
553   return true;
554 }
555 
556 static
557 protocol_role
corresponding_role(protocol_role their_role)558 corresponding_role(protocol_role their_role)
559 {
560   switch (their_role)
561     {
562     case source_role:
563       return sink_role;
564     case source_and_sink_role:
565       return source_and_sink_role;
566     case sink_role:
567       return source_role;
568     }
569   I(false);
570 }
571 
handle_service_request()572 bool session::handle_service_request()
573 {
574   enum { is_netsync, is_automate } is_what;
575   bool auth;
576 
577   // netsync parameters
578   protocol_role role;
579   globish their_include;
580   globish their_exclude;
581 
582   // auth parameters
583   key_id client_id;
584   id nonce1;
585   rsa_sha1_signature sig;
586 
587   rsa_oaep_sha_data hmac_encrypted;
588 
589 
590   switch (cmd_in.get_cmd_code())
591     {
592     case anonymous_cmd:
593       cmd_in.read_anonymous_cmd(role, their_include, their_exclude,
594                                 hmac_encrypted);
595       L(FL("received 'anonymous' netcmd from client for pattern '%s' excluding '%s' "
596            "in %s mode\n")
597         % their_include % their_exclude
598         % (role == source_and_sink_role ? _("source and sink") :
599            (role == source_role ? _("source") : _("sink"))));
600 
601       is_what = is_netsync;
602       auth = false;
603       break;
604     case auth_cmd:
605       cmd_in.read_auth_cmd(role, their_include, their_exclude,
606                            client_id, nonce1, hmac_encrypted, sig);
607       L(FL("received 'auth(hmac)' netcmd from client '%s' for pattern '%s' "
608            "exclude '%s' in %s mode with nonce1 '%s'\n")
609         % client_id % their_include % their_exclude
610         % (role == source_and_sink_role ? _("source and sink") :
611            (role == source_role ? _("source") : _("sink")))
612         % nonce1);
613       is_what = is_netsync;
614       auth = true;
615       break;
616     case automate_cmd:
617       cmd_in.read_automate_cmd(client_id, nonce1, hmac_encrypted, sig);
618       is_what = is_automate;
619       auth = true;
620       break;
621     default:
622       I(false);
623     }
624   set_session_key(hmac_encrypted);
625 
626   if (auth && !project.db.public_key_exists(client_id))
627     {
628       key_name their_name;
629       keypair their_pair;
630       if (keys.maybe_get_key_pair(client_id, their_name, their_pair))
631         {
632           project.db.put_key(their_name, their_pair.pub);
633         }
634       else
635         {
636           auth = false;
637         }
638     }
639   if (auth)
640     {
641       if (!(nonce1 == saved_nonce))
642         {
643           error(error_codes::failed_identification,
644                 F("detected replay attack in auth netcmd").str());
645         }
646       if (project.db.check_signature(client_id, nonce1(), sig) != cert_ok)
647         {
648           error(error_codes::failed_identification,
649                 F("bad client signature").str());
650         }
651       authenticated = true;
652       remote_peer_key_id = client_id;
653     }
654 
655   switch (is_what)
656     {
657     case is_netsync:
658       wrapped.reset(new netsync_session(this,
659                                         app.opts, app.lua,
660                                         project,
661                                         keys,
662                                         corresponding_role(role),
663                                         their_include,
664                                         their_exclude,
665                                         connection_counts::create()));
666       break;
667     case is_automate:
668       wrapped.reset(new automate_session(app, this, 0, 0));
669       break;
670     }
671 
672   key_identity_info client_identity;
673   if (authenticated)
674     {
675       client_identity.id = client_id;
676       if (!client_identity.id.inner()().empty())
677         project.complete_key_identity_from_id(keys, app.lua, client_identity);
678     }
679 
680   wrapped->on_begin(session_id, client_identity);
681   wrapped->prepare_to_confirm(client_identity, use_transport_auth);
682 
683   netcmd cmd(get_version());
684   cmd.write_confirm_cmd();
685   write_netcmd(cmd);
686 
687 
688   completed_hello = true;
689   authenticated = true;
690   return true;
691 }
692 
write_netcmd(netcmd const & cmd)693 void session::write_netcmd(netcmd const & cmd)
694 {
695   if (!encountered_error)
696   {
697     string buf;
698     cmd.write(buf, write_hmac);
699     queue_output(buf);
700     L(FL("queued outgoing netcmd of type '%d'") % cmd.get_cmd_code());
701   }
702   else
703     L(FL("dropping outgoing netcmd of type '%d' (because we're in error unwind mode)")
704       % cmd.get_cmd_code());
705 }
706 
get_version() const707 u8 session::get_version() const
708 {
709   return version;
710 }
711 
get_voice() const712 protocol_voice session::get_voice() const
713 {
714   return voice;
715 }
716 
get_peer() const717 string session::get_peer() const
718 {
719   return peer;
720 }
721 
get_error_code() const722 int session::get_error_code() const
723 {
724   return error_code;
725 }
726 
get_authenticated() const727 bool session::get_authenticated() const
728 {
729   return authenticated;
730 }
731 
732 void
error(int errcode,string const & errmsg)733 session::error(int errcode, string const & errmsg)
734 {
735   error_code = errcode;
736   throw netsync_error(errmsg);
737 }
738 
739 void
note_bytes_in(int count)740 session::note_bytes_in(int count)
741 {
742   if (wrapped)
743     {
744       wrapped->note_bytes_in(count + unnoted_bytes_in);
745       unnoted_bytes_in = 0;
746     }
747   else
748     unnoted_bytes_in += count;
749 }
750 
751 void
note_bytes_out(int count)752 session::note_bytes_out(int count)
753 {
754   if (wrapped)
755     {
756       wrapped->note_bytes_out(count + unnoted_bytes_out);
757       unnoted_bytes_out = 0;
758     }
759   else
760     unnoted_bytes_out += count;
761 }
762 
763 // Local Variables:
764 // mode: C++
765 // fill-column: 76
766 // c-file-style: "gnu"
767 // indent-tabs-mode: nil
768 // End:
769 // vim: et:sw=2:sts=2:ts=2:cino=>2s,{s,\:s,+s,t0,g0,^-2,e-2,n-2,p2s,(0,=s:
770