bytesify __init__. Tests now parse if dns.resolver is removed.
This commit is contained in:
+82
-78
@@ -56,7 +56,7 @@ __all__ = [
|
||||
class Simple:
|
||||
"""Class that represents the "simple" canonicalization algorithm."""
|
||||
|
||||
name = "simple"
|
||||
name = b"simple"
|
||||
|
||||
@staticmethod
|
||||
def canonicalize_headers(headers):
|
||||
@@ -66,12 +66,12 @@ class Simple:
|
||||
@staticmethod
|
||||
def canonicalize_body(body):
|
||||
# Ignore all empty lines at the end of the message body.
|
||||
return re.sub("(\r\n)*$", "\r\n", body)
|
||||
return re.sub(b"(\r\n)*$", b"\r\n", body)
|
||||
|
||||
class Relaxed:
|
||||
"""Class that represents the "relaxed" canonicalization algorithm."""
|
||||
|
||||
name = "relaxed"
|
||||
name = b"relaxed"
|
||||
|
||||
@staticmethod
|
||||
def canonicalize_headers(headers):
|
||||
@@ -79,14 +79,14 @@ class Relaxed:
|
||||
# Unfold all header lines.
|
||||
# Compress WSP to single space.
|
||||
# Remove all WSP at the start or end of the field value (strip).
|
||||
return [(x[0].lower(), re.sub(r"\s+", " ", re.sub("\r\n", "", x[1])).strip()+"\r\n") for x in headers]
|
||||
return [(x[0].lower(), re.sub(br"\s+", b" ", re.sub(b"\r\n", b"", x[1])).strip()+b"\r\n") for x in headers]
|
||||
|
||||
@staticmethod
|
||||
def canonicalize_body(body):
|
||||
# Remove all trailing WSP at end of lines.
|
||||
# Compress non-line-ending WSP to single space.
|
||||
# Ignore all empty lines at the end of the message body.
|
||||
return re.sub("(\r\n)*$", "\r\n", re.sub(r"[\x09\x20]+", " ", re.sub("[\\x09\\x20]+\r\n", "\r\n", body)))
|
||||
return re.sub(b"(\r\n)*$", b"\r\n", re.sub(br"[\x09\x20]+", b" ", re.sub(b"[\\x09\\x20]+\r\n", b"\r\n", body)))
|
||||
|
||||
class DKIMException(Exception):
|
||||
"""Base class for DKIM errors."""
|
||||
@@ -133,11 +133,11 @@ def hash_headers(hasher, canonicalize_headers, headers, include_headers,
|
||||
# The call to _remove() assumes that the signature b= only appears
|
||||
# once in the signature header
|
||||
cheaders = canonicalize_headers.canonicalize_headers(
|
||||
[(sigheaders[0][0], _remove(sigheaders[0][1], sig['b']))])
|
||||
[(sigheaders[0][0], _remove(sigheaders[0][1], sig[b'b']))])
|
||||
sign_headers += [(x[0], x[1].rstrip()) for x in cheaders]
|
||||
for x in sign_headers:
|
||||
hasher.update(x[0])
|
||||
hasher.update(":")
|
||||
hasher.update(b":")
|
||||
hasher.update(x[1])
|
||||
|
||||
|
||||
@@ -149,40 +149,43 @@ def validate_signature_fields(sig):
|
||||
|
||||
@param sig: A dict mapping field keys to values.
|
||||
"""
|
||||
mandatory_fields = ('v', 'a', 'b', 'bh', 'd', 'h', 's')
|
||||
mandatory_fields = (b'v', b'a', b'b', b'bh', b'd', b'h', b's')
|
||||
for field in mandatory_fields:
|
||||
if field not in sig:
|
||||
raise ValidationError("signature missing %s=" % field)
|
||||
|
||||
if sig['v'] != "1":
|
||||
raise ValidationError("v= value is not 1 (%s)" % sig['v'])
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['b']) is None:
|
||||
raise ValidationError("b= value is not valid base64 (%s)" % sig['b'])
|
||||
if re.match(r"[\s0-9A-Za-z+/]+=*$", sig['bh']) is None:
|
||||
if sig[b'v'] != b"1":
|
||||
raise ValidationError("v= value is not 1 (%s)" % sig[b'v'])
|
||||
if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'b']) is None:
|
||||
raise ValidationError("b= value is not valid base64 (%s)" % sig[b'b'])
|
||||
if re.match(br"[\s0-9A-Za-z+/]+=*$", sig[b'bh']) is None:
|
||||
raise ValidationError(
|
||||
"bh= value is not valid base64 (%s)" % sig['bh'])
|
||||
if 'i' in sig and (
|
||||
not sig['i'].endswith(sig['d']) or
|
||||
sig['i'][-len(sig['d'])-1] not in "@."):
|
||||
"bh= value is not valid base64 (%s)" % sig[b'bh'])
|
||||
# Nasty hack to support both str and bytes... check for both the
|
||||
# character and integer values.
|
||||
if b'i' in sig and (
|
||||
not sig[b'i'].endswith(sig[b'd']) or
|
||||
sig[b'i'][-len(sig[b'd'])-1] not in ('@', '.', 64, 46)):
|
||||
raise ValidationError(
|
||||
"i= domain is not a subdomain of d= (i=%s d=%d)" %
|
||||
(sig['i'], sig['d']))
|
||||
if 'l' in sig and re.match(r"\d{,76}$", sig['l']) is None:
|
||||
(sig[b'i'], sig[b'd']))
|
||||
if b'l' in sig and re.match(br"\d{,76}$", sig['l']) is None:
|
||||
raise ValidationError(
|
||||
"l= value is not a decimal integer (%s)" % sig['l'])
|
||||
if 'q' in sig and sig['q'] != "dns/txt":
|
||||
raise ValidationError("q= value is not dns/txt (%s)" % sig['q'])
|
||||
if 't' in sig and re.match(r"\d+$", sig['t']) is None:
|
||||
"l= value is not a decimal integer (%s)" % sig[b'l'])
|
||||
if b'q' in sig and sig[b'q'] != b"dns/txt":
|
||||
raise ValidationError("q= value is not dns/txt (%s)" % sig[b'q'])
|
||||
if b't' in sig and re.match(br"\d+$", sig[b't']) is None:
|
||||
raise ValidationError(
|
||||
"t= value is not a decimal integer (%s)" % sig['t'])
|
||||
if 'x' in sig:
|
||||
if re.match(r"\d+$", sig['x']) is None:
|
||||
"t= value is not a decimal integer (%s)" % sig[b't'])
|
||||
if b'x' in sig:
|
||||
if re.match(br"\d+$", sig[b'x']) is None:
|
||||
raise ValidationError(
|
||||
"x= value is not a decimal integer (%s)" % sig['x'])
|
||||
if int(sig['x']) < int(sig['t']):
|
||||
"x= value is not a decimal integer (%s)" % sig[b'x'])
|
||||
if int(sig[b'x']) < int(sig[b't']):
|
||||
raise ValidationError(
|
||||
"x= value is less than t= value (x=%s t=%s)" %
|
||||
(sig['x'], sig['t']))
|
||||
(sig[b'x'], sig[b't']))
|
||||
|
||||
|
||||
def rfc822_parse(message):
|
||||
"""Parse a message in RFC822 format.
|
||||
@@ -193,52 +196,53 @@ def rfc822_parse(message):
|
||||
The body is a CRLF-separated string.
|
||||
|
||||
"""
|
||||
|
||||
headers = []
|
||||
lines = re.split("\r?\n", message)
|
||||
lines = re.split(b"\r?\n", message)
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
if len(lines[i]) == 0:
|
||||
# End of headers, return what we have plus the body, excluding the blank line.
|
||||
i += 1
|
||||
break
|
||||
if re.match(r"[\x09\x20]", lines[i][0]):
|
||||
headers[-1][1] += lines[i]+"\r\n"
|
||||
if lines[i][0] in ("\x09", "\x20", 0x09, 0x20):
|
||||
headers[-1][1] += lines[i]+b"\r\n"
|
||||
else:
|
||||
m = re.match(r"([\x21-\x7e]+?):", lines[i])
|
||||
m = re.match(br"([\x21-\x7e]+?):", lines[i])
|
||||
if m is not None:
|
||||
headers.append([m.group(1), lines[i][m.end(0):]+"\r\n"])
|
||||
elif lines[i].startswith("From "):
|
||||
headers.append([m.group(1), lines[i][m.end(0):]+b"\r\n"])
|
||||
elif lines[i].startswith(b"From "):
|
||||
pass
|
||||
else:
|
||||
raise MessageFormatError("Unexpected characters in RFC822 header: %s" % lines[i])
|
||||
i += 1
|
||||
return (headers, "\r\n".join(lines[i:]))
|
||||
return (headers, b"\r\n".join(lines[i:]))
|
||||
|
||||
|
||||
def dnstxt(name):
|
||||
"""Return a TXT record associated with a DNS name."""
|
||||
a = dns.resolver.query(name, dns.rdatatype.TXT)
|
||||
for r in a.response.answer:
|
||||
if r.rdtype == dns.rdatatype.TXT:
|
||||
return "".join(r.items[0].strings)
|
||||
return b"".join(r.items[0].strings)
|
||||
return None
|
||||
|
||||
|
||||
def fold(header):
|
||||
"""Fold a header line into multiple crlf-separated lines at column 72."""
|
||||
i = header.rfind("\r\n ")
|
||||
i = header.rfind(b"\r\n ")
|
||||
if i == -1:
|
||||
pre = ""
|
||||
pre = b""
|
||||
else:
|
||||
i += 3
|
||||
pre = header[:i]
|
||||
header = header[i:]
|
||||
while len(header) > 72:
|
||||
i = header[:72].rfind(" ")
|
||||
i = header[:72].rfind(b" ")
|
||||
if i == -1:
|
||||
j = i
|
||||
else:
|
||||
j = i + 1
|
||||
pre += header[:i] + "\r\n "
|
||||
pre += header[:i] + b"\r\n "
|
||||
header = header[j:]
|
||||
return pre + header
|
||||
|
||||
@@ -286,26 +290,26 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
bodyhash = base64.b64encode(h.digest())
|
||||
|
||||
sigfields = [x for x in [
|
||||
('v', "1"),
|
||||
('a', "rsa-sha256"),
|
||||
('c', "%s/%s" % (canonicalize[0].name, canonicalize[1].name)),
|
||||
('d', domain),
|
||||
('i', identity or "@"+domain),
|
||||
length and ('l', len(body)),
|
||||
('q', "dns/txt"),
|
||||
('s', selector),
|
||||
('t', str(int(time.time()))),
|
||||
('h', " : ".join(x[0] for x in sign_headers)),
|
||||
('bh', bodyhash),
|
||||
('b', ""),
|
||||
(b'v', b"1"),
|
||||
(b'a', b"rsa-sha256"),
|
||||
(b'c', b"/".join((canonicalize[0].name, canonicalize[1].name))),
|
||||
(b'd', domain),
|
||||
(b'i', identity or b"@"+domain),
|
||||
length and (b'l', len(body)),
|
||||
(b'q', b"dns/txt"),
|
||||
(b's', selector),
|
||||
(b't', str(int(time.time())).encode('ascii')),
|
||||
(b'h', b" : ".join(x[0] for x in sign_headers)),
|
||||
(b'bh', bodyhash),
|
||||
(b'b', b""),
|
||||
] if x]
|
||||
|
||||
sig_value = fold("; ".join("%s=%s" % x for x in sigfields))
|
||||
sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields))
|
||||
dkim_header = canonicalize[0].canonicalize_headers([
|
||||
['DKIM-Signature', ' ' + sig_value]])[0]
|
||||
[b'DKIM-Signature', b' ' + sig_value]])[0]
|
||||
# the dkim sig is hashed with no trailing crlf, even if the
|
||||
# canonicalization algorithm would add one.
|
||||
if dkim_header[1][-2:] == '\r\n':
|
||||
if dkim_header[1][-2:] == b'\r\n':
|
||||
dkim_header = (dkim_header[0], dkim_header[1][:-2])
|
||||
sign_headers.append(dkim_header)
|
||||
|
||||
@@ -313,7 +317,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
h = hashlib.sha256()
|
||||
for x in sign_headers:
|
||||
h.update(x[0])
|
||||
h.update(":")
|
||||
h.update(b":")
|
||||
h.update(x[1])
|
||||
|
||||
try:
|
||||
@@ -323,7 +327,7 @@ def sign(message, selector, domain, privkey, identity=None,
|
||||
raise ParameterError("digest too large for modulus")
|
||||
sig_value += base64.b64encode(sig2)
|
||||
|
||||
return 'DKIM-Signature: ' + sig_value + "\r\n"
|
||||
return b'DKIM-Signature: ' + sig_value + b"\r\n"
|
||||
|
||||
|
||||
def verify(message, logger=None, dnsfunc=dnstxt):
|
||||
@@ -338,7 +342,7 @@ def verify(message, logger=None, dnsfunc=dnstxt):
|
||||
|
||||
(headers, body) = rfc822_parse(message)
|
||||
|
||||
sigheaders = [x for x in headers if x[0].lower() == "dkim-signature"]
|
||||
sigheaders = [x for x in headers if x[0].lower() == b"dkim-signature"]
|
||||
if len(sigheaders) < 1:
|
||||
return False
|
||||
|
||||
@@ -355,20 +359,20 @@ def verify(message, logger=None, dnsfunc=dnstxt):
|
||||
logger.error("signature fields failed to validate: %s" % e)
|
||||
return False
|
||||
|
||||
m = re.match("(\w+)(?:/(\w+))?$", sig['c'])
|
||||
m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c'])
|
||||
if m is None:
|
||||
logger.error(
|
||||
"c= value is not in format method/method (%s)" % sig['c'])
|
||||
"c= value is not in format method/method (%s)" % sig[b'c'])
|
||||
return False
|
||||
can_headers = m.group(1)
|
||||
if m.group(2) is not None:
|
||||
can_body = m.group(2)
|
||||
else:
|
||||
can_body = "simple"
|
||||
can_body = b"simple"
|
||||
|
||||
if can_headers == "simple":
|
||||
if can_headers == b"simple":
|
||||
canonicalize_headers = Simple
|
||||
elif can_headers == "relaxed":
|
||||
elif can_headers == b"relaxed":
|
||||
canonicalize_headers = Relaxed
|
||||
else:
|
||||
logger.error("unknown header canonicalization (%s)" % can_headers)
|
||||
@@ -376,36 +380,36 @@ def verify(message, logger=None, dnsfunc=dnstxt):
|
||||
|
||||
headers = canonicalize_headers.canonicalize_headers(headers)
|
||||
|
||||
if can_body == "simple":
|
||||
if can_body == b"simple":
|
||||
body = Simple.canonicalize_body(body)
|
||||
elif can_body == "relaxed":
|
||||
elif can_body == b"relaxed":
|
||||
body = Relaxed.canonicalize_body(body)
|
||||
else:
|
||||
logger.error("unknown body canonicalization (%s)" % can_body)
|
||||
return False
|
||||
|
||||
if sig['a'] == "rsa-sha1":
|
||||
if sig[b'a'] == b"rsa-sha1":
|
||||
hasher = hashlib.sha1
|
||||
elif sig['a'] == "rsa-sha256":
|
||||
elif sig[b'a'] == b"rsa-sha256":
|
||||
hasher = hashlib.sha256
|
||||
else:
|
||||
logger.error("unknown signature algorithm (%s)" % sig['a'])
|
||||
logger.error("unknown signature algorithm (%s)" % sig[b'a'])
|
||||
return False
|
||||
|
||||
if 'l' in sig:
|
||||
body = body[:int(sig['l'])]
|
||||
if b'l' in sig:
|
||||
body = body[:int(sig[b'l'])]
|
||||
|
||||
h = hasher()
|
||||
h.update(body)
|
||||
bodyhash = h.digest()
|
||||
logger.debug("bh: %s" % base64.b64encode(bodyhash))
|
||||
if bodyhash != base64.b64decode(re.sub(r"\s+", "", sig['bh'])):
|
||||
if bodyhash != base64.b64decode(re.sub(br"\s+", "", sig[b'bh'])):
|
||||
logger.error(
|
||||
"body hash mismatch (got %s, expected %s)" %
|
||||
(base64.b64encode(bodyhash), sig['bh']))
|
||||
(base64.b64encode(bodyhash), sig[b'bh']))
|
||||
return False
|
||||
|
||||
s = dnsfunc(sig['s']+"._domainkey."+sig['d']+".")
|
||||
s = dnsfunc(sig[b's']+b"._domainkey."+sig[b'd']+b".")
|
||||
if not s:
|
||||
return False
|
||||
try:
|
||||
@@ -413,16 +417,16 @@ def verify(message, logger=None, dnsfunc=dnstxt):
|
||||
except InvalidTagValueList:
|
||||
return False
|
||||
try:
|
||||
pk = parse_public_key(base64.b64decode(pub['p']))
|
||||
pk = parse_public_key(base64.b64decode(pub[b'p']))
|
||||
except UnparsableKeyError as e:
|
||||
logger.error("could not parse public key: %s" % e)
|
||||
return False
|
||||
|
||||
include_headers = re.split(r"\s*:\s*", sig['h'])
|
||||
include_headers = re.split(br"\s*:\s*", sig[b'h'])
|
||||
h = hasher()
|
||||
hash_headers(
|
||||
h, canonicalize_headers, headers, include_headers, sigheaders, sig)
|
||||
signature = base64.b64decode(re.sub(r"\s+", "", sig['b']))
|
||||
signature = base64.b64decode(re.sub(br"\s+", "", sig[b'b']))
|
||||
try:
|
||||
return RSASSA_PKCS1_v1_5_verify(
|
||||
h, signature, pk['publicExponent'], pk['modulus'])
|
||||
|
||||
Reference in New Issue
Block a user