File prevent-oom-with-high-amount-of-batch-async-calls-bs.patch of Package salt

From 59ed379f4576b9e7ceb48e3c1b182b08e3279b9a Mon Sep 17 00:00:00 2001
From: Victor Zhestkov <vzhestkov@suse.com>
Date: Wed, 15 May 2024 11:47:35 +0200
Subject: [PATCH] Prevent OOM with high amount of batch async calls
 (bsc#1216063)

* Refactor batch_async implementation

* Fix batch_async tests after refactoring
---
 salt/cli/batch_async.py                    | 356 +++++++++++++++-----
 salt/master.py                             |   4 +
 tests/pytests/unit/cli/test_batch_async.py | 360 ++++++++++++---------
 3 files changed, 487 insertions(+), 233 deletions(-)

diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
index ddbd16870a..217102f8b4 100644
--- a/salt/cli/batch_async.py
+++ b/salt/cli/batch_async.py
@@ -4,15 +4,193 @@ Execute a job on the targeted minions by using a moving window of fixed size `ba
 
 # pylint: enable=import-error,no-name-in-module,redefined-builtin
 import logging
+import re
 
 import tornado
 
 import salt.client
+import salt.utils.event
 from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+from tornado.iostream import StreamClosedError
 
 log = logging.getLogger(__name__)
 
 
+__SHARED_EVENTS_CHANNEL = None
+
+
+def _get_shared_events_channel(opts, io_loop):
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is None:
+        __SHARED_EVENTS_CHANNEL = SharedEventsChannel(opts, io_loop)
+    return __SHARED_EVENTS_CHANNEL
+
+
+def _destroy_unused_shared_events_channel():
+    global __SHARED_EVENTS_CHANNEL
+    if __SHARED_EVENTS_CHANNEL is not None and __SHARED_EVENTS_CHANNEL.destroy_unused():
+        __SHARED_EVENTS_CHANNEL = None
+
+
+def batch_async_required(opts, minions, extra):
+    """
+    Check opts to identify if batch async is required for the operation.
+    """
+    if not isinstance(minions, list):
+        False
+    batch_async_opts = opts.get("batch_async", {})
+    batch_async_threshold = (
+        batch_async_opts.get("threshold", 1)
+        if isinstance(batch_async_opts, dict)
+        else 1
+    )
+    if batch_async_threshold == -1:
+        batch_size = get_bnum(extra, minions, True)
+        return len(minions) >= batch_size
+    elif batch_async_threshold > 0:
+        return len(minions) >= batch_async_threshold
+    return False
+
+
+class SharedEventsChannel:
+    def __init__(self, opts, io_loop):
+        self.io_loop = io_loop
+        self.local_client = salt.client.get_local_client(
+            opts["conf_file"], io_loop=self.io_loop
+        )
+        self.master_event = salt.utils.event.get_event(
+            "master",
+            sock_dir=self.local_client.opts["sock_dir"],
+            opts=self.local_client.opts,
+            listen=True,
+            io_loop=self.io_loop,
+            keep_loop=True,
+        )
+        self.master_event.set_event_handler(self.__handle_event)
+        if self.master_event.subscriber.stream:
+            self.master_event.subscriber.stream.set_close_callback(self.__handle_close)
+        self._re_tag_ret_event = re.compile(r"salt\/job\/(\d+)\/ret\/.*")
+        self._subscribers = {}
+        self._subscriptions = {}
+        self._used_by = set()
+        batch_async_opts = opts.get("batch_async", {})
+        if not isinstance(batch_async_opts, dict):
+            batch_async_opts = {}
+        self._subscriber_reconnect_tries = batch_async_opts.get(
+            "subscriber_reconnect_tries", 5
+        )
+        self._subscriber_reconnect_interval = batch_async_opts.get(
+            "subscriber_reconnect_interval", 1.0
+        )
+        self._reconnecting_subscriber = False
+
+    def subscribe(self, jid, op, subscriber_id, handler):
+        if subscriber_id not in self._subscribers:
+            self._subscribers[subscriber_id] = set()
+        if jid not in self._subscriptions:
+            self._subscriptions[jid] = []
+        self._subscribers[subscriber_id].add(jid)
+        if (op, subscriber_id, handler) not in self._subscriptions[jid]:
+            self._subscriptions[jid].append((op, subscriber_id, handler))
+        if not self.master_event.subscriber.connected():
+            self.__reconnect_subscriber()
+
+    def unsubscribe(self, jid, op, subscriber_id):
+        if subscriber_id not in self._subscribers:
+            return
+        jids = self._subscribers[subscriber_id].copy()
+        if jid is not None:
+            jids = set(jid)
+        for i_jid in jids:
+            self._subscriptions[i_jid] = list(
+                filter(
+                    lambda x: not (op in (x[0], None) and x[1] == subscriber_id),
+                    self._subscriptions.get(i_jid, []),
+                )
+            )
+            self._subscribers[subscriber_id].discard(i_jid)
+        self._subscriptions = dict(filter(lambda x: x[1], self._subscriptions.items()))
+        if not self._subscribers[subscriber_id]:
+            del self._subscribers[subscriber_id]
+
+    @tornado.gen.coroutine
+    def __handle_close(self):
+        if not self._subscriptions:
+            return
+        log.warning("Master Event Subscriber was closed. Trying to reconnect...")
+        yield self.__reconnect_subscriber()
+
+    @tornado.gen.coroutine
+    def __handle_event(self, raw):
+        if self.master_event is None:
+            return
+        try:
+            tag, data = self.master_event.unpack(raw)
+            tag_match = self._re_tag_ret_event.match(tag)
+            if tag_match:
+                jid = tag_match.group(1)
+                if jid in self._subscriptions:
+                    for op, _, handler in self._subscriptions[jid]:
+                        yield handler(tag, data, op)
+        except Exception as ex:  # pylint: disable=W0703
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
+
+    @tornado.gen.coroutine
+    def __reconnect_subscriber(self):
+        if self.master_event.subscriber.connected() or self._reconnecting_subscriber:
+            return
+        self._reconnecting_subscriber = True
+        max_tries = max(1, int(self._subscriber_reconnect_tries))
+        _try = 1
+        while _try <= max_tries:
+            log.info(
+                "Trying to reconnect to event publisher (try %d of %d) ...",
+                _try,
+                max_tries,
+            )
+            try:
+                yield self.master_event.subscriber.connect()
+            except StreamClosedError:
+                log.warning(
+                    "Unable to reconnect to event publisher (try %d of %d)",
+                    _try,
+                    max_tries,
+                )
+            if self.master_event.subscriber.connected():
+                self.master_event.subscriber.stream.set_close_callback(
+                    self.__handle_close
+                )
+                log.info("Event publisher connection restored")
+                self._reconnecting_subscriber = False
+                return
+            if _try < max_tries:
+                yield tornado.gen.sleep(self._subscriber_reconnect_interval)
+            _try += 1
+        self._reconnecting_subscriber = False
+
+    def use(self, subscriber_id):
+        self._used_by.add(subscriber_id)
+        return self
+
+    def unuse(self, subscriber_id):
+        self._used_by.discard(subscriber_id)
+
+    def destroy_unused(self):
+        if self._used_by:
+            return False
+        self.master_event.remove_event_handler(self.__handle_event)
+        self.master_event.destroy()
+        self.master_event = None
+        self.local_client.destroy()
+        self.local_client = None
+        return True
+
+
 class BatchAsync:
     """
     Run a job on the targeted minions by using a moving window of fixed size `batch`.
@@ -44,17 +222,26 @@ class BatchAsync:
          }
     """
 
-    def __init__(self, parent_opts, jid_gen, clear_load):
-        ioloop = tornado.ioloop.IOLoop.current()
-        self.local = salt.client.get_local_client(
-            parent_opts["conf_file"], io_loop=ioloop
+    def __init__(self, opts, jid_gen, clear_load):
+        self.extra_job_kwargs = {}
+        kwargs = clear_load.get("kwargs", {})
+        for kwarg in ("module_executors", "executor_opts"):
+            if kwarg in kwargs:
+                self.extra_job_kwargs[kwarg] = kwargs[kwarg]
+            elif kwarg in opts:
+                self.extra_job_kwargs[kwarg] = opts[kwarg]
+        self.io_loop = tornado.ioloop.IOLoop.current()
+        self.events_channel = _get_shared_events_channel(opts, self.io_loop).use(
+            id(self)
         )
         if "gather_job_timeout" in clear_load["kwargs"]:
             clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
                 "gather_job_timeout"
             )
         else:
-            clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+            clear_load["gather_job_timeout"] = self.events_channel.local_client.opts[
+                "gather_job_timeout"
+            ]
         self.batch_presence_ping_timeout = clear_load["kwargs"].get(
             "batch_presence_ping_timeout", None
         )
@@ -63,7 +250,7 @@ class BatchAsync:
             clear_load.pop("tgt"),
             clear_load.pop("fun"),
             clear_load["kwargs"].pop("batch"),
-            self.local.opts,
+            self.events_channel.local_client.opts,
             **clear_load,
         )
         self.eauth = batch_get_eauth(clear_load["kwargs"])
@@ -77,53 +264,45 @@ class BatchAsync:
         self.jid_gen = jid_gen
         self.ping_jid = jid_gen()
         self.batch_jid = jid_gen()
-        self.find_job_jid = jid_gen()
         self.find_job_returned = set()
+        self.metadata.update({"batch_jid": self.batch_jid, "ping_jid": self.ping_jid})
         self.ended = False
-        self.event = salt.utils.event.get_event(
-            "master",
-            sock_dir=self.opts["sock_dir"],
-            opts=self.opts,
-            listen=True,
-            io_loop=ioloop,
-            keep_loop=True,
-        )
+        self.event = self.events_channel.master_event
         self.scheduled = False
-        self.patterns = set()
 
     def __set_event_handler(self):
-        ping_return_pattern = f"salt/job/{self.ping_jid}/ret/*"
-        batch_return_pattern = f"salt/job/{self.batch_jid}/ret/*"
-        self.event.subscribe(ping_return_pattern, match_type="glob")
-        self.event.subscribe(batch_return_pattern, match_type="glob")
-        self.patterns = {
-            (ping_return_pattern, "ping_return"),
-            (batch_return_pattern, "batch_run"),
-        }
-        self.event.set_event_handler(self.__event_handler)
+        self.events_channel.subscribe(
+            self.ping_jid, "ping_return", id(self), self.__event_handler
+        )
+        self.events_channel.subscribe(
+            self.batch_jid, "batch_run", id(self), self.__event_handler
+        )
 
-    def __event_handler(self, raw):
+    @tornado.gen.coroutine
+    def __event_handler(self, tag, data, op):
         if not self.event:
             return
         try:
-            mtag, data = self.event.unpack(raw)
-            for pattern, op in self.patterns:
-                if mtag.startswith(pattern[:-1]):
-                    minion = data["id"]
-                    if op == "ping_return":
-                        self.minions.add(minion)
-                        if self.targeted_minions == self.minions:
-                            self.event.io_loop.spawn_callback(self.start_batch)
-                    elif op == "find_job_return":
-                        if data.get("return", None):
-                            self.find_job_returned.add(minion)
-                    elif op == "batch_run":
-                        if minion in self.active:
-                            self.active.remove(minion)
-                            self.done_minions.add(minion)
-                            self.event.io_loop.spawn_callback(self.schedule_next)
+            minion = data["id"]
+            if op == "ping_return":
+                self.minions.add(minion)
+                if self.targeted_minions == self.minions:
+                    yield self.start_batch()
+            elif op == "find_job_return":
+                if data.get("return", None):
+                    self.find_job_returned.add(minion)
+            elif op == "batch_run":
+                if minion in self.active:
+                    self.active.remove(minion)
+                    self.done_minions.add(minion)
+                    yield self.schedule_next()
         except Exception as ex:  # pylint: disable=W0703
-            log.error("Exception occured while processing event: %s", ex, exc_info=True)
+            log.error(
+                "Exception occured while processing event: %s: %s",
+                tag,
+                ex,
+                exc_info=True,
+            )
 
     def _get_next(self):
         to_run = (
@@ -144,9 +323,7 @@ class BatchAsync:
         """
         if not self.event:
             return
-        find_job_return_pattern = f"salt/job/{jid}/ret/*"
-        self.event.unsubscribe(find_job_return_pattern, match_type="glob")
-        self.patterns.remove((find_job_return_pattern, "find_job_return"))
+        self.events_channel.unsubscribe(jid, "find_job_return", id(self))
 
         timedout_minions = batch_minions.difference(self.find_job_returned).difference(
             self.done_minions
@@ -158,11 +335,11 @@ class BatchAsync:
         )
 
         if timedout_minions:
-            self.event.io_loop.spawn_callback(self.schedule_next)
+            yield self.schedule_next()
 
         if self.event and running:
             self.find_job_returned = self.find_job_returned.difference(running)
-            self.event.io_loop.spawn_callback(self.find_job, running)
+            yield self.find_job(running)
 
     @tornado.gen.coroutine
     def find_job(self, minions):
@@ -178,21 +355,23 @@ class BatchAsync:
             return
         try:
             jid = self.jid_gen()
-            find_job_return_pattern = f"salt/job/{jid}/ret/*"
-            self.patterns.add((find_job_return_pattern, "find_job_return"))
-            self.event.subscribe(find_job_return_pattern, match_type="glob")
-            ret = yield self.local.run_job_async(
+            self.events_channel.subscribe(
+                jid, "find_job_return", id(self), self.__event_handler
+            )
+            ret = yield self.events_channel.local_client.run_job_async(
                 not_done,
                 "saltutil.find_job",
                 [self.batch_jid],
                 "list",
                 gather_job_timeout=self.opts["gather_job_timeout"],
                 jid=jid,
+                io_loop=self.io_loop,
+                listen=False,
                 **self.eauth,
             )
             yield tornado.gen.sleep(self.opts["gather_job_timeout"])
             if self.event:
-                self.event.io_loop.spawn_callback(self.check_find_job, not_done, jid)
+                yield self.check_find_job(not_done, jid)
         except Exception as ex:  # pylint: disable=W0703
             log.error(
                 "Exception occured handling batch async: %s. Aborting execution.",
@@ -201,6 +380,34 @@ class BatchAsync:
             )
             self.close_safe()
 
+    @tornado.gen.coroutine
+    def start(self):
+        """
+        Start the batch execution
+        """
+        if not self.event:
+            return
+        self.__set_event_handler()
+        ping_return = yield self.events_channel.local_client.run_job_async(
+            self.opts["tgt"],
+            "test.ping",
+            [],
+            self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
+            gather_job_timeout=self.opts["gather_job_timeout"],
+            jid=self.ping_jid,
+            metadata=self.metadata,
+            io_loop=self.io_loop,
+            listen=False,
+            **self.eauth,
+        )
+        self.targeted_minions = set(ping_return["minions"])
+        # start batching even if not all minions respond to ping
+        yield tornado.gen.sleep(
+            self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
+        )
+        if self.event:
+            yield self.start_batch()
+
     @tornado.gen.coroutine
     def start(self):
         """
@@ -230,7 +437,7 @@ class BatchAsync:
     @tornado.gen.coroutine
     def start_batch(self):
         """
-        Start the next interation of batch execution
+        Fire `salt/batch/*/start` and continue batch with `run_next`
         """
         if self.initialized:
             return
@@ -241,9 +448,11 @@ class BatchAsync:
             "down_minions": self.targeted_minions.difference(self.minions),
             "metadata": self.metadata,
         }
-        ret = self.event.fire_event(data, f"salt/batch/{self.batch_jid}/start")
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/start"
+        )
         if self.event:
-            self.event.io_loop.spawn_callback(self.run_next)
+            yield self.run_next()
 
     @tornado.gen.coroutine
     def end_batch(self):
@@ -265,7 +474,9 @@ class BatchAsync:
             "timedout_minions": self.timedout_minions,
             "metadata": self.metadata,
         }
-        self.event.fire_event(data, f"salt/batch/{self.batch_jid}/done")
+        yield self.events_channel.master_event.fire_event_async(
+            data, f"salt/batch/{self.batch_jid}/done"
+        )
 
         # release to the IOLoop to allow the event to be published
         # before closing batch async execution
@@ -273,14 +484,12 @@ class BatchAsync:
         self.close_safe()
 
     def close_safe(self):
-        if self.event:
-            for pattern, label in self.patterns:
-                self.event.unsubscribe(pattern, match_type="glob")
-            self.event.remove_event_handler(self.__event_handler)
-            self.event.destroy()
-            self.event = None
-        self.local = None
-        self.ioloop = None
+        if self.events_channel is not None:
+            self.events_channel.unsubscribe(None, None, id(self))
+            self.events_channel.unuse(id(self))
+            self.events_channel = None
+            _destroy_unused_shared_events_channel()
+        self.event = None
 
     @tornado.gen.coroutine
     def schedule_next(self):
@@ -290,10 +499,13 @@ class BatchAsync:
         # call later so that we maybe gather more returns
         yield tornado.gen.sleep(self.batch_delay)
         if self.event:
-            self.event.io_loop.spawn_callback(self.run_next)
+            yield self.run_next()
 
     @tornado.gen.coroutine
     def run_next(self):
+        """
+        Continue batch execution with the next targets
+        """
         self.scheduled = False
         next_batch = self._get_next()
         if not next_batch:
@@ -301,7 +513,7 @@ class BatchAsync:
             return
         self.active = self.active.union(next_batch)
         try:
-            ret = yield self.local.run_job_async(
+            ret = yield self.events_channel.local_client.run_job_async(
                 next_batch,
                 self.opts["fun"],
                 self.opts["arg"],
@@ -311,13 +523,17 @@ class BatchAsync:
                 gather_job_timeout=self.opts["gather_job_timeout"],
                 jid=self.batch_jid,
                 metadata=self.metadata,
+                io_loop=self.io_loop,
+                listen=False,
+                **self.eauth,
+                **self.extra_job_kwargs,
             )
 
             yield tornado.gen.sleep(self.opts["timeout"])
 
             # The batch can be done already at this point, which means no self.event
             if self.event:
-                self.event.io_loop.spawn_callback(self.find_job, set(next_batch))
+                yield self.find_job(set(next_batch))
         except Exception as ex:  # pylint: disable=W0703
             log.error(
                 "Error in scheduling next batch: %s. Aborting execution",
@@ -326,7 +542,3 @@ class BatchAsync:
             )
             self.active = self.active.difference(next_batch)
             self.close_safe()
-
-    # pylint: disable=W1701
-    def __del__(self):
-        self.close_safe()
diff --git a/salt/master.py b/salt/master.py
index 0c04e0721e..bdadd4a062 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -58,6 +58,7 @@ import salt.utils.user
 import salt.utils.verify
 import salt.utils.zeromq
 import salt.wheel
+from salt.cli.batch_async import batch_async_required
 from salt.config import DEFAULT_INTERVAL
 from salt.defaults import DEFAULT_TARGET_DELIM
 from salt.transport import TRANSPORTS
@@ -2496,6 +2497,9 @@ class ClearFuncs(TransportMethods):
                         ),
                     },
                 }
+        if extra.get("batch", None) and batch_async_required(self.opts, minions, extra):
+            return self.publish_batch(clear_load, minions, missing)
+
         jid = self._prep_jid(clear_load, extra)
         if jid is None:
             return {"enc": "clear", "load": {"error": "Master failed to assign jid"}}
diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py
index 94a9e15747..56611c833f 100644
--- a/tests/pytests/unit/cli/test_batch_async.py
+++ b/tests/pytests/unit/cli/test_batch_async.py
@@ -1,7 +1,7 @@
 import pytest
 import tornado
 
-from salt.cli.batch_async import BatchAsync
+from salt.cli.batch_async import BatchAsync, batch_async_required
 from tests.support.mock import MagicMock, patch
 
 
@@ -22,16 +22,44 @@ def batch(temp_salt_master):
         with patch("salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)):
             batch = BatchAsync(
                 opts,
-                MagicMock(side_effect=["1234", "1235", "1236"]),
+                MagicMock(side_effect=["1234", "1235"]),
                 {
                     "tgt": "",
                     "fun": "",
-                    "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+                    "kwargs": {
+                        "batch": "",
+                        "batch_presence_ping_timeout": 1,
+                        "metadata": {"mykey": "myvalue"},
+                    },
                 },
             )
             yield batch
 
 
+@pytest.mark.parametrize(
+    "threshold,minions,batch,expected",
+    [
+        (1, 2, 200, True),
+        (1, 500, 200, True),
+        (0, 2, 200, False),
+        (0, 500, 200, False),
+        (-1, 2, 200, False),
+        (-1, 500, 200, True),
+        (-1, 9, 10, False),
+        (-1, 11, 10, True),
+        (10, 9, 8, False),
+        (10, 9, 10, False),
+        (10, 11, 8, True),
+        (10, 11, 10, True),
+    ],
+)
+def test_batch_async_required(threshold, minions, batch, expected):
+    minions_list = [f"minion{i}.example.org" for i in range(minions)]
+    batch_async_opts = {"batch_async": {"threshold": threshold}}
+    extra = {"batch": batch}
+    assert batch_async_required(batch_async_opts, minions_list, extra) == expected
+
+
 def test_ping_jid(batch):
     assert batch.ping_jid == "1234"
 
@@ -40,10 +68,6 @@ def test_batch_jid(batch):
     assert batch.batch_jid == "1235"
 
 
-def test_find_job_jid(batch):
-    assert batch.find_job_jid == "1236"
-
-
 def test_batch_size(batch):
     """
     Tests passing batch value as a number
@@ -55,58 +79,74 @@ def test_batch_size(batch):
 
 
 def test_batch_start_on_batch_presence_ping_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
+    future_ret = tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     future = tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("tornado.gen.sleep", return_value=future):
-        # ret = batch_async.start(batch)
+    future.set_result({})
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "tornado.gen.sleep", return_value=future
+    ), patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         ret = batch.start()
-        # assert start_batch is called later with batch_presence_ping_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
         # assert test.ping called
-        assert batch.local.run_job_async.call_args[0] == ("*", "test.ping", [], "glob")
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
+            "*",
+            "test.ping",
+            [],
+            "glob",
+        )
         # assert targeted_minions == all minions matched by tgt
         assert batch.targeted_minions == {"foo", "bar"}
 
 
 def test_batch_start_on_gather_job_timeout(batch):
-    # batch_async = BatchAsyncMock();
-    batch.event = MagicMock()
     future = tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
+    future.set_result({})
+    future_ret = tornado.gen.Future()
+    future_ret.set_result({"minions": ["foo", "bar"]})
     batch.batch_presence_ping_timeout = None
-    with patch("tornado.gen.sleep", return_value=future):
+    with patch.object(batch, "events_channel", MagicMock()), patch(
+        "tornado.gen.sleep", return_value=future
+    ), patch.object(
+        batch, "start_batch", return_value=future
+    ) as start_batch_mock, patch.object(
+        batch, "batch_presence_ping_timeout", None
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future_ret
         # ret = batch_async.start(batch)
         ret = batch.start()
-        # assert start_batch is called later with gather_job_timeout as param
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+        # assert start_batch is called
+        start_batch_mock.assert_called_once()
 
 
 def test_batch_fire_start_event(batch):
     batch.minions = {"foo", "bar"}
     batch.opts = {"batch": "2", "timeout": 5}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    batch.start_batch()
-    assert batch.event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "down_minions": set(),
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/start",
-    )
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.start_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "down_minions": set(),
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/start",
+        )
 
 
 def test_start_batch_calls_next(batch):
-    batch.run_next = MagicMock(return_value=MagicMock())
-    batch.event = MagicMock()
-    batch.start_batch()
-    assert batch.initialized
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.run_next,)
+    batch.initialized = False
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "event", MagicMock()), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.events_channel.master_event.fire_event_async.return_value = future
+        batch.start_batch()
+        assert batch.initialized
+        run_next_mock.assert_called_once()
 
 
 def test_batch_fire_done_event(batch):
@@ -114,69 +154,52 @@ def test_batch_fire_done_event(batch):
     batch.minions = {"foo", "bar"}
     batch.done_minions = {"foo"}
     batch.timedout_minions = {"bar"}
-    batch.event = MagicMock()
-    batch.metadata = {"mykey": "myvalue"}
-    old_event = batch.event
-    batch.end_batch()
-    assert old_event.fire_event.call_args[0] == (
-        {
-            "available_minions": {"foo", "bar"},
-            "done_minions": batch.done_minions,
-            "down_minions": {"baz"},
-            "timedout_minions": batch.timedout_minions,
-            "metadata": batch.metadata,
-        },
-        "salt/batch/1235/done",
-    )
-
-
-def test_batch__del__(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.__del__()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
+    with patch.object(batch, "events_channel", MagicMock()):
+        batch.end_batch()
+        assert batch.events_channel.master_event.fire_event_async.call_args[0] == (
+            {
+                "available_minions": {"foo", "bar"},
+                "done_minions": batch.done_minions,
+                "down_minions": {"baz"},
+                "timedout_minions": batch.timedout_minions,
+                "metadata": batch.metadata,
+            },
+            "salt/batch/1235/done",
+        )
 
 
 def test_batch_close_safe(batch):
-    batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
-    event = MagicMock()
-    batch.event = event
-    batch.patterns = {
-        ("salt/job/1234/ret/*", "find_job_return"),
-        ("salt/job/4321/ret/*", "find_job_return"),
-    }
-    batch.close_safe()
-    assert batch.local is None
-    assert batch.event is None
-    assert batch.ioloop is None
-    assert len(event.unsubscribe.mock_calls) == 2
-    assert len(event.remove_event_handler.mock_calls) == 1
+    with patch.object(
+        batch, "events_channel", MagicMock()
+    ) as events_channel_mock, patch.object(batch, "event", MagicMock()):
+        batch.close_safe()
+        batch.close_safe()
+        assert batch.events_channel is None
+        assert batch.event is None
+        events_channel_mock.unsubscribe.assert_called_once()
+        events_channel_mock.unuse.assert_called_once()
 
 
 def test_batch_next(batch):
-    batch.event = MagicMock()
     batch.opts["fun"] = "my.fun"
     batch.opts["arg"] = []
-    batch._get_next = MagicMock(return_value={"foo", "bar"})
     batch.batch_size = 2
     future = tornado.gen.Future()
-    future.set_result({"minions": ["foo", "bar"]})
-    batch.local.run_job_async.return_value = future
-    with patch("tornado.gen.sleep", return_value=future):
+    future.set_result({})
+    with patch("tornado.gen.sleep", return_value=future), patch.object(
+        batch, "events_channel", MagicMock()
+    ), patch.object(batch, "_get_next", return_value={"foo", "bar"}), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.run_next()
-        assert batch.local.run_job_async.call_args[0] == (
+        assert batch.events_channel.local_client.run_job_async.call_args[0] == (
             {"foo", "bar"},
             "my.fun",
             [],
             "list",
         )
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.find_job,
-            {"foo", "bar"},
-        )
+        assert find_job_mock.call_args[0] == ({"foo", "bar"},)
         assert batch.active == {"bar", "foo"}
 
 
@@ -239,124 +262,132 @@ def test_next_batch_all_timedout(batch):
 
 def test_batch__event_handler_ping_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
     batch.start()
     assert batch.minions == set()
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+    )
     assert batch.minions == {"foo"}
     assert batch.done_minions == set()
 
 
 def test_batch__event_handler_call_start_batch_when_all_pings_return(batch):
     batch.targeted_minions = {"foo"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.start_batch,)
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_called_once()
 
 
 def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(batch):
     batch.targeted_minions = {"foo", "bar"}
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch._BatchAsync__event_handler(MagicMock())
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "start_batch", return_value=future) as start_batch_mock:
+        batch.start()
+        batch._BatchAsync__event_handler(
+            "salt/job/1234/ret/foo", {"id": "foo"}, "ping_return"
+        )
+        start_batch_mock.assert_not_called()
 
 
 def test_batch__event_handler_batch_run_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
-    )
-    batch.start()
-    batch.active = {"foo"}
-    batch._BatchAsync__event_handler(MagicMock())
-    assert batch.active == set()
-    assert batch.done_minions == {"foo"}
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.schedule_next,)
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "schedule_next", return_value=future
+    ) as schedule_next_mock:
+        batch.start()
+        batch.active = {"foo"}
+        batch._BatchAsync__event_handler(
+            "salt/job/1235/ret/foo", {"id": "foo"}, "batch_run"
+        )
+        assert batch.active == set()
+        assert batch.done_minions == {"foo"}
+        schedule_next_mock.assert_called_once()
 
 
 def test_batch__event_handler_find_job_return(batch):
-    batch.event = MagicMock(
-        unpack=MagicMock(
-            return_value=(
-                "salt/job/1236/ret/foo",
-                {"id": "foo", "return": "deadbeaf"},
-            )
-        )
-    )
     batch.start()
-    batch.patterns.add(("salt/job/1236/ret/*", "find_job_return"))
-    batch._BatchAsync__event_handler(MagicMock())
+    batch._BatchAsync__event_handler(
+        "salt/job/1236/ret/foo", {"id": "foo", "return": "deadbeaf"}, "find_job_return"
+    )
     assert batch.find_job_returned == {"foo"}
 
 
 def test_batch_run_next_end_batch_when_no_next(batch):
-    batch.end_batch = MagicMock()
-    batch._get_next = MagicMock(return_value={})
-    batch.run_next()
-    assert len(batch.end_batch.mock_calls) == 1
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(
+        batch, "_get_next", return_value={}
+    ), patch.object(
+        batch, "end_batch", return_value=future
+    ) as end_batch_mock:
+        batch.run_next()
+        end_batch_mock.assert_called_once()
 
 
 def test_batch_find_job(batch):
-    batch.event = MagicMock()
     future = tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("tornado.gen.sleep", return_value=future):
+    with patch("tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo", "bar"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_find_job_with_done_minions(batch):
     batch.done_minions = {"bar"}
-    batch.event = MagicMock()
     future = tornado.gen.Future()
     future.set_result({})
-    batch.local.run_job_async.return_value = future
     batch.minions = {"foo", "bar"}
-    batch.jid_gen = MagicMock(return_value="1234")
-    with patch("tornado.gen.sleep", return_value=future):
+    with patch("tornado.gen.sleep", return_value=future), patch.object(
+        batch, "check_find_job", return_value=future
+    ) as check_find_job_mock, patch.object(
+        batch, "jid_gen", return_value="1236"
+    ):
+        batch.events_channel.local_client.run_job_async.return_value = future
         batch.find_job({"foo", "bar"})
-        assert batch.event.io_loop.spawn_callback.call_args[0] == (
-            batch.check_find_job,
+        assert check_find_job_mock.call_args[0] == (
             {"foo"},
-            "1234",
+            "1236",
         )
 
 
 def test_batch_check_find_job_did_not_return(batch):
-    batch.event = MagicMock()
     batch.active = {"foo"}
     batch.find_job_returned = set()
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.find_job_returned == set()
-    assert batch.active == set()
-    assert len(batch.event.io_loop.add_callback.mock_calls) == 0
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        assert batch.find_job_returned == set()
+        assert batch.active == set()
+        find_job_mock.assert_not_called()
 
 
 def test_batch_check_find_job_did_return(batch):
-    batch.event = MagicMock()
     batch.find_job_returned = {"foo"}
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job({"foo"}, jid="1234")
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+    future = tornado.gen.Future()
+    future.set_result({})
+    with patch.object(batch, "find_job", return_value=future) as find_job_mock:
+        batch.check_find_job({"foo"}, jid="1234")
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_batch_check_find_job_multiple_states(batch):
-    batch.event = MagicMock()
     # currently running minions
     batch.active = {"foo", "bar"}
 
@@ -372,21 +403,28 @@ def test_batch_check_find_job_multiple_states(batch):
     # both not yet done but only 'foo' responded to find_job
     not_done = {"foo", "bar"}
 
-    batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
-    batch.check_find_job(not_done, jid="1234")
+    future = tornado.gen.Future()
+    future.set_result({})
 
-    # assert 'bar' removed from active
-    assert batch.active == {"foo"}
+    with patch.object(batch, "schedule_next", return_value=future), patch.object(
+        batch, "find_job", return_value=future
+    ) as find_job_mock:
+        batch.check_find_job(not_done, jid="1234")
 
-    # assert 'bar' added to timedout_minions
-    assert batch.timedout_minions == {"bar", "faz"}
+        # assert 'bar' removed from active
+        assert batch.active == {"foo"}
 
-    # assert 'find_job' schedueled again only for 'foo'
-    assert batch.event.io_loop.spawn_callback.call_args[0] == (batch.find_job, {"foo"})
+        # assert 'bar' added to timedout_minions
+        assert batch.timedout_minions == {"bar", "faz"}
+
+        # assert 'find_job' schedueled again only for 'foo'
+        find_job_mock.assert_called_once_with({"foo"})
 
 
 def test_only_on_run_next_is_scheduled(batch):
-    batch.event = MagicMock()
+    future = tornado.gen.Future()
+    future.set_result({})
     batch.scheduled = True
-    batch.schedule_next()
-    assert len(batch.event.io_loop.spawn_callback.mock_calls) == 0
+    with patch.object(batch, "run_next", return_value=future) as run_next_mock:
+        batch.schedule_next()
+        run_next_mock.assert_not_called()
-- 
2.47.0

openSUSE Build Service is sponsored by