File 3883-ssl-dtls-let-new-accept-process-handle-new-connectio.patch of Package erlang

From 44dcb4c3d900777493ce2a6129f451aa475811f9 Mon Sep 17 00:00:00 2001
From: Dan Gudmundsson <dgud@erlang.org>
Date: Mon, 9 Jan 2023 15:58:33 +0100
Subject: [PATCH 3/3] ssl: dtls let new accept process handle new connections

If a user process is listening for new connections let that handle
the "new" connection instead of re-useing the old one.

If the new connection is successfully connected, bring down the old
connection.
---
 lib/ssl/src/dtls_connection.erl     |  37 +++++---
 lib/ssl/src/dtls_gen_connection.erl |   6 ++
 lib/ssl/src/dtls_packet_demux.erl   |  90 ++++++++++++++----
 lib/ssl/test/dtls_api_SUITE.erl     | 142 ++++++++++++++++++++++++----
 lib/ssl/test/ssl_test_lib.erl       |  16 ++--
 5 files changed, 234 insertions(+), 57 deletions(-)

diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index ff7fee47f0..a0b3820c58 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -512,13 +512,20 @@ cipher(Type, Event, State) ->
 		 #hello_request{} | #client_hello{}| term(), #state{}) ->
 			gen_statem:state_function_result().
 %%--------------------------------------------------------------------
-connection(enter, _, #state{connection_states = Cs0} = State0) ->
-    State = case maps:is_key(previous_cs, Cs0) of
-                false ->
-                    State0;
-                true ->
-                    Cs = maps:remove(previous_cs, Cs0),
-                    State0#state{connection_states = Cs}
+connection(enter, _, #state{connection_states = Cs0,
+                            static_env = Env} = State0) ->
+    State = case Env of
+                #static_env{socket = {Listener, {Client, _}}} ->
+                    dtls_packet_demux:connection_setup(Listener, Client),
+                    case maps:is_key(previous_cs, Cs0) of
+                        false ->
+                            State0;
+                        true ->
+                            Cs = maps:remove(previous_cs, Cs0),
+                            State0#state{connection_states = Cs}
+                    end;
+                _ -> %% client
+                    State0
             end,
     {keep_state, State};
 connection(info, Event, State) ->
@@ -572,14 +579,20 @@ connection(internal, #client_hello{}, #state{static_env = #static_env{role = ser
     dtls_gen_connection:next_event(?FUNCTION_NAME, Record, State);
 connection(internal, new_connection, #state{ssl_options=SSLOptions,
                                             handshake_env=HsEnv,
+                                            static_env = #static_env{socket = {Listener, {Client, _}}},
                                             connection_states = OldCs} = State) ->
     case maps:get(previous_cs, OldCs, undefined) of
         undefined ->
-            BeastMitigation = maps:get(beast_mitigation, SSLOptions, disabled),
-            ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
-            ConnectionStates = ConnectionStates0#{previous_cs => OldCs},
-            {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
-                                            connection_states = ConnectionStates}};
+            case dtls_packet_demux:new_connection(Listener, Client) of
+                true ->
+                    {keep_state, State};
+                false ->
+                    BeastMitigation = maps:get(beast_mitigation, SSLOptions, disabled),
+                    ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
+                    ConnectionStates = ConnectionStates0#{previous_cs => OldCs},
+                    {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
+                                                    connection_states = ConnectionStates}}
+            end;
         _ ->
             %% Someone spamming new_connection, just drop them
             {keep_state, State}
diff --git a/lib/ssl/src/dtls_gen_connection.erl b/lib/ssl/src/dtls_gen_connection.erl
index 4964d3d21f..aa634a3218 100644
--- a/lib/ssl/src/dtls_gen_connection.erl
+++ b/lib/ssl/src/dtls_gen_connection.erl
@@ -572,6 +572,12 @@ handle_info(new_cookie_secret, StateName,
     {next_state, StateName, State#state{protocol_specific = 
                                             CookieInfo#{current_cookie_secret => dtls_v1:cookie_secret(),
                                                         previous_cookie_secret => Secret}}};
+handle_info({socket_reused, Client}, StateName,
+            #state{static_env = #static_env{socket = {_, {Client, _}}}} = State) ->
+    Alert = ?ALERT_REC(?FATAL, ?CLOSE_NOTIFY, transport_closed),
+    ssl_gen_statem:handle_normal_shutdown(Alert#alert{role = server}, StateName, State),
+    {stop, {shutdown, transport_closed}, State};
+
 handle_info(Msg, StateName, State) ->
     ssl_gen_statem:handle_info(Msg, StateName, State).
 
diff --git a/lib/ssl/src/dtls_packet_demux.erl b/lib/ssl/src/dtls_packet_demux.erl
index 5c4556fbb8..4ad0e9ee03 100644
--- a/lib/ssl/src/dtls_packet_demux.erl
+++ b/lib/ssl/src/dtls_packet_demux.erl
@@ -33,6 +33,8 @@
          sockname/1,
          close/1,
          new_owner/1,
+         new_connection/2,
+         connection_setup/2,
          get_all_opts/1,
          set_all_opts/2,
          get_sock_opts/2,
@@ -84,6 +86,12 @@ close(PacketSocket) ->
 new_owner(PacketSocket) ->
     call(PacketSocket, new_owner).
 
+new_connection(PacketSocket, Client) ->
+    call(PacketSocket, {new_connection, Client, self()}).
+
+connection_setup(PacketSocket, Client) ->
+    gen_server:cast(PacketSocket, {connection_setup, Client}).
+
 get_sock_opts(PacketSocket, SplitSockOpts) ->
     call(PacketSocket,  {get_sock_opts, SplitSockOpts}).
 get_all_opts(PacketSocket) ->
@@ -145,6 +153,18 @@ handle_call(close, _, #state{dtls_processes = Processes,
     end;
 handle_call(new_owner, _, State) ->
     {reply, ok,  State#state{close = false, first = true}};
+handle_call({new_connection, Old, _Pid}, _,
+            #state{accepters = Accepters, dtls_msq_queues = MsgQs0} = State) ->
+    case queue:is_empty(Accepters) of
+        false ->
+            OldQueue = kv_get(Old, MsgQs0),
+            MsgQs1 = kv_delete(Old, MsgQs0),
+            MsgQs = kv_insert({old,Old}, OldQueue, MsgQs1),
+            {reply, true, State#state{dtls_msq_queues = MsgQs}};
+        true ->
+            {reply, false, State}
+    end;
+
 handle_call({get_sock_opts, {SocketOptNames, EmOptNames}}, _, #state{listener = Socket,
                                                                emulated_options = EmOpts} = State) ->
     case get_socket_opts(Socket, SocketOptNames) of
@@ -169,7 +189,16 @@ handle_call({getstat, Options}, _,  #state{listener = Socket, transport =  {Tran
 
 handle_cast({active_once, Client, Pid}, State0) ->
     State = handle_active_once(Client, Pid, State0),
-    {noreply, State}.
+    {noreply, State};
+handle_cast({connection_setup, Client}, #state{dtls_msq_queues = MsgQueues} = State) ->
+    case kv_lookup({old, Client}, MsgQueues) of
+        none ->
+            {noreply, State};
+        {value, {Pid, _}} ->
+            Pid ! {socket_reused, Client},
+            %% Will be deleted when handling DOWN message
+            {noreply, State}
+    end.
 
 handle_info({Transport, Socket, IP, InPortNo, _} = Msg, #state{listener = Socket, transport = {_,Transport,_,_,_}} = State0) ->
     State = handle_datagram({IP, InPortNo}, Msg, State0),
@@ -189,23 +218,40 @@ handle_info({udp_error, Socket, econnreset = Error}, #state{listener = Socket, t
     ?LOG_NOTICE(Report),
     {noreply, State};
 handle_info({ErrorTag, Socket, Error}, #state{listener = Socket, transport = {_,_,_, ErrorTag,_}} = State) ->
-    Report = io_lib:format("SSL Packet muliplxer shutdown: Socket error: ~p ~n", [Error]),
+    Report = io_lib:format("SSL Packet muliplexer shutdown: Socket error: ~p ~n", [Error]),
     ?LOG_NOTICE(Report),
     {noreply, State#state{close=true}};
 
 handle_info({'DOWN', _, process, Pid, _},
             #state{dtls_processes = Processes0,
                    dtls_msq_queues = MsgQueues0,
-                   close = ListenClosed} = State) ->
+                   close = ListenClosed} = State0) ->
     Client = kv_get(Pid, Processes0),
     Processes = kv_delete(Pid, Processes0),
-    MsgQueues = kv_delete(Client, MsgQueues0),
+    State = case kv_lookup(Client, MsgQueues0) of
+                none ->
+                    MsgQueues1 = kv_delete({old, Client}, MsgQueues0),
+                    State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1};
+                {value, {Pid, _}} ->
+                    MsgQueues1 = kv_delete(Client, MsgQueues0),
+                    %% Restore old process if exists
+                    case kv_lookup({old, Client}, MsgQueues1) of
+                        none ->
+                            State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1};
+                        {value, Old} ->
+                            MsgQueues2 = kv_delete({old, Client}, MsgQueues1),
+                            MsgQueues = kv_insert(Client, Old, MsgQueues2),
+                            State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues}
+                    end;
+                {value, _} -> %% Old process died (just delete its queue)
+                    MsgQueues1 = kv_delete({old, Client}, MsgQueues0),
+                    State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1}
+            end,
     case ListenClosed andalso kv_empty(Processes) of
         true ->
             {stop, normal, State};
         false ->
-            {noreply, State#state{dtls_processes = Processes,
-                                  dtls_msq_queues = MsgQueues}}
+            {noreply, State}
     end.
 
 terminate(_Reason, _State) ->
@@ -232,33 +278,39 @@ handle_datagram(Client, Msg, #state{dtls_msq_queues = MsgQueues, accepters = Acc
 	    dispatch(Queue, Client, Msg, State)
     end.
 
-dispatch(Queue0, Client, Msg, #state{dtls_msq_queues = MsgQueues} = State) ->
+dispatch({Pid, Queue0}, Client, Msg, #state{dtls_msq_queues = MsgQueues} = State) ->
     case queue:out(Queue0) of
         {{value, Pid}, Queue} when is_pid(Pid) ->
             Pid ! Msg,
             State#state{dtls_msq_queues =
-                            kv_update(Client, Queue, MsgQueues)};
+                            kv_update(Client, {Pid, Queue}, MsgQueues)};
         {{value, _UDP}, _Queue} ->
             State#state{dtls_msq_queues =
-                            kv_update(Client, queue:in(Msg, Queue0), MsgQueues)};
+                            kv_update(Client, {Pid, queue:in(Msg, Queue0)}, MsgQueues)};
         {empty, Queue} ->
             State#state{dtls_msq_queues =
-                            kv_update(Client, queue:in(Msg, Queue), MsgQueues)}
+                            kv_update(Client, {Pid, queue:in(Msg, Queue)}, MsgQueues)}
     end.
 
 next_datagram(Socket, N) ->
     inet:setopts(Socket, [{active, N}]).
 
 handle_active_once(Client, Pid, #state{dtls_msq_queues = MsgQueues} = State0) ->
-    Queue0 = kv_get(Client, MsgQueues),
+    {Key, Queue0} = case kv_lookup(Client, MsgQueues) of
+                        {value, {Pid, Q0}} -> {Client, Q0};
+                        _ ->
+                            OldKey = {old, Client},
+                            {Pid, Q0} = kv_get(OldKey, MsgQueues),
+                            {OldKey, Q0}
+                    end,
     case queue:out(Queue0) of
-	{{value, Pid}, _} when is_pid(Pid) ->
-	    State0;
-	{{value, Msg}, Queue} ->
-	    Pid ! Msg,
-	    State0#state{dtls_msq_queues = kv_update(Client, Queue, MsgQueues)};
-	{empty, Queue0} ->
-	    State0#state{dtls_msq_queues = kv_update(Client, queue:in(Pid, Queue0), MsgQueues)}
+        {{value, Pid}, _} when is_pid(Pid) ->
+            State0;
+        {{value, Msg}, Queue} ->
+            Pid ! Msg,
+            State0#state{dtls_msq_queues = kv_update(Key, {Pid, Queue}, MsgQueues)};
+        {empty, Queue0} ->
+            State0#state{dtls_msq_queues = kv_update(Key, {Pid, queue:in(Pid, Queue0)}, MsgQueues)}
     end.
 
 setup_new_connection(User, From, Client, Msg, #state{dtls_processes = Processes,
@@ -275,7 +327,7 @@ setup_new_connection(User, From, Client, Msg, #state{dtls_processes = Processes,
 	    erlang:monitor(process, Pid),
 	    gen_server:reply(From, {ok, Pid, {Client, Socket}}),
 	    Pid ! Msg,
-	    State#state{dtls_msq_queues = kv_insert(Client, queue:new(), MsgQueues),
+	    State#state{dtls_msq_queues = kv_insert(Client, {Pid, queue:new()}, MsgQueues),
 			dtls_processes = kv_insert(Pid, Client, Processes)};
 	{error, Reason} ->
 	    gen_server:reply(From, {error, Reason}),
diff --git a/lib/ssl/test/dtls_api_SUITE.erl b/lib/ssl/test/dtls_api_SUITE.erl
index f6dab82bd0..9602612f5e 100644
--- a/lib/ssl/test/dtls_api_SUITE.erl
+++ b/lib/ssl/test/dtls_api_SUITE.erl
@@ -52,7 +52,8 @@
          dtls_listen_two_sockets_5/1,
          dtls_listen_two_sockets_6/0,
          dtls_listen_two_sockets_6/1,
-         client_restarts/0, client_restarts/1
+         client_restarts/0, client_restarts/1,
+         client_restarts_multiple_acceptors/1
         ]).
 
 -include_lib("ssl/src/ssl_internal.hrl").
@@ -84,7 +85,8 @@ api_tests() ->
      dtls_listen_two_sockets_4,
      dtls_listen_two_sockets_5,
      dtls_listen_two_sockets_6,
-     client_restarts
+     client_restarts,
+     client_restarts_multiple_acceptors
     ].
 
 init_per_suite(Config0) ->
@@ -354,59 +356,159 @@ client_restarts() ->
     [{doc, "Test re-connection "}].
 
 client_restarts(Config) ->
-    ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+    ClientOpts0 = ssl_test_lib:ssl_options(client_rsa_opts, Config),
     ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
     {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
 
+    ClientOpts = [{verify, verify_none},{reuse_sessions, save} | ClientOpts0],
     Server =
 	ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
 				   {from, self()},
                                    {mfa, {ssl_test_lib, no_result, []}},
-				   {options, ServerOpts}]),
+				   {options, [{verify, verify_none}|ServerOpts]}]),
     Port = ssl_test_lib:inet_port(Server),
     Client0 = ssl_test_lib:start_client([{node, ClientNode},
                                          {port, Port}, {host, Hostname},
                                          {mfa, {ssl_test_lib, no_result, []}},
                                          {from, self()},
-                                         {options, [{reuse_sessions, save} | ClientOpts]}]),
+                                         {options, ClientOpts}]),
+
+    ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+    ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+    Server ! {active_receive, Msg1},
+    Client0 ! {active_receive, Msg2},
+
+    Msgs = lists:sort([{Server, Msg1}, {Client0, Msg2}]),
+    Msgs = lists:sort(flush()),
+
     ReConnect =  %% Whitebox re-connect test
         fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
                 ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
-                {ok, {{Address,CPort},UDPSocket}=IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
-
+                {ok, IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+                {{Address,CPort},UDPSocket}=IntSocket,
                 ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
 
                 {ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
-                             ssl = SslOpts0}} = ssl:handle_options(ClientOpts, client, Address),
+                             ssl = SslOpts0}} =
+                    ssl:handle_options(ClientOpts, client, Address),
                 SslOpts = {SslOpts0, #socket_options{}, undefined},
 
                 ct:sleep(250),
                 ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
-                Res = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket, SslOpts, self(), CbInfo, infinity),
-                {Res, Pid}
+                {ok, NewSocket} = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket,
+                                                         SslOpts, self(), CbInfo, infinity),
+                {replace, NewSocket}
         end,
 
     Client0 ! {apply, self(), ReConnect},
     receive
-        {apply_res, {Res, _Prev}} ->
+        {apply_res, {replace, Res}} ->
             ct:log("Apply res: ~p~n", [Res]),
             ok;
-        Msg ->
-            ct:log("Unhandled: ~p~n", [Msg]),
-            ct:fail({wrong_msg, Msg})
+        ErrMsg ->
+            ct:log("Unhandled: ~p~n", [ErrMsg]),
+            ct:fail({wrong_msg, ErrMsg})
     end,
 
+    ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+    ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+    Server ! {active_receive, Msg1},
+    Client0 ! {active_receive, Msg2},
+
+    Msgs = lists:sort(flush()),
+
+    ssl_test_lib:close(Server),
+    ssl_test_lib:close(Client0),
+    ok.
+
+
+flush() ->
+    receive Msg -> [Msg|flush()]
+    after 500 -> []
+    end.
+
+client_restarts_multiple_acceptors(Config) ->
+    %% Can also be tested with openssl by connecting a client and hit
+    %% Ctrl-C to kill openssl process, so that the connection is not
+    %% closed.
+    %% Then do a new openssl connect with the same client port.
+
+    ClientOpts0 = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+    ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
+    {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
+
+    ClientOpts = [{verify, verify_none},{reuse_sessions, save} | ClientOpts0],
+    Server =
+	ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
+				   {from, self()},
+                                   {mfa, {ssl_test_lib, no_result, []}},
+                                   {accepters, 2},
+				   {options, [{verify, verify_none}|ServerOpts]}]),
+    Port = ssl_test_lib:inet_port(Server),
+    Client0 = ssl_test_lib:start_client([{node, ClientNode},
+                                         {port, Port}, {host, Hostname},
+                                         {mfa, {ssl_test_lib, no_result, []}},
+                                         {from, self()},
+                                         {options, ClientOpts}]),
+
+    Server2 = receive {accepter, 2, Server2Pid} -> Server2Pid
+              after 5000 -> ct:fail(msg_timeout)
+              end,
+
+    ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+    ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+    Server ! {active_receive, Msg1},
+    Client0 ! {active_receive, Msg2},
+
+    Msgs = lists:sort([{Server, Msg1}, {Client0, Msg2}]),
+    Msgs = lists:sort(flush()),
+
+    ReConnect =  %% Whitebox re-connect test
+        fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
+                ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
+                {ok, IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+                {{Address,CPort},UDPSocket}=IntSocket,
+                ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
+
+                {ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
+                             ssl = SslOpts0}} =
+                    ssl:handle_options(ClientOpts, client, Address),
+                SslOpts = {SslOpts0, #socket_options{}, undefined},
+
+                ct:sleep(250),
+                ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
+                {ok, NewSocket} = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket,
+                                                         SslOpts, self(), CbInfo, infinity),
+                {replace, NewSocket}
+        end,
+
+    Client0 ! {apply, self(), ReConnect},
     receive
-        Msg2 ->
-            ct:log("Unhandled: ~p~n", [Msg2]),
-            ct:fail({wrong_msg, Msg2})
-    after 200 ->
-            ct:log("Nothing received~n", [])
+        {apply_res, {replace, Res}} ->
+            ct:log("Apply res: ~p~n", [Res]),
+            ok;
+        ErrMsg ->
+            ct:log("Unhandled: ~p~n", [ErrMsg]),
+            ct:fail({wrong_msg, ErrMsg})
     end,
 
+    ok = ssl_test_lib:send(Client0, Msg3 = "from client 2"),
+    ok = ssl_test_lib:send(Server2, Msg4 = "from server 2 to client 2"),
+    {error, closed} = ssl_test_lib:send(Server,  "Should be closed"),
+
+    Msgs2 = lists:sort([{Server2, Msg3}, {Client0, Msg4}]),
+
+    Server2 ! {active_receive, Msg3},
+    Client0 ! {active_receive, Msg4},
+
+    Msgs2 = lists:sort(flush()),
+
     ssl_test_lib:close(Server),
+    ssl_test_lib:close(Server2),
     ssl_test_lib:close(Client0),
-
     ok.
 
 %%--------------------------------------------------------------------
diff --git a/lib/ssl/test/ssl_test_lib.erl b/lib/ssl/test/ssl_test_lib.erl
index 5a15283e82..feeedca4ee 100644
--- a/lib/ssl/test/ssl_test_lib.erl
+++ b/lib/ssl/test/ssl_test_lib.erl
@@ -1107,13 +1107,17 @@ client_loop_core(Socket, Pid, Transport) ->
         {gen_tcp, closed} ->
             ok;
         {apply, From, Fun} ->
-            try
-                Res = Fun(Socket, Transport),
-                From ! {apply_res, Res}
+            try Fun(Socket, Transport) of
+                {replace, NewSocket} = Res ->
+                    From ! {apply_res, Res},
+                    client_loop_core(NewSocket, Pid, Transport);
+                Res ->
+                    From ! {apply_res, Res},
+                    client_loop_core(Socket, Pid, Transport)
             catch E:R:ST ->
-                    From ! {apply_res, {E,R,ST}}
-            end,
-            client_loop_core(Socket, Pid, Transport)
+                    From ! {apply_res, {E,R,ST}},
+                    client_loop_core(Socket, Pid, Transport)
+            end
     end.
 
 client_cont_loop(_Node, Host, Port, Pid, Transport, Options, cancel, _Opts) ->
-- 
2.35.3

openSUSE Build Service is sponsored by