File CVE-2025-4565.patch of Package protobuf.40656
From 91ce18e074dbe300b982834fefff1297ab9d408f Mon Sep 17 00:00:00 2001
From: Protobuf Team Bot <protobuf-github-bot@google.com>
Date: Tue, 13 May 2025 14:42:18 -0700
Subject: [PATCH] Add recursion depth limits to pure python
PiperOrigin-RevId: 758382549
---
python/google/protobuf/internal/decoder.py | 38 +++++
.../google/protobuf/internal/decoder_test.py | 156 ++++++++++++++++++
.../google/protobuf/internal/message_test.py | 76 +++++++--
.../protobuf/internal/self_recursive.proto | 24 +++
4 files changed, 282 insertions(+), 12 deletions(-)
create mode 100644 python/google/protobuf/internal/decoder_test.py
create mode 100644 python/google/protobuf/internal/self_recursive.proto
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 845d77427..8c28509cf 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -652,7 +652,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError(
+ 'Error parsing message: too many levels of nesting.'
+ )
pos = value.add()._InternalParse(buffer, pos, end)
+ current_depth -= 1
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -669,7 +675,11 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
pos = value._InternalParse(buffer, pos, end)
+ current_depth -= 1
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -699,10 +709,16 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError(
+ 'Error parsing message: too many levels of nesting.'
+ )
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
+ current_depth -= 1
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -720,10 +736,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
if value._InternalParse(buffer, pos, new_pos) != new_pos:
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
+ current_depth -= 1
return new_pos
return DecodeField
@@ -902,6 +922,14 @@ def _SkipLengthDelimited(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+DEFAULT_RECURSION_LIMIT = 100
+_recursion_limit = DEFAULT_RECURSION_LIMIT
+
+
+def SetRecursionLimit(new_limit):
+ global _recursion_limit
+ _recursion_limit = new_limit
+
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
@@ -945,7 +973,17 @@ def _DecodeUnknownField(buffer, pos, wire_type):
data = buffer[pos:pos+size]
pos += size
elif wire_type == wire_format.WIRETYPE_START_GROUP:
+ end_tag_bytes = encoder.TagBytes(
+ field_number, wire_format.WIRETYPE_END_GROUP
+ )
+ current_depth += 1
+ if current_depth >= _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+ current_depth -= 1
+ # Check end tag.
+ if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
+ raise _DecodeError('Missing group end tag.')
elif wire_type == wire_format.WIRETYPE_END_GROUP:
return (0, -1)
else:
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
new file mode 100644
index 000000000..d2bc81b46
--- /dev/null
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -0,0 +1,156 @@
+# -*- coding: utf-8 -*-
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test decoder."""
+
+import io
+import unittest
+
+from google.protobuf import message
+from google.protobuf.internal import api_implementation
+from google.protobuf.internal import decoder
+from google.protobuf.internal import message_set_extensions_pb2
+from google.protobuf.internal import testing_refleaks
+from google.protobuf.internal import wire_format
+
+from absl.testing import parameterized
+
+
+_INPUT_BYTES = b'\x84r\x12'
+_EXPECTED = (14596, 18)
+
+
+@testing_refleaks.TestCase
+class DecoderTest(parameterized.TestCase):
+
+ def test_decode_varint_bytes(self):
+ (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
+ self.assertEqual(size, _EXPECTED[0])
+ self.assertEqual(pos, 2)
+
+ (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2)
+ self.assertEqual(size, _EXPECTED[1])
+ self.assertEqual(pos, 3)
+
+ def test_decode_varint_bytes_empty(self):
+ with self.assertRaises(IndexError) as context:
+ decoder._DecodeVarint(b'', 0)
+ self.assertIn('index out of range', str(context.exception))
+
+ def test_decode_varint_bytesio(self):
+ index = 0
+ input_io = io.BytesIO(_INPUT_BYTES)
+ while True:
+ size = decoder._DecodeVarint(input_io)
+ if size is None:
+ break
+ self.assertEqual(size, _EXPECTED[index])
+ index += 1
+ self.assertEqual(index, len(_EXPECTED))
+
+ def test_decode_varint_bytesio_empty(self):
+ input_io = io.BytesIO(b'')
+ size = decoder._DecodeVarint(input_io)
+ self.assertIsNone(size)
+
+ def test_decode_unknown_group_field(self):
+ data = memoryview(b'\013\020\003\014\040\005')
+ parsed, pos = decoder._DecodeUnknownField(
+ data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
+ )
+
+ self.assertEqual(pos, 4)
+ self.assertEqual(len(parsed), 1)
+ self.assertEqual(parsed[0].field_number, 2)
+ self.assertEqual(parsed[0].data, 3)
+
+ def test_decode_unknown_group_field_nested(self):
+ data = memoryview(b'\013\023\013\030\004\014\024\014\050\006')
+ parsed, pos = decoder._DecodeUnknownField(
+ data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
+ )
+
+ self.assertEqual(pos, 8)
+ self.assertEqual(len(parsed), 1)
+ self.assertEqual(parsed[0].field_number, 2)
+ self.assertEqual(len(parsed[0].data), 1)
+ self.assertEqual(parsed[0].data[0].field_number, 1)
+ self.assertEqual(len(parsed[0].data[0].data), 1)
+ self.assertEqual(parsed[0].data[0].data[0].field_number, 3)
+ self.assertEqual(parsed[0].data[0].data[0].data, 4)
+
+ def test_decode_unknown_group_field_too_many_levels(self):
+ data = memoryview(b'\023' * 5_000_000)
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Error parsing message',
+ decoder._DecodeUnknownField,
+ data,
+ 1,
+ len(data),
+ 1,
+ wire_format.WIRETYPE_START_GROUP,
+ )
+
+ def test_decode_unknown_mismatched_end_group(self):
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Missing group end tag.*',
+ decoder._DecodeUnknownField,
+ memoryview(b'\013\024'),
+ 1,
+ 2,
+ 1,
+ wire_format.WIRETYPE_START_GROUP,
+ )
+
+ def test_decode_unknown_mismatched_end_group_nested(self):
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Missing group end tag.*',
+ decoder._DecodeUnknownField,
+ memoryview(b'\013\023\034\024\014'),
+ 1,
+ 5,
+ 1,
+ wire_format.WIRETYPE_START_GROUP,
+ )
+
+ def test_decode_message_set_unknown_mismatched_end_group(self):
+ proto = message_set_extensions_pb2.TestMessageSet()
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Unexpected end-group tag.'
+ if api_implementation.Type() == 'python'
+ else '.*',
+ proto.ParseFromString,
+ b'\013\054\014',
+ )
+
+ def test_unknown_message_set_decoder_mismatched_end_group(self):
+ # This behavior isn't actually reachable in practice, but it's good to
+ # test anyway.
+ decode = decoder.UnknownMessageSetItemDecoder()
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Unexpected end-group tag.',
+ decode,
+ memoryview(b'\054\014'),
+ )
+
+ @parameterized.parameters(int(0), float(0.0), False, '')
+ def test_default_scalar(self, value):
+ self.assertTrue(decoder.IsDefaultScalarValue(value))
+
+ @parameterized.parameters(int(1), float(-0.0), float(1.0), True, 'a')
+ def test_not_default_scalar(self, value):
+ self.assertFalse(decoder.IsDefaultScalarValue(value))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 9de3e7a01..368474cc8 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -82,6 +82,7 @@ from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import packed_field_test_pb2
+from google.protobuf.internal import self_recursive_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf import message
@@ -1299,6 +1300,52 @@ class MessageTest(unittest.TestCase):
self.assertEqual(True, m.repeated_bool[0])
+@testing_refleaks.TestCase
+class TestRecursiveGroup(unittest.TestCase):
+
+ def _MakeRecursiveGroupMessage(self, n):
+ msg = self_recursive_pb2.SelfRecursive()
+ sub = msg
+ for _ in range(n):
+ sub = sub.sub_group
+ sub.i = 1
+ return msg.SerializeToString()
+
+ def testRecursiveGroups(self):
+ recurse_msg = self_recursive_pb2.SelfRecursive()
+ data = self._MakeRecursiveGroupMessage(100)
+ recurse_msg.ParseFromString(data)
+ self.assertTrue(recurse_msg.HasField('sub_group'))
+
+ def testRecursiveGroupsException(self):
+ if api_implementation.Type() != 'python':
+ api_implementation._c_module.SetAllowOversizeProtos(False)
+ recurse_msg = self_recursive_pb2.SelfRecursive()
+ data = self._MakeRecursiveGroupMessage(300)
+ with self.assertRaises(message.DecodeError) as context:
+ recurse_msg.ParseFromString(data)
+ self.assertIn('Error parsing message', str(context.exception))
+ if api_implementation.Type() == 'python':
+ self.assertIn('too many levels of nesting', str(context.exception))
+
+ def testRecursiveGroupsUnknownFields(self):
+ if api_implementation.Type() != 'python':
+ api_implementation._c_module.SetAllowOversizeProtos(False)
+ test_msg = unittest_pb2.TestAllTypes()
+ data = self._MakeRecursiveGroupMessage(300) # unknown to test_msg
+ with self.assertRaises(message.DecodeError) as context:
+ test_msg.ParseFromString(data)
+ self.assertIn(
+ 'Error parsing message',
+ str(context.exception),
+ )
+ if api_implementation.Type() == 'python':
+ self.assertIn('too many levels of nesting', str(context.exception))
+ decoder.SetRecursionLimit(310)
+ test_msg.ParseFromString(data)
+ decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
+
+
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
class Proto2Test(unittest.TestCase):
@@ -1646,6 +1693,7 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
+
def testAssignUnknownEnum(self):
"""Assigning an unknown enum value is allowed and preserves the value."""
m = unittest_proto3_arena_pb2.TestAllTypes()
@@ -2576,20 +2624,24 @@ class OversizeProtosTest(unittest.TestCase):
self.p_serialized = self.p.SerializeToString()
def testAssertOversizeProto(self):
- from google.protobuf.pyext._message import SetAllowOversizeProtos
- SetAllowOversizeProtos(False)
- q = self.proto_cls()
- try:
- q.ParseFromString(self.p_serialized)
- except message.DecodeError as e:
- self.assertEqual(str(e), 'Error parsing message')
+ if api_implementation.Type() != 'python':
+ api_implementation._c_module.SetAllowOversizeProtos(False)
+ msg = unittest_pb2.TestRecursiveMessage()
+ with self.assertRaises(message.DecodeError) as context:
+ msg.ParseFromString(self.GenerateNestedProto(101))
+ self.assertIn('Error parsing message', str(context.exception))
def testSucceedOversizeProto(self):
- from google.protobuf.pyext._message import SetAllowOversizeProtos
- SetAllowOversizeProtos(True)
- q = self.proto_cls()
- q.ParseFromString(self.p_serialized)
- self.assertEqual(self.p.field.payload, q.field.payload)
+
+ if api_implementation.Type() == 'python':
+ decoder.SetRecursionLimit(310)
+ else:
+ api_implementation._c_module.SetAllowOversizeProtos(True)
+
+ msg = unittest_pb2.TestRecursiveMessage()
+ msg.ParseFromString(self.GenerateNestedProto(101))
+ decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
new file mode 100644
index 000000000..d2a7f004b
--- /dev/null
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -0,0 +1,24 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2024 Google Inc. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+edition = "2023";
+
+package google.protobuf.python.internal;
+
+message SelfRecursive {
+ SelfRecursive sub = 1;
+ int32 i = 2;
+ SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
+}
+
+message IndirectRecursive {
+ IntermediateRecursive intermediate = 1;
+}
+
+message IntermediateRecursive {
+ IndirectRecursive indirect = 1;
+}
--
2.51.0