diff --git a/dkim/__init__.py b/dkim/__init__.py index 8d4a366..c24144e 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -9,6 +9,8 @@ from dkim.crypto import ( EMSA_PKCS1_v1_5_encode, parse_private_key, parse_public_key, + RSASSA_PKCS1_v1_5_sign, + RSASSA_PKCS1_v1_5_verify, ) __all__ = [ @@ -152,38 +154,10 @@ def validate_signature_fields(sig, debuglog=None): return False return True -def perform_rsa(input, exponent, modulus, modlen): - return int2str(pow(str2int(input), exponent, modulus), modlen) - # These values come from RFC 3447, section 9.2 Notes, page 43. HASHID_SHA1 = "\x2b\x0e\x03\x02\x1a" HASHID_SHA256 = "\x60\x86\x48\x01\x65\x03\x04\x02\x01" -def str2int(s): - """Convert an octet string to an integer. Octet string assumed to represent a positive integer.""" - r = 0 - for c in s: - r = (r << 8) | ord(c) - return r - -def int2str(n, length = -1): - """Convert an integer to an octet string. Number must be positive. - - @param n: Number to convert. - @param length: Minimum length, or -1 to return the smallest number of bytes that represent the integer. - - """ - - assert n >= 0 - r = [] - while length < 0 or len(r) < length: - r.append(chr(n & 0xff)) - n >>= 8 - if length < 0 and n == 0: break - r.reverse() - assert length < 0 or len(r) == length - return r - def rfc822_parse(message): """Parse a message in RFC822 format. @@ -316,9 +290,8 @@ def sign(message, selector, domain, privkey, identity=None, canonicalize=(Simple if debuglog is not None: print >>debuglog, "sign digest:", " ".join("%02x" % ord(x) for x in d) - modlen = len(int2str(pk['modulus'])) - encoded = EMSA_PKCS1_v1_5_encode(d, modlen, HASHID_SHA256) - sig2 = perform_rsa(encoded, pk['privateExponent'], pk['modulus'], modlen) + sig2 = RSASSA_PKCS1_v1_5_sign( + d, HASHID_SHA256, pk['privateExponent'], pk['modulus']) sig += base64.b64encode(''.join(sig2)) return sig + "\r\n" @@ -430,9 +403,6 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): print >>debuglog, "invalid format in _domainkey txt record" return False pk = parse_public_key(base64.b64decode(pub['p'])) - modlen = len(int2str(pk['modulus'])) - if debuglog is not None: - print >>debuglog, "modlen:", modlen include_headers = re.split(r"\s*:\s*", sig['h']) h = hasher() @@ -441,18 +411,9 @@ def verify(message, debuglog=None, dnsfunc=dnstxt): d = h.digest() if debuglog is not None: print >>debuglog, "verify digest:", " ".join("%02x" % ord(x) for x in d) + signature = base64.b64decode(re.sub(r"\s+", "", sig['b'])) try: - sig2 = EMSA_PKCS1_v1_5_encode(d, modlen, hashid) + return RSASSA_PKCS1_v1_5_verify( + d, hashid, signature, pk['publicExponent'], pk['modulus']) except ParameterError: return False - if debuglog is not None: - print >>debuglog, "sig2:", " ".join("%02x" % ord(x) for x in sig2) - print >>debuglog, sig['b'] - print >>debuglog, re.sub(r"\s+", "", sig['b']) - signature = base64.b64decode(re.sub(r"\s+", "", sig['b'])) - v = perform_rsa(signature, pk['publicExponent'], pk['modulus'], modlen) - if debuglog is not None: - print >>debuglog, "v:", " ".join("%02x" % ord(x) for x in v) - assert len(v) == len(sig2) - # Byte-by-byte compare of signatures - return not [1 for x in zip(v, sig2) if x[0] != x[1]] diff --git a/dkim/crypto.py b/dkim/crypto.py index 217318d..a43bef0 100644 --- a/dkim/crypto.py +++ b/dkim/crypto.py @@ -21,6 +21,8 @@ __all__ = [ 'EMSA_PKCS1_v1_5_encode', 'parse_private_key', 'parse_public_key', + 'RSASSA_PKCS1_v1_5_sign', + 'RSASSA_PKCS1_v1_5_verify', ] from dkim.asn1 import ( @@ -108,3 +110,51 @@ def EMSA_PKCS1_v1_5_encode(digest, modlen, hashid): raise Exception("Hash too large for modulus") # XXX: DKIMException return "\x00\x01"+"\xff"*(modlen-len(dinfo)-3)+"\x00"+dinfo + +def str2int(s): + """Convert an octet string to an integer. + + Octet string assumed to represent a positive integer. + """ + r = 0 + for c in s: + r = (r << 8) | ord(c) + return r + + +def int2str(n, length = -1): + """Convert an integer to an octet string. Number must be positive. + + @param n: Number to convert. + @param length: Minimum length, or -1 to return the smallest number of + bytes that represent the integer. + """ + + assert n >= 0 + r = [] + while length < 0 or len(r) < length: + r.append(chr(n & 0xff)) + n >>= 8 + if length < 0 and n == 0: + break + r.reverse() + assert length < 0 or len(r) == length + return ''.join(r) + + +def perform_rsa(input, exponent, modulus, modlen): + return int2str(pow(str2int(input), exponent, modulus), modlen) + + +def RSASSA_PKCS1_v1_5_sign(digest, hashid, private_exponent, modulus): + modlen = len(int2str(modulus)) + encoded_digest = EMSA_PKCS1_v1_5_encode(digest, modlen, hashid) + return perform_rsa(encoded_digest, private_exponent, modulus, modlen) + + +def RSASSA_PKCS1_v1_5_verify(digest, hashid, signature, public_exponent, + modulus): + modlen = len(int2str(modulus)) + encoded_digest = EMSA_PKCS1_v1_5_encode(digest, modlen, hashid) + signed_digest = perform_rsa(signature, public_exponent, modulus, modlen) + return encoded_digest == signed_digest