Implement setsymlist decorator and test framework

This commit is contained in:
Stuart D. Gathman
2016-12-01 23:59:31 -05:00
parent 207278479f
commit 381e906b6a
4 changed files with 86 additions and 9 deletions
+33 -4
View File
@@ -48,6 +48,12 @@ OPTIONAL_CALLBACKS = {
'header':(P_NR_HDR,P_NOHDRS) 'header':(P_NR_HDR,P_NOHDRS)
} }
MACRO_CALLBACKS = {
'connect': M_CONNECT,
'hello': M_HELO, 'envfrom': M_ENVFROM, 'envrcpt': M_ENVRCPT,
'data': M_DATA, 'eom': M_EOM, 'eoh': M_EOH
}
## @private ## @private
R = re.compile(r'%+') R = re.compile(r'%+')
@@ -141,6 +147,7 @@ def nocallback(func):
except KeyError: except KeyError:
raise ValueError( raise ValueError(
'@nocallback applied to non-optional method: '+func.__name__) '@nocallback applied to non-optional method: '+func.__name__)
@wraps(func)
def wrapper(self,*args): def wrapper(self,*args):
if func(self,*args) != CONTINUE: if func(self,*args) != CONTINUE:
raise RuntimeError('%s return code must be CONTINUE with @nocallback' raise RuntimeError('%s return code must be CONTINUE with @nocallback'
@@ -173,6 +180,19 @@ def noreply(func):
wrapper.milter_protocol = nr_mask wrapper.milter_protocol = nr_mask
return wrapper return wrapper
## Function decorator to set macros used in a callback.
# By default, the MTA sends all macros defined for a callback.
# If some or all of these are unused, the bandwidth can be saved
# by listing the ones that are used.
# @since 1.0.2
def symlist(func,*syms):
if func.__name__ not in MACRO_CALLBACKS:
raise ValueError('@symlist applied to non-symlist method: '+func.__name__)
if len(syms) > 5:
raise ValueError('@symlist limited to 5 macros by MTA: '+func.__name__)
func._symlist = syms
return func
## Disabled action exception. ## Disabled action exception.
# set_flags() can tell the MTA that this application will not use certain # set_flags() can tell the MTA that this application will not use certain
# features (such as CHGFROM). This can also be negotiated for each # features (such as CHGFROM). This can also be negotiated for each
@@ -393,6 +413,11 @@ class Base(object):
def negotiate(self,opts): def negotiate(self,opts):
try: try:
self._actions,p,f1,f2 = opts self._actions,p,f1,f2 = opts
for func,stage in MACRO_CALLBACKS.items():
func = getattr(self,func)
syms = getattr(func,'_symlist',None)
if syms is not None:
self.setsymlist(stage,syms)
opts[1] = self._protocol = p & ~self.protocol_mask() opts[1] = self._protocol = p & ~self.protocol_mask()
opts[2] = 0 opts[2] = 0
opts[3] = 0 opts[3] = 0
@@ -443,23 +468,27 @@ class Base(object):
# set. The protocol stages are M_CONNECT, M_HELO, M_ENVFROM, M_ENVRCPT, # set. The protocol stages are M_CONNECT, M_HELO, M_ENVFROM, M_ENVRCPT,
# M_DATA, M_EOM, M_EOH. # M_DATA, M_EOM, M_EOH.
# #
# May only be called from negotiate callback. # May only be called from negotiate callback. Hence, this is an advanced
# feature. Use the @@symlist function decorator to conviently set
# the macros used by a callback.
# @since 0.9.8, previous version was misspelled! # @since 0.9.8, previous version was misspelled!
# @param stage the protocol stage to set to macro list for, # @param stage the protocol stage to set to macro list for,
# one of the M_* constants defined in Milter # one of the M_* constants defined in Milter
# @param macros space separated and/or lists of strings # @param macros space separated and/or lists of strings
def setsymlist(self,stage,*macros): def setsymlist(self,stage,*macros):
if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST") if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST")
if len(macros) > 5:
raise ValueError('setsymlist limited to 5 macros by MTA')
a = [] a = []
for m in macros: for m in macros:
try: try:
m = m.encode('utf8') m = m.encode('utf8')
except: pass except: pass
try: try:
m = m.split(' ') m = m.split(b' ')
except: pass
a += m a += m
return self._ctx.setsymlist(stage,' '.join(a)) except: pass
return self._ctx.setsymlist(stage,b' '.join(a))
# Milter methods which can only be called from eom callback. # Milter methods which can only be called from eom callback.
+41 -4
View File
@@ -40,6 +40,9 @@ class TestBase(object):
self._reply = None self._reply = None
## The rfc822 message object for the current email being fed to the %milter. ## The rfc822 message object for the current email being fed to the %milter.
self._msg = None self._msg = None
## The protocol stage for macros returned
self._stage = None
## The macros returned by protocol stage
self._symlist = [ None, None, None, None, None, None, None ] self._symlist = [ None, None, None, None, None, None, None ]
def log(self,*msg): def log(self,*msg):
@@ -54,8 +57,12 @@ class TestBase(object):
self._macros[name] = val self._macros[name] = val
def getsymval(self,name): def getsymval(self,name):
# FIXME: track stage, and use _symlist stage = self._stage
return self._macros.get(name,'') if stage >= 0:
syms = self._symlist[stage]
if syms is not None and name not in syms:
return None
return self._macros.get(name,None)
def replacebody(self,chunk): def replacebody(self,chunk):
if self._body: if self._body:
@@ -113,7 +120,10 @@ class TestBase(object):
self._reply = (rcode,xcode) + msg self._reply = (rcode,xcode) + msg
def setsymlist(self,stage,macros): def setsymlist(self,stage,macros):
if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST") if not self._actions & SETSYMLIST:
raise DisabledAction("SETSYMLIST")
if self._stage != -1:
raise RuntimeError("setsymlist may only be called from negotiate")
# not used yet, but just for grins we save the data # not used yet, but just for grins we save the data
a = [] a = []
for m in macros: for m in macros:
@@ -121,9 +131,13 @@ class TestBase(object):
m = m.encode('utf8') m = m.encode('utf8')
except: pass except: pass
try: try:
m = m.split(' ') m = m.split(b' ')
except: pass except: pass
a += m a += m
if len(a) > 5:
raise ValueError('setsymlist limited to 5 macros by MTA')
if self._symlist[stage] is not None:
raise ValueError('setsymlist already called for stage:'+stage)
self._symlist[stage] = set(a) self._symlist[stage] = set(a)
## Feed a file like object to the %milter. Calls envfrom, envrcpt for ## Feed a file like object to the %milter. Calls envfrom, envrcpt for
@@ -144,16 +158,32 @@ class TestBase(object):
self._reply = None self._reply = None
self._sender = '<%s>'%sender self._sender = '<%s>'%sender
msg = mime.message_from_file(fp) msg = mime.message_from_file(fp)
# envfrom
self._stage = Milter.M_ENVFROM
rc = self.envfrom(self._sender) rc = self.envfrom(self._sender)
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# envrcpt
for rcpt in (rcpt,) + rcpts: for rcpt in (rcpt,) + rcpts:
self._stage = Milter.M_ENVRCPT
rc = self.envrcpt('<%s>'%rcpt) rc = self.envrcpt('<%s>'%rcpt)
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# data
self._stage = Milter.M_DATA
rc = self.data()
self._stage = None
if rc != Milter.CONTINUE: return rc
# header
for h,val in msg.items(): for h,val in msg.items():
rc = self.header(h,val) rc = self.header(h,val)
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# eoh
self._stage = Milter.M_EOH
rc = self.eoh() rc = self.eoh()
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# body
header,body = msg.as_bytes().split(b'\n\n',1) header,body = msg.as_bytes().split(b'\n\n',1)
bfp = BytesIO(body) bfp = BytesIO(body)
while 1: while 1:
@@ -163,7 +193,9 @@ class TestBase(object):
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
self._msg = msg self._msg = msg
self._body = BytesIO() self._body = BytesIO()
self._stage = Milter.M_EOM
rc = self.eom() rc = self.eom()
self._stage = None
if self._bodyreplaced: if self._bodyreplaced:
body = self._body.getvalue() body = self._body.getvalue()
self._body = BytesIO() self._body = BytesIO()
@@ -189,12 +221,17 @@ class TestBase(object):
self._body = None self._body = None
self._bodyreplaced = False self._bodyreplaced = False
opts = [ Milter.CURR_ACTS,~0,0,0 ] opts = [ Milter.CURR_ACTS,~0,0,0 ]
self._stage = -1
rc = self.negotiate(opts) rc = self.negotiate(opts)
self._stage = Milter.M_CONNECT
rc = super(TestBase,self).connect(host,1,(ip,1234)) rc = super(TestBase,self).connect(host,1,(ip,1234))
if rc != Milter.CONTINUE: if rc != Milter.CONTINUE:
self._stage = None
self.close() self.close()
return rc return rc
self._stage = Milter.M_HELO
rc = self.hello(helo) rc = self.hello(helo)
self._stage = None
if rc != Milter.CONTINUE: if rc != Milter.CONTINUE:
self.close() self.close()
return rc return rc
+7 -1
View File
@@ -33,18 +33,24 @@ class sampleMilter(Milter.Milter):
self.fp = None self.fp = None
self.bodysize = 0 self.bodysize = 0
self.id = Milter.uniqueID() self.id = Milter.uniqueID()
self.user = None
# multiple messages can be received on a single connection # multiple messages can be received on a single connection
# envfrom (MAIL FROM in the SMTP protocol) seems to mark the start # envfrom (MAIL FROM in the SMTP protocol) seems to mark the start
# of each message. # of each message.
@Milter.symlist('{auth_authen}')
@Milter.noreply @Milter.noreply
def envfrom(self,f,*str): def envfrom(self,f,*str):
"start of MAIL transaction" "start of MAIL transaction"
self.log("mail from",f,str)
self.fp = BytesIO() self.fp = BytesIO()
self.tempname = None self.tempname = None
self.mailfrom = f self.mailfrom = f
self.bodysize = 0 self.bodysize = 0
self.user = self.getsymval('{auth_authen}')
if self.user:
self.log("user",self.user,"sent mail from",f,str)
else:
self.log("mail from",f,str)
return Milter.CONTINUE return Milter.CONTINUE
def envrcpt(self,to,*str): def envrcpt(self,to,*str):
+5
View File
@@ -13,9 +13,14 @@ class BMSMilterTestCase(unittest.TestCase):
def testDefang(self,fname='virus1'): def testDefang(self,fname='virus1'):
milter = TestMilter() milter = TestMilter()
milter.setsymval('{auth_authen}','batman')
milter.setsymval('{auth_type}','batcomputer')
milter.setsymval('j','mailhost')
rc = milter.connect() rc = milter.connect()
self.failUnless(rc == Milter.CONTINUE) self.failUnless(rc == Milter.CONTINUE)
rc = milter.feedMsg(fname) rc = milter.feedMsg(fname)
self.failUnless(milter.user == 'batman',"getsymval failed")
self.failUnless(milter.auth_type != 'batcomputer',"setsymlist failed")
self.failUnless(rc == Milter.ACCEPT) self.failUnless(rc == Milter.ACCEPT)
self.failUnless(milter._bodyreplaced,"Message body not replaced") self.failUnless(milter._bodyreplaced,"Message body not replaced")
fp = milter._body fp = milter._body