Refactor load_pk_from_dns to reduce code duplication between async and non-async.

This commit is contained in:
Scott Kitterman
2019-11-05 08:34:13 -05:00
parent e8ee183a7f
commit 3de1dc0362
2 changed files with 8 additions and 48 deletions
+7 -2
View File
@@ -425,8 +425,7 @@ def fold(header, namelen=0, linesep=b'\r\n'):
return pre + header return pre + header
def load_pk_from_dns(name, dnsfunc, timeout=5): def evaluate_pk(name, s):
s = dnsfunc(name, timeout=timeout)
if not s: if not s:
raise KeyFormatError("missing public key: %s"%name) raise KeyFormatError("missing public key: %s"%name)
try: try:
@@ -475,6 +474,12 @@ def load_pk_from_dns(name, dnsfunc, timeout=5):
return pk, keysize, ktag, seqtlsrpt return pk, keysize, ktag, seqtlsrpt
def load_pk_from_dns(name, dnsfunc=get_txt, timeout=5):
s = dnsfunc(name, timeout=timeout)
pk, keysize, ktag, seqtlsrpt = evaluate_pk(name, s)
return pk, keysize, ktag, seqtlsrpt
#: Abstract base class for holding messages and options during DKIM/ARC signing and verification. #: Abstract base class for holding messages and options during DKIM/ARC signing and verification.
class DomainSigner(object): class DomainSigner(object):
# NOTE - the first 2 indentation levels are 2 instead of 4 # NOTE - the first 2 indentation levels are 2 instead of 4
+1 -46
View File
@@ -59,7 +59,6 @@ async def get_txt_async(name, timeout=5):
result = await query(name, 'TXT') result = await query(name, 'TXT')
except aiodns.error.DNSError: except aiodns.error.DNSError:
result = None result = None
print('result', result)
if result: if result:
return result[0].text return result[0].text
@@ -69,51 +68,7 @@ async def get_txt_async(name, timeout=5):
async def load_pk_from_dns_async(name, dnsfunc, timeout=5): async def load_pk_from_dns_async(name, dnsfunc, timeout=5):
s = await dnsfunc(name, timeout=timeout) s = await dnsfunc(name, timeout=timeout)
if not s: pk, keysize, ktag, seqtlsrpt = dkim.evaluate_pk(name, s)
raise dkim.KeyFormatError("missing public key: %s"%name)
try:
if type(s) is str:
s = s.encode('ascii')
pub = dkim.parse_tag_value(s)
except dkim.InvalidTagValueList as e:
raise dkim.KeyFormatError(e)
try:
if pub[b'v'] != b'DKIM1':
raise dkim.KeyFormatError("bad version")
except KeyError as e:
# Version not required in key record: RFC 6376 3.6.1
pass
try:
if pub[b'k'] == b'ed25519':
pk = nacl.signing.VerifyKey(pub[b'p'], encoder=nacl.encoding.Base64Encoder)
keysize = 256
ktag = b'ed25519'
except KeyError:
pub[b'k'] = b'rsa'
if pub[b'k'] == b'rsa':
try:
pk = dkim.parse_public_key(base64.b64decode(pub[b'p']))
keysize = dkim.bitsize(pk['modulus'])
except KeyError:
raise dkim.KeyFormatError("incomplete public key: %s" % s)
except (TypeError,dkim.UnparsableKeyError) as e:
raise dkim.KeyFormatError("could not parse public key (%s): %s" % (pub[b'p'],e))
ktag = b'rsa'
if pub[b'k'] != b'rsa' and pub[b'k'] != b'ed25519':
raise dkim.KeyFormatError('unknown algorithm in k= tag: {0}'.format(pub[b'k']))
seqtlsrpt = False
try:
# Ignore unknown service types, RFC 6376 3.6.1
if pub[b's'] != b'*' and pub[b's'] != b'email' and pub[b's'] != b'tlsrpt':
pk = None
keysize = None
ktag = None
raise dkim.KeyFormatError('unknown service type in s= tag: {0}'.format(pub[b's']))
elif pub[b's'] == b'tlsrpt':
seqtlsrpt = True
except:
# Default is '*' - all service types, so no error if missing from key record
pass
return pk, keysize, ktag, seqtlsrpt return pk, keysize, ktag, seqtlsrpt
class DKIM(dkim.DKIM): class DKIM(dkim.DKIM):