File async-batch-implementation-60269.patch of Package salt

From 0cc69b4d24132195af0d33ac57463d6aea1298ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pablo=20Su=C3=A1rez=20Hern=C3=A1ndez?=
 <psuarezhernandez@suse.com>
Date: Fri, 25 Oct 2024 16:21:22 +0100
Subject: [PATCH] Async batch implementation #60269

---
 changelog/60269.added.md                      |   1 +
 salt/auth/__init__.py                         |   2 +
 salt/cli/batch.py                             | 166 ++++++--
 salt/cli/batch_async.py                       | 332 +++++++++++++++
 salt/client/__init__.py                       |  44 +-
 salt/crypt.py                                 |   9 +-
 salt/master.py                                |  23 +
 salt/transport/ipc.py                         |  10 +-
 salt/utils/event.py                           |   9 +
 .../functional/channel/test_req_channel.py    |  34 ++
 tests/pytests/unit/cli/test_batch.py          |  66 ++-
 tests/pytests/unit/cli/test_batch_async.py    | 392 ++++++++++++++++++
 tests/unit/transport/test_ipc.py              |  12 +-
 13 files changed, 1033 insertions(+), 67 deletions(-)
 create mode 100644 changelog/60269.added.md
 create mode 100644 salt/cli/batch_async.py
 create mode 100644 tests/pytests/unit/cli/test_batch_async.py

diff --git a/changelog/60269.added.md b/changelog/60269.added.md
new file mode 100644
index 0000000000..0360f161ba
--- /dev/null
+++ b/changelog/60269.added.md
@@ -0,0 +1 @@
+Async batch implementation with presence ping
diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py
index a8e1470001..cb39365c0f 100644
--- a/salt/auth/__init__.py
+++ b/salt/auth/__init__.py
@@ -49,6 +49,8 @@ AUTH_INTERNAL_KEYWORDS = frozenset(
         "print_event",
         "raw",
         "yield_pub_data",
+        "batch",
+        "batch_delay",
     ]
 )
 
diff --git a/salt/cli/batch.py b/salt/cli/batch.py
index 2e43b0ee22..bd024877c2 100644
--- a/salt/cli/batch.py
+++ b/salt/cli/batch.py
@@ -13,9 +13,142 @@ import salt.exceptions
 import salt.output
 import salt.utils.stringutils
 
+# pylint: disable=import-error,no-name-in-module,redefined-builtin
+
 log = logging.getLogger(__name__)
 
 
+def get_bnum(opts, minions, quiet):
+    """
+    .. versionadded:: 3007.0
+
+    Return the active number of minions to maintain
+
+    :param dict opts:
+        The salt options dictionary.
+
+    :param minions:
+        The list of the minions to perform the calculation.
+
+    :param boolean quiet:
+        Suppress the output to the CLI.
+
+    """
+    partition = lambda x: float(x) / 100.0 * len(minions)
+    try:
+        if isinstance(opts["batch"], str) and "%" in opts["batch"]:
+            res = partition(float(opts["batch"].strip("%")))
+            if res < 1:
+                return int(math.ceil(res))
+            else:
+                return int(res)
+        else:
+            return int(opts["batch"])
+    except ValueError:
+        if not quiet:
+            salt.utils.stringutils.print_cli(
+                "Invalid batch data sent: {}\nData must be in the "
+                "form of %10, 10% or 3".format(opts["batch"])
+            )
+
+
+def batch_get_opts(
+    tgt, fun, batch, parent_opts, arg=(), tgt_type="glob", ret="", kwarg=None, **kwargs
+):
+    """
+    .. versionadded:: 3007.0
+
+    Return the dictionary with batch options populated
+
+    :param tgt:
+        Which minions to target for the execution.
+
+    :param str fun:
+        The function to run.
+
+    :param batch:
+        The batch size.
+
+    :param dict parent_opts:
+        The salt options dictionary.
+
+    :param list arg:
+        The arguments to put to the resulting ``arg`` key of resulting dictionary.
+
+    :param str tgt_type:
+        Default ``glob``. Target type to use with ``tgt``.
+
+    :param ret:
+        ``ret`` parameter to put to the resulting dictionary.
+
+    :param dict kwarg:
+        Extra arguments to put to the resulting ``arg`` key of resulting dictionary.
+
+    :param dict kwargs:
+        Extra keyword arguments.
+
+    """
+    # We need to re-import salt.utils.args here
+    # even though it has already been imported.
+    # when cmd_batch is called via the NetAPI
+    # the module is unavailable.
+    import salt.utils.args
+
+    arg = salt.utils.args.condition_input(arg, kwarg)
+    opts = {
+        "tgt": tgt,
+        "fun": fun,
+        "arg": arg,
+        "tgt_type": tgt_type,
+        "ret": ret,
+        "batch": batch,
+        "failhard": kwargs.get("failhard", parent_opts.get("failhard", False)),
+        "raw": kwargs.get("raw", False),
+    }
+
+    if "timeout" in kwargs:
+        opts["timeout"] = kwargs["timeout"]
+    if "gather_job_timeout" in kwargs:
+        opts["gather_job_timeout"] = kwargs["gather_job_timeout"]
+    if "batch_wait" in kwargs:
+        opts["batch_wait"] = int(kwargs["batch_wait"])
+
+    for key, val in parent_opts.items():
+        if key not in opts:
+            opts[key] = val
+
+    opts["batch_presence_ping_timeout"] = kwargs.get(
+        "batch_presence_ping_timeout", opts["timeout"]
+    )
+    opts["batch_presence_ping_gather_job_timeout"] = kwargs.get(
+        "batch_presence_ping_gather_job_timeout", opts["gather_job_timeout"]
+    )
+
+    return opts
+
+
+def batch_get_eauth(kwargs):
+    """
+    .. versionadded:: 3007.0
+
+    Return the dictionary with eauth information
+
+    :param dict kwargs:
+        Keyword arguments to extract eauth data from.
+
+    """
+    eauth = {}
+    if "eauth" in kwargs:
+        eauth["eauth"] = kwargs.pop("eauth")
+    if "username" in kwargs:
+        eauth["username"] = kwargs.pop("username")
+    if "password" in kwargs:
+        eauth["password"] = kwargs.pop("password")
+    if "token" in kwargs:
+        eauth["token"] = kwargs.pop("token")
+    return eauth
+
+
 class Batch:
     """
     Manage the execution of batch runs
@@ -39,6 +172,7 @@ class Batch:
         self.pub_kwargs = eauth if eauth else {}
         self.quiet = quiet
         self.options = _parser
+        self.minions = set()
         # Passing listen True to local client will prevent it from purging
         # cahced events while iterating over the batches.
         self.local = salt.client.get_local_client(opts["conf_file"], listen=True)
@@ -51,7 +185,7 @@ class Batch:
             self.opts["tgt"],
             "test.ping",
             [],
-            self.opts["timeout"],
+            self.opts.get("batch_presence_ping_timeout", self.opts["timeout"]),
         ]
 
         selected_target_option = self.opts.get("selected_target_option", None)
@@ -62,7 +196,12 @@ class Batch:
 
         self.pub_kwargs["yield_pub_data"] = True
         ping_gen = self.local.cmd_iter(
-            *args, gather_job_timeout=self.opts["gather_job_timeout"], **self.pub_kwargs
+            *args,
+            gather_job_timeout=self.opts.get(
+                "batch_presence_ping_gather_job_timeout",
+                self.opts["gather_job_timeout"],
+            ),
+            **self.pub_kwargs,
         )
 
         # Broadcast to targets
@@ -87,28 +226,7 @@ class Batch:
         return (list(fret), ping_gen, nret.difference(fret))
 
     def get_bnum(self):
-        """
-        Return the active number of minions to maintain
-        """
-
-        def partition(x):
-            return float(x) / 100.0 * len(self.minions)
-
-        try:
-            if isinstance(self.opts["batch"], str) and "%" in self.opts["batch"]:
-                res = partition(float(self.opts["batch"].strip("%")))
-                if res < 1:
-                    return int(math.ceil(res))
-                else:
-                    return int(res)
-            else:
-                return int(self.opts["batch"])
-        except ValueError:
-            if not self.quiet:
-                salt.utils.stringutils.print_cli(
-                    "Invalid batch data sent: {}\nData must be in the "
-                    "form of %10, 10% or 3".format(self.opts["batch"])
-                )
+        return get_bnum(self.opts, self.minions, self.quiet)
 
     def __update_wait(self, wait):
         now = datetime.now()
diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
new file mode 100644
index 0000000000..ddbd16870a
--- /dev/null
+++ b/salt/cli/batch_async.py
@@ -0,0 +1,332 @@
+"""
+Execute a job on the targeted minions by using a moving window of fixed size `batch`.
+"""
+
+# pylint: enable=import-error,no-name-in-module,redefined-builtin
+import logging
+
+import tornado
+
+import salt.client
+from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+
+log = logging.getLogger(__name__)
+
+
+class BatchAsync:
+    """
+    Run a job on the targeted minions by using a moving window of fixed size `batch`.
+
+    ``BatchAsync`` is used to execute a job on the targeted minions by keeping
+    the number of concurrent running minions to the size of `batch` parameter.
+
+    The control parameters are:
+        - batch: number/percentage of concurrent running minions
+        - batch_delay: minimum wait time between batches
+        - batch_presence_ping_timeout: time to wait for presence pings before starting the batch
+        - gather_job_timeout: `find_job` timeout
+        - timeout: time to wait before firing a `find_job`
+
+    When the batch starts, a `start` event is fired:
+         - tag: salt/batch/<batch-jid>/start
+         - data: {
+             "available_minions": self.minions,
+             "down_minions": targeted_minions - presence_ping_minions
+           }
+
+    When the batch ends, a `done` event is fired:
+        - tag: salt/batch/<batch-jid>/done
+        - data: {
+             "available_minions": self.minions,
+             "down_minions": targeted_minions - presence_ping_minions
+             "done_minions": self.done_minions,
+             "timedout_minions": self.timedout_minions
+         }
+    """
+
+    def __init__(self, parent_opts, jid_gen, clear_load):
+        ioloop = tornado.ioloop.IOLoop.current()
+        self.local = salt.client.get_local_client(
+            parent_opts["conf_file"], io_loop=ioloop
+        )
+        if "gather_job_timeout" in clear_load["kwargs"]:
+            clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
+                "gather_job_timeout"
+            )
+        else:
+            clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+        self.batch_presence_ping_timeout = clear_load["kwargs"].get(
+            "batch_presence_ping_timeout", None
+        )
+        self.batch_delay = clear_load["kwargs"].get("batch_delay", 1)
+        self.opts = batch_get_opts(
+            clear_load.pop("tgt"),
+            clear_load.pop("fun"),
+            clear_load["kwargs"].pop("batch"),
+            self.local.opts,
+            **clear_load,
+        )
+        self.eauth = batch_get_eauth(clear_load["kwargs"])
+        self.metadata = clear_load["kwargs"].get("metadata", {})
+        self.minions = set()
+        self.targeted_minions = set()
+        self.timedout_minions = set()
+        self.done_minions = set()
+        self.active = set()
+        self.initialized = False
+        self.jid_gen = jid_gen
+        self.ping_jid = jid_gen()
+        self.batch_jid = jid_gen()
+        self.find_job_jid = jid_gen()
+        self.find_job_returned = set()
+        self.ended = False
+        self.event = salt.utils.event.get_event(
+            "master",
+            sock_dir=self.opts["sock_dir"],
+            opts=self.opts,
+            listen=True,
+            io_loop=ioloop,
+            keep_loop=True,
+        )
+        self.scheduled = False
+        self.patterns = set()
+
+    def __set_event_handler(self):
+        ping_return_pattern = f"salt/job/{self.ping_jid}/ret/*"
+        batch_return_pattern = f"salt/job/{self.batch_jid}/ret/*"
+        self.event.subscribe(ping_return_pattern, match_type="glob")
+        self.event.subscribe(batch_return_pattern, match_type="glob")
+        self.patterns = {
+            (ping_return_pattern, "ping_return"),
+            (batch_return_pattern, "batch_run"),
+        }
+        self.event.set_event_handler(self.__event_handler)
+
+    def __event_handler(self, raw):
+        if not self.event:
+            return
+        try:
+            mtag, data = self.event.unpack(raw)
+            for pattern, op in self.patterns:
+                if mtag.startswith(pattern[:-1]):
+                    minion = data["id"]
+                    if op == "ping_return":
+                        self.minions.add(minion)
+                        if self.targeted_minions == self.minions:
+                            self.event.io_loop.spawn_callback(self.start_batch)
+                    elif op == "find_job_return":
+                        if data.get("return", None):
+                            self.find_job_returned.add(minion)
+                    elif op == "batch_run":
+                        if minion in self.active:
+                            self.active.remove(minion)
+                            self.done_minions.add(minion)
+                            self.event.io_loop.spawn_callback(self.schedule_next)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error("Exception occured while processing event: %s", ex, exc_info=True)
+
+    def _get_next(self):
+        to_run = (
+            self.minions.difference(self.done_minions)
+            .difference(self.active)
+            .difference(self.timedout_minions)
+        )
+        next_batch_size = min(
+            len(to_run),  # partial batch (all left)
+            self.batch_size - len(self.active),  # full batch or available slots
+        )
+        return set(list(to_run)[:next_batch_size])
+
+    @tornado.gen.coroutine
+    def check_find_job(self, batch_minions, jid):
+        """
+        Check if the job with specified ``jid`` was finished on the minions
+        """
+        if not self.event:
+            return
+        find_job_return_pattern = f"salt/job/{jid}/ret/*"
+        self.event.unsubscribe(find_job_return_pattern, match_type="glob")
+        self.patterns.remove((find_job_return_pattern, "find_job_return"))
+
+        timedout_minions = batch_minions.difference(self.find_job_returned).difference(
+            self.done_minions
+        )
+        self.timedout_minions = self.timedout_minions.union(timedout_minions)
+        self.active = self.active.difference(self.timedout_minions)
+        running = batch_minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
+
+        if timedout_minions:
+            self.event.io_loop.spawn_callback(self.schedule_next)
+
+        if self.event and running:
+            self.find_job_returned = self.find_job_returned.difference(running)
+            self.event.io_loop.spawn_callback(self.find_job, running)
+
+    @tornado.gen.coroutine
+    def find_job(self, minions):
+        """
+        Find if the job was finished on the minions
+        """
+        if not self.event:
+            return
+        not_done = minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
+        if not not_done:
+            return
+        try:
+            jid = self.jid_gen()
+            find_job_return_pattern = f"salt/job/{jid}/ret/*"
+            self.patterns.add((find_job_return_pattern, "find_job_return"))
+            self.event.subscribe(find_job_return_pattern, match_type="glob")
+            ret = yield self.local.run_job_async(
+                not_done,
+                "saltutil.find_job",
+                [self.batch_jid],
+                "list",
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=jid,
+                **self.eauth,
+            )
+            yield tornado.gen.sleep(self.opts["gather_job_timeout"])
+            if self.event:
+                self.event.io_loop.spawn_callback(self.check_find_job, not_done, jid)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured handling batch async: %s. Aborting execution.",
+                ex,
+                exc_info=True,
+            )
+            self.close_safe()
+
+    @tornado.gen.coroutine
+    def start(self):
+        """
+        Start the batch execution
+        """
+        if not self.event:
+            return
+        self.__set_event_handler()
+        ping_return = yield self.local.run_job_async(
+            self.opts["tgt"],
+            "test.ping",
+            [],
+            self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            **self.eauth,
+        )
+        self.targeted_minions = set(ping_return["minions"])
+        # start batching even if not all minions respond to ping
+        yield tornado.gen.sleep(
+            self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
+        )
+        if self.event:
+            self.event.io_loop.spawn_callback(self.start_batch)
+
+    @tornado.gen.coroutine
+    def start_batch(self):
+        """
+        Start the next interation of batch execution
+        """
+        if self.initialized:
+            return
+        self.batch_size = get_bnum(self.opts, self.minions, True)
+        self.initialized = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "metadata": self.metadata,
+        }
+        ret = self.event.fire_event(data, f"salt/batch/{self.batch_jid}/start")
+        if self.event:
+            self.event.io_loop.spawn_callback(self.run_next)
+
+    @tornado.gen.coroutine
+    def end_batch(self):
+        """
+        End the batch and call safe closing
+        """
+        left = self.minions.symmetric_difference(
+            self.done_minions.union(self.timedout_minions)
+        )
+        # Send salt/batch/*/done only if there is nothing to do
+        # and the event haven't been sent already
+        if left or self.ended:
+            return
+        self.ended = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "done_minions": self.done_minions,
+            "timedout_minions": self.timedout_minions,
+            "metadata": self.metadata,
+        }
+        self.event.fire_event(data, f"salt/batch/{self.batch_jid}/done")
+
+        # release to the IOLoop to allow the event to be published
+        # before closing batch async execution
+        yield tornado.gen.sleep(1)
+        self.close_safe()
+
+    def close_safe(self):
+        if self.event:
+            for pattern, label in self.patterns:
+                self.event.unsubscribe(pattern, match_type="glob")
+            self.event.remove_event_handler(self.__event_handler)
+            self.event.destroy()
+            self.event = None
+        self.local = None
+        self.ioloop = None
+
+    @tornado.gen.coroutine
+    def schedule_next(self):
+        if self.scheduled:
+            return
+        self.scheduled = True
+        # call later so that we maybe gather more returns
+        yield tornado.gen.sleep(self.batch_delay)
+        if self.event:
+            self.event.io_loop.spawn_callback(self.run_next)
+
+    @tornado.gen.coroutine
+    def run_next(self):
+        self.scheduled = False
+        next_batch = self._get_next()
+        if not next_batch:
+            yield self.end_batch()
+            return
+        self.active = self.active.union(next_batch)
+        try:
+            ret = yield self.local.run_job_async(
+                next_batch,
+                self.opts["fun"],
+                self.opts["arg"],
+                "list",
+                raw=self.opts.get("raw", False),
+                ret=self.opts.get("return", ""),
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=self.batch_jid,
+                metadata=self.metadata,
+            )
+
+            yield tornado.gen.sleep(self.opts["timeout"])
+
+            # The batch can be done already at this point, which means no self.event
+            if self.event:
+                self.event.io_loop.spawn_callback(self.find_job, set(next_batch))
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Error in scheduling next batch: %s. Aborting execution",
+                ex,
+                exc_info=True,
+            )
+            self.active = self.active.difference(next_batch)
+            self.close_safe()
+
+    # pylint: disable=W1701
+    def __del__(self):
+        self.close_safe()
diff --git a/salt/client/__init__.py b/salt/client/__init__.py
index fe8d2d62e9..20616a6527 100644
--- a/salt/client/__init__.py
+++ b/salt/client/__init__.py
@@ -594,38 +594,20 @@ class LocalClient:
         import salt.cli.batch
         import salt.utils.args
 
-        arg = salt.utils.args.condition_input(arg, kwarg)
-        opts = {
-            "tgt": tgt,
-            "fun": fun,
-            "arg": arg,
-            "tgt_type": tgt_type,
-            "ret": ret,
-            "batch": batch,
-            "failhard": kwargs.get("failhard", self.opts.get("failhard", False)),
-            "raw": kwargs.get("raw", False),
-        }
+        opts = salt.cli.batch.batch_get_opts(
+            tgt,
+            fun,
+            batch,
+            self.opts,
+            arg=arg,
+            tgt_type=tgt_type,
+            ret=ret,
+            kwarg=kwarg,
+            **kwargs,
+        )
+
+        eauth = salt.cli.batch.batch_get_eauth(kwargs)
 
-        if "timeout" in kwargs:
-            opts["timeout"] = kwargs["timeout"]
-        if "gather_job_timeout" in kwargs:
-            opts["gather_job_timeout"] = kwargs["gather_job_timeout"]
-        if "batch_wait" in kwargs:
-            opts["batch_wait"] = int(kwargs["batch_wait"])
-
-        eauth = {}
-        if "eauth" in kwargs:
-            eauth["eauth"] = kwargs.pop("eauth")
-        if "username" in kwargs:
-            eauth["username"] = kwargs.pop("username")
-        if "password" in kwargs:
-            eauth["password"] = kwargs.pop("password")
-        if "token" in kwargs:
-            eauth["token"] = kwargs.pop("token")
-
-        for key, val in self.opts.items():
-            if key not in opts:
-                opts[key] = val
         batch = salt.cli.batch.Batch(opts, eauth=eauth, quiet=True)
         for ret, _ in batch.run():
             yield ret
diff --git a/salt/crypt.py b/salt/crypt.py
index d0a8d232a9..a9d24c55c4 100644
--- a/salt/crypt.py
+++ b/salt/crypt.py
@@ -1680,8 +1680,13 @@ class Crypticle:
             log.debug("Failed to authenticate message")
             raise AuthenticationError("message authentication failed")
         result = 0
-        for zipped_x, zipped_y in zip(mac_bytes, sig):
-            result |= zipped_x ^ zipped_y
+
+        try:
+            for zipped_x, zipped_y in zip(mac_bytes, sig):
+                result |= zipped_x ^ zipped_y
+        except TypeError:
+            log.debug("Failed to authenticate message")
+            raise AuthenticationError("message authentication failed")
         if result != 0:
             log.debug("Failed to authenticate message")
             raise AuthenticationError("message authentication failed")
diff --git a/salt/master.py b/salt/master.py
index f2e3640525..0c04e0721e 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -17,11 +17,13 @@ import sys
 import threading
 import time
 
+import tornado
 import tornado.gen
 
 import salt.acl
 import salt.auth
 import salt.channel.server
+import salt.cli.batch_async
 import salt.client
 import salt.client.ssh.client
 import salt.crypt
@@ -2329,6 +2331,27 @@ class ClearFuncs(TransportMethods):
             return False
         return self.loadauth.get_tok(clear_load["token"])
 
+    def publish_batch(self, clear_load, minions, missing):
+        """
+        This method sends out publications to the minions in case of using batch
+
+        .. versionadded:: 3007.0
+        """
+        batch_load = {}
+        batch_load.update(clear_load)
+        batch = salt.cli.batch_async.BatchAsync(
+            self.local.opts,
+            lambda: self._prep_jid(clear_load, {}),
+            batch_load,
+        )
+        ioloop = tornado.ioloop.IOLoop.current()
+        ioloop.add_callback(batch.start)
+
+        return {
+            "enc": "clear",
+            "load": {"jid": batch.batch_jid, "minions": minions, "missing": missing},
+        }
+
     async def publish(self, clear_load):
         """
         This method sends out publications to the minions, it can only be used
diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py
index bd27a20752..a9036d9494 100644
--- a/salt/transport/ipc.py
+++ b/salt/transport/ipc.py
@@ -698,6 +698,7 @@ class IPCMessageSubscriber(IPCClient):
         self._read_stream_future = None
         self._saved_data = []
         self._read_in_progress = Lock()
+        self.callbacks = set()
 
     @tornado.gen.coroutine
     def _read(self, timeout, callback=None):
@@ -802,13 +803,18 @@ class IPCMessageSubscriber(IPCClient):
             return self._saved_data.pop(0)
         return self.io_loop.run_sync(lambda: self._read(timeout))
 
+    def __run_callbacks(self, raw):
+        for callback in self.callbacks:
+            self.io_loop.spawn_callback(callback, raw)
+
     @tornado.gen.coroutine
-    def read_async(self, callback):
+    def read_async(self):
         """
         Asynchronously read messages and invoke a callback when they are ready.
 
         :param callback: A callback with the received data
         """
+
         while not self.connected():
             try:
                 yield self.connect(timeout=5)
@@ -821,7 +827,7 @@ class IPCMessageSubscriber(IPCClient):
             except Exception as exc:  # pylint: disable=broad-except
                 log.error("Exception occurred while Subscriber connecting: %s", exc)
                 yield tornado.gen.sleep(1)
-        yield self._read(None, callback)
+        yield self._read(None, self.__run_callbacks)
 
     def close(self):
         """
diff --git a/salt/utils/event.py b/salt/utils/event.py
index d79c67186e..2687bec8a3 100644
--- a/salt/utils/event.py
+++ b/salt/utils/event.py
@@ -924,6 +924,15 @@ class SaltEvent:
                     # Minion fired a bad retcode, fire an event
                     self._fire_ret_load_specific_fun(load)
 
+    def remove_event_handler(self, event_handler):
+        """
+        Remove the event_handler callback
+
+        .. versionadded:: 3007.0
+        """
+        if event_handler in self.subscriber.callbacks:
+            self.subscriber.callbacks.remove(event_handler)
+
     def set_event_handler(self, event_handler):
         """
         Invoke the event_handler callback each time an event arrives.
diff --git a/tests/pytests/functional/channel/test_req_channel.py b/tests/pytests/functional/channel/test_req_channel.py
index 20d08d6266..691ce4f5ec 100644
--- a/tests/pytests/functional/channel/test_req_channel.py
+++ b/tests/pytests/functional/channel/test_req_channel.py
@@ -98,6 +98,10 @@ class ReqServerChannelProcess(salt.utils.process.SignalHandlingProcess):
 
     @tornado.gen.coroutine
     def _handle_payload(self, payload):
+        if payload.get("load", {}).get("raise_exception", False):
+            raise Exception(payload["load"]["raise_exception"])
+        if payload.get("load", {}).get("server_side_exception", False):
+            raise tornado.gen.Return(({}, {"fun": "madeup-fun"}))
         if self.req_channel_crypt == "clear":
             raise tornado.gen.Return((payload, {"fun": "send_clear"}))
         raise tornado.gen.Return((payload, {"fun": "send"}))
@@ -181,3 +185,33 @@ def test_badload(push_channel, req_channel_crypt):
         for msg in msgs:
             with pytest.raises(salt.exceptions.AuthenticationError):
                 push_channel.send(msg, timeout=5, tries=1)
+
+
+def test_payload_handling_exception(push_channel, req_channel_crypt):
+    """
+    Test of getting exception on payload handling
+    """
+    if req_channel_crypt == "clear":
+        ret = push_channel.send(
+            {"raise_exception": "Test exception"}, timeout=5, tries=1
+        )
+        assert ret == "Some exception handling minion payload"
+    else:
+        with pytest.raises(salt.exceptions.AuthenticationError):
+            push_channel.send({"raise_exception": "Test exception"}, timeout=5, tries=1)
+
+
+def test_serverside_exception(push_channel, req_channel_crypt):
+    """
+    Test of getting server side exception on payload handling
+    """
+    if req_channel_crypt == "clear":
+        ret = push_channel.send(
+            {"server_side_exception": "Test exception"}, timeout=5, tries=1
+        )
+        assert ret == "Server-side exception handling payload"
+    else:
+        with pytest.raises(salt.exceptions.AuthenticationError):
+            push_channel.send(
+                {"server_side_exception": "Test exception"}, timeout=5, tries=1
+            )
diff --git a/tests/pytests/unit/cli/test_batch.py b/tests/pytests/unit/cli/test_batch.py
index f92d3a8c6c..3f156d9797 100644
--- a/tests/pytests/unit/cli/test_batch.py
+++ b/tests/pytests/unit/cli/test_batch.py
@@ -4,7 +4,7 @@ Unit Tests for the salt.cli.batch module
 
 import pytest
 
-from salt.cli.batch import Batch
+from salt.cli.batch import Batch, batch_get_opts
 from tests.support.mock import MagicMock, patch
 
 
@@ -135,3 +135,67 @@ def test_return_value_in_run_for_return(batch):
         verbose=False,
         gather_job_timeout=5,
     )
+
+
+def test_batch_presence_ping(batch):
+    """
+    Tests passing batch_presence_ping_timeout and batch_presence_ping_gather_job_timeout
+    """
+    ret = batch_get_opts("", "test.ping", "2", {}, timeout=20, gather_job_timeout=120)
+    assert ret["batch_presence_ping_timeout"] == 20
+    assert ret["batch_presence_ping_gather_job_timeout"] == 120
+    ret = batch_get_opts(
+        "",
+        "test.ping",
+        "2",
+        {},
+        timeout=20,
+        gather_job_timeout=120,
+        batch_presence_ping_timeout=4,
+        batch_presence_ping_gather_job_timeout=360,
+    )
+    assert ret["batch_presence_ping_timeout"] == 4
+    assert ret["batch_presence_ping_gather_job_timeout"] == 360
+
+
+def test_gather_minions_with_batch_presence_ping(batch):
+    """
+    Tests __gather_minions with batch_presence_ping options
+    """
+    opts_no_pp = {
+        "batch": "2",
+        "conf_file": {},
+        "tgt": "",
+        "transport": "",
+        "timeout": 5,
+        "gather_job_timeout": 20,
+    }
+    opts_with_pp = {
+        "batch": "2",
+        "conf_file": {},
+        "tgt": "",
+        "transport": "",
+        "timeout": 5,
+        "gather_job_timeout": 20,
+        "batch_presence_ping_timeout": 3,
+        "batch_presence_ping_gather_job_timeout": 4,
+    }
+    local_client_mock = MagicMock()
+    with patch(
+        "salt.client.get_local_client", MagicMock(return_value=local_client_mock)
+    ), patch("salt.client.LocalClient.cmd_iter", MagicMock(return_value=[])):
+        Batch(opts_no_pp).gather_minions()
+        Batch(opts_with_pp).gather_minions()
+        assert local_client_mock.mock_calls[0][1][3] == opts_no_pp["timeout"]
+        assert (
+            local_client_mock.mock_calls[0][2]["gather_job_timeout"]
+            == opts_no_pp["gather_job_timeout"]
+        )
+        assert (
+            local_client_mock.mock_calls[2][1][3]
+            == opts_with_pp["batch_presence_ping_timeout"]
+        )
+        assert (
+            local_client_mock.mock_calls[2][2]["gather_job_timeout"]
+            == opts_with_pp["batch_presence_ping_gather_job_timeout"]
+        )
diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py
new file mode 100644
index 0000000000..94a9e15747
--- /dev/null
+++ b/tests/pytests/unit/cli/test_batch_async.py
@@ -0,0 +1,392 @@
+import pytest
+import tornado
+
+from salt.cli.batch_async import BatchAsync
+from tests.support.mock import MagicMock, patch
+
+
+@pytest.fixture
+def batch(temp_salt_master):
+    opts = {
+        "batch": "1",
+        "conf_file": {},
+        "tgt": "*",
+        "timeout": 5,
+        "gather_job_timeout": 5,
+        "batch_presence_ping_timeout": 1,
+        "transport": None,
+        "sock_dir": "",
+    }
+
+    with patch("salt.client.get_local_client", MagicMock(return_value=MagicMock())):
+        with patch("salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)):
+            batch = BatchAsync(
+                opts,
+                MagicMock(side_effect=["1234", "1235", "1236"]),
+                {
+                    "tgt": "",
+                    "fun": "",
+                    "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+                },
+            )
+            yield batch
+
+
+def test_ping_jid(batch):
+    assert batch.ping_jid == "1234"
+
+
+def test_batch_jid(batch):
+    assert batch.batch_jid == "1235"
+
+
+def test_find_job_jid(batch):
+    assert batch.find_job_jid == "1236"
+
+
+def test_batch_size(batch):
+    """
+    Tests passing batch value as a number
+    """
+    batch.opts = {"batch": "2", "timeout": 5}
+    batch.minions = {"foo", "bar"}
+    batch.start_batch()
+    assert batch.batch_size == 2
+
+
+def test_batch_start_on_batch_presence_ping_timeout(batch):
+    # batch_async = BatchAsyncMock();
+    batch.event = MagicMock()
+    future = tornado.gen.Future()
+    future.set_result({"minions": ["foo", "bar"]})
+    batch.local.run_job_async.return_value = future
+    with patch("tornado.gen.sleep", return_value=future):
+        # ret = batch_async.start(batch)
+        ret = batch.start()
+        # assert start_batch is called later with batch_presence_ping_timeout as param
+        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert test.ping called
+        assert batch.local.run_job_async.call_args[0] == ("*", "test.ping", [], "glob")
+        # assert targeted_minions == all minions matched by tgt
+        assert batch.targeted_minions == {"foo", "bar"}
+
+
+def test_batch_start_on_gather_job_timeout(batch):
+    # batch_async = BatchAsyncMock();
+    batch.event = MagicMock()
+    future = tornado.gen.Future()
+    future.set_result({"minions": ["foo", "bar"]})
+    batch.local.run_job_async.return_value = future
+    batch.batch_presence_ping_timeout = None
+    with patch("tornado.gen.sleep", return_value=future):
+        # ret = batch_async.start(batch)
+        ret = batch.start()
+        # assert start_batch is called later with gather_job_timeout as param
+        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+
+
+def test_batch_fire_start_event(batch):
+    batch.minions = {"foo", "bar"}
+    batch.opts = {"batch": "2", "timeout": 5}
+    batch.event = MagicMock()
+    batch.metadata = {"mykey": "myvalue"}
+    batch.start_batch()
+    assert batch.event.fire_event.call_args[0] == (
+        {
+            "available_minions": {"foo", "bar"},
+            "down_minions": set(),
+            "metadata": batch.metadata,
+        },
+        "salt/batch/1235/start",
+    )
+
+
+def test_start_batch_calls_next(batch):
+    batch.run_next = MagicMock(return_value=MagicMock())
+    batch.event = MagicMock()
+    batch.start_batch()
+    assert batch.initialized
+    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.run_next,)
+
+
+def test_batch_fire_done_event(batch):
+    batch.targeted_minions = {"foo", "baz", "bar"}
+    batch.minions = {"foo", "bar"}
+    batch.done_minions = {"foo"}
+    batch.timedout_minions = {"bar"}
+    batch.event = MagicMock()
+    batch.metadata = {"mykey": "myvalue"}
+    old_event = batch.event
+    batch.end_batch()
+    assert old_event.fire_event.call_args[0] == (
+        {
+            "available_minions": {"foo", "bar"},
+            "done_minions": batch.done_minions,
+            "down_minions": {"baz"},
+            "timedout_minions": batch.timedout_minions,
+            "metadata": batch.metadata,
+        },
+        "salt/batch/1235/done",
+    )
+
+
+def test_batch__del__(batch):
+    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
+    event = MagicMock()
+    batch.event = event
+    batch.__del__()
+    assert batch.local is None
+    assert batch.event is None
+    assert batch.ioloop is None
+
+
+def test_batch_close_safe(batch):
+    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
+    event = MagicMock()
+    batch.event = event
+    batch.patterns = {
+        ("salt/job/1234/ret/*", "find_job_return"),
+        ("salt/job/4321/ret/*", "find_job_return"),
+    }
+    batch.close_safe()
+    assert batch.local is None
+    assert batch.event is None
+    assert batch.ioloop is None
+    assert len(event.unsubscribe.mock_calls) == 2
+    assert len(event.remove_event_handler.mock_calls) == 1
+
+
+def test_batch_next(batch):
+    batch.event = MagicMock()
+    batch.opts["fun"] = "my.fun"
+    batch.opts["arg"] = []
+    batch._get_next = MagicMock(return_value={"foo", "bar"})
+    batch.batch_size = 2
+    future = tornado.gen.Future()
+    future.set_result({"minions": ["foo", "bar"]})
+    batch.local.run_job_async.return_value = future
+    with patch("tornado.gen.sleep", return_value=future):
+        batch.run_next()
+        assert batch.local.run_job_async.call_args[0] == (
+            {"foo", "bar"},
+            "my.fun",
+            [],
+            "list",
+        )
+        assert batch.event.io_loop.spawn_callback.call_args[0] == (
+            batch.find_job,
+            {"foo", "bar"},
+        )
+        assert batch.active == {"bar", "foo"}
+
+
+def test_next_batch(batch):
+    batch.minions = {"foo", "bar"}
+    batch.batch_size = 2
+    assert batch._get_next() == {"foo", "bar"}
+
+
+def test_next_batch_one_done(batch):
+    batch.minions = {"foo", "bar"}
+    batch.done_minions = {"bar"}
+    batch.batch_size = 2
+    assert batch._get_next() == {"foo"}
+
+
+def test_next_batch_one_done_one_active(batch):
+    batch.minions = {"foo", "bar", "baz"}
+    batch.done_minions = {"bar"}
+    batch.active = {"baz"}
+    batch.batch_size = 2
+    assert batch._get_next() == {"foo"}
+
+
+def test_next_batch_one_done_one_active_one_timedout(batch):
+    batch.minions = {"foo", "bar", "baz", "faz"}
+    batch.done_minions = {"bar"}
+    batch.active = {"baz"}
+    batch.timedout_minions = {"faz"}
+    batch.batch_size = 2
+    assert batch._get_next() == {"foo"}
+
+
+def test_next_batch_bigger_size(batch):
+    batch.minions = {"foo", "bar"}
+    batch.batch_size = 3
+    assert batch._get_next() == {"foo", "bar"}
+
+
+def test_next_batch_all_done(batch):
+    batch.minions = {"foo", "bar"}
+    batch.done_minions = {"foo", "bar"}
+    batch.batch_size = 2
+    assert batch._get_next() == set()
+
+
+def test_next_batch_all_active(batch):
+    batch.minions = {"foo", "bar"}
+    batch.active = {"foo", "bar"}
+    batch.batch_size = 2
+    assert batch._get_next() == set()
+
+
+def test_next_batch_all_timedout(batch):
+    batch.minions = {"foo", "bar"}
+    batch.timedout_minions = {"foo", "bar"}
+    batch.batch_size = 2
+    assert batch._get_next() == set()
+
+
+def test_batch__event_handler_ping_return(batch):
+    batch.targeted_minions = {"foo"}
+    batch.event = MagicMock(
+        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+    )
+    batch.start()
+    assert batch.minions == set()
+    batch._BatchAsync__event_handler(MagicMock())
+    assert batch.minions == {"foo"}
+    assert batch.done_minions == set()
+
+
+def test_batch__event_handler_call_start_batch_when_all_pings_return(batch):
+    batch.targeted_minions = {"foo"}
+    batch.event = MagicMock(
+        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+    )
+    batch.start()
+    batch._BatchAsync__event_handler(MagicMock())
+    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+
+
+def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(batch):
+    batch.targeted_minions = {"foo", "bar"}
+    batch.event = MagicMock(
+        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+    )
+    batch.start()
+    batch._BatchAsync__event_handler(MagicMock())
+    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+
+
+def test_batch__event_handler_batch_run_return(batch):
+    batch.event = MagicMock(
+        unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
+    )
+    batch.start()
+    batch.active = {"foo"}
+    batch._BatchAsync__event_handler(MagicMock())
+    assert batch.active == set()
+    assert batch.done_minions == {"foo"}
+    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.schedule_next,)
+
+
+def test_batch__event_handler_find_job_return(batch):
+    batch.event = MagicMock(
+        unpack=MagicMock(
+            return_value=(
+                "salt/job/1236/ret/foo",
+                {"id": "foo", "return": "deadbeaf"},
+            )
+        )
+    )
+    batch.start()
+    batch.patterns.add(("salt/job/1236/ret/*", "find_job_return"))
+    batch._BatchAsync__event_handler(MagicMock())
+    assert batch.find_job_returned == {"foo"}
+
+
+def test_batch_run_next_end_batch_when_no_next(batch):
+    batch.end_batch = MagicMock()
+    batch._get_next = MagicMock(return_value={})
+    batch.run_next()
+    assert len(batch.end_batch.mock_calls) == 1
+
+
+def test_batch_find_job(batch):
+    batch.event = MagicMock()
+    future = tornado.gen.Future()
+    future.set_result({})
+    batch.local.run_job_async.return_value = future
+    batch.minions = {"foo", "bar"}
+    batch.jid_gen = MagicMock(return_value="1234")
+    with patch("tornado.gen.sleep", return_value=future):
+        batch.find_job({"foo", "bar"})
+        assert batch.event.io_loop.spawn_callback.call_args[0] == (
+            batch.check_find_job,
+            {"foo", "bar"},
+            "1234",
+        )
+
+
+def test_batch_find_job_with_done_minions(batch):
+    batch.done_minions = {"bar"}
+    batch.event = MagicMock()
+    future = tornado.gen.Future()
+    future.set_result({})
+    batch.local.run_job_async.return_value = future
+    batch.minions = {"foo", "bar"}
+    batch.jid_gen = MagicMock(return_value="1234")
+    with patch("tornado.gen.sleep", return_value=future):
+        batch.find_job({"foo", "bar"})
+        assert batch.event.io_loop.spawn_callback.call_args[0] == (
+            batch.check_find_job,
+            {"foo"},
+            "1234",
+        )
+
+
+def test_batch_check_find_job_did_not_return(batch):
+    batch.event = MagicMock()
+    batch.active = {"foo"}
+    batch.find_job_returned = set()
+    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+    batch.check_find_job({"foo"}, jid="1234")
+    assert batch.find_job_returned == set()
+    assert batch.active == set()
+    assert len(batch.event.io_loop.add_callback.mock_calls) == 0
+
+
+def test_batch_check_find_job_did_return(batch):
+    batch.event = MagicMock()
+    batch.find_job_returned = {"foo"}
+    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+    batch.check_find_job({"foo"}, jid="1234")
+    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+
+
+def test_batch_check_find_job_multiple_states(batch):
+    batch.event = MagicMock()
+    # currently running minions
+    batch.active = {"foo", "bar"}
+
+    # minion is running and find_job returns
+    batch.find_job_returned = {"foo"}
+
+    # minion started running but find_job did not return
+    batch.timedout_minions = {"faz"}
+
+    # minion finished
+    batch.done_minions = {"baz"}
+
+    # both not yet done but only 'foo' responded to find_job
+    not_done = {"foo", "bar"}
+
+    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+    batch.check_find_job(not_done, jid="1234")
+
+    # assert 'bar' removed from active
+    assert batch.active == {"foo"}
+
+    # assert 'bar' added to timedout_minions
+    assert batch.timedout_minions == {"bar", "faz"}
+
+    # assert 'find_job' schedueled again only for 'foo'
+    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+
+
+def test_only_on_run_next_is_scheduled(batch):
+    batch.event = MagicMock()
+    batch.scheduled = True
+    batch.schedule_next()
+    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
diff --git a/tests/unit/transport/test_ipc.py b/tests/unit/transport/test_ipc.py
index eb469d6efb..45e39abeaf 100644
--- a/tests/unit/transport/test_ipc.py
+++ b/tests/unit/transport/test_ipc.py
@@ -108,8 +108,10 @@ class IPCMessagePubSubCase(tornado.testing.AsyncTestCase):
                 self.stop()
 
         # Now let both waiting data at once
-        client1.read_async(handler)
-        client2.read_async(handler)
+        client1.callbacks = {handler}
+        client2.callbacks = {handler}
+        client1.read_async()
+        client2.read_async()
         self.pub_channel.publish("TEST")
         self.wait()
         self.assertEqual(len(call_cnt), 2)
@@ -146,12 +148,8 @@ class IPCMessagePubSubCase(tornado.testing.AsyncTestCase):
         watchdog = threading.Thread(target=close_server)
         watchdog.start()
 
-        # Runs in ioloop thread so we're safe from race conditions here
-        def handler(raw):
-            pass
-
         try:
-            ret1 = yield client1.read_async(handler)
+            ret1 = yield client1.read_async()
             self.wait()
         except StreamClosedError as ex:
             assert False, "StreamClosedError was raised inside the Future"
-- 
2.47.0

openSUSE Build Service is sponsored by