File jwt2.patch of Package python-djangorestframework-simplejwt

From ae1ea58daacb76a1d72e202198734053022a6efe Mon Sep 17 00:00:00 2001
From: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>
Date: Mon, 22 Feb 2021 16:51:38 +0100
Subject: [PATCH] Support jwt 2 (#376)

* Add PyJWT 2.0.0 support

* Many tests relied on passing a payload into encode and changing it payload by reference.
  In PyJWT 2.0.0, the payload is copied (`dict.copy()`).
* Verify was deprecated in favor of using the option verify_signature. This is reflected in backends.py. Unless you wrote verify=False, you are not effected by this change.

* Isort lint

Co-authored-by: Andrew-Chen-Wang <acwangpython@gmail.com>
---
 rest_framework_simplejwt/backends.py |  8 +++--
 tests/test_backends.py               | 44 ++++++++++++++++++----------
 2 files changed, 33 insertions(+), 19 deletions(-)

diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py
index 887ecac..de7128a 100644
--- a/rest_framework_simplejwt/backends.py
+++ b/rest_framework_simplejwt/backends.py
@@ -65,9 +65,11 @@ def decode(self, token, verify=True):
         signature check fails, or if its 'exp' claim indicates it has expired.
         """
         try:
-            return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
-                              audience=self.audience, issuer=self.issuer,
-                              options={'verify_aud': self.audience is not None})
+            return jwt.decode(
+                token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
+                audience=self.audience, issuer=self.issuer,
+                options={'verify_aud': self.audience is not None, "verify_signature": verify}
+            )
         except InvalidAlgorithmError as ex:
             raise TokenBackendError(_('Invalid algorithm specified')) from ex
         except InvalidTokenError:
diff --git a/tests/test_backends.py b/tests/test_backends.py
index f9a9a23..bfb8da6 100644
--- a/tests/test_backends.py
+++ b/tests/test_backends.py
@@ -3,11 +3,13 @@
 
 import jwt
 from django.test import TestCase
-from jwt import PyJWT, algorithms
+from jwt import PyJWS, algorithms
 
 from rest_framework_simplejwt.backends import TokenBackend
 from rest_framework_simplejwt.exceptions import TokenBackendError
-from rest_framework_simplejwt.utils import aware_utcnow, make_utc
+from rest_framework_simplejwt.utils import (
+    aware_utcnow, datetime_to_epoch, make_utc,
+)
 
 SECRET = 'not_secret'
 
@@ -163,9 +165,9 @@ def test_decode_hmac_with_expiry(self):
 
     def test_decode_hmac_with_invalid_sig(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
-        token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+        token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
         self.payload['foo'] = 'baz'
-        token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+        token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
 
         token_2_payload = token_2.rsplit('.', 1)[0]
         token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -176,9 +178,11 @@ def test_decode_hmac_with_invalid_sig(self):
 
     def test_decode_hmac_with_invalid_sig_no_verify(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
-        token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+        token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
         self.payload['foo'] = 'baz'
-        token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+        token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
+        # Payload copied
+        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
 
         token_2_payload = token_2.rsplit('.', 1)[0]
         token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -193,7 +197,9 @@ def test_decode_hmac_success(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
         self.payload['foo'] = 'baz'
 
-        token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+        token = jwt.encode(self.payload, SECRET, algorithm='HS256')
+        # Payload copied
+        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
 
         self.assertEqual(self.hmac_token_backend.decode(token), self.payload)
 
@@ -220,9 +226,9 @@ def test_decode_rsa_with_expiry(self):
 
     def test_decode_rsa_with_invalid_sig(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
-        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
         self.payload['foo'] = 'baz'
-        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
 
         token_2_payload = token_2.rsplit('.', 1)[0]
         token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -233,13 +239,15 @@ def test_decode_rsa_with_invalid_sig(self):
 
     def test_decode_rsa_with_invalid_sig_no_verify(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
-        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
         self.payload['foo'] = 'baz'
-        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
 
         token_2_payload = token_2.rsplit('.', 1)[0]
         token_1_sig = token_1.rsplit('.', 1)[-1]
         invalid_token = token_2_payload + '.' + token_1_sig
+        # Payload copied
+        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
 
         self.assertEqual(
             self.hmac_token_backend.decode(invalid_token, verify=False),
@@ -250,7 +258,9 @@ def test_decode_rsa_success(self):
         self.payload['exp'] = aware_utcnow() + timedelta(days=1)
         self.payload['foo'] = 'baz'
 
-        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
+        # Payload copied
+        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
 
         self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
 
@@ -260,21 +270,23 @@ def test_decode_aud_iss_success(self):
         self.payload['aud'] = AUDIENCE
         self.payload['iss'] = ISSUER
 
-        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
+        # Payload copied
+        self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
 
         self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
 
     def test_decode_when_algorithm_not_available(self):
-        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
 
-        pyjwt_without_rsa = PyJWT()
+        pyjwt_without_rsa = PyJWS()
         pyjwt_without_rsa.unregister_algorithm('RS256')
         with patch.object(jwt, 'decode', new=pyjwt_without_rsa.decode):
             with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
                 self.rsa_token_backend.decode(token)
 
     def test_decode_when_token_algorithm_does_not_match(self):
-        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+        token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
 
         with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
             self.hmac_token_backend.decode(token)
openSUSE Build Service is sponsored by