File async-batch-implementation.patch of Package salt.12957
From dfd16dd5968aae96e36e0dee412864fc765b62fb Mon Sep 17 00:00:00 2001
From: Mihai Dinca <mdinca@suse.de>
Date: Fri, 16 Nov 2018 17:05:29 +0100
Subject: [PATCH] Async batch implementation
Add find_job checks
Check if should close on all events
Make batch_delay a request parameter
Allow multiple event handlers
Use config value for gather_job_timeout when not in payload
Add async batch unittests
Allow metadata to pass
Pass metadata only to batch jobs
Add the metadata to the start/done events
Pass only metadata not all **kwargs
Add separate batch presence_ping timeout
---
 salt/auth/__init__.py              |   4 +-
 salt/cli/batch.py                  |  91 ++++++--
 salt/cli/batch_async.py            | 227 +++++++++++++++++++
 salt/client/__init__.py            |  44 +---
 salt/master.py                     |  25 ++
 salt/netapi/__init__.py            |   3 +-
 salt/transport/ipc.py              |  11 +-
 salt/utils/event.py                |  11 +-
 tests/unit/cli/test_batch_async.py | 351 +++++++++++++++++++++++++++++
 9 files changed, 707 insertions(+), 60 deletions(-)
 create mode 100644 salt/cli/batch_async.py
 create mode 100644 tests/unit/cli/test_batch_async.py
diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py
index 61fbb018fd..a8aefa7091 100644
--- a/salt/auth/__init__.py
+++ b/salt/auth/__init__.py
@@ -51,7 +51,9 @@ AUTH_INTERNAL_KEYWORDS = frozenset([
     'metadata',
     'print_event',
     'raw',
-    'yield_pub_data'
+    'yield_pub_data',
+    'batch',
+    'batch_delay'
 ])
 
 
diff --git a/salt/cli/batch.py b/salt/cli/batch.py
index e3a7bf9bcf..4bd07f584a 100644
--- a/salt/cli/batch.py
+++ b/salt/cli/batch.py
@@ -26,6 +26,79 @@ import logging
 log = logging.getLogger(__name__)
 
 
+def get_bnum(opts, minions, quiet):
+    '''
+    Return the active number of minions to maintain
+    '''
+    partition = lambda x: float(x) / 100.0 * len(minions)
+    try:
+        if '%' in opts['batch']:
+            res = partition(float(opts['batch'].strip('%')))
+            if res < 1:
+                return int(math.ceil(res))
+            else:
+                return int(res)
+        else:
+            return int(opts['batch'])
+    except ValueError:
+        if not quiet:
+            salt.utils.stringutils.print_cli('Invalid batch data sent: {0}\nData must be in the '
+                      'form of %10, 10% or 3'.format(opts['batch']))
+
+
+def batch_get_opts(
+        tgt,
+        fun,
+        batch,
+        parent_opts,
+        arg=(),
+        tgt_type='glob',
+        ret='',
+        kwarg=None,
+        **kwargs):
+    # We need to re-import salt.utils.args here
+    # even though it has already been imported.
+    # when cmd_batch is called via the NetAPI
+    # the module is unavailable.
+    import salt.utils.args
+
+    arg = salt.utils.args.condition_input(arg, kwarg)
+    opts = {'tgt': tgt,
+            'fun': fun,
+            'arg': arg,
+            'tgt_type': tgt_type,
+            'ret': ret,
+            'batch': batch,
+            'failhard': kwargs.get('failhard', False),
+            'raw': kwargs.get('raw', False)}
+
+    if 'timeout' in kwargs:
+        opts['timeout'] = kwargs['timeout']
+    if 'gather_job_timeout' in kwargs:
+        opts['gather_job_timeout'] = kwargs['gather_job_timeout']
+    if 'batch_wait' in kwargs:
+        opts['batch_wait'] = int(kwargs['batch_wait'])
+
+    for key, val in six.iteritems(parent_opts):
+        if key not in opts:
+            opts[key] = val
+
+    return opts
+
+
+def batch_get_eauth(kwargs):
+    eauth = {}
+    if 'eauth' in kwargs:
+        eauth['eauth'] = kwargs.pop('eauth')
+    if 'username' in kwargs:
+        eauth['username'] = kwargs.pop('username')
+    if 'password' in kwargs:
+        eauth['password'] = kwargs.pop('password')
+    if 'token' in kwargs:
+        eauth['token'] = kwargs.pop('token')
+    return eauth
+
+
 class Batch(object):
     '''
     Manage the execution of batch runs
@@ -80,23 +153,7 @@ class Batch(object):
         return (list(fret), ping_gen, nret.difference(fret))
 
     def get_bnum(self):
-        '''
-        Return the active number of minions to maintain
-        '''
-        partition = lambda x: float(x) / 100.0 * len(self.minions)
-        try:
-            if '%' in self.opts['batch']:
-                res = partition(float(self.opts['batch'].strip('%')))
-                if res < 1:
-                    return int(math.ceil(res))
-                else:
-                    return int(res)
-            else:
-                return int(self.opts['batch'])
-        except ValueError:
-            if not self.quiet:
-                salt.utils.stringutils.print_cli('Invalid batch data sent: {0}\nData must be in the '
-                          'form of %10, 10% or 3'.format(self.opts['batch']))
+        return get_bnum(self.opts, self.minions, self.quiet)
 
     def __update_wait(self, wait):
         now = datetime.now()
diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
new file mode 100644
index 0000000000..3160d46d8b
--- /dev/null
+++ b/salt/cli/batch_async.py
@@ -0,0 +1,227 @@
+# -*- coding: utf-8 -*-
+'''
+Execute a job on the targeted minions by using a moving window of fixed size `batch`.
+'''
+
+# Import python libs
+from __future__ import absolute_import, print_function, unicode_literals
+import tornado
+
+# Import salt libs
+import salt.client
+
+# pylint: enable=import-error,no-name-in-module,redefined-builtin
+import logging
+import fnmatch
+
+log = logging.getLogger(__name__)
+
+from salt.cli.batch import get_bnum, batch_get_opts, batch_get_eauth
+
+
+class BatchAsync(object):
+    '''
+    Run a job on the targeted minions by using a moving window of fixed size `batch`.
+
+    ``BatchAsync`` is used to execute a job on the targeted minions by keeping
+    the number of concurrent running minions to the size of `batch` parameter.
+
+    The control parameters are:
+        - batch: number/percentage of concurrent running minions
+        - batch_delay: minimum wait time between batches
+        - batch_presence_ping_timeout: time to wait for presence pings before starting the batch
+        - gather_job_timeout: `find_job` timeout
+        - timeout: time to wait before firing a `find_job`
+
+    When the batch stars, a `start` event is fired:
+         - tag: salt/batch/<batch-jid>/start
+         - data: {
+             "available_minions": self.minions,
+             "down_minions": self.down_minions
+           }
+
+    When the batch ends, an `done` event is fired:
+        - tag: salt/batch/<batch-jid>/done
+        - data: {
+             "available_minions": self.minions,
+             "down_minions": self.down_minions,
+             "done_minions": self.done_minions,
+             "timedout_minions": self.timedout_minions
+         }
+    '''
+    def __init__(self, parent_opts, jid_gen, clear_load):
+        ioloop = tornado.ioloop.IOLoop.current()
+        self.local = salt.client.get_local_client(parent_opts['conf_file'])
+        if 'gather_job_timeout' in clear_load['kwargs']:
+            clear_load['gather_job_timeout'] = clear_load['kwargs'].pop('gather_job_timeout')
+        else:
+            clear_load['gather_job_timeout'] = self.local.opts['gather_job_timeout']
+        self.batch_presence_ping_timeout = clear_load['kwargs'].get('batch_presence_ping_timeout', None)
+        self.batch_delay = clear_load['kwargs'].get('batch_delay', 1)
+        self.opts = batch_get_opts(
+            clear_load.pop('tgt'),
+            clear_load.pop('fun'),
+            clear_load['kwargs'].pop('batch'),
+            self.local.opts,
+            **clear_load)
+        self.eauth = batch_get_eauth(clear_load['kwargs'])
+        self.metadata = clear_load['kwargs'].get('metadata', {})
+        self.minions = set()
+        self.down_minions = set()
+        self.timedout_minions = set()
+        self.done_minions = set()
+        self.active = set()
+        self.initialized = False
+        self.ping_jid = jid_gen()
+        self.batch_jid = jid_gen()
+        self.find_job_jid = jid_gen()
+        self.find_job_returned = set()
+        self.event = salt.utils.event.get_event(
+            'master',
+            self.opts['sock_dir'],
+            self.opts['transport'],
+            opts=self.opts,
+            listen=True,
+            io_loop=ioloop,
+            keep_loop=True)
+
+    def __set_event_handler(self):
+        ping_return_pattern = 'salt/job/{0}/ret/*'.format(self.ping_jid)
+        batch_return_pattern = 'salt/job/{0}/ret/*'.format(self.batch_jid)
+        find_job_return_pattern = 'salt/job/{0}/ret/*'.format(self.find_job_jid)
+        self.event.subscribe(ping_return_pattern, match_type='glob')
+        self.event.subscribe(batch_return_pattern, match_type='glob')
+        self.event.subscribe(find_job_return_pattern, match_type='glob')
+        self.event.patterns = {
+            (ping_return_pattern, 'ping_return'),
+            (batch_return_pattern, 'batch_run'),
+            (find_job_return_pattern, 'find_job_return')
+        }
+        self.event.set_event_handler(self.__event_handler)
+
+    def __event_handler(self, raw):
+        if not self.event:
+            return
+        mtag, data = self.event.unpack(raw, self.event.serial)
+        for (pattern, op) in self.event.patterns:
+            if fnmatch.fnmatch(mtag, pattern):
+                minion = data['id']
+                if op == 'ping_return':
+                    self.minions.add(minion)
+                    self.down_minions.remove(minion)
+                    if not self.down_minions:
+                        self.event.io_loop.spawn_callback(self.start_batch)
+                elif op == 'find_job_return':
+                    self.find_job_returned.add(minion)
+                elif op == 'batch_run':
+                    if minion in self.active:
+                        self.active.remove(minion)
+                        self.done_minions.add(minion)
+                        # call later so that we maybe gather more returns
+                        self.event.io_loop.call_later(self.batch_delay, self.schedule_next)
+
+        if self.initialized and self.done_minions == self.minions.difference(self.timedout_minions):
+            self.end_batch()
+
+    def _get_next(self):
+        to_run = self.minions.difference(
+            self.done_minions).difference(
+            self.active).difference(
+            self.timedout_minions)
+        next_batch_size = min(
+            len(to_run),                   # partial batch (all left)
+            self.batch_size - len(self.active)  # full batch or available slots
+        )
+        return set(list(to_run)[:next_batch_size])
+
+    @tornado.gen.coroutine
+    def check_find_job(self, minions):
+        did_not_return = minions.difference(self.find_job_returned)
+        if did_not_return:
+            for minion in did_not_return:
+                if minion in self.find_job_returned:
+                    self.find_job_returned.remove(minion)
+                if minion in self.active:
+                    self.active.remove(minion)
+                self.timedout_minions.add(minion)
+        running = minions.difference(did_not_return).difference(self.done_minions).difference(self.timedout_minions)
+        if running:
+            self.event.io_loop.add_callback(self.find_job, running)
+
+    @tornado.gen.coroutine
+    def find_job(self, minions):
+        not_done = minions.difference(self.done_minions)
+        ping_return = yield self.local.run_job_async(
+            not_done,
+            'saltutil.find_job',
+            [self.batch_jid],
+            'list',
+            gather_job_timeout=self.opts['gather_job_timeout'],
+            jid=self.find_job_jid,
+            **self.eauth)
+        self.event.io_loop.call_later(
+            self.opts['gather_job_timeout'],
+            self.check_find_job,
+            not_done)
+
+    @tornado.gen.coroutine
+    def start(self):
+        self.__set_event_handler()
+        #start batching even if not all minions respond to ping
+        self.event.io_loop.call_later(
+            self.batch_presence_ping_timeout or self.opts['gather_job_timeout'],
+            self.start_batch)
+        ping_return = yield self.local.run_job_async(
+            self.opts['tgt'],
+            'test.ping',
+            [],
+            self.opts.get(
+                'selected_target_option',
+                self.opts.get('tgt_type', 'glob')
+            ),
+            gather_job_timeout=self.opts['gather_job_timeout'],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            **self.eauth)
+        self.down_minions = set(ping_return['minions'])
+
+    @tornado.gen.coroutine
+    def start_batch(self):
+        if not self.initialized:
+            self.batch_size = get_bnum(self.opts, self.minions, True)
+            self.initialized = True
+            data = {
+                "available_minions": self.minions,
+                "down_minions": self.down_minions,
+                "metadata": self.metadata
+            }
+            self.event.fire_event(data, "salt/batch/{0}/start".format(self.batch_jid))
+            yield self.schedule_next()
+
+    def end_batch(self):
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.down_minions,
+            "done_minions": self.done_minions,
+            "timedout_minions": self.timedout_minions,
+            "metadata": self.metadata
+        }
+        self.event.fire_event(data, "salt/batch/{0}/done".format(self.batch_jid))
+        self.event.remove_event_handler(self.__event_handler)
+
+    @tornado.gen.coroutine
+    def schedule_next(self):
+        next_batch = self._get_next()
+        if next_batch:
+            yield self.local.run_job_async(
+                next_batch,
+                self.opts['fun'],
+                self.opts['arg'],
+                'list',
+                raw=self.opts.get('raw', False),
+                ret=self.opts.get('return', ''),
+                gather_job_timeout=self.opts['gather_job_timeout'],
+                jid=self.batch_jid,
+                metadata=self.metadata)
+            self.event.io_loop.call_later(self.opts['timeout'], self.find_job, set(next_batch))
+            self.active = self.active.union(next_batch)
diff --git a/salt/client/__init__.py b/salt/client/__init__.py
index 9f0903c7f0..8b37422cbf 100644
--- a/salt/client/__init__.py
+++ b/salt/client/__init__.py
@@ -531,45 +531,14 @@ class LocalClient(object):
             {'dave': {...}}
             {'stewart': {...}}
         '''
-        # We need to re-import salt.utils.args here
-        # even though it has already been imported.
-        # when cmd_batch is called via the NetAPI
-        # the module is unavailable.
-        import salt.utils.args
-
         # Late import - not used anywhere else in this file
         import salt.cli.batch
+        opts = salt.cli.batch.batch_get_opts(
+            tgt, fun, batch, self.opts,
+            arg=arg, tgt_type=tgt_type, ret=ret, kwarg=kwarg, **kwargs)
+
+        eauth = salt.cli.batch.batch_get_eauth(kwargs)
 
-        arg = salt.utils.args.condition_input(arg, kwarg)
-        opts = {'tgt': tgt,
-                'fun': fun,
-                'arg': arg,
-                'tgt_type': tgt_type,
-                'ret': ret,
-                'batch': batch,
-                'failhard': kwargs.get('failhard', False),
-                'raw': kwargs.get('raw', False)}
-
-        if 'timeout' in kwargs:
-            opts['timeout'] = kwargs['timeout']
-        if 'gather_job_timeout' in kwargs:
-            opts['gather_job_timeout'] = kwargs['gather_job_timeout']
-        if 'batch_wait' in kwargs:
-            opts['batch_wait'] = int(kwargs['batch_wait'])
-
-        eauth = {}
-        if 'eauth' in kwargs:
-            eauth['eauth'] = kwargs.pop('eauth')
-        if 'username' in kwargs:
-            eauth['username'] = kwargs.pop('username')
-        if 'password' in kwargs:
-            eauth['password'] = kwargs.pop('password')
-        if 'token' in kwargs:
-            eauth['token'] = kwargs.pop('token')
-
-        for key, val in six.iteritems(self.opts):
-            if key not in opts:
-                opts[key] = val
         batch = salt.cli.batch.Batch(opts, eauth=eauth, quiet=True)
         for ret in batch.run():
             yield ret
@@ -1732,7 +1701,8 @@ class LocalClient(object):
             if listen and not self.event.connect_pub(timeout=timeout):
                 raise SaltReqTimeoutError()
             payload = channel.send(payload_kwargs, timeout=timeout)
-        except SaltReqTimeoutError:
+        except SaltReqTimeoutError as err:
+            log.error(err)
             raise SaltReqTimeoutError(
                 'Salt request timed out. The master is not responding. You '
                 'may need to run your command with `--async` in order to '
diff --git a/salt/master.py b/salt/master.py
index 6881aae137..f08c126280 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -32,6 +32,7 @@ import tornado.gen  # pylint: disable=F0401
 
 # Import salt libs
 import salt.crypt
+import salt.cli.batch_async
 import salt.client
 import salt.client.ssh.client
 import salt.exceptions
@@ -2039,6 +2040,27 @@ class ClearFuncs(object):
             return False
         return self.loadauth.get_tok(clear_load['token'])
 
+    def publish_batch(self, clear_load, minions, missing):
+        batch_load = {}
+        batch_load.update(clear_load)
+        import salt.cli.batch_async
+        batch = salt.cli.batch_async.BatchAsync(
+            self.local.opts,
+            functools.partial(self._prep_jid, clear_load, {}),
+            batch_load
+        )
+        ioloop = tornado.ioloop.IOLoop.current()
+        ioloop.add_callback(batch.start)
+
+        return {
+            'enc': 'clear',
+            'load': {
+                'jid': batch.batch_jid,
+                'minions': minions,
+                'missing': missing
+            }
+        }
+
     def publish(self, clear_load):
         '''
         This method sends out publications to the minions, it can only be used
@@ -2130,6 +2152,9 @@ class ClearFuncs(object):
                         'error': 'Master could not resolve minions for target {0}'.format(clear_load['tgt'])
                     }
                 }
+        if extra.get('batch', None):
+            return self.publish_batch(clear_load, minions, missing)
+
         jid = self._prep_jid(clear_load, extra)
         if jid is None:
             return {'enc': 'clear',
diff --git a/salt/netapi/__init__.py b/salt/netapi/__init__.py
index 95f6384889..43b6e943a7 100644
--- a/salt/netapi/__init__.py
+++ b/salt/netapi/__init__.py
@@ -88,7 +88,8 @@ class NetapiClient(object):
         :return: job ID
         '''
         local = salt.client.get_local_client(mopts=self.opts)
-        return local.run_job(*args, **kwargs)
+        ret = local.run_job(*args, **kwargs)
+        return ret
 
     def local(self, *args, **kwargs):
         '''
diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py
index 40a172991d..8235f104ef 100644
--- a/salt/transport/ipc.py
+++ b/salt/transport/ipc.py
@@ -669,6 +669,8 @@ class IPCMessageSubscriber(IPCClient):
         self._sync_ioloop_running = False
         self.saved_data = []
         self._sync_read_in_progress = Semaphore()
+        self.callbacks = set()
+        self.reading = False
 
     @tornado.gen.coroutine
     def _read_sync(self, timeout):
@@ -756,6 +758,7 @@ class IPCMessageSubscriber(IPCClient):
         while not self.stream.closed():
             try:
                 self._read_stream_future = self.stream.read_bytes(4096, partial=True)
+                self.reading = True
                 wire_bytes = yield self._read_stream_future
                 self._read_stream_future = None
                 self.unpacker.feed(wire_bytes)
@@ -768,8 +771,12 @@ class IPCMessageSubscriber(IPCClient):
             except Exception as exc:
                 log.error('Exception occurred while Subscriber handling stream: %s', exc)
 
+    def __run_callbacks(self, raw):
+        for callback in self.callbacks:
+            self.io_loop.spawn_callback(callback, raw)
+
     @tornado.gen.coroutine
-    def read_async(self, callback):
+    def read_async(self):
         '''
         Asynchronously read messages and invoke a callback when they are ready.
 
@@ -784,7 +791,7 @@ class IPCMessageSubscriber(IPCClient):
             except Exception as exc:
                 log.error('Exception occurred while Subscriber connecting: %s', exc)
                 yield tornado.gen.sleep(1)
-        yield self._read_async(callback)
+        yield self._read_async(self.__run_callbacks)
 
     def close(self):
         '''
diff --git a/salt/utils/event.py b/salt/utils/event.py
index 296a296084..d2700bd2a0 100644
--- a/salt/utils/event.py
+++ b/salt/utils/event.py
@@ -863,6 +863,10 @@ class SaltEvent(object):
                     # Minion fired a bad retcode, fire an event
                     self._fire_ret_load_specific_fun(load)
 
+    def remove_event_handler(self, event_handler):
+        if event_handler in self.subscriber.callbacks:
+            self.subscriber.callbacks.remove(event_handler)
+
     def set_event_handler(self, event_handler):
         '''
         Invoke the event_handler callback each time an event arrives.
@@ -871,8 +875,11 @@ class SaltEvent(object):
 
         if not self.cpub:
             self.connect_pub()
-        # This will handle reconnects
-        return self.subscriber.read_async(event_handler)
+
+        self.subscriber.callbacks.add(event_handler)
+        if not self.subscriber.reading:
+            # This will handle reconnects
+            self.subscriber.read_async()
 
     def __del__(self):
         # skip exceptions in destroy-- since destroy() doesn't cover interpreter
diff --git a/tests/unit/cli/test_batch_async.py b/tests/unit/cli/test_batch_async.py
new file mode 100644
index 0000000000..f65b6a06c3
--- /dev/null
+++ b/tests/unit/cli/test_batch_async.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+
+# Import Salt Libs
+from salt.cli.batch_async import BatchAsync
+
+import tornado
+from tornado.testing import AsyncTestCase
+from tests.support.unit import skipIf, TestCase
+from tests.support.mock import MagicMock, patch, NO_MOCK, NO_MOCK_REASON
+
+
+@skipIf(NO_MOCK, NO_MOCK_REASON)
+class AsyncBatchTestCase(AsyncTestCase, TestCase):
+
+    def setUp(self):
+        self.io_loop = self.get_new_ioloop()
+        opts = {'batch': '1',
+                'conf_file': {},
+                'tgt': '*',
+                'timeout': 5,
+                'gather_job_timeout': 5,
+                'batch_presence_ping_timeout': 1,
+                'transport': None,
+                'sock_dir': ''}
+
+        with patch('salt.client.get_local_client', MagicMock(return_value=MagicMock())):
+            with patch('salt.cli.batch_async.batch_get_opts',
+                MagicMock(return_value=opts)
+            ):
+                self.batch = BatchAsync(
+                    opts,
+                    MagicMock(side_effect=['1234', '1235', '1236']),
+                    {
+                        'tgt': '',
+                        'fun': '',
+                        'kwargs': {
+                            'batch': '',
+                            'batch_presence_ping_timeout': 1
+                        }
+                    })
+
+    def test_ping_jid(self):
+        self.assertEqual(self.batch.ping_jid, '1234')
+
+    def test_batch_jid(self):
+        self.assertEqual(self.batch.batch_jid, '1235')
+
+    def test_find_job_jid(self):
+        self.assertEqual(self.batch.find_job_jid, '1236')
+
+    def test_batch_size(self):
+        '''
+        Tests passing batch value as a number
+        '''
+        self.batch.opts = {'batch': '2', 'timeout': 5}
+        self.batch.minions = set(['foo', 'bar'])
+        self.batch.start_batch()
+        self.assertEqual(self.batch.batch_size, 2)
+
+    @tornado.testing.gen_test
+    def test_batch_start_on_batch_presence_ping_timeout(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({'minions': ['foo', 'bar']})
+        self.batch.local.run_job_async.return_value = future
+        ret = self.batch.start()
+        # assert start_batch is called later with batch_presence_ping_timeout as param
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.batch_presence_ping_timeout, self.batch.start_batch))
+        # assert test.ping called
+        self.assertEqual(
+            self.batch.local.run_job_async.call_args[0],
+            ('*', 'test.ping', [], 'glob')
+        )
+        # assert down_minions == all minions matched by tgt
+        self.assertEqual(self.batch.down_minions, set(['foo', 'bar']))
+
+    @tornado.testing.gen_test
+    def test_batch_start_on_gather_job_timeout(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({'minions': ['foo', 'bar']})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.batch_presence_ping_timeout = None
+        ret = self.batch.start()
+        # assert start_batch is called later with gather_job_timeout as param
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts['gather_job_timeout'], self.batch.start_batch))
+
+    def test_batch_fire_start_event(self):
+        self.batch.minions = set(['foo', 'bar'])
+        self.batch.opts = {'batch': '2', 'timeout': 5}
+        self.batch.event = MagicMock()
+        self.batch.metadata = {'mykey': 'myvalue'}
+        self.batch.start_batch()
+        self.assertEqual(
+            self.batch.event.fire_event.call_args[0],
+            (
+                {
+                    'available_minions': set(['foo', 'bar']),
+                    'down_minions': set(),
+                    'metadata': self.batch.metadata
+                },
+                "salt/batch/1235/start"
+            )
+        )
+
+    @tornado.testing.gen_test
+    def test_start_batch_calls_next(self):
+        self.batch.schedule_next = MagicMock(return_value=MagicMock())
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result(None)
+        self.batch.schedule_next = MagicMock(return_value=future)
+        self.batch.start_batch()
+        self.assertEqual(self.batch.initialized, True)
+        self.assertEqual(len(self.batch.schedule_next.mock_calls), 1)
+
+    def test_batch_fire_done_event(self):
+        self.batch.minions = set(['foo', 'bar'])
+        self.batch.event = MagicMock()
+        self.batch.metadata = {'mykey': 'myvalue'}
+        self.batch.end_batch()
+        self.assertEqual(
+            self.batch.event.fire_event.call_args[0],
+            (
+                {
+                    'available_minions': set(['foo', 'bar']),
+                    'done_minions': set(),
+                    'down_minions': set(),
+                    'timedout_minions': set(),
+                    'metadata': self.batch.metadata
+                },
+                "salt/batch/1235/done"
+            )
+        )
+        self.assertEqual(
+            len(self.batch.event.remove_event_handler.mock_calls), 1)
+
+    @tornado.testing.gen_test
+    def test_batch_next(self):
+        self.batch.event = MagicMock()
+        self.batch.opts['fun'] = 'my.fun'
+        self.batch.opts['arg'] = []
+        self.batch._get_next = MagicMock(return_value={'foo', 'bar'})
+        self.batch.batch_size = 2
+        future = tornado.gen.Future()
+        future.set_result({'minions': ['foo', 'bar']})
+        self.batch.local.run_job_async.return_value = future
+        ret = self.batch.schedule_next().result()
+        self.assertEqual(
+            self.batch.local.run_job_async.call_args[0],
+            ({'foo', 'bar'}, 'my.fun', [], 'list')
+        )
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts['timeout'], self.batch.find_job, {'foo', 'bar'})
+        )
+        self.assertEqual(self.batch.active, {'bar', 'foo'})
+
+    def test_next_batch(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {'foo', 'bar'})
+
+    def test_next_batch_one_done(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.done_minions = {'bar'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {'foo'})
+
+    def test_next_batch_one_done_one_active(self):
+        self.batch.minions = {'foo', 'bar', 'baz'}
+        self.batch.done_minions = {'bar'}
+        self.batch.active = {'baz'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {'foo'})
+
+    def test_next_batch_one_done_one_active_one_timedout(self):
+        self.batch.minions = {'foo', 'bar', 'baz', 'faz'}
+        self.batch.done_minions = {'bar'}
+        self.batch.active = {'baz'}
+        self.batch.timedout_minions = {'faz'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {'foo'})
+
+    def test_next_batch_bigger_size(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.batch_size = 3
+        self.assertEqual(self.batch._get_next(), {'foo', 'bar'})
+
+    def test_next_batch_all_done(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.done_minions = {'foo', 'bar'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_next_batch_all_active(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.active = {'foo', 'bar'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_next_batch_all_timedout(self):
+        self.batch.minions = {'foo', 'bar'}
+        self.batch.timedout_minions = {'foo', 'bar'}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_batch__event_handler_ping_return(self):
+        self.batch.down_minions = {'foo'}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
+        self.batch.start()
+        self.assertEqual(self.batch.minions, set())
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.minions, {'foo'})
+        self.assertEqual(self.batch.done_minions, set())
+
+    def test_batch__event_handler_call_start_batch_when_all_pings_return(self):
+        self.batch.down_minions = {'foo'}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(
+            self.batch.event.io_loop.spawn_callback.call_args[0],
+            (self.batch.start_batch,))
+
+    def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self):
+        self.batch.down_minions = {'foo', 'bar'}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(
+            len(self.batch.event.io_loop.spawn_callback.mock_calls), 0)
+
+    def test_batch__event_handler_batch_run_return(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/1235/ret/foo', {'id': 'foo'})))
+        self.batch.start()
+        self.batch.active = {'foo'}
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.active, set())
+        self.assertEqual(self.batch.done_minions, {'foo'})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.batch_delay, self.batch.schedule_next))
+
+    def test_batch__event_handler_find_job_return(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/1236/ret/foo', {'id': 'foo'})))
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.find_job_returned, {'foo'})
+
+    @tornado.testing.gen_test
+    def test_batch__event_handler_end_batch(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=('salt/job/not-my-jid/ret/foo', {'id': 'foo'})))
+        future = tornado.gen.Future()
+        future.set_result({'minions': ['foo', 'bar', 'baz']})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.start()
+        self.batch.initialized = True
+        self.assertEqual(self.batch.down_minions, {'foo', 'bar', 'baz'})
+        self.batch.end_batch = MagicMock()
+        self.batch.minions = {'foo', 'bar', 'baz'}
+        self.batch.done_minions = {'foo', 'bar'}
+        self.batch.timedout_minions = {'baz'}
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(len(self.batch.end_batch.mock_calls), 1)
+
+    @tornado.testing.gen_test
+    def test_batch_find_job(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.find_job({'foo', 'bar'})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts['gather_job_timeout'], self.batch.check_find_job, {'foo', 'bar'})
+        )
+
+    @tornado.testing.gen_test
+    def test_batch_find_job_with_done_minions(self):
+        self.batch.done_minions = {'bar'}
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.find_job({'foo', 'bar'})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts['gather_job_timeout'], self.batch.check_find_job, {'foo'})
+        )
+
+    def test_batch_check_find_job_did_not_return(self):
+        self.batch.event = MagicMock()
+        self.batch.active = {'foo'}
+        self.batch.find_job_returned = set()
+        self.batch.check_find_job({'foo'})
+        self.assertEqual(self.batch.find_job_returned, set())
+        self.assertEqual(self.batch.active, set())
+        self.assertEqual(len(self.batch.event.io_loop.add_callback.mock_calls), 0)
+
+    def test_batch_check_find_job_did_return(self):
+        self.batch.event = MagicMock()
+        self.batch.find_job_returned = {'foo'}
+        self.batch.check_find_job({'foo'})
+        self.assertEqual(
+            self.batch.event.io_loop.add_callback.call_args[0],
+            (self.batch.find_job, {'foo'})
+        )
+
+    def test_batch_check_find_job_multiple_states(self):
+        self.batch.event = MagicMock()
+        # currently running minions
+        self.batch.active = {'foo', 'bar'}
+
+        # minion is running and find_job returns
+        self.batch.find_job_returned = {'foo'}
+
+        # minion started running but find_job did not return
+        self.batch.timedout_minions = {'faz'}
+
+        # minion finished
+        self.batch.done_minions = {'baz'}
+
+        # both not yet done but only 'foo' responded to find_job
+        not_done = {'foo', 'bar'}
+
+        self.batch.check_find_job(not_done)
+
+        # assert 'bar' removed from active
+        self.assertEqual(self.batch.active, {'foo'})
+
+        # assert 'bar' added to timedout_minions
+        self.assertEqual(self.batch.timedout_minions, {'bar', 'faz'})
+
+        # assert 'find_job' schedueled again only for 'foo'
+        self.assertEqual(
+            self.batch.event.io_loop.add_callback.call_args[0],
+            (self.batch.find_job, {'foo'})
+        )
-- 
2.20.1