File async-batch-implementation.patch of Package salt.20998
From 78faccbd063b8635550935057b8630262958f669 Mon Sep 17 00:00:00 2001
From: Mihai Dinca <mdinca@suse.de>
Date: Fri, 16 Nov 2018 17:05:29 +0100
Subject: [PATCH] Async batch implementation
Add find_job checks
Check if should close on all events
Make batch_delay a request parameter
Allow multiple event handlers
Use config value for gather_job_timeout when not in payload
Add async batch unittests
Allow metadata to pass
Pass metadata only to batch jobs
Add the metadata to the start/done events
Pass only metadata not all **kwargs
Add separate batch presence_ping timeout
---
 salt/auth/__init__.py              |   2 +
 salt/cli/batch.py                  | 115 +++++++---
 salt/cli/batch_async.py            | 240 +++++++++++++++++++
 salt/client/__init__.py            |  14 ++
 salt/master.py                     |  26 ++-
 salt/netapi/__init__.py            |   3 +-
 salt/transport/ipc.py              |  43 ++--
 salt/utils/event.py                |   8 +-
 tests/unit/cli/test_batch_async.py | 357 +++++++++++++++++++++++++++++
 9 files changed, 741 insertions(+), 67 deletions(-)
 create mode 100644 salt/cli/batch_async.py
 create mode 100644 tests/unit/cli/test_batch_async.py
diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py
index ee1eac7ce4..22c54e8048 100644
--- a/salt/auth/__init__.py
+++ b/salt/auth/__init__.py
@@ -52,6 +52,8 @@ AUTH_INTERNAL_KEYWORDS = frozenset(
         "print_event",
         "raw",
         "yield_pub_data",
+        "batch",
+        "batch_delay",
     ]
 )
 
diff --git a/salt/cli/batch.py b/salt/cli/batch.py
index 155dc734b7..527cffdeb7 100644
--- a/salt/cli/batch.py
+++ b/salt/cli/batch.py
@@ -1,10 +1,7 @@
-# -*- coding: utf-8 -*-
 """
 Execute batch runs
 """
 
-# Import python libs
-from __future__ import absolute_import, print_function, unicode_literals
 
 import copy
 
@@ -17,11 +14,8 @@ from datetime import datetime, timedelta
 import salt.client
 import salt.exceptions
 import salt.output
-
-# Import salt libs
 import salt.utils.stringutils
 
-# Import 3rd-party libs
 # pylint: disable=import-error,no-name-in-module,redefined-builtin
 from salt.ext import six
 from salt.ext.six.moves import range
@@ -29,7 +23,77 @@ from salt.ext.six.moves import range
 log = logging.getLogger(__name__)
 
 
-class Batch(object):
+def get_bnum(opts, minions, quiet):
+    """
+    Return the active number of minions to maintain
+    """
+    partition = lambda x: float(x) / 100.0 * len(minions)
+    try:
+        if isinstance(opts["batch"], str) and "%" in opts["batch"]:
+            res = partition(float(opts["batch"].strip("%")))
+            if res < 1:
+                return int(math.ceil(res))
+            else:
+                return int(res)
+        else:
+            return int(opts["batch"])
+    except ValueError:
+        if not quiet:
+            salt.utils.stringutils.print_cli(
+                "Invalid batch data sent: {}\nData must be in the "
+                "form of %10, 10% or 3".format(opts["batch"])
+            )
+
+
+def batch_get_opts(
+    tgt, fun, batch, parent_opts, arg=(), tgt_type="glob", ret="", kwarg=None, **kwargs
+):
+    # We need to re-import salt.utils.args here
+    # even though it has already been imported.
+    # when cmd_batch is called via the NetAPI
+    # the module is unavailable.
+    import salt.utils.args
+
+    arg = salt.utils.args.condition_input(arg, kwarg)
+    opts = {
+        "tgt": tgt,
+        "fun": fun,
+        "arg": arg,
+        "tgt_type": tgt_type,
+        "ret": ret,
+        "batch": batch,
+        "failhard": kwargs.get("failhard", parent_opts.get("failhard", False)),
+        "raw": kwargs.get("raw", False),
+    }
+
+    if "timeout" in kwargs:
+        opts["timeout"] = kwargs["timeout"]
+    if "gather_job_timeout" in kwargs:
+        opts["gather_job_timeout"] = kwargs["gather_job_timeout"]
+    if "batch_wait" in kwargs:
+        opts["batch_wait"] = int(kwargs["batch_wait"])
+
+    for key, val in parent_opts.items():
+        if key not in opts:
+            opts[key] = val
+
+    return opts
+
+
+def batch_get_eauth(kwargs):
+    eauth = {}
+    if "eauth" in kwargs:
+        eauth["eauth"] = kwargs.pop("eauth")
+    if "username" in kwargs:
+        eauth["username"] = kwargs.pop("username")
+    if "password" in kwargs:
+        eauth["password"] = kwargs.pop("password")
+    if "token" in kwargs:
+        eauth["token"] = kwargs.pop("token")
+    return eauth
+
+
+class Batch:
     """
     Manage the execution of batch runs
     """
@@ -75,7 +139,7 @@ class Batch(object):
                 continue
             else:
                 try:
-                    m = next(six.iterkeys(ret))
+                    m = next(iter(ret.keys()))
                 except StopIteration:
                     if not self.quiet:
                         salt.utils.stringutils.print_cli(
@@ -87,28 +151,7 @@ class Batch(object):
         return (list(fret), ping_gen, nret.difference(fret))
 
     def get_bnum(self):
-        """
-        Return the active number of minions to maintain
-        """
-        partition = lambda x: float(x) / 100.0 * len(self.minions)
-        try:
-            if (
-                isinstance(self.opts["batch"], six.string_types)
-                and "%" in self.opts["batch"]
-            ):
-                res = partition(float(self.opts["batch"].strip("%")))
-                if res < 1:
-                    return int(math.ceil(res))
-                else:
-                    return int(res)
-            else:
-                return int(self.opts["batch"])
-        except ValueError:
-            if not self.quiet:
-                salt.utils.stringutils.print_cli(
-                    "Invalid batch data sent: {0}\nData must be in the "
-                    "form of %10, 10% or 3".format(self.opts["batch"])
-                )
+        return get_bnum(self.opts, self.minions, self.quiet)
 
     def __update_wait(self, wait):
         now = datetime.now()
@@ -161,7 +204,7 @@ class Batch(object):
             # the user we won't be attempting to run a job on them
             for down_minion in self.down_minions:
                 salt.utils.stringutils.print_cli(
-                    "Minion {0} did not respond. No job will be sent.".format(
+                    "Minion {} did not respond. No job will be sent.".format(
                         down_minion
                     )
                 )
@@ -190,7 +233,7 @@ class Batch(object):
             if next_:
                 if not self.quiet:
                     salt.utils.stringutils.print_cli(
-                        "\nExecuting run on {0}\n".format(sorted(next_))
+                        "\nExecuting run on {}\n".format(sorted(next_))
                     )
                 # create a new iterator for this batch of minions
                 return_value = self.opts.get("return", self.opts.get("ret", ""))
@@ -218,7 +261,7 @@ class Batch(object):
             for ping_ret in self.ping_gen:
                 if ping_ret is None:
                     break
-                m = next(six.iterkeys(ping_ret))
+                m = next(iter(ping_ret.keys()))
                 if m not in self.minions:
                     self.minions.append(m)
                     to_run.append(m)
@@ -243,7 +286,7 @@ class Batch(object):
                                 )
                             else:
                                 salt.utils.stringutils.print_cli(
-                                    "minion {0} was already deleted from tracker, probably a duplicate key".format(
+                                    "minion {} was already deleted from tracker, probably a duplicate key".format(
                                         part["id"]
                                     )
                                 )
@@ -254,7 +297,7 @@ class Batch(object):
                                     minion_tracker[queue]["minions"].remove(id)
                                 else:
                                     salt.utils.stringutils.print_cli(
-                                        "minion {0} was already deleted from tracker, probably a duplicate key".format(
+                                        "minion {} was already deleted from tracker, probably a duplicate key".format(
                                             id
                                         )
                                     )
@@ -274,7 +317,7 @@ class Batch(object):
                                 parts[minion] = {}
                                 parts[minion]["ret"] = {}
 
-            for minion, data in six.iteritems(parts):
+            for minion, data in parts.items():
                 if minion in active:
                     active.remove(minion)
                     if bwait:
diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
new file mode 100644
index 0000000000..1557e5105b
--- /dev/null
+++ b/salt/cli/batch_async.py
@@ -0,0 +1,240 @@
+"""
+Execute a job on the targeted minions by using a moving window of fixed size `batch`.
+"""
+
+import fnmatch
+
+# pylint: enable=import-error,no-name-in-module,redefined-builtin
+import logging
+
+import salt.client
+import tornado
+from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+
+log = logging.getLogger(__name__)
+
+
+class BatchAsync:
+    """
+    Run a job on the targeted minions by using a moving window of fixed size `batch`.
+
+    ``BatchAsync`` is used to execute a job on the targeted minions by keeping
+    the number of concurrent running minions to the size of `batch` parameter.
+
+    The control parameters are:
+        - batch: number/percentage of concurrent running minions
+        - batch_delay: minimum wait time between batches
+        - batch_presence_ping_timeout: time to wait for presence pings before starting the batch
+        - gather_job_timeout: `find_job` timeout
+        - timeout: time to wait before firing a `find_job`
+
+    When the batch stars, a `start` event is fired:
+         - tag: salt/batch/<batch-jid>/start
+         - data: {
+             "available_minions": self.minions,
+             "down_minions": self.down_minions
+           }
+
+    When the batch ends, an `done` event is fired:
+        - tag: salt/batch/<batch-jid>/done
+        - data: {
+             "available_minions": self.minions,
+             "down_minions": self.down_minions,
+             "done_minions": self.done_minions,
+             "timedout_minions": self.timedout_minions
+         }
+    """
+
+    def __init__(self, parent_opts, jid_gen, clear_load):
+        ioloop = tornado.ioloop.IOLoop.current()
+        self.local = salt.client.get_local_client(parent_opts["conf_file"])
+        if "gather_job_timeout" in clear_load["kwargs"]:
+            clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
+                "gather_job_timeout"
+            )
+        else:
+            clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+        self.batch_presence_ping_timeout = clear_load["kwargs"].get(
+            "batch_presence_ping_timeout", None
+        )
+        self.batch_delay = clear_load["kwargs"].get("batch_delay", 1)
+        self.opts = batch_get_opts(
+            clear_load.pop("tgt"),
+            clear_load.pop("fun"),
+            clear_load["kwargs"].pop("batch"),
+            self.local.opts,
+            **clear_load
+        )
+        self.eauth = batch_get_eauth(clear_load["kwargs"])
+        self.metadata = clear_load["kwargs"].get("metadata", {})
+        self.minions = set()
+        self.down_minions = set()
+        self.timedout_minions = set()
+        self.done_minions = set()
+        self.active = set()
+        self.initialized = False
+        self.ping_jid = jid_gen()
+        self.batch_jid = jid_gen()
+        self.find_job_jid = jid_gen()
+        self.find_job_returned = set()
+        self.event = salt.utils.event.get_event(
+            "master",
+            self.opts["sock_dir"],
+            self.opts["transport"],
+            opts=self.opts,
+            listen=True,
+            io_loop=ioloop,
+            keep_loop=True,
+        )
+
+    def __set_event_handler(self):
+        ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid)
+        batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid)
+        find_job_return_pattern = "salt/job/{}/ret/*".format(self.find_job_jid)
+        self.event.subscribe(ping_return_pattern, match_type="glob")
+        self.event.subscribe(batch_return_pattern, match_type="glob")
+        self.event.subscribe(find_job_return_pattern, match_type="glob")
+        self.event.patterns = {
+            (ping_return_pattern, "ping_return"),
+            (batch_return_pattern, "batch_run"),
+            (find_job_return_pattern, "find_job_return"),
+        }
+        self.event.set_event_handler(self.__event_handler)
+
+    def __event_handler(self, raw):
+        if not self.event:
+            return
+        mtag, data = self.event.unpack(raw, self.event.serial)
+        for (pattern, op) in self.event.patterns:
+            if fnmatch.fnmatch(mtag, pattern):
+                minion = data["id"]
+                if op == "ping_return":
+                    self.minions.add(minion)
+                    self.down_minions.remove(minion)
+                    if not self.down_minions:
+                        self.event.io_loop.spawn_callback(self.start_batch)
+                elif op == "find_job_return":
+                    self.find_job_returned.add(minion)
+                elif op == "batch_run":
+                    if minion in self.active:
+                        self.active.remove(minion)
+                        self.done_minions.add(minion)
+                        # call later so that we maybe gather more returns
+                        self.event.io_loop.call_later(
+                            self.batch_delay, self.schedule_next
+                        )
+
+        if self.initialized and self.done_minions == self.minions.difference(
+            self.timedout_minions
+        ):
+            self.end_batch()
+
+    def _get_next(self):
+        to_run = (
+            self.minions.difference(self.done_minions)
+            .difference(self.active)
+            .difference(self.timedout_minions)
+        )
+        next_batch_size = min(
+            len(to_run),  # partial batch (all left)
+            self.batch_size - len(self.active),  # full batch or available slots
+        )
+        return set(list(to_run)[:next_batch_size])
+
+    @tornado.gen.coroutine
+    def check_find_job(self, minions):
+        did_not_return = minions.difference(self.find_job_returned)
+        if did_not_return:
+            for minion in did_not_return:
+                if minion in self.find_job_returned:
+                    self.find_job_returned.remove(minion)
+                if minion in self.active:
+                    self.active.remove(minion)
+                self.timedout_minions.add(minion)
+        running = (
+            minions.difference(did_not_return)
+            .difference(self.done_minions)
+            .difference(self.timedout_minions)
+        )
+        if running:
+            self.event.io_loop.add_callback(self.find_job, running)
+
+    @tornado.gen.coroutine
+    def find_job(self, minions):
+        not_done = minions.difference(self.done_minions)
+        ping_return = yield self.local.run_job_async(
+            not_done,
+            "saltutil.find_job",
+            [self.batch_jid],
+            "list",
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.find_job_jid,
+            **self.eauth
+        )
+        self.event.io_loop.call_later(
+            self.opts["gather_job_timeout"], self.check_find_job, not_done
+        )
+
+    @tornado.gen.coroutine
+    def start(self):
+        self.__set_event_handler()
+        # start batching even if not all minions respond to ping
+        self.event.io_loop.call_later(
+            self.batch_presence_ping_timeout or self.opts["gather_job_timeout"],
+            self.start_batch,
+        )
+        ping_return = yield self.local.run_job_async(
+            self.opts["tgt"],
+            "test.ping",
+            [],
+            self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            **self.eauth
+        )
+        self.down_minions = set(ping_return["minions"])
+
+    @tornado.gen.coroutine
+    def start_batch(self):
+        if not self.initialized:
+            self.batch_size = get_bnum(self.opts, self.minions, True)
+            self.initialized = True
+            data = {
+                "available_minions": self.minions,
+                "down_minions": self.down_minions,
+                "metadata": self.metadata,
+            }
+            self.event.fire_event(data, "salt/batch/{}/start".format(self.batch_jid))
+            yield self.schedule_next()
+
+    def end_batch(self):
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.down_minions,
+            "done_minions": self.done_minions,
+            "timedout_minions": self.timedout_minions,
+            "metadata": self.metadata,
+        }
+        self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid))
+        self.event.remove_event_handler(self.__event_handler)
+
+    @tornado.gen.coroutine
+    def schedule_next(self):
+        next_batch = self._get_next()
+        if next_batch:
+            yield self.local.run_job_async(
+                next_batch,
+                self.opts["fun"],
+                self.opts["arg"],
+                "list",
+                raw=self.opts.get("raw", False),
+                ret=self.opts.get("return", ""),
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=self.batch_jid,
+                metadata=self.metadata,
+            )
+            self.event.io_loop.call_later(
+                self.opts["timeout"], self.find_job, set(next_batch)
+            )
+            self.active = self.active.union(next_batch)
diff --git a/salt/client/__init__.py b/salt/client/__init__.py
index 6fab45fcbf..1e9f11df4c 100644
--- a/salt/client/__init__.py
+++ b/salt/client/__init__.py
@@ -543,6 +543,20 @@ class LocalClient:
         # Late import - not used anywhere else in this file
         import salt.cli.batch
 
+        opts = salt.cli.batch.batch_get_opts(
+            tgt,
+            fun,
+            batch,
+            self.opts,
+            arg=arg,
+            tgt_type=tgt_type,
+            ret=ret,
+            kwarg=kwarg,
+            **kwargs
+        )
+
+        eauth = salt.cli.batch.batch_get_eauth(kwargs)
+
         arg = salt.utils.args.condition_input(arg, kwarg)
         opts = {
             "tgt": tgt,
diff --git a/salt/master.py b/salt/master.py
index 1c91c28209..b9bc1a7a67 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -3,7 +3,6 @@ This module contains all of the routines needed to set up a master server, this
 involves preparing the three listeners and the workers needed by the master.
 """
 
-# Import python libs
 
 import collections
 import copy
@@ -21,10 +20,9 @@ import time
 
 import salt.acl
 import salt.auth
+import salt.cli.batch_async
 import salt.client
 import salt.client.ssh.client
-
-# Import salt libs
 import salt.crypt
 import salt.daemons.masterapi
 import salt.defaults.exitcodes
@@ -89,7 +87,6 @@ except ImportError:
     # resource is not available on windows
     HAS_RESOURCE = False
 
-# Import halite libs
 try:
     import halite  # pylint: disable=import-error
 
@@ -2232,6 +2229,24 @@ class ClearFuncs(TransportMethods):
             return False
         return self.loadauth.get_tok(clear_load["token"])
 
+    def publish_batch(self, clear_load, minions, missing):
+        batch_load = {}
+        batch_load.update(clear_load)
+        import salt.cli.batch_async
+
+        batch = salt.cli.batch_async.BatchAsync(
+            self.local.opts,
+            functools.partial(self._prep_jid, clear_load, {}),
+            batch_load,
+        )
+        ioloop = tornado.ioloop.IOLoop.current()
+        ioloop.add_callback(batch.start)
+
+        return {
+            "enc": "clear",
+            "load": {"jid": batch.batch_jid, "minions": minions, "missing": missing},
+        }
+
     def publish(self, clear_load):
         """
         This method sends out publications to the minions, it can only be used
@@ -2349,6 +2364,9 @@ class ClearFuncs(TransportMethods):
                         ),
                     },
                 }
+        if extra.get("batch", None):
+            return self.publish_batch(clear_load, minions, missing)
+
         jid = self._prep_jid(clear_load, extra)
         if jid is None:
             return {"enc": "clear", "load": {"error": "Master failed to assign jid"}}
diff --git a/salt/netapi/__init__.py b/salt/netapi/__init__.py
index 96f57f6c79..dec19b37ef 100644
--- a/salt/netapi/__init__.py
+++ b/salt/netapi/__init__.py
@@ -151,7 +151,8 @@ class NetapiClient:
         :return: job ID
         """
         local = salt.client.get_local_client(mopts=self.opts)
-        return local.run_job(*args, **kwargs)
+        ret = local.run_job(*args, **kwargs)
+        return ret
 
     def local(self, *args, **kwargs):
         """
diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py
index 041718d058..f411907da2 100644
--- a/salt/transport/ipc.py
+++ b/salt/transport/ipc.py
@@ -1,10 +1,7 @@
-# -*- coding: utf-8 -*-
 """
 IPC transport classes
 """
 
-# Import Python libs
-from __future__ import absolute_import, print_function, unicode_literals
 
 import errno
 import logging
@@ -12,15 +9,12 @@ import socket
 import sys
 import time
 
-# Import Tornado libs
 import salt.ext.tornado
 import salt.ext.tornado.concurrent
 import salt.ext.tornado.gen
 import salt.ext.tornado.netutil
 import salt.transport.client
 import salt.transport.frame
-
-# Import Salt libs
 import salt.utils.msgpack
 from salt.ext import six
 from salt.ext.tornado.ioloop import IOLoop
@@ -42,7 +36,7 @@ def future_with_timeout_callback(future):
 
 class FutureWithTimeout(salt.ext.tornado.concurrent.Future):
     def __init__(self, io_loop, future, timeout):
-        super(FutureWithTimeout, self).__init__()
+        super().__init__()
         self.io_loop = io_loop
         self._future = future
         if timeout is not None:
@@ -85,7 +79,7 @@ class FutureWithTimeout(salt.ext.tornado.concurrent.Future):
             self.set_exception(exc)
 
 
-class IPCServer(object):
+class IPCServer:
     """
     A Tornado IPC server very similar to Tornado's TCPServer class
     but using either UNIX domain sockets or TCP sockets
@@ -181,10 +175,7 @@ class IPCServer(object):
             # Under Py2 we still want raw to be set to True
             msgpack_kwargs = {"raw": six.PY2}
         else:
-            if six.PY2:
-                msgpack_kwargs = {"encoding": None}
-            else:
-                msgpack_kwargs = {"encoding": "utf-8"}
+            msgpack_kwargs = {"encoding": "utf-8"}
         unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
         while not stream.closed():
             try:
@@ -200,7 +191,7 @@ class IPCServer(object):
             except StreamClosedError:
                 log.trace("Client disconnected from IPC %s", self.socket_path)
                 break
-            except socket.error as exc:
+            except OSError as exc:
                 # On occasion an exception will occur with
                 # an error code of 0, it's a spurious exception.
                 if exc.errno == 0:
@@ -247,7 +238,7 @@ class IPCServer(object):
     # pylint: enable=W1701
 
 
-class IPCClient(object):
+class IPCClient:
     """
     A Tornado IPC client very similar to Tornado's TCPClient class
     but using either UNIX domain sockets or TCP sockets
@@ -282,10 +273,7 @@ class IPCClient(object):
             # Under Py2 we still want raw to be set to True
             msgpack_kwargs = {"raw": six.PY2}
         else:
-            if six.PY2:
-                msgpack_kwargs = {"encoding": None}
-            else:
-                msgpack_kwargs = {"encoding": "utf-8"}
+            msgpack_kwargs = {"encoding": "utf-8"}
         self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs)
 
     def connected(self):
@@ -385,10 +373,10 @@ class IPCClient(object):
         if self.stream is not None and not self.stream.closed():
             try:
                 self.stream.close()
-            except socket.error as exc:
+            except OSError as exc:
                 if exc.errno != errno.EBADF:
                     # If its not a bad file descriptor error, raise
-                    six.reraise(*sys.exc_info())
+                    raise
 
 
 class IPCMessageClient(IPCClient):
@@ -483,7 +471,7 @@ class IPCMessageServer(IPCServer):
     """
 
 
-class IPCMessagePublisher(object):
+class IPCMessagePublisher:
     """
     A Tornado IPC Publisher similar to Tornado's TCPServer class
     but using either UNIX domain sockets or TCP sockets
@@ -645,10 +633,11 @@ class IPCMessageSubscriber(IPCClient):
     """
 
     def __init__(self, socket_path, io_loop=None):
-        super(IPCMessageSubscriber, self).__init__(socket_path, io_loop=io_loop)
+        super().__init__(socket_path, io_loop=io_loop)
         self._read_stream_future = None
         self._saved_data = []
         self._read_in_progress = Lock()
+        self.callbacks = set()
 
     @salt.ext.tornado.gen.coroutine
     def _read(self, timeout, callback=None):
@@ -725,8 +714,12 @@ class IPCMessageSubscriber(IPCClient):
             return self._saved_data.pop(0)
         return self.io_loop.run_sync(lambda: self._read(timeout))
 
+    def __run_callbacks(self, raw):
+        for callback in self.callbacks:
+            self.io_loop.spawn_callback(callback, raw)
+
     @salt.ext.tornado.gen.coroutine
-    def read_async(self, callback):
+    def read_async(self):
         """
         Asynchronously read messages and invoke a callback when they are ready.
 
@@ -744,7 +737,7 @@ class IPCMessageSubscriber(IPCClient):
             except Exception as exc:  # pylint: disable=broad-except
                 log.error("Exception occurred while Subscriber connecting: %s", exc)
                 yield salt.ext.tornado.gen.sleep(1)
-        yield self._read(None, callback)
+        yield self._read(None, self.__run_callbacks)
 
     def close(self):
         """
@@ -754,7 +747,7 @@ class IPCMessageSubscriber(IPCClient):
         """
         if self._closing:
             return
-        super(IPCMessageSubscriber, self).close()
+        super().close()
         # This will prevent this message from showing up:
         # '[ERROR   ] Future exception was never retrieved:
         # StreamClosedError'
diff --git a/salt/utils/event.py b/salt/utils/event.py
index 6f7edef4e5..ae200f9dfa 100644
--- a/salt/utils/event.py
+++ b/salt/utils/event.py
@@ -867,6 +867,10 @@ class SaltEvent:
                     # Minion fired a bad retcode, fire an event
                     self._fire_ret_load_specific_fun(load)
 
+    def remove_event_handler(self, event_handler):
+        if event_handler in self.subscriber.callbacks:
+            self.subscriber.callbacks.remove(event_handler)
+
     def set_event_handler(self, event_handler):
         """
         Invoke the event_handler callback each time an event arrives.
@@ -875,8 +879,10 @@ class SaltEvent:
 
         if not self.cpub:
             self.connect_pub()
+
+        self.subscriber.callbacks.add(event_handler)
         # This will handle reconnects
-        return self.subscriber.read_async(event_handler)
+        return self.subscriber.read_async()
 
     # pylint: disable=W1701
     def __del__(self):
diff --git a/tests/unit/cli/test_batch_async.py b/tests/unit/cli/test_batch_async.py
new file mode 100644
index 0000000000..3f8626a2dd
--- /dev/null
+++ b/tests/unit/cli/test_batch_async.py
@@ -0,0 +1,357 @@
+import tornado
+from salt.cli.batch_async import BatchAsync
+from tests.support.mock import NO_MOCK, NO_MOCK_REASON, MagicMock, patch
+from tests.support.unit import TestCase, skipIf
+from tornado.testing import AsyncTestCase
+
+
+@skipIf(NO_MOCK, NO_MOCK_REASON)
+class AsyncBatchTestCase(AsyncTestCase, TestCase):
+    def setUp(self):
+        self.io_loop = self.get_new_ioloop()
+        opts = {
+            "batch": "1",
+            "conf_file": {},
+            "tgt": "*",
+            "timeout": 5,
+            "gather_job_timeout": 5,
+            "batch_presence_ping_timeout": 1,
+            "transport": None,
+            "sock_dir": "",
+        }
+
+        with patch("salt.client.get_local_client", MagicMock(return_value=MagicMock())):
+            with patch(
+                "salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)
+            ):
+                self.batch = BatchAsync(
+                    opts,
+                    MagicMock(side_effect=["1234", "1235", "1236"]),
+                    {
+                        "tgt": "",
+                        "fun": "",
+                        "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+                    },
+                )
+
+    def test_ping_jid(self):
+        self.assertEqual(self.batch.ping_jid, "1234")
+
+    def test_batch_jid(self):
+        self.assertEqual(self.batch.batch_jid, "1235")
+
+    def test_find_job_jid(self):
+        self.assertEqual(self.batch.find_job_jid, "1236")
+
+    def test_batch_size(self):
+        """
+        Tests passing batch value as a number
+        """
+        self.batch.opts = {"batch": "2", "timeout": 5}
+        self.batch.minions = {"foo", "bar"}
+        self.batch.start_batch()
+        self.assertEqual(self.batch.batch_size, 2)
+
+    @tornado.testing.gen_test
+    def test_batch_start_on_batch_presence_ping_timeout(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({"minions": ["foo", "bar"]})
+        self.batch.local.run_job_async.return_value = future
+        ret = self.batch.start()
+        # assert start_batch is called later with batch_presence_ping_timeout as param
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.batch_presence_ping_timeout, self.batch.start_batch),
+        )
+        # assert test.ping called
+        self.assertEqual(
+            self.batch.local.run_job_async.call_args[0], ("*", "test.ping", [], "glob")
+        )
+        # assert down_minions == all minions matched by tgt
+        self.assertEqual(self.batch.down_minions, {"foo", "bar"})
+
+    @tornado.testing.gen_test
+    def test_batch_start_on_gather_job_timeout(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({"minions": ["foo", "bar"]})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.batch_presence_ping_timeout = None
+        ret = self.batch.start()
+        # assert start_batch is called later with gather_job_timeout as param
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts["gather_job_timeout"], self.batch.start_batch),
+        )
+
+    def test_batch_fire_start_event(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.opts = {"batch": "2", "timeout": 5}
+        self.batch.event = MagicMock()
+        self.batch.metadata = {"mykey": "myvalue"}
+        self.batch.start_batch()
+        self.assertEqual(
+            self.batch.event.fire_event.call_args[0],
+            (
+                {
+                    "available_minions": {"foo", "bar"},
+                    "down_minions": set(),
+                    "metadata": self.batch.metadata,
+                },
+                "salt/batch/1235/start",
+            ),
+        )
+
+    @tornado.testing.gen_test
+    def test_start_batch_calls_next(self):
+        self.batch.schedule_next = MagicMock(return_value=MagicMock())
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result(None)
+        self.batch.schedule_next = MagicMock(return_value=future)
+        self.batch.start_batch()
+        self.assertEqual(self.batch.initialized, True)
+        self.assertEqual(len(self.batch.schedule_next.mock_calls), 1)
+
+    def test_batch_fire_done_event(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.event = MagicMock()
+        self.batch.metadata = {"mykey": "myvalue"}
+        self.batch.end_batch()
+        self.assertEqual(
+            self.batch.event.fire_event.call_args[0],
+            (
+                {
+                    "available_minions": {"foo", "bar"},
+                    "done_minions": set(),
+                    "down_minions": set(),
+                    "timedout_minions": set(),
+                    "metadata": self.batch.metadata,
+                },
+                "salt/batch/1235/done",
+            ),
+        )
+        self.assertEqual(len(self.batch.event.remove_event_handler.mock_calls), 1)
+
+    @tornado.testing.gen_test
+    def test_batch_next(self):
+        self.batch.event = MagicMock()
+        self.batch.opts["fun"] = "my.fun"
+        self.batch.opts["arg"] = []
+        self.batch._get_next = MagicMock(return_value={"foo", "bar"})
+        self.batch.batch_size = 2
+        future = tornado.gen.Future()
+        future.set_result({"minions": ["foo", "bar"]})
+        self.batch.local.run_job_async.return_value = future
+        ret = self.batch.schedule_next().result()
+        self.assertEqual(
+            self.batch.local.run_job_async.call_args[0],
+            ({"foo", "bar"}, "my.fun", [], "list"),
+        )
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts["timeout"], self.batch.find_job, {"foo", "bar"}),
+        )
+        self.assertEqual(self.batch.active, {"bar", "foo"})
+
+    def test_next_batch(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {"foo", "bar"})
+
+    def test_next_batch_one_done(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.done_minions = {"bar"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {"foo"})
+
+    def test_next_batch_one_done_one_active(self):
+        self.batch.minions = {"foo", "bar", "baz"}
+        self.batch.done_minions = {"bar"}
+        self.batch.active = {"baz"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {"foo"})
+
+    def test_next_batch_one_done_one_active_one_timedout(self):
+        self.batch.minions = {"foo", "bar", "baz", "faz"}
+        self.batch.done_minions = {"bar"}
+        self.batch.active = {"baz"}
+        self.batch.timedout_minions = {"faz"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), {"foo"})
+
+    def test_next_batch_bigger_size(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.batch_size = 3
+        self.assertEqual(self.batch._get_next(), {"foo", "bar"})
+
+    def test_next_batch_all_done(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.done_minions = {"foo", "bar"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_next_batch_all_active(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.active = {"foo", "bar"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_next_batch_all_timedout(self):
+        self.batch.minions = {"foo", "bar"}
+        self.batch.timedout_minions = {"foo", "bar"}
+        self.batch.batch_size = 2
+        self.assertEqual(self.batch._get_next(), set())
+
+    def test_batch__event_handler_ping_return(self):
+        self.batch.down_minions = {"foo"}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+        )
+        self.batch.start()
+        self.assertEqual(self.batch.minions, set())
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.minions, {"foo"})
+        self.assertEqual(self.batch.done_minions, set())
+
+    def test_batch__event_handler_call_start_batch_when_all_pings_return(self):
+        self.batch.down_minions = {"foo"}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+        )
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(
+            self.batch.event.io_loop.spawn_callback.call_args[0],
+            (self.batch.start_batch,),
+        )
+
+    def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self):
+        self.batch.down_minions = {"foo", "bar"}
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+        )
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(len(self.batch.event.io_loop.spawn_callback.mock_calls), 0)
+
+    def test_batch__event_handler_batch_run_return(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
+        )
+        self.batch.start()
+        self.batch.active = {"foo"}
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.active, set())
+        self.assertEqual(self.batch.done_minions, {"foo"})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.batch_delay, self.batch.schedule_next),
+        )
+
+    def test_batch__event_handler_find_job_return(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(return_value=("salt/job/1236/ret/foo", {"id": "foo"}))
+        )
+        self.batch.start()
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(self.batch.find_job_returned, {"foo"})
+
+    @tornado.testing.gen_test
+    def test_batch__event_handler_end_batch(self):
+        self.batch.event = MagicMock(
+            unpack=MagicMock(
+                return_value=("salt/job/not-my-jid/ret/foo", {"id": "foo"})
+            )
+        )
+        future = tornado.gen.Future()
+        future.set_result({"minions": ["foo", "bar", "baz"]})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.start()
+        self.batch.initialized = True
+        self.assertEqual(self.batch.down_minions, {"foo", "bar", "baz"})
+        self.batch.end_batch = MagicMock()
+        self.batch.minions = {"foo", "bar", "baz"}
+        self.batch.done_minions = {"foo", "bar"}
+        self.batch.timedout_minions = {"baz"}
+        self.batch._BatchAsync__event_handler(MagicMock())
+        self.assertEqual(len(self.batch.end_batch.mock_calls), 1)
+
+    @tornado.testing.gen_test
+    def test_batch_find_job(self):
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.find_job({"foo", "bar"})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (
+                self.batch.opts["gather_job_timeout"],
+                self.batch.check_find_job,
+                {"foo", "bar"},
+            ),
+        )
+
+    @tornado.testing.gen_test
+    def test_batch_find_job_with_done_minions(self):
+        self.batch.done_minions = {"bar"}
+        self.batch.event = MagicMock()
+        future = tornado.gen.Future()
+        future.set_result({})
+        self.batch.local.run_job_async.return_value = future
+        self.batch.find_job({"foo", "bar"})
+        self.assertEqual(
+            self.batch.event.io_loop.call_later.call_args[0],
+            (self.batch.opts["gather_job_timeout"], self.batch.check_find_job, {"foo"}),
+        )
+
+    def test_batch_check_find_job_did_not_return(self):
+        self.batch.event = MagicMock()
+        self.batch.active = {"foo"}
+        self.batch.find_job_returned = set()
+        self.batch.check_find_job({"foo"})
+        self.assertEqual(self.batch.find_job_returned, set())
+        self.assertEqual(self.batch.active, set())
+        self.assertEqual(len(self.batch.event.io_loop.add_callback.mock_calls), 0)
+
+    def test_batch_check_find_job_did_return(self):
+        self.batch.event = MagicMock()
+        self.batch.find_job_returned = {"foo"}
+        self.batch.check_find_job({"foo"})
+        self.assertEqual(
+            self.batch.event.io_loop.add_callback.call_args[0],
+            (self.batch.find_job, {"foo"}),
+        )
+
+    def test_batch_check_find_job_multiple_states(self):
+        self.batch.event = MagicMock()
+        # currently running minions
+        self.batch.active = {"foo", "bar"}
+
+        # minion is running and find_job returns
+        self.batch.find_job_returned = {"foo"}
+
+        # minion started running but find_job did not return
+        self.batch.timedout_minions = {"faz"}
+
+        # minion finished
+        self.batch.done_minions = {"baz"}
+
+        # both not yet done but only 'foo' responded to find_job
+        not_done = {"foo", "bar"}
+
+        self.batch.check_find_job(not_done)
+
+        # assert 'bar' removed from active
+        self.assertEqual(self.batch.active, {"foo"})
+
+        # assert 'bar' added to timedout_minions
+        self.assertEqual(self.batch.timedout_minions, {"bar", "faz"})
+
+        # assert 'find_job' schedueled again only for 'foo'
+        self.assertEqual(
+            self.batch.event.io_loop.add_callback.call_args[0],
+            (self.batch.find_job, {"foo"}),
+        )
-- 
2.29.2