1# This software is provided 'as-is', without any express or implied 2# warranty. In no event will the author be held liable for any damages 3# arising from the use of this software. 4# 5# Permission is granted to anyone to use this software for any purpose, 6# including commercial applications, and to alter it and redistribute it 7# freely, subject to the following restrictions: 8# 9# 1. The origin of this software must not be misrepresented; you must not 10# claim that you wrote the original software. If you use this software 11# in a product, an acknowledgment in the product documentation would be 12# appreciated but is not required. 13# 2. Altered source versions must be plainly marked as such, and must not be 14# misrepresented as being the original software. 15# 3. This notice may not be removed or altered from any source distribution. 16# 17# Copyright (c) 2008 Greg Hewgill http://hewgill.com 18 19from __future__ import print_function 20import base64 21import hashlib 22import re 23import time 24from fuglu.localStringEncoding import force_bString, force_uString, forceBytesFromChar, forceCharFromBytes 25 26__all__ = [ 27 "Simple", 28 "Relaxed", 29 "InternalError", 30 "KeyFormatError", 31 "MessageFormatError", 32 "ParameterError", 33 "sign", 34 "verify", 35] 36 37 38class Simple(object): 39 40 """Class that represents the "simple" canonicalization algorithm.""" 41 42 name = "simple" 43 44 @staticmethod 45 def canonicalize_headers(headers): 46 # No changes to headers. 47 return headers 48 49 @staticmethod 50 def canonicalize_body(body): 51 # Ignore all empty lines at the end of the message body. 52 return re.sub("(\r\n)*$", "\r\n", body) 53 54 55class Relaxed(object): 56 57 """Class that represents the "relaxed" canonicalization algorithm.""" 58 59 name = "relaxed" 60 61 @staticmethod 62 def canonicalize_headers(headers): 63 # Convert all header field names to lowercase. 64 # Unfold all header lines. 65 # Compress WSP to single space. 66 # Remove all WSP at the start or end of the field value (strip). 67 return [(x[0].lower(), re.sub(r"\s+", " ", re.sub("\r\n", "", x[1])).strip() + "\r\n") for x in headers] 68 69 @staticmethod 70 def canonicalize_body(body): 71 # Remove all trailing WSP at end of lines. 72 # Compress non-line-ending WSP to single space. 73 # Ignore all empty lines at the end of the message body. 74 return re.sub("(\r\n)*$", "\r\n", re.sub(r"[\x09\x20]+", " ", re.sub("[\\x09\\x20]+\r\n", "\r\n", body))) 75 76 77class DKIMException(Exception): 78 79 """Base class for DKIM errors.""" 80 pass 81 82 83class InternalError(DKIMException): 84 85 """Internal error in dkim module. Should never happen.""" 86 pass 87 88 89class KeyFormatError(DKIMException): 90 91 """Key format error while parsing an RSA public or private key.""" 92 pass 93 94 95class MessageFormatError(DKIMException): 96 97 """RFC822 message format error.""" 98 pass 99 100 101class ParameterError(DKIMException): 102 103 """Input parameter error.""" 104 pass 105 106 107def _remove(s, t): 108 i = s.find(t) 109 assert i >= 0 110 return s[:i] + s[i + len(t):] 111 112INTEGER = 0x02 113BIT_STRING = 0x03 114OCTET_STRING = 0x04 115NULL = 0x05 116OBJECT_IDENTIFIER = 0x06 117SEQUENCE = 0x30 118 119ASN1_Object = [ 120 (SEQUENCE, [ 121 (SEQUENCE, [ 122 (OBJECT_IDENTIFIER,), 123 (NULL,), 124 ]), 125 (BIT_STRING,), 126 ]) 127] 128 129ASN1_RSAPublicKey = [ 130 (SEQUENCE, [ 131 (INTEGER,), 132 (INTEGER,), 133 ]) 134] 135 136ASN1_RSAPrivateKey = [ 137 (SEQUENCE, [ 138 (INTEGER,), 139 (INTEGER,), 140 (INTEGER,), 141 (INTEGER,), 142 (INTEGER,), 143 (INTEGER,), 144 (INTEGER,), 145 (INTEGER,), 146 (INTEGER,), 147 ]) 148] 149 150 151def asn1_parse(template, data): 152 """Parse a data structure according to ASN.1 template. 153 154 @param template: A list of tuples comprising the ASN.1 template. 155 @param data: A list of bytes to parse. 156 157 """ 158 159 r = [] 160 i = 0 161 for t in template: 162 tag = ord(data[i]) 163 i += 1 164 if tag == t[0]: 165 length = ord(data[i]) 166 i += 1 167 if length & 0x80: 168 n = length & 0x7f 169 length = 0 170 for j in range(n): 171 length = (length << 8) | ord(data[i]) 172 i += 1 173 if tag == INTEGER: 174 n = 0 175 for j in range(length): 176 n = (n << 8) | ord(data[i]) 177 i += 1 178 r.append(n) 179 elif tag == BIT_STRING: 180 r.append(data[i:i + length]) 181 i += length 182 elif tag == NULL: 183 assert length == 0 184 r.append(None) 185 elif tag == OBJECT_IDENTIFIER: 186 r.append(data[i:i + length]) 187 i += length 188 elif tag == SEQUENCE: 189 r.append(asn1_parse(t[1], data[i:i + length])) 190 i += length 191 else: 192 raise KeyFormatError("Unexpected tag in template: %02x" % tag) 193 else: 194 raise KeyFormatError( 195 "Unexpected tag (got %02x, expecting %02x)" % (tag, t[0])) 196 return r 197 198 199def asn1_length(n): 200 """Return a string representing a field length in ASN.1 format.""" 201 assert n >= 0 202 if n < 0x7f: 203 return chr(n) 204 r = "" 205 while n > 0: 206 r = chr(n & 0xff) + r 207 n >>= 8 208 return r 209 210 211def asn1_build(node): 212 """Build an ASN.1 data structure based on pairs of (type, data).""" 213 if node[0] == OCTET_STRING: 214 return chr(OCTET_STRING) + asn1_length(len(node[1])) + node[1] 215 if node[0] == NULL: 216 assert node[1] is None 217 return chr(NULL) + asn1_length(0) 218 elif node[0] == OBJECT_IDENTIFIER: 219 return chr(OBJECT_IDENTIFIER) + asn1_length(len(node[1])) + node[1] 220 elif node[0] == SEQUENCE: 221 r = "" 222 for x in node[1]: 223 r += asn1_build(x) 224 return chr(SEQUENCE) + asn1_length(len(r)) + r 225 else: 226 raise InternalError("Unexpected tag in template: %02x" % node[0]) 227 228# These values come from RFC 3447, section 9.2 Notes, page 43. 229HASHID_SHA1 = "\x2b\x0e\x03\x02\x1a" 230HASHID_SHA256 = "\x60\x86\x48\x01\x65\x03\x04\x02\x01" 231 232 233def str2int(s): 234 """Convert an octet string to an integer. Octet string assumed to represent a positive integer.""" 235 r = 0 236 for c in s: 237 r = (r << 8) | ord(c) 238 return r 239 240 241def int2str(n, length=-1): 242 """Convert an integer to an octet string. Number must be positive. 243 244 @param n: Number to convert. 245 @param length: Minimum length, or -1 to return the smallest number of bytes that represent the integer. 246 247 """ 248 249 assert n >= 0 250 r = [] 251 while length < 0 or len(r) < length: 252 r.append(chr(n & 0xff)) 253 n >>= 8 254 if length < 0 and n == 0: 255 break 256 r.reverse() 257 assert length < 0 or len(r) == length 258 return r 259 260 261def rfc822_parse(message): 262 """Parse a message in RFC822 format. 263 264 @param message: The message in RFC822 format. Either CRLF or LF is an accepted line separator. 265 266 @return Returns a tuple of (headers, body) where headers is a list of (name, value) pairs. 267 The body is a CRLF-separated string. 268 269 """ 270 271 headers = [] 272 lines = re.split("\r?\n", message) 273 i = 0 274 while i < len(lines): 275 if len(lines[i]) == 0: 276 # End of headers, return what we have plus the body, excluding the 277 # blank line. 278 i += 1 279 break 280 if re.match(r"[\x09\x20]", lines[i][0]): 281 headers[-1][1] += lines[i] + "\r\n" 282 else: 283 m = re.match(r"([\x21-\x7e]+?):", lines[i]) 284 if m is not None: 285 headers.append([m.group(1), lines[i][m.end(0):] + "\r\n"]) 286 elif lines[i].startswith("From "): 287 pass 288 else: 289 raise MessageFormatError( 290 "Unexpected characters in RFC822 header: %s" % lines[i]) 291 i += 1 292 return headers, "\r\n".join(lines[i:]) 293 294 295def dnstxt(name): 296 import dns.resolver 297 a = dns.resolver.query(name, dns.rdatatype.TXT) 298 for r in a.response.answer: 299 if r.rdtype == dns.rdatatype.TXT: 300 return "".join(r.items[0].strings) 301 return None 302 303 304def fold(header): 305 """Fold a header line into multiple crlf-separated lines at column 72.""" 306 i = header.rfind("\r\n ") 307 if i == -1: 308 pre = "" 309 else: 310 i += 3 311 pre = header[:i] 312 header = header[i:] 313 while len(header) > 72: 314 i = header[:72].rfind(" ") 315 if i == -1: 316 j = i 317 else: 318 j = i + 1 319 pre += header[:i] + "\r\n " 320 header = header[j:] 321 return pre + header 322 323 324def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple, Simple), include_headers=None, length=False, debuglog=None): 325 """Sign an RFC822 message and return the DKIM-Signature header line. 326 327 @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings) 328 @param selector: the DKIM selector value for the signature 329 @param domain: the DKIM domain value for the signature 330 @param privkey: a PKCS#1 private key in base64-encoded text form 331 @param identity: the DKIM identity value for the signature (default "@"+domain) 332 @param canonicalize: the canonicalization algorithms to use (default (Simple, Simple)) 333 @param include_headers: a list of strings indicating which headers are to be signed (default all headers) 334 @param length: true if the l= tag should be included to indicate body length (default False) 335 @param debuglog: a file-like object to which debug info will be written (default None) 336 337 """ 338 339 (headers, body) = rfc822_parse(message) 340 341 m = re.search("--\n(.*?)\n--", privkey, re.DOTALL) 342 if m is None: 343 raise KeyFormatError("Private key not found") 344 try: 345 pkdata = base64.b64decode(m.group(1)) 346 pkdata = forceCharFromBytes(pkdata) 347 348 except TypeError as e: 349 raise KeyFormatError(str(e)) 350 if debuglog is not None: 351 print(" ".join("%02x" % ord(x) for x in pkdata), file=debuglog) 352 pka = asn1_parse(ASN1_RSAPrivateKey, pkdata) 353 pk = { 354 'version': pka[0][0], 355 'modulus': pka[0][1], 356 'publicExponent': pka[0][2], 357 'privateExponent': pka[0][3], 358 'prime1': pka[0][4], 359 'prime2': pka[0][5], 360 'exponent1': pka[0][6], 361 'exponent2': pka[0][7], 362 'coefficient': pka[0][8], 363 } 364 365 if identity is not None and not identity.endswith(domain): 366 raise ParameterError("identity must end with domain") 367 368 headers = canonicalize[0].canonicalize_headers(headers) 369 370 if include_headers is None: 371 include_headers = [x[0].lower() for x in headers] 372 else: 373 include_headers = [x.lower() for x in include_headers] 374 sign_headers = [x for x in headers if x[0].lower() in include_headers] 375 376 body = canonicalize[1].canonicalize_body(body) 377 378 h = hashlib.sha256() 379 h.update(force_bString(body)) 380 bodyhash = base64.b64encode(h.digest()) 381 bodyhash = forceCharFromBytes(bodyhash) 382 383 sigfields = [x for x in [ 384 ('v', "1"), 385 ('a', "rsa-sha256"), 386 ('c', "%s/%s" % (canonicalize[0].name, canonicalize[1].name)), 387 ('d', domain), 388 ('i', identity or "@" + domain), 389 length and ('l', len(body)), 390 ('q', "dns/txt"), 391 ('s', selector), 392 ('t', str(int(time.time()))), 393 ('h', " : ".join(x[0] for x in sign_headers)), 394 ('bh', bodyhash), 395 ('b', ""), 396 ] if x] 397 sig = "DKIM-Signature: " + "; ".join("%s=%s" % x for x in sigfields) 398 399 sig = fold(sig) 400 401 if debuglog is not None: 402 print("sign headers:", sign_headers + 403 [("DKIM-Signature", " " + "; ".join("%s=%s" % 404 x for x in sigfields))], file=debuglog) 405 h = hashlib.sha256() 406 for x in sign_headers: 407 h.update(force_bString(x[0])) 408 h.update(b":") 409 h.update(force_bString(x[1])) 410 h.update(force_bString(sig)) 411 d = h.digest() 412 d = forceCharFromBytes(d) 413 414 if debuglog is not None: 415 print("sign digest:", " ".join("%02x" % ord(x) 416 for x in d), file=debuglog) 417 418 dinfo = asn1_build( 419 (SEQUENCE, [ 420 (SEQUENCE, [ 421 (OBJECT_IDENTIFIER, HASHID_SHA256), 422 (NULL, None), 423 ]), 424 (OCTET_STRING, d), 425 ]) 426 ) 427 modlen = len(int2str(pk['modulus'])) 428 if len(dinfo) + 3 > modlen: 429 raise ParameterError("Hash too large for modulus") 430 signature = "\x00\x01" + "\xff" * (modlen - len(dinfo) - 3) +"\x00" + dinfo 431 sig2 = int2str(pow(str2int(signature), pk['privateExponent'], pk['modulus']), modlen) 432 sigEncoded = base64.b64encode(forceBytesFromChar(''.join(sig2))) 433 sigEncoded = forceCharFromBytes(sigEncoded) 434 sig += sigEncoded 435 436 return sig + "\r\n" 437 438 439def verify(message, debuglog=None): 440 """Verify a DKIM signature on an RFC822 formatted message. 441 442 @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings) 443 @param debuglog: a file-like object to which debug info will be written (default None) 444 445 """ 446 447 (headers, body) = rfc822_parse(message) 448 449 sigheaders = [x for x in headers if x[0].lower() == "dkim-signature"] 450 if len(sigheaders) < 1: 451 return False 452 453 # Currently, we only validate the first DKIM-Signature line found. 454 455 a = re.split(r"\s*;\s*", sigheaders[0][1].strip()) 456 if debuglog is not None: 457 print("a:", a, file=debuglog) 458 sig = {} 459 for x in a: 460 if x: 461 m = re.match(r"(\w+)\s*=\s*(.*)", x, re.DOTALL) 462 if m is None: 463 if debuglog is not None: 464 print("invalid format of signature part: %s" % 465 x, file=debuglog) 466 return False 467 sig[m.group(1)] = m.group(2) 468 if debuglog is not None: 469 print("sig:", sig, file=debuglog) 470 471 if 'v' not in sig: 472 if debuglog is not None: 473 print("signature missing v=", file=debuglog) 474 return False 475 if sig['v'] != "1": 476 if debuglog is not None: 477 print("v= value is not 1 (%s)" % sig['v'], file=debuglog) 478 return False 479 if 'a' not in sig: 480 if debuglog is not None: 481 print("signature missing a=", file=debuglog) 482 return False 483 if 'b' not in sig: 484 if debuglog is not None: 485 print("signature missing b=", file=debuglog) 486 return False 487 if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: 488 if debuglog is not None: 489 print("b= value is not valid base64 (%s)" % 490 sig['b'], file=debuglog) 491 return False 492 if 'bh' not in sig: 493 if debuglog is not None: 494 print("signature missing bh=", file=debuglog) 495 return False 496 if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: 497 if debuglog is not None: 498 print("bh= value is not valid base64 (%s)" % 499 sig['bh'], file=debuglog) 500 return False 501 if 'd' not in sig: 502 if debuglog is not None: 503 print("signature missing d=", file=debuglog) 504 return False 505 if 'h' not in sig: 506 if debuglog is not None: 507 print("signature missing h=", file=debuglog) 508 return False 509 if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d']) - 1] not in "@."): 510 if debuglog is not None: 511 print("i= domain is not a subdomain of d= (i=%s d=%d)" % ( 512 sig['i'], sig['d']), file=debuglog) 513 return False 514 if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: 515 if debuglog is not None: 516 print("l= value is not a decimal integer (%s)" % sig[ 517 'l'], file=debuglog) 518 return False 519 if 'q' in sig and sig['q'] != "dns/txt": 520 if debuglog is not None: 521 print("q= value is not dns/txt (%s)" % sig['q'], file=debuglog) 522 return False 523 if 's' not in sig: 524 if debuglog is not None: 525 print("signature missing s=", file=debuglog) 526 return False 527 if 't' in sig and re.match(r"\d+$", sig['t']) is None: 528 if debuglog is not None: 529 print("t= value is not a decimal integer (%s)" % sig[ 530 't'], file=debuglog) 531 return False 532 if 'x' in sig: 533 if re.match(r"\d+$", sig['x']) is None: 534 if debuglog is not None: 535 print("x= value is not a decimal integer (%s)" % sig[ 536 'x'], file=debuglog) 537 return False 538 if int(sig['x']) < int(sig['t']): 539 if debuglog is not None: 540 print("x= value is less than t= value (x=%s t=%s)" % ( 541 sig['x'], sig['t']), file=debuglog) 542 return False 543 544 m = re.match("(\w+)(?:/(\w+))?$", sig['c']) 545 if m is None: 546 if debuglog is not None: 547 print("c= value is not in format method/method (%s)" % sig[ 548 'c'], file=debuglog) 549 return False 550 can_headers = m.group(1) 551 if m.group(2) is not None: 552 can_body = m.group(2) 553 else: 554 can_body = "simple" 555 556 if can_headers == "simple": 557 canonicalize_headers = Simple 558 elif can_headers == "relaxed": 559 canonicalize_headers = Relaxed 560 else: 561 if debuglog is not None: 562 print("Unknown header canonicalization (%s)" % 563 can_headers, file=debuglog) 564 return False 565 566 headers = canonicalize_headers.canonicalize_headers(headers) 567 568 if can_body == "simple": 569 body = Simple.canonicalize_body(body) 570 elif can_body == "relaxed": 571 body = Relaxed.canonicalize_body(body) 572 else: 573 if debuglog is not None: 574 print("Unknown body canonicalization (%s)" % 575 can_body, file=debuglog) 576 return False 577 578 if sig['a'] == "rsa-sha1": 579 hasher = hashlib.sha1 580 hashid = HASHID_SHA1 581 elif sig['a'] == "rsa-sha256": 582 hasher = hashlib.sha256 583 hashid = HASHID_SHA256 584 else: 585 if debuglog is not None: 586 print("Unknown signature algorithm (%s)" % sig['a'], file=debuglog) 587 return False 588 589 if 'l' in sig: 590 body = body[:int(sig['l'])] 591 592 h = hasher() 593 h.update(force_bString(body)) 594 bodyhash = h.digest() 595 if debuglog is not None: 596 print("bh:", base64.b64encode(bodyhash), file=debuglog) 597 if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])): 598 if debuglog is not None: 599 print("body hash mismatch (got %s, expected %s)" % ( 600 base64.b64encode(bodyhash), sig['bh']), file=debuglog) 601 return False 602 603 s = dnstxt(sig['s'] + "._domainkey." + sig['d'] + ".") 604 if not s: 605 return False 606 a = re.split(r"\s*;\s*", s) 607 pub = {} 608 for f in a: 609 m = re.match(r"(\w+)=(.*)", f) 610 if m is not None: 611 pub[m.group(1)] = m.group(2) 612 else: 613 if debuglog is not None: 614 print("invalid format in _domainkey txt record", file=debuglog) 615 return False 616 pkey = base64.b64decode(pub['p']) 617 pkey = forceCharFromBytes(pkey) 618 619 x = asn1_parse(ASN1_Object, pkey) 620 # Not sure why the [1:] is necessary to skip a byte. 621 pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:]) 622 pk = { 623 'modulus': pkd[0][0], 624 'publicExponent': pkd[0][1], 625 } 626 modlen = len(int2str(pk['modulus'])) 627 if debuglog is not None: 628 print("modlen:", modlen, file=debuglog) 629 630 include_headers = re.split(r"\s*:\s*", sig['h']) 631 if debuglog is not None: 632 print("include_headers:", include_headers, file=debuglog) 633 sign_headers = [] 634 lastindex = {} 635 for h in include_headers: 636 i = lastindex.get(h, len(headers)) 637 while i > 0: 638 i -= 1 639 if h.lower() == headers[i][0].lower(): 640 sign_headers.append(headers[i]) 641 break 642 lastindex[h] = i 643 # The call to _remove() assumes that the signature b= only appears once in 644 # the signature header 645 sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers( 646 [(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])] 647 if debuglog is not None: 648 print("verify headers:", sign_headers, file=debuglog) 649 650 h = hasher() 651 for x in sign_headers: 652 h.update(force_bString(x[0])) 653 h.update(force_bString(":")) 654 h.update(force_bString(x[1])) 655 d = h.digest() 656 d = forceCharFromBytes(d) 657 658 if debuglog is not None: 659 print("verify digest:", " ".join( 660 "%02x" % ord(x) for x in d), file=debuglog) 661 662 663 dinfo = asn1_build( 664 (SEQUENCE, [ 665 (SEQUENCE, [ 666 (OBJECT_IDENTIFIER, hashid), 667 (NULL, None), 668 ]), 669 (OCTET_STRING, d), 670 ]) 671 ) 672 if debuglog is not None: 673 print("dinfo:", " ".join("%02x" % ord(x) 674 for x in dinfo), file=debuglog) 675 if len(dinfo) + 3 > modlen: 676 if debuglog is not None: 677 print("Hash too large for modulus", file=debuglog) 678 return False 679 sig2 = "\x00\x01" + "\xff" * (modlen - len(dinfo) - 3) + "\x00" + dinfo 680 sig2 = forceCharFromBytes(sig2) 681 if debuglog is not None: 682 print("sig2:", " ".join("%02x" % ord(x) for x in sig2), file=debuglog) 683 print(sig['b'], file=debuglog) 684 print(re.sub(r"\s+", "", sig['b']), file=debuglog) 685 686 687 sigEncoded = base64.b64decode(forceBytesFromChar(re.sub(r"\s+", "", sig['b']))) 688 sigEncoded = forceCharFromBytes((sigEncoded)) 689 690 v = int2str(pow(str2int(sigEncoded), pk['publicExponent'], pk['modulus']), modlen) 691 692 if debuglog is not None: 693 print("v:", " ".join("%02x" % ord(x) for x in v), file=debuglog) 694 assert len(v) == len(sig2) 695 # Byte-by-byte compare of signatures 696 return not [1 for x in zip(v, sig2) if x[0] != x[1]] 697 698if __name__ == "__main__": 699 msg = """From: greg@hewgill.com\r\nSubject: test\r\n message\r\n\r\nHi.\r\n\r\nWe lost the game. Are you hungry yet?\r\n\r\nJoe.\r\n""" 700 print(rfc822_parse(msg)) 701 sign = sign(msg, "greg", "hewgill.com", open( 702 "/home/greg/.domainkeys/rsa.private").read()) 703 print(sign) 704 print(verify(sign + msg)) 705 # print sign(open("/home/greg/tmp/message").read(), "greg", "hewgill.com", 706 # open("/home/greg/.domainkeys/rsa.private").read()) 707