File protobuf-CVE-2022-1941.patch of Package protobuf.29292

From 55815e423bb82cc828836bbd60c79c1f9a195763 Mon Sep 17 00:00:00 2001
From: Deanna Garcia <deannagarcia@google.com>
Date: Tue, 13 Sep 2022 17:20:00 +0000
Subject: [PATCH] Apply patch

---
 src/google/protobuf/extension_set_inl.h     |   25 ++++--
 src/google/protobuf/wire_format_lite.h      |   27 ++++---
 src/google/protobuf/wire_format_unittest.cc |  108 +++++++++++++++++++++++++---
 3 files changed, 134 insertions(+), 26 deletions(-)

--- a/src/google/protobuf/extension_set_inl.h
+++ b/src/google/protobuf/extension_set_inl.h
@@ -208,15 +208,21 @@ const char* ExtensionSet::ParseMessageSe
                                                   Metadata* metadata,
                                                   internal::ParseContext* ctx) {
   std::string payload;
-  uint32 type_id = 0;
+  uint32 type_id;
+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
+  State state = State::kNoTag;
+
   while (!ctx->Done(&ptr)) {
     uint32 tag = static_cast<uint8>(*ptr++);
     if (tag == WireFormatLite::kMessageSetTypeIdTag) {
       uint64 tmp;
       ptr = ParseVarint64Inline(ptr, &tmp);
       GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
-      type_id = tmp;
-      if (!payload.empty()) {
+      if (state == State::kNoTag) {
+        type_id = tmp;
+        state = State::kHasType;
+      } else if (state == State::kHasPayload) {
+        type_id = tmp;
         ExtensionInfo extension;
         bool was_packed_on_wire;
         if (!FindExtension(2, type_id, containing_type, ctx, &extension,
@@ -242,19 +248,24 @@ const char* ExtensionSet::ParseMessageSe
           GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
                                          tmp_ctx.EndedAtLimit());
         }
-        type_id = 0;
+        state = State::kDone;
       }
     } else if (tag == WireFormatLite::kMessageSetMessageTag) {
-      if (type_id != 0) {
+      if (state == State::kHasType) {
         ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
                                     containing_type, metadata, ctx);
         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
-        type_id = 0;
+        state = State::kDone;
       } else {
+        std::string tmp;
         int32 size = ReadSize(&ptr);
         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
-        ptr = ctx->ReadString(ptr, size, &payload);
+        ptr = ctx->ReadString(ptr, size, &tmp);
         GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
+        if (state == State::kNoTag) {
+          payload = std::move(tmp);
+          state = State::kHasPayload;
+        }
       }
     } else {
       if (tag >= 128) {
--- a/src/google/protobuf/wire_format_lite.h
+++ b/src/google/protobuf/wire_format_lite.h
@@ -1796,6 +1796,9 @@ bool ParseMessageSetItemImpl(io::CodedIn
   // we can parse it later.
   std::string message_data;
 
+  enum class State { kNoTag, kHasType, kHasPayload, kDone };
+  State state = State::kNoTag;
+
   while (true) {
     const uint32 tag = input->ReadTagNoLastTag();
     if (tag == 0) return false;
@@ -1804,26 +1807,34 @@ bool ParseMessageSetItemImpl(io::CodedIn
       case WireFormatLite::kMessageSetTypeIdTag: {
         uint32 type_id;
         if (!input->ReadVarint32(&type_id)) return false;
-        last_type_id = type_id;
-
-        if (!message_data.empty()) {
+        if (state == State::kNoTag) {
+          last_type_id = type_id;
+          state = State::kHasType;
+        } else if (state == State::kHasPayload) {
           // We saw some message data before the type_id.  Have to parse it
           // now.
           io::CodedInputStream sub_input(
               reinterpret_cast<const uint8*>(message_data.data()),
               static_cast<int>(message_data.size()));
           sub_input.SetRecursionLimit(input->RecursionBudget());
-          if (!ms.ParseField(last_type_id, &sub_input)) {
+          if (!ms.ParseField(type_id, &sub_input)) {
             return false;
           }
           message_data.clear();
+          state = State::kDone;
         }
 
         break;
       }
 
       case WireFormatLite::kMessageSetMessageTag: {
-        if (last_type_id == 0) {
+        if (state == State::kHasType) {
+          // Already saw type_id, so we can parse this directly.
+          if (!ms.ParseField(last_type_id, input)) {
+            return false;
+          }
+          state = State::kDone;
+        } else if (state == State::kNoTag) {
           // We haven't seen a type_id yet.  Append this data to message_data.
           uint32 length;
           if (!input->ReadVarint32(&length)) return false;
@@ -1834,11 +1845,9 @@ bool ParseMessageSetItemImpl(io::CodedIn
           auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
           ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
           if (!input->ReadRaw(ptr, length)) return false;
+          state = State::kHasPayload;
         } else {
-          // Already saw type_id, so we can parse this directly.
-          if (!ms.ParseField(last_type_id, input)) {
-            return false;
-          }
+          if (!ms.SkipField(tag, input)) return false;
         }
 
         break;
--- a/src/google/protobuf/wire_format_unittest.cc
+++ b/src/google/protobuf/wire_format_unittest.cc
@@ -40,6 +40,7 @@
 #include <google/protobuf/unittest_proto3_arena.pb.h>
 #include <google/protobuf/io/coded_stream.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/dynamic_message.h>
 #include <google/protobuf/descriptor.h>
 
 #include <google/protobuf/stubs/logging.h>
@@ -579,30 +580,57 @@ TEST(WireFormatTest, ParseMessageSet) {
   EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
 }
 
-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
+namespace {
+std::string BuildMessageSetItemStart() {
   std::string data;
   {
-    unittest::TestMessageSetExtension1 message;
-    message.set_i(123);
-    // Build a MessageSet manually with its message content put before its
-    // type_id.
     io::StringOutputStream output_stream(&data);
     io::CodedOutputStream coded_output(&output_stream);
     coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
+  }
+  return data;
+}
+std::string BuildMessageSetItemEnd() {
+  std::string data;
+  {
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
+    coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+  }
+  return data;
+}
+std::string BuildMessageSetTestExtension1(int value = 123) {
+  std::string data;
+  {
+    unittest::TestMessageSetExtension1 message;
+    message.set_i(value);
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
     // Write the message content first.
     WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
                              WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
                              &coded_output);
     coded_output.WriteVarint32(message.ByteSize());
     message.SerializeWithCachedSizes(&coded_output);
-    // Write the type id.
-    uint32 type_id = message.GetDescriptor()->extension(0)->number();
+  }
+  return data;
+}
+std::string BuildMessageSetItemTypeId(int extension_number) {
+  std::string data;
+  {
+    io::StringOutputStream output_stream(&data);
+    io::CodedOutputStream coded_output(&output_stream);
     WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
-                                type_id, &coded_output);
-    coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+                                extension_number, &coded_output);
   }
+  return data;
+}
+void ValidateTestMessageSet(const std::string& test_case,
+                            const std::string& data) {
+  SCOPED_TRACE(test_case);
   {
     proto2_wireformat_unittest::TestMessageSet message_set;
+    ::proto2_wireformat_unittest::TestMessageSet message_set;
     ASSERT_TRUE(message_set.ParseFromString(data));
 
     EXPECT_EQ(123,
@@ -610,10 +638,15 @@ TEST(WireFormatTest, ParseMessageSetWith
                   .GetExtension(
                       unittest::TestMessageSetExtension1::message_set_extension)
                   .i());
+
+    // Make sure it does not contain anything else.
+    message_set.ClearExtension(
+        unittest::TestMessageSetExtension1::message_set_extension);
+    EXPECT_EQ(message_set.SerializeAsString(), "");
   }
   {
     // Test parse the message via Reflection.
-    proto2_wireformat_unittest::TestMessageSet message_set;
+    ::proto2_wireformat_unittest::TestMessageSet message_set;
     io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()),
                                data.size());
     EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set));
@@ -625,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWith
                       unittest::TestMessageSetExtension1::message_set_extension)
                   .i());
   }
+  {
+    // Test parse the message via DynamicMessage.
+    DynamicMessageFactory factory;
+    std::unique_ptr<Message> msg(
+        factory
+            .GetPrototype(
+                ::proto2_wireformat_unittest::TestMessageSet::descriptor())
+            ->New());
+    msg->ParseFromString(data);
+    auto* reflection = msg->GetReflection();
+    std::vector<const FieldDescriptor*> fields;
+    reflection->ListFields(*msg, &fields);
+    ASSERT_EQ(fields.size(), 1);
+    const auto& sub = reflection->GetMessage(*msg, fields[0]);
+    reflection = sub.GetReflection();
+    EXPECT_EQ(123, reflection->GetInt32(
+                       sub, sub.GetDescriptor()->FindFieldByName("i")));
+  }
+}
+}  // namespace
+
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
+  std::string start = BuildMessageSetItemStart();
+  std::string end = BuildMessageSetItemEnd();
+  std::string id = BuildMessageSetItemTypeId(
+      unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+  std::string message = BuildMessageSetTestExtension1();
+
+  ValidateTestMessageSet("id + message", start + id + message + end);
+  ValidateTestMessageSet("message + id", start + message + id + end);
+}
+
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
+  std::string start = BuildMessageSetItemStart();
+  std::string end = BuildMessageSetItemEnd();
+  std::string id = BuildMessageSetItemTypeId(
+      unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+  std::string other_id = BuildMessageSetItemTypeId(123456);
+  std::string message = BuildMessageSetTestExtension1();
+  std::string other_message = BuildMessageSetTestExtension1(321);
+
+  // Double id
+  ValidateTestMessageSet("id + other_id + message",
+                         start + id + other_id + message + end);
+  ValidateTestMessageSet("id + message + other_id",
+                         start + id + message + other_id + end);
+  ValidateTestMessageSet("message + id + other_id",
+                         start + message + id + other_id + end);
+  // Double message
+  ValidateTestMessageSet("id + message + other_message",
+                         start + id + message + other_message + end);
+  ValidateTestMessageSet("message + id + other_message",
+                         start + message + id + other_message + end);
+  ValidateTestMessageSet("message + other_message + id",
+                         start + message + other_message + id + end);
 }
 
 void SerializeReverseOrder(
openSUSE Build Service is sponsored by