diff --git a/dkim/__init__.py b/dkim/__init__.py index dee2b7a..7224b4d 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -25,7 +25,7 @@ import logging import re import time -from dkim.canonicalization import algorithms +from dkim.canonicalization import CanonicalizationPolicy from dkim.crypto import ( DigestTooLargeError, HASH_ALGORITHMS, @@ -231,7 +231,9 @@ def sign(message, selector, domain, privkey, identity=None, if identity is not None and not identity.endswith(domain): raise ParameterError("identity must end with domain") - headers = algorithms[canonicalize[0]].canonicalize_headers(headers) + canon_policy = CanonicalizationPolicy.from_c_value( + b'/'.join(canonicalize)) + headers = canon_policy.canonicalize_headers(headers) if include_headers is None: include_headers = [x[0].lower() for x in headers] @@ -239,7 +241,7 @@ def sign(message, selector, domain, privkey, identity=None, include_headers = [x.lower() for x in include_headers] sign_headers = [x for x in headers if x[0].lower() in include_headers] - body = algorithms[canonicalize[1]].canonicalize_body(body) + body = canon_policy.canonicalize_body(body) h = hashlib.sha256() h.update(body) @@ -248,9 +250,7 @@ def sign(message, selector, domain, privkey, identity=None, sigfields = [x for x in [ (b'v', b"1"), (b'a', signature_algorithm), - (b'c', b"/".join( - (algorithms[canonicalize[0]].name, - algorithms[canonicalize[1]].name))), + (b'c', canon_policy.to_c_value()), (b'd', domain), (b'i', identity or b"@"+domain), length and (b'l', len(body)), @@ -263,7 +263,7 @@ def sign(message, selector, domain, privkey, identity=None, ] if x] sig_value = fold(b"; ".join(b"=".join(x) for x in sigfields)) - dkim_header = algorithms[canonicalize[0]].canonicalize_headers([ + dkim_header = canon_policy.canonicalize_headers([ [b'DKIM-Signature', b' ' + sig_value]])[0] # the dkim sig is hashed with no trailing crlf, even if the # canonicalization algorithm would add one. @@ -317,25 +317,11 @@ def verify(message, logger=None, dnsfunc=get_txt): logger.error("signature fields failed to validate: %s" % e) return False - m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c']) - if m is None: - logger.error( - "c= value is not in format method/method (%s)" % sig[b'c']) + canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger) + if canon_policy is None: return False - can_headers = m.group(1) - if m.group(2) is not None: - can_body = m.group(2) - else: - can_body = b"simple" - - try: - header_algorithm = algorithms[can_headers] - body_algorithm = algorithms[can_body] - except KeyError as e: - logger.error("unknown canonicalization algorithm: %s" % e.message) - return False - headers = header_algorithm.canonicalize_headers(headers) - body = body_algorithm.canonicalize_body(body) + headers = canon_policy.canonicalize_headers(headers) + body = canon_policy.canonicalize_body(body) try: hasher = HASH_ALGORITHMS[sig[b'a']] @@ -372,8 +358,7 @@ def verify(message, logger=None, dnsfunc=get_txt): include_headers = re.split(br"\s*:\s*", sig[b'h']) h = hasher() - hash_headers( - h, header_algorithm, headers, include_headers, sigheaders, sig) + hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig) signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) try: return RSASSA_PKCS1_v1_5_verify( diff --git a/dkim/canonicalization.py b/dkim/canonicalization.py index 9b191f2..0e05981 100644 --- a/dkim/canonicalization.py +++ b/dkim/canonicalization.py @@ -22,7 +22,7 @@ import re __all__ = [ - 'algorithms', + 'CanonicalizationPolicy', ] @@ -83,4 +83,50 @@ class Relaxed: compress_whitespace(strip_trailing_whitespace(body))) -algorithms = dict((c.name, c) for c in (Simple, Relaxed)) +class CanonicalizationPolicy: + + def __init__(self, header_algorithm, body_algorithm): + self.header_algorithm = header_algorithm + self.body_algorithm = body_algorithm + + @classmethod + def from_c_value(cls, c, logger=None): + """Construct the canonicalization policy described by a c= value. + + @param c: c= value from a DKIM-Signature header field + @return: a C{CanonicalizationPolicy}, or C{None} if the value is + invalid + """ + if c is None: + c = b'simple/simple' + m = c.split(b'/') + if len(m) not in (1, 2): + if logger: + logger.error( + "c= value is not in format method/method: %s" % c) + return None + if len(m) == 1: + m.append(b'simple') + can_headers, can_body = m + try: + header_algorithm = ALGORITHMS[can_headers] + body_algorithm = ALGORITHMS[can_body] + except KeyError as e: + if logger: + logger.error( + "unknown canonicalization algorithm: %s" % e.message) + return None + return cls(header_algorithm, body_algorithm) + + def to_c_value(self): + return b'/'.join( + (self.header_algorithm.name, self.body_algorithm.name)) + + def canonicalize_headers(self, headers): + return self.header_algorithm.canonicalize_headers(headers) + + def canonicalize_body(self, body): + return self.body_algorithm.canonicalize_body(body) + + +ALGORITHMS = dict((c.name, c) for c in (Simple, Relaxed)) diff --git a/dkim/tests/test_canonicalization.py b/dkim/tests/test_canonicalization.py index 5269f72..ce0685d 100644 --- a/dkim/tests/test_canonicalization.py +++ b/dkim/tests/test_canonicalization.py @@ -18,7 +18,11 @@ import unittest -from dkim.canonicalization import Simple, Relaxed +from dkim.canonicalization import ( + CanonicalizationPolicy, + Simple, + Relaxed, + ) class BaseCanonicalizationTest(unittest.TestCase): @@ -94,6 +98,59 @@ class TestRelaxedAlgorithmBody(BaseCanonicalizationTest): b'Foo\r\nbar\r\n\r\n\r\n') +class TestCanonicalizationPolicyFromCValue(unittest.TestCase): + + def assertAlgorithms(self, header_algo, body_algo, c_value): + p = CanonicalizationPolicy.from_c_value(c_value) + self.assertEqual( + (header_algo, body_algo), + (p.header_algorithm, p.body_algorithm)) + + def assertValueDoesNotParse(self, c_value): + self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value)) + + def test_both_default_to_simple(self): + self.assertAlgorithms(Simple, Simple, None) + + def test_relaxed_headers(self): + self.assertAlgorithms(Relaxed, Simple, b'relaxed') + + def test_relaxed_body(self): + self.assertAlgorithms(Simple, Relaxed, b'simple/relaxed') + + def test_relaxed_both(self): + self.assertAlgorithms(Relaxed, Relaxed, b'relaxed/relaxed') + + def test_explict_simple_both(self): + self.assertAlgorithms(Simple, Simple, b'simple/simple') + + def test_corruption_is_ignored(self): + self.assertValueDoesNotParse(b'') + self.assertValueDoesNotParse(b'simple/simple/simple') + self.assertValueDoesNotParse(b'relaxed/stressed') + self.assertValueDoesNotParse(b'worried') + + +class TestCanonicalizationpolicyToCValue(unittest.TestCase): + + def assertCValue(self, c_value, header_algo, body_algo): + self.assertEqual( + c_value, + CanonicalizationPolicy(header_algo, body_algo).to_c_value()) + + def test_both_simple(self): + self.assertCValue(b'simple/simple', Simple, Simple) + + def test_relaxed_body(self): + self.assertCValue(b'simple/relaxed', Simple, Relaxed) + + def test_both_relaxed(self): + self.assertCValue(b'relaxed/relaxed', Relaxed, Relaxed) + + def test_relaxed_headers(self): + self.assertCValue(b'relaxed/simple', Relaxed, Simple) + + def test_suite(): from unittest import TestLoader return TestLoader().loadTestsFromName(__name__)