1 /*
2  *  UFTP - UDP based FTP with multicast
3  *
4  *  Copyright (C) 2001-2020   Dennis A. Bush, Jr.   bush@tcnj.edu
5  *
6  *  This program is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This program is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  *  Additional permission under GNU GPL version 3 section 7
20  *
21  *  If you modify this program, or any covered work, by linking or
22  *  combining it with the OpenSSL project's OpenSSL library (or a
23  *  modified version of that library), containing parts covered by the
24  *  terms of the OpenSSL or SSLeay licenses, the copyright holder
25  *  grants you additional permission to convey the resulting work.
26  *  Corresponding Source for a non-source form of such a combination
27  *  shall include the source code for the parts of OpenSSL used as well
28  *  as that of the covered work.
29  */
30 
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <stdarg.h>
35 #include <sys/types.h>
36 #include <sys/stat.h>
37 #include <fcntl.h>
38 #include <time.h>
39 #include <errno.h>
40 #include <math.h>
41 
42 #ifdef WINDOWS
43 
44 #include <io.h>
45 #include <winsock2.h>
46 #include <ws2tcpip.h>
47 #include <iphlpapi.h>
48 #include <Mswsock.h>
49 
50 #include "uftp.h"
51 #include "uftp_common.h"
52 #include "encryption.h"
53 #include "win_func.h"
54 
getiflist(struct iflist * list,int * len)55 void getiflist(struct iflist *list, int *len)
56 {
57     IP_ADAPTER_ADDRESSES *head, *curr;
58     IP_ADAPTER_UNICAST_ADDRESS *uni;
59     char *buf;
60     int buflen, err, i;
61 
62     buflen = 100000;
63     buf = safe_calloc(buflen, 1);
64     head = (IP_ADAPTER_ADDRESSES *)buf;
65     if ((err = GetAdaptersAddresses(AF_UNSPEC, 0, NULL, head,
66                                     &buflen)) != ERROR_SUCCESS) {
67         char errbuf[300];
68         FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, err,
69                       0, errbuf, sizeof(errbuf), NULL);
70         log0(0, 0, 0, "GetAdaptersAddresses failed: (%d) %s", err, errbuf);
71         free(buf);
72         return;
73     }
74     for (*len = 0, curr = head; curr; curr = curr->Next) {
75         if (curr->IfType == IF_TYPE_TUNNEL) continue;
76         for (uni = curr->FirstUnicastAddress; uni; uni = uni->Next) {
77             if (curr->OperStatus == IfOperStatusUp) {
78                 memset(&list[*len], 0, sizeof(struct iflist));
79                 strncpy(list[*len].name, (char *)curr->AdapterName,
80                         sizeof(list[i].name) - 1);
81                 memcpy(&list[*len].su, uni->Address.lpSockaddr,
82                         uni->Address.iSockaddrLength);
83                 list[*len].isloopback =
84                         (curr->IfType == IF_TYPE_SOFTWARE_LOOPBACK);
85                 list[*len].ismulti =
86                         ((curr->Flags & IP_ADAPTER_NO_MULTICAST) == 0);
87                 if (uni->Address.lpSockaddr->sa_family == AF_INET6) {
88                     list[*len].ifidx = curr->Ipv6IfIndex;
89                 } else {
90                     list[*len].ifidx = curr->IfIndex;
91                 }
92                 (*len)++;
93             }
94         }
95     }
96     free(buf);
97 }
98 
99 #else  /*if WINDOWS*/
100 
101 #include <libgen.h>
102 #include <netinet/in.h>
103 #include <unistd.h>
104 #include <sys/ioctl.h>
105 #include <sys/types.h>
106 #include <sys/socket.h>
107 #include <sys/time.h>
108 #include <net/if.h>
109 #include <sys/statvfs.h>
110 
111 #include "uftp.h"
112 #include "uftp_common.h"
113 #include "encryption.h"
114 
115 #ifdef HAS_GETIFADDRS
116 
117 #include <ifaddrs.h>
118 
getiflist(struct iflist * list,int * len)119 void getiflist(struct iflist *list, int *len)
120 {
121     struct ifaddrs *ifa, *ifa_tmp;
122     int count;
123     unsigned ifidx;
124 
125     if (getifaddrs(&ifa) == -1) {
126         syserror(0, 0, 0, "getifaddrs failed");
127         *len = 0;
128         return;
129     }
130     ifa_tmp = ifa;
131     count = *len;
132     *len = 0;
133     while (ifa_tmp && (*len < count)) {
134         if ((ifidx = if_nametoindex(ifa_tmp->ifa_name)) == 0) {
135             syserror(0, 0, 0, "Error getting interface index for interface %s",
136                               ifa_tmp->ifa_name);
137         } else if (ifa_tmp->ifa_addr &&
138                    ((ifa_tmp->ifa_addr->sa_family == AF_INET) ||
139                     (ifa_tmp->ifa_addr->sa_family == AF_INET6)) &&
140                    ((ifa_tmp->ifa_flags & IFF_UP) != 0)) {
141             memset(&list[*len], 0, sizeof(struct iflist));
142             strncpy(list[*len].name, ifa_tmp->ifa_name,
143                     sizeof(list[*len].name) - 1);
144             memcpy(&list[*len].su, ifa_tmp->ifa_addr,
145                     sizeof(struct sockaddr_storage));
146             list[*len].isloopback = (ifa_tmp->ifa_flags & IFF_LOOPBACK) != 0;
147             list[*len].ismulti = (ifa_tmp->ifa_flags & IFF_MULTICAST) != 0;
148             list[*len].ifidx = ifidx;
149             (*len)++;
150         }
151         ifa_tmp = ifa_tmp->ifa_next;
152     }
153     freeifaddrs(ifa);
154 }
155 
156 #else
157 
getiflist(struct iflist * list,int * len)158 void getiflist(struct iflist *list, int *len)
159 {
160     int s, i, count;
161     struct lifconf ifc;
162     struct lifreq *ifr, ifr_tmp_flags, ifr_tmp_ifidx;
163 
164     if (*len <= 0) return;
165     count = *len;
166     ifr = safe_malloc(sizeof(struct lifreq) * count);
167     ifc.lifc_family = AF_UNSPEC;
168     ifc.lifc_flags = 0;
169     ifc.lifc_len = sizeof(struct lifreq) * count;
170     ifc.lifc_req = ifr;
171 
172     if ((s = socket(AF_INET, SOCK_DGRAM, 0)) == -1) {
173         sockerror(0, 0, 0, "Error creating socket for interface list");
174         free(ifr);
175         *len = 0;
176         return;
177     }
178     if (ioctl(s, SIOCGLIFCONF, &ifc) == -1) {
179         syserror(0, 0, 0, "Error getting interface list");
180         free(ifr);
181         close(s);
182         *len = 0;
183         return;
184     }
185     count = ifc.lifc_len / sizeof(struct lifreq);
186     for (i = 0, *len = 0; i < count; i++) {
187         strcpy(ifr_tmp_flags.lifr_name, ifr[i].lifr_name);
188         if (ioctl(s, SIOCGLIFFLAGS, &ifr_tmp_flags) == -1) {
189             syserror(0, 0, 0, "Error getting flags for interface %s",
190                               ifr[i].lifr_name);
191             continue;
192         }
193         strcpy(ifr_tmp_ifidx.lifr_name, ifr[i].lifr_name);
194         if (ioctl(s, SIOCGLIFINDEX, &ifr_tmp_ifidx) == -1) {
195             syserror(0, 0, 0, "Error getting interface index for interface %s",
196                               ifr[i].lifr_name);
197             continue;
198         }
199         if (((ifr[i].lifr_addr.ss_family == AF_INET) ||
200                 (ifr[i].lifr_addr.ss_family == AF_INET6)) &&
201                 ((ifr_tmp_flags.lifr_flags & IFF_UP) != 0)) {
202             memset(&list[*len], 0, sizeof(struct iflist));
203             strncpy(list[*len].name,ifr[i].lifr_name, sizeof(list[i].name) - 1);
204             memcpy(&list[*len].su, &ifr[i].lifr_addr,
205                     sizeof(struct sockaddr_storage));
206             list[*len].isloopback =
207                     (ifr_tmp_flags.lifr_flags & IFF_LOOPBACK) != 0;
208             list[*len].ismulti = (ifr_tmp_flags.lifr_flags & IFF_MULTICAST)!=0;
209             list[*len].ifidx = ifr_tmp_ifidx.lifr_index;
210             (*len)++;
211         }
212     }
213     free(ifr);
214     close(s);
215 }
216 
217 #endif /*if Sun*/
218 
219 #ifdef VMS
GENERIC_SETSID(void)220 pid_t GENERIC_SETSID(void) { return(0); }
221 #endif
222 
223 #endif /*if WINDOWS*/
224 
usec_to_tv(int64_t t)225 struct timeval usec_to_tv(int64_t t)
226 {
227     struct timeval tv;
228     // TODO: Y2038 issue, switch to timespec / clock_gettime
229     tv.tv_sec = (long)(t / 1000000);
230     tv.tv_usec = t % 1000000;
231     return tv;
232 }
233 
tv_to_usec(struct timeval tv)234 int64_t tv_to_usec(struct timeval tv)
235 {
236     return (int64_t)tv.tv_sec * 1000000 + tv.tv_usec;
237 }
238 
diff_sec(struct timeval t2,struct timeval t1)239 int32_t diff_sec(struct timeval t2, struct timeval t1)
240 {
241     return t2.tv_sec - t1.tv_sec;
242 }
243 
diff_usec(struct timeval t2,struct timeval t1)244 int64_t diff_usec(struct timeval t2, struct timeval t1)
245 {
246     return (t2.tv_usec - t1.tv_usec) +
247             (int64_t)1000000 * (t2.tv_sec - t1.tv_sec);
248 }
249 
cmptimestamp(struct timeval t1,struct timeval t2)250 int cmptimestamp(struct timeval t1, struct timeval t2)
251 {
252     if (t1.tv_sec > t2.tv_sec) {
253         return 1;
254     } else if (t1.tv_sec < t2.tv_sec) {
255         return -1;
256     } else if (t1.tv_usec > t2.tv_usec) {
257         return 1;
258     } else if (t1.tv_usec < t2.tv_usec) {
259         return -1;
260     } else {
261         return 0;
262     }
263 }
264 
add_timeval(struct timeval t2,struct timeval t1)265 struct timeval add_timeval(struct timeval t2, struct timeval t1)
266 {
267     struct timeval result;
268 
269     result.tv_sec = t2.tv_sec + t1.tv_sec;
270     result.tv_usec = t2.tv_usec + t1.tv_usec;
271     while (result.tv_usec >= 1000000) {
272         result.tv_usec -= 1000000;
273         result.tv_sec++;
274     }
275     return result;
276 }
277 
add_timeval_d(struct timeval * t2,double t1)278 void add_timeval_d(struct timeval *t2, double t1)
279 {
280     t2->tv_sec += (long)(floor(t1) + 0);
281     t2->tv_usec += (long)((t1 - floor(t1)) * 1000000);
282     while (t2->tv_usec >= 1000000) {
283         t2->tv_usec -= 1000000;
284         t2->tv_sec++;
285     }
286 }
287 
diff_timeval(struct timeval t2,struct timeval t1)288 struct timeval diff_timeval(struct timeval t2, struct timeval t1)
289 {
290     struct timeval result;
291 
292     result.tv_sec = t2.tv_sec - t1.tv_sec;
293     result.tv_usec = t2.tv_usec - t1.tv_usec;
294     while (result.tv_usec < 0) {
295         result.tv_usec += 1000000;
296         result.tv_sec--;
297     }
298     return result;
299 }
300 
301 /**
302  * Gets the name of the UFTP message type for the given message type constant
303  */
func_name(int func)304 const char *func_name(int func)
305 {
306     switch (func) {
307     case ANNOUNCE:
308         return "ANNOUNCE";
309     case REGISTER:
310         return "REGISTER";
311     case CLIENT_KEY:
312         return "CLIENT_KEY";
313     case REG_CONF:
314         return "REG_CONF";
315     case KEYINFO:
316         return "KEYINFO";
317     case KEYINFO_ACK:
318         return "KEYINFO_ACK";
319     case FILEINFO:
320         return "FILEINFO";
321     case FILEINFO_ACK:
322         return "FILEINFO_ACK";
323     case FILESEG:
324         return "FILESEG";
325     case DONE:
326         return "DONE";
327     case STATUS:
328         return "STATUS";
329     case COMPLETE:
330         return "COMPLETE";
331     case DONE_CONF:
332         return "DONE_CONF";
333     case HB_REQ:
334         return "HB_REQ";
335     case HB_RESP:
336         return "HB_RESP";
337     case KEY_REQ:
338         return "KEY_REQ";
339     case PROXY_KEY:
340         return "PROXY_KEY";
341     case ENCRYPTED:
342         return "ENCRYPTED";
343     case ABORT:
344         return "ABORT";
345     case CONG_CTRL:
346         return "CONG_CTRL";
347     case CC_ACK:
348         return "CC_ACK";
349     default:
350         return "UNKNOWN";
351   }
352 }
353 
354 /**
355  * Gets the name of the EC curve for the given EC curve constant.
356  */
curve_name(int curve)357 const char *curve_name(int curve)
358 {
359     switch (curve) {
360     case CURVE_secp256r1:
361         return "prime256v1";
362     case CURVE_secp384r1:
363         return "secp384r1";
364     case CURVE_secp521r1:
365         return "secp521r1";
366     default:
367         return "UNKNOWN";
368     }
369 }
370 
371 /**
372  * Gets the EC curve constant for the given curve name.
373  * Returns 0 if the name is invalid
374  */
get_curve(const char * name)375 uint8_t get_curve(const char *name)
376 {
377     if (!strcmp(name, "secp256r1")) {
378         return CURVE_secp256r1;
379     } else if (!strcmp(name, "secp384r1")) {
380         return CURVE_secp384r1;
381     } else if (!strcmp(name, "secp521r1")) {
382         return CURVE_secp521r1;
383     } else if (!strcmp(name, "prime256v1")) {
384         return CURVE_prime256v1;
385     } else {
386         return 0;
387     }
388 }
389 
390 char logfile[MAXPATHNAME];
391 int showtime;
392 FILE *applog;
393 int log_level, init_log_mux, use_log_mux, max_log_count;
394 f_offset_t log_size, max_log_size;
395 mux_t log_mux;
396 
397 static int rolling = 0;
398 
399 /**
400  * Initialize the log file.
401  */
init_log(int _debug)402 void init_log(int _debug)
403 {
404     use_log_mux = 0;
405     if (init_log_mux) {
406         if (mux_create(log_mux)) {
407             perror("Failed to create log mutex");
408             exit(ERR_LOGGING);
409         }
410     }
411 
412     if (strcmp(logfile, "") && !_debug) {
413         int fd;
414         stat_struct statbuf;
415 
416         if ((lstat_func(logfile, &statbuf) != -1) && S_ISREG(statbuf.st_mode)) {
417             log_size = statbuf.st_size;
418         } else {
419             log_size = 0;
420         }
421         if ((fd = open(logfile, O_WRONLY | O_APPEND | O_CREAT, 0644)) == -1) {
422             perror("Can't open log file");
423             exit(ERR_LOGGING);
424         }
425         dup2(fd, 2);
426         close(fd);
427 
428         showtime = 1;
429     } else {
430         log_size = 0;
431         max_log_size = 0;
432         max_log_count = 0;
433     }
434     applog = stderr;
435 }
436 
437 /**
438  * Close log file
439  */
close_log()440 void close_log()
441 {
442     if (init_log_mux) {
443         mux_destroy(log_mux);
444     }
445     fclose(applog);
446 }
447 
448 /**
449  * Rolls the log file.
450  */
roll_log()451 void roll_log()
452 {
453     char oldname[MAXPATHNAME], newname[MAXPATHNAME];
454     int rval, fd, i;
455 
456     if (rolling) return;
457     rolling = 1;
458     log2(0, 0, 0, "Rolling logs");
459     for (i = max_log_count; i >=0; i--) {
460         if (i == 0) {
461             rval = snprintf(oldname, sizeof(oldname), "%s", logfile);
462             if  (rval >= sizeof(oldname)) {
463                 log0(0, 0, 0, "Old log name too long");
464                 rolling = 0;
465                 return;
466             }
467         } else {
468             rval = snprintf(oldname, sizeof(oldname), "%s.%d", logfile, i);
469             if  (rval >= sizeof(oldname)) {
470                 log0(0, 0, 0, "Old log name too long");
471                 rolling = 0;
472                 return;
473             }
474         }
475         rval = snprintf(newname, sizeof(newname), "%s.%d", logfile, i + 1);
476         if  (rval >= sizeof(oldname)) {
477             log0(0, 0, 0, "New log name too long");
478             rolling = 0;
479             return;
480         }
481         if (i == max_log_count) {
482             if (unlink(oldname) == -1) {
483                 syserror(0, 0, 0, "Couldn't remove log %s", oldname);
484             }
485         } else if (i == 0) {
486 #ifdef WINDOWS
487             log2(0, 0, 0, "Switching to new log");
488             close(2);
489             if (rename(oldname, newname) == -1) {
490                 printf("Couldn't rename log %s to %s", oldname, newname);
491                 exit(ERR_LOGGING);
492             }
493             if ((fd=open(logfile, O_WRONLY | O_APPEND | O_CREAT, 0644)) == -1) {
494                 printf("Can't open log file");
495                 exit(ERR_LOGGING);
496             }
497             log_size = 0;
498             log2(0, 0, 0, "Switch to new log complete");
499 #else
500             if (rename(oldname, newname) == -1) {
501                 syserror(0, 0, 0, "Couldn't rename log %s to %s",
502                                   oldname, newname);
503             }
504             log2(0, 0, 0, "Opening new log");
505             if ((fd=open(logfile, O_WRONLY | O_APPEND | O_CREAT, 0644)) == -1) {
506                 syserror(0, 0, 0, "Can't open log file");
507                 exit(ERR_LOGGING);
508             }
509             log2(0, 0, 0, "Switching to new log");
510             dup2(fd, 2);
511             close(fd);
512             log_size = 0;
513             log2(0, 0, 0, "Switch to new log complete");
514 #endif
515         } else {
516             if (rename(oldname, newname) == -1) {
517                 syserror(0, 0, 0, "Couldn't rename log %s to %s",
518                                   oldname, newname);
519             }
520         }
521     }
522     rolling = 0;
523 }
524 
525 /**
526  * The main logging function.
527  * Called via a series of macros for a particular log level or output format.
528  */
logfunc(uint32_t group_id,uint8_t group_inst,uint16_t file_id,int level,int _showtime,int newline,int err,int sockerr,const char * str,...)529 void logfunc(uint32_t group_id, uint8_t group_inst, uint16_t file_id,
530              int level, int _showtime, int newline, int err, int sockerr,
531              const char *str, ...)
532 {
533     struct tm *timeval;
534     struct timeval tv;
535     time_t t;
536     va_list args;
537     int write_len;
538 
539     if (level > log_level) return;
540     if (use_log_mux && !rolling) {
541         if (mux_lock(log_mux)) {
542             write_len = fprintf(applog, "Failed to lock log mutex\n");
543             if (write_len != -1) log_size += write_len;
544         }
545     }
546     if (_showtime) {
547         gettimeofday(&tv, NULL);
548         // In Windows, tv.tv_sec is long, not time_t
549         t = tv.tv_sec;
550         timeval = localtime(&t);
551         write_len = fprintf(applog, "%04d/%02d/%02d %02d:%02d:%02d.%06d: ",
552                 timeval->tm_year + 1900, timeval->tm_mon + 1, timeval->tm_mday,
553                 timeval->tm_hour, timeval->tm_min, timeval->tm_sec,
554                 (int)tv.tv_usec);
555         if (write_len != -1) log_size += write_len;
556         if (group_id) {
557             if (file_id) {
558                 write_len = fprintf(applog, "[%08X/%02X:%04X]: ",
559                                     group_id, group_inst, file_id);
560             } else {
561                 write_len = fprintf(applog, "[%08X/%02X:0]: ",
562                                     group_id, group_inst);
563             }
564             if (write_len != -1) log_size += write_len;
565         }
566     }
567     va_start(args, str);
568     write_len = vfprintf(applog, str, args);
569     if (write_len != -1) log_size += write_len;
570     va_end(args);
571     if (sockerr) {
572 #ifdef WINDOWS
573         char errbuf[300];
574         FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, WSAGetLastError(),
575                       0, errbuf, sizeof(errbuf), NULL);
576         write_len = fprintf(applog, ": (%d) %s", WSAGetLastError(), errbuf);
577         newline = 0;
578 #else
579         write_len = fprintf(applog, ": %s", strerror(err));
580 #endif
581         if (write_len != -1) log_size += write_len;
582     } else if (err) {
583         write_len = fprintf(applog, ": %s", strerror(err));
584         if (write_len != -1) log_size += write_len;
585     }
586     if (newline) {
587         write_len = fprintf(applog, "\n");
588         if (write_len != -1) log_size += write_len;
589     }
590     fflush(applog);
591     if ((max_log_size > 0) && (log_size > max_log_size)) {
592         roll_log();
593     }
594     if (use_log_mux && !rolling) {
595         if (mux_unlock(log_mux)) {
596             write_len = fprintf(applog, "Failed to unlock log mutex\n");
597             if (write_len != -1) log_size += write_len;
598             fflush(applog);
599         }
600     }
601 }
602 
603 /**
604  * Takes a pathname and splits it into a directory part and file part.
605  * The caller is expected to clean up *dir and *file.
606  */
split_path(const char * path,char ** dir,char ** file)607 void split_path(const char *path, char **dir, char **file)
608 {
609 #ifdef WINDOWS
610     char *result, *filename;
611     DWORD len, len2;
612 
613     if (strlen(path) == 0) {
614         *dir = NULL;
615         *file = NULL;
616         return;
617     }
618 
619     // GetFullPathNameA doens't handle trailing slashes well, so disallow
620     if ((path[strlen(path)-1] == '/') || (path[strlen(path)-1] == '\\')) {
621         log0(0, 0, 0, "bad path, trailing / or \\ not allowed");
622         *dir = NULL;
623         *file = NULL;
624         return;
625     }
626 
627     len = GetFullPathNameA(path, 0, NULL, &filename);
628     if (len == 0) {
629         char errbuf[300];
630         FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL,
631                 GetLastError(), 0, errbuf, sizeof(errbuf), NULL);
632         log0(0, 0, 0, "Error in GetFullPathNameA: %s", errbuf);
633         *dir = NULL;
634         *file = NULL;
635         return;
636     }
637 
638     *dir = NULL;
639     *file = NULL;
640     result = safe_malloc(len);
641     if ((len2 = GetFullPathNameA(path, len, result, &filename)) <= len) {
642         *dir = strdup(result);
643         *file = strdup(filename);
644         if (!*dir || (filename && !*file)) {
645             syserror(0, 0, 0, "strdup failed!");
646             exit(ERR_ALLOC);
647         }
648         (*dir)[strlen(*dir) - strlen(*file) - 1] = '\x0';
649     }
650     free(result);
651 #else
652     char *dirc, *filec;
653 
654     dirc = strdup(path);
655     filec = strdup(path);
656     if (!dirc || !filec) {
657         syserror(0, 0, 0, "strdup failed!");
658         exit(ERR_ALLOC);
659     }
660     *dir = strdup(dirname(dirc));
661     *file = strdup(basename(filec));
662     if (!*dir || !*file) {
663         syserror(0, 0, 0, "strdup failed!");
664         exit(ERR_ALLOC);
665     }
666     free(dirc);
667     free(filec);
668 #endif
669 }
670 
671 /**
672  * Parses a key fingerprint string and saves it to the specified buffer
673  * Returns 1 on success, 0 on fail
674  */
parse_fingerprint(unsigned char * fingerprint,const char * fingerprint_str)675 int parse_fingerprint(unsigned char *fingerprint, const char *fingerprint_str)
676 {
677     char *p, *tmp, *saveptr;
678     int num, len;
679 
680     if (fingerprint_str == NULL) {
681         return 0;
682     }
683     tmp = strdup(fingerprint_str);
684     len = 0;
685     saveptr = NULL;
686     p = strtok_r(tmp, ":", &saveptr);
687     if (p == NULL) {
688         log1(0, 0, 0, "Invalid fingerprint %s", fingerprint_str);
689         free(tmp);
690         return 0;
691     }
692     do {
693         if (len >= HMAC_LEN) {
694             log1(0, 0, 0, "Key fingerprint %s too long", fingerprint_str);
695             free(tmp);
696             return 0;
697         }
698         errno = 0;
699         num = strtol(p, NULL, 16);
700         if (errno) {
701             syserror(0, 0, 0, "Parse of host key fingerprint %s failed",
702                               fingerprint_str);
703             free(tmp);
704             return 0;
705         } else if ((num > 255) || (num < 0)) {
706             log1(0, 0, 0, "Parse of host key fingerprint %s failed",
707                           fingerprint_str);
708             free(tmp);
709             return 0;
710         }
711         fingerprint[len++] = (uint8_t)num;
712         p = strtok_r(NULL, ":", &saveptr);
713     } while (p);
714     free(tmp);
715     return 1;
716 }
717 
718 /**
719  * Looks up a host in a list of fingerprint structs
720  * Returns NULL if not found
721  */
fp_lookup(uint32_t id,struct fp_list_t * list,int count)722 struct fp_list_t *fp_lookup(uint32_t id, struct fp_list_t* list, int count)
723 {
724     int i;
725 
726     for (i = 0; i < count; i++) {
727         if (list[i].uid == id) {
728             return &list[i];
729         }
730     }
731     return NULL;
732 }
733 
734 /**
735  * Tests a sockaddr_u union to see if it's a valid multicast address
736  */
is_multicast(const union sockaddr_u * addr,int ssm)737 int is_multicast(const union sockaddr_u *addr, int ssm)
738 {
739     int val;
740 
741     if (addr->ss.ss_family == AF_INET6) {
742         if (addr->sin6.sin6_addr.s6_addr[0] == 0xff) {
743             if (ssm && ((addr->sin6.sin6_addr.s6_addr[1] & 0x30) == 0)) {
744                 return 0;
745             } else {
746                 return 1;
747             }
748         } else {
749             return 0;
750         }
751     } else if (addr->ss.ss_family == AF_INET) {
752         val = ntohl(addr->sin.sin_addr.s_addr) >> 24;
753         if (ssm && (val != 232)) {
754             return 0;
755         } else if ((val >= 224) && (val < 240)) {
756             return 1;
757         } else {
758             return 0;
759         }
760     } else {
761         return 0;
762     }
763 }
764 
765 /**
766  * Compares two sockaddr_u unions for equality
767  * Returns 1 if address family, address, and port are equal, 0 otherwise
768  */
addr_equal(const union sockaddr_u * addr1,const union sockaddr_u * addr2)769 int addr_equal(const union sockaddr_u *addr1, const union sockaddr_u *addr2)
770 {
771     if (addr1->ss.ss_family != addr2->ss.ss_family) {
772         return 0;
773     }
774     if (addr1->ss.ss_family == AF_INET6) {
775         if ((!memcmp(&addr1->sin6.sin6_addr, &addr2->sin6.sin6_addr,
776                     sizeof(struct in6_addr))) &&
777                 (addr1->sin6.sin6_port == addr2->sin6.sin6_port)) {
778             return 1;
779         } else {
780             return 0;
781         }
782     } else {
783         if ((addr1->sin.sin_addr.s_addr == addr2->sin.sin_addr.s_addr) &&
784                 (addr1->sin.sin_port == addr2->sin.sin_port)) {
785             return 1;
786         } else {
787             return 0;
788         }
789     }
790 }
791 
792 /**
793  * Checks to see if a sockaddr_u union has a zero address
794  * Returns 1 if the address is zero (for the given family), 0 otherwise.
795  */
addr_blank(const union sockaddr_u * addr)796 int addr_blank(const union sockaddr_u *addr)
797 {
798     if (addr->ss.ss_family == AF_INET6) {
799         return (memcmp(&addr->sin6.sin6_addr, &in6addr_any,
800                          sizeof(struct in6_addr)) == 0);
801     } else if (addr->ss.ss_family == AF_INET) {
802         return (addr->sin.sin_addr.s_addr == INADDR_ANY);
803     } else {
804         return 1;
805     }
806 }
807 
808 /**
809  * Converts a 64-bit value from host to network byte order
810  */
uftp_htonll(uint64_t val)811 uint64_t uftp_htonll(uint64_t val)
812 {
813     uint64_t rval;
814     int i;
815     unsigned char *p;
816 
817     p = (unsigned char *)&rval;
818     for (i = 0; i < 8; i++) {
819         p[7 - i] = (val & (0xFFULL << (i * 8))) >> (i * 8);
820     }
821     return rval;
822 }
823 
824 /**
825  * Converts a 64-bit value from network to host byte order
826  */
uftp_ntohll(uint64_t val)827 uint64_t uftp_ntohll(uint64_t val)
828 {
829     uint64_t rval;
830     int i;
831     unsigned char *p;
832 
833     p = (unsigned char *)&val;
834     for (i = 0, rval = 0; i < 8; i++) {
835         rval |= (uint64_t)p[i] << ((7 - i) * 8);
836     }
837     return rval;
838 }
839 
840 /**
841  * Returns the effective length of a sockaddr type struct
842  * based on the address family
843  */
family_len(union sockaddr_u addr)844 int family_len(union sockaddr_u addr)
845 {
846     if (addr.ss.ss_family == AF_INET6) {
847         return sizeof(struct sockaddr_in6);
848     } else if (addr.ss.ss_family == AF_INET) {
849         return sizeof(struct sockaddr_in);
850     } else {
851         return sizeof(struct sockaddr_storage);
852     }
853 }
854 
855 /**
856  * Returns whether the last socket operation would have blocked
857  */
would_block_err()858 int would_block_err()
859 {
860 #ifdef WINDOWS
861     return (WSAGetLastError() == WSAEWOULDBLOCK);
862 #else
863     return (errno == EAGAIN);
864 #endif
865 }
866 
867 /**
868  * Returns whether a connection reset error occurred
869  */
conn_reset_err(void)870 int conn_reset_err(void)
871 {
872 #ifdef WINDOWS
873     return (WSAGetLastError() == WSAECONNRESET);
874 #else
875     return (errno == ECONNRESET);
876 #endif
877 }
878 
879 /**
880  * Returns whether a call was interrupted
881  */
interrupted_err(void)882 int interrupted_err(void)
883 {
884 #ifdef WINDOWS
885     return (WSAGetLastError() == WSAEINTR);
886 #else
887     return (errno == EINTR);
888 #endif
889 }
890 
891 /**
892  * Calls sendto, retrying if the send would block.
893  * The calling function should check for and log any other errors.
894  */
nb_sendto(SOCKET s,const void * msg,int len,int flags,const struct sockaddr * to,int tolen)895 int nb_sendto(SOCKET s, const void *msg, int len, int flags,
896               const struct sockaddr *to, int tolen)
897 {
898     int retry, sentlen;
899 
900     retry = 1;
901     while (retry) {
902         if ((sentlen = sendto(s, msg, len, flags, to, tolen)) == SOCKET_ERROR) {
903             if (!would_block_err()) {
904                 return -1;
905             }
906         } else {
907             retry = 0;
908         }
909     }
910     return sentlen;
911 }
912 
913 /**
914  * Reads a packet off the network with a possible timeout.
915  * The socket must be non-blocking.
916  * Returns 1 on success, 0 on timeout, -1 on fail.
917  */
read_packet(SOCKET sock,union sockaddr_u * su,unsigned char * buffer,int * len,int bsize,const struct timeval * timeout,uint8_t * tos)918 int read_packet(SOCKET sock, union sockaddr_u *su, unsigned char *buffer,
919                 int *len, int bsize, const struct timeval *timeout,
920                 uint8_t *tos)
921 {
922     fd_set fdin;
923     struct timeval tv;
924     int rval;
925 #ifdef WINDOWS
926     GUID WSARecvMsg_GUID = WSAID_WSARECVMSG;
927     static LPFN_WSARECVMSG WSARecvMsg;
928     int nbytes;
929     WSAMSG mhdr;
930     WSABUF iov;
931     WSACMSGHDR *cmhdr;
932     char control[1000];
933 #elif defined NO_RECVMSG
934     socklen_t addr_len;
935 #else
936     struct msghdr mhdr;
937     struct iovec iov;
938     struct cmsghdr *cmhdr;
939     char control[1000];
940 #endif
941 
942     while (1) {
943 #ifdef WINDOWS
944         if (WSARecvMsg == NULL) {
945             rval = WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER,
946                     &WSARecvMsg_GUID, sizeof WSARecvMsg_GUID,
947                     &WSARecvMsg, sizeof WSARecvMsg, &nbytes, NULL, NULL);
948             if (rval == SOCKET_ERROR) {
949                 sockerror(0, 0, 0, "WSAIoctl for WSARecvMsg failed");
950                 exit(ERR_SOCKET);
951             }
952         }
953         mhdr.name = (LPSOCKADDR)su;
954         mhdr.namelen = sizeof(union sockaddr_u);
955         mhdr.lpBuffers = &iov;
956         mhdr.dwBufferCount = 1;
957         mhdr.Control.buf = control;
958         mhdr.Control.len = sizeof(control);
959         mhdr.dwFlags = 0;
960         iov.buf = buffer;
961         iov.len = bsize;
962         if ((rval = WSARecvMsg(sock, &mhdr, len, NULL, NULL)) == SOCKET_ERROR) {
963             if (!would_block_err()) {
964                 if (!conn_reset_err()) {
965                     sockerror(0, 0, 0, "Error receiving");
966                 }
967                 return -1;
968             }
969         } else {
970             *tos = 0;
971             cmhdr = WSA_CMSG_FIRSTHDR(&mhdr);
972             while (cmhdr) {
973                 if ((cmhdr->cmsg_level == IPPROTO_IP &&
974                         cmhdr->cmsg_type == IP_TCLASS) ||
975                         (cmhdr->cmsg_level == IPPROTO_IPV6 &&
976                          cmhdr->cmsg_type == IPV6_TCLASS)) {
977                     *tos = ((uint8_t *)WSA_CMSG_DATA(cmhdr))[0];
978                 }
979                 cmhdr = WSA_CMSG_NXTHDR(&mhdr, cmhdr);
980             }
981             log5(0, 0, 0, "tos / traffic class byte = %02X", *tos);
982             return 1;
983         }
984 #elif defined NO_RECVMSG
985         addr_len = sizeof(union sockaddr_u);
986         if ((*len = recvfrom(sock, buffer, bsize, 0, (struct sockaddr *)su,
987                              &addr_len)) == SOCKET_ERROR) {
988             if (!would_block_err()) {
989                 if (!conn_reset_err()) {
990                     sockerror(0, 0, 0, "Error receiving");
991                 }
992                 return -1;
993             }
994         } else {
995             return 1;
996         }
997 #else
998         mhdr.msg_name = su;
999         mhdr.msg_namelen = sizeof(union sockaddr_u);
1000         mhdr.msg_iov = &iov;
1001         mhdr.msg_iovlen = 1;
1002         mhdr.msg_control = &control;
1003         mhdr.msg_controllen = sizeof(control);
1004         iov.iov_base = buffer;
1005         iov.iov_len = bsize;
1006         if ((*len = recvmsg(sock, &mhdr, 0)) == SOCKET_ERROR) {
1007             if (!would_block_err()) {
1008                 if (!conn_reset_err()) {
1009                     sockerror(0, 0, 0, "Error receiving");
1010                 }
1011                 return -1;
1012             }
1013         } else {
1014             *tos = 0;
1015             cmhdr = CMSG_FIRSTHDR(&mhdr);
1016             while (cmhdr) {
1017                 int istos;
1018 #ifdef IP_RECVTOS
1019 #if defined IPV6_TCLASS && !defined WINDOWS
1020                 istos = ((cmhdr->cmsg_level == IPPROTO_IP &&
1021                         (cmhdr->cmsg_type == IP_TOS ||
1022                         cmhdr->cmsg_type == IP_RECVTOS)) ||
1023                         (cmhdr->cmsg_level == IPPROTO_IPV6 &&
1024                          cmhdr->cmsg_type == IPV6_TCLASS));
1025 #else
1026                 istos = (cmhdr->cmsg_level == IPPROTO_IP &&
1027                         (cmhdr->cmsg_type == IP_TOS ||
1028                         cmhdr->cmsg_type == IP_RECVTOS));
1029 #endif
1030 #else  // IP_RECVTOS
1031 #if defined IPV6_TCLASS && !defined WINDOWS
1032                 istos = ((cmhdr->cmsg_level == IPPROTO_IP &&
1033                         cmhdr->cmsg_type == IP_TOS) ||
1034                         (cmhdr->cmsg_level == IPPROTO_IPV6 &&
1035                          cmhdr->cmsg_type == IPV6_TCLASS));
1036 #else
1037                 istos = (cmhdr->cmsg_level == IPPROTO_IP &&
1038                         cmhdr->cmsg_type == IP_TOS);
1039 #endif
1040 #endif  // IP_RECVTOS
1041                 if (istos) {
1042                     *tos = ((uint8_t *)CMSG_DATA(cmhdr))[0];
1043                 }
1044                 cmhdr = CMSG_NXTHDR(&mhdr, cmhdr);
1045             }
1046             log5(0, 0, 0, "tos / traffic class byte = %02X", *tos);
1047             return 1;
1048         }
1049 #endif
1050 
1051         FD_ZERO(&fdin);
1052         FD_SET(sock,&fdin);
1053         if (timeout) tv = *timeout;
1054         if ((rval = select(FD_SETSIZE-1, &fdin, NULL, NULL,
1055                            (timeout ? &tv : NULL))) == SOCKET_ERROR) {
1056             if (!interrupted_err()) {
1057                 sockerror(0, 0, 0, "Select failed");
1058             }
1059             return -1;
1060         }
1061         if (rval == 0) {
1062             return 0;
1063         } else if (!FD_ISSET(sock, &fdin)) {
1064             log0(0, 0, 0, "Unknown select error");
1065             return -1;
1066         }
1067     }
1068 }
1069 
1070 /**
1071  * Performs an XOR between p1 and p2, storing the result in p1
1072  */
memxor(void * p1,const void * p2,int len)1073 void memxor(void *p1, const void *p2, int len)
1074 {
1075     int i;
1076 
1077     for (i = 0; i < len; i++) {
1078         ((unsigned char *)p1)[i] ^= ((const unsigned char *)p2)[i];
1079     }
1080 }
1081 
1082 /**
1083  * Constructs an initialization vector (IV) as follows:
1084  * For a 128-bit IV (AES non-auth): IV = S + src_ID + ctr
1085  * For a 96-bit IV (AES auth):      IV = (S XOR src_ID) + ctr
1086  * For a 64-bit IV (DES, 3DES):     IV = (S + src_ID) XOR ctr
1087  * All values other should be in network byte order.
1088  */
build_iv4(uint8_t * iv,const uint8_t * salt,int ivlen,uint64_t ivctr,uint32_t src_id)1089 void build_iv4(uint8_t *iv, const uint8_t *salt, int ivlen, uint64_t ivctr,
1090                uint32_t src_id)
1091 {
1092     char tmp[16], tmp2[16];
1093     int tmplen, tmp2len;
1094 
1095     memset(tmp, 0, sizeof(tmp));
1096     tmplen = 0;
1097     if (ivlen == 8) {
1098         memcpy(tmp, salt, SALT_LEN);
1099         tmplen = SALT_LEN;
1100         memcpy(tmp + tmplen, &src_id, sizeof(uint32_t));
1101         tmplen += sizeof(uint32_t);
1102         memcpy(tmp2, &ivctr, sizeof(uint64_t));
1103         tmp2len = sizeof(uint64_t);
1104         memxor(tmp, tmp2, tmp2len);
1105     } else if (ivlen == 12) {
1106         memcpy(tmp, salt, SALT_LEN);
1107         tmplen = SALT_LEN;
1108         memcpy(tmp2, &src_id, sizeof(uint32_t));
1109         tmp2len = sizeof(uint32_t);
1110         memxor(tmp, tmp2, tmp2len);
1111         memcpy(tmp + tmplen, &ivctr, sizeof(uint64_t));
1112         tmplen += sizeof(uint64_t);
1113     } else if (ivlen == 16) {
1114         memcpy(tmp, salt, SALT_LEN);
1115         tmplen = SALT_LEN;
1116         memcpy(tmp + tmplen, &src_id, sizeof(uint32_t));
1117         tmplen += sizeof(uint32_t);
1118         memcpy(tmp + tmplen, &ivctr, sizeof(uint64_t));
1119         tmplen += sizeof(uint64_t);
1120     }
1121     memcpy(iv, tmp, tmplen);
1122 }
1123 
1124 /**
1125  * Builds an IV as: S XOR CTR
1126  * All values should be in network byte order.
1127  */
build_iv(uint8_t * iv,const uint8_t * salt,int ivlen,uint64_t ivctr)1128 void build_iv(uint8_t *iv, const uint8_t *salt, int ivlen, uint64_t ivctr)
1129 {
1130     memcpy(iv, salt, ivlen);
1131     memxor(iv + (ivlen - sizeof(uint64_t)), &ivctr, sizeof(uint64_t));
1132 }
1133 
1134 /**
1135  * Outputs data buffers to log in hex.
1136  * Used only for debugging
1137  */
printhex(const char * name,const unsigned char * data,int len)1138 void printhex(const char *name, const unsigned char *data, int len)
1139 {
1140     int i;
1141 
1142     sclog2("%s:", name);
1143     for (i = 0; i < len; i++) {
1144         sclog2(" %02X", data[i]);
1145         if (i % 16 == 15) sclog2("\n");
1146     }
1147     sclog2("\n");
1148 }
1149 
1150 /**
1151  * Returns 1 if the specified keytype is an authentication mode cipher
1152  */
is_auth_enc(int keytype)1153 int is_auth_enc(int keytype)
1154 {
1155     return ((keytype == KEY_AES128_GCM) || (keytype == KEY_AES256_GCM) ||
1156             (keytype == KEY_AES128_CCM) || (keytype == KEY_AES256_CCM));
1157 }
1158 
1159 /**
1160  * Returns 1 if the specified keytype is a GCM mode cipher
1161  */
is_gcm_mode(int keytype)1162 int is_gcm_mode(int keytype)
1163 {
1164     return ((keytype == KEY_AES128_GCM) || (keytype == KEY_AES256_GCM));
1165 }
1166 
1167 /**
1168  * Returns 1 if the specified keytype is a CCM mode cipher
1169  */
is_ccm_mode(int keytype)1170 int is_ccm_mode(int keytype)
1171 {
1172     return ((keytype == KEY_AES128_CCM) || (keytype == KEY_AES256_CCM));
1173 }
1174 
1175 /**
1176  * If the specified keytype is for an authentication cipher,
1177  * return the keytype for the same cipher in CBC mode.
1178  */
unauth_key(int keytype)1179 int unauth_key(int keytype)
1180 {
1181     switch (keytype) {
1182     case KEY_AES128_GCM:
1183     case KEY_AES128_CCM:
1184         return KEY_AES128_CBC;
1185     case KEY_AES256_GCM:
1186     case KEY_AES256_CCM:
1187         return KEY_AES256_CBC;
1188     default:
1189         return keytype;
1190     }
1191 }
1192 
1193 /**
1194  * Verify the signature of an encrypted message and decrypt.
1195  * The decrypted message is returned without a uftp_h header.
1196  * Returns 1 on success, 0 on fail
1197  */
validate_and_decrypt(unsigned char * encpacket,unsigned int enclen,unsigned char ** decpacket,unsigned int * declen,int keytype,const uint8_t * key,const uint8_t * salt,int ivlen)1198 int validate_and_decrypt(unsigned char *encpacket, unsigned int enclen,
1199                          unsigned char **decpacket, unsigned int *declen,
1200                          int keytype, const uint8_t *key,
1201                          const uint8_t *salt, int ivlen)
1202 {
1203     struct uftp_h *header;
1204     struct encrypted_h *encrypted;
1205     unsigned char *payload, *iv;
1206     unsigned int rval, allocdec;
1207     uint64_t ivctr;
1208 
1209     header = (struct uftp_h *)encpacket;
1210     encrypted = (struct encrypted_h *)(encpacket + sizeof(struct uftp_h));
1211     payload = (unsigned char *)encrypted + sizeof(struct encrypted_h);
1212 
1213     if (header->func != ENCRYPTED) {
1214         log0(0, 0, 0, "Attempt to decrypt non-encrypted message");
1215         return 0;
1216     }
1217     if (enclen != (sizeof(struct uftp_h) + sizeof(struct encrypted_h) +
1218             ntohs(encrypted->payload_len))) {
1219         log0(0, 0, 0, "Invalid signature and/or encrypted payload length");
1220         return 0;
1221     }
1222 
1223     iv = safe_calloc(ivlen, 1);
1224     allocdec = 0;
1225     if (*decpacket == NULL) {
1226         allocdec = 1;
1227         *decpacket = safe_calloc(MAXMTU + KEYBLSIZE, 1);
1228     }
1229     ivctr = ntohl(encrypted->iv_ctr_lo);
1230     ivctr |= (uint64_t)ntohl(encrypted->iv_ctr_hi) << 32;
1231     if (header->version == UFTP4_VER_NUM) {
1232         build_iv4(iv, salt, ivlen, uftp_htonll(ivctr), header->src_id);
1233     } else {
1234         build_iv(iv, salt, ivlen, uftp_htonll(ivctr));
1235     }
1236     if (!decrypt_block(keytype, iv, key, encpacket,
1237             sizeof(struct uftp_h) + sizeof(struct encrypted_h),
1238             payload, ntohs(encrypted->payload_len), *decpacket, declen)) {
1239         log0(0, 0, 0, "Decrypt failed");
1240         if (allocdec) {
1241             free(*decpacket);
1242             *decpacket = NULL;
1243         }
1244         rval = 0;
1245         goto end;
1246     }
1247 
1248     rval = 1;
1249 
1250 end:
1251     free(iv);
1252     return rval;
1253 }
1254 
1255 /**
1256  * Encrypts a message and attaches a signature to the encrypted message.
1257  * The incoming message should include a uftp_h header.
1258  * Returns 1 on success, 0 on fail
1259  */
encrypt_and_sign(const unsigned char * decpacket,unsigned char ** encpacket,int declen,int * enclen,int keytype,uint8_t * key,const uint8_t * salt,uint64_t * ivctr,int ivlen)1260 int encrypt_and_sign(const unsigned char *decpacket, unsigned char **encpacket,
1261                      int declen, int *enclen, int keytype, uint8_t *key,
1262                      const uint8_t *salt, uint64_t *ivctr, int ivlen)
1263 {
1264     struct uftp_h *header;
1265     struct encrypted_h *encrypted;
1266     const unsigned char *mheader;
1267     unsigned char *payload, *iv;
1268     unsigned int payloadlen, allocenc;
1269 
1270     allocenc = 0;
1271     if (*encpacket == NULL) {
1272         allocenc = 1;
1273         *encpacket = safe_calloc(MAXMTU + KEYBLSIZE, 1);
1274     }
1275     iv = safe_calloc(ivlen, 1);
1276 
1277     mheader = decpacket + sizeof(struct uftp_h);
1278     header = (struct uftp_h *)*encpacket;
1279     encrypted = (struct encrypted_h *)(*encpacket + sizeof(struct uftp_h));
1280     payload = (unsigned char *)encrypted + sizeof(struct encrypted_h);
1281 
1282     (*ivctr)++;
1283     memcpy(*encpacket, decpacket, sizeof(struct uftp_h));
1284     header->func = ENCRYPTED;
1285     encrypted->iv_ctr_hi = htonl((*ivctr & 0xFFFFFFFF00000000ULL) >> 32);
1286     encrypted->iv_ctr_lo = htonl(*ivctr & 0x00000000FFFFFFFFULL);
1287     if (is_gcm_mode(keytype)) {
1288         encrypted->payload_len = htons(declen + GCM_TAG_LEN);
1289     } else if (is_ccm_mode(keytype)) {
1290         encrypted->payload_len = htons(declen + CCM_TAG_LEN);
1291     } else {
1292         log0(0, 0, 0, "Invalid cipher mode for keytype %d", keytype);
1293         return 0;
1294     }
1295 
1296     if (header->version == UFTP4_VER_NUM) {
1297         build_iv4(iv, salt, ivlen, uftp_htonll(*ivctr), header->src_id);
1298     } else {
1299         build_iv(iv, salt, ivlen, uftp_htonll(*ivctr));
1300     }
1301     if (!encrypt_block(keytype, iv, key, *encpacket,
1302             sizeof(struct uftp_h) + sizeof(struct encrypted_h),
1303             mheader, declen, payload, &payloadlen)) {
1304         // Called function should log
1305         free(iv);
1306         if (allocenc) {
1307             free(*encpacket);
1308             *encpacket = NULL;
1309         }
1310         return 0;
1311     }
1312     free(iv);
1313     if (payloadlen != ntohs(encrypted->payload_len)) {
1314         log0(0, 0, 0, "Invalid payloadlen: got %d, expected %d",
1315                       payloadlen, ntohs(encrypted->payload_len));
1316         return 0;
1317     }
1318     *enclen = sizeof(struct encrypted_h) + payloadlen;
1319 
1320     return 1;
1321 }
1322 
1323 /**
1324  * Pseudo-random function for an individual hashing algorithm
1325  * as defined in RFC 5246
1326  */
P_hash(int hashtype,int bytes,const unsigned char * secret,int secret_len,const char * label,const unsigned char * seed,int seed_len,unsigned char * outbuf,int * outbuf_len)1327 static void P_hash(int hashtype, int bytes,
1328                    const unsigned char *secret, int secret_len,
1329                    const char *label, const unsigned char *seed, int seed_len,
1330                    unsigned char *outbuf, int *outbuf_len)
1331 {
1332     unsigned char *newseed, *inbuf, *tmpbuf;
1333     unsigned newseed_len, inbuf_len;
1334     unsigned int tmpbuf_len, outbuf_len_new;
1335 
1336     newseed = safe_calloc(strlen(label) + seed_len, 1);
1337     inbuf = safe_calloc(get_hash_len(hashtype) + strlen(label) + seed_len, 1);
1338     tmpbuf = safe_calloc(get_hash_len(hashtype) + strlen(label) + seed_len, 1);
1339 
1340     *outbuf_len = 0;
1341     newseed_len = 0;
1342     memcpy(newseed, label, strlen(label));
1343     newseed_len += (unsigned)strlen(label);
1344     memcpy(newseed + newseed_len, seed, seed_len);
1345     newseed_len += seed_len;
1346 
1347     memcpy(inbuf, newseed, newseed_len);
1348     inbuf_len = newseed_len;
1349     while (*outbuf_len < bytes)
1350     {
1351         create_hmac(hashtype, secret, secret_len, inbuf, inbuf_len,
1352                     tmpbuf, &tmpbuf_len);
1353         memcpy(tmpbuf + tmpbuf_len, newseed, newseed_len);
1354         tmpbuf_len += newseed_len;
1355         create_hmac(hashtype, secret, secret_len, tmpbuf, tmpbuf_len,
1356                     outbuf + *outbuf_len, &outbuf_len_new);
1357         *outbuf_len += outbuf_len_new;
1358         memcpy(inbuf,tmpbuf,tmpbuf_len);
1359         inbuf_len = tmpbuf_len;
1360     }
1361 
1362     free(newseed);
1363     free(inbuf);
1364     free(tmpbuf);
1365 }
1366 
1367 /**
1368  * Pseudo-random function
1369  * as defined in RFC 5246
1370  */
PRF(int hashtype,int bytes,const unsigned char * secret,int secret_len,const char * label,const unsigned char * seed,int seed_len,unsigned char * outbuf,int * outbuf_len)1371 void PRF(int hashtype, int bytes, const unsigned char *secret, int secret_len,
1372          const char *label, const unsigned char *seed, int seed_len,
1373          unsigned char *outbuf, int *outbuf_len)
1374 {
1375     P_hash(hashtype, bytes, secret, secret_len, label,
1376            seed, seed_len, outbuf, outbuf_len);
1377 }
1378 
1379 /**
1380   * Creates Server_HS_Context
1381   * All integer values within the context should be in network byte order
1382   */
create_server_context(uint32_t group_id,uint8_t group_inst,uint32_t server_id,const struct enc_info_he * encinfo,int extlen,uint8_t ** context,int * context_len)1383 void create_server_context(uint32_t group_id, uint8_t group_inst,
1384                            uint32_t server_id,const struct enc_info_he *encinfo,
1385                            int extlen, uint8_t **context, int *context_len)
1386 {
1387     *context = safe_malloc(sizeof(group_id) + sizeof(group_inst) +
1388                            sizeof(server_id) + extlen);
1389     *context_len = 0;
1390     memcpy(*context + *context_len, &group_id, sizeof(group_id));
1391     *context_len += sizeof(group_id);
1392     memcpy(*context + *context_len, &group_inst, sizeof(group_inst));
1393     *context_len += sizeof(group_inst);
1394     memcpy(*context + *context_len, &server_id, sizeof(server_id));
1395     *context_len += sizeof(server_id);
1396     memcpy(*context + *context_len, encinfo, extlen);
1397     *context_len += extlen;
1398 }
1399 
1400 /**
1401   * Creates Proxy_HS_Context
1402   * All integer values within the context should be in network byte order
1403   */
create_proxy_context(uint32_t proxy_id,const struct proxy_key_h * proxykey,uint8_t ** context,int * context_len)1404 void create_proxy_context(uint32_t proxy_id, const struct proxy_key_h *proxykey,
1405                           uint8_t **context, int *context_len)
1406 {
1407     uint8_t *newcontext = safe_malloc(sizeof(proxy_id) + proxykey->hlen * 4);
1408     *context = newcontext;
1409     memcpy(*context + *context_len, &proxy_id, sizeof(proxy_id));
1410     *context_len += sizeof(proxy_id);
1411     memcpy(*context + *context_len, proxykey, proxykey->hlen * 4);
1412     *context_len += proxykey->hlen * 4;
1413 }
1414 
1415 /**
1416   * Creates Client_HS_Context1
1417   * All integer within the context values should be in network byte order
1418   * client_dh is an EC keyblob
1419   */
create_client_context_1(const uint8_t * s_context,int s_context_len,const uint8_t * p_context,int p_context_len,uint32_t client_id,const uint8_t * client_dh,int client_dh_len,const uint8_t * client_rand,uint8_t ** context,int * context_len)1420 void create_client_context_1(const uint8_t *s_context, int s_context_len,
1421                              const uint8_t *p_context, int p_context_len,
1422                              uint32_t client_id, const uint8_t *client_dh,
1423                              int client_dh_len, const uint8_t *client_rand,
1424                              uint8_t **context, int *context_len)
1425 {
1426     *context = safe_malloc(s_context_len + p_context_len + sizeof(client_id) +
1427                            client_dh_len + RAND_LEN);
1428     *context_len = 0;
1429     memcpy(*context + *context_len, s_context, s_context_len);
1430     *context_len += s_context_len;
1431     if (p_context_len) {
1432         memcpy(*context + *context_len, p_context, p_context_len);
1433         *context_len += p_context_len;
1434     }
1435     memcpy(*context + *context_len, &client_id, sizeof(client_id));
1436     *context_len += sizeof(client_id);
1437     memcpy(*context + *context_len, client_dh, client_dh_len);
1438     *context_len += client_dh_len;
1439     memcpy(*context + *context_len, client_rand, RAND_LEN);
1440     *context_len += RAND_LEN;
1441 }
1442 
1443 /**
1444   * Creates Client_HS_Context2
1445   * All integer values within the context should be in network byte order
1446   */
create_client_context_2(const uint8_t * c_context1,int c_context1_len,const struct client_key_h * ckheader,int header_len,uint8_t ** context,int * context_len)1447 void create_client_context_2(const uint8_t *c_context1, int c_context1_len,
1448                              const struct client_key_h *ckheader,int header_len,
1449                              uint8_t **context, int *context_len)
1450 {
1451     *context = safe_malloc(c_context1_len + header_len);
1452     *context_len = 0;
1453     memcpy(*context + *context_len, c_context1, c_context1_len);
1454     *context_len += c_context1_len;
1455     if (header_len) {
1456         memcpy(*context + *context_len, ckheader, header_len);
1457         *context_len += header_len;
1458     }
1459 }
1460 
1461 /**
1462  * HMAC based Key Derivation Function (HKDF) - Extract
1463  * as defined in RFC 5869
1464  */
HKDF_Extract(int hashtype,const unsigned char * salt,unsigned int salt_len,const unsigned char * secret,unsigned int secret_len,unsigned char * outbuf,unsigned int * outbuf_len)1465 void HKDF_Extract(int hashtype,
1466                   const unsigned char *salt, unsigned int salt_len,
1467                   const unsigned char *secret, unsigned int secret_len,
1468                   unsigned char *outbuf, unsigned int *outbuf_len)
1469 {
1470     create_hmac(hashtype, salt ? salt : (const unsigned char *)"",
1471                 salt ? salt_len : 0, secret, secret_len, outbuf, outbuf_len);
1472 }
1473 
1474 /**
1475  * HMAC based Key Derivation Function (HKDF) - Extract
1476  * as defined in RFC 5869
1477  * This can generate more bytes than requested, up to the hash length
1478  */
HKDF_Expand(int hashtype,unsigned int bytes,const unsigned char * secret,unsigned int secret_len,const unsigned char * info,unsigned int info_len,unsigned char * outbuf,unsigned int * outbuf_len)1479 void HKDF_Expand(int hashtype, unsigned int bytes,
1480                  const unsigned char *secret, unsigned int secret_len,
1481                  const unsigned char *info, unsigned int info_len,
1482                  unsigned char *outbuf, unsigned int *outbuf_len)
1483 {
1484     unsigned char *inbuf, *tmpbuf;
1485     unsigned inbuf_len, tmpbuf_len;
1486     unsigned char counter = 1;
1487 
1488     inbuf = safe_malloc(get_hash_len(hashtype) + info_len + 1);
1489     tmpbuf = safe_malloc(get_hash_len(hashtype) + info_len + 1);
1490 
1491     *outbuf_len = 0;
1492     tmpbuf_len = 0;
1493     while (*outbuf_len < bytes) {
1494         memcpy(inbuf, tmpbuf, tmpbuf_len);
1495         inbuf_len = tmpbuf_len;
1496         memcpy(inbuf + inbuf_len, info, info_len);
1497         inbuf_len += info_len;
1498         memcpy(inbuf + inbuf_len, &counter, 1);
1499         inbuf_len += 1;
1500         create_hmac(hashtype, secret, secret_len, inbuf, inbuf_len,
1501                     tmpbuf, &tmpbuf_len);
1502         memcpy(outbuf + *outbuf_len, tmpbuf, tmpbuf_len);
1503         *outbuf_len += tmpbuf_len;
1504         counter++;
1505     }
1506     free(inbuf);
1507     free(tmpbuf);
1508 }
1509 
1510 /**
1511  * HDKF Expand with label
1512  * derived from RFC 8446
1513  * This can generate more bytes than requested, up to the hash length
1514  */
HKDF_Expand_Label(int hashtype,unsigned int bytes,const char * label,const unsigned char * secret,unsigned int secret_len,const unsigned char * context,unsigned int context_len,unsigned char * outbuf,unsigned int * outbuf_len)1515 void HKDF_Expand_Label(int hashtype, unsigned int bytes, const char *label,
1516                        const unsigned char *secret, unsigned int secret_len,
1517                        const unsigned char *context, unsigned int context_len,
1518                        unsigned char *outbuf, unsigned int *outbuf_len)
1519 {
1520     const char *info_str = "UFTP5 ";
1521     unsigned int info_str_len = (unsigned)strlen(info_str);
1522     unsigned int label_len = (unsigned)strlen(label);
1523     unsigned int info_len = info_str_len + label_len + context_len;
1524     unsigned char *info = safe_malloc(info_len);
1525 
1526     memcpy(info, info_str, info_str_len);
1527     memcpy(info + info_str_len, label, label_len);
1528     memcpy(info + info_str_len + label_len, context, context_len);
1529     HKDF_Expand(hashtype, bytes, secret, secret_len, info, info_len,
1530                 outbuf, outbuf_len);
1531     free(info);
1532 }
1533 
1534 /**
1535  * Creates the handshake keys for the client and server
1536  */
calculate_hs_keys(int hashtype,uint8_t * premaster,int premaster_len,uint8_t * client_context1,unsigned int client_context1_len,unsigned int key_len,unsigned int iv_len,uint8_t * server_hs_key,uint8_t * server_hs_iv,uint8_t * client_hs_key,uint8_t * client_hs_iv)1537 void calculate_hs_keys(int hashtype, uint8_t *premaster, int premaster_len,
1538                        uint8_t *client_context1,
1539                        unsigned int client_context1_len,
1540                        unsigned int key_len, unsigned int iv_len,
1541                        uint8_t *server_hs_key, uint8_t *server_hs_iv,
1542                        uint8_t *client_hs_key, uint8_t *client_hs_iv)
1543 {
1544     uint8_t zeros[HASH_LEN] = { 0 };
1545     uint8_t context_hash[HASH_LEN];
1546     uint8_t hs_secret[2*HASH_LEN];
1547     uint8_t server_hs_secret[2*HASH_LEN];
1548     uint8_t client_hs_secret[2*HASH_LEN];
1549     uint8_t tmp_out[2*HASH_LEN];
1550     unsigned int hash_len, hs_secret_len, out_len;
1551 
1552     hash_len = get_hash_len(hashtype);
1553     hash(hashtype, client_context1, client_context1_len, context_hash,&out_len);
1554 
1555     HKDF_Extract(hashtype, zeros, hash_len, premaster, premaster_len,
1556                  hs_secret, &hs_secret_len);
1557     HKDF_Expand_Label(hashtype, hash_len, "s hs traffic", hs_secret,
1558             hs_secret_len, context_hash, hash_len, server_hs_secret, &out_len);
1559     HKDF_Expand_Label(hashtype, hash_len, "c hs traffic", hs_secret,
1560             hs_secret_len, context_hash, hash_len, client_hs_secret, &out_len);
1561 
1562     HKDF_Expand_Label(hashtype, key_len, "key", server_hs_secret, hash_len,
1563                       (const unsigned char *)"", 0, tmp_out, &out_len);
1564     memcpy(server_hs_key, tmp_out, key_len);
1565     HKDF_Expand_Label(hashtype, iv_len, "iv", server_hs_secret, hash_len,
1566                       (const unsigned char *)"", 0, tmp_out, &out_len);
1567     memcpy(server_hs_iv, tmp_out, iv_len);
1568     HKDF_Expand_Label(hashtype, key_len, "key", client_hs_secret, hash_len,
1569                       (const unsigned char *)"", 0, tmp_out, &out_len);
1570     memcpy(client_hs_key, tmp_out, key_len);
1571     HKDF_Expand_Label(hashtype, iv_len, "iv", client_hs_secret, hash_len,
1572                       (const unsigned char *)"", 0, tmp_out, &out_len);
1573     memcpy(client_hs_iv, tmp_out, iv_len);
1574 }
1575 
1576 /**
1577  * Creates the application keys for the server
1578  */
calculate_server_app_keys(int hashtype,uint8_t * groupmaster,int groupmaster_len,uint8_t * server_context,unsigned int server_context_len,unsigned int key_len,unsigned int iv_len,uint8_t * server_app_key,uint8_t * server_app_iv)1579 void calculate_server_app_keys(int hashtype, uint8_t *groupmaster,
1580                                int groupmaster_len, uint8_t *server_context,
1581                                unsigned int server_context_len,
1582                                unsigned int key_len, unsigned int iv_len,
1583                                uint8_t *server_app_key, uint8_t *server_app_iv)
1584 {
1585     uint8_t zeros[HASH_LEN] = { 0 };
1586     uint8_t context_hash[HASH_LEN];
1587     uint8_t app_secret[HASH_LEN];
1588     uint8_t server_app_secret[2*HASH_LEN];
1589     uint8_t tmp_out[2*HASH_LEN];
1590     unsigned int hash_len, app_secret_len, out_len;
1591 
1592     hash_len = get_hash_len(hashtype);
1593     hash(hashtype, server_context, server_context_len, context_hash, &out_len);
1594 
1595     HKDF_Extract(hashtype, zeros, hash_len, groupmaster, groupmaster_len,
1596                  app_secret, &app_secret_len);
1597     HKDF_Expand_Label(hashtype, hash_len, "s app traffic", app_secret,
1598             app_secret_len, context_hash, hash_len, server_app_secret,&out_len);
1599 
1600     HKDF_Expand_Label(hashtype, key_len, "key", server_app_secret, hash_len,
1601                       (const unsigned char *)"", 0, tmp_out, &out_len);
1602     memcpy(server_app_key, tmp_out, key_len);
1603     HKDF_Expand_Label(hashtype, iv_len, "iv", server_app_secret, hash_len,
1604                       (const unsigned char *)"", 0, tmp_out, &out_len);
1605     memcpy(server_app_iv, tmp_out, iv_len);
1606 }
1607 
1608 /**
1609  * Creates the application keys for a client
1610  * Also calculates the finished hash for the client
1611  */
calculate_client_app_keys(int hashtype,uint8_t * groupmaster,int groupmaster_len,uint8_t * client_context2,unsigned int client_context2_len,unsigned int key_len,unsigned int iv_len,uint8_t * client_app_key,uint8_t * client_app_iv,uint8_t * finished_key,uint8_t * verify_data)1612 void calculate_client_app_keys(int hashtype, uint8_t *groupmaster,
1613                                int groupmaster_len, uint8_t *client_context2,
1614                                unsigned int client_context2_len,
1615                                unsigned int key_len, unsigned int iv_len,
1616                                uint8_t *client_app_key, uint8_t *client_app_iv,
1617                                uint8_t *finished_key, uint8_t *verify_data)
1618 {
1619     uint8_t zeros[HASH_LEN] = { 0 };
1620     uint8_t context_hash[HASH_LEN];
1621     uint8_t app_secret[HASH_LEN];
1622     uint8_t client_app_secret[2*HASH_LEN];
1623     uint8_t tmp_out[2*HASH_LEN];
1624     uint8_t *verify_context;
1625     unsigned int hash_len, app_secret_len, verify_context_len, out_len;
1626 
1627     hash_len = get_hash_len(hashtype);
1628     hash(hashtype, client_context2, client_context2_len, context_hash,&out_len);
1629 
1630     HKDF_Extract(hashtype, zeros, hash_len, groupmaster, groupmaster_len,
1631                  app_secret, &app_secret_len);
1632     HKDF_Expand_Label(hashtype, hash_len, "c app traffic", app_secret,
1633             app_secret_len, context_hash, hash_len, client_app_secret,&out_len);
1634 
1635     HKDF_Expand_Label(hashtype, key_len, "key", client_app_secret, hash_len,
1636                       (const unsigned char *)"", 0, tmp_out, &out_len);
1637     memcpy(client_app_key, tmp_out, key_len);
1638     HKDF_Expand_Label(hashtype, iv_len, "iv", client_app_secret, hash_len,
1639                       (const unsigned char *)"", 0, tmp_out, &out_len);
1640     memcpy(client_app_iv, tmp_out, iv_len);
1641     HKDF_Expand_Label(hashtype, hash_len, "finished", client_app_secret,
1642             hash_len, (const unsigned char *)"", 0, tmp_out, &out_len);
1643     memcpy(finished_key, tmp_out, hash_len);
1644 
1645     verify_context_len = 0;
1646     verify_context = safe_malloc(client_context2_len + groupmaster_len);
1647     memcpy(verify_context + verify_context_len, client_context2,
1648             client_context2_len);
1649     verify_context_len += client_context2_len;
1650     memcpy(verify_context + verify_context_len, groupmaster, groupmaster_len);
1651     verify_context_len += groupmaster_len;
1652 
1653     create_hmac(hashtype, finished_key, hash_len, verify_context,
1654                 verify_context_len, verify_data, &out_len);
1655     free(verify_context);
1656 }
1657 
1658 /**
1659  * Outputs a key's fingerprint
1660  */
print_key_fingerprint(const union key_t key,int keytype)1661 const char *print_key_fingerprint(const union key_t key, int keytype)
1662 {
1663     static char fpstr[100];
1664     char *p;
1665     unsigned char *keyblob, fingerprint[HMAC_LEN];
1666     uint16_t bloblen;
1667     unsigned int fplen, i, cnt;
1668 
1669     keyblob = safe_calloc(PUBKEY_LEN, 1);
1670 
1671     if (keytype == KEYBLOB_RSA) {
1672         if (!export_RSA_key(key.rsa, keyblob, &bloblen)) {
1673             free(keyblob);
1674             return NULL;
1675         }
1676     } else {
1677         if (!export_EC_key(key.ec, keyblob, &bloblen)) {
1678             free(keyblob);
1679             return NULL;
1680         }
1681     }
1682     hash(HASH_SHA1, keyblob, bloblen, fingerprint, &fplen);
1683 
1684     for (i = 0, p = fpstr; i < fplen; i++) {
1685         if (i != 0) {
1686             *p = ':';
1687             p++;
1688         }
1689         cnt = snprintf(p, 3, "%02X", fingerprint[i]);
1690         p += cnt;
1691     }
1692 
1693     free(keyblob);
1694     return fpstr;
1695 }
1696 
1697 #if ((!defined WINDOWS) && (defined MCAST_JOIN_GROUP))
1698 
1699 /**
1700  * Join the specified multicast group on the specified list of interfaces.
1701  * If source specific multicast is supported and we're given a list of servers,
1702  * join source specific multicast groups for those servers.
1703  * Returns 1 on success, 0 on fail
1704  */
multicast_join(SOCKET s,uint32_t group_id,const union sockaddr_u * multi,const struct iflist * addrlist,int addrlen,const struct fp_list_t * fplist,int fplist_len)1705 int multicast_join(SOCKET s, uint32_t group_id, const union sockaddr_u *multi,
1706                    const struct iflist *addrlist, int addrlen,
1707                    const struct fp_list_t *fplist, int fplist_len)
1708 {
1709     struct group_req greq = { 0 };
1710     struct group_source_req gsreq = { 0 };
1711     int level = 0, i, j;
1712 
1713     for (i = 0; i < addrlen; i++) {
1714         if (!addrlist[i].ismulti) {
1715             continue;
1716         }
1717         if (addrlist[i].su.ss.ss_family != multi->ss.ss_family) {
1718             continue;
1719         }
1720         if (addrlist[i].su.ss.ss_family == AF_INET6) {
1721             level = IPPROTO_IPV6;
1722         } else if (addrlist[i].su.ss.ss_family == AF_INET) {
1723             level = IPPROTO_IP;
1724         }
1725         if (fplist_len == 0) {
1726             greq.gr_interface = addrlist[i].ifidx;
1727             greq.gr_group = multi->ss;
1728             if (setsockopt(s, level, MCAST_JOIN_GROUP,
1729                     (char *)&greq, sizeof(greq)) == -1) {
1730                 sockerror(group_id, 0, 0, "Error joining multicast group");
1731                 return 0;
1732             }
1733         } else {
1734             for (j = 0; j < fplist_len; j++) {
1735                 if (addrlist[i].su.ss.ss_family!=fplist[j].addr.ss.ss_family) {
1736                     continue;
1737                 }
1738                 gsreq.gsr_interface = addrlist[i].ifidx;
1739                 gsreq.gsr_source = fplist[j].addr.ss;
1740                 gsreq.gsr_group = multi->ss;
1741                 if (setsockopt(s, level, MCAST_JOIN_SOURCE_GROUP,
1742                         (char *)&gsreq, sizeof(gsreq)) == -1) {
1743                     sockerror(group_id, 0, 0, "Error joining multicast group");
1744                     return 0;
1745                 }
1746             }
1747         }
1748     }
1749     return 1;
1750 }
1751 
1752 /**
1753  * Leave the specified multicast group on the specified list of interfaces.
1754  * If source specific multicast is supported and we're given a list of servers,
1755  * leave source specific multicast groups for those servers.
1756  */
multicast_leave(SOCKET s,uint32_t group_id,const union sockaddr_u * multi,const struct iflist * addrlist,int addrlen,const struct fp_list_t * fplist,int fplist_len)1757 void multicast_leave(SOCKET s, uint32_t group_id, const union sockaddr_u *multi,
1758                      const struct iflist *addrlist, int addrlen,
1759                      const struct fp_list_t *fplist, int fplist_len)
1760 {
1761     struct group_req greq = { 0 };
1762     struct group_source_req gsreq = { 0 };
1763     int level = 0, i, j;
1764 
1765     for (i = 0; i < addrlen; i++) {
1766         if (!addrlist[i].ismulti) {
1767             continue;
1768         }
1769         if (addrlist[i].su.ss.ss_family != multi->ss.ss_family) {
1770             continue;
1771         }
1772         if (addrlist[i].su.ss.ss_family == AF_INET6) {
1773             level = IPPROTO_IPV6;
1774         } else if (addrlist[i].su.ss.ss_family == AF_INET) {
1775             level = IPPROTO_IP;
1776         }
1777         if (fplist_len == 0) {
1778             greq.gr_interface = addrlist[i].ifidx;
1779             greq.gr_group = multi->ss;
1780             if (setsockopt(s, level, MCAST_LEAVE_GROUP,
1781                     (char *)&greq, sizeof(greq)) == -1) {
1782                 sockerror(group_id, 0, 0, "Error leaving multicast group");
1783             }
1784         } else {
1785             for (j = 0; j < fplist_len; j++) {
1786                 if (addrlist[i].su.ss.ss_family!=fplist[j].addr.ss.ss_family) {
1787                     continue;
1788                 }
1789                 gsreq.gsr_interface = addrlist[i].ifidx;
1790                 gsreq.gsr_source = fplist[j].addr.ss;
1791                 gsreq.gsr_group = multi->ss;
1792                 if (setsockopt(s, level, MCAST_LEAVE_SOURCE_GROUP,
1793                         (char *)&gsreq, sizeof(gsreq)) == -1) {
1794                     sockerror(group_id, 0, 0, "Error leaving multicast group");
1795                 }
1796             }
1797         }
1798     }
1799 }
1800 
1801 #else
1802 
1803 /**
1804  * Join the specified multicast group on the specified list of interfaces.
1805  * If source specific multicast is supported and we're given a list of servers,
1806  * join source specific multicast groups for those servers.
1807  * Returns 1 on success, 0 on fail
1808  */
multicast_join(SOCKET s,uint32_t group_id,const union sockaddr_u * multi,const struct iflist * addrlist,int addrlen,const struct fp_list_t * fplist,int fplist_len)1809 int multicast_join(SOCKET s, uint32_t group_id, const union sockaddr_u *multi,
1810                    const struct iflist *addrlist, int addrlen,
1811                    const struct fp_list_t *fplist, int fplist_len)
1812 {
1813     struct ip_mreq mreq = { 0 };
1814     struct ipv6_mreq mreq6 = { 0 };
1815     int i;
1816 
1817     for (i = 0; i < addrlen; i++) {
1818         if (!addrlist[i].ismulti) {
1819             continue;
1820         }
1821         if (addrlist[i].su.ss.ss_family != multi->ss.ss_family) {
1822             continue;
1823         }
1824         if (multi->ss.ss_family == AF_INET6) {
1825             mreq6.ipv6mr_multiaddr = multi->sin6.sin6_addr;
1826             mreq6.ipv6mr_interface = addrlist[i].ifidx;
1827             if (setsockopt(s, IPPROTO_IPV6, IPV6_JOIN_GROUP,
1828                            (char *)&mreq6, sizeof(mreq6)) == SOCKET_ERROR) {
1829                 sockerror(group_id, 0, 0, "Error joining multicast group");
1830                 return 0;
1831             }
1832         } else {
1833 #ifdef IP_ADD_SOURCE_MEMBERSHIP
1834             if (fplist_len != 0) {
1835                 int j;
1836                 for (j = 0; j < fplist_len; j++) {
1837                     struct ip_mreq_source srcmreq;
1838                     srcmreq.imr_multiaddr = multi->sin.sin_addr;
1839                     srcmreq.imr_sourceaddr = fplist[j].addr.sin.sin_addr;
1840                     srcmreq.imr_interface = addrlist[i].su.sin.sin_addr;
1841                     if (setsockopt(s, IPPROTO_IP, IP_ADD_SOURCE_MEMBERSHIP,
1842                            (char *)&srcmreq, sizeof(srcmreq)) == SOCKET_ERROR) {
1843                         sockerror(group_id, 0, 0,
1844                                   "Error joining multicast group");
1845                         return 0;
1846                     }
1847                 }
1848             } else {
1849                 mreq.imr_multiaddr = multi->sin.sin_addr;
1850                 mreq.imr_interface = addrlist[i].su.sin.sin_addr;
1851                 if (setsockopt(s, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1852                                (char *)&mreq, sizeof(mreq)) == SOCKET_ERROR) {
1853                     sockerror(group_id, 0, 0, "Error joining multicast group");
1854                     return 0;
1855                 }
1856             }
1857 #else
1858             mreq.imr_multiaddr = multi->sin.sin_addr;
1859             mreq.imr_interface = addrlist[i].su.sin.sin_addr;
1860             if (setsockopt(s, IPPROTO_IP, IP_ADD_MEMBERSHIP,
1861                            (char *)&mreq, sizeof(mreq)) == SOCKET_ERROR) {
1862                 sockerror(group_id, 0, 0, "Error joining multicast group");
1863                 return 0;
1864             }
1865 #endif
1866         }
1867     }
1868     return 1;
1869 }
1870 
1871 /**
1872  * Leave the specified multicast group on the specified list of interfaces.
1873  * If source specific multicast is supported and we're given a list of servers,
1874  * leave source specific multicast groups for those servers.
1875  */
multicast_leave(SOCKET s,uint32_t group_id,const union sockaddr_u * multi,const struct iflist * addrlist,int addrlen,const struct fp_list_t * fplist,int fplist_len)1876 void multicast_leave(SOCKET s, uint32_t group_id, const union sockaddr_u *multi,
1877                      const struct iflist *addrlist, int addrlen,
1878                      const struct fp_list_t *fplist, int fplist_len)
1879 {
1880     struct ip_mreq mreq = { 0 };
1881     struct ipv6_mreq mreq6 = { 0 };
1882     int i;
1883 
1884     for (i = 0; i < addrlen; i++) {
1885         if (!addrlist[i].ismulti) {
1886             continue;
1887         }
1888         if (addrlist[i].su.ss.ss_family != multi->ss.ss_family) {
1889             continue;
1890         }
1891         if (multi->ss.ss_family == AF_INET6) {
1892             mreq6.ipv6mr_multiaddr = multi->sin6.sin6_addr;
1893             mreq6.ipv6mr_interface = addrlist[i].ifidx;
1894             if (setsockopt(s, IPPROTO_IPV6, IPV6_LEAVE_GROUP,
1895                            (char *)&mreq6, sizeof(mreq6)) == SOCKET_ERROR) {
1896                 sockerror(group_id, 0, 0, "Error leaving multicast group");
1897             }
1898         } else {
1899 #ifdef IP_DROP_SOURCE_MEMBERSHIP
1900             if (fplist_len != 0) {
1901                 int j;
1902                 for (j = 0; j < fplist_len; j++) {
1903                     struct ip_mreq_source srcmreq;
1904                     srcmreq.imr_multiaddr = multi->sin.sin_addr;
1905                     srcmreq.imr_sourceaddr = fplist[j].addr.sin.sin_addr;
1906                     srcmreq.imr_interface = addrlist[i].su.sin.sin_addr;
1907                     if (setsockopt(s, IPPROTO_IP, IP_DROP_SOURCE_MEMBERSHIP,
1908                            (char *)&srcmreq, sizeof(srcmreq)) == SOCKET_ERROR) {
1909                         sockerror(group_id, 0, 0,
1910                                   "Error leaving multicast group");
1911                     }
1912                 }
1913             } else {
1914                 mreq.imr_multiaddr = multi->sin.sin_addr;
1915                 mreq.imr_interface = addrlist[i].su.sin.sin_addr;
1916                 if (setsockopt(s, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1917                                (char *)&mreq, sizeof(mreq)) == SOCKET_ERROR) {
1918                     sockerror(group_id, 0, 0, "Error leaving multicast group");
1919                 }
1920             }
1921 #else
1922             mreq.imr_multiaddr = multi->sin.sin_addr;
1923             mreq.imr_interface = addrlist[i].su.sin.sin_addr;
1924             if (setsockopt(s, IPPROTO_IP, IP_DROP_MEMBERSHIP,
1925                            (char *)&mreq, sizeof(mreq)) == SOCKET_ERROR) {
1926                 sockerror(group_id, 0, 0, "Error leaving multicast group");
1927             }
1928 #endif
1929         }
1930     }
1931 }
1932 
1933 #endif // MCAST_JOIN_GROUP
1934 
1935 /**
1936  * Search for a network interface in a list with the matching name or index.
1937  * The name is formatted as interface/ip_version, ex. eth0/6, 2/4.
1938  * If ip_version is not given, defaults to IPv4.
1939  * Returns the index in the list if found, -1 if not found.
1940  */
getifbyname(const char * name,const struct iflist * list,int len)1941 int getifbyname(const char *name, const struct iflist *list, int len)
1942 {
1943     char *tmpname, *p, *ptr;
1944     int family, idx, i;
1945 
1946     tmpname = strdup(name);
1947     if (tmpname == NULL) {
1948         syserror(0, 0, 0, "strdup failed!");
1949         exit(ERR_ALLOC);
1950     }
1951 
1952     p = strchr(tmpname, '/');
1953     if (p == NULL) {
1954         family = AF_INET;
1955     } else {
1956         p[0] = 0;
1957         if (p[1] == '6') {
1958             family = AF_INET6;
1959         } else if (p[1] == '4') {
1960             family = AF_INET;
1961         } else {
1962             free(tmpname);
1963             return -1;
1964         }
1965     }
1966 
1967     errno = 0;
1968     idx = strtoul(tmpname, &ptr, 10);
1969     if ((errno == 0) && (*ptr == '\x0')) {
1970         for (i = 0; i < len; i++) {
1971             if ((idx == list[i].ifidx) && (list[i].su.ss.ss_family == family)) {
1972                 free(tmpname);
1973                 return i;
1974             }
1975         }
1976     } else {
1977         for (i = 0; i < len; i++) {
1978             if ((!strcmp(tmpname, list[i].name)) &&
1979                     (list[i].su.ss.ss_family == family)) {
1980                 free(tmpname);
1981                 return i;
1982             }
1983         }
1984     }
1985     free(tmpname);
1986     return -1;
1987 }
1988 
1989 /**
1990  * Search for a network interface in a list with the matching IP address.
1991  * Returns the index in the list if found, -1 if not found.
1992  */
getifbyaddr(union sockaddr_u * su,const struct iflist * list,int len)1993 int getifbyaddr(union sockaddr_u *su, const struct iflist *list, int len)
1994 {
1995     int i;
1996 
1997     for (i = 0; i < len; i++) {
1998         if (su->ss.ss_family == list[i].su.ss.ss_family) {
1999             if (su->ss.ss_family == AF_INET) {
2000                 if (su->sin.sin_addr.s_addr == list[i].su.sin.sin_addr.s_addr) {
2001                     return i;
2002                 }
2003             } else if (su->ss.ss_family == AF_INET6) {
2004                 if (!memcmp(&su->sin6.sin6_addr, &list[i].su.sin6.sin6_addr,
2005                         sizeof(struct in6_addr))) {
2006                     return i;
2007                 }
2008             }
2009         }
2010     }
2011     return -1;
2012 }
2013 
2014 /**
2015  * Reads buflen bytes into buf from the given file descriptor.
2016  * If buflen bytes are read, returns buflen.
2017  * If 0 bytes are read, returns 0 if allow_eof is true, otherwise returns -1.
2018  * If less that buflen bytes are read, or on error, returns -1.
2019  */
file_read(int fd,void * buf,int buflen,int allow_eof)2020 int file_read(int fd, void *buf, int buflen, int allow_eof)
2021 {
2022     int read_len;
2023 
2024     if ((read_len = read(fd, buf, buflen)) == -1) {
2025         syserror(0, 0, 0, "Read failed");
2026         return -1;
2027     }
2028     if ((read_len != buflen) && (!allow_eof || (read_len != 0))) {
2029         log0(0,0,0, "Read error: read %d bytes, expected %d", read_len, buflen);
2030         return -1;
2031     }
2032     return read_len;
2033 }
2034 /**
2035  * Writes buflen bytes from buf to the given file descriptor.
2036  * If buflen bytes are written, returns buflen.
2037  * If less that buflen bytes are written, or on error, returns -1.
2038  */
file_write(int fd,const void * buf,int buflen)2039 int file_write(int fd, const void *buf, int buflen)
2040 {
2041     int write_len;
2042 
2043     if ((write_len = write(fd, buf, buflen)) == -1) {
2044         syserror(0, 0, 0, "Write failed");
2045         return -1;
2046     }
2047     if (write_len != buflen) {
2048         log0(0,0,0,"Write error: wrote %d bytes, expected %d",write_len,buflen);
2049         return -1;
2050     }
2051     return write_len;
2052 }
2053 
2054 /**
2055  * Returns the free disk space in bytes of the filesystem that contains
2056  * the given file.  Returns 2^63-1 on error.
2057  */
free_space(const char * file)2058 uint64_t free_space(const char *file)
2059 {
2060 #ifdef WINDOWS
2061     ULARGE_INTEGER bytes_free;
2062     char *dirname, *filename;
2063 
2064     split_path(file, &dirname, &filename);
2065     if (dirname == NULL) {
2066         free(dirname);
2067         free(filename);
2068         return 0x7FFFFFFFFFFFFFFFULL;
2069     }
2070     if (!GetDiskFreeSpaceEx(dirname, &bytes_free, NULL, NULL)) {
2071         char errbuf[300];
2072         FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL,
2073                 GetLastError(), 0, errbuf, sizeof(errbuf), NULL);
2074         log0(0, 0, 0, "Error in GetDiskFreeSpaceEx: %s", errbuf);
2075         free(dirname);
2076         free(filename);
2077         return 0x7FFFFFFFFFFFFFFFULL;
2078     } else {
2079         log3(0, 0, 0, "Free space: " F_i64, bytes_free.QuadPart);
2080         free(dirname);
2081         free(filename);
2082         return bytes_free.QuadPart;
2083     }
2084 #else
2085     struct statvfs buf;
2086 
2087     if (statvfs(file, &buf) == -1) {
2088         syserror(0, 0, 0, "statvfs failed");
2089         return 0x7FFFFFFFFFFFFFFFULL;
2090     } else {
2091         log3(0, 0, 0, "Free space: " F_i64,
2092                       (uint64_t)buf.f_bsize * buf.f_bavail);
2093         return (uint64_t)buf.f_bsize * buf.f_bavail;
2094     }
2095 #endif
2096 }
2097 
2098 /**
2099  * Determines if the priority value passed in is valid.
2100  * Returns 1 on success, 0 on fail
2101  */
valid_priority(int priority)2102 int valid_priority(int priority)
2103 {
2104 #ifdef WINDOWS
2105     if ((priority >= -2) && (priority <= 2)) {
2106         return 1;
2107     } else {
2108         return 0;
2109     }
2110 #else
2111     if ((priority >= -20) && (priority <= 19)) {
2112         return 1;
2113     } else {
2114         return 0;
2115     }
2116 #endif
2117 }
2118 
2119 /**
2120  * Returns a 32-bit random number.
2121  * Some implementations of rand() generate values from 0 to 32767,
2122  * so this guarantees we get a full 32 bits.
2123  */
rand32()2124 uint32_t rand32()
2125 {
2126     return((rand() & 0x7FFF) << 17) | ((rand() & 0x7FFF) << 2) | (rand() & 0x3);
2127 }
2128 
2129 /**
2130  * Safe malloc routine that always returns non-NULL
2131  * On error, exit()
2132  */
safe_malloc(size_t size)2133 void *safe_malloc(size_t size)
2134 {
2135     void *p = malloc(size);
2136     if (p == NULL) {
2137         syserror(0, 0, 0, "malloc failed!");
2138         exit(ERR_ALLOC);
2139     }
2140     return p;
2141 }
2142 
2143 /**
2144  * Safe calloc routine that always returns non-NULL
2145  * On error, exit()
2146  */
safe_calloc(size_t num,size_t size)2147 void *safe_calloc(size_t num, size_t size)
2148 {
2149     void *p = calloc(num, size);
2150     if (p == NULL) {
2151         syserror(0, 0, 0, "calloc failed!");
2152         exit(ERR_ALLOC);
2153     }
2154     return p;
2155 }
2156 
2157 #define RTT_MIN 1.0e-6
2158 #define RTT_MAX 1000.0
2159 
2160 /**
2161  * Convert grtt from a double to a single byte.
2162  * As defined in RFC 5401
2163  */
quantize_grtt(double rtt)2164 uint8_t quantize_grtt(double rtt)
2165 {
2166     if (rtt > RTT_MAX) {
2167         rtt = RTT_MAX;
2168     } else if (rtt < RTT_MIN) {
2169         rtt = RTT_MIN;
2170     }
2171     if (rtt < (33.0 * RTT_MIN)) {
2172         return ((uint8_t)(rtt / RTT_MIN) - 1);
2173     } else {
2174         return ((uint8_t)(0 + ceil(255.0 - (13.0 * log(RTT_MAX/rtt)))));
2175     }
2176 }
2177 
2178 /**
2179  * Convert grtt from a single byte to a double
2180  * As defined in RFC 5401
2181  */
unquantize_grtt(uint8_t rtt)2182 double unquantize_grtt(uint8_t rtt)
2183 {
2184     return ((rtt <= 31) ?
2185             (((double)(rtt + 1)) * (double)RTT_MIN) :
2186             (RTT_MAX / exp(((double)(255 - rtt)) / (double)13.0)));
2187 }
2188 
2189 /**
2190  * Convert the group size from an int to an 8-bit float (5 bit M, 3 bit E)
2191  */
quantize_gsize(int size)2192 uint8_t quantize_gsize(int size)
2193 {
2194     double M;
2195     int E;
2196     int rval;
2197 
2198     M = size;
2199     E = 0;
2200     while (M >= 10) {
2201         M /= 10;
2202         E++;
2203     }
2204     rval = ((int)((M * 32.0 / 10.0) + 0.5)) << 3;
2205     if (rval > 0xFF) {
2206         M /= 10;
2207         E++;
2208         rval = ((int)((M * 32.0 / 10.0) + 0.5)) << 3;
2209     }
2210     rval |= E;
2211 
2212     return rval;
2213 }
2214 
2215 /**
2216  * Convert the group size from an 8-bit float to an int (5 bit M, 3 bit E)
2217  */
unquantize_gsize(uint8_t size)2218 int unquantize_gsize(uint8_t size)
2219 {
2220     int E, i;
2221     double rval;
2222 
2223     E = size & 0x7;
2224     rval = (size >> 3) * (10.0 / 32.0);
2225     for (i = 0; i < E; i++) {
2226         rval *= 10;
2227     }
2228 
2229     return (int)(rval + 0.5);
2230 }
2231 
2232 /**
2233  * Convert rate from an int to a 16-bit float
2234  * As defined in RFC 5740
2235  */
quantize_rate(int64_t rate)2236 uint16_t quantize_rate(int64_t rate)
2237 {
2238     int E;
2239     double M;
2240     int rval;
2241 
2242     M = (double)rate;
2243     E = 0;
2244     while (M > 10) {
2245         M /= 10;
2246         E++;
2247     }
2248     rval = (((int)(M * 4096.0 / 10.0 + 0.5)) << 4);
2249     if (rval > 0xFFFF) {
2250         M /= 10;
2251         E++;
2252         rval = (((int)(M * 4096.0 / 10.0 + 0.5)) << 4);
2253     }
2254     rval |= E;
2255 
2256     return rval;
2257 }
2258 
2259 /**
2260  * Convert rate in B/s from a 16-bit float to an int
2261  * As defined in RFC 5740
2262  */
unquantize_rate(uint16_t rate)2263 int64_t unquantize_rate(uint16_t rate)
2264 {
2265     int E, i;
2266     double rval;
2267 
2268     E = rate & 0xF;
2269     rval = (rate >> 4) * (10.0 / 4096.0);
2270     for (i = 0; i < E; i++) {
2271         rval *= 10;
2272     }
2273 
2274     return (int64_t)rval;
2275 }
2276 
2277