File 4596-ssl-Reset-to-old-connection-state.patch of Package erlang

From 9c7a4732562ce4313bf9b77b0ef7b3a2be18ab5d Mon Sep 17 00:00:00 2001
From: Dan Gudmundsson <dgud@erlang.org>
Date: Wed, 22 Sep 2021 14:46:42 +0200
Subject: [PATCH 06/10] ssl: Reset to old connection state

In case of fake takeover, reset connection to old working one,
i.e. if any alert is generated or coming on the "new" connection,
reset connection_states and state to the saved one.
---
 lib/ssl/src/dtls_connection.erl     | 91 +++++++++++++++++++----------
 lib/ssl/src/dtls_gen_connection.erl | 28 +++++++--
 2 files changed, 81 insertions(+), 38 deletions(-)

diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index a27b9a6213..0ee01515a1 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -132,6 +132,8 @@
 
 -export([renegotiate/2]).
 
+-export([alert_or_reset_connection/3]).  %% Code re-use from dtls_gen_connection.
+
 %% gen_statem state functions
 -export([initial_hello/3,
          config_error/3,
@@ -287,7 +289,7 @@ hello(internal, #client_hello{cookie = <<>>,
                                                                              ssl_handshake:init_handshake_history()}},
                                            Actions)
     catch throw:#alert{} = Alert ->
-            ssl_gen_statem:handle_own_alert(Alert,?FUNCTION_NAME, State0)
+            alert_or_reset_connection(Alert, ?FUNCTION_NAME, State0)
     end;
 hello(internal, #hello_verify_request{cookie = Cookie}, 
       #state{static_env = #static_env{role = client,
@@ -327,17 +329,8 @@ hello(internal, #client_hello{extensions = Extensions} = Hello,
                                                  handshake_env = HsEnv#handshake_env{hello = Hello}},
              [{reply, From, {ok, Extensions}}]}
     catch throw:#alert{} = Alert ->
-            ssl_gen_statem:handle_own_alert(Alert, ?FUNCTION_NAME, State0)
+            alert_or_reset_connection(Alert, ?FUNCTION_NAME, State0)
     end;
-hello(internal, #server_hello{extensions = Extensions} = Hello, 
-      #state{ssl_options = #{
-                             handshake := hello},
-             handshake_env = HsEnv,
-             start_or_recv_from = From} = State) ->
-    {next_state, user_hello, State#state{start_or_recv_from = undefined,
-                                         handshake_env = HsEnv#handshake_env{
-                                                           hello = Hello}},
-     [{reply, From, {ok, Extensions}}]};       
 hello(internal, #client_hello{cookie = Cookie} = Hello, #state{static_env = #static_env{role = server,
                                                                                         transport_cb = Transport,
                                                                                         socket = Socket},
@@ -350,13 +343,23 @@ hello(internal, #client_hello{cookie = Cookie} = Hello, #state{static_env = #sta
 	    handle_client_hello(Hello, State);
 	_ ->
             case dtls_handshake:cookie(PSecret, IP, Port, Hello) of
-               	Cookie -> 
+               	Cookie ->
                     handle_client_hello(Hello, State);
                 _ ->
                     %% Handle bad cookie as new cookie request RFC 6347 4.1.2
                     hello(internal, Hello#client_hello{cookie = <<>>}, State)
             end
     end;
+hello(internal, #server_hello{extensions = Extensions} = Hello, 
+      #state{ssl_options = #{
+                 handshake := hello},
+             handshake_env = HsEnv,
+             start_or_recv_from = From} = State) ->
+    {next_state, user_hello, State#state{start_or_recv_from = undefined,
+                                         handshake_env = HsEnv#handshake_env{
+                                                           hello = Hello}},
+     [{reply, From, {ok, Extensions}}]};       
+
 hello(internal, #server_hello{} = Hello,
       #state{
          static_env = #static_env{role = client},
@@ -487,12 +490,19 @@ cipher(Type, Event, State) ->
      gen_handshake(?FUNCTION_NAME, Type, Event, State).
 
 %%--------------------------------------------------------------------
--spec connection(gen_statem:event_type(),  
+-spec connection(gen_statem:event_type(),
 		 #hello_request{} | #client_hello{}| term(), #state{}) ->
 			gen_statem:state_function_result().
 %%--------------------------------------------------------------------
-connection(enter, _, State) ->
-    {keep_state, State};     
+connection(enter, _, #state{connection_states = Cs0} = State0) ->
+    State = case maps:is_key(previous_cs, Cs0) of
+                false ->
+                    State0;
+                true ->
+                    Cs = maps:remove(previous_cs, Cs0),
+                    State0#state{connection_states = Cs}
+            end,
+    {keep_state, State};
 connection(info, Event, State) ->
     gen_info(Event, ?FUNCTION_NAME, State);
 connection(internal, #hello_request{}, #state{static_env = #static_env{host = Host,
@@ -544,12 +554,18 @@ connection(internal, #client_hello{}, #state{static_env = #static_env{role = ser
 connection(internal, new_connection, #state{ssl_options=SSLOptions,
                                             handshake_env=HsEnv,
                                             connection_states = OldCs} = State) ->
-    #{beast_mitigation := BeastMitigation} = SSLOptions,
-    ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
-    #{current_write:=CW, current_read:=CR} = OldCs,
-    ConnectionStates = ConnectionStates0#{saved_write:=CW, saved_read:=CR},
-    {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
-                                    connection_states = ConnectionStates}};
+    case maps:get(previous_cs, OldCs, undefined) of
+        undefined ->
+            #{beast_mitigation := BeastMitigation} = SSLOptions,
+            ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
+            ConnectionStates = ConnectionStates0#{previous_cs => OldCs},
+            {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
+                                            connection_states = ConnectionStates}};
+        _ ->
+            %% Someone spamming new_connection, just drop them
+            {keep_state, State}
+    end;
+
 connection({call, From}, {application_data, Data}, State) ->
     try
         send_application_data(Data, From, ?FUNCTION_NAME, State)
@@ -680,7 +696,7 @@ handle_client_hello(#client_hello{client_version = ClientVersion} = Hello, State
                                             session = Session}),
         {next_state, hello, State, [{next_event, internal, {common_client_hello, Type, ServerHelloExt}}]}
     catch #alert{} = Alert ->
-            ssl_gen_statem:handle_own_alert(Alert, hello, State0)
+            alert_or_reset_connection(Alert, hello, State0)
     end.
 
 
@@ -694,32 +710,43 @@ handle_state_timeout(flight_retransmission_timeout, StateName,
     %% This will reset the retransmission timer by repeating the enter state event
     {repeat_state, State, Actions}.
 
-
+alert_or_reset_connection(Alert, StateName, #state{connection_states = Cs} = State) ->
+    case maps:get(previous_cs, Cs, undefined) of
+        undefined ->
+            ssl_gen_statem:handle_own_alert(Alert, StateName, State);
+        PreviousConn ->
+            %% There exists an old connection and the new one failed,
+            %% reset to the old working one.
+            %% The next alert will be sent
+            HsEnv0 = State#state.handshake_env,
+            HsEnv  = HsEnv0#handshake_env{renegotiation = undefined},
+            NewState = State#state{connection_states = PreviousConn,
+                                   handshake_env = HsEnv
+                                  },
+            {next_state, connection, NewState}
+    end.
 
 gen_handshake(StateName, Type, Event, State) ->
     try tls_dtls_connection:StateName(Type, Event, State)
     catch
         throw:#alert{}=Alert ->
-            ssl_gen_statem:handle_own_alert(Alert, StateName,State);
+            alert_or_reset_connection(Alert, StateName, State);
         error:_ ->
             Alert = ?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake_data),
-            ssl_gen_statem:handle_own_alert(Alert, StateName,State)
+            alert_or_reset_connection(Alert, StateName, State)
     end.
 
 gen_info(Event, connection = StateName, State) ->
     try dtls_gen_connection:handle_info(Event, StateName, State)
     catch error:_ ->
-	    ssl_gen_statem:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR,
-						       malformed_data), 
-					    StateName, State)  
+            Alert = ?ALERT_REC(?FATAL, ?INTERNAL_ERROR, malformed_data),
+            alert_or_reset_connection(Alert, StateName, State)
     end;
-
 gen_info(Event, StateName, State) ->
     try dtls_gen_connection:handle_info(Event, StateName, State)
     catch error:_ ->
-	    ssl_gen_statem:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
-						       malformed_handshake_data), 
-					    StateName, State)  
+            Alert = ?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,malformed_handshake_data),
+            alert_or_reset_connection(Alert, StateName, State)
     end.
 
 prepare_flight(#state{flight_buffer = Flight,
diff --git a/lib/ssl/src/dtls_gen_connection.erl b/lib/ssl/src/dtls_gen_connection.erl
index ebc9766645..12aa094191 100644
--- a/lib/ssl/src/dtls_gen_connection.erl
+++ b/lib/ssl/src/dtls_gen_connection.erl
@@ -629,10 +629,26 @@ handle_alerts([], Result) ->
     Result;
 handle_alerts(_, {stop, _, _} = Stop) ->
     Stop;
-handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->
-     handle_alerts(Alerts, ssl_gen_statem:handle_alert(Alert, StateName, State));
-handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->
-     handle_alerts(Alerts, ssl_gen_statem:handle_alert(Alert, StateName, State)).
+handle_alerts(Alerts, {next_state, StateName, State}) ->
+    handle_alerts_or_reset(Alerts, StateName, State);
+handle_alerts(Alerts, {next_state, StateName, State, _Actions}) ->
+    handle_alerts_or_reset(Alerts, StateName, State).
+
+handle_alerts_or_reset([Alert|Alerts], StateName, #state{connection_states = Cs} = State) ->
+    case maps:get(previous_cs, Cs, undefined) of
+        undefined ->
+            handle_alerts(Alerts, ssl_gen_statem:handle_alert(Alert, StateName, State));
+        PreviousConn ->
+            %% There exists an old connection and the new one sent alerts,
+            %% reset to the old working one.
+            HsEnv0 = State#state.handshake_env,
+            HsEnv  = HsEnv0#handshake_env{renegotiation = undefined},
+            NewState = State#state{connection_states = PreviousConn,
+                                   handshake_env = HsEnv
+                                  },
+            {next_state, connection, NewState}
+    end.
+
 
 handle_own_alert(Alert, StateName,
                  #state{static_env = #static_env{data_tag = udp,
@@ -643,10 +659,10 @@ handle_own_alert(Alert, StateName,
             log_ignore_alert(LogLevel, StateName, Alert, Role),
             {next_state, StateName, State};
         {false, State} ->
-            ssl_gen_statem:handle_own_alert(Alert, StateName, State)
+            dtls_connection:alert_or_reset_connection(Alert, StateName, State)
     end;
 handle_own_alert(Alert, StateName, State) ->
-    ssl_gen_statem:handle_own_alert(Alert, StateName, State).
+    dtls_connection:alert_or_reset_connection(Alert, StateName, State).
 
 ignore_alert(#alert{level = ?FATAL}, #state{protocol_specific = #{ignored_alerts := N,
                                                   max_ignored_alerts := N}} = State) ->
-- 
2.31.1

openSUSE Build Service is sponsored by