File CVE-2024-35255.patch of Package python-azure-identity.37645

From 9ce8f49284f279f052418e4af5adfe6412159f16 Mon Sep 17 00:00:00 2001
From: Paul Van Eck <paulvaneck@microsoft.com>
Date: Fri, 24 May 2024 01:48:40 +0000
Subject: [PATCH] [Identity] Add Azure Arc key validation checks

Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
---
 .../azure/identity/_credentials/azure_arc.py  |  66 ++++++++-
 .../tests/test_managed_identity.py            | 132 ++++++++++++++++--
 .../tests/test_managed_identity_async.py      | 123 +++++++++++++++-
 3 files changed, 302 insertions(+), 19 deletions(-)

diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
index ff38ddbc7a..d2c8b4aecb 100644
--- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
+++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
@@ -4,6 +4,7 @@
 # ------------------------------------
 import functools
 import os
+import sys
 from typing import TYPE_CHECKING
 
 from azure.core.exceptions import ClientAuthenticationError
@@ -29,7 +30,7 @@ class AzureArcCredential(ManagedIdentityBase):
             return ManagedIdentityClient(
                 _per_retry_policies=[ArcChallengeAuthPolicy()],
                 request_factory=functools.partial(_get_request, url),
-                **kwargs
+                **kwargs,
             )
         return None
 
@@ -72,16 +73,71 @@ def _get_secret_key(response):
     # expecting header with structure like 'Basic realm=<file path>'
     try:
         key_file = header.split("=")[1]
-    except IndexError:
+    except IndexError as ex:
         raise ClientAuthenticationError(
             message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
-        )
-    with open(key_file, "r") as file:
+        ) from ex
+
+    try:
+        _validate_key_file(key_file)
+    except ValueError as ex:
+        raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex
+
+    with open(key_file, "r", encoding="utf-8") as file:
         try:
             return file.read()
         except Exception as error:  # pylint:disable=broad-except
             # user is expected to have obtained read permission prior to this being called
-            raise ClientAuthenticationError(message="Could not read file {} contents: {}".format(key_file, error))
+            raise ClientAuthenticationError(
+                message="Could not read file {} contents: {}".format(key_file, error)
+            ) from error
+
+
+def _get_key_file_path() -> str:
+    """Returns the expected path for the Azure Arc MSI key file based on the current platform.
+
+    Only Linux and Windows are supported.
+
+    :return: The expected path.
+    :rtype: str
+    :raises ValueError: If the current platform is not supported.
+    """
+    if sys.platform.startswith("linux"):
+        return "/var/opt/azcmagent/tokens"
+    if sys.platform.startswith("win"):
+        program_data_path = os.environ.get("PROGRAMDATA")
+        if not program_data_path:
+            raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
+        return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
+    raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")
+
+
+def _validate_key_file(file_path: str) -> None:
+    """Validates that a given Azure Arc MSI file path is valid for use.
+
+    A valid file will:
+        1. Be in the expected path for the current platform.
+        2. Have a `.key` extension.
+        3. Be at most 4096 bytes in size.
+
+    :param str file_path: The path to the key file.
+    :raises ClientAuthenticationError: If the file path is invalid.
+    """
+    if not file_path:
+        raise ValueError("The file path must not be empty.")
+
+    if not os.path.exists(file_path):
+        raise ValueError(f"The file path does not exist: {file_path}")
+
+    expected_directory = _get_key_file_path()
+    if not os.path.dirname(file_path) == expected_directory:
+        raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")
+
+    if not file_path.endswith(".key"):
+        raise ValueError("The file path must have a '.key' extension.")
+
+    if os.path.getsize(file_path) > 4096:
+        raise ValueError("The file size must be less than or equal to 4096 bytes.")
 
 
 class ArcChallengeAuthPolicy(HTTPPolicy):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py
index 64afc210a2..b3108b58a5 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity.py
@@ -3,6 +3,7 @@
 # Licensed under the MIT License.
 # ------------------------------------
 import os
+import sys
 import time
 
 try:
@@ -893,9 +894,10 @@ def test_azure_arc(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = ManagedIdentityCredential(transport=transport).get_token(scope)
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 def test_azure_arc_tenant_id(tmpdir):
@@ -946,9 +948,10 @@ def test_azure_arc_tenant_id(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 def test_azure_arc_client_id():
@@ -960,10 +963,123 @@ def test_azure_arc_client_id():
             EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
         },
     ):
-        credential = ManagedIdentityCredential(client_id="some-guid")
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            credential = ManagedIdentityCredential(client_id="some-guid")
 
-    with pytest.raises(ClientAuthenticationError):
+    with pytest.raises(ClientAuthenticationError) as ex:
         credential.get_token("scope")
+    assert "not supported" in str(ex.value)
+
+
+def test_azure_arc_key_too_large(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    secret_key = "X" * 4097
+
+    key_file = tmp_path / "key_file.key"
+    key_file.write_text(secret_key)
+    assert key_file.read_text() == secret_key
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "file size" in str(ex.value)
+
+
+def test_azure_arc_key_not_exist(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with pytest.raises(ClientAuthenticationError) as ex:
+            ManagedIdentityCredential(transport=transport).get_token(scope)
+        assert "not exist" in str(ex.value)
+
+
+def test_azure_arc_key_invalid(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    key_file = tmp_path / "key_file.txt"
+    key_file.write_text("secret")
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "Unexpected file path" in str(ex.value)
+
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "extension" in str(ex.value)
 
 
 def test_token_exchange(tmpdir):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
index 577e5a9b90..0dfdfab167 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
@@ -857,9 +857,10 @@ async def test_azure_arc(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = await ManagedIdentityCredential(transport=transport).get_token(scope)
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 @pytest.mark.asyncio
@@ -910,9 +911,10 @@ async def test_azure_arc_tenant_id(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 @pytest.mark.asyncio
@@ -931,6 +933,115 @@ async def test_azure_arc_client_id():
         await credential.get_token("scope")
 
 
+@pytest.mark.asyncio
+async def test_azure_arc_key_too_large(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    secret_key = "X" * 4097
+
+    key_file = tmp_path / "key_file.key"
+    key_file.write_text(secret_key)
+    assert key_file.read_text() == secret_key
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "file size" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_not_exist(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=key_file"}),
+        ],
+    )
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with pytest.raises(ClientAuthenticationError) as ex:
+            await ManagedIdentityCredential(transport=transport).get_token(scope)
+        assert "not exist" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_invalid(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    key_file = tmp_path / "key_file.txt"
+    key_file.write_text("secret")
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "Unexpected file path" in str(ex.value)
+
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "extension" in str(ex.value)
+
+
 @pytest.mark.asyncio
 async def test_token_exchange(tmpdir):
     exchange_token = "exchange-token"
-- 
2.48.1

openSUSE Build Service is sponsored by