File CVE-2025-4565.patch of Package protobuf.39291

From d31100c9195819edb0a12f44705dfc2da111ea9b Mon Sep 17 00:00:00 2001
From: shaod2 <shaod@google.com>
Date: Wed, 21 May 2025 14:30:53 -0400
Subject: [PATCH] Manually backport recursion limit enforcement to 25.x

---
 python/build_targets.bzl                      |   6 +
 python/google/protobuf/internal/decoder.py    | 110 ++++++++++++++----
 .../google/protobuf/internal/decoder_test.py  |  33 ++++++
 .../google/protobuf/internal/message_test.py  |  31 +++++
 .../protobuf/internal/python_message.py       |   6 +-
 .../protobuf/internal/self_recursive.proto    |  17 +++
 6 files changed, 176 insertions(+), 27 deletions(-)
 create mode 100644 python/google/protobuf/internal/decoder_test.py
 create mode 100644 python/google/protobuf/internal/self_recursive.proto

diff --git a/python/build_targets.bzl b/python/build_targets.bzl
index fff32b4a5..78329b18f 100644
--- a/python/build_targets.bzl
+++ b/python/build_targets.bzl
@@ -335,6 +335,12 @@ def build_targets(name):
         data = ["//src/google/protobuf:testdata"],
     )
 
+    internal_py_test(
+        name = "decoder_test",
+        srcs = ["google/protobuf/internal/decoder_test.py"],
+        data = ["//src/google/protobuf:testdata"],
+    )
+
     internal_py_test(
         name = "proto_builder_test",
         srcs = ["google/protobuf/internal/proto_builder_test.py"],
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index acb91aa02..2b07fdd43 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -172,7 +172,10 @@ def _SimpleDecoder(wire_type, decode_value):
                       clear_if_default=False):
     if is_packed:
       local_DecodeVarint = _DecodeVarint
-      def DecodePackedField(buffer, pos, end, message, field_dict):
+      def DecodePackedField(
+          buffer, pos, end, message, field_dict, current_depth=0
+      ):
+        del current_depth # unused
         value = field_dict.get(key)
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
@@ -191,7 +194,10 @@ def _SimpleDecoder(wire_type, decode_value):
     elif is_repeated:
       tag_bytes = encoder.TagBytes(field_number, wire_type)
       tag_len = len(tag_bytes)
-      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      def DecodeRepeatedField(
+          buffer, pos, end, message, field_dict, current_depth=0
+      ):
+        del current_depth # unused
         value = field_dict.get(key)
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
@@ -208,7 +214,8 @@ def _SimpleDecoder(wire_type, decode_value):
             return new_pos
       return DecodeRepeatedField
     else:
-      def DecodeField(buffer, pos, end, message, field_dict):
+      def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+        del current_depth # unused
         (new_value, pos) = decode_value(buffer, pos)
         if pos > end:
           raise _DecodeError('Truncated message.')
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
   enum_type = key.enum_type
   if is_packed:
     local_DecodeVarint = _DecodeVarint
-    def DecodePackedField(buffer, pos, end, message, field_dict):
+    def DecodePackedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       """Decode serialized packed enum to its value and a new position.
 
       Args:
@@ -365,6 +374,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
       Returns:
         int, new position in serialized data.
       """
+      del current_depth # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -405,7 +415,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
   elif is_repeated:
     tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       """Decode serialized repeated enum to its value and a new position.
 
       Args:
@@ -418,6 +430,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
       Returns:
         int, new position in serialized data.
       """
+      del current_depth # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -446,7 +459,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
           return new_pos
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       """Decode serialized repeated enum to its value and a new position.
 
       Args:
@@ -459,6 +472,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
       Returns:
         int, new position in serialized data.
       """
+      del current_depth # unused
       value_start_pos = pos
       (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
       if pos > end:
@@ -540,7 +554,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
+      del current_depth # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -557,7 +574,8 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
           return new_pos
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+      del current_depth # unused
       (size, pos) = local_DecodeVarint(buffer, pos)
       new_pos = pos + size
       if new_pos > end:
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
+      del current_depth # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -598,7 +619,8 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
           return new_pos
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+      del current_depth # unused
       (size, pos) = local_DecodeVarint(buffer, pos)
       new_pos = pos + size
       if new_pos > end:
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_START_GROUP)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -632,7 +656,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
         # Read sub-message.
-        pos = value.add()._InternalParse(buffer, pos, end)
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
+        pos = value.add()._InternalParse(buffer, pos, end, current_depth)
+        current_depth -= 1
         # Read end tag.
         new_pos = pos+end_tag_len
         if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -644,12 +674,16 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
           return new_pos
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
       # Read sub-message.
-      pos = value._InternalParse(buffer, pos, end)
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
+      pos = value._InternalParse(buffer, pos, end, current_depth)
+      current_depth -= 1
       # Read end tag.
       new_pos = pos+end_tag_len
       if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -679,18 +715,27 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
         if new_pos > end:
           raise _DecodeError('Truncated message.')
         # Read sub-message.
-        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
+        if (
+            value.add()._InternalParse(buffer, pos, new_pos, current_depth)
+            != new_pos
+        ):
           # The only reason _InternalParse would return early is if it
           # encountered an end-group tag.
           raise _DecodeError('Unexpected end-group tag.')
         # Predict that the next tag is another copy of the same repeated field.
+        current_depth -= 1
         pos = new_pos + tag_len
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
           # Prediction failed.  Return.
           return new_pos
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -699,11 +744,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
       new_pos = pos + size
       if new_pos > end:
         raise _DecodeError('Truncated message.')
-      # Read sub-message.
-      if value._InternalParse(buffer, pos, new_pos) != new_pos:
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
+      if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
         # The only reason _InternalParse would return early is if it encountered
         # an end-group tag.
         raise _DecodeError('Unexpected end-group tag.')
+      current_depth -= 1
       return new_pos
     return DecodeField
 
@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
   # Can't read _concrete_class yet; might not be initialized.
   message_type = field_descriptor.message_type
 
-  def DecodeMap(buffer, pos, end, message, field_dict):
+  def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
+    del current_depth # unused
     submsg = message_type._concrete_class()
     value = field_dict.get(key)
     if value is None:
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
       return pos
     pos = new_pos
 
+DEFAULT_RECURSION_LIMIT = 100
+_recursion_limit = DEFAULT_RECURSION_LIMIT
+
+
+def SetRecursionLimit(new_limit):
+  global _recursion_limit
+  _recursion_limit = new_limit
+
 
-def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
+def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
   """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
 
   unknown_field_set = containers.UnknownFieldSet()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
     field_number, wire_type = wire_format.UnpackTag(tag)
     if wire_type == wire_format.WIRETYPE_END_GROUP:
       break
-    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
     # pylint: disable=protected-access
     unknown_field_set._add(field_number, wire_type, data)
 
   return (unknown_field_set, pos)
 
 
-def _DecodeUnknownField(buffer, pos, wire_type):
+def _DecodeUnknownField(buffer, pos, wire_type, current_depth=0):
   """Decode a unknown field.  Returns the UnknownField and new position."""
 
   if wire_type == wire_format.WIRETYPE_VARINT:
@@ -973,7 +1030,12 @@ def _DecodeUnknownField(buffer, pos, wire_type):
     data = buffer[pos:pos+size].tobytes()
     pos += size
   elif wire_type == wire_format.WIRETYPE_START_GROUP:
-    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+    print("MMP " + str(current_depth))
+    current_depth += 1
+    if current_depth >= _recursion_limit:
+      raise _DecodeError('Error parsing message: too many levels of nesting.')
+    (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+    current_depth -= 1
   elif wire_type == wire_format.WIRETYPE_END_GROUP:
     return (0, -1)
   else:
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
new file mode 100644
index 000000000..613b73bcb
--- /dev/null
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc.  All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test decoder."""
+
+import unittest
+
+from google.protobuf import message
+from google.protobuf.internal import decoder
+from google.protobuf.internal import testing_refleaks
+from google.protobuf.internal import wire_format
+
+@testing_refleaks.TestCase
+class DecoderTest(unittest.TestCase):
+  def test_decode_unknown_group_field_too_many_levels(self):
+    data = memoryview(b'\023' * 5_000_000)
+    self.assertRaisesRegex(
+        message.DecodeError,
+        'Error parsing message',
+        decoder._DecodeUnknownField,
+        data,
+        1,
+        wire_format.WIRETYPE_START_GROUP,
+        1
+    )
+
+if __name__ == '__main__':
+  unittest.main()
\ No newline at end of file
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 3af128592..637db3307 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -30,9 +30,11 @@ cmp = lambda x, y: (x > y) - (x < y)
 
 from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
 from google.protobuf.internal import encoder
+from google.protobuf.internal import decoder
 from google.protobuf.internal import more_extensions_pb2
 from google.protobuf.internal import more_messages_pb2
 from google.protobuf.internal import packed_field_test_pb2
+from google.protobuf.internal import self_recursive_pb2
 from google.protobuf.internal import test_proto3_optional_pb2
 from google.protobuf.internal import test_util
 from google.protobuf.internal import testing_refleaks
@@ -1261,6 +1263,35 @@ class MessageTest(unittest.TestCase):
     self.assertNotEqual(ComparesWithFoo(), m)
 
 
+@testing_refleaks.TestCase
+class TestRecursiveGroup(unittest.TestCase):
+
+  def _MakeRecursiveGroupMessage(self, n):
+    msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+    sub = msg
+    for _ in range(n):
+      sub = sub.sub_group
+    sub.i = 1
+    return msg.SerializeToString()
+
+  def testRecursiveGroups(self):
+    recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+    data = self._MakeRecursiveGroupMessage(100)
+    recurse_msg.ParseFromString(data)
+    self.assertTrue(recurse_msg.HasField('sub_group'))
+
+  def testRecursiveGroupsException(self):
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
+    recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
+    data = self._MakeRecursiveGroupMessage(300)
+    with self.assertRaises(message.DecodeError) as context:
+      recurse_msg.ParseFromString(data)
+    self.assertIn('Error parsing message', str(context.exception))
+    if api_implementation.Type() == 'python':
+      self.assertIn('too many levels of nesting', str(context.exception))
+
+
 # Class to test proto2-only features (required, extensions, etc.)
 @testing_refleaks.TestCase
 class Proto2Test(unittest.TestCase):
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 40c776453..8f20e9abb 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1136,7 +1136,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
   fields_by_tag = cls._fields_by_tag
   message_set_decoders_by_tag = cls._message_set_decoders_by_tag
 
-  def InternalParse(self, buffer, pos, end):
+  def InternalParse(self, buffer, pos, end, current_depth=0):
     """Create a message from serialized bytes.
 
     Args:
@@ -1180,7 +1180,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
         # TODO: remove old_pos.
         old_pos = new_pos
         (data, new_pos) = decoder._DecodeUnknownField(
-            buffer, new_pos, wire_type)  # pylint: disable=protected-access
+            buffer, new_pos, wire_type, current_depth)  # pylint: disable=protected-access
         if new_pos == -1:
           return pos
         # pylint: disable=protected-access
@@ -1195,7 +1195,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
       else:
         _MaybeAddDecoder(cls, field_des)
         field_decoder = field_des._decoders[is_packed]
-        pos = field_decoder(buffer, new_pos, end, self, field_dict)
+        pos = field_decoder(buffer, new_pos, end, self, field_dict, current_depth)
         if field_des.containing_oneof:
           self._UpdateOneofState(field_des)
     return pos
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
new file mode 100644
index 000000000..2a7aacb0b
--- /dev/null
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -0,0 +1,17 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2024 Google Inc.  All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+syntax = "proto2";
+
+package google.protobuf.python.internal;
+
+message SelfRecursive {
+  optional group RecursiveGroup = 1 {
+    optional RecursiveGroup sub_group = 2;
+    optional int32 i = 3;
+  };
+}
\ No newline at end of file
-- 
2.49.0

openSUSE Build Service is sponsored by