diff --git a/dkim/__init__.py b/dkim/__init__.py index b694a07..71e906f 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -94,6 +94,125 @@ def _remove(s, t): assert i >= 0 return s[:i] + s[i+len(t):] + +def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid): + """Encode a digest with EMSA-PKCS1-v1_5. + + Defined in RFC3447 section 9.2. + + @param digest: A digest value to encode. + @param modlen: The desired message length. + @param hashid: The ID of the hash used to generate the digest. + """ + dinfo = asn1_build( + (SEQUENCE, [ + (SEQUENCE, [ + (OBJECT_IDENTIFIER, hashid), + (NULL, None), + ]), + (OCTET_STRING, digest), + ])) + if len(dinfo)+3 > modlen: + raise ParameterError("Hash too large for modulus") + return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo + + +def hash_headers(hasher, canonicalize_headers, headers, include_headers, + sigheaders, sig): + """Sign message header fields.""" + sign_headers = [] + lastindex = {} + for h in include_headers: + i = lastindex.get(h, len(headers)) + while i > 0: + i -= 1 + if h.lower() == headers[i][0].lower(): + sign_headers.append(headers[i]) + break + lastindex[h] = i + # The call to _remove() assumes that the signature b= only appears + # once in the signature header + cheaders = canonicalize_headers.canonicalize_headers( + [(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))]) + sign_headers += [(x[0], x[1].rstrip()) for x in cheaders] + for x in sign_headers: + hasher.update(x[0]) + hasher.update(":") + hasher.update(x[1]) + + +def parse_public_key(data): + """Parse an RSA public key. + + @param data: A DER-encoded X.509 subjectPublicKeyInfo + containing an RFC3447 RSAPublicKey. + """ + x = asn1_parse(ASN1_Object, data) + # Not sure why the [1:] is necessary to skip a byte. + pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:]) + pk = { + 'modulus': pkd[0][0], + 'publicExponent': pkd[0][1], + } + return pk + + +def validate_signature_fields(sig, debuglog=None): + """Validate DKIM-Signature fields. + + Basic checks for presence and correct formatting of mandatory fields. + + @param sig: A dict mapping field keys to values. + @param debuglog: A file-like object to which details will be written + on error. + """ + mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's') + for field in mandatory_fields: + if field not in sig: + if debuglog is not None: + print >>debuglog, "signature missing %s=" % field + return False + + if sig['v'] != "1": + if debuglog is not None: + print >>debuglog, "v= value is not 1 (%s)" % sig['v'] + return False + if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: + if debuglog is not None: + print >>debuglog, "b= value is not valid base64 (%s)" % sig['b'] + return False + if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: + if debuglog is not None: + print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh'] + return False + if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."): + if debuglog is not None: + print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d']) + return False + if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: + if debuglog is not None: + print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l'] + return False + if 'q' in sig and sig['q'] != "dns/txt": + if debuglog is not None: + print >>debuglog, "q= value is not dns/txt (%s)" % sig['q'] + return False + if 't' in sig and re.match(r"\d+$", sig['t']) is None: + if debuglog is not None: + print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t'] + return False + if 'x' in sig: + if re.match(r"\d+$", sig['x']) is None: + if debuglog is not None: + print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x'] + return False + if int(sig['x']) < int(sig['t']): + if debuglog is not None: + print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t']) + return False + return True + + INTEGER = 0x02 BIT_STRING = 0x03 OCTET_STRING = 0x04 @@ -379,19 +498,9 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple if debuglog is not None: print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d) - dinfo = asn1_build( - (SEQUENCE, [ - (SEQUENCE, [ - (OBJECT_IDENTIFIER, HASHID_SHA256), - (NULL, None), - ]), - (OCTET_STRING, d), - ]) - ) modlen = len(int2str(pk['modulus'])) - if len(dinfo)+3 > modlen: - raise ParameterError("Hash too large for modulus") - sig2 = int2str(pow(str2int("\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo), pk['privateExponent'], pk['modulus']), modlen) + encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256) + sig2 = int2str(pow(str2int(encoded), pk['privateExponent'], pk['modulus']), modlen) sig += base64.b64encode(''.join(sig2)) return sig + "\r\n" @@ -427,71 +536,8 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): if debuglog is not None: print >>debuglog, "sig:", sig - if 'v' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing v=" + if not validate_signature_fields(sig, debuglog): return False - if sig['v'] != "1": - if debuglog is not None: - print >>debuglog, "v= value is not 1 (%s)" % sig['v'] - return False - if 'a' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing a=" - return False - if 'b' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing b=" - return False - if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: - if debuglog is not None: - print >>debuglog, "b= value is not valid base64 (%s)" % sig['b'] - return False - if 'bh' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing bh=" - return False - if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: - if debuglog is not None: - print >>debuglog, "bh= value is not valid base64 (%s)" % sig['bh'] - return False - if 'd' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing d=" - return False - if 'h' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing h=" - return False - if 'i' in sig and (not sig['i'].endswith(sig['d']) or sig['i'][-len(sig['d'])-1] not in "@."): - if debuglog is not None: - print >>debuglog, "i= domain is not a subdomain of d= (i=%s d=%d)" % (sig['i'], sig['d']) - return False - if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: - if debuglog is not None: - print >>debuglog, "l= value is not a decimal integer (%s)" % sig['l'] - return False - if 'q' in sig and sig['q'] != "dns/txt": - if debuglog is not None: - print >>debuglog, "q= value is not dns/txt (%s)" % sig['q'] - return False - if 's' not in sig: - if debuglog is not None: - print >>debuglog, "signature missing s=" - return False - if 't' in sig and re.match(r"\d+$", sig['t']) is None: - if debuglog is not None: - print >>debuglog, "t= value is not a decimal integer (%s)" % sig['t'] - return False - if 'x' in sig: - if re.match(r"\d+$", sig['x']) is None: - if debuglog is not None: - print >>debuglog, "x= value is not a decimal integer (%s)" % sig['x'] - return False - if int(sig['x']) < int(sig['t']): - if debuglog is not None: - print >>debuglog, "x= value is less than t= value (x=%s t=%s)" % (sig['x'], sig['t']) - return False m = re.match("(\w+)(?:/(\w+))?$", sig['c']) if m is None: @@ -565,60 +611,22 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): if debuglog is not None: print >>debuglog, "invalid format in _domainkey txt record" return False - x = asn1_parse(ASN1_Object, base64.b64decode(pub['p'])) - # Not sure why the [1:] is necessary to skip a byte. - pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:]) - pk = { - 'modulus': pkd[0][0], - 'publicExponent': pkd[0][1], - } + pk = parse_public_key(base64.b64decode(pub['p'])) modlen = len(int2str(pk['modulus'])) if debuglog is not None: print >>debuglog, "modlen:", modlen include_headers = re.split(r"\s*:\s*", sig['h']) - if debuglog is not None: - print >>debuglog, "include_headers:", include_headers - sign_headers = [] - lastindex = {} - for h in include_headers: - i = lastindex.get(h, len(headers)) - while i > 0: - i -= 1 - if h.lower() == headers[i][0].lower(): - sign_headers.append(headers[i]) - break - lastindex[h] = i - # The call to _remove() assumes that the signature b= only appears once in the signature header - sign_headers += [(x[0], x[1].rstrip()) for x in canonicalize_headers.canonicalize_headers([(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])] - if debuglog is not None: - print >>debuglog, "verify headers:", sign_headers - h = hasher() - for x in sign_headers: - h.update(x[0]) - h.update(":") - h.update(x[1]) + hash_headers( + h, canonicalize_headers, headers, include_headers, sigheaders, sig) d = h.digest() if debuglog is not None: print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d) - - dinfo = asn1_build( - (SEQUENCE, [ - (SEQUENCE, [ - (OBJECT_IDENTIFIER, hashid), - (NULL, None), - ]), - (OCTET_STRING, d), - ]) - ) - if debuglog is not None: - print >>debuglog, "dinfo:", " ".join("%02x" % ord(x) for x in dinfo) - if len(dinfo)+3 > modlen: - if debuglog is not None: - print >>debuglog, "Hash too large for modulus" + try: + sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid) + except ParameterError: return False - sig2 = "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo if debuglog is not None: print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2) print >>debuglog, sig['b']