File prevent-oom-with-high-amount-of-batch-async-calls-bs.patch of Package salt

From d57472b4fa2213ec551197ee2e147aef364fdcfe Mon Sep 17 00:00:00 2001
From: Victor Zhestkov <vzhestkov@suse.com>
Date: Wed, 15 May 2024 11:47:35 +0200
Subject: [PATCH] Prevent OOM with high amount of batch async calls
 (bsc#1216063)

* Refactor batch_async implementation

* Fix batch_async tests after refactoring
---
 salt/cli/batch_async.py                    | 584 ++++++++++++++-------
 salt/master.py                             |   9 +-
 tests/pytests/unit/cli/test_batch_async.py | 360 +++++++------
 3 files changed, 597 insertions(+), 356 deletions(-)

diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
index 1012ce37cc..5d49993faa 100644
--- a/salt/cli/batch_async.py
+++ b/salt/cli/batch_async.py
@@ -2,18 +2,193 @@
 Execute a job on the targeted minions by using a moving window of fixed size `batch`.
 """
 
-import gc
-
-# pylint: enable=import-error,no-name-in-module,redefined-builtin
 import logging
+import re
 
 import salt.client
 import salt.ext.tornado
+import salt.utils.event
 from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+from salt.ext.tornado.iostream import StreamClosedError
 
 log = logging.getLogger(__name__)
 
 
+__SHARED_EVENTS_CHANNEL = None
+
+
+def _get_shared_events_channel(opts, io_loop):
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is None:
+        __SHARED_EVENTS_CHANNEL = SharedEventsChannel(opts, io_loop)
+    return __SHARED_EVENTS_CHANNEL
+
+
+def _destroy_unused_shared_events_channel():
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is not None and __SHARED_EVENTS_CHANNEL.destroy_unused():
+        __SHARED_EVENTS_CHANNEL = None
+
+
+def batch_async_required(opts, minions, extra):
+    """
+    Check opts to identify if batch async is required for the operation.
+    """
+    if not isinstance(minions, list):
+        False
+    batch_async_opts = opts.get("batch_async", {})
+    batch_async_threshold = (
+        batch_async_opts.get("threshold", 1)
+        if isinstance(batch_async_opts, dict)
+        else 1
+    )
+    if batch_async_threshold == -1:
+        batch_size = get_bnum(extra, minions, True)
+        return len(minions) >= batch_size
+    elif batch_async_threshold > 0:
+        return len(minions) >= batch_async_threshold
+    return False
+
+
+class SharedEventsChannel:
+    def __init__(self, opts, io_loop):
+        self.io_loop = io_loop
+        self.local_client = salt.client.get_local_client(
+            opts["conf_file"], io_loop=self.io_loop
+        )
+        self.master_event = salt.utils.event.get_event(
+            "master",
+            sock_dir=self.local_client.opts["sock_dir"],
+            opts=self.local_client.opts,
+            listen=True,
+            io_loop=self.io_loop,
+            keep_loop=True,
+        )
+        self.master_event.set_event_handler(self.__handle_event)
+        if self.master_event.subscriber.stream:
+            self.master_event.subscriber.stream.set_close_callback(self.__handle_close)
+        self._re_tag_ret_event = re.compile(r"salt\/job\/(\d+)\/ret\/.*")
+        self._subscribers = {}
+        self._subscriptions = {}
+        self._used_by = set()
+        batch_async_opts = opts.get("batch_async", {})
+        if not isinstance(batch_async_opts, dict):
+            batch_async_opts = {}
+        self._subscriber_reconnect_tries = batch_async_opts.get(
+            "subscriber_reconnect_tries", 5
+        )
+        self._subscriber_reconnect_interval = batch_async_opts.get(
+            "subscriber_reconnect_interval", 1.0
+        )
+        self._reconnecting_subscriber = False
+
+    def subscribe(self, jid, op, subscriber_id, handler):
+        if subscriber_id not in self._subscribers:
+            self._subscribers[subscriber_id] = set()
+        if jid not in self._subscriptions:
+            self._subscriptions[jid] = []
+        self._subscribers[subscriber_id].add(jid)
+        if (op, subscriber_id, handler) not in self._subscriptions[jid]:
+            self._subscriptions[jid].append((op, subscriber_id, handler))
+        if not self.master_event.subscriber.connected():
+            self.__reconnect_subscriber()
+
+    def unsubscribe(self, jid, op, subscriber_id):
+        if subscriber_id not in self._subscribers:
+            return
+        jids = self._subscribers[subscriber_id].copy()
+        if jid is not None:
+            jids = set(jid)
+        for i_jid in jids:
+            self._subscriptions[i_jid] = list(
+                filter(
+                    lambda x: not (op in (x[0], None) and x[1] == subscriber_id),
+                    self._subscriptions.get(i_jid, []),
+                )
+            )
+            self._subscribers[subscriber_id].discard(i_jid)
+        self._subscriptions = dict(filter(lambda x: x[1], self._subscriptions.items()))
+        if not self._subscribers[subscriber_id]:
+            del self._subscribers[subscriber_id]
+
+    @salt.ext.tornado.gen.coroutine
+    def __handle_close(self):
+        if not self._subscriptions:
+            return
+        log.warning("Master Event Subscriber was closed. Trying to reconnect...")
+        yield self.__reconnect_subscriber()
+
+    @salt.ext.tornado.gen.coroutine
+    def __handle_event(self, raw):
+        if self.master_event is None:
+            return
+        try:
+            tag, data = self.master_event.unpack(raw)
+            tag_match = self._re_tag_ret_event.match(tag)
+            if tag_match:
+                jid = tag_match.group(1)
+                if jid in self._subscriptions:
+                    for op, _, handler in self._subscriptions[jid]:
+                        yield handler(tag, data, op)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
+
+    @salt.ext.tornado.gen.coroutine
+    def __reconnect_subscriber(self):
+        if self.master_event.subscriber.connected() or self._reconnecting_subscriber:
+            return
+        self._reconnecting_subscriber = True
+        max_tries = max(1, int(self._subscriber_reconnect_tries))
+        _try = 1
+        while _try <= max_tries:
+            log.info(
+                "Trying to reconnect to event publisher (try %d of %d) ...",
+                _try,
+                max_tries,
+            )
+            try:
+                yield self.master_event.subscriber.connect()
+            except StreamClosedError:
+                log.warning(
+                    "Unable to reconnect to event publisher (try %d of %d)",
+                    _try,
+                    max_tries,
+                )
+            if self.master_event.subscriber.connected():
+                self.master_event.subscriber.stream.set_close_callback(
+                    self.__handle_close
+                )
+                log.info("Event publisher connection restored")
+                self._reconnecting_subscriber = False
+                return
+            if _try < max_tries:
+                yield salt.ext.tornado.gen.sleep(self._subscriber_reconnect_interval)
+            _try += 1
+        self._reconnecting_subscriber = False
+
+    def use(self, subscriber_id):
+        self._used_by.add(subscriber_id)
+        return self
+
+    def unuse(self, subscriber_id):
+        self._used_by.discard(subscriber_id)
+
+    def destroy_unused(self):
+        if self._used_by:
+            return False
+        self.master_event.remove_event_handler(self.__handle_event)
+        self.master_event.destroy()
+        self.master_event = None
+        self.local_client.destroy()
+        self.local_client = None
+        return True
+
+
 class BatchAsync:
     """
     Run a job on the targeted minions by using a moving window of fixed size `batch`.
@@ -28,14 +203,14 @@ class BatchAsync:
         - gather_job_timeout: `find_job` timeout
         - timeout: time to wait before firing a `find_job`
 
-    When the batch stars, a `start` event is fired:
+    When the batch starts, a `start` event is fired:
          - tag: salt/batch/<batch-jid>/start
          - data: {
              "available_minions": self.minions,
              "down_minions": targeted_minions - presence_ping_minions
            }
 
-    When the batch ends, an `done` event is fired:
+    When the batch ends, a `done` event is fired:
         - tag: salt/batch/<batch-jid>/done
         - data: {
              "available_minions": self.minions,
@@ -45,17 +220,26 @@ class BatchAsync:
          }
     """
 
-    def __init__(self, parent_opts, jid_gen, clear_load):
-        ioloop = salt.ext.tornado.ioloop.IOLoop.current()
-        self.local = salt.client.get_local_client(
-            parent_opts["conf_file"], io_loop=ioloop
+    def __init__(self, opts, jid_gen, clear_load):
+        self.extra_job_kwargs = {}
+        kwargs = clear_load.get("kwargs", {})
+        for kwarg in ("module_executors", "executor_opts"):
+            if kwarg in kwargs:
+                self.extra_job_kwargs[kwarg] = kwargs[kwarg]
+            elif kwarg in opts:
+                self.extra_job_kwargs[kwarg] = opts[kwarg]
+        self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()
+        self.events_channel = _get_shared_events_channel(opts, self.io_loop).use(
+            id(self)
         )
         if "gather_job_timeout" in clear_load["kwargs"]:
             clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
                 "gather_job_timeout"
             )
         else:
-            clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+            clear_load["gather_job_timeout"] = self.events_channel.local_client.opts[
+                "gather_job_timeout"
+            ]
         self.batch_presence_ping_timeout = clear_load["kwargs"].get(
             "batch_presence_ping_timeout", None
         )
@@ -64,8 +248,8 @@ class BatchAsync:
             clear_load.pop("tgt"),
             clear_load.pop("fun"),
             clear_load["kwargs"].pop("batch"),
-            self.local.opts,
-            **clear_load
+            self.events_channel.local_client.opts,
+            **clear_load,
         )
         self.eauth = batch_get_eauth(clear_load["kwargs"])
         self.metadata = clear_load["kwargs"].get("metadata", {})
@@ -78,54 +262,45 @@ class BatchAsync:
         self.jid_gen = jid_gen
         self.ping_jid = jid_gen()
         self.batch_jid = jid_gen()
-        self.find_job_jid = jid_gen()
         self.find_job_returned = set()
+        self.metadata.update({"batch_jid": self.batch_jid, "ping_jid": self.ping_jid})
         self.ended = False
-        self.event = salt.utils.event.get_event(
-            "master",
-            self.opts["sock_dir"],
-            self.opts["transport"],
-            opts=self.opts,
-            listen=True,
-            io_loop=ioloop,
-            keep_loop=True,
-        )
+        self.event = self.events_channel.master_event
         self.scheduled = False
-        self.patterns = set()
 
     def __set_event_handler(self):
-        ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid)
-        batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid)
-        self.event.subscribe(ping_return_pattern, match_type="glob")
-        self.event.subscribe(batch_return_pattern, match_type="glob")
-        self.patterns = {
-            (ping_return_pattern, "ping_return"),
-            (batch_return_pattern, "batch_run"),
-        }
-        self.event.set_event_handler(self.__event_handler)
+        self.events_channel.subscribe(
+            self.ping_jid, "ping_return", id(self), self.__event_handler
+        )
+        self.events_channel.subscribe(
+            self.batch_jid, "batch_run", id(self), self.__event_handler
+        )
 
-    def __event_handler(self, raw):
+    @salt.ext.tornado.gen.coroutine
+    def __event_handler(self, tag, data, op):
         if not self.event:
             return
         try:
-            mtag, data = self.event.unpack(raw)
-            for (pattern, op) in self.patterns:
-                if mtag.startswith(pattern[:-1]):
-                    minion = data["id"]
-                    if op == "ping_return":
-                        self.minions.add(minion)
-                        if self.targeted_minions == self.minions:
-                            self.event.io_loop.spawn_callback(self.start_batch)
-                    elif op == "find_job_return":
-                        if data.get("return", None):
-                            self.find_job_returned.add(minion)
-                    elif op == "batch_run":
-                        if minion in self.active:
-                            self.active.remove(minion)
-                            self.done_minions.add(minion)
-                            self.event.io_loop.spawn_callback(self.schedule_next)
-        except Exception as ex:
-            log.error("Exception occured while processing event: {}".format(ex))
+            minion = data["id"]
+            if op == "ping_return":
+                self.minions.add(minion)
+                if self.targeted_minions == self.minions:
+                    yield self.start_batch()
+            elif op == "find_job_return":
+                if data.get("return", None):
+                    self.find_job_returned.add(minion)
+            elif op == "batch_run":
+                if minion in self.active:
+                    self.active.remove(minion)
+                    self.done_minions.add(minion)
+                    yield self.schedule_next()
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
 
     def _get_next(self):
         to_run = (
@@ -139,176 +314,203 @@ class BatchAsync:
         )
         return set(list(to_run)[:next_batch_size])
 
+    @salt.ext.tornado.gen.coroutine
     def check_find_job(self, batch_minions, jid):
-        if self.event:
-            find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
-            self.event.unsubscribe(find_job_return_pattern, match_type="glob")
-            self.patterns.remove((find_job_return_pattern, "find_job_return"))
-
-            timedout_minions = batch_minions.difference(
-                self.find_job_returned
-            ).difference(self.done_minions)
-            self.timedout_minions = self.timedout_minions.union(timedout_minions)
-            self.active = self.active.difference(self.timedout_minions)
-            running = batch_minions.difference(self.done_minions).difference(
-                self.timedout_minions
-            )
+        """
+        Check if the job with specified ``jid`` was finished on the minions
+        """
+        if not self.event:
+            return
+        self.events_channel.unsubscribe(jid, "find_job_return", id(self))
 
-            if timedout_minions:
-                self.schedule_next()
+        timedout_minions = batch_minions.difference(self.find_job_returned).difference(
+            self.done_minions
+        )
+        self.timedout_minions = self.timedout_minions.union(timedout_minions)
+        self.active = self.active.difference(self.timedout_minions)
+        running = batch_minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
 
-            if self.event and running:
-                self.find_job_returned = self.find_job_returned.difference(running)
-                self.event.io_loop.spawn_callback(self.find_job, running)
+        if timedout_minions:
+            yield self.schedule_next()
+
+        if self.event and running:
+            self.find_job_returned = self.find_job_returned.difference(running)
+            yield self.find_job(running)
 
     @salt.ext.tornado.gen.coroutine
     def find_job(self, minions):
-        if self.event:
-            not_done = minions.difference(self.done_minions).difference(
-                self.timedout_minions
+        """
+        Find if the job was finished on the minions
+        """
+        if not self.event:
+            return
+        not_done = minions.difference(self.done_minions).difference(
+            self.timedout_minions
+        )
+        if not not_done:
+            return
+        try:
+            jid = self.jid_gen()
+            self.events_channel.subscribe(
+                jid, "find_job_return", id(self), self.__event_handler
             )
-            try:
-                if not_done:
-                    jid = self.jid_gen()
-                    find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
-                    self.patterns.add((find_job_return_pattern, "find_job_return"))
-                    self.event.subscribe(find_job_return_pattern, match_type="glob")
-                    ret = yield self.local.run_job_async(
-                        not_done,
-                        "saltutil.find_job",
-                        [self.batch_jid],
-                        "list",
-                        gather_job_timeout=self.opts["gather_job_timeout"],
-                        jid=jid,
-                        **self.eauth
-                    )
-                    yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"])
-                    if self.event:
-                        self.event.io_loop.spawn_callback(
-                            self.check_find_job, not_done, jid
-                        )
-            except Exception as ex:
-                log.error(
-                    "Exception occured handling batch async: {}. Aborting execution.".format(
-                        ex
-                    )
-                )
-                self.close_safe()
+            ret = yield self.events_channel.local_client.run_job_async(
+                not_done,
+                "saltutil.find_job",
+                [self.batch_jid],
+                "list",
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=jid,
+                io_loop=self.io_loop,
+                listen=False,
+                **self.eauth,
+            )
+            yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"])
+            if self.event:
+                yield self.check_find_job(not_done, jid)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured handling batch async: %s. Aborting execution.",
+                ex,
+                exc_info=True,
+            )
+            self.close_safe()
 
     @salt.ext.tornado.gen.coroutine
     def start(self):
+        """
+        Start the batch execution
+        """
+        if not self.event:
+            return
+        self.__set_event_handler()
+        ping_return = yield self.events_channel.local_client.run_job_async(
+            self.opts["tgt"],
+            "test.ping",
+            [],
+            self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            io_loop=self.io_loop,
+            listen=False,
+            **self.eauth,
+        )
+        self.targeted_minions = set(ping_return["minions"])
+        # start batching even if not all minions respond to ping
+        yield salt.ext.tornado.gen.sleep(
+            self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
+        )
         if self.event:
-            self.__set_event_handler()
-            ping_return = yield self.local.run_job_async(
-                self.opts["tgt"],
-                "test.ping",
-                [],
-                self.opts.get(
-                    "selected_target_option", self.opts.get("tgt_type", "glob")
-                ),
-                gather_job_timeout=self.opts["gather_job_timeout"],
-                jid=self.ping_jid,
-                metadata=self.metadata,
-                **self.eauth
-            )
-            self.targeted_minions = set(ping_return["minions"])
-            # start batching even if not all minions respond to ping
-            yield salt.ext.tornado.gen.sleep(
-                self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
-            )
-            if self.event:
-                self.event.io_loop.spawn_callback(self.start_batch)
+            yield self.start_batch()
 
     @salt.ext.tornado.gen.coroutine
     def start_batch(self):
-        if not self.initialized:
-            self.batch_size = get_bnum(self.opts, self.minions, True)
-            self.initialized = True
-            data = {
-                "available_minions": self.minions,
-                "down_minions": self.targeted_minions.difference(self.minions),
-                "metadata": self.metadata,
-            }
-            ret = self.event.fire_event(
-                data, "salt/batch/{}/start".format(self.batch_jid)
-            )
-            if self.event:
-                self.event.io_loop.spawn_callback(self.run_next)
+        """
+        Fire `salt/batch/*/start` and continue batch with `run_next`
+        """
+        if self.initialized:
+            return
+        self.batch_size = get_bnum(self.opts, self.minions, True)
+        self.initialized = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "metadata": self.metadata,
+        }
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/start"
+        )
+        if self.event:
+            yield self.run_next()
 
     @salt.ext.tornado.gen.coroutine
     def end_batch(self):
+        """
+        End the batch and call safe closing
+        """
         left = self.minions.symmetric_difference(
             self.done_minions.union(self.timedout_minions)
         )
-        if not left and not self.ended:
-            self.ended = True
-            data = {
-                "available_minions": self.minions,
-                "down_minions": self.targeted_minions.difference(self.minions),
-                "done_minions": self.done_minions,
-                "timedout_minions": self.timedout_minions,
-                "metadata": self.metadata,
-            }
-            self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid))
-
-            # release to the IOLoop to allow the event to be published
-            # before closing batch async execution
-            yield salt.ext.tornado.gen.sleep(1)
-            self.close_safe()
+        # Send salt/batch/*/done only if there is nothing to do
+        # and the event haven't been sent already
+        if left or self.ended:
+            return
+        self.ended = True
+        data = {
+            "available_minions": self.minions,
+            "down_minions": self.targeted_minions.difference(self.minions),
+            "done_minions": self.done_minions,
+            "timedout_minions": self.timedout_minions,
+            "metadata": self.metadata,
+        }
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/done"
+        )
+
+        # release to the IOLoop to allow the event to be published
+        # before closing batch async execution
+        yield salt.ext.tornado.gen.sleep(1)
+        self.close_safe()
 
     def close_safe(self):
-        for (pattern, label) in self.patterns:
-            self.event.unsubscribe(pattern, match_type="glob")
-        self.event.remove_event_handler(self.__event_handler)
+        if self.events_channel is not None:
+            self.events_channel.unsubscribe(None, None, id(self))
+            self.events_channel.unuse(id(self))
+            self.events_channel = None
+            _destroy_unused_shared_events_channel()
         self.event = None
-        self.local = None
-        self.ioloop = None
-        del self
-        gc.collect()
 
     @salt.ext.tornado.gen.coroutine
     def schedule_next(self):
-        if not self.scheduled:
-            self.scheduled = True
-            # call later so that we maybe gather more returns
-            yield salt.ext.tornado.gen.sleep(self.batch_delay)
-            if self.event:
-                self.event.io_loop.spawn_callback(self.run_next)
+        if self.scheduled:
+            return
+        self.scheduled = True
+        # call later so that we maybe gather more returns
+        yield salt.ext.tornado.gen.sleep(self.batch_delay)
+        if self.event:
+            yield self.run_next()
 
     @salt.ext.tornado.gen.coroutine
     def run_next(self):
+        """
+        Continue batch execution with the next targets
+        """
         self.scheduled = False
         next_batch = self._get_next()
-        if next_batch:
-            self.active = self.active.union(next_batch)
-            try:
-                ret = yield self.local.run_job_async(
-                    next_batch,
-                    self.opts["fun"],
-                    self.opts["arg"],
-                    "list",
-                    raw=self.opts.get("raw", False),
-                    ret=self.opts.get("return", ""),
-                    gather_job_timeout=self.opts["gather_job_timeout"],
-                    jid=self.batch_jid,
-                    metadata=self.metadata,
-                )
-
-                yield salt.ext.tornado.gen.sleep(self.opts["timeout"])
-
-                # The batch can be done already at this point, which means no self.event
-                if self.event:
-                    self.event.io_loop.spawn_callback(self.find_job, set(next_batch))
-            except Exception as ex:
-                log.error("Error in scheduling next batch: %s. Aborting execution", ex)
-                self.active = self.active.difference(next_batch)
-                self.close_safe()
-        else:
+        if not next_batch:
             yield self.end_batch()
-        gc.collect()
+            return
+        self.active = self.active.union(next_batch)
+        try:
+            ret = yield self.events_channel.local_client.run_job_async(
+                next_batch,
+                self.opts["fun"],
+                self.opts["arg"],
+                "list",
+                raw=self.opts.get("raw", False),
+                ret=self.opts.get("return", ""),
+                gather_job_timeout=self.opts["gather_job_timeout"],
+                jid=self.batch_jid,
+                metadata=self.metadata,
+                io_loop=self.io_loop,
+                listen=False,
+                **self.eauth,
+                **self.extra_job_kwargs,
+            )
 
-    def __del__(self):
-        self.local = None
-        self.event = None
-        self.ioloop = None
-        gc.collect()
+            yield salt.ext.tornado.gen.sleep(self.opts["timeout"])
+
+            # The batch can be done already at this point, which means no self.event
+            if self.event:
+                yield self.find_job(set(next_batch))
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Error in scheduling next batch: %s. Aborting execution",
+                ex,
+                exc_info=True,
+            )
+            self.active = self.active.difference(next_batch)
+            self.close_safe()
diff --git a/salt/master.py b/salt/master.py
index 425b412148..d7182d10b5 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -2,6 +2,7 @@
 This module contains all of the routines needed to set up a master server, this
 involves preparing the three listeners and the workers needed by the master.
 """
+
 import collections
 import copy
 import ctypes
@@ -19,7 +20,6 @@ import time
 import salt.acl
 import salt.auth
 import salt.channel.server
-import salt.cli.batch_async
 import salt.client
 import salt.client.ssh.client
 import salt.crypt
@@ -55,6 +55,7 @@ import salt.utils.user
 import salt.utils.verify
 import salt.utils.zeromq
 import salt.wheel
+from salt.cli.batch_async import BatchAsync, batch_async_required
 from salt.config import DEFAULT_INTERVAL
 from salt.defaults import DEFAULT_TARGET_DELIM
 from salt.ext.tornado.stack_context import StackContext
@@ -2174,9 +2175,9 @@ class ClearFuncs(TransportMethods):
     def publish_batch(self, clear_load, minions, missing):
         batch_load = {}
         batch_load.update(clear_load)
-        batch = salt.cli.batch_async.BatchAsync(
+        batch = BatchAsync(
             self.local.opts,
-            functools.partial(self._prep_jid, clear_load, {}),
+            lambda: self._prep_jid(clear_load, {}),
             batch_load,
         )
         ioloop = salt.ext.tornado.ioloop.IOLoop.current()
@@ -2331,7 +2332,7 @@ class ClearFuncs(TransportMethods):
                         ),
                     },
                 }
-        if extra.get("batch", None):
+        if extra.get("batch", None) and batch_async_required(self.opts, minions, extra):
             return self.publish_batch(clear_load, minions, missing)
 
         jid = self._prep_jid(clear_load, extra)
diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py
index e0774ffff3..bc871aba54 100644
--- a/tests/pytests/unit/cli/test_batch_async.py
+++ b/tests/pytests/unit/cli/test_batch_async.py
@@ -1,7 +1,7 @@
 import pytest
 
 import salt.ext.tornado
-from salt.cli.batch_async import BatchAsync
+from salt.cli.batch_async import BatchAsync, batch_async_required
 from tests.support.mock import MagicMock, patch
 
 
@@ -22,16 +22,44 @@ def batch(temp_salt_master):
         with patch("salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)):
             batch = BatchAsync(
                 opts,
-                MagicMock(side_effect=["1234", "1235", "1236"]),
+                MagicMock(side_effect=["1234", "1235"]),
                 {
                     "tgt": "",
                     "fun": "",
-                    "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+                    "kwargs": {
+                        "batch": "",
+                        "batch_presence_ping_timeout": 1,
+                        "metadata": {"mykey": "myvalue"},
+                    },
                 },
             )
             yield batch
 
 
+@pytest.mark.parametrize(
+    "threshold,minions,batch,expected",
+    [
+        (1, 2, 200, True),
+        (1, 500, 200, True),
+        (0, 2, 200, False),
+        (0, 500, 200, False),
+        (-1, 2, 200, False),
+        (-1, 500, 200, True),
+        (-1, 9, 10, False),
+        (-1, 11, 10, True),
+        (10, 9, 8, False),
+        (10, 9, 10, False),
+        (10, 11, 8, True),
+        (10, 11, 10, True),
+    ],
+)
+def test_batch_async_required(threshold, minions, batch, expected):
+    minions_list = [f"minion{i}.example.org" for i in range(minions)]
+    batch_async_opts = {"batch_async": {"threshold": threshold}}
+    extra = {"batch": batch}
+    assert batch_async_required(batch_async_opts, minions_list, extra) == expected
+
+
 def test_ping_jid(batch):
     assert batch.ping_jid == "1234"
 
@@ -40,10 +68,6 @@ def test_batch_jid(batch):
     assert batch.batch_jid == "1235"
 
 
-def test_find_job_jid(batch):
-    assert batch.find_job_jid == "1236"
-
-
 def test_batch_size(batch):
     """
     Tests passing batch value as a number
@@ -55,58 +79,74 @@ def test_batch_size(batch):
 
 
 def test_batch_start_on_batch_presence_ping_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
+    future_ret = salt.ext.tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
-        # ret = batch_async.start(batch)
+    future.set_result({})
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "salt.ext.tornado.gen.sleep", return_value=future
+    ), patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         ret = batch.start()
-        # assert start_batch is called later with batch_presence_ping_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
         # assert test.ping called
-        assert batch.local.run_job_async.call_args[0] == ("*", "test.ping", [], "glob")
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
+            "*",
+            "test.ping",
+            [],
+            "glob",
+        )
         # assert targeted_minions == all minions matched by tgt
         assert batch.targeted_minions == {"foo", "bar"}
 
 
 def test_batch_start_on_gather_job_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
+    future.set_result({})
+    future_ret = salt.ext.tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     batch.batch_presence_ping_timeout = None
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "salt.ext.tornado.gen.sleep", return_value=future
+    ), patch.object(
+        batch, "start_batch", return_value=future
+    ) as start_batch_mock, patch.object(
+        batch, "batch_presence_ping_timeout", None
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         # ret = batch_async.start(batch)
         ret = batch.start()
-        # assert start_batch is called later with gather_job_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
 
 
 def test_batch_fire_start_event(batch):
     batch.minions = {"foo", "bar"}
     batch.opts = {"batch": "2", "timeout": 5}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    batch.start_batch()
-    assert batch.event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "down_minions": set(),
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/start",
-    )
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.start_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "down_minions": set(),
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/start",
+        )
 
 
 def test_start_batch_calls_next(batch):
-    batch.run_next = MagicMock(return_value=MagicMock())
-    batch.event = MagicMock()
-    batch.start_batch()
-    assert batch.initialized
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.run_next,)
+    batch.initialized = False
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "event", MagicMock()), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.events_channel.master_event.fire_event_async.return_value = future
+        batch.start_batch()
+        assert batch.initialized
+        run_next_mock.assert_called_once()
 
 
 def test_batch_fire_done_event(batch):
@@ -114,69 +154,52 @@ def test_batch_fire_done_event(batch):
     batch.minions = {"foo", "bar"}
     batch.done_minions = {"foo"}
     batch.timedout_minions = {"bar"}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    old_event = batch.event
-    batch.end_batch()
-    assert old_event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "done_minions": batch.done_minions,
-            "down_minions": {"baz"},
-            "timedout_minions": batch.timedout_minions,
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/done",
-    )
-
-
-def test_batch__del__(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.__del__()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.end_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "done_minions": batch.done_minions,
+                "down_minions": {"baz"},
+                "timedout_minions": batch.timedout_minions,
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/done",
+        )
 
 
 def test_batch_close_safe(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.patterns = {
-        ("salt/job/1234/ret/*", "find_job_return"),
-        ("salt/job/4321/ret/*", "find_job_return"),
-    }
-    batch.close_safe()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
-    assert len(event.unsubscribe.mock_calls) == 2
-    assert len(event.remove_event_handler.mock_calls) == 1
+    with patch.object(
+        batch, "events_channel", MagicMock()
+    ) as events_channel_mock, patch.object(batch, "event", MagicMock()):
+        batch.close_safe()
+        batch.close_safe()
+        assert batch.events_channel is None
+        assert batch.event is None
+        events_channel_mock.unsubscribe.assert_called_once()
+        events_channel_mock.unuse.assert_called_once()
 
 
 def test_batch_next(batch):
-    batch.event = MagicMock()
     batch.opts["fun"] = "my.fun"
     batch.opts["arg"] = []
-    batch._get_next = MagicMock(return_value={"foo", "bar"})
     batch.batch_size = 2
     future = salt.ext.tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    future.set_result({})
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "_get_next", return_value={"foo", "bar"}), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.run_next()
-        assert batch.local.run_job_async.call_args[0] == (
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
             {"foo", "bar"},
             "my.fun",
             [],
             "list",
         )
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.find_job,
-            {"foo", "bar"},
-        )
+        assert find_job_mock.call_args[0] == ({"foo", "bar"},)
         assert batch.active == {"bar", "foo"}
 
 
@@ -239,124 +262,132 @@ def test_next_batch_all_timedout(batch):
 
 def test_batch__event_handler_ping_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
     batch.start()
     assert batch.minions == set()
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+    )
     assert batch.minions == {"foo"}
     assert batch.done_minions == set()
 
 
 def test_batch__event_handler_call_start_batch_when_all_pings_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_called_once()
 
 
 def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(batch):
     batch.targeted_minions = {"foo", "bar"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_not_called()
 
 
 def test_batch__event_handler_batch_run_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch.active = {"foo"}
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.active == set()
-    assert batch.done_minions == {"foo"}
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.schedule_next,)
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "schedule_next", return_value=future
+    ) as schedule_next_mock:
+        batch.start()
+        batch.active = {"foo"}
+        batch._BatchAsync__event_handler(
+            "salt/job/1235/ret/foo", {"id": "foo"}, "batch_run"
+        )
+        assert batch.active == set()
+        assert batch.done_minions == {"foo"}
+        schedule_next_mock.assert_called_once()
 
 
 def test_batch__event_handler_find_job_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(
-            return_value=(
-                "salt/job/1236/ret/foo",
-                {"id": "foo", "return": "deadbeaf"},
-            )
-        )
-    )
     batch.start()
-    batch.patterns.add(("salt/job/1236/ret/*", "find_job_return"))
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1236/ret/foo", {"id": "foo", "return": "deadbeaf"}, "find_job_return"
+    )
     assert batch.find_job_returned == {"foo"}
 
 
 def test_batch_run_next_end_batch_when_no_next(batch):
-    batch.end_batch = MagicMock()
-    batch._get_next = MagicMock(return_value={})
-    batch.run_next()
-    assert len(batch.end_batch.mock_calls) == 1
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "_get_next", return_value={}
+    ), patch.object(
+        batch, "end_batch", return_value=future
+    ) as end_batch_mock:
+        batch.run_next()
+        end_batch_mock.assert_called_once()
 
 
 def test_batch_find_job(batch):
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo", "bar"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_find_job_with_done_minions(batch):
     batch.done_minions = {"bar"}
-    batch.event = MagicMock()
     future = salt.ext.tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("salt.ext.tornado.gen.sleep", return_value=future):
+    with patch("salt.ext.tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_check_find_job_did_not_return(batch):
-    batch.event = MagicMock()
     batch.active = {"foo"}
     batch.find_job_returned = set()
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.find_job_returned == set()
-    assert batch.active == set()
-    assert len(batch.event.io_loop.add_callback.mock_calls) == 0
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        assert batch.find_job_returned == set()
+        assert batch.active == set()
+        find_job_mock.assert_not_called()
 
 
 def test_batch_check_find_job_did_return(batch):
-    batch.event = MagicMock()
     batch.find_job_returned = {"foo"}
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_batch_check_find_job_multiple_states(batch):
-    batch.event = MagicMock()
     # currently running minions
     batch.active = {"foo", "bar"}
 
@@ -372,21 +403,28 @@ def test_batch_check_find_job_multiple_states(batch):
     # both not yet done but only 'foo' responded to find_job
     not_done = {"foo", "bar"}
 
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job(not_done, jid="1234")
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
 
-    # assert 'bar' removed from active
-    assert batch.active == {"foo"}
+    with patch.object(batch, "schedule_next", return_value=future), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.check_find_job(not_done, jid="1234")
 
-    # assert 'bar' added to timedout_minions
-    assert batch.timedout_minions == {"bar", "faz"}
+        # assert 'bar' removed from active
+        assert batch.active == {"foo"}
 
-    # assert 'find_job' schedueled again only for 'foo'
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+        # assert 'bar' added to timedout_minions
+        assert batch.timedout_minions == {"bar", "faz"}
+
+        # assert 'find_job' schedueled again only for 'foo'
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_only_on_run_next_is_scheduled(batch):
-    batch.event = MagicMock()
+    future = salt.ext.tornado.gen.Future()
+    future.set_result({})
     batch.scheduled = True
-    batch.schedule_next()
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    with patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.schedule_next()
+        run_next_mock.assert_not_called()
-- 
2.45.0

openSUSE Build Service is sponsored by