We have some news to share for the request index beta feature. We’ve added more options to sort your requests, counters to the individual filters and documentation for the search functionality. Checkout the blog post for more details.

File prevent-oom-with-high-amount-of-batch-async-calls-bs.patch of Package salt

From d57472b4fa2213ec551197ee2e147aef364fdcfe Mon Sep 17 00:00:00 2001
From: Victor Zhestkov <vzhestkov@suse.com>
Date: Wed, 15 May 2024 11:47:35 +0200
Subject: [PATCH] Prevent OOM with high amount of batch async calls
 (bsc#1216063)

* Refactor batch_async implementation

* Fix batch_async tests after refactoring
---
 salt/cli/batch_async.py                    | 584 ++++++++++++++-------
 salt/master.py                             |   9 +-
 tests/pytests/unit/cli/test_batch_async.py | 360 +++++++------
 3 files changed, 597 insertions(+), 356 deletions(-)

diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
index 1012ce37cc..5d49993faa 100644
--- a/salt/cli/batch_async.py
+++ b/salt/cli/batch_async.py
@@ -2,18 +2,193 @@
 Execute a job on the targeted minions by using a moving window of fixed size `batch`.
 """
 
-import gc
-
-# pylint: enable=import-error,no-name-in-module,redefined-builtin
 import logging
+import re
 
 import salt.client
 import salt.ext.tornado
+import salt.utils.event
 from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+from salt.ext.tornado.iostream import StreamClosedError
 
 log = logging.getLogger(__name__)
 
 
+__SHARED_EVENTS_CHANNEL = None
+
+
+def _get_shared_events_channel(opts, io_loop):
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is None:
+        __SHARED_EVENTS_CHANNEL = SharedEventsChannel(opts, io_loop)
+    return __SHARED_EVENTS_CHANNEL
+
+
+def _destroy_unused_shared_events_channel():
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is not None and __SHARED_EVENTS_CHANNEL.destroy_unused():
+        __SHARED_EVENTS_CHANNEL = None
+
+
+def batch_async_required(opts, minions, extra):
+    """
+    Check opts to identify if batch async is required for the operation.
+    """
+    if not isinstance(minions, list):
+        False
+    batch_async_opts = opts.get("batch_async", {})
+    batch_async_threshold = (
+        batch_async_opts.get("threshold", 1)
+        if isinstance(batch_async_opts, dict)
+        else 1
+    )
+    if batch_async_threshold == -1:
+        batch_size = get_bnum(extra, minions, True)
+        return len(minions) >= batch_size
+    elif batch_async_threshold > 0:
+        return len(minions) >= batch_async_threshold
+    return False
+
+
+class SharedEventsChannel:
+    def __init__(self, opts, io_loop):
+        self.io_loop = io_loop
+        self.local_client = salt.client.get_local_client(
+            opts["conf_file"], io_loop=self.io_loop
+        )
+        self.master_event = salt.utils.event.get_event(
+            "master",
+            sock_dir=self.local_client.opts["sock_dir"],
+            opts=self.local_client.opts,
+            listen=True,
+            io_loop=self.io_loop,
+            keep_loop=True,
+        )
+        self.master_event.set_event_handler(self.__handle_event)
+        if self.master_event.subscriber.stream:
+            self.master_event.subscriber.stream.set_close_callback(self.__handle_close)
+        self._re_tag_ret_event = re.compile(r"salt\/job\/(\d+)\/ret\/.*")
+        self._subscribers = {}
+        self._subscriptions = {}
+        self._used_by = set()
+        batch_async_opts = opts.get("batch_async", {})
+        if not isinstance(batch_async_opts, dict):
+            batch_async_opts = {}
+        self._subscriber_reconnect_tries = batch_async_opts.get(
+            "subscriber_reconnect_tries", 5
+        )
+        self._subscriber_reconnect_interval = batch_async_opts.get(
+            "subscriber_reconnect_interval", 1.0
+        )
+        self._reconnecting_subscriber = False
+
+    def subscribe(self, jid, op, subscriber_id, handler):
+        if subscriber_id not in self._subscribers:
+            self._subscribers[subscriber_id] = set()
+        if jid not in self._subscriptions:
+            self._subscriptions[jid] = []
+        self._subscribers[subscriber_id].add(jid)
+        if (op, subscriber_id, handler) not in self._subscriptions[jid]:
+            self._subscriptions[jid].append((op, subscriber_id, handler))
+        if not self.master_event.subscriber.connected():
+            self.__reconnect_subscriber()
+
+    def unsubscribe(self, jid, op, subscriber_id):
+        if subscriber_id not in self._subscribers:
+            return
+        jids = self._subscribers[subscriber_id].copy()
+        if jid is not None:
+            jids = set(jid)
+        for i_jid in jids:
+            self._subscriptions[i_jid] = list(
+                filter(
+                    lambda x: not (op in (x[0], None) and x[1] == subscriber_id),
+                    self._subscriptions.get(i_jid, []),
+                )
+            )
+            self._subscribers[subscriber_id].discard(i_jid)
+        self._subscriptions = dict(filter(lambda x: x[1], self._subscriptions.items()))
+        if not self._subscribers[subscriber_id]:
+            del self._subscribers[subscriber_id]
+
+    @salt.ext.tornado.gen.coroutine
+    def __handle_close(self):
+        if not self._subscriptions:
+            return
+        log.warning("Master Event Subscriber was closed. Trying to reconnect...")
+        yield self.__reconnect_subscriber()
+
+    @salt.ext.tornado.gen.coroutine
+    def __handle_event(self, raw):
+        if self.master_event is None:
+            return
+        try:
+            tag, data = self.master_event.unpack(raw)
+            tag_match = self._re_tag_ret_event.match(tag)
+            if tag_match:
+                jid = tag_match.group(1)
+                if jid in self._subscriptions:
+                    for op, _, handler in self._subscriptions[jid]:
+                        yield handler(tag, data, op)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
+
+    @salt.ext.tornado.gen.coroutine
+    def __reconnect_subscriber(self):
+        if self.master_event.subscriber.connected() or self._reconnecting_subscriber:
+            return
+        self._reconnecting_subscriber = True
+        max_tries = max(1, int(self._subscriber_reconnect_tries))
+        _try = 1
+        while _try <= max_tries:
+            log.info(
+                "Trying to reconnect to event publisher (try %d of %d) ...",
+                _try,
+                max_tries,
+            )
+            try:
+                yield self.master_event.subscriber.connect()
+            except StreamClosedError:
+                log.warning(
+                    "Unable to reconnect to event publisher (try %d of %d)",
+                    _try,
+                    max_tries,
+                )
+            if self.master_event.subscriber.connected():
+                self.master_event.subscriber.stream.set_close_callback(
+                    self.__handle_close
+                )
+                log.info("Event publisher connection restored")
+                self._reconnecting_subscriber = False
+                return
+            if _try < max_tries:
+                yield salt.ext.tornado.gen.sleep(self._subscriber_reconnect_interval)
+            _try += 1
+        self._reconnecting_subscriber = False
+
+    def use(self, subscriber_id):
+        self._used_by.add(subscriber_id)
+        return self
+
+    def unuse(self, subscriber_id):
+        self._used_by.discard(subscriber_id)
+
+    def destroy_unused(self):
+        if self._used_by:
+            return False
+        self.master_event.remove_event_handler(self.__handle_event)
+        self.master_event.destroy()
+        self.master_event = None
+        self.local_client.destroy()
+        self.local_client = None
+        return True
+
+
 class BatchAsync:
     """
     Run a job on the targeted minions by using a moving window of fixed size `batch`.
@@ -28,14 +203,14 @@ class BatchAsync:
         - gather_job_timeout: `find_job` timeout
         - timeout: time to wait before firing a `find_job`
 
-    When the batch stars, a `start` event is fired:
+    When the batch starts, a `start` event is fired:
          - tag: salt/batch/<batch-jid>/start
          - data: {
              "available_minions": self.minions,
              "down_minions": targeted_minions - presence_ping_minions
            }
 
-    When the batch ends, an `done` event is fired:
+    When the batch ends, a `done` event is fired:
         - tag: salt/batch/<batch-jid>/done
         - data: {
              "available_minions": self.minions,
@@ -45,17 +220,26 @@ class BatchAsync:
          }
     """
 
-    def __init__(self, parent_opts, jid_gen, clear_load):
-        ioloop = salt.ext.tornado.ioloop.IOLoop.current()
-        self.local = salt.client.get_local_client(
-            parent_opts["conf_file"], io_loop=ioloop
+    def __init__(self, opts, jid_gen, clear_load):
+        self.extra_job_kwargs = {}
+        kwargs = clear_load.get("kwargs", {})
+        for kwarg in ("module_executors", "executor_opts"):
+            if kwarg in kwargs:
+                self.extra_job_kwargs[kwarg] = kwargs[kwarg]
+            elif kwarg in opts:
+                self.extra_job_kwargs[kwarg] = opts[kwarg]
+        self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()
+        self.events_channel = _get_shared_events_channel(opts, self.io_loop).use(
+            id(self)
         )
         if "gather_job_timeout" in clear_load["kwargs"]:
             clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
                 "gather_job_timeout"
             )
         else:
-            clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+            clear_load["gather_job_timeout"] = self.events_channel.local_client.opts[
+                "gather_job_timeout"
+            ]
         self.batch_presence_ping_timeout = clear_load["kwargs"].get(
             "batch_presence_ping_timeout", None
         )
@@ -64,8 +248,8 @@ class BatchAsync:
             clear_load.pop("tgt"),
             clear_load.pop("fun"),
             clear_load["kwargs"].pop("batch"),
-            self.local.opts,
-            **clear_load
+            self.events_channel.local_client.opts,
+            **clear_load,
         )
         self.eauth = batch_get_eauth(clear_load["kwargs"])
         self.metadata = clear_load["kwargs"].get("metadata", {})
@@ -78,54 +262,45 @@ class BatchAsync:
         self.jid_gen = jid_gen
         self.ping_jid = jid_gen()
         self.batch_jid = jid_gen()
-        self.find_job_jid = jid_gen()
         self.find_job_returned = set()
+        self.metadata.update({"batch_jid": self.batch_jid, "ping_jid": self.ping_jid})
         self.ended = False
-        self.event = salt.utils.event.get_event(
-            "master",
-            self.opts["sock_dir"],
-            self.opts["transport"],
-            opts=self.opts,
-            listen=True,
-            io_loop=ioloop,
-            keep_loop=True,
-        )
+        self.event = self.events_channel.master_event
         self.scheduled = False
-        self.patterns = set()
 
     def __set_event_handler(self):
-        ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid)
-        batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid)
-        self.event.subscribe(ping_return_pattern, match_type="glob")
-        self.event.subscribe(batch_return_pattern, match_type="glob")
-        self.patterns = {
-            (ping_return_pattern, "ping_return"),
-            (batch_return_pattern, "batch_run"),
-        }
-        self.event.set_event_handler(self.__event_handler)
+        self.events_channel.subscribe(
+            self.ping_jid, "ping_return", id(self), self.__event_handler
+        )
+        self.events_channel.subscribe(
+            self.batch_jid, "batch_run", id(self), self.__event_handler
+        )
 
-    def __event_handler(self, raw):
+    @salt.ext.tornado.gen.coroutine
+    def __event_handler(self, tag, data, op):
         if not self.event:
             return
         try:
-            mtag, data = self.event.unpack(raw)
-            for (pattern, op) in self.patterns:
-                if mtag.startswith(pattern[:-1]):
-                    minion = data["id"]
-                    if op == "ping_return":
-                        self.minions.add(minion)
-                        if self.targeted_minions == self.minions:
-                            self.event.io_loop.spawn_callback(self.start_batch)
-                    elif op == "find_job_return":
-                        if data.get("return", None):
-                            self.find_job_returned.add(minion)
-                    elif op == "batch_run":
-                        if minion in self.active:
-                            self.active.remove(minion)
-                            self.done_minions.add(minion)
-                            self.event.io_loop.spawn_callback(self.schedule_next)
-        except Exception as ex:
-            log.error("Exception occured while processing event: {}".format(ex))
+            minion = data["id"]
+            if op == "ping_return":
+                self.minions.add(minion)
+                if self.targeted_minions == self.minions:
+                    yield self.start_batch()
+            elif op == "find_job_return":
+                if data.get("return", None):
+                    self.find_job_returned.add(minion)
+            elif op == "batch_run":
+                if minion in self.active:
+                    self.active.remove(minion)
+                    self.done_minions.add(minion)
+                    yield self.schedule_next()
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
 
     def _get_next(self):
         to_run = (
@@ -139,176 +314,203 @@ class BatchAsync:
         )
         return set(list(to_run)[:next_batch_size])
 
+    @salt.ext.tornado.gen.coroutine
     def check_find_job(self, batch_minions, jid):
-        if self.event:
-            find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
-            self.event.unsubscribe(find_job_return_pattern, match_type="glob")
-            self.patterns.remove((find_job_return_pattern, "find_job_return"))
-
-            timedout_minions = batch_minions.difference(
-                self.find_job_returned
-            ).difference(self.done_minions)
-            self.timedout_minions = self.timedout_minions.union(timedout_minions)
-            self.active = self.active.difference(self.timedout_minions)
-            running = batch_minions.difference(self.done_minions).difference(
-                self.timedout_minions
-            )
+        """
+        Check if the job with specified ``jid`` was finished on the minions
+        """
+        if not self.event:
+            return
+        self.events_channel.unsubscribe(jid, "find_job_return", id(self))
 
-            if timedout_minions:
-                self.schedule_next()
+        timedout_minions = batch_minions.difference(self.find_job_returned).difference(
+            self.done_minions
+        )
+        self.timedout_minions = self.timedout_minions.union(timedout_minions)
+        self.active = self.active.difference(self.timedout_minions)
+        running = batch_minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
 
-            if self.event and running:
-                self.find_job_returned = self.find_job_returned.difference(running)
-                self.event.io_loop.spawn_callback(self.find_job, running)
+        if timedout_minions:
+            yield self.schedule_next()
+
+        if self.event and running:
+            self.find_job_returned = self.find_job_returned.difference(running)
+            yield self.find_job(running)
 
     @salt.ext.tornado.gen.coroutine
     def find_job(self, minions):
-        if self.event:
-            not_done = minions.difference(self.done_minions).difference(
-                self.timedout_minions
+        """
+        Find if the job was finished on the minions
+        """
+        if not self.event:
+            return
+        not_done = minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
+        if not not_done:
+            return
+        try:
+            jid = self.jid_gen()
+            self.events_channel.subscribe(
+                jid, "find_job_return", id(self), self.__event_handler
             )
-            try:
-                if not_done:
-                    jid = self.jid_gen()
-                    find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
-                    self.patterns.add((find_job_return_pattern, "find_job_return"))
-                    self.event.subscribe(find_job_return_pattern, match_type="glob")
-                    ret = yield self.local.run_job_async(
-                        not_done,
-                        "saltutil.find_job",
-                        [self.batch_jid],
-                        "list",
-                        gather_job_timeout=self.opts["gather_job_timeout"],
-                        jid=jid,
-                        **self.eauth
-                    )
-                    yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"])
-                    if self.event:
-                        self.event.io_loop.spawn_callback(
-                            self.check_find_job, not_done, jid
-                        )
-            except Exception as ex:
-                log.error(
-                    "Exception occured handling batch async: {}. Aborting execution.".format(
-                        ex
-                    )
-                )
-                self.close_safe()
+            ret = yield self.events_channel.local_client.run_job_async(
+                not_done,
+                "saltutil.find_job",
+                [self.batch_jid],
+                "list",
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=jid,
+                io_loop=self.io_loop,
+                listen=False,
+                **self.eauth,
+            )
+            yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"])
+            if self.event:
+                yield self.check_find_job(not_done, jid)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured handling batch async: %s. Aborting execution.",
+                ex,
+                exc_info=True,
+            )
+            self.close_safe()
 
     @salt.ext.tornado.gen.coroutine
     def start(self):
+        """
+        Start the batch execution
+        """
+        if not self.event:
+            return
+        self.__set_event_handler()
+        ping_return = yield self.events_channel.local_client.run_job_async(
+            self.opts["tgt"],
+            "test.ping",
+            [],
+            self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            io_loop=self.io_loop,
+            listen=False,
+            **self.eauth,
+        )
+        self.targeted_minions = set(ping_return["minions"])
+        # start batching even if not all minions respond to ping
+        yield salt.ext.tornado.gen.sleep(
+            self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
+        )
         if self.event:
-            self.__set_event_handler()
-            ping_return = yield self.local.run_job_async(
-                self.opts["tgt"],
-                "test.ping",
-                [],
-                self.opts.get(
-                    "selected_target_option", self.opts.get("tgt_type", "glob")
-                ),
-                gather_job_timeout=self.opts["gather_job_timeout"],
-                jid=self.ping_jid,
-                metadata=self.metadata,
-                **self.eauth
-            )
-            self.targeted_minions = set(ping_return["minions"])
-            # start batching even if not all minions respond to ping
-            yield salt.ext.tornado.gen.sleep(
-                self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
-            )
-            if self.event:
-                self.event.io_loop.spawn_callback(self.start_batch)
+            yield self.start_batch()
 
     @salt.ext.tornado.gen.coroutine
     def start_batch(self):
-        if not self.initialized:
-            self.batch_size = get_bnum(self.opts, self.minions, True)
-            self.initialized = True
-            data = {
-                "available_minions": self.minions,
-                "down_minions": self.targeted_minions.difference(self.minions),
-                "metadata": self.metadata,
-            }
-            ret = self.event.fire_event(
-                data, "salt/batch/{}/start".format(self.batch_jid)
-            )
-            if self.event:
-                self.event.io_loop.spawn_callback(self.run_next)
+        """
+        Fire `salt/batch/*/start` and continue batch with `run_next`
+        """
+        if self.initialized:
+            return
+        self.batch_size = get_bnum(self.opts, self.minions, True)
+        self.initialized = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "metadata": self.metadata,
+        }
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/start"
+        )
+        if self.event:
+            yield self.run_next()
 
     @salt.ext.tornado.gen.coroutine
     def end_batch(self):
+        """
+        End the batch and call safe closing
+        """
         left = self.minions.symmetric_difference(
             self.done_minions.union(self.timedout_minions)
         )
-        if not left and not self.ended:
-            self.ended = True
-            data = {
-                "available_minions": self.minions,
-                "down_minions": self.targeted_minions.difference(self.minions),
-                "done_minions": self.done_minions,
-                "timedout_minions": self.timedout_minions,
-                "metadata": self.metadata,
-            }
-            self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid))
-
-            # release to the IOLoop to allow the event to be published
-            # before closing batch async execution
-            yield salt.ext.tornado.gen.sleep(1)
-            self.close_safe()
+        # Send salt/batch/*/done only if there is nothing to do
+        # and the event haven't been sent already
+        if left or self.ended:
+            return
+        self.ended = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "done_minions": self.done_minions,
+            "timedout_minions": self.timedout_minions,
+            "metadata": self.metadata,
+        }
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/done"
+        )
+
+        # release to the IOLoop to allow the event to be published
+        # before closing batch async execution
+        yield salt.ext.tornado.gen.sleep(1)
+        self.close_safe()
 
     def close_safe(self):
-        for (pattern, label) in self.patterns:
-            self.event.unsubscribe(pattern, match_type="glob")
-        self.event.remove_event_handler(self.__event_handler)
+        if self.events_channel is not None:
+            self.events_channel.unsubscribe(None, None, id(self))
+            self.events_channel.unuse(id(self))
+            self.events_channel = None
+            _destroy_unused_shared_events_channel()
         self.event = None
-        self.local = None
-        self.ioloop = None
-        del self
-        gc.collect()
 
     @salt.ext.tornado.gen.coroutine
     def schedule_next(self):
-        if not self.scheduled:
-            self.scheduled = True
-            # call later so that we maybe gather more returns
-            yield salt.ext.tornado.gen.sleep(self.batch_delay)
-            if self.event:
-                self.event.io_loop.spawn_callback(self.run_next)
+        if self.scheduled:
+            return
+        self.scheduled = True
+        # call later so that we maybe gather more returns
+        yield salt.ext.tornado.gen.sleep(self.batch_delay)
+        if self.event:
+            yield self.run_next()
 
     @salt.ext.tornado.gen.coroutine
     def run_next(self):
+        """
+        Continue batch execution with the next targets
+        """
         self.scheduled = False
         next_batch = self._get_next()
-        if next_batch:
-            self.active = self.active.union(next_batch)
-            try:
-                ret = yield self.local.run_job_async(
-                    next_batch,
-                    self.opts["fun"],
-                    self.opts["arg"],
-                    "list",
-                    raw=self.opts.get("raw", False),
-                    ret=self.opts.get("return", ""),
-                    gather_job_timeout=self.opts["gather_job_timeout"],
-                    jid=self.batch_jid,
-                    metadata=self.metadata,
-                )
-
-                yield salt.ext.tornado.gen.sleep(self.opts["timeout"])
-
-                # The batch can be done already at this point, which means no self.event
-                if self.event:
-                    self.event.io_loop.spawn_callback(self.find_job, set(next_batch))
-            except Exception as ex:
-                log.error("Error in scheduling next batch: %s. Aborting execution", ex)
-                self.active = self.active.difference(next_batch)
-                self.close_safe()
-        else:
+        if not next_batch:
             yield self.end_batch()
-        gc.collect()
+            return
+        self.active = self.active.union(next_batch)
+        try:
+            ret = yield self.events_channel.local_client.run_job_async(
+                next_batch,
+                self.opts["fun"],
+                self.opts["arg"],
+                "list",
+                raw=self.opts.get("raw", False),
+                ret=self.opts.get("return", ""),
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=self.batch_jid,
+                metadata=self.metadata,
+                io_loop=self.io_loop,
+                listen=False,
+                **self.eauth,
+                **self.extra_job_kwargs,
+            )
 
-    def __del__(self):
-        self.local = None
-        self.event = None
-        self.ioloop = None
-        gc.collect()
+            yield salt.ext.tornado.gen.sleep(self.opts["timeout"])
+
+            # The batch can be done already at this point, which means no self.event
+            if self.event:
+                yield self.find_job(set(next_batch))
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Error in scheduling next batch: %s. Aborting execution",
+                ex,
+                exc_info=True,
+            )
+            self.active = self.active.difference(next_batch)
+            self.close_safe()
diff --git a/salt/master.py b/salt/master.py
index 425b412148..d7182d10b5 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -2,6 +2,7 @@
 This module contains all of the routines needed to set up a master server, this
 involves preparing the three listeners and the workers needed by the master.
 """
+
 import collections
 import copy
 import ctypes
@@ -19,7 +20,6 @@ import time
 import salt.acl
 import salt.auth
 import salt.channel.server
-import salt.cli.batch_async
 import salt.client
 import salt.client.ssh.client
 import salt.crypt
@@ -55,6 +55,7 @@ import salt.utils.user
 import salt.utils.verify
 import salt.utils.zeromq
 import salt.wheel
+from salt.cli.batch_async import BatchAsync, batch_async_required
 from salt.config import DEFAULT_INTERVAL
 from salt.defaults import DEFAULT_TARGET_DELIM
 from salt.ext.tornado.stack_context import StackContext
@@ -2174,9 +2175,9 @@ class ClearFuncs(TransportMethods):
     def publish_batch(self, clear_load, minions, missing):
         batch_load = {}
         batch_load.update(clear_load)
-        batch = salt.cli.batch_async.BatchAsync(
+        batch = BatchAsync(
             self.local.opts,
-            functools.partial(self._prep_jid, clear_load, {}),
+            lambda: self._prep_jid(clear_load, {}),
             batch_load,
         )
         ioloop = salt.ext.tornado.ioloop.IOLoop.current()
@@ -2331,7 +2332,7 @@ class ClearFuncs(TransportMethods):
                         ),
                     },
                 }
-        if extra.get("batch", None):
+        if extra.get("batch", None) and batch_async_required(self.opts, minions, extra):
             return self.publish_batch(clear_load, minions, missing)
 
         jid = self._prep_jid(clear_load, extra)
diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py
index e0774ffff3..bc871aba54 100644
--- a/tests/pytests/unit/cli/test_batch_async.py
+++ b/tests/pytests/unit/cli/test_batch_async.py
@@ -1,7 +1,7 @@
 import pytest
 
 import salt.ext.tornado
-from salt.cli.batch_async import BatchAsync
+from salt.cli.batch_async import BatchAsync, batch_async_required
 from tests.support.mock import MagicMock, patch
 
 
@@ -22,16 +22,44 @@ def batch(temp_salt_master):
         with patch("salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)):
             batch = BatchAsync(
                 opts,
-                MagicMock(side_effect=["1234", "1235", "1236"]),
+                MagicMock(side_effect=["1234", "1235"]),
                 {
                     "tgt": "",
                     "fun": "",
-                    "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+                    "kwargs": {
+                        "batch": "",
+                        "batch_presence_ping_timeout": 1,
+                        "metadata": {"mykey": "myvalue"},
+                    },
                 },
             )
             yield batch
 
 
+@pytest.mark.parametrize(
+    "threshold,minions,batch,expected",
+    [
+        (1, 2, 200, True),
+        (1, 500, 200, True),
+        (0, 2, 200, False),
+        (0, 500, 200, False),
+        (-1, 2, 200, False),
+        (-1, 500, 200, True),
+        (-1, 9, 10, False),
+        (-1, 11, 10, True),
+        (10, 9, 8, False),
+        (10, 9, 10, False),
+        (10, 11, 8, True),
+        (10, 11, 10, True),
+    ],
+)
+def test_batch_async_required(threshold, minions, batch, expected):
+    minions_list = [f"minion{i}.example.org" for i in range(minions)]
+    batch_async_opts = {"batch_async": {"threshold": threshold}}
+    extra = {"batch": batch}
+    assert batch_async_required(batch_async_opts, minions_list, extra) == expected
+
+
 def test_ping_jid(batch):
     assert batch.ping_jid == "1234"
 
@@ -40,10 +68,6 @@ def test_batch_jid(batch):
     assert batch.batch_jid == "1235"
 
 
-def test_find_job_jid(batch):
-    assert batch.find_job_jid == "1236"
-
-
 def test_batch_size(batch):
     """
     Tests passing batch value as a number
@@ -55,58 +79,74 @@ def test_batch_size(batch):
 
 
 def test_batch_start_on_batch_presence_ping_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
+    future_ret = salt.ext.tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
-        # ret = batch_async.start(batch)
+    future.set_result({})
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "salt.ext.tornado.gen.sleep", return_value=future
+    ), patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         ret = batch.start()
-        # assert start_batch is called later with batch_presence_ping_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
         # assert test.ping called
-        assert batch.local.run_job_async.call_args[0] == ("*", "test.ping", [], "glob")
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
+            "*",
+            "test.ping",
+            [],
+            "glob",
+        )
         # assert targeted_minions == all minions matched by tgt
         assert batch.targeted_minions == {"foo", "bar"}
 
 
 def test_batch_start_on_gather_job_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
+    future.set_result({})
+    future_ret = salt.ext.tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     batch.batch_presence_ping_timeout = None
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "salt.ext.tornado.gen.sleep", return_value=future
+    ), patch.object(
+        batch, "start_batch", return_value=future
+    ) as start_batch_mock, patch.object(
+        batch, "batch_presence_ping_timeout", None
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         # ret = batch_async.start(batch)
         ret = batch.start()
-        # assert start_batch is called later with gather_job_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
 
 
 def test_batch_fire_start_event(batch):
     batch.minions = {"foo", "bar"}
     batch.opts = {"batch": "2", "timeout": 5}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    batch.start_batch()
-    assert batch.event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "down_minions": set(),
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/start",
-    )
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.start_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "down_minions": set(),
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/start",
+        )
 
 
 def test_start_batch_calls_next(batch):
-    batch.run_next = MagicMock(return_value=MagicMock())
-    batch.event = MagicMock()
-    batch.start_batch()
-    assert batch.initialized
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.run_next,)
+    batch.initialized = False
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "event", MagicMock()), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.events_channel.master_event.fire_event_async.return_value = future
+        batch.start_batch()
+        assert batch.initialized
+        run_next_mock.assert_called_once()
 
 
 def test_batch_fire_done_event(batch):
@@ -114,69 +154,52 @@ def test_batch_fire_done_event(batch):
     batch.minions = {"foo", "bar"}
     batch.done_minions = {"foo"}
     batch.timedout_minions = {"bar"}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    old_event = batch.event
-    batch.end_batch()
-    assert old_event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "done_minions": batch.done_minions,
-            "down_minions": {"baz"},
-            "timedout_minions": batch.timedout_minions,
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/done",
-    )
-
-
-def test_batch__del__(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.__del__()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.end_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "done_minions": batch.done_minions,
+                "down_minions": {"baz"},
+                "timedout_minions": batch.timedout_minions,
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/done",
+        )
 
 
 def test_batch_close_safe(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.patterns = {
-        ("salt/job/1234/ret/*", "find_job_return"),
-        ("salt/job/4321/ret/*", "find_job_return"),
-    }
-    batch.close_safe()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
-    assert len(event.unsubscribe.mock_calls) == 2
-    assert len(event.remove_event_handler.mock_calls) == 1
+    with patch.object(
+        batch, "events_channel", MagicMock()
+    ) as events_channel_mock, patch.object(batch, "event", MagicMock()):
+        batch.close_safe()
+        batch.close_safe()
+        assert batch.events_channel is None
+        assert batch.event is None
+        events_channel_mock.unsubscribe.assert_called_once()
+        events_channel_mock.unuse.assert_called_once()
 
 
 def test_batch_next(batch):
-    batch.event = MagicMock()
     batch.opts["fun"] = "my.fun"
     batch.opts["arg"] = []
-    batch._get_next = MagicMock(return_value={"foo", "bar"})
     batch.batch_size = 2
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    future.set_result({})
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "_get_next", return_value={"foo", "bar"}), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.run_next()
-        assert batch.local.run_job_async.call_args[0] == (
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
             {"foo", "bar"},
             "my.fun",
             [],
             "list",
         )
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.find_job,
-            {"foo", "bar"},
-        )
+        assert find_job_mock.call_args[0] == ({"foo", "bar"},)
         assert batch.active == {"bar", "foo"}
 
 
@@ -239,124 +262,132 @@ def test_next_batch_all_timedout(batch):
 
 def test_batch__event_handler_ping_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
     batch.start()
     assert batch.minions == set()
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+    )
     assert batch.minions == {"foo"}
     assert batch.done_minions == set()
 
 
 def test_batch__event_handler_call_start_batch_when_all_pings_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_called_once()
 
 
 def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(batch):
     batch.targeted_minions = {"foo", "bar"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_not_called()
 
 
 def test_batch__event_handler_batch_run_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch.active = {"foo"}
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.active == set()
-    assert batch.done_minions == {"foo"}
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.schedule_next,)
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "schedule_next", return_value=future
+    ) as schedule_next_mock:
+        batch.start()
+        batch.active = {"foo"}
+        batch._BatchAsync__event_handler(
+            "salt/job/1235/ret/foo", {"id": "foo"}, "batch_run"
+        )
+        assert batch.active == set()
+        assert batch.done_minions == {"foo"}
+        schedule_next_mock.assert_called_once()
 
 
 def test_batch__event_handler_find_job_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(
-            return_value=(
-                "salt/job/1236/ret/foo",
-                {"id": "foo", "return": "deadbeaf"},
-            )
-        )
-    )
     batch.start()
-    batch.patterns.add(("salt/job/1236/ret/*", "find_job_return"))
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1236/ret/foo", {"id": "foo", "return": "deadbeaf"}, "find_job_return"
+    )
     assert batch.find_job_returned == {"foo"}
 
 
 def test_batch_run_next_end_batch_when_no_next(batch):
-    batch.end_batch = MagicMock()
-    batch._get_next = MagicMock(return_value={})
-    batch.run_next()
-    assert len(batch.end_batch.mock_calls) == 1
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "_get_next", return_value={}
+    ), patch.object(
+        batch, "end_batch", return_value=future
+    ) as end_batch_mock:
+        batch.run_next()
+        end_batch_mock.assert_called_once()
 
 
 def test_batch_find_job(batch):
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo", "bar"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_find_job_with_done_minions(batch):
     batch.done_minions = {"bar"}
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_check_find_job_did_not_return(batch):
-    batch.event = MagicMock()
     batch.active = {"foo"}
     batch.find_job_returned = set()
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.find_job_returned == set()
-    assert batch.active == set()
-    assert len(batch.event.io_loop.add_callback.mock_calls) == 0
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        assert batch.find_job_returned == set()
+        assert batch.active == set()
+        find_job_mock.assert_not_called()
 
 
 def test_batch_check_find_job_did_return(batch):
-    batch.event = MagicMock()
     batch.find_job_returned = {"foo"}
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_batch_check_find_job_multiple_states(batch):
-    batch.event = MagicMock()
     # currently running minions
     batch.active = {"foo", "bar"}
 
@@ -372,21 +403,28 @@ def test_batch_check_find_job_multiple_states(batch):
     # both not yet done but only 'foo' responded to find_job
     not_done = {"foo", "bar"}
 
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job(not_done, jid="1234")
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
 
-    # assert 'bar' removed from active
-    assert batch.active == {"foo"}
+    with patch.object(batch, "schedule_next", return_value=future), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.check_find_job(not_done, jid="1234")
 
-    # assert 'bar' added to timedout_minions
-    assert batch.timedout_minions == {"bar", "faz"}
+        # assert 'bar' removed from active
+        assert batch.active == {"foo"}
 
-    # assert 'find_job' schedueled again only for 'foo'
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+        # assert 'bar' added to timedout_minions
+        assert batch.timedout_minions == {"bar", "faz"}
+
+        # assert 'find_job' schedueled again only for 'foo'
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_only_on_run_next_is_scheduled(batch):
-    batch.event = MagicMock()
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
     batch.scheduled = True
-    batch.schedule_next()
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    with patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.schedule_next()
+        run_next_mock.assert_not_called()
-- 
2.45.0

openSUSE Build Service is sponsored by