File 3883-ssl-dtls-let-new-accept-process-handle-new-connectio.patch of Package erlang
From 44dcb4c3d900777493ce2a6129f451aa475811f9 Mon Sep 17 00:00:00 2001
From: Dan Gudmundsson <dgud@erlang.org>
Date: Mon, 9 Jan 2023 15:58:33 +0100
Subject: [PATCH 3/3] ssl: dtls let new accept process handle new connections
If a user process is listening for new connections let that handle
the "new" connection instead of re-useing the old one.
If the new connection is successfully connected, bring down the old
connection.
---
lib/ssl/src/dtls_connection.erl | 37 +++++---
lib/ssl/src/dtls_gen_connection.erl | 6 ++
lib/ssl/src/dtls_packet_demux.erl | 90 ++++++++++++++----
lib/ssl/test/dtls_api_SUITE.erl | 142 ++++++++++++++++++++++++----
lib/ssl/test/ssl_test_lib.erl | 16 ++--
5 files changed, 234 insertions(+), 57 deletions(-)
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index ff7fee47f0..a0b3820c58 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -512,13 +512,20 @@ cipher(Type, Event, State) ->
#hello_request{} | #client_hello{}| term(), #state{}) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
-connection(enter, _, #state{connection_states = Cs0} = State0) ->
- State = case maps:is_key(previous_cs, Cs0) of
- false ->
- State0;
- true ->
- Cs = maps:remove(previous_cs, Cs0),
- State0#state{connection_states = Cs}
+connection(enter, _, #state{connection_states = Cs0,
+ static_env = Env} = State0) ->
+ State = case Env of
+ #static_env{socket = {Listener, {Client, _}}} ->
+ dtls_packet_demux:connection_setup(Listener, Client),
+ case maps:is_key(previous_cs, Cs0) of
+ false ->
+ State0;
+ true ->
+ Cs = maps:remove(previous_cs, Cs0),
+ State0#state{connection_states = Cs}
+ end;
+ _ -> %% client
+ State0
end,
{keep_state, State};
connection(info, Event, State) ->
@@ -572,14 +579,20 @@ connection(internal, #client_hello{}, #state{static_env = #static_env{role = ser
dtls_gen_connection:next_event(?FUNCTION_NAME, Record, State);
connection(internal, new_connection, #state{ssl_options=SSLOptions,
handshake_env=HsEnv,
+ static_env = #static_env{socket = {Listener, {Client, _}}},
connection_states = OldCs} = State) ->
case maps:get(previous_cs, OldCs, undefined) of
undefined ->
- BeastMitigation = maps:get(beast_mitigation, SSLOptions, disabled),
- ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
- ConnectionStates = ConnectionStates0#{previous_cs => OldCs},
- {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
- connection_states = ConnectionStates}};
+ case dtls_packet_demux:new_connection(Listener, Client) of
+ true ->
+ {keep_state, State};
+ false ->
+ BeastMitigation = maps:get(beast_mitigation, SSLOptions, disabled),
+ ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
+ ConnectionStates = ConnectionStates0#{previous_cs => OldCs},
+ {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
+ connection_states = ConnectionStates}}
+ end;
_ ->
%% Someone spamming new_connection, just drop them
{keep_state, State}
diff --git a/lib/ssl/src/dtls_gen_connection.erl b/lib/ssl/src/dtls_gen_connection.erl
index 4964d3d21f..aa634a3218 100644
--- a/lib/ssl/src/dtls_gen_connection.erl
+++ b/lib/ssl/src/dtls_gen_connection.erl
@@ -572,6 +572,12 @@ handle_info(new_cookie_secret, StateName,
{next_state, StateName, State#state{protocol_specific =
CookieInfo#{current_cookie_secret => dtls_v1:cookie_secret(),
previous_cookie_secret => Secret}}};
+handle_info({socket_reused, Client}, StateName,
+ #state{static_env = #static_env{socket = {_, {Client, _}}}} = State) ->
+ Alert = ?ALERT_REC(?FATAL, ?CLOSE_NOTIFY, transport_closed),
+ ssl_gen_statem:handle_normal_shutdown(Alert#alert{role = server}, StateName, State),
+ {stop, {shutdown, transport_closed}, State};
+
handle_info(Msg, StateName, State) ->
ssl_gen_statem:handle_info(Msg, StateName, State).
diff --git a/lib/ssl/src/dtls_packet_demux.erl b/lib/ssl/src/dtls_packet_demux.erl
index 5c4556fbb8..4ad0e9ee03 100644
--- a/lib/ssl/src/dtls_packet_demux.erl
+++ b/lib/ssl/src/dtls_packet_demux.erl
@@ -33,6 +33,8 @@
sockname/1,
close/1,
new_owner/1,
+ new_connection/2,
+ connection_setup/2,
get_all_opts/1,
set_all_opts/2,
get_sock_opts/2,
@@ -84,6 +86,12 @@ close(PacketSocket) ->
new_owner(PacketSocket) ->
call(PacketSocket, new_owner).
+new_connection(PacketSocket, Client) ->
+ call(PacketSocket, {new_connection, Client, self()}).
+
+connection_setup(PacketSocket, Client) ->
+ gen_server:cast(PacketSocket, {connection_setup, Client}).
+
get_sock_opts(PacketSocket, SplitSockOpts) ->
call(PacketSocket, {get_sock_opts, SplitSockOpts}).
get_all_opts(PacketSocket) ->
@@ -145,6 +153,18 @@ handle_call(close, _, #state{dtls_processes = Processes,
end;
handle_call(new_owner, _, State) ->
{reply, ok, State#state{close = false, first = true}};
+handle_call({new_connection, Old, _Pid}, _,
+ #state{accepters = Accepters, dtls_msq_queues = MsgQs0} = State) ->
+ case queue:is_empty(Accepters) of
+ false ->
+ OldQueue = kv_get(Old, MsgQs0),
+ MsgQs1 = kv_delete(Old, MsgQs0),
+ MsgQs = kv_insert({old,Old}, OldQueue, MsgQs1),
+ {reply, true, State#state{dtls_msq_queues = MsgQs}};
+ true ->
+ {reply, false, State}
+ end;
+
handle_call({get_sock_opts, {SocketOptNames, EmOptNames}}, _, #state{listener = Socket,
emulated_options = EmOpts} = State) ->
case get_socket_opts(Socket, SocketOptNames) of
@@ -169,7 +189,16 @@ handle_call({getstat, Options}, _, #state{listener = Socket, transport = {Tran
handle_cast({active_once, Client, Pid}, State0) ->
State = handle_active_once(Client, Pid, State0),
- {noreply, State}.
+ {noreply, State};
+handle_cast({connection_setup, Client}, #state{dtls_msq_queues = MsgQueues} = State) ->
+ case kv_lookup({old, Client}, MsgQueues) of
+ none ->
+ {noreply, State};
+ {value, {Pid, _}} ->
+ Pid ! {socket_reused, Client},
+ %% Will be deleted when handling DOWN message
+ {noreply, State}
+ end.
handle_info({Transport, Socket, IP, InPortNo, _} = Msg, #state{listener = Socket, transport = {_,Transport,_,_,_}} = State0) ->
State = handle_datagram({IP, InPortNo}, Msg, State0),
@@ -189,23 +218,40 @@ handle_info({udp_error, Socket, econnreset = Error}, #state{listener = Socket, t
?LOG_NOTICE(Report),
{noreply, State};
handle_info({ErrorTag, Socket, Error}, #state{listener = Socket, transport = {_,_,_, ErrorTag,_}} = State) ->
- Report = io_lib:format("SSL Packet muliplxer shutdown: Socket error: ~p ~n", [Error]),
+ Report = io_lib:format("SSL Packet muliplexer shutdown: Socket error: ~p ~n", [Error]),
?LOG_NOTICE(Report),
{noreply, State#state{close=true}};
handle_info({'DOWN', _, process, Pid, _},
#state{dtls_processes = Processes0,
dtls_msq_queues = MsgQueues0,
- close = ListenClosed} = State) ->
+ close = ListenClosed} = State0) ->
Client = kv_get(Pid, Processes0),
Processes = kv_delete(Pid, Processes0),
- MsgQueues = kv_delete(Client, MsgQueues0),
+ State = case kv_lookup(Client, MsgQueues0) of
+ none ->
+ MsgQueues1 = kv_delete({old, Client}, MsgQueues0),
+ State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1};
+ {value, {Pid, _}} ->
+ MsgQueues1 = kv_delete(Client, MsgQueues0),
+ %% Restore old process if exists
+ case kv_lookup({old, Client}, MsgQueues1) of
+ none ->
+ State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1};
+ {value, Old} ->
+ MsgQueues2 = kv_delete({old, Client}, MsgQueues1),
+ MsgQueues = kv_insert(Client, Old, MsgQueues2),
+ State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues}
+ end;
+ {value, _} -> %% Old process died (just delete its queue)
+ MsgQueues1 = kv_delete({old, Client}, MsgQueues0),
+ State0#state{dtls_processes = Processes, dtls_msq_queues = MsgQueues1}
+ end,
case ListenClosed andalso kv_empty(Processes) of
true ->
{stop, normal, State};
false ->
- {noreply, State#state{dtls_processes = Processes,
- dtls_msq_queues = MsgQueues}}
+ {noreply, State}
end.
terminate(_Reason, _State) ->
@@ -232,33 +278,39 @@ handle_datagram(Client, Msg, #state{dtls_msq_queues = MsgQueues, accepters = Acc
dispatch(Queue, Client, Msg, State)
end.
-dispatch(Queue0, Client, Msg, #state{dtls_msq_queues = MsgQueues} = State) ->
+dispatch({Pid, Queue0}, Client, Msg, #state{dtls_msq_queues = MsgQueues} = State) ->
case queue:out(Queue0) of
{{value, Pid}, Queue} when is_pid(Pid) ->
Pid ! Msg,
State#state{dtls_msq_queues =
- kv_update(Client, Queue, MsgQueues)};
+ kv_update(Client, {Pid, Queue}, MsgQueues)};
{{value, _UDP}, _Queue} ->
State#state{dtls_msq_queues =
- kv_update(Client, queue:in(Msg, Queue0), MsgQueues)};
+ kv_update(Client, {Pid, queue:in(Msg, Queue0)}, MsgQueues)};
{empty, Queue} ->
State#state{dtls_msq_queues =
- kv_update(Client, queue:in(Msg, Queue), MsgQueues)}
+ kv_update(Client, {Pid, queue:in(Msg, Queue)}, MsgQueues)}
end.
next_datagram(Socket, N) ->
inet:setopts(Socket, [{active, N}]).
handle_active_once(Client, Pid, #state{dtls_msq_queues = MsgQueues} = State0) ->
- Queue0 = kv_get(Client, MsgQueues),
+ {Key, Queue0} = case kv_lookup(Client, MsgQueues) of
+ {value, {Pid, Q0}} -> {Client, Q0};
+ _ ->
+ OldKey = {old, Client},
+ {Pid, Q0} = kv_get(OldKey, MsgQueues),
+ {OldKey, Q0}
+ end,
case queue:out(Queue0) of
- {{value, Pid}, _} when is_pid(Pid) ->
- State0;
- {{value, Msg}, Queue} ->
- Pid ! Msg,
- State0#state{dtls_msq_queues = kv_update(Client, Queue, MsgQueues)};
- {empty, Queue0} ->
- State0#state{dtls_msq_queues = kv_update(Client, queue:in(Pid, Queue0), MsgQueues)}
+ {{value, Pid}, _} when is_pid(Pid) ->
+ State0;
+ {{value, Msg}, Queue} ->
+ Pid ! Msg,
+ State0#state{dtls_msq_queues = kv_update(Key, {Pid, Queue}, MsgQueues)};
+ {empty, Queue0} ->
+ State0#state{dtls_msq_queues = kv_update(Key, {Pid, queue:in(Pid, Queue0)}, MsgQueues)}
end.
setup_new_connection(User, From, Client, Msg, #state{dtls_processes = Processes,
@@ -275,7 +327,7 @@ setup_new_connection(User, From, Client, Msg, #state{dtls_processes = Processes,
erlang:monitor(process, Pid),
gen_server:reply(From, {ok, Pid, {Client, Socket}}),
Pid ! Msg,
- State#state{dtls_msq_queues = kv_insert(Client, queue:new(), MsgQueues),
+ State#state{dtls_msq_queues = kv_insert(Client, {Pid, queue:new()}, MsgQueues),
dtls_processes = kv_insert(Pid, Client, Processes)};
{error, Reason} ->
gen_server:reply(From, {error, Reason}),
diff --git a/lib/ssl/test/dtls_api_SUITE.erl b/lib/ssl/test/dtls_api_SUITE.erl
index f6dab82bd0..9602612f5e 100644
--- a/lib/ssl/test/dtls_api_SUITE.erl
+++ b/lib/ssl/test/dtls_api_SUITE.erl
@@ -52,7 +52,8 @@
dtls_listen_two_sockets_5/1,
dtls_listen_two_sockets_6/0,
dtls_listen_two_sockets_6/1,
- client_restarts/0, client_restarts/1
+ client_restarts/0, client_restarts/1,
+ client_restarts_multiple_acceptors/1
]).
-include_lib("ssl/src/ssl_internal.hrl").
@@ -84,7 +85,8 @@ api_tests() ->
dtls_listen_two_sockets_4,
dtls_listen_two_sockets_5,
dtls_listen_two_sockets_6,
- client_restarts
+ client_restarts,
+ client_restarts_multiple_acceptors
].
init_per_suite(Config0) ->
@@ -354,59 +356,159 @@ client_restarts() ->
[{doc, "Test re-connection "}].
client_restarts(Config) ->
- ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+ ClientOpts0 = ssl_test_lib:ssl_options(client_rsa_opts, Config),
ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
{ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
+ ClientOpts = [{verify, verify_none},{reuse_sessions, save} | ClientOpts0],
Server =
ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
{from, self()},
{mfa, {ssl_test_lib, no_result, []}},
- {options, ServerOpts}]),
+ {options, [{verify, verify_none}|ServerOpts]}]),
Port = ssl_test_lib:inet_port(Server),
Client0 = ssl_test_lib:start_client([{node, ClientNode},
{port, Port}, {host, Hostname},
{mfa, {ssl_test_lib, no_result, []}},
{from, self()},
- {options, [{reuse_sessions, save} | ClientOpts]}]),
+ {options, ClientOpts}]),
+
+ ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+ ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+ Server ! {active_receive, Msg1},
+ Client0 ! {active_receive, Msg2},
+
+ Msgs = lists:sort([{Server, Msg1}, {Client0, Msg2}]),
+ Msgs = lists:sort(flush()),
+
ReConnect = %% Whitebox re-connect test
fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
- {ok, {{Address,CPort},UDPSocket}=IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
-
+ {ok, IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+ {{Address,CPort},UDPSocket}=IntSocket,
ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
{ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
- ssl = SslOpts0}} = ssl:handle_options(ClientOpts, client, Address),
+ ssl = SslOpts0}} =
+ ssl:handle_options(ClientOpts, client, Address),
SslOpts = {SslOpts0, #socket_options{}, undefined},
ct:sleep(250),
ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
- Res = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket, SslOpts, self(), CbInfo, infinity),
- {Res, Pid}
+ {ok, NewSocket} = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket,
+ SslOpts, self(), CbInfo, infinity),
+ {replace, NewSocket}
end,
Client0 ! {apply, self(), ReConnect},
receive
- {apply_res, {Res, _Prev}} ->
+ {apply_res, {replace, Res}} ->
ct:log("Apply res: ~p~n", [Res]),
ok;
- Msg ->
- ct:log("Unhandled: ~p~n", [Msg]),
- ct:fail({wrong_msg, Msg})
+ ErrMsg ->
+ ct:log("Unhandled: ~p~n", [ErrMsg]),
+ ct:fail({wrong_msg, ErrMsg})
end,
+ ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+ ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+ Server ! {active_receive, Msg1},
+ Client0 ! {active_receive, Msg2},
+
+ Msgs = lists:sort(flush()),
+
+ ssl_test_lib:close(Server),
+ ssl_test_lib:close(Client0),
+ ok.
+
+
+flush() ->
+ receive Msg -> [Msg|flush()]
+ after 500 -> []
+ end.
+
+client_restarts_multiple_acceptors(Config) ->
+ %% Can also be tested with openssl by connecting a client and hit
+ %% Ctrl-C to kill openssl process, so that the connection is not
+ %% closed.
+ %% Then do a new openssl connect with the same client port.
+
+ ClientOpts0 = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+ ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
+ {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
+
+ ClientOpts = [{verify, verify_none},{reuse_sessions, save} | ClientOpts0],
+ Server =
+ ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
+ {from, self()},
+ {mfa, {ssl_test_lib, no_result, []}},
+ {accepters, 2},
+ {options, [{verify, verify_none}|ServerOpts]}]),
+ Port = ssl_test_lib:inet_port(Server),
+ Client0 = ssl_test_lib:start_client([{node, ClientNode},
+ {port, Port}, {host, Hostname},
+ {mfa, {ssl_test_lib, no_result, []}},
+ {from, self()},
+ {options, ClientOpts}]),
+
+ Server2 = receive {accepter, 2, Server2Pid} -> Server2Pid
+ after 5000 -> ct:fail(msg_timeout)
+ end,
+
+ ssl_test_lib:send(Client0, Msg1 = "from client 0"),
+ ssl_test_lib:send(Server, Msg2 = "from server to client 0"),
+
+ Server ! {active_receive, Msg1},
+ Client0 ! {active_receive, Msg2},
+
+ Msgs = lists:sort([{Server, Msg1}, {Client0, Msg2}]),
+ Msgs = lists:sort(flush()),
+
+ ReConnect = %% Whitebox re-connect test
+ fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
+ ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
+ {ok, IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+ {{Address,CPort},UDPSocket}=IntSocket,
+ ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
+
+ {ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
+ ssl = SslOpts0}} =
+ ssl:handle_options(ClientOpts, client, Address),
+ SslOpts = {SslOpts0, #socket_options{}, undefined},
+
+ ct:sleep(250),
+ ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
+ {ok, NewSocket} = ssl_gen_statem:connect(ConnectionCb, Address, CPort, IntSocket,
+ SslOpts, self(), CbInfo, infinity),
+ {replace, NewSocket}
+ end,
+
+ Client0 ! {apply, self(), ReConnect},
receive
- Msg2 ->
- ct:log("Unhandled: ~p~n", [Msg2]),
- ct:fail({wrong_msg, Msg2})
- after 200 ->
- ct:log("Nothing received~n", [])
+ {apply_res, {replace, Res}} ->
+ ct:log("Apply res: ~p~n", [Res]),
+ ok;
+ ErrMsg ->
+ ct:log("Unhandled: ~p~n", [ErrMsg]),
+ ct:fail({wrong_msg, ErrMsg})
end,
+ ok = ssl_test_lib:send(Client0, Msg3 = "from client 2"),
+ ok = ssl_test_lib:send(Server2, Msg4 = "from server 2 to client 2"),
+ {error, closed} = ssl_test_lib:send(Server, "Should be closed"),
+
+ Msgs2 = lists:sort([{Server2, Msg3}, {Client0, Msg4}]),
+
+ Server2 ! {active_receive, Msg3},
+ Client0 ! {active_receive, Msg4},
+
+ Msgs2 = lists:sort(flush()),
+
ssl_test_lib:close(Server),
+ ssl_test_lib:close(Server2),
ssl_test_lib:close(Client0),
-
ok.
%%--------------------------------------------------------------------
diff --git a/lib/ssl/test/ssl_test_lib.erl b/lib/ssl/test/ssl_test_lib.erl
index 5a15283e82..feeedca4ee 100644
--- a/lib/ssl/test/ssl_test_lib.erl
+++ b/lib/ssl/test/ssl_test_lib.erl
@@ -1107,13 +1107,17 @@ client_loop_core(Socket, Pid, Transport) ->
{gen_tcp, closed} ->
ok;
{apply, From, Fun} ->
- try
- Res = Fun(Socket, Transport),
- From ! {apply_res, Res}
+ try Fun(Socket, Transport) of
+ {replace, NewSocket} = Res ->
+ From ! {apply_res, Res},
+ client_loop_core(NewSocket, Pid, Transport);
+ Res ->
+ From ! {apply_res, Res},
+ client_loop_core(Socket, Pid, Transport)
catch E:R:ST ->
- From ! {apply_res, {E,R,ST}}
- end,
- client_loop_core(Socket, Pid, Transport)
+ From ! {apply_res, {E,R,ST}},
+ client_loop_core(Socket, Pid, Transport)
+ end
end.
client_cont_loop(_Node, Host, Port, Pid, Transport, Options, cancel, _Opts) ->
--
2.35.3