File CVE-2025-4565.patch of Package protobuf.39291
From d31100c9195819edb0a12f44705dfc2da111ea9b Mon Sep 17 00:00:00 2001
From: shaod2 <shaod@google.com>
Date: Wed, 21 May 2025 14:30:53 -0400
Subject: [PATCH] Manually backport recursion limit enforcement to 25.x
---
python/build_targets.bzl | 6 +
python/google/protobuf/internal/decoder.py | 110 ++++++++++++++----
.../google/protobuf/internal/decoder_test.py | 33 ++++++
.../google/protobuf/internal/message_test.py | 31 +++++
.../protobuf/internal/python_message.py | 6 +-
.../protobuf/internal/self_recursive.proto | 17 +++
6 files changed, 176 insertions(+), 27 deletions(-)
create mode 100644 python/google/protobuf/internal/decoder_test.py
create mode 100644 python/google/protobuf/internal/self_recursive.proto
diff --git a/python/build_targets.bzl b/python/build_targets.bzl
index fff32b4a5..78329b18f 100644
--- a/python/build_targets.bzl
+++ b/python/build_targets.bzl
@@ -335,6 +335,12 @@ def build_targets(name):
data = ["//src/google/protobuf:testdata"],
)
+ internal_py_test(
+ name = "decoder_test",
+ srcs = ["google/protobuf/internal/decoder_test.py"],
+ data = ["//src/google/protobuf:testdata"],
+ )
+
internal_py_test(
name = "proto_builder_test",
srcs = ["google/protobuf/internal/proto_builder_test.py"],
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index acb91aa02..2b07fdd43 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -172,7 +172,10 @@ def _SimpleDecoder(wire_type, decode_value):
clear_if_default=False):
if is_packed:
local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
+ def DecodePackedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -191,7 +194,10 @@ def _SimpleDecoder(wire_type, decode_value):
elif is_repeated:
tag_bytes = encoder.TagBytes(field_number, wire_type)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -208,7 +214,8 @@ def _SimpleDecoder(wire_type, decode_value):
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(new_value, pos) = decode_value(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
enum_type = key.enum_type
if is_packed:
local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
+ def DecodePackedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
"""Decode serialized packed enum to its value and a new position.
Args:
@@ -365,6 +374,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -405,7 +415,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
elif is_repeated:
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
"""Decode serialized repeated enum to its value and a new position.
Args:
@@ -418,6 +430,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -446,7 +459,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
"""Decode serialized repeated enum to its value and a new position.
Args:
@@ -459,6 +472,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
@@ -540,7 +554,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -557,7 +574,8 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -598,7 +619,8 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_START_GROUP)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -632,7 +656,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
- pos = value.add()._InternalParse(buffer, pos, end)
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError(
+ 'Error parsing message: too many levels of nesting.'
+ )
+ pos = value.add()._InternalParse(buffer, pos, end, current_depth)
+ current_depth -= 1
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -644,12 +674,16 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
- pos = value._InternalParse(buffer, pos, end)
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
+ pos = value._InternalParse(buffer, pos, end, current_depth)
+ current_depth -= 1
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -679,18 +715,27 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
- if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError(
+ 'Error parsing message: too many levels of nesting.'
+ )
+ if (
+ value.add()._InternalParse(buffer, pos, new_pos, current_depth)
+ != new_pos
+ ):
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
# Predict that the next tag is another copy of the same repeated field.
+ current_depth -= 1
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -699,11 +744,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated message.')
- # Read sub-message.
- if value._InternalParse(buffer, pos, new_pos) != new_pos:
+ current_depth += 1
+ if current_depth > _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
+ if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
+ current_depth -= 1
return new_pos
return DecodeField
@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
# Can't read _concrete_class yet; might not be initialized.
message_type = field_descriptor.message_type
- def DecodeMap(buffer, pos, end, message, field_dict):
+ def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
submsg = message_type._concrete_class()
value = field_dict.get(key)
if value is None:
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
return pos
pos = new_pos
+DEFAULT_RECURSION_LIMIT = 100
+_recursion_limit = DEFAULT_RECURSION_LIMIT
+
+
+def SetRecursionLimit(new_limit):
+ global _recursion_limit
+ _recursion_limit = new_limit
+
-def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
+def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
unknown_field_set = containers.UnknownFieldSet()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
field_number, wire_type = wire_format.UnpackTag(tag)
if wire_type == wire_format.WIRETYPE_END_GROUP:
break
- (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+ (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
# pylint: disable=protected-access
unknown_field_set._add(field_number, wire_type, data)
return (unknown_field_set, pos)
-def _DecodeUnknownField(buffer, pos, wire_type):
+def _DecodeUnknownField(buffer, pos, wire_type, current_depth=0):
"""Decode a unknown field. Returns the UnknownField and new position."""
if wire_type == wire_format.WIRETYPE_VARINT:
@@ -973,7 +1030,12 @@ def _DecodeUnknownField(buffer, pos, wire_type):
data = buffer[pos:pos+size].tobytes()
pos += size
elif wire_type == wire_format.WIRETYPE_START_GROUP:
- (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+ print("MMP " + str(current_depth))
+ current_depth += 1
+ if current_depth >= _recursion_limit:
+ raise _DecodeError('Error parsing message: too many levels of nesting.')
+ (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+ current_depth -= 1
elif wire_type == wire_format.WIRETYPE_END_GROUP:
return (0, -1)
else:
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
new file mode 100644
index 000000000..613b73bcb
--- /dev/null
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test decoder."""
+
+import unittest
+
+from google.protobuf import message
+from google.protobuf.internal import decoder
+from google.protobuf.internal import testing_refleaks
+from google.protobuf.internal import wire_format
+
+@testing_refleaks.TestCase
+class DecoderTest(unittest.TestCase):
+ def test_decode_unknown_group_field_too_many_levels(self):
+ data = memoryview(b'\023' * 5_000_000)
+ self.assertRaisesRegex(
+ message.DecodeError,
+ 'Error parsing message',
+ decoder._DecodeUnknownField,
+ data,
+ 1,
+ wire_format.WIRETYPE_START_GROUP,
+ 1
+ )
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 3af128592..637db3307 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -30,9 +30,11 @@ cmp = lambda x, y: (x > y) - (x < y)
from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
from google.protobuf.internal import encoder
+from google.protobuf.internal import decoder
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import packed_field_test_pb2
+from google.protobuf.internal import self_recursive_pb2
from google.protobuf.internal import test_proto3_optional_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
@@ -1261,6 +1263,35 @@ class MessageTest(unittest.TestCase):
self.assertNotEqual(ComparesWithFoo(), m)
+@testing_refleaks.TestCase
+class TestRecursiveGroup(unittest.TestCase):
+
+ def _MakeRecursiveGroupMessage(self, n):
+ msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+ sub = msg
+ for _ in range(n):
+ sub = sub.sub_group
+ sub.i = 1
+ return msg.SerializeToString()
+
+ def testRecursiveGroups(self):
+ recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+ data = self._MakeRecursiveGroupMessage(100)
+ recurse_msg.ParseFromString(data)
+ self.assertTrue(recurse_msg.HasField('sub_group'))
+
+ def testRecursiveGroupsException(self):
+ if api_implementation.Type() != 'python':
+ api_implementation._c_module.SetAllowOversizeProtos(False)
+ recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+ data = self._MakeRecursiveGroupMessage(300)
+ with self.assertRaises(message.DecodeError) as context:
+ recurse_msg.ParseFromString(data)
+ self.assertIn('Error parsing message', str(context.exception))
+ if api_implementation.Type() == 'python':
+ self.assertIn('too many levels of nesting', str(context.exception))
+
+
# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
class Proto2Test(unittest.TestCase):
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 40c776453..8f20e9abb 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1136,7 +1136,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
fields_by_tag = cls._fields_by_tag
message_set_decoders_by_tag = cls._message_set_decoders_by_tag
- def InternalParse(self, buffer, pos, end):
+ def InternalParse(self, buffer, pos, end, current_depth=0):
"""Create a message from serialized bytes.
Args:
@@ -1180,7 +1180,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
# TODO: remove old_pos.
old_pos = new_pos
(data, new_pos) = decoder._DecodeUnknownField(
- buffer, new_pos, wire_type) # pylint: disable=protected-access
+ buffer, new_pos, wire_type, current_depth) # pylint: disable=protected-access
if new_pos == -1:
return pos
# pylint: disable=protected-access
@@ -1195,7 +1195,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
else:
_MaybeAddDecoder(cls, field_des)
field_decoder = field_des._decoders[is_packed]
- pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ pos = field_decoder(buffer, new_pos, end, self, field_dict, current_depth)
if field_des.containing_oneof:
self._UpdateOneofState(field_des)
return pos
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
new file mode 100644
index 000000000..2a7aacb0b
--- /dev/null
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -0,0 +1,17 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2024 Google Inc. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+syntax = "proto2";
+
+package google.protobuf.python.internal;
+
+message SelfRecursive {
+ optional group RecursiveGroup = 1 {
+ optional RecursiveGroup sub_group = 2;
+ optional int32 i = 3;
+ };
+}
\ No newline at end of file
--
2.49.0