1 /*
2  * Based on public domain code available at: http://cr.yp.to/snuffle.html
3  *
4  * This therefore is public domain.
5  */
6 
7 #ifndef ZT_SALSA20_HPP
8 #define ZT_SALSA20_HPP
9 
10 #include <stdio.h>
11 #include <stdint.h>
12 #include <stdlib.h>
13 #include <string.h>
14 
15 #include "Constants.hpp"
16 #include "Utils.hpp"
17 
18 #if (!defined(ZT_SALSA20_SSE)) && (defined(__SSE2__) || defined(__WINDOWS__))
19 #define ZT_SALSA20_SSE 1
20 #endif
21 
22 #ifdef ZT_SALSA20_SSE
23 #include <emmintrin.h>
24 #endif // ZT_SALSA20_SSE
25 
26 namespace ZeroTier {
27 
28 /**
29  * Salsa20 stream cipher
30  */
31 class Salsa20
32 {
33 public:
Salsa20()34 	Salsa20() {}
~Salsa20()35 	~Salsa20() { Utils::burn(&_state,sizeof(_state)); }
36 
37 	/**
38 	 * XOR d with s
39 	 *
40 	 * This is done efficiently using e.g. SSE if available. It's used when
41 	 * alternative Salsa20 implementations are used in Packet and is here
42 	 * since this is where all the SSE stuff is already included.
43 	 *
44 	 * @param d Destination to XOR
45 	 * @param s Source bytes to XOR with destination
46 	 * @param len Length of s and d
47 	 */
memxor(uint8_t * d,const uint8_t * s,unsigned int len)48 	static inline void memxor(uint8_t *d,const uint8_t *s,unsigned int len)
49 	{
50 #ifdef ZT_SALSA20_SSE
51 		while (len >= 128) {
52 			__m128i s0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s));
53 			__m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 16));
54 			__m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 32));
55 			__m128i s3 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 48));
56 			__m128i s4 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 64));
57 			__m128i s5 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 80));
58 			__m128i s6 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 96));
59 			__m128i s7 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 112));
60 			__m128i d0 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d));
61 			__m128i d1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 16));
62 			__m128i d2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 32));
63 			__m128i d3 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 48));
64 			__m128i d4 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 64));
65 			__m128i d5 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 80));
66 			__m128i d6 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 96));
67 			__m128i d7 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 112));
68 			d0 = _mm_xor_si128(d0,s0);
69 			d1 = _mm_xor_si128(d1,s1);
70 			d2 = _mm_xor_si128(d2,s2);
71 			d3 = _mm_xor_si128(d3,s3);
72 			d4 = _mm_xor_si128(d4,s4);
73 			d5 = _mm_xor_si128(d5,s5);
74 			d6 = _mm_xor_si128(d6,s6);
75 			d7 = _mm_xor_si128(d7,s7);
76 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d),d0);
77 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 16),d1);
78 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 32),d2);
79 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 48),d3);
80 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 64),d4);
81 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 80),d5);
82 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 96),d6);
83 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d + 112),d7);
84 			s += 128;
85 			d += 128;
86 			len -= 128;
87 		}
88 		while (len >= 16) {
89 			_mm_storeu_si128(reinterpret_cast<__m128i *>(d),_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<__m128i *>(d)),_mm_loadu_si128(reinterpret_cast<const __m128i *>(s))));
90 			s += 16;
91 			d += 16;
92 			len -= 16;
93 		}
94 #else
95 #ifndef ZT_NO_TYPE_PUNNING
96 		while (len >= 16) {
97 			(*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
98 			s += 8;
99 			d += 8;
100 			(*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
101 			s += 8;
102 			d += 8;
103 			len -= 16;
104 		}
105 #endif
106 #endif
107 		while (len) {
108 			--len;
109 			*(d++) ^= *(s++);
110 		}
111 	}
112 
113 	/**
114 	 * @param key 256-bit (32 byte) key
115 	 * @param iv 64-bit initialization vector
116 	 */
Salsa20(const void * key,const void * iv)117 	Salsa20(const void *key,const void *iv)
118 	{
119 		init(key,iv);
120 	}
121 
122 	/**
123 	 * Initialize cipher
124 	 *
125 	 * @param key Key bits
126 	 * @param iv 64-bit initialization vector
127 	 */
128 	void init(const void *key,const void *iv);
129 
130 	/**
131 	 * Encrypt/decrypt data using Salsa20/12
132 	 *
133 	 * @param in Input data
134 	 * @param out Output buffer
135 	 * @param bytes Length of data
136 	 */
137 	void crypt12(const void *in,void *out,unsigned int bytes);
138 
139 	/**
140 	 * Encrypt/decrypt data using Salsa20/20
141 	 *
142 	 * @param in Input data
143 	 * @param out Output buffer
144 	 * @param bytes Length of data
145 	 */
146 	void crypt20(const void *in,void *out,unsigned int bytes);
147 
148 private:
149 	union {
150 #ifdef ZT_SALSA20_SSE
151 		__m128i v[4];
152 #endif // ZT_SALSA20_SSE
153 		uint32_t i[16];
154 	} _state;
155 };
156 
157 } // namespace ZeroTier
158 
159 #endif
160