1// Copyright 2012 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// This file implements the Socialist Millionaires Protocol as described in 6// http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol 7// specification is required in order to understand this code and, where 8// possible, the variable names in the code match up with the spec. 9 10package otr 11 12import ( 13 "bytes" 14 "crypto/sha256" 15 "errors" 16 "hash" 17 "math/big" 18) 19 20type smpFailure string 21 22func (s smpFailure) Error() string { 23 return string(s) 24} 25 26var smpFailureError = smpFailure("otr: SMP protocol failed") 27var smpSecretMissingError = smpFailure("otr: mutual secret needed") 28 29const smpVersion = 1 30 31const ( 32 smpState1 = iota 33 smpState2 34 smpState3 35 smpState4 36) 37 38type smpState struct { 39 state int 40 a2, a3, b2, b3, pb, qb *big.Int 41 g2a, g3a *big.Int 42 g2, g3 *big.Int 43 g3b, papb, qaqb, ra *big.Int 44 saved *tlv 45 secret *big.Int 46 question string 47} 48 49func (c *Conversation) startSMP(question string) (tlvs []tlv) { 50 if c.smp.state != smpState1 { 51 tlvs = append(tlvs, c.generateSMPAbort()) 52 } 53 tlvs = append(tlvs, c.generateSMP1(question)) 54 c.smp.question = "" 55 c.smp.state = smpState2 56 return 57} 58 59func (c *Conversation) resetSMP() { 60 c.smp.state = smpState1 61 c.smp.secret = nil 62 c.smp.question = "" 63} 64 65func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) { 66 data := in.data 67 68 switch in.typ { 69 case tlvTypeSMPAbort: 70 if c.smp.state != smpState1 { 71 err = smpFailureError 72 } 73 c.resetSMP() 74 return 75 case tlvTypeSMP1WithQuestion: 76 // We preprocess this into a SMP1 message. 77 nulPos := bytes.IndexByte(data, 0) 78 if nulPos == -1 { 79 err = errors.New("otr: SMP message with question didn't contain a NUL byte") 80 return 81 } 82 c.smp.question = string(data[:nulPos]) 83 data = data[nulPos+1:] 84 } 85 86 numMPIs, data, ok := getU32(data) 87 if !ok || numMPIs > 20 { 88 err = errors.New("otr: corrupt SMP message") 89 return 90 } 91 92 mpis := make([]*big.Int, numMPIs) 93 for i := range mpis { 94 var ok bool 95 mpis[i], data, ok = getMPI(data) 96 if !ok { 97 err = errors.New("otr: corrupt SMP message") 98 return 99 } 100 } 101 102 switch in.typ { 103 case tlvTypeSMP1, tlvTypeSMP1WithQuestion: 104 if c.smp.state != smpState1 { 105 c.resetSMP() 106 out = c.generateSMPAbort() 107 return 108 } 109 if c.smp.secret == nil { 110 err = smpSecretMissingError 111 return 112 } 113 if err = c.processSMP1(mpis); err != nil { 114 return 115 } 116 c.smp.state = smpState3 117 out = c.generateSMP2() 118 case tlvTypeSMP2: 119 if c.smp.state != smpState2 { 120 c.resetSMP() 121 out = c.generateSMPAbort() 122 return 123 } 124 if out, err = c.processSMP2(mpis); err != nil { 125 out = c.generateSMPAbort() 126 return 127 } 128 c.smp.state = smpState4 129 case tlvTypeSMP3: 130 if c.smp.state != smpState3 { 131 c.resetSMP() 132 out = c.generateSMPAbort() 133 return 134 } 135 if out, err = c.processSMP3(mpis); err != nil { 136 return 137 } 138 c.smp.state = smpState1 139 c.smp.secret = nil 140 complete = true 141 case tlvTypeSMP4: 142 if c.smp.state != smpState4 { 143 c.resetSMP() 144 out = c.generateSMPAbort() 145 return 146 } 147 if err = c.processSMP4(mpis); err != nil { 148 out = c.generateSMPAbort() 149 return 150 } 151 c.smp.state = smpState1 152 c.smp.secret = nil 153 complete = true 154 default: 155 panic("unknown SMP message") 156 } 157 158 return 159} 160 161func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) { 162 h := sha256.New() 163 h.Write([]byte{smpVersion}) 164 if weStarted { 165 h.Write(c.PrivateKey.PublicKey.Fingerprint()) 166 h.Write(c.TheirPublicKey.Fingerprint()) 167 } else { 168 h.Write(c.TheirPublicKey.Fingerprint()) 169 h.Write(c.PrivateKey.PublicKey.Fingerprint()) 170 } 171 h.Write(c.SSID[:]) 172 h.Write(mutualSecret) 173 c.smp.secret = new(big.Int).SetBytes(h.Sum(nil)) 174} 175 176func (c *Conversation) generateSMP1(question string) tlv { 177 var randBuf [16]byte 178 c.smp.a2 = c.randMPI(randBuf[:]) 179 c.smp.a3 = c.randMPI(randBuf[:]) 180 g2a := new(big.Int).Exp(g, c.smp.a2, p) 181 g3a := new(big.Int).Exp(g, c.smp.a3, p) 182 h := sha256.New() 183 184 r2 := c.randMPI(randBuf[:]) 185 r := new(big.Int).Exp(g, r2, p) 186 c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r)) 187 d2 := new(big.Int).Mul(c.smp.a2, c2) 188 d2.Sub(r2, d2) 189 d2.Mod(d2, q) 190 if d2.Sign() < 0 { 191 d2.Add(d2, q) 192 } 193 194 r3 := c.randMPI(randBuf[:]) 195 r.Exp(g, r3, p) 196 c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r)) 197 d3 := new(big.Int).Mul(c.smp.a3, c3) 198 d3.Sub(r3, d3) 199 d3.Mod(d3, q) 200 if d3.Sign() < 0 { 201 d3.Add(d3, q) 202 } 203 204 var ret tlv 205 if len(question) > 0 { 206 ret.typ = tlvTypeSMP1WithQuestion 207 ret.data = append(ret.data, question...) 208 ret.data = append(ret.data, 0) 209 } else { 210 ret.typ = tlvTypeSMP1 211 } 212 ret.data = appendU32(ret.data, 6) 213 ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3) 214 return ret 215} 216 217func (c *Conversation) processSMP1(mpis []*big.Int) error { 218 if len(mpis) != 6 { 219 return errors.New("otr: incorrect number of arguments in SMP1 message") 220 } 221 g2a := mpis[0] 222 c2 := mpis[1] 223 d2 := mpis[2] 224 g3a := mpis[3] 225 c3 := mpis[4] 226 d3 := mpis[5] 227 h := sha256.New() 228 229 r := new(big.Int).Exp(g, d2, p) 230 s := new(big.Int).Exp(g2a, c2, p) 231 r.Mul(r, s) 232 r.Mod(r, p) 233 t := new(big.Int).SetBytes(hashMPIs(h, 1, r)) 234 if c2.Cmp(t) != 0 { 235 return errors.New("otr: ZKP c2 incorrect in SMP1 message") 236 } 237 r.Exp(g, d3, p) 238 s.Exp(g3a, c3, p) 239 r.Mul(r, s) 240 r.Mod(r, p) 241 t.SetBytes(hashMPIs(h, 2, r)) 242 if c3.Cmp(t) != 0 { 243 return errors.New("otr: ZKP c3 incorrect in SMP1 message") 244 } 245 246 c.smp.g2a = g2a 247 c.smp.g3a = g3a 248 return nil 249} 250 251func (c *Conversation) generateSMP2() tlv { 252 var randBuf [16]byte 253 b2 := c.randMPI(randBuf[:]) 254 c.smp.b3 = c.randMPI(randBuf[:]) 255 r2 := c.randMPI(randBuf[:]) 256 r3 := c.randMPI(randBuf[:]) 257 r4 := c.randMPI(randBuf[:]) 258 r5 := c.randMPI(randBuf[:]) 259 r6 := c.randMPI(randBuf[:]) 260 261 g2b := new(big.Int).Exp(g, b2, p) 262 g3b := new(big.Int).Exp(g, c.smp.b3, p) 263 264 r := new(big.Int).Exp(g, r2, p) 265 h := sha256.New() 266 c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r)) 267 d2 := new(big.Int).Mul(b2, c2) 268 d2.Sub(r2, d2) 269 d2.Mod(d2, q) 270 if d2.Sign() < 0 { 271 d2.Add(d2, q) 272 } 273 274 r.Exp(g, r3, p) 275 c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r)) 276 d3 := new(big.Int).Mul(c.smp.b3, c3) 277 d3.Sub(r3, d3) 278 d3.Mod(d3, q) 279 if d3.Sign() < 0 { 280 d3.Add(d3, q) 281 } 282 283 c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p) 284 c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p) 285 c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p) 286 c.smp.qb = new(big.Int).Exp(g, r4, p) 287 r.Exp(c.smp.g2, c.smp.secret, p) 288 c.smp.qb.Mul(c.smp.qb, r) 289 c.smp.qb.Mod(c.smp.qb, p) 290 291 s := new(big.Int) 292 s.Exp(c.smp.g2, r6, p) 293 r.Exp(g, r5, p) 294 s.Mul(r, s) 295 s.Mod(s, p) 296 r.Exp(c.smp.g3, r5, p) 297 cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s)) 298 299 // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q 300 301 s.Mul(r4, cp) 302 r.Sub(r5, s) 303 d5 := new(big.Int).Mod(r, q) 304 if d5.Sign() < 0 { 305 d5.Add(d5, q) 306 } 307 308 s.Mul(c.smp.secret, cp) 309 r.Sub(r6, s) 310 d6 := new(big.Int).Mod(r, q) 311 if d6.Sign() < 0 { 312 d6.Add(d6, q) 313 } 314 315 var ret tlv 316 ret.typ = tlvTypeSMP2 317 ret.data = appendU32(ret.data, 11) 318 ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6) 319 return ret 320} 321 322func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) { 323 if len(mpis) != 11 { 324 err = errors.New("otr: incorrect number of arguments in SMP2 message") 325 return 326 } 327 g2b := mpis[0] 328 c2 := mpis[1] 329 d2 := mpis[2] 330 g3b := mpis[3] 331 c3 := mpis[4] 332 d3 := mpis[5] 333 pb := mpis[6] 334 qb := mpis[7] 335 cp := mpis[8] 336 d5 := mpis[9] 337 d6 := mpis[10] 338 h := sha256.New() 339 340 r := new(big.Int).Exp(g, d2, p) 341 s := new(big.Int).Exp(g2b, c2, p) 342 r.Mul(r, s) 343 r.Mod(r, p) 344 s.SetBytes(hashMPIs(h, 3, r)) 345 if c2.Cmp(s) != 0 { 346 err = errors.New("otr: ZKP c2 failed in SMP2 message") 347 return 348 } 349 350 r.Exp(g, d3, p) 351 s.Exp(g3b, c3, p) 352 r.Mul(r, s) 353 r.Mod(r, p) 354 s.SetBytes(hashMPIs(h, 4, r)) 355 if c3.Cmp(s) != 0 { 356 err = errors.New("otr: ZKP c3 failed in SMP2 message") 357 return 358 } 359 360 c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p) 361 c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p) 362 363 r.Exp(g, d5, p) 364 s.Exp(c.smp.g2, d6, p) 365 r.Mul(r, s) 366 s.Exp(qb, cp, p) 367 r.Mul(r, s) 368 r.Mod(r, p) 369 370 s.Exp(c.smp.g3, d5, p) 371 t := new(big.Int).Exp(pb, cp, p) 372 s.Mul(s, t) 373 s.Mod(s, p) 374 t.SetBytes(hashMPIs(h, 5, s, r)) 375 if cp.Cmp(t) != 0 { 376 err = errors.New("otr: ZKP cP failed in SMP2 message") 377 return 378 } 379 380 var randBuf [16]byte 381 r4 := c.randMPI(randBuf[:]) 382 r5 := c.randMPI(randBuf[:]) 383 r6 := c.randMPI(randBuf[:]) 384 r7 := c.randMPI(randBuf[:]) 385 386 pa := new(big.Int).Exp(c.smp.g3, r4, p) 387 r.Exp(c.smp.g2, c.smp.secret, p) 388 qa := new(big.Int).Exp(g, r4, p) 389 qa.Mul(qa, r) 390 qa.Mod(qa, p) 391 392 r.Exp(g, r5, p) 393 s.Exp(c.smp.g2, r6, p) 394 r.Mul(r, s) 395 r.Mod(r, p) 396 397 s.Exp(c.smp.g3, r5, p) 398 cp.SetBytes(hashMPIs(h, 6, s, r)) 399 400 r.Mul(r4, cp) 401 d5 = new(big.Int).Sub(r5, r) 402 d5.Mod(d5, q) 403 if d5.Sign() < 0 { 404 d5.Add(d5, q) 405 } 406 407 r.Mul(c.smp.secret, cp) 408 d6 = new(big.Int).Sub(r6, r) 409 d6.Mod(d6, q) 410 if d6.Sign() < 0 { 411 d6.Add(d6, q) 412 } 413 414 r.ModInverse(qb, p) 415 qaqb := new(big.Int).Mul(qa, r) 416 qaqb.Mod(qaqb, p) 417 418 ra := new(big.Int).Exp(qaqb, c.smp.a3, p) 419 r.Exp(qaqb, r7, p) 420 s.Exp(g, r7, p) 421 cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r)) 422 423 r.Mul(c.smp.a3, cr) 424 d7 := new(big.Int).Sub(r7, r) 425 d7.Mod(d7, q) 426 if d7.Sign() < 0 { 427 d7.Add(d7, q) 428 } 429 430 c.smp.g3b = g3b 431 c.smp.qaqb = qaqb 432 433 r.ModInverse(pb, p) 434 c.smp.papb = new(big.Int).Mul(pa, r) 435 c.smp.papb.Mod(c.smp.papb, p) 436 c.smp.ra = ra 437 438 out.typ = tlvTypeSMP3 439 out.data = appendU32(out.data, 8) 440 out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7) 441 return 442} 443 444func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) { 445 if len(mpis) != 8 { 446 err = errors.New("otr: incorrect number of arguments in SMP3 message") 447 return 448 } 449 pa := mpis[0] 450 qa := mpis[1] 451 cp := mpis[2] 452 d5 := mpis[3] 453 d6 := mpis[4] 454 ra := mpis[5] 455 cr := mpis[6] 456 d7 := mpis[7] 457 h := sha256.New() 458 459 r := new(big.Int).Exp(g, d5, p) 460 s := new(big.Int).Exp(c.smp.g2, d6, p) 461 r.Mul(r, s) 462 s.Exp(qa, cp, p) 463 r.Mul(r, s) 464 r.Mod(r, p) 465 466 s.Exp(c.smp.g3, d5, p) 467 t := new(big.Int).Exp(pa, cp, p) 468 s.Mul(s, t) 469 s.Mod(s, p) 470 t.SetBytes(hashMPIs(h, 6, s, r)) 471 if t.Cmp(cp) != 0 { 472 err = errors.New("otr: ZKP cP failed in SMP3 message") 473 return 474 } 475 476 r.ModInverse(c.smp.qb, p) 477 qaqb := new(big.Int).Mul(qa, r) 478 qaqb.Mod(qaqb, p) 479 480 r.Exp(qaqb, d7, p) 481 s.Exp(ra, cr, p) 482 r.Mul(r, s) 483 r.Mod(r, p) 484 485 s.Exp(g, d7, p) 486 t.Exp(c.smp.g3a, cr, p) 487 s.Mul(s, t) 488 s.Mod(s, p) 489 t.SetBytes(hashMPIs(h, 7, s, r)) 490 if t.Cmp(cr) != 0 { 491 err = errors.New("otr: ZKP cR failed in SMP3 message") 492 return 493 } 494 495 var randBuf [16]byte 496 r7 := c.randMPI(randBuf[:]) 497 rb := new(big.Int).Exp(qaqb, c.smp.b3, p) 498 499 r.Exp(qaqb, r7, p) 500 s.Exp(g, r7, p) 501 cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r)) 502 503 r.Mul(c.smp.b3, cr) 504 d7 = new(big.Int).Sub(r7, r) 505 d7.Mod(d7, q) 506 if d7.Sign() < 0 { 507 d7.Add(d7, q) 508 } 509 510 out.typ = tlvTypeSMP4 511 out.data = appendU32(out.data, 3) 512 out.data = appendMPIs(out.data, rb, cr, d7) 513 514 r.ModInverse(c.smp.pb, p) 515 r.Mul(pa, r) 516 r.Mod(r, p) 517 s.Exp(ra, c.smp.b3, p) 518 if r.Cmp(s) != 0 { 519 err = smpFailureError 520 } 521 522 return 523} 524 525func (c *Conversation) processSMP4(mpis []*big.Int) error { 526 if len(mpis) != 3 { 527 return errors.New("otr: incorrect number of arguments in SMP4 message") 528 } 529 rb := mpis[0] 530 cr := mpis[1] 531 d7 := mpis[2] 532 h := sha256.New() 533 534 r := new(big.Int).Exp(c.smp.qaqb, d7, p) 535 s := new(big.Int).Exp(rb, cr, p) 536 r.Mul(r, s) 537 r.Mod(r, p) 538 539 s.Exp(g, d7, p) 540 t := new(big.Int).Exp(c.smp.g3b, cr, p) 541 s.Mul(s, t) 542 s.Mod(s, p) 543 t.SetBytes(hashMPIs(h, 8, s, r)) 544 if t.Cmp(cr) != 0 { 545 return errors.New("otr: ZKP cR failed in SMP4 message") 546 } 547 548 r.Exp(rb, c.smp.a3, p) 549 if r.Cmp(c.smp.papb) != 0 { 550 return smpFailureError 551 } 552 553 return nil 554} 555 556func (c *Conversation) generateSMPAbort() tlv { 557 return tlv{typ: tlvTypeSMPAbort} 558} 559 560func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte { 561 if h != nil { 562 h.Reset() 563 } else { 564 h = sha256.New() 565 } 566 567 h.Write([]byte{magic}) 568 for _, mpi := range mpis { 569 h.Write(appendMPI(nil, mpi)) 570 } 571 return h.Sum(nil) 572} 573