File fix-async-batch-race-conditions.patch of Package salt

From 33c5e10c2912f584243d29c764c2c6cca86edf4a Mon Sep 17 00:00:00 2001
From: Mihai Dinca <mdinca@suse.de>
Date: Thu, 11 Apr 2019 15:57:59 +0200
Subject: [PATCH] Fix async batch race conditions

Close batching when there is no next batch
---
 salt/cli/batch_async.py            | 80 +++++++++++++++---------------
 tests/unit/cli/test_batch_async.py | 35 ++++++-------
 2 files changed, 54 insertions(+), 61 deletions(-)

diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
index 3160d46d8b..9c20b2fc6e 100644
--- a/salt/cli/batch_async.py
+++ b/salt/cli/batch_async.py
@@ -37,14 +37,14 @@ class BatchAsync(object):
          - tag: salt/batch/<batch-jid>/start
          - data: {
              "available_minions": self.minions,
-             "down_minions": self.down_minions
+             "down_minions": targeted_minions - presence_ping_minions
            }
 
     When the batch ends, an `done` event is fired:
         - tag: salt/batch/<batch-jid>/done
         - data: {
              "available_minions": self.minions,
-             "down_minions": self.down_minions,
+             "down_minions": targeted_minions - presence_ping_minions
              "done_minions": self.done_minions,
              "timedout_minions": self.timedout_minions
          }
@@ -67,7 +67,7 @@ class BatchAsync(object):
         self.eauth = batch_get_eauth(clear_load['kwargs'])
         self.metadata = clear_load['kwargs'].get('metadata', {})
         self.minions = set()
-        self.down_minions = set()
+        self.targeted_minions = set()
         self.timedout_minions = set()
         self.done_minions = set()
         self.active = set()
@@ -108,8 +108,7 @@ class BatchAsync(object):
                 minion = data['id']
                 if op == 'ping_return':
                     self.minions.add(minion)
-                    self.down_minions.remove(minion)
-                    if not self.down_minions:
+                    if self.targeted_minions == self.minions:
                         self.event.io_loop.spawn_callback(self.start_batch)
                 elif op == 'find_job_return':
                     self.find_job_returned.add(minion)
@@ -120,9 +119,6 @@ class BatchAsync(object):
                         # call later so that we maybe gather more returns
                         self.event.io_loop.call_later(self.batch_delay, self.schedule_next)
 
-        if self.initialized and self.done_minions == self.minions.difference(self.timedout_minions):
-            self.end_batch()
-
     def _get_next(self):
         to_run = self.minions.difference(
             self.done_minions).difference(
@@ -135,16 +131,13 @@ class BatchAsync(object):
         return set(list(to_run)[:next_batch_size])
 
     @tornado.gen.coroutine
-    def check_find_job(self, minions):
-        did_not_return = minions.difference(self.find_job_returned)
-        if did_not_return:
-            for minion in did_not_return:
-                if minion in self.find_job_returned:
-                    self.find_job_returned.remove(minion)
-                if minion in self.active:
-                    self.active.remove(minion)
-                self.timedout_minions.add(minion)
-        running = minions.difference(did_not_return).difference(self.done_minions).difference(self.timedout_minions)
+    def check_find_job(self, batch_minions):
+        timedout_minions = batch_minions.difference(self.find_job_returned).difference(self.done_minions)
+        self.timedout_minions = self.timedout_minions.union(timedout_minions)
+        self.active = self.active.difference(self.timedout_minions)
+        running = batch_minions.difference(self.done_minions).difference(self.timedout_minions)
+        if timedout_minions:
+            self.event.io_loop.call_later(self.batch_delay, self.schedule_next)
         if running:
             self.event.io_loop.add_callback(self.find_job, running)
 
@@ -183,7 +176,7 @@ class BatchAsync(object):
             jid=self.ping_jid,
             metadata=self.metadata,
             **self.eauth)
-        self.down_minions = set(ping_return['minions'])
+        self.targeted_minions = set(ping_return['minions'])
 
     @tornado.gen.coroutine
     def start_batch(self):
@@ -192,36 +185,43 @@ class BatchAsync(object):
             self.initialized = True
             data = {
                 "available_minions": self.minions,
-                "down_minions": self.down_minions,
+                "down_minions": self.targeted_minions.difference(self.minions),
                 "metadata": self.metadata
             }
             self.event.fire_event(data, "salt/batch/{0}/start".format(self.batch_jid))
             yield self.schedule_next()
 
     def end_batch(self):
-        data = {
-            "available_minions": self.minions,
-            "down_minions": self.down_minions,
-            "done_minions": self.done_minions,
-            "timedout_minions": self.timedout_minions,
-            "metadata": self.metadata
-        }
-        self.event.fire_event(data, "salt/batch/{0}/done".format(self.batch_jid))
-        self.event.remove_event_handler(self.__event_handler)
+        left = self.minions.symmetric_difference(self.done_minions.union(self.timedout_minions))
+        if not left:
+            data = {
+                "available_minions": self.minions,
+                "down_minions": self.targeted_minions.difference(self.minions),
+                "done_minions": self.done_minions,
+                "timedout_minions": self.timedout_minions,
+                "metadata": self.metadata
+            }
+            self.event.fire_event(data, "salt/batch/{0}/done".format(self.batch_jid))
+            self.event.remove_event_handler(self.__event_handler)
 
     @tornado.gen.coroutine
     def schedule_next(self):
         next_batch = self._get_next()
         if next_batch:
-            yield self.local.run_job_async(
-                next_batch,
-                self.opts['fun'],
-                self.opts['arg'],
-                'list',
-                raw=self.opts.get('raw', False),
-                ret=self.opts.get('return', ''),
-                gather_job_timeout=self.opts['gather_job_timeout'],
-                jid=self.batch_jid,
-                metadata=self.metadata)
-            self.event.io_loop.call_later(self.opts['timeout'], self.find_job, set(next_batch))
             self.active = self.active.union(next_batch)
+            try:
+                yield self.local.run_job_async(
+                    next_batch,
+                    self.opts['fun'],
+                    self.opts['arg'],
+                    'list',
+                    raw=self.opts.get('raw', False),
+                    ret=self.opts.get('return', ''),
+                    gather_job_timeout=self.opts['gather_job_timeout'],
+                    jid=self.batch_jid,
+                    metadata=self.metadata)
+                self.event.io_loop.call_later(self.opts['timeout'], self.find_job, set(next_batch))
+            except Exception as ex:
+                self.active = self.active.difference(next_batch)
+        else:
+            self.end_batch()
diff --git a/tests/unit/cli/test_batch_async.py b/tests/unit/cli/test_batch_async.py
index f65b6a06c3..d519157d92 100644
--- a/tests/unit/cli/test_batch_async.py
+++ b/tests/unit/cli/test_batch_async.py
@@ -75,8 +75,8 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
             self.batch.local.run_job_async.call_args[0],
             ('*', 'test.ping', [], 'glob')
         )
-        # assert down_minions == all minions matched by tgt
-        self.assertEqual(self.batch.down_minions, set(['foo', 'bar']))
+        # assert targeted_minions == all minions matched by tgt
+        self.assertEqual(self.batch.targeted_minions, set(['foo', 'bar']))
 
     @tornado.testing.gen_test
     def test_batch_start_on_gather_job_timeout(self):
@@ -121,7 +121,10 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
         self.assertEqual(len(self.batch.schedule_next.mock_calls), 1)
 
     def test_batch_fire_done_event(self):
+        self.batch.targeted_minions = {'foo', 'baz', 'bar'}
         self.batch.minions = set(['foo', 'bar'])
+        self.batch.done_minions = {'foo'}
+        self.batch.timedout_minions = {'bar'}
         self.batch.event = MagicMock()
         self.batch.metadata = {'mykey': 'myvalue'}
         self.batch.end_batch()
@@ -130,9 +133,9 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
             (
                 {
                     'available_minions': set(['foo', 'bar']),
-                    'done_minions': set(),
-                    'down_minions': set(),
-                    'timedout_minions': set(),
+                    'done_minions': self.batch.done_minions,
+                    'down_minions': {'baz'},
+                    'timedout_minions': self.batch.timedout_minions,
                     'metadata': self.batch.metadata
                 },
                 "salt/batch/1235/done"
@@ -212,7 +215,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
         self.assertEqual(self.batch._get_next(), set())
 
     def test_batch__event_handler_ping_return(self):
-        self.batch.down_minions = {'foo'}
+        self.batch.targeted_minions = {'foo'}
         self.batch.event = MagicMock(
             unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
         self.batch.start()
@@ -222,7 +225,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
         self.assertEqual(self.batch.done_minions, set())
 
     def test_batch__event_handler_call_start_batch_when_all_pings_return(self):
-        self.batch.down_minions = {'foo'}
+        self.batch.targeted_minions = {'foo'}
         self.batch.event = MagicMock(
             unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
         self.batch.start()
@@ -232,7 +235,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
             (self.batch.start_batch,))
 
     def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self):
-        self.batch.down_minions = {'foo', 'bar'}
+        self.batch.targeted_minions = {'foo', 'bar'}
         self.batch.event = MagicMock(
             unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'})))
         self.batch.start()
@@ -260,20 +263,10 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase):
         self.assertEqual(self.batch.find_job_returned, {'foo'})
 
     @tornado.testing.gen_test
-    def test_batch__event_handler_end_batch(self):
-        self.batch.event = MagicMock(
-            unpack=MagicMock(return_value=('salt/job/not-my-jid/ret/foo', {'id': 'foo'})))
-        future = tornado.gen.Future()
-        future.set_result({'minions': ['foo', 'bar', 'baz']})
-        self.batch.local.run_job_async.return_value = future
-        self.batch.start()
-        self.batch.initialized = True
-        self.assertEqual(self.batch.down_minions, {'foo', 'bar', 'baz'})
+    def test_batch_schedule_next_end_batch_when_no_next(self):
         self.batch.end_batch = MagicMock()
-        self.batch.minions = {'foo', 'bar', 'baz'}
-        self.batch.done_minions = {'foo', 'bar'}
-        self.batch.timedout_minions = {'baz'}
-        self.batch._BatchAsync__event_handler(MagicMock())
+        self.batch._get_next = MagicMock(return_value={})
+        self.batch.schedule_next()
         self.assertEqual(len(self.batch.end_batch.mock_calls), 1)
 
     @tornado.testing.gen_test
-- 
2.20.1