Canonicalizationpolicy.from_c_value() now raise InvalidCanonicalizationPolicyErrors instead of logging and returning None.

This commit is contained in:
William Grant
2011-06-04 15:51:14 +10:00
parent 79eff489d4
commit e9d01800de
3 changed files with 27 additions and 17 deletions
+9 -4
View File
@@ -25,7 +25,10 @@ import logging
import re import re
import time import time
from dkim.canonicalization import CanonicalizationPolicy from dkim.canonicalization import (
CanonicalizationPolicy,
InvalidCanonicalizationPolicyError,
)
from dkim.crypto import ( from dkim.crypto import (
DigestTooLargeError, DigestTooLargeError,
HASH_ALGORITHMS, HASH_ALGORITHMS,
@@ -317,8 +320,10 @@ def verify(message, logger=None, dnsfunc=get_txt):
logger.error("signature fields failed to validate: %s" % e) logger.error("signature fields failed to validate: %s" % e)
return False return False
canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'), logger) try:
if canon_policy is None: canon_policy = CanonicalizationPolicy.from_c_value(sig.get(b'c'))
except InvalidCanonicalizationPolicyError as e:
logger.error("invalid c= value: %s" % e.args[0])
return False return False
headers = canon_policy.canonicalize_headers(headers) headers = canon_policy.canonicalize_headers(headers)
body = canon_policy.canonicalize_body(body) body = canon_policy.canonicalize_body(body)
@@ -326,7 +331,7 @@ def verify(message, logger=None, dnsfunc=get_txt):
try: try:
hasher = HASH_ALGORITHMS[sig[b'a']] hasher = HASH_ALGORITHMS[sig[b'a']]
except KeyError as e: except KeyError as e:
logger.error("unknown signature algorithm: %s" % e.message) logger.error("unknown signature algorithm: %s" % e.args[0])
return False return False
if b'l' in sig: if b'l' in sig:
+13 -11
View File
@@ -23,9 +23,15 @@ import re
__all__ = [ __all__ = [
'CanonicalizationPolicy', 'CanonicalizationPolicy',
'InvalidCanonicalizationPolicyError',
] ]
class InvalidCanonicalizationPolicyError(Exception):
"""The c= value could not be parsed."""
pass
def strip_trailing_whitespace(content): def strip_trailing_whitespace(content):
return re.sub(b"[\t ]+\r\n", b"\r\n", content) return re.sub(b"[\t ]+\r\n", b"\r\n", content)
@@ -90,21 +96,20 @@ class CanonicalizationPolicy:
self.body_algorithm = body_algorithm self.body_algorithm = body_algorithm
@classmethod @classmethod
def from_c_value(cls, c, logger=None): def from_c_value(cls, c):
"""Construct the canonicalization policy described by a c= value. """Construct the canonicalization policy described by a c= value.
May raise an C{InvalidCanonicalizationPolicyError} if the given
value is invalid
@param c: c= value from a DKIM-Signature header field @param c: c= value from a DKIM-Signature header field
@return: a C{CanonicalizationPolicy}, or C{None} if the value is @return: a C{CanonicalizationPolicy}
invalid
""" """
if c is None: if c is None:
c = b'simple/simple' c = b'simple/simple'
m = c.split(b'/') m = c.split(b'/')
if len(m) not in (1, 2): if len(m) not in (1, 2):
if logger: raise InvalidCanonicalizationPolicyError(c)
logger.error(
"c= value is not in format method/method: %s" % c)
return None
if len(m) == 1: if len(m) == 1:
m.append(b'simple') m.append(b'simple')
can_headers, can_body = m can_headers, can_body = m
@@ -112,10 +117,7 @@ class CanonicalizationPolicy:
header_algorithm = ALGORITHMS[can_headers] header_algorithm = ALGORITHMS[can_headers]
body_algorithm = ALGORITHMS[can_body] body_algorithm = ALGORITHMS[can_body]
except KeyError as e: except KeyError as e:
if logger: raise InvalidCanonicalizationPolicyError(e.args[0])
logger.error(
"unknown canonicalization algorithm: %s" % e.message)
return None
return cls(header_algorithm, body_algorithm) return cls(header_algorithm, body_algorithm)
def to_c_value(self): def to_c_value(self):
+5 -2
View File
@@ -20,6 +20,7 @@ import unittest
from dkim.canonicalization import ( from dkim.canonicalization import (
CanonicalizationPolicy, CanonicalizationPolicy,
InvalidCanonicalizationPolicyError,
Simple, Simple,
Relaxed, Relaxed,
) )
@@ -107,7 +108,9 @@ class TestCanonicalizationPolicyFromCValue(unittest.TestCase):
(p.header_algorithm, p.body_algorithm)) (p.header_algorithm, p.body_algorithm))
def assertValueDoesNotParse(self, c_value): def assertValueDoesNotParse(self, c_value):
self.assertIs(None, CanonicalizationPolicy.from_c_value(c_value)) self.assertRaises(
InvalidCanonicalizationPolicyError,
CanonicalizationPolicy.from_c_value, c_value)
def test_both_default_to_simple(self): def test_both_default_to_simple(self):
self.assertAlgorithms(Simple, Simple, None) self.assertAlgorithms(Simple, Simple, None)
@@ -131,7 +134,7 @@ class TestCanonicalizationPolicyFromCValue(unittest.TestCase):
self.assertValueDoesNotParse(b'worried') self.assertValueDoesNotParse(b'worried')
class TestCanonicalizationpolicyToCValue(unittest.TestCase): class TestCanonicalizationPolicyToCValue(unittest.TestCase):
def assertCValue(self, c_value, header_algo, body_algo): def assertCValue(self, c_value, header_algo, body_algo):
self.assertEqual( self.assertEqual(