bytesify __init__. Tests now parse if dns.resolver is removed.

This commit is contained in:
William Grant
2011-03-19 20:09:01 +11:00
parent e176c8fd4b
commit 8cf859db4f
+82 -78
View File
@@ -56,7 +56,7 @@ __all__ = [
class Simple: class Simple:
"""Class that represents the "simple" canonicalization algorithm.""" """Class that represents the "simple" canonicalization algorithm."""
name = "simple" name = b"simple"
@staticmethod @staticmethod
def canonicalize_headers(headers): def canonicalize_headers(headers):
@@ -66,12 +66,12 @@ class Simple:
@staticmethod @staticmethod
def canonicalize_body(body): def canonicalize_body(body):
# Ignore all empty lines at the end of the message body. # Ignore all empty lines at the end of the message body.
return re.sub("(\r\n)*$", "\r\n", body) return re.sub(b"(\r\n)*$", b"\r\n", body)
class Relaxed: class Relaxed:
"""Class that represents the "relaxed" canonicalization algorithm.""" """Class that represents the "relaxed" canonicalization algorithm."""
name = "relaxed" name = b"relaxed"
@staticmethod @staticmethod
def canonicalize_headers(headers): def canonicalize_headers(headers):
@@ -79,14 +79,14 @@ class Relaxed:
# Unfold all header lines. # Unfold all header lines.
# Compress WSP to single space. # Compress WSP to single space.
# Remove all WSP at the start or end of the field value (strip). # Remove all WSP at the start or end of the field value (strip).
return [(x[0].lower(), re.sub(r"\s+", " ", re.sub("\r\n", "", x[1])).strip()+"\r\n") for x in headers] return [(x[0].lower(), re.sub(br"\s+", b" ", re.sub(b"\r\n", b"", x[1])).strip()+b"\r\n") for x in headers]
@staticmethod @staticmethod
def canonicalize_body(body): def canonicalize_body(body):
# Remove all trailing WSP at end of lines. # Remove all trailing WSP at end of lines.
# Compress non-line-ending WSP to single space. # Compress non-line-ending WSP to single space.
# Ignore all empty lines at the end of the message body. # Ignore all empty lines at the end of the message body.
return re.sub("(\r\n)*$", "\r\n", re.sub(r"[\x09\x20]+", " ", re.sub("[\\x09\\x20]+\r\n", "\r\n", body))) return re.sub(b"(\r\n)*$", b"\r\n", re.sub(br"[\x09\x20]+", b" ", re.sub(b"[\\x09\\x20]+\r\n", b"\r\n", body)))
class DKIMException(Exception): class DKIMException(Exception):
"""Base class for DKIM errors.""" """Base class for DKIM errors."""
@@ -133,11 +133,11 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers,
# The call to _remove() assumes that the signature b= only appears # The call to _remove() assumes that the signature b= only appears
# once in the signature header # once in the signature header
cheaders = canonicalize_headers.canonicalize_headers( cheaders = canonicalize_headers.canonicalize_headers(
[(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))]) [(sigheaders[0][0], _remove(sigheaders[0][1], sig[b'b']))])
sign_headers += [(x[0], x[1].rstrip()) for x in cheaders] sign_headers += [(x[0], x[1].rstrip()) for x in cheaders]
for x in sign_headers: for x in sign_headers:
hasher.update(x[0]) hasher.update(x[0])
hasher.update(":") hasher.update(b":")
hasher.update(x[1]) hasher.update(x[1])
@@ -149,40 +149,43 @@ def validate_signature_fields(sig):
@param sig: A dict mapping field keys to values. @param sig: A dict mapping field keys to values.
""" """
mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's') mandatory_fields = (b'v', b'a', b'b', b'bh', b'd', b'h', b's')
for field in mandatory_fields: for field in mandatory_fields:
if field not in sig: if field not in sig:
raise ValidationError("signature missing %s=" % field) raise ValidationError("signature missing %s=" % field)
if sig['v'] != "1": if sig[b'v'] != b"1":
raise ValidationError("v= value is not 1 (%s)" % sig['v']) raise ValidationError("v= value is not 1 (%s)" % sig[b'v'])
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None: if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None:
raise ValidationError("b= value is not valid base64 (%s)" % sig['b']) raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b'])
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None: if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'bh']) is None:
raise ValidationError( raise ValidationError(
"bh= value is not valid base64 (%s)" % sig['bh']) "bh= value is not valid base64 (%s)" % sig[b'bh'])
if 'i' in sig and ( # Nasty hack to support both str and bytes... check for both the
not sig['i'].endswith(sig['d']) or # character and integer values.
sig['i'][-len(sig['d'])-1] not in "@."): if b'i' in sig and (
not sig[b'i'].endswith(sig[b'd']) or
sig[b'i'][-len(sig[b'd'])-1] not in ('@', '.', 64, 46)):
raise ValidationError( raise ValidationError(
"i= domain is not a subdomain of d= (i=%s d=%d)" % "i= domain is not a subdomain of d= (i=%s d=%d)" %
(sig['i'], sig['d'])) (sig[b'i'], sig[b'd']))
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None: if b'l' in sig and re.match(br"\d{,76}$", sig['l']) is None:
raise ValidationError( raise ValidationError(
"l= value is not a decimal integer (%s)" % sig['l']) "l= value is not a decimal integer (%s)" % sig[b'l'])
if 'q' in sig and sig['q'] != "dns/txt": if b'q' in sig and sig[b'q'] != b"dns/txt":
raise ValidationError("q= value is not dns/txt (%s)" % sig['q']) raise ValidationError("q= value is not dns/txt (%s)" % sig[b'q'])
if 't' in sig and re.match(r"\d+$", sig['t']) is None: if b't' in sig and re.match(br"\d+$", sig[b't']) is None:
raise ValidationError( raise ValidationError(
"t= value is not a decimal integer (%s)" % sig['t']) "t= value is not a decimal integer (%s)" % sig[b't'])
if 'x' in sig: if b'x' in sig:
if re.match(r"\d+$", sig['x']) is None: if re.match(br"\d+$", sig[b'x']) is None:
raise ValidationError( raise ValidationError(
"x= value is not a decimal integer (%s)" % sig['x']) "x= value is not a decimal integer (%s)" % sig[b'x'])
if int(sig['x']) < int(sig['t']): if int(sig[b'x']) < int(sig[b't']):
raise ValidationError( raise ValidationError(
"x= value is less than t= value (x=%s t=%s)" % "x= value is less than t= value (x=%s t=%s)" %
(sig['x'], sig['t'])) (sig[b'x'], sig[b't']))
def rfc822_parse(message): def rfc822_parse(message):
"""Parse a message in RFC822 format. """Parse a message in RFC822 format.
@@ -193,52 +196,53 @@ def rfc822_parse(message):
The body is a CRLF-separated string. The body is a CRLF-separated string.
""" """
headers = [] headers = []
lines = re.split("\r?\n", message) lines = re.split(b"\r?\n", message)
i = 0 i = 0
while i < len(lines): while i < len(lines):
if len(lines[i]) == 0: if len(lines[i]) == 0:
# End of headers, return what we have plus the body, excluding the blank line. # End of headers, return what we have plus the body, excluding the blank line.
i += 1 i += 1
break break
if re.match(r"[\x09\x20]", lines[i][0]): if lines[i][0] in ("\x09", "\x20", 0x09, 0x20):
headers[-1][1] += lines[i]+"\r\n" headers[-1][1] += lines[i]+b"\r\n"
else: else:
m = re.match(r"([\x21-\x7e]+?):", lines[i]) m = re.match(br"([\x21-\x7e]+?):", lines[i])
if m is not None: if m is not None:
headers.append([m.group(1), lines[i][m.end(0):]+"\r\n"]) headers.append([m.group(1), lines[i][m.end(0):]+b"\r\n"])
elif lines[i].startswith("From "): elif lines[i].startswith(b"From "):
pass pass
else: else:
raise MessageFormatError("Unexpected characters in RFC822 header: %s" % lines[i]) raise MessageFormatError("Unexpected characters in RFC822 header: %s" % lines[i])
i += 1 i += 1
return (headers, "\r\n".join(lines[i:])) return (headers, b"\r\n".join(lines[i:]))
def dnstxt(name): def dnstxt(name):
"""Return a TXT record associated with a DNS name.""" """Return a TXT record associated with a DNS name."""
a = dns.resolver.query(name, dns.rdatatype.TXT) a = dns.resolver.query(name, dns.rdatatype.TXT)
for r in a.response.answer: for r in a.response.answer:
if r.rdtype == dns.rdatatype.TXT: if r.rdtype == dns.rdatatype.TXT:
return "".join(r.items[0].strings) return b"".join(r.items[0].strings)
return None return None
def fold(header): def fold(header):
"""Fold a header line into multiple crlf-separated lines at column 72.""" """Fold a header line into multiple crlf-separated lines at column 72."""
i = header.rfind("\r\n ") i = header.rfind(b"\r\n ")
if i == -1: if i == -1:
pre = "" pre = b""
else: else:
i += 3 i += 3
pre = header[:i] pre = header[:i]
header = header[i:] header = header[i:]
while len(header) > 72: while len(header) > 72:
i = header[:72].rfind(" ") i = header[:72].rfind(b" ")
if i == -1: if i == -1:
j = i j = i
else: else:
j = i + 1 j = i + 1
pre += header[:i] + "\r\n " pre += header[:i] + b"\r\n "
header = header[j:] header = header[j:]
return pre + header return pre + header
@@ -286,26 +290,26 @@ def sign(message, selector, domain, privkey, identity=None,
bodyhash = base64.b64encode(h.digest()) bodyhash = base64.b64encode(h.digest())
sigfields = [x for x in [ sigfields = [x for x in [
('v', "1"), (b'v', b"1"),
('a', "rsa-sha256"), (b'a', b"rsa-sha256"),
('c', "%s/%s" % (canonicalize[0].name, canonicalize[1].name)), (b'c', b"/".join((canonicalize[0].name, canonicalize[1].name))),
('d', domain), (b'd', domain),
('i', identity or "@"+domain), (b'i', identity or b"@"+domain),
length and ('l', len(body)), length and (b'l', len(body)),
('q', "dns/txt"), (b'q', b"dns/txt"),
('s', selector), (b's', selector),
('t', str(int(time.time()))), (b't', str(int(time.time())).encode('ascii')),
('h', " : ".join(x[0] for x in sign_headers)), (b'h', b" : ".join(x[0] for x in sign_headers)),
('bh', bodyhash), (b'bh', bodyhash),
('b', ""), (b'b', b""),
] if x] ] if x]
sig_value = fold("; ".join("%s=%s" % x for x in sigfields)) sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields))
dkim_header = canonicalize[0].canonicalize_headers([ dkim_header = canonicalize[0].canonicalize_headers([
['DKIM-Signature', ' ' + sig_value]])[0] [b'DKIM-Signature', b' ' + sig_value]])[0]
# the dkim sig is hashed with no trailing crlf, even if the # the dkim sig is hashed with no trailing crlf, even if the
# canonicalization algorithm would add one. # canonicalization algorithm would add one.
if dkim_header[1][-2:] == '\r\n': if dkim_header[1][-2:] == b'\r\n':
dkim_header = (dkim_header[0], dkim_header[1][:-2]) dkim_header = (dkim_header[0], dkim_header[1][:-2])
sign_headers.append(dkim_header) sign_headers.append(dkim_header)
@@ -313,7 +317,7 @@ def sign(message, selector, domain, privkey, identity=None,
h = hashlib.sha256() h = hashlib.sha256()
for x in sign_headers: for x in sign_headers:
h.update(x[0]) h.update(x[0])
h.update(":") h.update(b":")
h.update(x[1]) h.update(x[1])
try: try:
@@ -323,7 +327,7 @@ def sign(message, selector, domain, privkey, identity=None,
raise ParameterError("digest too large for modulus") raise ParameterError("digest too large for modulus")
sig_value += base64.b64encode(sig2) sig_value += base64.b64encode(sig2)
return 'DKIM-Signature: ' + sig_value + "\r\n" return b'DKIM-Signature: ' + sig_value + b"\r\n"
def verify(message, logger=None, dnsfunc=dnstxt): def verify(message, logger=None, dnsfunc=dnstxt):
@@ -338,7 +342,7 @@ def verify(message, logger=None, dnsfunc=dnstxt):
(headers, body) = rfc822_parse(message) (headers, body) = rfc822_parse(message)
sigheaders = [x for x in headers if x[0].lower() == "dkim-signature"] sigheaders = [x for x in headers if x[0].lower() == b"dkim-signature"]
if len(sigheaders) < 1: if len(sigheaders) < 1:
return False return False
@@ -355,20 +359,20 @@ def verify(message, logger=None, dnsfunc=dnstxt):
logger.error("signature fields failed to validate: %s" % e) logger.error("signature fields failed to validate: %s" % e)
return False return False
m = re.match("(\w+)(?:/(\w+))?$", sig['c']) m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c'])
if m is None: if m is None:
logger.error( logger.error(
"c= value is not in format method/method (%s)" % sig['c']) "c= value is not in format method/method (%s)" % sig[b'c'])
return False return False
can_headers = m.group(1) can_headers = m.group(1)
if m.group(2) is not None: if m.group(2) is not None:
can_body = m.group(2) can_body = m.group(2)
else: else:
can_body = "simple" can_body = b"simple"
if can_headers == "simple": if can_headers == b"simple":
canonicalize_headers = Simple canonicalize_headers = Simple
elif can_headers == "relaxed": elif can_headers == b"relaxed":
canonicalize_headers = Relaxed canonicalize_headers = Relaxed
else: else:
logger.error("unknown header canonicalization (%s)" % can_headers) logger.error("unknown header canonicalization (%s)" % can_headers)
@@ -376,36 +380,36 @@ def verify(message, logger=None, dnsfunc=dnstxt):
headers = canonicalize_headers.canonicalize_headers(headers) headers = canonicalize_headers.canonicalize_headers(headers)
if can_body == "simple": if can_body == b"simple":
body = Simple.canonicalize_body(body) body = Simple.canonicalize_body(body)
elif can_body == "relaxed": elif can_body == b"relaxed":
body = Relaxed.canonicalize_body(body) body = Relaxed.canonicalize_body(body)
else: else:
logger.error("unknown body canonicalization (%s)" % can_body) logger.error("unknown body canonicalization (%s)" % can_body)
return False return False
if sig['a'] == "rsa-sha1": if sig[b'a'] == b"rsa-sha1":
hasher = hashlib.sha1 hasher = hashlib.sha1
elif sig['a'] == "rsa-sha256": elif sig[b'a'] == b"rsa-sha256":
hasher = hashlib.sha256 hasher = hashlib.sha256
else: else:
logger.error("unknown signature algorithm (%s)" % sig['a']) logger.error("unknown signature algorithm (%s)" % sig[b'a'])
return False return False
if 'l' in sig: if b'l' in sig:
body = body[:int(sig['l'])] body = body[:int(sig[b'l'])]
h = hasher() h = hasher()
h.update(body) h.update(body)
bodyhash = h.digest() bodyhash = h.digest()
logger.debug("bh: %s" % base64.b64encode(bodyhash)) logger.debug("bh: %s" % base64.b64encode(bodyhash))
if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])): if bodyhash != base64.b64decode(re.sub(br"\s+", "", sig[b'bh'])):
logger.error( logger.error(
"body hash mismatch (got %s, expected %s)" % "body hash mismatch (got %s, expected %s)" %
(base64.b64encode(bodyhash), sig['bh'])) (base64.b64encode(bodyhash), sig[b'bh']))
return False return False
s = dnsfunc(sig['s']+"._domainkey."+sig['d']+".") s = dnsfunc(sig[b's']+b"._domainkey."+sig[b'd']+b".")
if not s: if not s:
return False return False
try: try:
@@ -413,16 +417,16 @@ def verify(message, logger=None, dnsfunc=dnstxt):
except InvalidTagValueList: except InvalidTagValueList:
return False return False
try: try:
pk = parse_public_key(base64.b64decode(pub['p'])) pk = parse_public_key(base64.b64decode(pub[b'p']))
except UnparsableKeyError as e: except UnparsableKeyError as e:
logger.error("could not parse public key: %s" % e) logger.error("could not parse public key: %s" % e)
return False return False
include_headers = re.split(r"\s*:\s*", sig['h']) include_headers = re.split(br"\s*:\s*", sig[b'h'])
h = hasher() h = hasher()
hash_headers( hash_headers(
h, canonicalize_headers, headers, include_headers, sigheaders, sig) h, canonicalize_headers, headers, include_headers, sigheaders, sig)
signature = base64.b64decode(re.sub(r"\s+", "", sig['b'])) signature = base64.b64decode(re.sub(br"\s+", "", sig[b'b']))
try: try:
return RSASSA_PKCS1_v1_5_verify( return RSASSA_PKCS1_v1_5_verify(
h, signature, pk['publicExponent'], pk['modulus']) h, signature, pk['publicExponent'], pk['modulus'])