1# This software is provided 'as-is', without any express or implied
2# warranty.  In no event will the author be held liable for any damages
3# arising from the use of this software.
4#
5# Permission is granted to anyone to use this software for any purpose,
6# including commercial applications, and to alter it and redistribute it
7# freely, subject to the following restrictions:
8#
9# 1. The origin of this software must not be misrepresented; you must not
10#    claim that you wrote the original software. If you use this software
11#    in a product, an acknowledgment in the product documentation would be
12#    appreciated but is not required.
13# 2. Altered source versions must be plainly marked as such, and must not be
14#    misrepresented as being the original software.
15# 3. This notice may not be removed or altered from any source distribution.
16#
17# Copyright (c) 2008 Greg Hewgill http://hewgill.com
18
19from __future__ import print_function
20import base64
21import hashlib
22import re
23import time
24from fuglu.localStringEncoding import force_bString, force_uString, forceBytesFromChar, forceCharFromBytes
25
26__all__ = [
27    "Simple",
28    "Relaxed",
29    "InternalError",
30    "KeyFormatError",
31    "MessageFormatError",
32    "ParameterError",
33    "sign",
34    "verify",
35]
36
37
38class Simple(object):
39
40    """Class that represents the "simple" canonicalization algorithm."""
41
42    name = "simple"
43
44    @staticmethod
45    def canonicalize_headers(headers):
46        # No changes to headers.
47        return headers
48
49    @staticmethod
50    def canonicalize_body(body):
51        # Ignore all empty lines at the end of the message body.
52        return re.sub("(\r\n)*$", "\r\n", body)
53
54
55class Relaxed(object):
56
57    """Class that represents the "relaxed" canonicalization algorithm."""
58
59    name = "relaxed"
60
61    @staticmethod
62    def canonicalize_headers(headers):
63        # Convert all header field names to lowercase.
64        # Unfold all header lines.
65        # Compress WSP to single space.
66        # Remove all WSP at the start or end of the field value (strip).
67        return [(x[0].lower(), re.sub(r"\s+", " ", re.sub("\r\n", "", x[1])).strip() + "\r\n") for x in headers]
68
69    @staticmethod
70    def canonicalize_body(body):
71        # Remove all trailing WSP at end of lines.
72        # Compress non-line-ending WSP to single space.
73        # Ignore all empty lines at the end of the message body.
74        return re.sub("(\r\n)*$", "\r\n", re.sub(r"[\x09\x20]+", " ", re.sub("[\\x09\\x20]+\r\n", "\r\n", body)))
75
76
77class DKIMException(Exception):
78
79    """Base class for DKIM errors."""
80    pass
81
82
83class InternalError(DKIMException):
84
85    """Internal error in dkim module. Should never happen."""
86    pass
87
88
89class KeyFormatError(DKIMException):
90
91    """Key format error while parsing an RSA public or private key."""
92    pass
93
94
95class MessageFormatError(DKIMException):
96
97    """RFC822 message format error."""
98    pass
99
100
101class ParameterError(DKIMException):
102
103    """Input parameter error."""
104    pass
105
106
107def _remove(s, t):
108    i = s.find(t)
109    assert i >= 0
110    return s[:i] + s[i + len(t):]
111
112INTEGER = 0x02
113BIT_STRING = 0x03
114OCTET_STRING = 0x04
115NULL = 0x05
116OBJECT_IDENTIFIER = 0x06
117SEQUENCE = 0x30
118
119ASN1_Object = [
120    (SEQUENCE, [
121        (SEQUENCE, [
122            (OBJECT_IDENTIFIER,),
123            (NULL,),
124        ]),
125        (BIT_STRING,),
126    ])
127]
128
129ASN1_RSAPublicKey = [
130    (SEQUENCE, [
131        (INTEGER,),
132        (INTEGER,),
133    ])
134]
135
136ASN1_RSAPrivateKey = [
137    (SEQUENCE, [
138        (INTEGER,),
139        (INTEGER,),
140        (INTEGER,),
141        (INTEGER,),
142        (INTEGER,),
143        (INTEGER,),
144        (INTEGER,),
145        (INTEGER,),
146        (INTEGER,),
147    ])
148]
149
150
151def asn1_parse(template, data):
152    """Parse a data structure according to ASN.1 template.
153
154    @param template: A list of tuples comprising the ASN.1 template.
155    @param data: A list of bytes to parse.
156
157    """
158
159    r = []
160    i = 0
161    for t in template:
162        tag = ord(data[i])
163        i += 1
164        if tag == t[0]:
165            length = ord(data[i])
166            i += 1
167            if length & 0x80:
168                n = length & 0x7f
169                length = 0
170                for j in range(n):
171                    length = (length << 8) | ord(data[i])
172                    i += 1
173            if tag == INTEGER:
174                n = 0
175                for j in range(length):
176                    n = (n << 8) | ord(data[i])
177                    i += 1
178                r.append(n)
179            elif tag == BIT_STRING:
180                r.append(data[i:i + length])
181                i += length
182            elif tag == NULL:
183                assert length == 0
184                r.append(None)
185            elif tag == OBJECT_IDENTIFIER:
186                r.append(data[i:i + length])
187                i += length
188            elif tag == SEQUENCE:
189                r.append(asn1_parse(t[1], data[i:i + length]))
190                i += length
191            else:
192                raise KeyFormatError("Unexpected tag in template: %02x" % tag)
193        else:
194            raise KeyFormatError(
195                "Unexpected tag (got %02x, expecting %02x)" % (tag, t[0]))
196    return r
197
198
199def asn1_length(n):
200    """Return a string representing a field length in ASN.1 format."""
201    assert n >= 0
202    if n < 0x7f:
203        return chr(n)
204    r = ""
205    while n > 0:
206        r = chr(n & 0xff) + r
207        n >>= 8
208    return r
209
210
211def asn1_build(node):
212    """Build an ASN.1 data structure based on pairs of (type, data)."""
213    if node[0] == OCTET_STRING:
214        return chr(OCTET_STRING) + asn1_length(len(node[1])) + node[1]
215    if node[0] == NULL:
216        assert node[1] is None
217        return chr(NULL) + asn1_length(0)
218    elif node[0] == OBJECT_IDENTIFIER:
219        return chr(OBJECT_IDENTIFIER) + asn1_length(len(node[1])) + node[1]
220    elif node[0] == SEQUENCE:
221        r = ""
222        for x in node[1]:
223            r += asn1_build(x)
224        return chr(SEQUENCE) + asn1_length(len(r)) + r
225    else:
226        raise InternalError("Unexpected tag in template: %02x" % node[0])
227
228# These values come from RFC 3447, section 9.2 Notes, page 43.
229HASHID_SHA1 = "\x2b\x0e\x03\x02\x1a"
230HASHID_SHA256 = "\x60\x86\x48\x01\x65\x03\x04\x02\x01"
231
232
233def str2int(s):
234    """Convert an octet string to an integer. Octet string assumed to represent a positive integer."""
235    r = 0
236    for c in s:
237        r = (r << 8) | ord(c)
238    return r
239
240
241def int2str(n, length=-1):
242    """Convert an integer to an octet string. Number must be positive.
243
244    @param n: Number to convert.
245    @param length: Minimum length, or -1 to return the smallest number of bytes that represent the integer.
246
247    """
248
249    assert n >= 0
250    r = []
251    while length < 0 or len(r) < length:
252        r.append(chr(n & 0xff))
253        n >>= 8
254        if length < 0 and n == 0:
255            break
256    r.reverse()
257    assert length < 0 or len(r) == length
258    return r
259
260
261def rfc822_parse(message):
262    """Parse a message in RFC822 format.
263
264    @param message: The message in RFC822 format. Either CRLF or LF is an accepted line separator.
265
266    @return Returns a tuple of (headers, body) where headers is a list of (name, value) pairs.
267    The body is a CRLF-separated string.
268
269    """
270
271    headers = []
272    lines = re.split("\r?\n", message)
273    i = 0
274    while i < len(lines):
275        if len(lines[i]) == 0:
276            # End of headers, return what we have plus the body, excluding the
277            # blank line.
278            i += 1
279            break
280        if re.match(r"[\x09\x20]", lines[i][0]):
281            headers[-1][1] += lines[i] + "\r\n"
282        else:
283            m = re.match(r"([\x21-\x7e]+?):", lines[i])
284            if m is not None:
285                headers.append([m.group(1), lines[i][m.end(0):] + "\r\n"])
286            elif lines[i].startswith("From "):
287                pass
288            else:
289                raise MessageFormatError(
290                    "Unexpected characters in RFC822 header: %s" % lines[i])
291        i += 1
292    return headers, "\r\n".join(lines[i:])
293
294
295def dnstxt(name):
296    import dns.resolver
297    a = dns.resolver.query(name, dns.rdatatype.TXT)
298    for r in a.response.answer:
299        if r.rdtype == dns.rdatatype.TXT:
300            return "".join(r.items[0].strings)
301    return None
302
303
304def fold(header):
305    """Fold a header line into multiple crlf-separated lines at column 72."""
306    i = header.rfind("\r\n ")
307    if i == -1:
308        pre = ""
309    else:
310        i += 3
311        pre = header[:i]
312        header = header[i:]
313    while len(header) > 72:
314        i = header[:72].rfind(" ")
315        if i == -1:
316            j = i
317        else:
318            j = i + 1
319        pre += header[:i] + "\r\n "
320        header = header[j:]
321    return pre + header
322
323
324def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple, Simple), include_headers=None, length=False, debuglog=None):
325    """Sign an RFC822 message and return the DKIM-Signature header line.
326
327    @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings)
328    @param selector: the DKIM selector value for the signature
329    @param domain: the DKIM domain value for the signature
330    @param privkey: a PKCS#1 private key in base64-encoded text form
331    @param identity: the DKIM identity value for the signature (default "@"+domain)
332    @param canonicalize: the canonicalization algorithms to use (default (Simple, Simple))
333    @param include_headers: a list of strings indicating which headers are to be signed (default all headers)
334    @param length: true if the l= tag should be included to indicate body length (default False)
335    @param debuglog: a file-like object to which debug info will be written (default None)
336
337    """
338
339    (headers, body) = rfc822_parse(message)
340
341    m = re.search("--\n(.*?)\n--", privkey, re.DOTALL)
342    if m is None:
343        raise KeyFormatError("Private key not found")
344    try:
345        pkdata = base64.b64decode(m.group(1))
346        pkdata = forceCharFromBytes(pkdata)
347
348    except TypeError as e:
349        raise KeyFormatError(str(e))
350    if debuglog is not None:
351        print(" ".join("%02x" % ord(x) for x in pkdata), file=debuglog)
352    pka = asn1_parse(ASN1_RSAPrivateKey, pkdata)
353    pk = {
354        'version': pka[0][0],
355        'modulus': pka[0][1],
356        'publicExponent': pka[0][2],
357        'privateExponent': pka[0][3],
358        'prime1': pka[0][4],
359        'prime2': pka[0][5],
360        'exponent1': pka[0][6],
361        'exponent2': pka[0][7],
362        'coefficient': pka[0][8],
363    }
364
365    if identity is not None and not identity.endswith(domain):
366        raise ParameterError("identity must end with domain")
367
368    headers = canonicalize[0].canonicalize_headers(headers)
369
370    if include_headers is None:
371        include_headers = [x[0].lower() for x in headers]
372    else:
373        include_headers = [x.lower() for x in include_headers]
374    sign_headers = [x for x in headers if x[0].lower() in include_headers]
375
376    body = canonicalize[1].canonicalize_body(body)
377
378    h = hashlib.sha256()
379    h.update(force_bString(body))
380    bodyhash = base64.b64encode(h.digest())
381    bodyhash = forceCharFromBytes(bodyhash)
382
383    sigfields = [x for x in [
384        ('v', "1"),
385        ('a', "rsa-sha256"),
386        ('c', "%s/%s" % (canonicalize[0].name, canonicalize[1].name)),
387        ('d', domain),
388        ('i', identity or "@" + domain),
389        length and ('l', len(body)),
390        ('q', "dns/txt"),
391        ('s', selector),
392        ('t', str(int(time.time()))),
393        ('h', " : ".join(x[0] for x in sign_headers)),
394        ('bh', bodyhash),
395        ('b', ""),
396    ] if x]
397    sig = "DKIM-Signature: " + "; ".join("%s=%s" % x for x in sigfields)
398
399    sig = fold(sig)
400
401    if debuglog is not None:
402        print("sign headers:", sign_headers +
403              [("DKIM-Signature", " " + "; ".join("%s=%s" %
404                                                  x for x in sigfields))], file=debuglog)
405    h = hashlib.sha256()
406    for x in sign_headers:
407        h.update(force_bString(x[0]))
408        h.update(b":")
409        h.update(force_bString(x[1]))
410    h.update(force_bString(sig))
411    d = h.digest()
412    d = forceCharFromBytes(d)
413
414    if debuglog is not None:
415        print("sign digest:", " ".join("%02x" % ord(x)
416                                       for x in d), file=debuglog)
417
418    dinfo = asn1_build(
419        (SEQUENCE, [
420            (SEQUENCE, [
421                (OBJECT_IDENTIFIER, HASHID_SHA256),
422                (NULL, None),
423            ]),
424            (OCTET_STRING, d),
425        ])
426    )
427    modlen = len(int2str(pk['modulus']))
428    if len(dinfo) + 3 > modlen:
429        raise ParameterError("Hash too large for modulus")
430    signature = "\x00\x01" + "\xff" * (modlen - len(dinfo) - 3) +"\x00" + dinfo
431    sig2 = int2str(pow(str2int(signature), pk['privateExponent'], pk['modulus']), modlen)
432    sigEncoded = base64.b64encode(forceBytesFromChar(''.join(sig2)))
433    sigEncoded = forceCharFromBytes(sigEncoded)
434    sig += sigEncoded
435
436    return sig + "\r\n"
437
438
439def verify(message, debuglog=None):
440    """Verify a DKIM signature on an RFC822 formatted message.
441
442    @param message: an RFC822 formatted message (with either \\n or \\r\\n line endings)
443    @param debuglog: a file-like object to which debug info will be written (default None)
444
445    """
446
447    (headers, body) = rfc822_parse(message)
448
449    sigheaders = [x for x in headers if x[0].lower() == "dkim-signature"]
450    if len(sigheaders) < 1:
451        return False
452
453    # Currently, we only validate the first DKIM-Signature line found.
454
455    a = re.split(r"\s*;\s*", sigheaders[0][1].strip())
456    if debuglog is not None:
457        print("a:", a, file=debuglog)
458    sig = {}
459    for x in a:
460        if x:
461            m = re.match(r"(\w+)\s*=\s*(.*)", x, re.DOTALL)
462            if m is None:
463                if debuglog is not None:
464                    print("invalid format of signature part: %s" %
465                          x, file=debuglog)
466                return False
467            sig[m.group(1)] = m.group(2)
468    if debuglog is not None:
469        print("sig:", sig, file=debuglog)
470
471    if 'v' not in sig:
472        if debuglog is not None:
473            print("signature missing v=", file=debuglog)
474        return False
475    if sig['v'] != "1":
476        if debuglog is not None:
477            print("v= value is not 1 (%s)" % sig['v'], file=debuglog)
478        return False
479    if 'a' not in sig:
480        if debuglog is not None:
481            print("signature missing a=", file=debuglog)
482        return False
483    if 'b' not in sig:
484        if debuglog is not None:
485            print("signature missing b=", file=debuglog)
486        return False
487    if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
488        if debuglog is not None:
489            print("b= value is not valid base64 (%s)" %
490                  sig['b'], file=debuglog)
491        return False
492    if 'bh' not in sig:
493        if debuglog is not None:
494            print("signature missing bh=", file=debuglog)
495        return False
496    if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
497        if debuglog is not None:
498            print("bh= value is not valid base64 (%s)" %
499                  sig['bh'], file=debuglog)
500        return False
501    if 'd' not in sig:
502        if debuglog is not None:
503            print("signature missing d=", file=debuglog)
504        return False
505    if 'h' not in sig:
506        if debuglog is not None:
507            print("signature missing h=", file=debuglog)
508        return False
509    if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d']) - 1] not in "@."):
510        if debuglog is not None:
511            print("i= domain is not a subdomain of d= (i=%s d=%d)" % (
512                sig['i'], sig['d']), file=debuglog)
513        return False
514    if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
515        if debuglog is not None:
516            print("l= value is not a decimal integer (%s)" % sig[
517                'l'], file=debuglog)
518        return False
519    if 'q' in sig and sig['q'] != "dns/txt":
520        if debuglog is not None:
521            print("q= value is not dns/txt (%s)" % sig['q'], file=debuglog)
522        return False
523    if 's' not in sig:
524        if debuglog is not None:
525            print("signature missing s=", file=debuglog)
526        return False
527    if 't' in sig and re.match(r"\d+$", sig['t']) is None:
528        if debuglog is not None:
529            print("t= value is not a decimal integer (%s)" % sig[
530                't'], file=debuglog)
531        return False
532    if 'x' in sig:
533        if re.match(r"\d+$", sig['x']) is None:
534            if debuglog is not None:
535                print("x= value is not a decimal integer (%s)" % sig[
536                    'x'], file=debuglog)
537            return False
538        if int(sig['x']) < int(sig['t']):
539            if debuglog is not None:
540                print("x= value is less than t= value (x=%s t=%s)" % (
541                    sig['x'], sig['t']), file=debuglog)
542            return False
543
544    m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
545    if m is None:
546        if debuglog is not None:
547            print("c= value is not in format method/method (%s)" % sig[
548                'c'], file=debuglog)
549        return False
550    can_headers = m.group(1)
551    if m.group(2) is not None:
552        can_body = m.group(2)
553    else:
554        can_body = "simple"
555
556    if can_headers == "simple":
557        canonicalize_headers = Simple
558    elif can_headers == "relaxed":
559        canonicalize_headers = Relaxed
560    else:
561        if debuglog is not None:
562            print("Unknown header canonicalization (%s)" %
563                  can_headers, file=debuglog)
564        return False
565
566    headers = canonicalize_headers.canonicalize_headers(headers)
567
568    if can_body == "simple":
569        body = Simple.canonicalize_body(body)
570    elif can_body == "relaxed":
571        body = Relaxed.canonicalize_body(body)
572    else:
573        if debuglog is not None:
574            print("Unknown body canonicalization (%s)" %
575                  can_body, file=debuglog)
576        return False
577
578    if sig['a'] == "rsa-sha1":
579        hasher = hashlib.sha1
580        hashid = HASHID_SHA1
581    elif sig['a'] == "rsa-sha256":
582        hasher = hashlib.sha256
583        hashid = HASHID_SHA256
584    else:
585        if debuglog is not None:
586            print("Unknown signature algorithm (%s)" % sig['a'], file=debuglog)
587        return False
588
589    if 'l' in sig:
590        body = body[:int(sig['l'])]
591
592    h = hasher()
593    h.update(force_bString(body))
594    bodyhash = h.digest()
595    if debuglog is not None:
596        print("bh:", base64.b64encode(bodyhash), file=debuglog)
597    if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])):
598        if debuglog is not None:
599            print("body hash mismatch (got %s, expected %s)" % (
600                base64.b64encode(bodyhash), sig['bh']), file=debuglog)
601        return False
602
603    s = dnstxt(sig['s'] + "._domainkey." + sig['d'] + ".")
604    if not s:
605        return False
606    a = re.split(r"\s*;\s*", s)
607    pub = {}
608    for f in a:
609        m = re.match(r"(\w+)=(.*)", f)
610        if m is not None:
611            pub[m.group(1)] = m.group(2)
612        else:
613            if debuglog is not None:
614                print("invalid format in _domainkey txt record", file=debuglog)
615            return False
616    pkey = base64.b64decode(pub['p'])
617    pkey = forceCharFromBytes(pkey)
618
619    x = asn1_parse(ASN1_Object, pkey)
620    # Not sure why the [1:] is necessary to skip a byte.
621    pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
622    pk = {
623        'modulus': pkd[0][0],
624        'publicExponent': pkd[0][1],
625    }
626    modlen = len(int2str(pk['modulus']))
627    if debuglog is not None:
628        print("modlen:", modlen, file=debuglog)
629
630    include_headers = re.split(r"\s*:\s*", sig['h'])
631    if debuglog is not None:
632        print("include_headers:", include_headers, file=debuglog)
633    sign_headers = []
634    lastindex = {}
635    for h in include_headers:
636        i = lastindex.get(h, len(headers))
637        while i > 0:
638            i -= 1
639            if h.lower() == headers[i][0].lower():
640                sign_headers.append(headers[i])
641                break
642        lastindex[h] = i
643    # The call to _remove() assumes that the signature b= only appears once in
644    # the signature header
645    sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers(
646        [(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])]
647    if debuglog is not None:
648        print("verify headers:", sign_headers, file=debuglog)
649
650    h = hasher()
651    for x in sign_headers:
652        h.update(force_bString(x[0]))
653        h.update(force_bString(":"))
654        h.update(force_bString(x[1]))
655    d = h.digest()
656    d = forceCharFromBytes(d)
657
658    if debuglog is not None:
659        print("verify digest:", " ".join(
660            "%02x" % ord(x) for x in d), file=debuglog)
661
662
663    dinfo = asn1_build(
664        (SEQUENCE, [
665            (SEQUENCE, [
666                (OBJECT_IDENTIFIER, hashid),
667                (NULL, None),
668            ]),
669            (OCTET_STRING, d),
670        ])
671    )
672    if debuglog is not None:
673        print("dinfo:", " ".join("%02x" % ord(x)
674                                 for x in dinfo), file=debuglog)
675    if len(dinfo) + 3 > modlen:
676        if debuglog is not None:
677            print("Hash too large for modulus", file=debuglog)
678        return False
679    sig2 = "\x00\x01" + "\xff" * (modlen - len(dinfo) - 3) + "\x00" + dinfo
680    sig2 = forceCharFromBytes(sig2)
681    if debuglog is not None:
682        print("sig2:", " ".join("%02x" % ord(x) for x in sig2), file=debuglog)
683        print(sig['b'], file=debuglog)
684        print(re.sub(r"\s+", "", sig['b']), file=debuglog)
685
686
687    sigEncoded = base64.b64decode(forceBytesFromChar(re.sub(r"\s+", "", sig['b'])))
688    sigEncoded = forceCharFromBytes((sigEncoded))
689
690    v = int2str(pow(str2int(sigEncoded), pk['publicExponent'], pk['modulus']), modlen)
691
692    if debuglog is not None:
693        print("v:", " ".join("%02x" % ord(x) for x in v), file=debuglog)
694    assert len(v) == len(sig2)
695    # Byte-by-byte compare of signatures
696    return not [1 for x in zip(v, sig2) if x[0] != x[1]]
697
698if __name__ == "__main__":
699    msg = """From: greg@hewgill.com\r\nSubject: test\r\n message\r\n\r\nHi.\r\n\r\nWe lost the game. Are you hungry yet?\r\n\r\nJoe.\r\n"""
700    print(rfc822_parse(msg))
701    sign = sign(msg, "greg", "hewgill.com", open(
702        "/home/greg/.domainkeys/rsa.private").read())
703    print(sign)
704    print(verify(sign + msg))
705    # print sign(open("/home/greg/tmp/message").read(), "greg", "hewgill.com",
706    # open("/home/greg/.domainkeys/rsa.private").read())
707