File enable-keepalive-probes-for-salt-ssh-executions-bsc-.patch of Package salt.35892
From 5303cc612bcbdb1ec45ede397ca1e2ca12ba3bd3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pablo=20Su=C3=A1rez=20Hern=C3=A1ndez?=
 <psuarezhernandez@suse.com>
Date: Fri, 1 Dec 2023 10:59:30 +0000
Subject: [PATCH] Enable "KeepAlive" probes for Salt SSH executions
 (bsc#1211649) (#610)
* Enable KeepAlive probes for Salt SSH connections (bsc#1211649)
* Add tests for Salt SSH keepalive options
* Add changelog file
* Make changes suggested by pre-commit
---
 changelog/65488.added.md                     |  1 +
 salt/client/ssh/__init__.py                  | 32 +++++++++---
 salt/client/ssh/client.py                    | 13 ++++-
 salt/client/ssh/shell.py                     | 12 +++++
 salt/config/__init__.py                      |  6 +++
 salt/utils/parsers.py                        | 19 +++++++
 tests/pytests/unit/client/ssh/test_single.py | 55 ++++++++++++++++++++
 tests/pytests/unit/client/ssh/test_ssh.py    |  3 ++
 8 files changed, 133 insertions(+), 8 deletions(-)
 create mode 100644 changelog/65488.added.md
diff --git a/changelog/65488.added.md b/changelog/65488.added.md
new file mode 100644
index 0000000000..78476cec11
--- /dev/null
+++ b/changelog/65488.added.md
@@ -0,0 +1 @@
+Enable "KeepAlive" probes for Salt SSH executions
diff --git a/salt/client/ssh/__init__.py b/salt/client/ssh/__init__.py
index 1e143f9e30..1d8426b7c2 100644
--- a/salt/client/ssh/__init__.py
+++ b/salt/client/ssh/__init__.py
@@ -50,8 +50,8 @@ import salt.utils.thin
 import salt.utils.url
 import salt.utils.verify
 from salt._logging import LOG_LEVELS
-from salt._logging.mixins import MultiprocessingStateMixin
 from salt._logging.impl import LOG_LOCK
+from salt._logging.mixins import MultiprocessingStateMixin
 from salt.template import compile_template
 from salt.utils.process import Process
 from salt.utils.zeromq import zmq
@@ -307,6 +307,18 @@ class SSH(MultiprocessingStateMixin):
                 "ssh_timeout", salt.config.DEFAULT_MASTER_OPTS["ssh_timeout"]
             )
             + self.opts.get("timeout", salt.config.DEFAULT_MASTER_OPTS["timeout"]),
+            "keepalive": self.opts.get(
+                "ssh_keepalive",
+                salt.config.DEFAULT_MASTER_OPTS["ssh_keepalive"],
+            ),
+            "keepalive_interval": self.opts.get(
+                "ssh_keepalive_interval",
+                salt.config.DEFAULT_MASTER_OPTS["ssh_keepalive_interval"],
+            ),
+            "keepalive_count_max": self.opts.get(
+                "ssh_keepalive_count_max",
+                salt.config.DEFAULT_MASTER_OPTS["ssh_keepalive_count_max"],
+            ),
             "sudo": self.opts.get(
                 "ssh_sudo", salt.config.DEFAULT_MASTER_OPTS["ssh_sudo"]
             ),
@@ -557,7 +569,7 @@ class SSH(MultiprocessingStateMixin):
             mods=self.mods,
             fsclient=self.fsclient,
             thin=self.thin,
-            **target
+            **target,
         )
         if salt.utils.path.which("ssh-copy-id"):
             # we have ssh-copy-id, use it!
@@ -573,7 +585,7 @@ class SSH(MultiprocessingStateMixin):
                 mods=self.mods,
                 fsclient=self.fsclient,
                 thin=self.thin,
-                **target
+                **target,
             )
             stdout, stderr, retcode = single.cmd_block()
             try:
@@ -601,7 +613,7 @@ class SSH(MultiprocessingStateMixin):
             fsclient=self.fsclient,
             thin=self.thin,
             mine=mine,
-            **target
+            **target,
         )
         ret = {"id": single.id}
         stdout, stderr, retcode = single.run()
@@ -1022,7 +1034,10 @@ class Single:
         remote_port_forwards=None,
         winrm=False,
         ssh_options=None,
-        **kwargs
+        keepalive=True,
+        keepalive_interval=60,
+        keepalive_count_max=3,
+        **kwargs,
     ):
         # Get mine setting and mine_functions if defined in kwargs (from roster)
         self.mine = mine
@@ -1081,6 +1096,9 @@ class Single:
             "priv": priv,
             "priv_passwd": priv_passwd,
             "timeout": timeout,
+            "keepalive": keepalive,
+            "keepalive_interval": keepalive_interval,
+            "keepalive_count_max": keepalive_count_max,
             "sudo": sudo,
             "tty": tty,
             "mods": self.mods,
@@ -1302,7 +1320,7 @@ class Single:
                 self.id,
                 fsclient=self.fsclient,
                 minion_opts=self.minion_opts,
-                **self.target
+                **self.target,
             )
 
             opts_pkg = pre_wrapper["test.opts_pkg"]()  # pylint: disable=E1102
@@ -1388,7 +1406,7 @@ class Single:
             self.id,
             fsclient=self.fsclient,
             minion_opts=self.minion_opts,
-            **self.target
+            **self.target,
         )
         wrapper.fsclient.opts["cachedir"] = opts["cachedir"]
         self.wfuncs = salt.loader.ssh_wrapper(opts, wrapper, self.context)
diff --git a/salt/client/ssh/client.py b/salt/client/ssh/client.py
index 0b67598fc6..a00f5de423 100644
--- a/salt/client/ssh/client.py
+++ b/salt/client/ssh/client.py
@@ -52,6 +52,9 @@ class SSHClient:
             ("ssh_priv_passwd", str),
             ("ssh_identities_only", bool),
             ("ssh_remote_port_forwards", str),
+            ("ssh_keepalive", bool),
+            ("ssh_keepalive_interval", int),
+            ("ssh_keepalive_count_max", int),
             ("ssh_options", list),
             ("ssh_max_procs", int),
             ("ssh_askpass", bool),
@@ -108,7 +111,15 @@ class SSHClient:
         return sane_kwargs
 
     def _prep_ssh(
-        self, tgt, fun, arg=(), timeout=None, tgt_type="glob", kwarg=None, context=None, **kwargs
+        self,
+        tgt,
+        fun,
+        arg=(),
+        timeout=None,
+        tgt_type="glob",
+        kwarg=None,
+        context=None,
+        **kwargs
     ):
         """
         Prepare the arguments
diff --git a/salt/client/ssh/shell.py b/salt/client/ssh/shell.py
index bc1ad034df..182e2c19e3 100644
--- a/salt/client/ssh/shell.py
+++ b/salt/client/ssh/shell.py
@@ -85,6 +85,9 @@ class Shell:
         remote_port_forwards=None,
         winrm=False,
         ssh_options=None,
+        keepalive=True,
+        keepalive_interval=None,
+        keepalive_count_max=None,
     ):
         self.opts = opts
         # ssh <ipv6>, but scp [<ipv6]:/path
@@ -95,6 +98,9 @@ class Shell:
         self.priv = priv
         self.priv_passwd = priv_passwd
         self.timeout = timeout
+        self.keepalive = keepalive
+        self.keepalive_interval = keepalive_interval
+        self.keepalive_count_max = keepalive_count_max
         self.sudo = sudo
         self.tty = tty
         self.mods = mods
@@ -130,6 +136,9 @@ class Shell:
         if self.opts.get("_ssh_version", (0,)) > (4, 9):
             options.append("GSSAPIAuthentication=no")
         options.append("ConnectTimeout={}".format(self.timeout))
+        if self.keepalive:
+            options.append(f"ServerAliveInterval={self.keepalive_interval}")
+            options.append(f"ServerAliveCountMax={self.keepalive_count_max}")
         if self.opts.get("ignore_host_keys"):
             options.append("StrictHostKeyChecking=no")
         if self.opts.get("no_host_keys"):
@@ -165,6 +174,9 @@ class Shell:
         if self.opts["_ssh_version"] > (4, 9):
             options.append("GSSAPIAuthentication=no")
         options.append("ConnectTimeout={}".format(self.timeout))
+        if self.keepalive:
+            options.append(f"ServerAliveInterval={self.keepalive_interval}")
+            options.append(f"ServerAliveCountMax={self.keepalive_count_max}")
         if self.opts.get("ignore_host_keys"):
             options.append("StrictHostKeyChecking=no")
         if self.opts.get("no_host_keys"):
diff --git a/salt/config/__init__.py b/salt/config/__init__.py
index d8258a4dbc..68f2b0f674 100644
--- a/salt/config/__init__.py
+++ b/salt/config/__init__.py
@@ -822,6 +822,9 @@ VALID_OPTS = immutabletypes.freeze(
         "ssh_scan_ports": str,
         "ssh_scan_timeout": float,
         "ssh_identities_only": bool,
+        "ssh_keepalive": bool,
+        "ssh_keepalive_interval": int,
+        "ssh_keepalive_count_max": int,
         "ssh_log_file": str,
         "ssh_config_file": str,
         "ssh_merge_pillar": bool,
@@ -1592,6 +1595,9 @@ DEFAULT_MASTER_OPTS = immutabletypes.freeze(
         "ssh_scan_ports": "22",
         "ssh_scan_timeout": 0.01,
         "ssh_identities_only": False,
+        "ssh_keepalive": True,
+        "ssh_keepalive_interval": 60,
+        "ssh_keepalive_count_max": 3,
         "ssh_log_file": os.path.join(salt.syspaths.LOGS_DIR, "ssh"),
         "ssh_config_file": os.path.join(salt.syspaths.HOME_DIR, ".ssh", "config"),
         "cluster_mode": False,
diff --git a/salt/utils/parsers.py b/salt/utils/parsers.py
index dc125de7d7..6c7f9f2f66 100644
--- a/salt/utils/parsers.py
+++ b/salt/utils/parsers.py
@@ -3383,6 +3383,25 @@ class SaltSSHOptionParser(
                 "-R parameters."
             ),
         )
+        ssh_group.add_option(
+            "--disable-keepalive",
+            default=True,
+            action="store_false",
+            dest="ssh_keepalive",
+            help=(
+                "Disable KeepAlive probes (ServerAliveInterval) for the SSH connection."
+            ),
+        )
+        ssh_group.add_option(
+            "--keepalive-interval",
+            dest="ssh_keepalive_interval",
+            help=("Define the value for ServerAliveInterval option."),
+        )
+        ssh_group.add_option(
+            "--keepalive-count-max",
+            dest="ssh_keepalive_count_max",
+            help=("Define the value for ServerAliveCountMax option."),
+        )
         ssh_group.add_option(
             "--ssh-option",
             dest="ssh_options",
diff --git a/tests/pytests/unit/client/ssh/test_single.py b/tests/pytests/unit/client/ssh/test_single.py
index c88a1c2127..8d87da8700 100644
--- a/tests/pytests/unit/client/ssh/test_single.py
+++ b/tests/pytests/unit/client/ssh/test_single.py
@@ -63,6 +63,61 @@ def test_single_opts(opts, target):
         **target,
     )
 
+    assert single.shell._ssh_opts() == ""
+    expected_cmd = (
+        "ssh login1 "
+        "-o KbdInteractiveAuthentication=no -o "
+        "PasswordAuthentication=yes -o ConnectTimeout=65 -o ServerAliveInterval=60 "
+        "-o ServerAliveCountMax=3 -o Port=22 "
+        "-o IdentityFile=/etc/salt/pki/master/ssh/salt-ssh.rsa "
+        "-o User=root  date +%s"
+    )
+    assert single.shell._cmd_str("date +%s") == expected_cmd
+
+
+def test_single_opts_custom_keepalive_options(opts, target):
+    """Sanity check for ssh.Single options with custom keepalive"""
+
+    single = ssh.Single(
+        opts,
+        opts["argv"],
+        "localhost",
+        mods={},
+        fsclient=None,
+        thin=salt.utils.thin.thin_path(opts["cachedir"]),
+        mine=False,
+        keepalive_interval=15,
+        keepalive_count_max=5,
+        **target,
+    )
+
+    assert single.shell._ssh_opts() == ""
+    expected_cmd = (
+        "ssh login1 "
+        "-o KbdInteractiveAuthentication=no -o "
+        "PasswordAuthentication=yes -o ConnectTimeout=65 -o ServerAliveInterval=15 "
+        "-o ServerAliveCountMax=5 -o Port=22 "
+        "-o IdentityFile=/etc/salt/pki/master/ssh/salt-ssh.rsa "
+        "-o User=root  date +%s"
+    )
+    assert single.shell._cmd_str("date +%s") == expected_cmd
+
+
+def test_single_opts_disable_keepalive(opts, target):
+    """Sanity check for ssh.Single options with custom keepalive"""
+
+    single = ssh.Single(
+        opts,
+        opts["argv"],
+        "localhost",
+        mods={},
+        fsclient=None,
+        thin=salt.utils.thin.thin_path(opts["cachedir"]),
+        mine=False,
+        keepalive=False,
+        **target,
+    )
+
     assert single.shell._ssh_opts() == ""
     expected_cmd = (
         "ssh login1 "
diff --git a/tests/pytests/unit/client/ssh/test_ssh.py b/tests/pytests/unit/client/ssh/test_ssh.py
index cece16026c..23223ba8ec 100644
--- a/tests/pytests/unit/client/ssh/test_ssh.py
+++ b/tests/pytests/unit/client/ssh/test_ssh.py
@@ -78,6 +78,9 @@ def roster():
         ("ssh_scan_ports", "test", True),
         ("ssh_scan_timeout", 1.0, True),
         ("ssh_timeout", 1, False),
+        ("ssh_keepalive", True, True),
+        ("ssh_keepalive_interval", 30, True),
+        ("ssh_keepalive_count_max", 3, True),
         ("ssh_log_file", "/tmp/test", True),
         ("raw_shell", True, True),
         ("refresh_cache", True, True),
-- 
2.42.0