File CVE-2024-33663.patch of Package python-python-jose.18359

From 34bd82c43ea31da5b9deaa25ff591905a180bdf7 Mon Sep 17 00:00:00 2001
From: Daniel Garcia Moreno <daniel.garcia@suse.com>
Date: Thu, 2 May 2024 09:29:54 +0200
Subject: [PATCH 1/4] Improve asymmetric key check in CryptographyHMACKey

This change should fix https://github.com/mpdavis/python-jose/issues/346
security issue.

The code is based on pyjwt change:
https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc
---
 jose/backends/cryptography_backend.py | 72 ++++++++++++++++++++++++---
 tests/test_jwt.py                     | 35 ++++++++++++-
 2 files changed, 98 insertions(+), 9 deletions(-)

Index: python-jose-3.0.1/tests/test_jwt.py
===================================================================
--- python-jose-3.0.1.orig/tests/test_jwt.py
+++ python-jose-3.0.1/tests/test_jwt.py
@@ -4,7 +4,8 @@ import json
 
 from jose import jws
 from jose import jwt
-from jose.exceptions import JWTError
+from jose.constants import ALGORITHMS
+from jose.exceptions import JWTError, JWKError
 
 from datetime import datetime
 from datetime import timedelta
@@ -101,7 +102,7 @@ class TestJWT:
 
     def test_non_default_headers(self, claims, key, headers):
         encoded = jwt.encode(claims, key, headers=headers)
-        decoded = jwt.decode(encoded, key)
+        decoded = jwt.decode(encoded, key, algorithms=ALGORITHMS.HS256)
         assert claims == decoded
         all_headers = jwt.get_unverified_headers(encoded)
         for k, v in headers.items():
@@ -134,7 +135,7 @@ class TestJWT:
             '.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
         )
 
-        decoded = jwt.decode(token, key)
+        decoded = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED)
 
         assert decoded == claims
 
@@ -155,7 +156,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, options=options)
+        jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
 
     def test_iat_not_int(self, key):
 
@@ -166,7 +167,7 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_nbf_not_int(self, key):
 
@@ -177,7 +178,7 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_nbf_datetime(self, key):
 
@@ -188,7 +189,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key)
+        jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_nbf_with_leeway(self, key):
 
@@ -203,7 +204,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, options=options)
+        jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
 
     def test_nbf_in_future(self, key):
 
@@ -216,7 +217,7 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_nbf_skip(self, key):
 
@@ -229,13 +230,13 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
         options = {
             'verify_nbf': False
         }
 
-        jwt.decode(token, key, options=options)
+        jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
 
     def test_exp_not_int(self, key):
 
@@ -246,7 +247,7 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_exp_datetime(self, key):
 
@@ -257,7 +258,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key)
+        jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_exp_with_leeway(self, key):
 
@@ -272,7 +273,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, options=options)
+        jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
 
     def test_exp_in_past(self, key):
 
@@ -285,7 +286,7 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_exp_skip(self, key):
 
@@ -298,13 +299,13 @@ class TestJWT:
         token = jwt.encode(claims, key)
 
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
         options = {
             'verify_exp': False
         }
 
-        jwt.decode(token, key, options=options)
+        jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
 
     def test_aud_string(self, key):
 
@@ -315,7 +316,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, audience=aud)
+        jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
 
     def test_aud_list(self, key):
 
@@ -326,7 +327,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, audience=aud)
+        jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
 
     def test_aud_list_multiple(self, key):
 
@@ -337,7 +338,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, audience=aud)
+        jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
 
     def test_aud_list_is_strings(self, key):
 
@@ -349,7 +350,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, audience=aud)
+            jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
 
     def test_aud_case_sensitive(self, key):
 
@@ -361,14 +362,14 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, audience='AUDIENCE')
+            jwt.decode(token, key, audience='AUDIENCE', algorithms=ALGORITHMS.HS256)
 
     def test_aud_empty_claim(self, claims, key):
 
         aud = 'audience'
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, audience=aud)
+        jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
 
     def test_aud_not_string_or_list(self, key):
 
@@ -380,7 +381,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_aud_given_number(self, key):
 
@@ -392,7 +393,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, audience=1)
+            jwt.decode(token, key, audience=1, algorithms=ALGORITHMS.HS256)
 
     def test_iss_string(self, key):
 
@@ -403,7 +404,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, issuer=iss)
+        jwt.decode(token, key, issuer=iss, algorithms=ALGORITHMS.HS256)
 
     def test_iss_list(self, key):
 
@@ -414,7 +415,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, issuer=['https://issuer', 'issuer'])
+        jwt.decode(token, key, issuer=['https://issuer', 'issuer'], algorithms=ALGORITHMS.HS256)
 
     def test_iss_tuple(self, key):
 
@@ -425,7 +426,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, issuer=('https://issuer', 'issuer'))
+        jwt.decode(token, key, issuer=('https://issuer', 'issuer'), algorithms=ALGORITHMS.HS256)
 
     def test_iss_invalid(self, key):
 
@@ -437,7 +438,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, issuer='another')
+            jwt.decode(token, key, issuer='another', algorithms=ALGORITHMS.HS256)
 
     def test_sub_string(self, key):
 
@@ -448,7 +449,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key)
+        jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_sub_invalid(self, key):
 
@@ -460,7 +461,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_sub_correct(self, key):
 
@@ -471,7 +472,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key, subject=sub)
+        jwt.decode(token, key, subject=sub, algorithms=ALGORITHMS.HS256)
 
     def test_sub_incorrect(self, key):
 
@@ -483,7 +484,7 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, subject='another')
+            jwt.decode(token, key, subject='another', algorithms=ALGORITHMS.HS256)
 
     def test_jti_string(self, key):
 
@@ -494,7 +495,7 @@ class TestJWT:
         }
 
         token = jwt.encode(claims, key)
-        jwt.decode(token, key)
+        jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_jti_invalid(self, key):
 
@@ -506,33 +507,33 @@ class TestJWT:
 
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_at_hash(self, claims, key):
         access_token = '<ACCESS_TOKEN>'
         token = jwt.encode(claims, key, access_token=access_token)
-        payload = jwt.decode(token, key, access_token=access_token)
+        payload = jwt.decode(token, key, access_token=access_token, algorithms=ALGORITHMS.HS256)
         assert 'at_hash' in payload
 
     def test_at_hash_invalid(self, claims, key):
         token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
         with pytest.raises(JWTError):
-            jwt.decode(token, key, access_token='<OTHER_TOKEN>')
+            jwt.decode(token, key, access_token='<OTHER_TOKEN>', algorithms=ALGORITHMS.HS256)
 
     def test_at_hash_missing_access_token(self, claims, key):
         token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
         with pytest.raises(JWTError):
-            jwt.decode(token, key)
+            jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
 
     def test_at_hash_missing_claim(self, claims, key):
         token = jwt.encode(claims, key)
         with pytest.raises(JWTError):
-            jwt.decode(token, key, access_token='<ACCESS_TOKEN>')
+            jwt.decode(token, key, access_token='<ACCESS_TOKEN>', algorithms=ALGORITHMS.HS256)
 
     def test_at_hash_unable_to_calculate(self, claims, key):
         token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
         with pytest.raises(JWTError):
-            jwt.decode(token, key, access_token='\xe2')
+            jwt.decode(token, key, access_token='\xe2', algorithms=ALGORITHMS.HS256)
 
     def test_bad_claims(self):
         bad_token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck'
@@ -552,3 +553,42 @@ class TestJWT:
     def test_unverified_claims_object(self, claims, key):
         token = jwt.encode(claims, key)
         assert jwt.get_unverified_claims(token) == claims
+
+    def test_CVE_2024_33663(self):
+        """Test based on https://github.com/mpdavis/python-jose/issues/346"""
+        try:
+            from Crypto.PublicKey import ECC
+            from Crypto.Hash import HMAC, SHA256
+        except ModuleNotFoundError:
+            pytest.skip("pycryptodome module not installed")
+
+        # ----- SETUP -----
+        # generate an asymmetric ECC keypair
+        # !! signing should only be possible with the private key !!
+        KEY = ECC.generate(curve='P-256')
+
+        # PUBLIC KEY, AVAILABLE TO USER
+        # CAN BE RECOVERED THROUGH E.G. PUBKEY RECOVERY WITH TWO SIGNATURES:
+        # https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm#Public_key_recovery
+        # https://github.com/FlorianPicca/JWT-Key-Recovery
+        PUBKEY = KEY.public_key().export_key(format='OpenSSH').encode()
+
+        # ---- CLIENT SIDE -----
+        # without knowing the private key, a valid token can be constructed
+        # YIKES!!
+
+        b64 = lambda x:base64.urlsafe_b64encode(x).replace(b'=',b'')
+        payload = b64(b'{"alg":"HS256"}') + b'.' + b64(b'{"pwned":true}')
+        hasher = HMAC.new(PUBKEY, digestmod=SHA256)
+        hasher.update(payload)
+        evil_token = payload + b'.' + b64(hasher.digest())
+
+        # ---- SERVER SIDE -----
+        # verify and decode the token using the public key, as is custom
+        # algorithm field is left unspecified
+        # but the library will happily still verify without warning, trusting the user-controlled alg field of the token header
+        with pytest.raises(JWKError):
+            data = jwt.decode(evil_token, PUBKEY.decode(), algorithms=ALGORITHMS.HS256)
+
+        with pytest.raises(JWTError, match='.*required.*"algorithms".*'):
+            data = jwt.decode(evil_token, PUBKEY.decode())
Index: python-jose-3.0.1/jose/jwt.py
===================================================================
--- python-jose-3.0.1.orig/jose/jwt.py
+++ python-jose-3.0.1/jose/jwt.py
@@ -130,6 +130,13 @@ def decode(token, key, algorithms=None,
         defaults.update(options)
 
     verify_signature = defaults.get('verify_signature', True)
+    # Forbid the usage of the jwt.decode without alogrightms parameter
+    # See https://github.com/mpdavis/python-jose/issues/346 for more
+    # information CVE-2024-33663
+    if verify_signature and algorithms is None:
+        raise JWTError("It is required that you pass in a value for "
+                       'the "algorithms" argument when calling '
+                       "decode().")
 
     try:
         payload = jws.verify(token, key, algorithms, verify=verify_signature)
Index: python-jose-3.0.1/jose/utils.py
===================================================================
--- python-jose-3.0.1.orig/jose/utils.py
+++ python-jose-3.0.1/jose/utils.py
@@ -1,4 +1,4 @@
-
+import re
 import base64
 import hmac
 import six
@@ -132,3 +132,75 @@ def constant_time_string_compare(a, b):
             result |= ord(x) ^ ord(y)
 
         return result == 0
+
+
+# Based on https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc
+# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
+_PEMS = {
+    b"CERTIFICATE",
+    b"TRUSTED CERTIFICATE",
+    b"PRIVATE KEY",
+    b"PUBLIC KEY",
+    b"ENCRYPTED PRIVATE KEY",
+    b"OPENSSH PRIVATE KEY",
+    b"DSA PRIVATE KEY",
+    b"RSA PRIVATE KEY",
+    b"RSA PUBLIC KEY",
+    b"EC PRIVATE KEY",
+    b"DH PARAMETERS",
+    b"NEW CERTIFICATE REQUEST",
+    b"CERTIFICATE REQUEST",
+    b"SSH2 PUBLIC KEY",
+    b"SSH2 ENCRYPTED PRIVATE KEY",
+    b"X509 CRL",
+}
+
+
+_PEM_RE = re.compile(
+    b"----[- ]BEGIN ("
+    + b"|".join(_PEMS)
+    + b""")[- ]----\r?
+.+?\r?
+----[- ]END \\1[- ]----\r?\n?""",
+    re.DOTALL,
+)
+
+
+def is_pem_format(key):
+    """
+    Return True if the key is PEM format
+    This function uses the list of valid PEM headers defined in
+    _PEMS dict.
+    """
+    return bool(_PEM_RE.search(key))
+
+
+# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
+_CERT_SUFFIX = b"-cert-v01@openssh.com"
+_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
+_SSH_KEY_FORMATS = [
+    b"ssh-ed25519",
+    b"ssh-rsa",
+    b"ssh-dss",
+    b"ecdsa-sha2-nistp256",
+    b"ecdsa-sha2-nistp384",
+    b"ecdsa-sha2-nistp521",
+]
+
+
+def is_ssh_key(key):
+    """
+    Return True if the key is a SSH key
+    This function uses the list of valid SSH key format defined in
+    _SSH_KEY_FORMATS dict.
+    """
+    if any(string_value in key for string_value in _SSH_KEY_FORMATS):
+        return True
+
+    ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
+    if ssh_pubkey_match:
+        key_type = ssh_pubkey_match.group(1)
+        if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
+            return True
+
+    return False
Index: python-jose-3.0.1/tests/algorithms/test_HMAC.py
===================================================================
--- python-jose-3.0.1.orig/tests/algorithms/test_HMAC.py
+++ python-jose-3.0.1/tests/algorithms/test_HMAC.py
@@ -14,14 +14,17 @@ class TestHMACAlgorithm:
 
     def test_RSA_key(self):
         key = "-----BEGIN PUBLIC KEY-----"
+        key += "\n\n\n-----END PUBLIC KEY-----"
         with pytest.raises(JOSEError):
             HMACKey(key, ALGORITHMS.HS256)
 
         key = "-----BEGIN RSA PUBLIC KEY-----"
+        key += "\n\n\n-----END RSA PUBLIC KEY-----"
         with pytest.raises(JOSEError):
             HMACKey(key, ALGORITHMS.HS256)
 
         key = "-----BEGIN CERTIFICATE-----"
+        key += "\n\n\n-----END CERTIFICATE-----"
         with pytest.raises(JOSEError):
             HMACKey(key, ALGORITHMS.HS256)
 
Index: python-jose-3.0.1/jose/jwk.py
===================================================================
--- python-jose-3.0.1.orig/jose/jwk.py
+++ python-jose-3.0.1/jose/jwk.py
@@ -7,6 +7,7 @@ from jose.constants import ALGORITHMS
 from jose.exceptions import JWKError
 from jose.utils import base64url_decode, base64url_encode
 from jose.utils import constant_time_string_compare
+from jose.utils import is_pem_format, is_ssh_key
 from jose.backends.base import Key
 
 try:
@@ -103,17 +104,11 @@ class HMACKey(Key):
         if isinstance(key, six.text_type):
             key = key.encode('utf-8')
 
-        invalid_strings = [
-            b'-----BEGIN PUBLIC KEY-----',
-            b'-----BEGIN RSA PUBLIC KEY-----',
-            b'-----BEGIN CERTIFICATE-----',
-            b'ssh-rsa'
-        ]
-
-        if any(string_value in key for string_value in invalid_strings):
+        if is_pem_format(key) or is_ssh_key(key):
             raise JWKError(
-                'The specified key is an asymmetric key or x509 certificate and'
-                ' should not be used as an HMAC secret.')
+                "The specified key is an asymmetric key or x509 certificate and"
+                " should not be used as an HMAC secret."
+            )
 
         self.prepared_key = key
 
openSUSE Build Service is sponsored by