File 4595-Allow-re-connect-on-dtls-sockets.patch of Package erlang
From 8e592d2167ed4e99b22109f9081ba625843b2057 Mon Sep 17 00:00:00 2001
From: Dan Gudmundsson <dgud@erlang.org>
Date: Thu, 15 Jul 2021 13:36:38 +0200
Subject: [PATCH 05/10] Allow re-connect on dtls sockets
Can happen when a computer reboots and connects from the same client
port without the server noticing, should be allowed according to RFC.
---
lib/ssl/src/dtls_connection.erl | 18 ++++++++
lib/ssl/src/dtls_gen_connection.erl | 67 ++++++++++++++++------------
lib/ssl/src/ssl_gen_statem.erl | 2 +
lib/ssl/test/dtls_api_SUITE.erl | 68 +++++++++++++++++++++++++++--
lib/ssl/test/ssl_test_lib.erl | 10 ++++-
5 files changed, 134 insertions(+), 31 deletions(-)
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index ae24dc31cc..a27b9a6213 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -541,12 +541,30 @@ connection(internal, #client_hello{}, #state{static_env = #static_env{role = ser
State1 = dtls_gen_connection:send_alert(Alert, State0),
{Record, State} = ssl_gen_statem:prepare_connection(State1, Connection),
dtls_gen_connection:next_event(?FUNCTION_NAME, Record, State);
+connection(internal, new_connection, #state{ssl_options=SSLOptions,
+ handshake_env=HsEnv,
+ connection_states = OldCs} = State) ->
+ #{beast_mitigation := BeastMitigation} = SSLOptions,
+ ConnectionStates0 = dtls_record:init_connection_states(server, BeastMitigation),
+ #{current_write:=CW, current_read:=CR} = OldCs,
+ ConnectionStates = ConnectionStates0#{saved_write:=CW, saved_read:=CR},
+ {next_state, hello, State#state{handshake_env = HsEnv#handshake_env{renegotiation = {false, first}},
+ connection_states = ConnectionStates}};
connection({call, From}, {application_data, Data}, State) ->
try
send_application_data(Data, From, ?FUNCTION_NAME, State)
catch throw:Error ->
ssl_gen_statem:hibernate_after(?FUNCTION_NAME, State, [{reply, From, Error}])
end;
+connection({call, From}, {downgrade, Pid},
+ #state{connection_env = CEnv,
+ static_env = #static_env{transport_cb = Transport,
+ socket = {_Server, Socket} = DTLSSocket}} = State) ->
+ %% For testing purposes, downgrades without noticing the server
+ dtls_socket:setopts(Transport, Socket, [{active, false}, {packet, 0}, {mode, binary}]),
+ Transport:controlling_process(Socket, Pid),
+ {stop_and_reply, {shutdown, normal}, {reply, From, {ok, DTLSSocket}},
+ State#state{connection_env = CEnv#connection_env{socket_terminated = true}}};
connection(Type, Event, State) ->
try
tls_dtls_connection:?FUNCTION_NAME(Type, Event, State)
diff --git a/lib/ssl/src/dtls_gen_connection.erl b/lib/ssl/src/dtls_gen_connection.erl
index 68f6ff0ec8..ebc9766645 100644
--- a/lib/ssl/src/dtls_gen_connection.erl
+++ b/lib/ssl/src/dtls_gen_connection.erl
@@ -105,30 +105,29 @@ next_record(#state{protocol_buffers =
CurrentRead = dtls_record:get_connection_state_by_epoch(Epoch, ConnectionStates, read),
case dtls_record:replay_detect(CT, CurrentRead) of
false ->
- decode_cipher_text(State#state{connection_states = ConnectionStates}) ;
+ decode_cipher_text(State) ;
true ->
%% Ignore replayed record
- next_record(State#state{protocol_buffers =
- Buffers#protocol_buffers{dtls_cipher_texts = Rest},
- connection_states = ConnectionStates})
+ next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}})
end;
next_record(#state{protocol_buffers =
#protocol_buffers{dtls_cipher_texts = [#ssl_tls{epoch = Epoch} | Rest]}
= Buffers,
- connection_states = #{current_read := #{epoch := CurrentEpoch}} = ConnectionStates} = State)
+ connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State)
when Epoch > CurrentEpoch ->
%% TODO Buffer later Epoch message, drop it for now
- next_record(State#state{protocol_buffers =
- Buffers#protocol_buffers{dtls_cipher_texts = Rest},
- connection_states = ConnectionStates});
-next_record(#state{protocol_buffers =
- #protocol_buffers{dtls_cipher_texts = [ _ | Rest]}
- = Buffers,
- connection_states = ConnectionStates} = State) ->
- %% Drop old epoch message
- next_record(State#state{protocol_buffers =
- Buffers#protocol_buffers{dtls_cipher_texts = Rest},
- connection_states = ConnectionStates});
+ next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}});
+next_record(#state{protocol_buffers = #protocol_buffers{dtls_cipher_texts =
+ [#ssl_tls{epoch = Epoch} | Rest]
+ } = Buffers
+ } = State) ->
+ case Epoch of
+ 0 -> %% A reconnect (client might have rebooted and re-connected)
+ decode_cipher_text(State);
+ _ ->
+ %% Drop old epoch message
+ next_record(State#state{protocol_buffers = Buffers#protocol_buffers{dtls_cipher_texts = Rest}})
+ end;
next_record(#state{static_env = #static_env{role = server,
socket = {Listener, {Client, _}}}} = State) ->
dtls_packet_demux:active_once(Listener, Client, self()),
@@ -187,10 +186,11 @@ next_event(StateName, no_record,
next_event(connection = StateName, Record,
#state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
case Record of
- #ssl_tls{epoch = CurrentEpoch,
+ #ssl_tls{epoch = Epoch,
type = ?HANDSHAKE,
- version = Version} = Record ->
- State = dtls_version(StateName, Version, State0),
+ version = Version} = Record
+ when Epoch =:= CurrentEpoch; Epoch =:= 0 ->
+ State = dtls_version(StateName, Version, State0),
{next_state, StateName, State,
[{next_event, internal, {protocol_record, Record}} | Actions]};
#ssl_tls{epoch = CurrentEpoch} ->
@@ -330,9 +330,8 @@ handle_protocol_record(#ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, Stat
ssl_gen_statem:hibernate_after(StateName, State, Actions)
end;
%%% DTLS record protocol level handshake messages
-handle_protocol_record(#ssl_tls{type = ?HANDSHAKE,
- fragment = Data},
- StateName,
+handle_protocol_record(#ssl_tls{type = ?HANDSHAKE, epoch = Epoch, fragment = Data},
+ StateName,
#state{protocol_buffers = Buffers0,
connection_env = #connection_env{negotiated_version = Version},
ssl_options = Options} = State) ->
@@ -342,12 +341,17 @@ handle_protocol_record(#ssl_tls{type = ?HANDSHAKE,
next_event(StateName, no_record, State#state{protocol_buffers = Buffers});
{Packets, Buffers} ->
HsEnv = State#state.handshake_env,
- Events = dtls_handshake_events(Packets),
- {next_state, StateName,
+ HSEvents = dtls_handshake_events(Packets),
+ Events = case is_new_connection(Epoch, Packets, State) of
+ true -> [{next_event, internal, new_connection} | HSEvents];
+ false -> HSEvents
+ end,
+ {next_state, StateName,
State#state{protocol_buffers = Buffers,
- handshake_env =
- HsEnv#handshake_env{unprocessed_handshake_events
- = unprocessed_events(Events)}}, Events}
+ handshake_env =
+ HsEnv#handshake_env{
+ unprocessed_handshake_events = unprocessed_events(HSEvents)}
+ }, Events}
end
catch throw:#alert{} = Alert ->
handle_own_alert(Alert, StateName, State)
@@ -550,6 +554,15 @@ handle_info(Msg, StateName, State) ->
%% Internal functions
%%====================================================================
+is_new_connection(0, [{#client_hello{},_Raw}|_],
+ #state{
+ connection_states =
+ #{current_read := #{epoch := CurrentEpoch}}})
+ when CurrentEpoch > 0 ->
+ true;
+is_new_connection(_, _, _) ->
+ false.
+
dtls_handshake_events(Packets) ->
lists:map(fun(Packet) ->
{next_event, internal, {handshake, Packet}}
diff --git a/lib/ssl/src/ssl_gen_statem.erl b/lib/ssl/src/ssl_gen_statem.erl
index 1370036013..46926d15ef 100644
--- a/lib/ssl/src/ssl_gen_statem.erl
+++ b/lib/ssl/src/ssl_gen_statem.erl
@@ -720,6 +720,8 @@ handle_common_event({timeout, recv}, timeout, StateName, #state{start_or_recv_fr
handle_common_event(internal, {recv, RecvFrom}, StateName, #state{start_or_recv_from = RecvFrom}) when
StateName =/= connection ->
{keep_state_and_data, [postpone]};
+handle_common_event(internal, new_connection, StateName, State) ->
+ {next_state, StateName, State};
handle_common_event(Type, Msg, StateName, State) ->
Alert = ?ALERT_REC(?FATAL,?UNEXPECTED_MESSAGE, {unexpected_msg, {Type, Msg}}),
handle_own_alert(Alert, StateName, State).
diff --git a/lib/ssl/test/dtls_api_SUITE.erl b/lib/ssl/test/dtls_api_SUITE.erl
index 572702af02..4117bec12e 100644
--- a/lib/ssl/test/dtls_api_SUITE.erl
+++ b/lib/ssl/test/dtls_api_SUITE.erl
@@ -51,9 +51,12 @@
dtls_listen_two_sockets_5/0,
dtls_listen_two_sockets_5/1,
dtls_listen_two_sockets_6/0,
- dtls_listen_two_sockets_6/1
+ dtls_listen_two_sockets_6/1,
+ client_restarts/0, client_restarts/1
]).
+-include_lib("ssl/src/ssl_internal.hrl").
+
%%--------------------------------------------------------------------
%% Common Test interface functions -----------------------------------
%%--------------------------------------------------------------------
@@ -80,7 +83,8 @@ api_tests() ->
dtls_listen_two_sockets_3,
dtls_listen_two_sockets_4,
dtls_listen_two_sockets_5,
- dtls_listen_two_sockets_6
+ dtls_listen_two_sockets_6,
+ client_restarts
].
init_per_suite(Config0) ->
@@ -300,7 +304,6 @@ dtls_listen_two_sockets_6(_Config) when is_list(_Config) ->
ssl:close(S1),
ok.
-
replay_window() ->
[{doc, "Whitebox test of replay window"}].
replay_window(_Config) ->
@@ -347,6 +350,65 @@ bits_to_list(Bits, I, Acc) ->
0 -> bits_to_list(Bits bsr 1, I+1, Acc)
end.
+client_restarts() ->
+ [{doc, "Test re-connection "}].
+
+client_restarts(Config) ->
+ ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config),
+ ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config),
+ {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config),
+
+ Server =
+ ssl_test_lib:start_server([{node, ServerNode}, {port, 0},
+ {from, self()},
+ {mfa, {ssl_test_lib, no_result, []}},
+ {options, ServerOpts}]),
+ Port = ssl_test_lib:inet_port(Server),
+ Client0 = ssl_test_lib:start_client([{node, ClientNode},
+ {port, Port}, {host, Hostname},
+ {mfa, {ssl_test_lib, no_result, []}},
+ {from, self()},
+ {options, [{reuse_sessions, save} | ClientOpts]}]),
+ ReConnect = %% Whitebox re-connect test
+ fun({sslsocket, {gen_udp,_,dtls_gen_connection}, [Pid]} = Socket, ssl) ->
+ ct:log("~p Client Socket: ~p ~n", [self(), Socket]),
+ {ok, {{Adress,CPort},UDPSocket}=IntSocket} = gen_statem:call(Pid, {downgrade, self()}),
+ true = is_port(UDPSocket),
+ ct:log("Info: ~p~n", [inet:info(UDPSocket)]),
+
+ {ok, #config{transport_info = CbInfo, connection_cb = ConnectionCb,
+ ssl = SslOpts0}} = ssl:handle_options(ClientOpts, client, Adress),
+ SslOpts = {SslOpts0, #socket_options{}, undefined},
+
+ ct:sleep(250),
+ ct:log("Client second connect: ~p ~p~n", [Socket, CbInfo]),
+ Res = ssl_gen_statem:connect(ConnectionCb, Adress, CPort, IntSocket, SslOpts, self(), CbInfo, infinity),
+ {Res, Pid}
+ end,
+
+ Client0 ! {apply, self(), ReConnect},
+ receive
+ {apply_res, {Res, _Prev}} ->
+ ct:log("Apply res: ~p~n", [Res]),
+ ok;
+ Msg ->
+ ct:log("Unhandled: ~p~n", [Msg]),
+ ct:fail({wrong_msg, Msg})
+ end,
+
+ receive
+ Msg2 ->
+ ct:log("Unhandled: ~p~n", [Msg2]),
+ ct:fail({wrong_msg, Msg2})
+ after 200 ->
+ ct:log("Nothing received~n", [])
+ end,
+
+ ssl_test_lib:close(Server),
+ ssl_test_lib:close(Client0),
+
+ ok.
+
%%--------------------------------------------------------------------
%% Internal functions ------------------------------------------------
%%--------------------------------------------------------------------
diff --git a/lib/ssl/test/ssl_test_lib.erl b/lib/ssl/test/ssl_test_lib.erl
index e4c23c22cf..313eec3e22 100644
--- a/lib/ssl/test/ssl_test_lib.erl
+++ b/lib/ssl/test/ssl_test_lib.erl
@@ -1041,7 +1041,15 @@ client_loop_core(Socket, Pid, Transport) ->
{ssl_closed, Socket} ->
ok;
{gen_tcp, closed} ->
- ok
+ ok;
+ {apply, From, Fun} ->
+ try
+ Res = Fun(Socket, Transport),
+ From ! {apply_res, Res}
+ catch E:R:ST ->
+ From ! {apply_res, {E,R,ST}}
+ end,
+ client_loop_core(Socket, Pid, Transport)
end.
client_cont_loop(_Node, Host, Port, Pid, Transport, Options, cancel, _Opts) ->
--
2.31.1