From 8cf859db4f9b8c7578803d73d86af6c125471f63 Mon Sep 17 00:00:00 2001 From: William Grant Date: Sat, 19 Mar 2011 20:09:01 +1100 Subject: [PATCH] bytesify __init__. Tests now parse if dns.resolver is removed. --- dkim/__init__.py | 160 ++++++++++++++++++++++++----------------------- 1 file changed, 82 insertions(+), 78 deletions(-) diff --git a/dkim/__init__.py b/dkim/__init__.py index 6130bb7..9bc127a 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -56,7 +56,7 @@ __all__ = [ class Simple: """Class that represents the "simple" canonicalization algorithm.""" - name = "simple" + name = b"simple" @staticmethod def canonicalize_headers(headers): @@ -66,12 +66,12 @@ class Simple: @staticmethod def canonicalize_body(body): # Ignore all empty lines at the end of the message body. - return re.sub("(\r\n)*$", "\r\n", body) + return re.sub(b"(\r\n)*$", b"\r\n", body) class Relaxed: """Class that represents the "relaxed" canonicalization algorithm.""" - name = "relaxed" + name = b"relaxed" @staticmethod def canonicalize_headers(headers): @@ -79,14 +79,14 @@ class Relaxed: # Unfold all header lines. # Compress WSP to single space. # Remove all WSP at the start or end of the field value (strip). - return [(x[0].lower(), re.sub(r"\s+", " ", re.sub("\r\n", "", x[1])).strip()+"\r\n") for x in headers] + return [(x[0].lower(), re.sub(br"\s+", b" ", re.sub(b"\r\n", b"", x[1])).strip()+b"\r\n") for x in headers] @staticmethod def canonicalize_body(body): # Remove all trailing WSP at end of lines. # Compress non-line-ending WSP to single space. # Ignore all empty lines at the end of the message body. - return re.sub("(\r\n)*$", "\r\n", re.sub(r"[\x09\x20]+", " ", re.sub("[\\x09\\x20]+\r\n", "\r\n", body))) + return re.sub(b"(\r\n)*$", b"\r\n", re.sub(br"[\x09\x20]+", b" ", re.sub(b"[\\x09\\x20]+\r\n", b"\r\n", body))) class DKIMException(Exception): """Base class for DKIM errors.""" @@ -133,11 +133,11 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers, # The call to _remove() assumes that the signature b= only appears # once in the signature header cheaders = canonicalize_headers.canonicalize_headers( - [(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))]) + [(sigheaders[0][0], _remove(sigheaders[0][1], sig[b'b']))]) sign_headers += [(x[0], x[1].rstrip()) for x in cheaders] for x in sign_headers: hasher.update(x[0]) - hasher.update(":") + hasher.update(b":") hasher.update(x[1]) @@ -149,40 +149,43 @@ def validate_signature_fields(sig): @param sig: A dict mapping field keys to values. """ - mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's') + mandatory_fields = (b'v', b'a', b'b', b'bh', b'd', b'h', b's') for field in mandatory_fields: if field not in sig: raise ValidationError("signature missing %s=" % field) - if sig['v'] != "1": - raise ValidationError("v= value is not 1 (%s)" % sig['v']) - if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: - raise ValidationError("b= value is not valid base64 (%s)" % sig['b']) - if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: + if sig[b'v'] != b"1": + raise ValidationError("v= value is not 1 (%s)" % sig[b'v']) + if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None: + raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b']) + if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'bh']) is None: raise ValidationError( - "bh= value is not valid base64 (%s)" % sig['bh']) - if 'i' in sig and ( - not sig['i'].endswith(sig['d']) or - sig['i'][-len(sig['d'])-1] not in "@."): + "bh= value is not valid base64 (%s)" % sig[b'bh']) + # Nasty hack to support both str and bytes... check for both the + # character and integer values. + if b'i' in sig and ( + not sig[b'i'].endswith(sig[b'd']) or + sig[b'i'][-len(sig[b'd'])-1] not in ('@', '.', 64, 46)): raise ValidationError( "i= domain is not a subdomain of d= (i=%s d=%d)" % - (sig['i'], sig['d'])) - if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: + (sig[b'i'], sig[b'd'])) + if b'l' in sig and re.match(br"\d{,76}$", sig['l']) is None: raise ValidationError( - "l= value is not a decimal integer (%s)" % sig['l']) - if 'q' in sig and sig['q'] != "dns/txt": - raise ValidationError("q= value is not dns/txt (%s)" % sig['q']) - if 't' in sig and re.match(r"\d+$", sig['t']) is None: + "l= value is not a decimal integer (%s)" % sig[b'l']) + if b'q' in sig and sig[b'q'] != b"dns/txt": + raise ValidationError("q= value is not dns/txt (%s)" % sig[b'q']) + if b't' in sig and re.match(br"\d+$", sig[b't']) is None: raise ValidationError( - "t= value is not a decimal integer (%s)" % sig['t']) - if 'x' in sig: - if re.match(r"\d+$", sig['x']) is None: + "t= value is not a decimal integer (%s)" % sig[b't']) + if b'x' in sig: + if re.match(br"\d+$", sig[b'x']) is None: raise ValidationError( - "x= value is not a decimal integer (%s)" % sig['x']) - if int(sig['x']) < int(sig['t']): + "x= value is not a decimal integer (%s)" % sig[b'x']) + if int(sig[b'x']) < int(sig[b't']): raise ValidationError( "x= value is less than t= value (x=%s t=%s)" % - (sig['x'], sig['t'])) + (sig[b'x'], sig[b't'])) + def rfc822_parse(message): """Parse a message in RFC822 format. @@ -193,52 +196,53 @@ def rfc822_parse(message): The body is a CRLF-separated string. """ - headers = [] - lines = re.split("\r?\n", message) + lines = re.split(b"\r?\n", message) i = 0 while i < len(lines): if len(lines[i]) == 0: # End of headers, return what we have plus the body, excluding the blank line. i += 1 break - if re.match(r"[\x09\x20]", lines[i][0]): - headers[-1][1] += lines[i]+"\r\n" + if lines[i][0] in ("\x09", "\x20", 0x09, 0x20): + headers[-1][1] += lines[i]+b"\r\n" else: - m = re.match(r"([\x21-\x7e]+?):", lines[i]) + m = re.match(br"([\x21-\x7e]+?):", lines[i]) if m is not None: - headers.append([m.group(1), lines[i][m.end(0):]+"\r\n"]) - elif lines[i].startswith("From "): + headers.append([m.group(1), lines[i][m.end(0):]+b"\r\n"]) + elif lines[i].startswith(b"From "): pass else: raise MessageFormatError("Unexpected characters in RFC822 header: %s" % lines[i]) i += 1 - return (headers, "\r\n".join(lines[i:])) + return (headers, b"\r\n".join(lines[i:])) + def dnstxt(name): """Return a TXT record associated with a DNS name.""" a = dns.resolver.query(name, dns.rdatatype.TXT) for r in a.response.answer: if r.rdtype == dns.rdatatype.TXT: - return "".join(r.items[0].strings) + return b"".join(r.items[0].strings) return None + def fold(header): """Fold a header line into multiple crlf-separated lines at column 72.""" - i = header.rfind("\r\n ") + i = header.rfind(b"\r\n ") if i == -1: - pre = "" + pre = b"" else: i += 3 pre = header[:i] header = header[i:] while len(header) > 72: - i = header[:72].rfind(" ") + i = header[:72].rfind(b" ") if i == -1: j = i else: j = i + 1 - pre += header[:i] + "\r\n " + pre += header[:i] + b"\r\n " header = header[j:] return pre + header @@ -286,26 +290,26 @@ def sign(message, selector, domain, privkey, identity=None, bodyhash = base64.b64encode(h.digest()) sigfields = [x for x in [ - ('v', "1"), - ('a', "rsa-sha256"), - ('c', "%s/%s" % (canonicalize[0].name, canonicalize[1].name)), - ('d', domain), - ('i', identity or "@"+domain), - length and ('l', len(body)), - ('q', "dns/txt"), - ('s', selector), - ('t', str(int(time.time()))), - ('h', " : ".join(x[0] for x in sign_headers)), - ('bh', bodyhash), - ('b', ""), + (b'v', b"1"), + (b'a', b"rsa-sha256"), + (b'c', b"/".join((canonicalize[0].name, canonicalize[1].name))), + (b'd', domain), + (b'i', identity or b"@"+domain), + length and (b'l', len(body)), + (b'q', b"dns/txt"), + (b's', selector), + (b't', str(int(time.time())).encode('ascii')), + (b'h', b" : ".join(x[0] for x in sign_headers)), + (b'bh', bodyhash), + (b'b', b""), ] if x] - sig_value = fold("; ".join("%s=%s" % x for x in sigfields)) + sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields)) dkim_header = canonicalize[0].canonicalize_headers([ - ['DKIM-Signature', ' ' + sig_value]])[0] + [b'DKIM-Signature', b' ' + sig_value]])[0] # the dkim sig is hashed with no trailing crlf, even if the # canonicalization algorithm would add one. - if dkim_header[1][-2:] == '\r\n': + if dkim_header[1][-2:] == b'\r\n': dkim_header = (dkim_header[0], dkim_header[1][:-2]) sign_headers.append(dkim_header) @@ -313,7 +317,7 @@ def sign(message, selector, domain, privkey, identity=None, h = hashlib.sha256() for x in sign_headers: h.update(x[0]) - h.update(":") + h.update(b":") h.update(x[1]) try: @@ -323,7 +327,7 @@ def sign(message, selector, domain, privkey, identity=None, raise ParameterError("digest too large for modulus") sig_value += base64.b64encode(sig2) - return 'DKIM-Signature: ' + sig_value + "\r\n" + return b'DKIM-Signature: ' + sig_value + b"\r\n" def verify(message, logger=None, dnsfunc=dnstxt): @@ -338,7 +342,7 @@ def verify(message, logger=None, dnsfunc=dnstxt): (headers, body) = rfc822_parse(message) - sigheaders = [x for x in headers if x[0].lower() == "dkim-signature"] + sigheaders = [x for x in headers if x[0].lower() == b"dkim-signature"] if len(sigheaders) < 1: return False @@ -355,20 +359,20 @@ def verify(message, logger=None, dnsfunc=dnstxt): logger.error("signature fields failed to validate: %s" % e) return False - m = re.match("(\w+)(?:/(\w+))?$", sig['c']) + m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c']) if m is None: logger.error( - "c= value is not in format method/method (%s)" % sig['c']) + "c= value is not in format method/method (%s)" % sig[b'c']) return False can_headers = m.group(1) if m.group(2) is not None: can_body = m.group(2) else: - can_body = "simple" + can_body = b"simple" - if can_headers == "simple": + if can_headers == b"simple": canonicalize_headers = Simple - elif can_headers == "relaxed": + elif can_headers == b"relaxed": canonicalize_headers = Relaxed else: logger.error("unknown header canonicalization (%s)" % can_headers) @@ -376,36 +380,36 @@ def verify(message, logger=None, dnsfunc=dnstxt): headers = canonicalize_headers.canonicalize_headers(headers) - if can_body == "simple": + if can_body == b"simple": body = Simple.canonicalize_body(body) - elif can_body == "relaxed": + elif can_body == b"relaxed": body = Relaxed.canonicalize_body(body) else: logger.error("unknown body canonicalization (%s)" % can_body) return False - if sig['a'] == "rsa-sha1": + if sig[b'a'] == b"rsa-sha1": hasher = hashlib.sha1 - elif sig['a'] == "rsa-sha256": + elif sig[b'a'] == b"rsa-sha256": hasher = hashlib.sha256 else: - logger.error("unknown signature algorithm (%s)" % sig['a']) + logger.error("unknown signature algorithm (%s)" % sig[b'a']) return False - if 'l' in sig: - body = body[:int(sig['l'])] + if b'l' in sig: + body = body[:int(sig[b'l'])] h = hasher() h.update(body) bodyhash = h.digest() logger.debug("bh: %s" % base64.b64encode(bodyhash)) - if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])): + if bodyhash != base64.b64decode(re.sub(br"\s+", "", sig[b'bh'])): logger.error( "body hash mismatch (got %s, expected %s)" % - (base64.b64encode(bodyhash), sig['bh'])) + (base64.b64encode(bodyhash), sig[b'bh'])) return False - s = dnsfunc(sig['s']+"._domainkey."+sig['d']+".") + s = dnsfunc(sig[b's']+b"._domainkey."+sig[b'd']+b".") if not s: return False try: @@ -413,16 +417,16 @@ def verify(message, logger=None, dnsfunc=dnstxt): except InvalidTagValueList: return False try: - pk = parse_public_key(base64.b64decode(pub['p'])) + pk = parse_public_key(base64.b64decode(pub[b'p'])) except UnparsableKeyError as e: logger.error("could not parse public key: %s" % e) return False - include_headers = re.split(r"\s*:\s*", sig['h']) + include_headers = re.split(br"\s*:\s*", sig[b'h']) h = hasher() hash_headers( h, canonicalize_headers, headers, include_headers, sigheaders, sig) - signature = base64.b64decode(re.sub(r"\s+", "", sig['b'])) + signature = base64.b64decode(re.sub(br"\s+", "", sig[b'b'])) try: return RSASSA_PKCS1_v1_5_verify( h, signature, pk['publicExponent'], pk['modulus'])