From 206c86089022c9aee00666591193321eaf4d42b4 Mon Sep 17 00:00:00 2001 From: William Grant Date: Sat, 4 Jun 2011 14:30:19 +1000 Subject: [PATCH] Pull c= value parsing out into Canonicalizationpolicy.from_c_value. --- dkim/__init__.py | 22 +++------------- dkim/canonicalization.py | 29 +++++++++++++++++++++ dkim/tests/test_canonicalization.py | 39 ++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/dkim/__init__.py b/dkim/__init__.py index 6311ab1..cee4ae0 100644 --- a/dkim/__init__.py +++ b/dkim/__init__.py @@ -320,24 +320,9 @@ def verify(message, logger=None, dnsfunc=get_txt): logger.error("signature fields failed to validate: %s" % e) return False - m = re.match(b"(\w+)(?:/(\w+))?$", sig[b'c']) - if m is None: - logger.error( - "c= value is not in format method/method (%s)" % sig[b'c']) + canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger) + if canon_policy is None: return False - can_headers = m.group(1) - if m.group(2) is not None: - can_body = m.group(2) - else: - can_body = b"simple" - - try: - header_algorithm = algorithms[can_headers] - body_algorithm = algorithms[can_body] - except KeyError as e: - logger.error("unknown canonicalization algorithm: %s" % e.message) - return False - canon_policy = CanonicalizationPolicy(header_algorithm, body_algorithm) headers = canon_policy.canonicalize_headers(headers) body = canon_policy.canonicalize_body(body) @@ -376,8 +361,7 @@ def verify(message, logger=None, dnsfunc=get_txt): include_headers = re.split(br"\s*:\s*", sig[b'h']) h = hasher() - hash_headers( - h, header_algorithm, headers, include_headers, sigheaders, sig) + hash_headers(h, canon_policy, headers, include_headers, sigheaders, sig) signature = base64.b64decode(re.sub(br"\s+", b"", sig[b'b'])) try: return RSASSA_PKCS1_v1_5_verify( diff --git a/dkim/canonicalization.py b/dkim/canonicalization.py index 7023dc8..7395dd4 100644 --- a/dkim/canonicalization.py +++ b/dkim/canonicalization.py @@ -90,6 +90,35 @@ class CanonicalizationPolicy: self.header_algorithm = header_algorithm self.body_algorithm = body_algorithm + @classmethod + def from_c_value(cls, c, logger=None): + """Construct the canonicalization policy described by a c= value. + + @param c: c= value from a DKIM-Signature header field + @return: a C{CanonicalizationPolicy}, or C{None} if the value is + invalid + """ + if c is None: + c = b'simple/simple' + m = c.split(b'/') + if len(m) not in (1, 2): + if logger: + logger.error( + "c= value is not in format method/method: %s" % c) + return None + if len(m) == 1: + m.append(b'simple') + can_headers, can_body = m + try: + header_algorithm = algorithms[can_headers] + body_algorithm = algorithms[can_body] + except KeyError as e: + if logger: + logger.error( + "unknown canonicalization algorithm: %s" % e.message) + return None + return cls(header_algorithm, body_algorithm) + def canonicalize_headers(self, headers): return self.header_algorithm.canonicalize_headers(headers) diff --git a/dkim/tests/test_canonicalization.py b/dkim/tests/test_canonicalization.py index 5269f72..fb97d44 100644 --- a/dkim/tests/test_canonicalization.py +++ b/dkim/tests/test_canonicalization.py @@ -18,7 +18,11 @@ import unittest -from dkim.canonicalization import Simple, Relaxed +from dkim.canonicalization import ( + CanonicalizationPolicy, + Simple, + Relaxed, + ) class BaseCanonicalizationTest(unittest.TestCase): @@ -94,6 +98,39 @@ class TestRelaxedAlgorithmBody(BaseCanonicalizationTest): b'Foo\r\nbar\r\n\r\n\r\n') +class TestCanonicalizationPolicyFromCValue(unittest.TestCase): + + def assertAlgorithms(self, header_algo, body_algo, c_value): + p = CanonicalizationPolicy.from_c_value(c_value) + self.assertEqual( + (header_algo, body_algo), + (p.header_algorithm, p.body_algorithm)) + + def assertValueDoesNotParse(self, c_value): + self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value)) + + def test_both_default_to_simple(self): + self.assertAlgorithms(Simple, Simple, None) + + def test_relaxed_headers(self): + self.assertAlgorithms(Relaxed, Simple, b'relaxed') + + def test_relaxed_body(self): + self.assertAlgorithms(Simple, Relaxed, b'simple/relaxed') + + def test_relaxed_both(self): + self.assertAlgorithms(Relaxed, Relaxed, b'relaxed/relaxed') + + def test_explict_simple_both(self): + self.assertAlgorithms(Simple, Simple, b'simple/simple') + + def test_corruption_is_ignored(self): + self.assertValueDoesNotParse(b'') + self.assertValueDoesNotParse(b'simple/simple/simple') + self.assertValueDoesNotParse(b'relaxed/stressed') + self.assertValueDoesNotParse(b'worried') + + def test_suite(): from unittest import TestLoader return TestLoader().loadTestsFromName(__name__)