File 4595-Allow-re-connect-on-dtls-sockets.patch of Package erlang

From 8e592d2167ed4e99b22109f9081ba625843b2057 Mon Sep 17 00:00:00 2001
From: Dan Gudmundsson <dgud@erlang.org>
Date: Thu, 15 Jul 2021 13:36:38 +0200
Subject: [PATCH 05/10] Allow re-connect on dtls sockets

Can happen when a computer reboots and connects from the same client
port without the server noticing, should be allowed according to RFC.
---
 lib/ssl/src/dtls_connection.erl     | 18 ++++++++
 lib/ssl/src/dtls_gen_connection.erl | 67 ++++++++++++++++------------
 lib/ssl/src/ssl_gen_statem.erl      |  2 +
 lib/ssl/test/dtls_api_SUITE.erl     | 68 +++++++++++++++++++++++++++--
 lib/ssl/test/ssl_test_lib.erl       | 10 ++++-
 5 files changed, 134 insertions(+), 31 deletions(-)

diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index ae24dc31cc..a27b9a6213 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -541,12 +541,30 @@ connection(internal, #client_hello{}, #state{static_env = #static_env{role = ser
     State1 = dtls_gen_connection:send_alert(Alert, State0),
     {Record, State} = ssl_gen_statem:prepare_connection(State1, Connection),
     dtls_gen_connection:next_event(?FUNCTION_NAME, Record, State);
+connection(internal, new_connection, #state{ssl_options=SSLOptions,
+                                            handshake_env=HsEnv,
+                                            connection_states = OldCs} = State) ->
+    #{beast_mitigation := BeastMitigation} = SSLOptions,
+    ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
+    #{current_write:=CW, current_read:=CR} = OldCs,
+    ConnectionStates = ConnectionStates0#{saved_write:=CW, saved_read:=CR},
+    {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
+                                    connection_states = ConnectionStates}};
 connection({call, From}, {application_data, Data}, State) ->
     try
         send_application_data(Data, From, ?FUNCTION_NAME, State)
     catch throw:Error ->
             ssl_gen_statem:hibernate_after(?FUNCTION_NAME, State, [{reply, From, Error}])
     end;
+connection({call, From}, {downgrade, Pid},
+           #state{connection_env = CEnv,
+                  static_env = #static_env{transport_cb = Transport,
+                                           socket = {_Server, Socket} = DTLSSocket}} = State) ->
+    %% For testing purposes, downgrades without noticing the server
+    dtls_socket:setopts(Transport, Socket, [{active, false}, {packet, 0}, {mode, binary}]),
+    Transport:controlling_process(Socket, Pid),
+    {stop_and_reply, {shutdown, normal}, {reply, From, {ok, DTLSSocket}},
+     State#state{connection_env = CEnv#connection_env{socket_terminated = true}}};
 connection(Type, Event, State) ->
     try
         tls_dtls_connection:?FUNCTION_NAME(Type, Event, State)
diff --git a/lib/ssl/src/dtls_gen_connection.erl b/lib/ssl/src/dtls_gen_connection.erl
index 68f6ff0ec8..ebc9766645 100644
--- a/lib/ssl/src/dtls_gen_connection.erl
+++ b/lib/ssl/src/dtls_gen_connection.erl
@@ -105,30 +105,29 @@ next_record(#state{protocol_buffers =
     CurrentRead = dtls_record:get_connection_state_by_epoch(Epoch, ConnectionStates, read),
     case dtls_record:replay_detect(CT, CurrentRead) of
         false ->
-            decode_cipher_text(State#state{connection_states = ConnectionStates}) ;
+            decode_cipher_text(State) ;
         true ->
             %% Ignore replayed record
-            next_record(State#state{protocol_buffers =
-                                        Buffers#protocol_buffers{dtls_cipher_texts = Rest},
-                                    connection_states = ConnectionStates})
+            next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}})
     end;
 next_record(#state{protocol_buffers =
 		       #protocol_buffers{dtls_cipher_texts = [#ssl_tls{epoch = Epoch} | Rest]}
 		   = Buffers,
-		   connection_states = #{current_read := #{epoch := CurrentEpoch}} = ConnectionStates} = State) 
+		   connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State)
   when Epoch > CurrentEpoch ->
     %% TODO Buffer later Epoch message, drop it for now
-    next_record(State#state{protocol_buffers =
-                                Buffers#protocol_buffers{dtls_cipher_texts = Rest},
-                            connection_states = ConnectionStates});
-next_record(#state{protocol_buffers =
-		       #protocol_buffers{dtls_cipher_texts = [ _ | Rest]}
-		   = Buffers,
-		   connection_states = ConnectionStates} = State) ->
-    %% Drop old epoch message
-    next_record(State#state{protocol_buffers =
-                                Buffers#protocol_buffers{dtls_cipher_texts = Rest},
-                            connection_states = ConnectionStates});
+    next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}});
+next_record(#state{protocol_buffers = #protocol_buffers{dtls_cipher_texts =
+                                                            [#ssl_tls{epoch = Epoch} | Rest]
+                                                       } = Buffers
+                  } = State) ->
+    case Epoch of
+        0 -> %% A reconnect (client might have rebooted and re-connected)
+            decode_cipher_text(State);
+        _ ->
+            %% Drop old epoch message
+            next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}})
+    end;
 next_record(#state{static_env = #static_env{role = server,
                                             socket = {Listener, {Client, _}}}} = State) ->
     dtls_packet_demux:active_once(Listener, Client, self()),
@@ -187,10 +186,11 @@ next_event(StateName, no_record,
 next_event(connection = StateName, Record,
 	   #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
     case Record of
-        #ssl_tls{epoch = CurrentEpoch,
+        #ssl_tls{epoch = Epoch,
                  type = ?HANDSHAKE,
-                 version = Version} = Record ->
-            State = dtls_version(StateName, Version, State0), 
+                 version = Version} = Record
+          when Epoch =:= CurrentEpoch; Epoch =:= 0 ->
+            State = dtls_version(StateName, Version, State0),
 	    {next_state, StateName, State,
 	     [{next_event, internal, {protocol_record, Record}} | Actions]};
 	#ssl_tls{epoch = CurrentEpoch} ->
@@ -330,9 +330,8 @@ handle_protocol_record(#ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, Stat
             ssl_gen_statem:hibernate_after(StateName, State, Actions)
     end;
 %%% DTLS record protocol level handshake messages 
-handle_protocol_record(#ssl_tls{type = ?HANDSHAKE,
-				       fragment = Data}, 
-		    StateName, 
+handle_protocol_record(#ssl_tls{type = ?HANDSHAKE, epoch = Epoch, fragment = Data},
+                       StateName,
                        #state{protocol_buffers = Buffers0,
                               connection_env = #connection_env{negotiated_version = Version},
                               ssl_options = Options} = State) ->
@@ -342,12 +341,17 @@ handle_protocol_record(#ssl_tls{type = ?HANDSHAKE,
 		next_event(StateName, no_record, State#state{protocol_buffers = Buffers});
 	    {Packets, Buffers} ->
 		HsEnv = State#state.handshake_env,
-		Events = dtls_handshake_events(Packets),
-                {next_state, StateName, 
+		HSEvents = dtls_handshake_events(Packets),
+                Events = case is_new_connection(Epoch, Packets, State) of
+                             true  -> [{next_event, internal, new_connection} | HSEvents];
+                             false -> HSEvents
+                         end,
+                {next_state, StateName,
                  State#state{protocol_buffers = Buffers,
-                             handshake_env = 
-                                 HsEnv#handshake_env{unprocessed_handshake_events 
-                                                     = unprocessed_events(Events)}}, Events}
+                             handshake_env =
+                                 HsEnv#handshake_env{
+                                   unprocessed_handshake_events = unprocessed_events(HSEvents)}
+                            }, Events}
 	end
     catch throw:#alert{} = Alert ->
 	    handle_own_alert(Alert, StateName, State)
@@ -550,6 +554,15 @@ handle_info(Msg, StateName, State) ->
 %% Internal functions 
 %%====================================================================
 
+is_new_connection(0, [{#client_hello{},_Raw}|_],
+                  #state{
+                     connection_states =
+                         #{current_read := #{epoch := CurrentEpoch}}})
+  when CurrentEpoch > 0 ->
+    true;
+is_new_connection(_, _, _) ->
+    false.
+
 dtls_handshake_events(Packets) ->
     lists:map(fun(Packet) ->
 		      {next_event, internal, {handshake, Packet}}
diff --git a/lib/ssl/src/ssl_gen_statem.erl b/lib/ssl/src/ssl_gen_statem.erl
index 1370036013..46926d15ef 100644
--- a/lib/ssl/src/ssl_gen_statem.erl
+++ b/lib/ssl/src/ssl_gen_statem.erl
@@ -720,6 +720,8 @@ handle_common_event({timeout, recv}, timeout, StateName, #state{start_or_recv_fr
 handle_common_event(internal, {recv, RecvFrom}, StateName, #state{start_or_recv_from = RecvFrom}) when
       StateName =/= connection ->
     {keep_state_and_data, [postpone]};
+handle_common_event(internal, new_connection, StateName, State) ->
+    {next_state, StateName, State};
 handle_common_event(Type, Msg, StateName, State) ->
     Alert =  ?ALERT_REC(?FATAL,?UNEXPECTED_MESSAGE, {unexpected_msg, {Type, Msg}}),
     handle_own_alert(Alert, StateName, State).
diff --git a/lib/ssl/test/dtls_api_SUITE.erl b/lib/ssl/test/dtls_api_SUITE.erl
index 572702af02..4117bec12e 100644
--- a/lib/ssl/test/dtls_api_SUITE.erl
+++ b/lib/ssl/test/dtls_api_SUITE.erl
@@ -51,9 +51,12 @@
          dtls_listen_two_sockets_5/0,
          dtls_listen_two_sockets_5/1,
          dtls_listen_two_sockets_6/0,
-         dtls_listen_two_sockets_6/1
+         dtls_listen_two_sockets_6/1,
+         client_restarts/0, client_restarts/1
         ]).
 
+-include_lib("ssl/src/ssl_internal.hrl").
+
 %%--------------------------------------------------------------------
 %% Common Test interface functions -----------------------------------
 %%--------------------------------------------------------------------
@@ -80,7 +83,8 @@ api_tests() ->
      dtls_listen_two_sockets_3,
      dtls_listen_two_sockets_4,
      dtls_listen_two_sockets_5,
-     dtls_listen_two_sockets_6
+     dtls_listen_two_sockets_6,
+     client_restarts
     ].
 
 init_per_suite(Config0) ->
@@ -300,7 +304,6 @@ dtls_listen_two_sockets_6(_Config) when is_list(_Config) ->
     ssl:close(S1),
     ok.
 
-
 replay_window() ->
     [{doc, "Whitebox test of replay window"}].
 replay_window(_Config) ->
@@ -347,6 +350,65 @@ bits_to_list(Bits, I, Acc) ->
         0 -> bits_to_list(Bits bsr 1, I+1, Acc)
     end.
 
+client_restarts() ->
+    [{doc, "Test re-connection "}].
+
+client_restarts(Config) ->
+    ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+    ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
+    {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
+
+    Server =
+	ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
+				   {from, self()},
+                                   {mfa, {ssl_test_lib, no_result, []}},
+				   {options, ServerOpts}]),
+    Port = ssl_test_lib:inet_port(Server),
+    Client0 = ssl_test_lib:start_client([{node, ClientNode},
+                                         {port, Port}, {host, Hostname},
+                                         {mfa, {ssl_test_lib, no_result, []}},
+                                         {from, self()},
+                                         {options, [{reuse_sessions, save} | ClientOpts]}]),
+    ReConnect =  %% Whitebox re-connect test
+        fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
+                ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
+                {ok, {{Adress,CPort},UDPSocket}=IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+                true = is_port(UDPSocket),
+                ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
+
+                {ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
+                             ssl = SslOpts0}} = ssl:handle_options(ClientOpts, client, Adress),
+                SslOpts = {SslOpts0, #socket_options{}, undefined},
+
+                ct:sleep(250),
+                ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
+                Res = ssl_gen_statem:connect(ConnectionCb, Adress, CPort, IntSocket, SslOpts, self(), CbInfo, infinity),
+                {Res, Pid}
+        end,
+
+    Client0 ! {apply, self(), ReConnect},
+    receive
+        {apply_res, {Res, _Prev}} ->
+            ct:log("Apply res: ~p~n", [Res]),
+            ok;
+        Msg ->
+            ct:log("Unhandled: ~p~n", [Msg]),
+            ct:fail({wrong_msg, Msg})
+    end,
+
+    receive
+        Msg2 ->
+            ct:log("Unhandled: ~p~n", [Msg2]),
+            ct:fail({wrong_msg, Msg2})
+    after 200 ->
+            ct:log("Nothing received~n", [])
+    end,
+
+    ssl_test_lib:close(Server),
+    ssl_test_lib:close(Client0),
+
+    ok.
+
 %%--------------------------------------------------------------------
 %% Internal functions ------------------------------------------------
 %%--------------------------------------------------------------------
diff --git a/lib/ssl/test/ssl_test_lib.erl b/lib/ssl/test/ssl_test_lib.erl
index e4c23c22cf..313eec3e22 100644
--- a/lib/ssl/test/ssl_test_lib.erl
+++ b/lib/ssl/test/ssl_test_lib.erl
@@ -1041,7 +1041,15 @@ client_loop_core(Socket, Pid, Transport) ->
         {ssl_closed, Socket} ->
             ok;
         {gen_tcp, closed} ->
-            ok
+            ok;
+        {apply, From, Fun} ->
+            try
+                Res = Fun(Socket, Transport),
+                From ! {apply_res, Res}
+            catch E:R:ST ->
+                    From ! {apply_res, {E,R,ST}}
+            end,
+            client_loop_core(Socket, Pid, Transport)
     end.
 
 client_cont_loop(_Node, Host, Port, Pid, Transport, Options, cancel, _Opts) ->
-- 
2.31.1

openSUSE Build Service is sponsored by