File CVE-2024-35255.patch of Package python-azure-identity.37645
From 9ce8f49284f279f052418e4af5adfe6412159f16 Mon Sep 17 00:00:00 2001
From: Paul Van Eck <paulvaneck@microsoft.com>
Date: Fri, 24 May 2024 01:48:40 +0000
Subject: [PATCH] [Identity] Add Azure Arc key validation checks
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
---
.../azure/identity/_credentials/azure_arc.py | 66 ++++++++-
.../tests/test_managed_identity.py | 132 ++++++++++++++++--
.../tests/test_managed_identity_async.py | 123 +++++++++++++++-
3 files changed, 302 insertions(+), 19 deletions(-)
diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
index ff38ddbc7a..d2c8b4aecb 100644
--- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
+++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py
@@ -4,6 +4,7 @@
# ------------------------------------
import functools
import os
+import sys
from typing import TYPE_CHECKING
from azure.core.exceptions import ClientAuthenticationError
@@ -29,7 +30,7 @@ class AzureArcCredential(ManagedIdentityBase):
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
- **kwargs
+ **kwargs,
)
return None
@@ -72,16 +73,71 @@ def _get_secret_key(response):
# expecting header with structure like 'Basic realm=<file path>'
try:
key_file = header.split("=")[1]
- except IndexError:
+ except IndexError as ex:
raise ClientAuthenticationError(
message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
- )
- with open(key_file, "r") as file:
+ ) from ex
+
+ try:
+ _validate_key_file(key_file)
+ except ValueError as ex:
+ raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex
+
+ with open(key_file, "r", encoding="utf-8") as file:
try:
return file.read()
except Exception as error: # pylint:disable=broad-except
# user is expected to have obtained read permission prior to this being called
- raise ClientAuthenticationError(message="Could not read file {} contents: {}".format(key_file, error))
+ raise ClientAuthenticationError(
+ message="Could not read file {} contents: {}".format(key_file, error)
+ ) from error
+
+
+def _get_key_file_path() -> str:
+ """Returns the expected path for the Azure Arc MSI key file based on the current platform.
+
+ Only Linux and Windows are supported.
+
+ :return: The expected path.
+ :rtype: str
+ :raises ValueError: If the current platform is not supported.
+ """
+ if sys.platform.startswith("linux"):
+ return "/var/opt/azcmagent/tokens"
+ if sys.platform.startswith("win"):
+ program_data_path = os.environ.get("PROGRAMDATA")
+ if not program_data_path:
+ raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
+ return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
+ raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")
+
+
+def _validate_key_file(file_path: str) -> None:
+ """Validates that a given Azure Arc MSI file path is valid for use.
+
+ A valid file will:
+ 1. Be in the expected path for the current platform.
+ 2. Have a `.key` extension.
+ 3. Be at most 4096 bytes in size.
+
+ :param str file_path: The path to the key file.
+ :raises ClientAuthenticationError: If the file path is invalid.
+ """
+ if not file_path:
+ raise ValueError("The file path must not be empty.")
+
+ if not os.path.exists(file_path):
+ raise ValueError(f"The file path does not exist: {file_path}")
+
+ expected_directory = _get_key_file_path()
+ if not os.path.dirname(file_path) == expected_directory:
+ raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")
+
+ if not file_path.endswith(".key"):
+ raise ValueError("The file path must have a '.key' extension.")
+
+ if os.path.getsize(file_path) > 4096:
+ raise ValueError("The file size must be less than or equal to 4096 bytes.")
class ArcChallengeAuthPolicy(HTTPPolicy):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py
index 64afc210a2..b3108b58a5 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity.py
@@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import os
+import sys
import time
try:
@@ -893,9 +894,10 @@ def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = ManagedIdentityCredential(transport=transport).get_token(scope)
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert token.token == access_token
+ assert token.expires_on == expires_on
def test_azure_arc_tenant_id(tmpdir):
@@ -946,9 +948,10 @@ def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+ assert token.token == access_token
+ assert token.expires_on == expires_on
def test_azure_arc_client_id():
@@ -960,10 +963,123 @@ def test_azure_arc_client_id():
EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
},
):
- credential = ManagedIdentityCredential(client_id="some-guid")
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ credential = ManagedIdentityCredential(client_id="some-guid")
- with pytest.raises(ClientAuthenticationError):
+ with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
+ assert "not supported" in str(ex.value)
+
+
+def test_azure_arc_key_too_large(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ secret_key = "X" * 4097
+
+ key_file = tmp_path / "key_file.key"
+ key_file.write_text(secret_key)
+ assert key_file.read_text() == secret_key
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "file size" in str(ex.value)
+
+
+def test_azure_arc_key_not_exist(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "not exist" in str(ex.value)
+
+
+def test_azure_arc_key_invalid(tmp_path):
+
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ key_file = tmp_path / "key_file.txt"
+ key_file.write_text("secret")
+
+ transport = validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "Unexpected file path" in str(ex.value)
+
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "extension" in str(ex.value)
def test_token_exchange(tmpdir):
diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
index 577e5a9b90..0dfdfab167 100644
--- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py
+++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py
@@ -857,9 +857,10 @@ async def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = await ManagedIdentityCredential(transport=transport).get_token(scope)
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert token.token == access_token
+ assert token.expires_on == expires_on
@pytest.mark.asyncio
@@ -910,9 +911,10 @@ async def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
- token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
- assert token.token == access_token
- assert token.expires_on == expires_on
+ with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
+ token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
+ assert token.token == access_token
+ assert token.expires_on == expires_on
@pytest.mark.asyncio
@@ -931,6 +933,115 @@ async def test_azure_arc_client_id():
await credential.get_token("scope")
+@pytest.mark.asyncio
+async def test_azure_arc_key_too_large(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ secret_key = "X" * 4097
+
+ key_file = tmp_path / "key_file.key"
+ key_file.write_text(secret_key)
+ assert key_file.read_text() == secret_key
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "file size" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_not_exist(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=key_file"}),
+ ],
+ )
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "not exist" in str(ex.value)
+
+
+@pytest.mark.asyncio
+async def test_azure_arc_key_invalid(tmp_path):
+ api_version = "2019-11-01"
+ identity_endpoint = "http://localhost:42/token"
+ imds_endpoint = "http://localhost:42"
+ scope = "scope"
+ key_file = tmp_path / "key_file.txt"
+ key_file.write_text("secret")
+
+ transport = async_validating_transport(
+ requests=[
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ Request(
+ base_url=identity_endpoint,
+ method="GET",
+ required_headers={"Metadata": "true"},
+ required_params={"api-version": api_version, "resource": scope},
+ ),
+ ],
+ responses=[
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
+ ],
+ )
+
+ with mock.patch(
+ "os.environ",
+ {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
+ ):
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "Unexpected file path" in str(ex.value)
+
+ with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
+ with pytest.raises(ClientAuthenticationError) as ex:
+ await ManagedIdentityCredential(transport=transport).get_token(scope)
+ assert "extension" in str(ex.value)
+
+
@pytest.mark.asyncio
async def test_token_exchange(tmpdir):
exchange_token = "exchange-token"
--
2.48.1