diff --git a/dkim/__init__.py b/dkim/__init__.py index 7224b4d..fba63b5 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -25,7 +25,10 @@ import logging import re import time -from dkim.canonicalization import CanonicalizationPolicy +from dkim.canonicalization import ( + CanonicalizationPolicy, + InvalidCanonicalizationPolicyError, + ) from dkim.crypto import ( DigestTooLargeError, HASH_ALGORITHMS, @@ -317,8 +320,10 @@ def verify(message, logger=None, dnsfunc=get_txt): logger.error("signature fields failed to validate: %s" % e) return False - canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger) - if canon_policy is None: + try: + canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c')) + except InvalidCanonicalizationPolicyError as e: + logger.error("invalid c= value: %s" % e.args[0]) return False headers = canon_policy.canonicalize_headers(headers) body = canon_policy.canonicalize_body(body) @@ -326,7 +331,7 @@ def verify(message, logger=None, dnsfunc=get_txt): try: hasher = HASH_ALGORITHMS[sig[b'a']] except KeyError as e: - logger.error("unknown signature algorithm: %s" % e.message) + logger.error("unknown signature algorithm: %s" % e.args[0]) return False if b'l' in sig: diff --git a/dkim/canonicalization.py b/dkim/canonicalization.py index 0e05981..a674e25 100644 --- a/dkim/canonicalization.py +++ b/dkim/canonicalization.py @@ -23,9 +23,15 @@ import re __all__ = [ 'CanonicalizationPolicy', + 'InvalidCanonicalizationPolicyError', ] +class InvalidCanonicalizationPolicyError(Exception): + """The c= value could not be parsed.""" + pass + + def strip_trailing_whitespace(content): return re.sub(b"[\t ]+\r\n", b"\r\n", content) @@ -90,21 +96,20 @@ class CanonicalizationPolicy: self.body_algorithm = body_algorithm @classmethod - def from_c_value(cls, c, logger=None): + def from_c_value(cls, c): """Construct the canonicalization policy described by a c= value. + May raise an C{InvalidCanonicalizationPolicyError} if the given + value is invalid + @param c: c= value from a DKIM-Signature header field - @return: a C{CanonicalizationPolicy}, or C{None} if the value is - invalid + @return: a C{CanonicalizationPolicy} """ if c is None: c = b'simple/simple' m = c.split(b'/') if len(m) not in (1, 2): - if logger: - logger.error( - "c= value is not in format method/method: %s" % c) - return None + raise InvalidCanonicalizationPolicyError(c) if len(m) == 1: m.append(b'simple') can_headers, can_body = m @@ -112,10 +117,7 @@ class CanonicalizationPolicy: header_algorithm = ALGORITHMS[can_headers] body_algorithm = ALGORITHMS[can_body] except KeyError as e: - if logger: - logger.error( - "unknown canonicalization algorithm: %s" % e.message) - return None + raise InvalidCanonicalizationPolicyError(e.args[0]) return cls(header_algorithm, body_algorithm) def to_c_value(self): diff --git a/dkim/tests/test_canonicalization.py b/dkim/tests/test_canonicalization.py index ce0685d..84505e7 100644 --- a/dkim/tests/test_canonicalization.py +++ b/dkim/tests/test_canonicalization.py @@ -20,6 +20,7 @@ import unittest from dkim.canonicalization import ( CanonicalizationPolicy, + InvalidCanonicalizationPolicyError, Simple, Relaxed, ) @@ -107,7 +108,9 @@ class TestCanonicalizationPolicyFromCValue(unittest.TestCase): (p.header_algorithm, p.body_algorithm)) def assertValueDoesNotParse(self, c_value): - self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value)) + self.assertRaises( + InvalidCanonicalizationPolicyError, + CanonicalizationPolicy.from_c_value, c_value) def test_both_default_to_simple(self): self.assertAlgorithms(Simple, Simple, None) @@ -131,7 +134,7 @@ class TestCanonicalizationPolicyFromCValue(unittest.TestCase): self.assertValueDoesNotParse(b'worried') -class TestCanonicalizationpolicyToCValue(unittest.TestCase): +class TestCanonicalizationPolicyToCValue(unittest.TestCase): def assertCValue(self, c_value, header_algo, body_algo): self.assertEqual(