1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements the Socialist Millionaires Protocol as described in
6// http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
7// specification is required in order to understand this code and, where
8// possible, the variable names in the code match up with the spec.
9
10package otr
11
12import (
13	"bytes"
14	"crypto/sha256"
15	"errors"
16	"hash"
17	"math/big"
18)
19
20type smpFailure string
21
22func (s smpFailure) Error() string {
23	return string(s)
24}
25
26var smpFailureError = smpFailure("otr: SMP protocol failed")
27var smpSecretMissingError = smpFailure("otr: mutual secret needed")
28
29const smpVersion = 1
30
31const (
32	smpState1 = iota
33	smpState2
34	smpState3
35	smpState4
36)
37
38type smpState struct {
39	state                  int
40	a2, a3, b2, b3, pb, qb *big.Int
41	g2a, g3a               *big.Int
42	g2, g3                 *big.Int
43	g3b, papb, qaqb, ra    *big.Int
44	saved                  *tlv
45	secret                 *big.Int
46	question               string
47}
48
49func (c *Conversation) startSMP(question string) (tlvs []tlv) {
50	if c.smp.state != smpState1 {
51		tlvs = append(tlvs, c.generateSMPAbort())
52	}
53	tlvs = append(tlvs, c.generateSMP1(question))
54	c.smp.question = ""
55	c.smp.state = smpState2
56	return
57}
58
59func (c *Conversation) resetSMP() {
60	c.smp.state = smpState1
61	c.smp.secret = nil
62	c.smp.question = ""
63}
64
65func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
66	data := in.data
67
68	switch in.typ {
69	case tlvTypeSMPAbort:
70		if c.smp.state != smpState1 {
71			err = smpFailureError
72		}
73		c.resetSMP()
74		return
75	case tlvTypeSMP1WithQuestion:
76		// We preprocess this into a SMP1 message.
77		nulPos := bytes.IndexByte(data, 0)
78		if nulPos == -1 {
79			err = errors.New("otr: SMP message with question didn't contain a NUL byte")
80			return
81		}
82		c.smp.question = string(data[:nulPos])
83		data = data[nulPos+1:]
84	}
85
86	numMPIs, data, ok := getU32(data)
87	if !ok || numMPIs > 20 {
88		err = errors.New("otr: corrupt SMP message")
89		return
90	}
91
92	mpis := make([]*big.Int, numMPIs)
93	for i := range mpis {
94		var ok bool
95		mpis[i], data, ok = getMPI(data)
96		if !ok {
97			err = errors.New("otr: corrupt SMP message")
98			return
99		}
100	}
101
102	switch in.typ {
103	case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
104		if c.smp.state != smpState1 {
105			c.resetSMP()
106			out = c.generateSMPAbort()
107			return
108		}
109		if c.smp.secret == nil {
110			err = smpSecretMissingError
111			return
112		}
113		if err = c.processSMP1(mpis); err != nil {
114			return
115		}
116		c.smp.state = smpState3
117		out = c.generateSMP2()
118	case tlvTypeSMP2:
119		if c.smp.state != smpState2 {
120			c.resetSMP()
121			out = c.generateSMPAbort()
122			return
123		}
124		if out, err = c.processSMP2(mpis); err != nil {
125			out = c.generateSMPAbort()
126			return
127		}
128		c.smp.state = smpState4
129	case tlvTypeSMP3:
130		if c.smp.state != smpState3 {
131			c.resetSMP()
132			out = c.generateSMPAbort()
133			return
134		}
135		if out, err = c.processSMP3(mpis); err != nil {
136			return
137		}
138		c.smp.state = smpState1
139		c.smp.secret = nil
140		complete = true
141	case tlvTypeSMP4:
142		if c.smp.state != smpState4 {
143			c.resetSMP()
144			out = c.generateSMPAbort()
145			return
146		}
147		if err = c.processSMP4(mpis); err != nil {
148			out = c.generateSMPAbort()
149			return
150		}
151		c.smp.state = smpState1
152		c.smp.secret = nil
153		complete = true
154	default:
155		panic("unknown SMP message")
156	}
157
158	return
159}
160
161func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
162	h := sha256.New()
163	h.Write([]byte{smpVersion})
164	if weStarted {
165		h.Write(c.PrivateKey.PublicKey.Fingerprint())
166		h.Write(c.TheirPublicKey.Fingerprint())
167	} else {
168		h.Write(c.TheirPublicKey.Fingerprint())
169		h.Write(c.PrivateKey.PublicKey.Fingerprint())
170	}
171	h.Write(c.SSID[:])
172	h.Write(mutualSecret)
173	c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
174}
175
176func (c *Conversation) generateSMP1(question string) tlv {
177	var randBuf [16]byte
178	c.smp.a2 = c.randMPI(randBuf[:])
179	c.smp.a3 = c.randMPI(randBuf[:])
180	g2a := new(big.Int).Exp(g, c.smp.a2, p)
181	g3a := new(big.Int).Exp(g, c.smp.a3, p)
182	h := sha256.New()
183
184	r2 := c.randMPI(randBuf[:])
185	r := new(big.Int).Exp(g, r2, p)
186	c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
187	d2 := new(big.Int).Mul(c.smp.a2, c2)
188	d2.Sub(r2, d2)
189	d2.Mod(d2, q)
190	if d2.Sign() < 0 {
191		d2.Add(d2, q)
192	}
193
194	r3 := c.randMPI(randBuf[:])
195	r.Exp(g, r3, p)
196	c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
197	d3 := new(big.Int).Mul(c.smp.a3, c3)
198	d3.Sub(r3, d3)
199	d3.Mod(d3, q)
200	if d3.Sign() < 0 {
201		d3.Add(d3, q)
202	}
203
204	var ret tlv
205	if len(question) > 0 {
206		ret.typ = tlvTypeSMP1WithQuestion
207		ret.data = append(ret.data, question...)
208		ret.data = append(ret.data, 0)
209	} else {
210		ret.typ = tlvTypeSMP1
211	}
212	ret.data = appendU32(ret.data, 6)
213	ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
214	return ret
215}
216
217func (c *Conversation) processSMP1(mpis []*big.Int) error {
218	if len(mpis) != 6 {
219		return errors.New("otr: incorrect number of arguments in SMP1 message")
220	}
221	g2a := mpis[0]
222	c2 := mpis[1]
223	d2 := mpis[2]
224	g3a := mpis[3]
225	c3 := mpis[4]
226	d3 := mpis[5]
227	h := sha256.New()
228
229	r := new(big.Int).Exp(g, d2, p)
230	s := new(big.Int).Exp(g2a, c2, p)
231	r.Mul(r, s)
232	r.Mod(r, p)
233	t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
234	if c2.Cmp(t) != 0 {
235		return errors.New("otr: ZKP c2 incorrect in SMP1 message")
236	}
237	r.Exp(g, d3, p)
238	s.Exp(g3a, c3, p)
239	r.Mul(r, s)
240	r.Mod(r, p)
241	t.SetBytes(hashMPIs(h, 2, r))
242	if c3.Cmp(t) != 0 {
243		return errors.New("otr: ZKP c3 incorrect in SMP1 message")
244	}
245
246	c.smp.g2a = g2a
247	c.smp.g3a = g3a
248	return nil
249}
250
251func (c *Conversation) generateSMP2() tlv {
252	var randBuf [16]byte
253	b2 := c.randMPI(randBuf[:])
254	c.smp.b3 = c.randMPI(randBuf[:])
255	r2 := c.randMPI(randBuf[:])
256	r3 := c.randMPI(randBuf[:])
257	r4 := c.randMPI(randBuf[:])
258	r5 := c.randMPI(randBuf[:])
259	r6 := c.randMPI(randBuf[:])
260
261	g2b := new(big.Int).Exp(g, b2, p)
262	g3b := new(big.Int).Exp(g, c.smp.b3, p)
263
264	r := new(big.Int).Exp(g, r2, p)
265	h := sha256.New()
266	c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
267	d2 := new(big.Int).Mul(b2, c2)
268	d2.Sub(r2, d2)
269	d2.Mod(d2, q)
270	if d2.Sign() < 0 {
271		d2.Add(d2, q)
272	}
273
274	r.Exp(g, r3, p)
275	c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
276	d3 := new(big.Int).Mul(c.smp.b3, c3)
277	d3.Sub(r3, d3)
278	d3.Mod(d3, q)
279	if d3.Sign() < 0 {
280		d3.Add(d3, q)
281	}
282
283	c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
284	c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
285	c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
286	c.smp.qb = new(big.Int).Exp(g, r4, p)
287	r.Exp(c.smp.g2, c.smp.secret, p)
288	c.smp.qb.Mul(c.smp.qb, r)
289	c.smp.qb.Mod(c.smp.qb, p)
290
291	s := new(big.Int)
292	s.Exp(c.smp.g2, r6, p)
293	r.Exp(g, r5, p)
294	s.Mul(r, s)
295	s.Mod(s, p)
296	r.Exp(c.smp.g3, r5, p)
297	cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
298
299	// D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
300
301	s.Mul(r4, cp)
302	r.Sub(r5, s)
303	d5 := new(big.Int).Mod(r, q)
304	if d5.Sign() < 0 {
305		d5.Add(d5, q)
306	}
307
308	s.Mul(c.smp.secret, cp)
309	r.Sub(r6, s)
310	d6 := new(big.Int).Mod(r, q)
311	if d6.Sign() < 0 {
312		d6.Add(d6, q)
313	}
314
315	var ret tlv
316	ret.typ = tlvTypeSMP2
317	ret.data = appendU32(ret.data, 11)
318	ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
319	return ret
320}
321
322func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
323	if len(mpis) != 11 {
324		err = errors.New("otr: incorrect number of arguments in SMP2 message")
325		return
326	}
327	g2b := mpis[0]
328	c2 := mpis[1]
329	d2 := mpis[2]
330	g3b := mpis[3]
331	c3 := mpis[4]
332	d3 := mpis[5]
333	pb := mpis[6]
334	qb := mpis[7]
335	cp := mpis[8]
336	d5 := mpis[9]
337	d6 := mpis[10]
338	h := sha256.New()
339
340	r := new(big.Int).Exp(g, d2, p)
341	s := new(big.Int).Exp(g2b, c2, p)
342	r.Mul(r, s)
343	r.Mod(r, p)
344	s.SetBytes(hashMPIs(h, 3, r))
345	if c2.Cmp(s) != 0 {
346		err = errors.New("otr: ZKP c2 failed in SMP2 message")
347		return
348	}
349
350	r.Exp(g, d3, p)
351	s.Exp(g3b, c3, p)
352	r.Mul(r, s)
353	r.Mod(r, p)
354	s.SetBytes(hashMPIs(h, 4, r))
355	if c3.Cmp(s) != 0 {
356		err = errors.New("otr: ZKP c3 failed in SMP2 message")
357		return
358	}
359
360	c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
361	c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
362
363	r.Exp(g, d5, p)
364	s.Exp(c.smp.g2, d6, p)
365	r.Mul(r, s)
366	s.Exp(qb, cp, p)
367	r.Mul(r, s)
368	r.Mod(r, p)
369
370	s.Exp(c.smp.g3, d5, p)
371	t := new(big.Int).Exp(pb, cp, p)
372	s.Mul(s, t)
373	s.Mod(s, p)
374	t.SetBytes(hashMPIs(h, 5, s, r))
375	if cp.Cmp(t) != 0 {
376		err = errors.New("otr: ZKP cP failed in SMP2 message")
377		return
378	}
379
380	var randBuf [16]byte
381	r4 := c.randMPI(randBuf[:])
382	r5 := c.randMPI(randBuf[:])
383	r6 := c.randMPI(randBuf[:])
384	r7 := c.randMPI(randBuf[:])
385
386	pa := new(big.Int).Exp(c.smp.g3, r4, p)
387	r.Exp(c.smp.g2, c.smp.secret, p)
388	qa := new(big.Int).Exp(g, r4, p)
389	qa.Mul(qa, r)
390	qa.Mod(qa, p)
391
392	r.Exp(g, r5, p)
393	s.Exp(c.smp.g2, r6, p)
394	r.Mul(r, s)
395	r.Mod(r, p)
396
397	s.Exp(c.smp.g3, r5, p)
398	cp.SetBytes(hashMPIs(h, 6, s, r))
399
400	r.Mul(r4, cp)
401	d5 = new(big.Int).Sub(r5, r)
402	d5.Mod(d5, q)
403	if d5.Sign() < 0 {
404		d5.Add(d5, q)
405	}
406
407	r.Mul(c.smp.secret, cp)
408	d6 = new(big.Int).Sub(r6, r)
409	d6.Mod(d6, q)
410	if d6.Sign() < 0 {
411		d6.Add(d6, q)
412	}
413
414	r.ModInverse(qb, p)
415	qaqb := new(big.Int).Mul(qa, r)
416	qaqb.Mod(qaqb, p)
417
418	ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
419	r.Exp(qaqb, r7, p)
420	s.Exp(g, r7, p)
421	cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
422
423	r.Mul(c.smp.a3, cr)
424	d7 := new(big.Int).Sub(r7, r)
425	d7.Mod(d7, q)
426	if d7.Sign() < 0 {
427		d7.Add(d7, q)
428	}
429
430	c.smp.g3b = g3b
431	c.smp.qaqb = qaqb
432
433	r.ModInverse(pb, p)
434	c.smp.papb = new(big.Int).Mul(pa, r)
435	c.smp.papb.Mod(c.smp.papb, p)
436	c.smp.ra = ra
437
438	out.typ = tlvTypeSMP3
439	out.data = appendU32(out.data, 8)
440	out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
441	return
442}
443
444func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
445	if len(mpis) != 8 {
446		err = errors.New("otr: incorrect number of arguments in SMP3 message")
447		return
448	}
449	pa := mpis[0]
450	qa := mpis[1]
451	cp := mpis[2]
452	d5 := mpis[3]
453	d6 := mpis[4]
454	ra := mpis[5]
455	cr := mpis[6]
456	d7 := mpis[7]
457	h := sha256.New()
458
459	r := new(big.Int).Exp(g, d5, p)
460	s := new(big.Int).Exp(c.smp.g2, d6, p)
461	r.Mul(r, s)
462	s.Exp(qa, cp, p)
463	r.Mul(r, s)
464	r.Mod(r, p)
465
466	s.Exp(c.smp.g3, d5, p)
467	t := new(big.Int).Exp(pa, cp, p)
468	s.Mul(s, t)
469	s.Mod(s, p)
470	t.SetBytes(hashMPIs(h, 6, s, r))
471	if t.Cmp(cp) != 0 {
472		err = errors.New("otr: ZKP cP failed in SMP3 message")
473		return
474	}
475
476	r.ModInverse(c.smp.qb, p)
477	qaqb := new(big.Int).Mul(qa, r)
478	qaqb.Mod(qaqb, p)
479
480	r.Exp(qaqb, d7, p)
481	s.Exp(ra, cr, p)
482	r.Mul(r, s)
483	r.Mod(r, p)
484
485	s.Exp(g, d7, p)
486	t.Exp(c.smp.g3a, cr, p)
487	s.Mul(s, t)
488	s.Mod(s, p)
489	t.SetBytes(hashMPIs(h, 7, s, r))
490	if t.Cmp(cr) != 0 {
491		err = errors.New("otr: ZKP cR failed in SMP3 message")
492		return
493	}
494
495	var randBuf [16]byte
496	r7 := c.randMPI(randBuf[:])
497	rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
498
499	r.Exp(qaqb, r7, p)
500	s.Exp(g, r7, p)
501	cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
502
503	r.Mul(c.smp.b3, cr)
504	d7 = new(big.Int).Sub(r7, r)
505	d7.Mod(d7, q)
506	if d7.Sign() < 0 {
507		d7.Add(d7, q)
508	}
509
510	out.typ = tlvTypeSMP4
511	out.data = appendU32(out.data, 3)
512	out.data = appendMPIs(out.data, rb, cr, d7)
513
514	r.ModInverse(c.smp.pb, p)
515	r.Mul(pa, r)
516	r.Mod(r, p)
517	s.Exp(ra, c.smp.b3, p)
518	if r.Cmp(s) != 0 {
519		err = smpFailureError
520	}
521
522	return
523}
524
525func (c *Conversation) processSMP4(mpis []*big.Int) error {
526	if len(mpis) != 3 {
527		return errors.New("otr: incorrect number of arguments in SMP4 message")
528	}
529	rb := mpis[0]
530	cr := mpis[1]
531	d7 := mpis[2]
532	h := sha256.New()
533
534	r := new(big.Int).Exp(c.smp.qaqb, d7, p)
535	s := new(big.Int).Exp(rb, cr, p)
536	r.Mul(r, s)
537	r.Mod(r, p)
538
539	s.Exp(g, d7, p)
540	t := new(big.Int).Exp(c.smp.g3b, cr, p)
541	s.Mul(s, t)
542	s.Mod(s, p)
543	t.SetBytes(hashMPIs(h, 8, s, r))
544	if t.Cmp(cr) != 0 {
545		return errors.New("otr: ZKP cR failed in SMP4 message")
546	}
547
548	r.Exp(rb, c.smp.a3, p)
549	if r.Cmp(c.smp.papb) != 0 {
550		return smpFailureError
551	}
552
553	return nil
554}
555
556func (c *Conversation) generateSMPAbort() tlv {
557	return tlv{typ: tlvTypeSMPAbort}
558}
559
560func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
561	if h != nil {
562		h.Reset()
563	} else {
564		h = sha256.New()
565	}
566
567	h.Write([]byte{magic})
568	for _, mpi := range mpis {
569		h.Write(appendMPI(nil, mpi))
570	}
571	return h.Sum(nil)
572}
573