File CVE-2024-35255.patch of Package python-azure-identity.35708

From 82c32f6d48178406ba6747d3e0f414dbc33532d5 Mon Sep 17 00:00:00 2001
From: Paul Van Eck <paulvaneck@microsoft.com>
Date: Mon, 10 Jun 2024 14:45:20 -0700
Subject: [PATCH] [Identity] Managed identity bug fix (#36010)

Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
---
 .../azure/identity/_credentials/azure_arc.py  |  56 +++++++-
 .../tests/test_managed_identity.py            | 132 ++++++++++++++++--
 .../tests/test_managed_identity_async.py      | 123 +++++++++++++++-
 3 files changed, 296 insertions(+), 15 deletions(-)

diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
index 68034300b8..859f625e15 100644
--- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
+++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
@@ -4,6 +4,7 @@
 # ------------------------------------
 import functools
 import os
+import sys
 from typing import Any, Dict, Optional
 
 from azure.core.exceptions import ClientAuthenticationError
@@ -24,7 +25,7 @@ class AzureArcCredential(ManagedIdentityBase):
             return ManagedIdentityClient(
                 _per_retry_policies=[ArcChallengeAuthPolicy()],
                 request_factory=functools.partial(_get_request, url),
-                **kwargs
+                **kwargs,
             )
         return None
 
@@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str:
         raise ClientAuthenticationError(
             message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
         ) from ex
+
+    try:
+        _validate_key_file(key_file)
+    except ValueError as ex:
+        raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex
+
     with open(key_file, "r", encoding="utf-8") as file:
         try:
             return file.read()
@@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str:
             ) from error
 
 
+def _get_key_file_path() -> str:
+    """Returns the expected path for the Azure Arc MSI key file based on the current platform.
+
+    Only Linux and Windows are supported.
+
+    :return: The expected path.
+    :rtype: str
+    :raises ValueError: If the current platform is not supported.
+    """
+    if sys.platform.startswith("linux"):
+        return "/var/opt/azcmagent/tokens"
+    if sys.platform.startswith("win"):
+        program_data_path = os.environ.get("PROGRAMDATA")
+        if not program_data_path:
+            raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
+        return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
+    raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")
+
+
+def _validate_key_file(file_path: str) -> None:
+    """Validates that a given Azure Arc MSI file path is valid for use.
+
+    A valid file will:
+        1. Be in the expected path for the current platform.
+        2. Have a `.key` extension.
+        3. Be at most 4096 bytes in size.
+
+    :param str file_path: The path to the key file.
+    :raises ClientAuthenticationError: If the file path is invalid.
+    """
+    if not file_path:
+        raise ValueError("The file path must not be empty.")
+
+    if not os.path.exists(file_path):
+        raise ValueError(f"The file path does not exist: {file_path}")
+
+    expected_directory = _get_key_file_path()
+    if not os.path.dirname(file_path) == expected_directory:
+        raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")
+
+    if not file_path.endswith(".key"):
+        raise ValueError("The file path must have a '.key' extension.")
+
+    if os.path.getsize(file_path) > 4096:
+        raise ValueError("The file size must be less than or equal to 4096 bytes.")
+
+
 class ArcChallengeAuthPolicy(HTTPPolicy):
     """Policy for handling Azure Arc's challenge authentication"""
 
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py
index 805e36343b..b0be4b91ee 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity.py
@@ -3,6 +3,7 @@
 # Licensed under the MIT License.
 # ------------------------------------
 import os
+import sys
 import time
 
 try:
@@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = ManagedIdentityCredential(transport=transport).get_token(scope)
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 def test_azure_arc_tenant_id(tmpdir):
@@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 def test_azure_arc_client_id():
@@ -950,10 +953,123 @@ def test_azure_arc_client_id():
             EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
         },
     ):
-        credential = ManagedIdentityCredential(client_id="some-guid")
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            credential = ManagedIdentityCredential(client_id="some-guid")
 
-    with pytest.raises(ClientAuthenticationError):
+    with pytest.raises(ClientAuthenticationError) as ex:
         credential.get_token("scope")
+    assert "not supported" in str(ex.value)
+
+
+def test_azure_arc_key_too_large(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    secret_key = "X" * 4097
+
+    key_file = tmp_path / "key_file.key"
+    key_file.write_text(secret_key)
+    assert key_file.read_text() == secret_key
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "file size" in str(ex.value)
+
+
+def test_azure_arc_key_not_exist(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with pytest.raises(ClientAuthenticationError) as ex:
+            ManagedIdentityCredential(transport=transport).get_token(scope)
+        assert "not exist" in str(ex.value)
+
+
+def test_azure_arc_key_invalid(tmp_path):
+
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    key_file = tmp_path / "key_file.txt"
+    key_file.write_text("secret")
+
+    transport = validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "Unexpected file path" in str(ex.value)
+
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "extension" in str(ex.value)
 
 
 def test_token_exchange(tmpdir):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
index f9c1981158..79d2548487 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
@@ -848,9 +848,10 @@ async def test_azure_arc(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = await ManagedIdentityCredential(transport=transport).get_token(scope)
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 @pytest.mark.asyncio
@@ -901,9 +902,10 @@ async def test_azure_arc_tenant_id(tmpdir):
         "os.environ",
         {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
     ):
-        token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
-        assert token.token == access_token
-        assert token.expires_on == expires_on
+        with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+            token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+            assert token.token == access_token
+            assert token.expires_on == expires_on
 
 
 @pytest.mark.asyncio
@@ -922,6 +924,115 @@ async def test_azure_arc_client_id():
         await credential.get_token("scope")
 
 
+@pytest.mark.asyncio
+async def test_azure_arc_key_too_large(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    secret_key = "X" * 4097
+
+    key_file = tmp_path / "key_file.key"
+    key_file.write_text(secret_key)
+    assert key_file.read_text() == secret_key
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "file size" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_not_exist(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=key_file"}),
+        ],
+    )
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with pytest.raises(ClientAuthenticationError) as ex:
+            await ManagedIdentityCredential(transport=transport).get_token(scope)
+        assert "not exist" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_invalid(tmp_path):
+    api_version = "2019-11-01"
+    identity_endpoint = "http://localhost:42/token"
+    imds_endpoint = "http://localhost:42"
+    scope = "scope"
+    key_file = tmp_path / "key_file.txt"
+    key_file.write_text("secret")
+
+    transport = async_validating_transport(
+        requests=[
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+            Request(
+                base_url=identity_endpoint,
+                method="GET",
+                required_headers={"Metadata": "true"},
+                required_params={"api-version": api_version, "resource": scope},
+            ),
+        ],
+        responses=[
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+            mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+        ],
+    )
+
+    with mock.patch(
+        "os.environ",
+        {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+    ):
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "Unexpected file path" in str(ex.value)
+
+        with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+            with pytest.raises(ClientAuthenticationError) as ex:
+                await ManagedIdentityCredential(transport=transport).get_token(scope)
+            assert "extension" in str(ex.value)
+
+
 @pytest.mark.asyncio
 async def test_token_exchange(tmpdir):
     exchange_token = "exchange-token"
-- 
2.46.0

openSUSE Build Service is sponsored by