File delete-bad-api-token-files.patch of Package salt.14915

From 2a3e94535e41339aa3163d89151cee44e04fe5b1 Mon Sep 17 00:00:00 2001
From: Jochen Breuer <jbreuer@suse.de>
Date: Wed, 29 Jan 2020 15:49:35 +0100
Subject: [PATCH] Delete bad API token files

Under the following conditions and API token should be considered
invalid:

- The file is empty.
- We cannot deserialize the token from the file.
- The token exists but has no expiration date.
- The token exists but has expired.

All of these conditions necessitate deleting the token file. Otherwise
we should simply return an empty token.
---
 salt/auth/__init__.py             | 27 ++++++++++-----
 salt/exceptions.py                |  6 ++++
 salt/payload.py                   | 10 ++++--
 tests/unit/test_auth.py           | 69 +++++++++++++++++++++++++++++++++++++++
 tests/unit/tokens/test_localfs.py | 47 +++++++++++++++++++++++++-
 5 files changed, 147 insertions(+), 12 deletions(-)

diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py
index a8aefa7091..e44c94fb37 100644
--- a/salt/auth/__init__.py
+++ b/salt/auth/__init__.py
@@ -25,7 +25,9 @@ from salt.ext import six
 
 # Import salt libs
 import salt.config
+import salt.exceptions
 import salt.loader
+import salt.payload
 import salt.transport.client
 import salt.utils.args
 import salt.utils.dictupdate
@@ -34,7 +36,6 @@ import salt.utils.minions
 import salt.utils.user
 import salt.utils.versions
 import salt.utils.zeromq
-import salt.payload
 
 log = logging.getLogger(__name__)
 
@@ -246,16 +247,24 @@ class LoadAuth(object):
         Return the name associated with the token, or False if the token is
         not valid
         '''
-        tdata = self.tokens["{0}.get_token".format(self.opts['eauth_tokens'])](self.opts, tok)
-        if not tdata:
-            return {}
-
-        rm_tok = False
-        if 'expire' not in tdata:
-            # invalid token, delete it!
+        tdata = {}
+        try:
+            tdata = self.tokens["{0}.get_token".format(self.opts['eauth_tokens'])](self.opts, tok)
+        except salt.exceptions.SaltDeserializationError:
+            log.warning("Failed to load token %r - removing broken/empty file.", tok)
             rm_tok = True
-        if tdata.get('expire', '0') < time.time():
+        else:
+            if not tdata:
+                return {}
+            rm_tok = False
+
+        if tdata.get('expire', 0) < time.time():
+            # If expire isn't present in the token it's invalid and needs
+            # to be removed. Also, if it's present and has expired - in
+            # other words, the expiration is before right now, it should
+            # be removed.
             rm_tok = True
+
         if rm_tok:
             self.rm_token(tok)
 
diff --git a/salt/exceptions.py b/salt/exceptions.py
index cc6980b289..d0d865527d 100644
--- a/salt/exceptions.py
+++ b/salt/exceptions.py
@@ -351,6 +351,12 @@ class TokenAuthenticationError(SaltException):
     '''
 
 
+class SaltDeserializationError(SaltException):
+    '''
+    Thrown when salt cannot deserialize data.
+    '''
+
+
 class AuthorizationError(SaltException):
     '''
     Thrown when runner or wheel execution fails due to permissions
diff --git a/salt/payload.py b/salt/payload.py
index ea569c9f73..dc34cf4dab 100644
--- a/salt/payload.py
+++ b/salt/payload.py
@@ -18,7 +18,7 @@ import salt.crypt
 import salt.transport.frame
 import salt.utils.immutabletypes as immutabletypes
 import salt.utils.stringutils
-from salt.exceptions import SaltReqTimeoutError
+from salt.exceptions import SaltReqTimeoutError, SaltDeserializationError
 from salt.utils.data import CaseInsensitiveDict
 
 # Import third party libs
@@ -164,7 +164,13 @@ class Serial(object):
             )
             log.debug('Msgpack deserialization failure on message: %s', msg)
             gc.collect()
-            raise
+            raise six.raise_from(
+                SaltDeserializationError(
+                    'Could not deserialize msgpack message.'
+                    ' See log for more info.'
+                ),
+                exc,
+            )
         finally:
             gc.enable()
         return ret
diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py
index 5d88e82077..54c3915144 100644
--- a/tests/unit/test_auth.py
+++ b/tests/unit/test_auth.py
@@ -6,6 +6,8 @@
 # Import pytohn libs
 from __future__ import absolute_import, print_function, unicode_literals
 
+import time
+
 # Import Salt Testing libs
 from tests.support.unit import TestCase, skipIf
 from tests.support.mock import patch, call, NO_MOCK, NO_MOCK_REASON, MagicMock
@@ -14,6 +16,7 @@ from tests.support.mock import patch, call, NO_MOCK, NO_MOCK_REASON, MagicMock
 import salt.master
 from tests.support.case import ModuleCase
 from salt import auth
+from salt.exceptions import SaltDeserializationError
 import salt.utils.platform
 
 
@@ -37,6 +40,72 @@ class LoadAuthTestCase(TestCase):
             self.addCleanup(patcher.stop)
         self.lauth = auth.LoadAuth({})  # Load with empty opts
 
+    def test_get_tok_with_broken_file_will_remove_bad_token(self):
+        fake_get_token = MagicMock(side_effect=SaltDeserializationError('hi'))
+        patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
+        patch_get_token = patch.dict(
+            self.lauth.tokens,
+            {
+                'testfs.get_token': fake_get_token
+            },
+        )
+        mock_rm_token = MagicMock()
+        patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
+        with patch_opts, patch_get_token, patch_rm_token:
+            expected_token = 'fnord'
+            self.lauth.get_tok(expected_token)
+            mock_rm_token.assert_called_with(expected_token)
+
+    def test_get_tok_with_no_expiration_should_remove_bad_token(self):
+        fake_get_token = MagicMock(return_value={'no_expire_here': 'Nope'})
+        patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
+        patch_get_token = patch.dict(
+            self.lauth.tokens,
+            {
+                'testfs.get_token': fake_get_token
+            },
+        )
+        mock_rm_token = MagicMock()
+        patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
+        with patch_opts, patch_get_token, patch_rm_token:
+            expected_token = 'fnord'
+            self.lauth.get_tok(expected_token)
+            mock_rm_token.assert_called_with(expected_token)
+
+    def test_get_tok_with_expire_before_current_time_should_remove_token(self):
+        fake_get_token = MagicMock(return_value={'expire': time.time()-1})
+        patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
+        patch_get_token = patch.dict(
+            self.lauth.tokens,
+            {
+                'testfs.get_token': fake_get_token
+            },
+        )
+        mock_rm_token = MagicMock()
+        patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
+        with patch_opts, patch_get_token, patch_rm_token:
+            expected_token = 'fnord'
+            self.lauth.get_tok(expected_token)
+            mock_rm_token.assert_called_with(expected_token)
+
+    def test_get_tok_with_valid_expiration_should_return_token(self):
+        expected_token = {'expire': time.time()+1}
+        fake_get_token = MagicMock(return_value=expected_token)
+        patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
+        patch_get_token = patch.dict(
+            self.lauth.tokens,
+            {
+                'testfs.get_token': fake_get_token
+            },
+        )
+        mock_rm_token = MagicMock()
+        patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
+        with patch_opts, patch_get_token, patch_rm_token:
+            token_name = 'fnord'
+            actual_token = self.lauth.get_tok(token_name)
+            mock_rm_token.assert_not_called()
+            assert expected_token is actual_token, 'Token was not returned'
+
     def test_load_name(self):
         valid_eauth_load = {'username': 'test_user',
                             'show_timeout': False,
diff --git a/tests/unit/tokens/test_localfs.py b/tests/unit/tokens/test_localfs.py
index f950091252..b7d86d9f23 100644
--- a/tests/unit/tokens/test_localfs.py
+++ b/tests/unit/tokens/test_localfs.py
@@ -1,10 +1,14 @@
 # -*- coding: utf-8 -*-
+'''
+Tests the localfs tokens interface.
+'''
 from __future__ import absolute_import, print_function, unicode_literals
 
 import os
 
-import salt.utils.files
+import salt.exceptions
 import salt.tokens.localfs
+import salt.utils.files
 
 from tests.support.unit import TestCase, skipIf
 from tests.support.helpers import with_tempdir
@@ -51,3 +55,44 @@ class WriteTokenTest(TestCase):
         assert rename.called_with == [
             ((temp_t_path, t_path), {})
         ], rename.called_with
+
+
+class TestLocalFS(TestCase):
+    def setUp(self):
+        # Default expected data
+        self.expected_data = {'this': 'is', 'some': 'token data'}
+
+    @with_tempdir()
+    def test_get_token_should_return_token_if_exists(self, tempdir):
+        opts = {'token_dir': tempdir}
+        tok = salt.tokens.localfs.mk_token(
+            opts=opts,
+            tdata=self.expected_data,
+        )['token']
+        actual_data = salt.tokens.localfs.get_token(opts=opts, tok=tok)
+        self.assertDictEqual(self.expected_data, actual_data)
+
+    @with_tempdir()
+    def test_get_token_should_raise_SaltDeserializationError_if_token_file_is_empty(self, tempdir):
+        opts = {'token_dir': tempdir}
+        tok = salt.tokens.localfs.mk_token(
+            opts=opts,
+            tdata=self.expected_data,
+        )['token']
+        with open(os.path.join(tempdir, tok), 'w') as f:
+            f.truncate()
+        with self.assertRaises(salt.exceptions.SaltDeserializationError) as e:
+            salt.tokens.localfs.get_token(opts=opts, tok=tok)
+
+    @with_tempdir()
+    def test_get_token_should_raise_SaltDeserializationError_if_token_file_is_malformed(self, tempdir):
+        opts = {'token_dir': tempdir}
+        tok = salt.tokens.localfs.mk_token(
+            opts=opts,
+            tdata=self.expected_data,
+        )['token']
+        with open(os.path.join(tempdir, tok), 'w') as f:
+            f.truncate()
+            f.write('this is not valid msgpack data')
+        with self.assertRaises(salt.exceptions.SaltDeserializationError) as e:
+            salt.tokens.localfs.get_token(opts=opts, tok=tok)
-- 
2.16.4