File CVE-2024-35255.patch of Package python-azure-identity.35708
From 82c32f6d48178406ba6747d3e0f414dbc33532d5 Mon Sep 17 00:00:00 2001
From: Paul Van Eck <paulvaneck@microsoft.com>
Date: Mon, 10 Jun 2024 14:45:20 -0700
Subject: [PATCH] [Identity] Managed identity bug fix (#36010)
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
---
.../azure/identity/_credentials/azure_arc.py | 56 +++++++-
.../tests/test_managed_identity.py | 132 ++++++++++++++++--
.../tests/test_managed_identity_async.py | 123 +++++++++++++++-
3 files changed, 296 insertions(+), 15 deletions(-)
diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
index 68034300b8..859f625e15 100644
--- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
+++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
@@ -4,6 +4,7 @@
# ------------------------------------
import functools
import os
+import sys
from typing import Any, Dict, Optional
from azure.core.exceptions import ClientAuthenticationError
@@ -24,7 +25,7 @@ class AzureArcCredential(ManagedIdentityBase):
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
- **kwargs
+ **kwargs,
)
return None
@@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str:
raise ClientAuthenticationError(
message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
) from ex
+
+ try:
+ _validate_key_file(key_file)
+ except ValueError as ex:
+ raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex
+
with open(key_file, "r", encoding="utf-8") as file:
try:
return file.read()
@@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str:
) from error
+def _get_key_file_path() -> str:
+ """Returns the expected path for the Azure Arc MSI key file based on the current platform.
+
+ Only Linux and Windows are supported.
+
+ :return: The expected path.
+ :rtype: str
+ :raises ValueError: If the current platform is not supported.
+ """
+ if sys.platform.startswith("linux"):
+ return "/var/opt/azcmagent/tokens"
+ if sys.platform.startswith("win"):
+ program_data_path = os.environ.get("PROGRAMDATA")
+ if not program_data_path:
+ raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
+ return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
+ raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")
+
+
+def _validate_key_file(file_path: str) -> None:
+ """Validates that a given Azure Arc MSI file path is valid for use.
+
+ A valid file will:
+ 1. Be in the expected path for the current platform.
+ 2. Have a `.key` extension.
+ 3. Be at most 4096 bytes in size.
+
+ :param str file_path: The path to the key file.
+ :raises ClientAuthenticationError: If the file path is invalid.
+ """
+ if not file_path:
+ raise ValueError("The file path must not be empty.")
+
+ if not os.path.exists(file_path):
+ raise ValueError(f"The file path does not exist: {file_path}")
+
+ expected_directory = _get_key_file_path()
+ if not os.path.dirname(file_path) == expected_directory:
+ raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")
+
+ if not file_path.endswith(".key"):
+ raise ValueError("The file path must have a '.key' extension.")
+
+ if os.path.getsize(file_path) > 4096:
+ raise ValueError("The file size must be less than or equal to 4096 bytes.")
+
+
class ArcChallengeAuthPolicy(HTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py
index 805e36343b..b0be4b91ee 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity.py
@@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import os
+import sys
import time
try:
@@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = ManagedIdentityCredential(transport=transport).get_token(scope)
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert token.token == access_token
+ assert token.expires_on == expires_on
def test_azure_arc_tenant_id(tmpdir):
@@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+ assert token.token == access_token
+ assert token.expires_on == expires_on
def test_azure_arc_client_id():
@@ -950,10 +953,123 @@ def test_azure_arc_client_id():
EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
},
):
- credential = ManagedIdentityCredential(client_id="some-guid")
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ credential = ManagedIdentityCredential(client_id="some-guid")
- with pytest.raises(ClientAuthenticationError):
+ with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
+ assert "not supported" in str(ex.value)
+
+
+def test_azure_arc_key_too_large(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ secret_key = "X" * 4097
+
+ key_file = tmp_path / "key_file.key"
+ key_file.write_text(secret_key)
+ assert key_file.read_text() == secret_key
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "file size" in str(ex.value)
+
+
+def test_azure_arc_key_not_exist(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "not exist" in str(ex.value)
+
+
+def test_azure_arc_key_invalid(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ key_file = tmp_path / "key_file.txt"
+ key_file.write_text("secret")
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "Unexpected file path" in str(ex.value)
+
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "extension" in str(ex.value)
def test_token_exchange(tmpdir):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
index f9c1981158..79d2548487 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
@@ -848,9 +848,10 @@ async def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = await ManagedIdentityCredential(transport=transport).get_token(scope)
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert token.token == access_token
+ assert token.expires_on == expires_on
@pytest.mark.asyncio
@@ -901,9 +902,10 @@ async def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+ assert token.token == access_token
+ assert token.expires_on == expires_on
@pytest.mark.asyncio
@@ -922,6 +924,115 @@ async def test_azure_arc_client_id():
await credential.get_token("scope")
+@pytest.mark.asyncio
+async def test_azure_arc_key_too_large(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ secret_key = "X" * 4097
+
+ key_file = tmp_path / "key_file.key"
+ key_file.write_text(secret_key)
+ assert key_file.read_text() == secret_key
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "file size" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_not_exist(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=key_file"}),
+ ],
+ )
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "not exist" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_invalid(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ key_file = tmp_path / "key_file.txt"
+ key_file.write_text("secret")
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "Unexpected file path" in str(ex.value)
+
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "extension" in str(ex.value)
+
+
@pytest.mark.asyncio
async def test_token_exchange(tmpdir):
exchange_token = "exchange-token"
--
2.46.0