File 3831-Introduce-kernel-TLS-dist-module.patch of Package erlang
From 2c4e83212d1c989cabe3dd0839f65a7b5de9eeb8 Mon Sep 17 00:00:00 2001
From: Zeyu Zhang <zeyu@fb.com>
Date: Wed, 15 Jun 2022 16:16:33 -0700
Subject: [PATCH 01/11] Introduce kernel TLS dist module
---
lib/ssl/src/inet_tls_dist.erl | 156 ++++++++++++++++++++++++-----
lib/ssl/src/ssl.erl | 3 +
lib/ssl/src/ssl_gen_statem.erl | 53 +++++++++-
lib/ssl/src/ssl_internal.hrl | 1 +
lib/ssl/src/tls_connection_1_3.erl | 12 ++-
lib/ssl/src/tls_gen_connection.erl | 7 ++
lib/ssl/test/ssl_dist_SUITE.erl | 89 ++++++++++++++++
7 files changed, 291 insertions(+), 30 deletions(-)
diff --git a/lib/ssl/src/inet_tls_dist.erl b/lib/ssl/src/inet_tls_dist.erl
index f0b5da43b3..89db0bab11 100644
--- a/lib/ssl/src/inet_tls_dist.erl
+++ b/lib/ssl/src/inet_tls_dist.erl
@@ -41,6 +41,8 @@
-include_lib("public_key/include/public_key.hrl").
-include("ssl_api.hrl").
+-include("ssl_cipher.hrl").
+-include("ssl_internal.hrl").
-include_lib("kernel/include/logger.hrl").
%% -------------------------------------------------------------------------
@@ -70,8 +72,55 @@ is_node_name(Node) ->
%% -------------------------------------------------------------------------
+hs_data_common(Socket) when is_port(Socket) ->
+ {ok, {Ip, Port}} = inet:peername(Socket),
+ #hs_data{
+ socket = Socket,
+ f_send = fun inet_tcp:send/2,
+ f_recv = fun inet_tcp:recv/3,
+ f_setopts_pre_nodeup =
+ fun(S) ->
+ inet:setopts(
+ S,
+ [
+ {active, false},
+ {packet, 4},
+ nodelay()
+ ]
+ )
+ end,
+ f_setopts_post_nodeup =
+ fun(S) ->
+ inet:setopts(
+ S,
+ [
+ {active, true},
+ {deliver, port},
+ {packet, 4},
+ nodelay()
+ ]
+ )
+ end,
+
+ f_getll = fun inet:getll/1,
+ f_address =
+ fun(_, Node) ->
+ {node, _, Host} = dist_util:split_node(Node),
+ #net_address{
+ address = {Ip, Port},
+ host = Host,
+ protocol = tls,
+ family = inet
+ }
+ end,
+ mf_tick = fun(S) -> inet_tcp_dist:tick(inet_tcp, S) end,
+ mf_getstat = fun inet_tcp_dist:getstat/1,
+ mf_setopts = fun inet_tcp_dist:setopts/2,
+ mf_getopts = fun inet_tcp_dist:getopts/2
+ };
hs_data_common(#sslsocket{pid = [_, DistCtrl|_]} = SslSocket) ->
#hs_data{
+ socket = DistCtrl,
f_send =
fun (_Ctrl, Packet) ->
f_send(SslSocket, Packet)
@@ -272,6 +321,7 @@ spawn_accept({Driver, Listen, Kernel}) ->
accept_one(Driver, Kernel, Socket) ->
Opts = setup_verify_client(Socket, get_ssl_options(server)),
+ KTLS = proplists:get_value(ktls, Opts, false),
wait_for_code_server(),
case
ssl:handshake(
@@ -279,14 +329,28 @@ accept_one(Driver, Kernel, Socket) ->
trace([{active, false},{packet, 4}|Opts]),
net_kernel:connecttime())
of
- {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} ->
+ {ok, #sslsocket{pid = [Receiver, Sender| _]} = SslSocket} ->
+ DistCtrl = case KTLS of
+ true ->
+ {ok, KtlsInfo} = ssl_gen_statem:ktls_handover(Receiver),
+ set_ktls(KtlsInfo),
+ Socket;
+ false ->
+ Sender
+ end,
trace(
Kernel !
{accept, self(), DistCtrl,
Driver:family(), tls}),
receive
{Kernel, controller, Pid} ->
- case ssl:controlling_process(SslSocket, Pid) of
+ ChangeOwner = case KTLS of
+ true ->
+ inet_tcp:controlling_process(Socket, Pid);
+ false ->
+ ssl:controlling_process(SslSocket, Pid)
+ end,
+ case ChangeOwner of
ok ->
trace(Pid ! {self(), controller});
Error ->
@@ -448,19 +512,22 @@ do_accept(
receive
{AcceptPid, controller} ->
erlang:demonitor(MRef, [flush]),
- {ok, SslSocket} = tls_sender:dist_tls_socket(DistCtrl),
- Timer = dist_util:start_timer(SetupTime),
- NewAllowed = allowed_nodes(SslSocket, Allowed),
- HSData0 = hs_data_common(SslSocket),
+ Timer = dist_util:start_timer(SetupTime),
+ {HSData0, NewAllowed} = case is_port(DistCtrl) of
+ true ->
+ {hs_data_common(DistCtrl), Allowed};
+ false ->
+ {ok, SslSocket} = tls_sender:dist_tls_socket(DistCtrl),
+ link(DistCtrl),
+ {hs_data_common(SslSocket), allowed_nodes(SslSocket, Allowed)}
+ end,
HSData =
HSData0#hs_data{
kernel_pid = Kernel,
this_node = MyNode,
- socket = DistCtrl,
timer = Timer,
this_flags = 0,
allowed = NewAllowed},
- link(DistCtrl),
dist_util:handshake_other_started(trace(HSData));
{AcceptPid, exit} ->
%% this can happen when connection was initiated, but dropped
@@ -579,35 +646,43 @@ do_setup(Driver, Kernel, Node, Type, MyNode, LongOrShortNames, SetupTime) ->
do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNode, Timer) ->
Opts = trace(connect_options(get_ssl_options(client))),
+ KTLS = proplists:get_value(ktls, Opts, false),
dist_util:reset_timer(Timer),
case ssl:connect(
Ip, TcpPort,
[binary, {active, false}, {packet, 4}, {server_name_indication, Address},
Driver:family(), {nodelay, true}] ++ Opts,
- net_kernel:connecttime()) of
- {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} ->
- _ = monitor_pid(DistCtrl),
- ok = ssl:controlling_process(SslSocket, self()),
- HSData0 = hs_data_common(SslSocket),
- HSData =
+ net_kernel:connecttime()
+ ) of
+ {ok, #sslsocket{pid = [Receiver, Sender| _]} = SslSocket} ->
+ HSData0 = case KTLS of
+ true ->
+ {ok, KtlsInfo} = ssl_gen_statem:ktls_handover(Receiver),
+ set_ktls(KtlsInfo),
+ #{socket := Socket} = KtlsInfo,
+ hs_data_common(Socket);
+ false ->
+ _ = monitor_pid(Sender),
+ ok = ssl:controlling_process(SslSocket, self()),
+ link(Sender),
+ hs_data_common(SslSocket)
+ end,
+ HSData =
HSData0#hs_data{
kernel_pid = Kernel,
other_node = Node,
this_node = MyNode,
- socket = DistCtrl,
timer = Timer,
this_flags = 0,
other_version = Version,
request_type = Type},
- link(DistCtrl),
- dist_util:handshake_we_started(trace(HSData));
- Other ->
- %% Other Node may have closed since
- %% port_please !
- ?shutdown2(
- Node,
- trace(
- {ssl_connect_failed, Ip, TcpPort, Other}))
+ dist_util:handshake_we_started(trace(HSData));
+ Other ->
+ %% Other Node may have closed since
+ %% port_please !
+ ?shutdown2(
+ Node,
+ trace({ssl_connect_failed, Ip, TcpPort, Other}))
end.
close(Socket) ->
@@ -897,6 +972,39 @@ verify_fun(Value) ->
error(malformed_ssl_dist_opt, [Value])
end.
+set_ktls(#{
+ socket := Socket,
+ tls_version := {3, 4},
+ cipher_suite := ?TLS_AES_256_GCM_SHA384,
+ socket_options := #socket_options{
+ mode = _Mode,
+ packet = Packet,
+ packet_size = PacketSize,
+ header = Header,
+ active = Active
+ },
+ write_state := #cipher_state{
+ iv = <<WriteSalt:4/bytes, WriteIV:8/bytes>>, key = WriteKey
+ },
+ write_seq := WriteSeq,
+ read_state := #cipher_state{
+ iv = <<ReadSalt:4/bytes, ReadIV:8/bytes>>, key = ReadKey
+ },
+ read_seq := ReadSeq
+}) ->
+ % SOL_TCP = 6, TCP_ULP = 31
+ inet:setopts(Socket, [{raw, 6, 31, <<"tls">>}]),
+ % SOL_TLS = 282, TLS_TX = 1, TLS_RX = 2, TLS_1_3_VERSION = <<4, 3>>, TLS_CIPHER_AES_GCM_256 = <<52, 0>>
+ inet:setopts(Socket, [
+ {raw, 282, 1, <<4, 3, 52, 0, WriteIV/binary, WriteKey/binary, WriteSalt/binary, WriteSeq:64>>}
+ ]),
+ inet:setopts(Socket, [
+ {raw, 282, 2, <<4, 3, 52, 0, ReadIV/binary, ReadKey/binary, ReadSalt/binary, ReadSeq:64>>}
+ ]),
+ inet:setopts(Socket, [
+ list, {packet, Packet}, {packet_size, PacketSize}, {header, Header}, {active, Active}
+ ]).
+
%% -------------------------------------------------------------------------
%% Trace point
diff --git a/lib/ssl/src/ssl.erl b/lib/ssl/src/ssl.erl
index c16c076afd..665449be8c 100644
--- a/lib/ssl/src/ssl.erl
+++ b/lib/ssl/src/ssl.erl
@@ -2204,6 +2204,9 @@ validate_option(early_data = Option, Value, client) ->
validate_option(erl_dist, Value, _)
when is_boolean(Value) ->
Value;
+validate_option(ktls, Value, _)
+ when is_boolean(Value) ->
+ Value;
validate_option(fail_if_no_peer_cert, Value, _)
when is_boolean(Value) ->
Value;
diff --git a/lib/ssl/src/ssl_gen_statem.erl b/lib/ssl/src/ssl_gen_statem.erl
index 0b4d032f78..d5d83612b2 100644
--- a/lib/ssl/src/ssl_gen_statem.erl
+++ b/lib/ssl/src/ssl_gen_statem.erl
@@ -62,7 +62,8 @@
set_opts/2,
peer_certificate/1,
negotiated_protocol/1,
- connection_information/2
+ connection_information/2,
+ ktls_handover/1
]).
%% Erlang Distribution export
@@ -422,6 +423,14 @@ peer_certificate(ConnectionPid) ->
negotiated_protocol(ConnectionPid) ->
call(ConnectionPid, negotiated_protocol).
+%%--------------------------------------------------------------------
+-spec ktls_handover(pid()) -> {ok, map()} | {error, reason()}.
+%%
+%% Description: Returns the negotiated protocol
+%%--------------------------------------------------------------------
+ktls_handover(ConnectionPid) ->
+ call(ConnectionPid, ktls_handover).
+
dist_handshake_complete(ConnectionPid, DHandle) ->
gen_statem:cast(ConnectionPid, {dist_handshake_complete, DHandle}).
@@ -648,6 +657,45 @@ connection({call, From},
{error, timeout} ->
{stop_and_reply, {shutdown, downgrade_fail}, [{reply, From, {error, timeout}}]}
end;
+connection({call, From}, ktls_handover, #state{
+ static_env = #static_env{
+ transport_cb = Transport,
+ socket = Socket
+ },
+ connection_env = #connection_env{
+ user_application = {_Mon, Pid},
+ negotiated_version = TlsVersion
+ },
+ ssl_options = #{ktls := true},
+ socket_options = SocketOpts,
+ connection_states = #{
+ current_write := #{
+ security_parameters := #security_parameters{cipher_suite = CipherSuite},
+ cipher_state := WriteState,
+ sequence_number := WriteSeq
+ },
+ current_read := #{
+ cipher_state := ReadState,
+ sequence_number := ReadSeq
+ }
+ }
+}) ->
+ Reply = case Transport:controlling_process(Socket, Pid) of
+ ok ->
+ {ok, #{
+ socket => Socket,
+ tls_version => TlsVersion,
+ cipher_suite => CipherSuite,
+ socket_options => SocketOpts,
+ write_state => WriteState,
+ write_seq => WriteSeq,
+ read_state => ReadState,
+ read_seq => ReadSeq
+ }};
+ {error, Reason} ->
+ {error, Reason}
+ end,
+ {stop_and_reply, {shutdown, ktls}, [{reply, From, Reply}]};
connection({call, From}, Msg, State) ->
handle_call(Msg, From, ?FUNCTION_NAME, State);
connection(cast, {dist_handshake_complete, DHandle},
@@ -1129,6 +1177,9 @@ maybe_invalidate_session({false, first}, server = Role, Host, Port, Session) ->
maybe_invalidate_session(_, _, _, _, _) ->
ok.
+terminate({shutdown, ktls}, connection, State) ->
+ %% Socket shall not be closed as it should be returned to user
+ handle_trusted_certs_db(State);
terminate({shutdown, downgrade}, downgrade, State) ->
%% Socket shall not be closed as it should be returned to user
handle_trusted_certs_db(State);
diff --git a/lib/ssl/src/ssl_internal.hrl b/lib/ssl/src/ssl_internal.hrl
index 93d7c2456e..86f55c4601 100644
--- a/lib/ssl/src/ssl_internal.hrl
+++ b/lib/ssl/src/ssl_internal.hrl
@@ -160,6 +160,7 @@
keyfile => {undefined, [versions,
certfile]},
key_update_at => {?KEY_USAGE_LIMIT_AES_GCM, [versions]},
+ ktls => {false, [versions]},
log_level => {notice, [versions]},
max_handshake_size => {?DEFAULT_MAX_HANDSHAKE_SIZE, [versions]},
middlebox_comp_mode => {true, [versions]},
diff --git a/lib/ssl/src/tls_connection_1_3.erl b/lib/ssl/src/tls_connection_1_3.erl
index 90eb9f2474..e59b9693ca 100644
--- a/lib/ssl/src/tls_connection_1_3.erl
+++ b/lib/ssl/src/tls_connection_1_3.erl
@@ -516,8 +516,7 @@ do_client_start(ServerHello, State0) ->
initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trackers}, User,
{CbModule, DataTag, CloseTag, ErrorTag, PassiveTag}) ->
put(log_level, maps:get(log_level, SSLOptions)),
- #{erl_dist := IsErlDist,
- %% Use highest supported version for client/server random nonce generation
+ #{%% Use highest supported version for client/server random nonce generation
versions := [Version|_],
client_renegotiation := ClientRenegotiation} = SSLOptions,
MaxEarlyDataSize = init_max_early_data_size(Role),
@@ -557,12 +556,15 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac
start_or_recv_from = undefined,
flight_buffer = [],
protocol_specific = #{sender => Sender,
- active_n => internal_active_n(IsErlDist),
+ active_n => internal_active_n(SSLOptions, Socket),
active_n_toggle => true
}
}.
-internal_active_n(true) ->
+internal_active_n(#{ktls := true}, Socket) ->
+ inet:setopts(Socket, [{packet, ssl_tls}]),
+ 1;
+internal_active_n(#{erl_dist := true}, _) ->
%% Start with a random number between 1 and ?INTERNAL_ACTIVE_N
%% In most cases distribution connections are established all at
%% the same time, and flow control engages with ?INTERNAL_ACTIVE_N for
@@ -571,7 +573,7 @@ internal_active_n(true) ->
%% a random number between 1 and ?INTERNAL_ACTIVE_N helps to spread the
%% spike.
erlang:system_time() rem ?INTERNAL_ACTIVE_N + 1;
-internal_active_n(false) ->
+internal_active_n(#{erl_dist := false}, _) ->
case application:get_env(ssl, internal_active_n) of
{ok, N} when is_integer(N) ->
N;
diff --git a/lib/ssl/src/tls_gen_connection.erl b/lib/ssl/src/tls_gen_connection.erl
index 1442c69927..38d2c36d3f 100644
--- a/lib/ssl/src/tls_gen_connection.erl
+++ b/lib/ssl/src/tls_gen_connection.erl
@@ -325,6 +325,11 @@ handle_info({CloseTag, Socket}, StateName,
%% is called after all data has been deliver.
{next_state, StateName, State#state{protocol_specific = PS#{active_n_toggle => true}}, []}
end;
+handle_info({ssl_tls, Port, Type, {Major, Minor}, Data}, StateName,
+ #state{static_env = #static_env{data_tag = Protocol},
+ ssl_options = #{ktls := true}} = State0) ->
+ Len = size(Data),
+ handle_info({Protocol, Port, <<Type, Major, Minor, Len:16, Data/binary>>}, StateName, State0);
handle_info(Msg, StateName, State) ->
ssl_gen_statem:handle_info(Msg, StateName, State).
@@ -632,6 +637,8 @@ next_record(_, #state{protocol_buffers = #protocol_buffers{tls_cipher_texts = []
next_record(_, State) ->
{no_record, State}.
+flow_ctrl(#state{ssl_options = #{ktls := true}} = State) ->
+ {no_record, State};
%%% bytes_to_read equals the integer Length arg of ssl:recv
%%% the actual value is only relevant for packet = raw | 0
%%% bytes_to_read = undefined means no recv call is ongoing
diff --git a/lib/ssl/test/ssl_dist_SUITE.erl b/lib/ssl/test/ssl_dist_SUITE.erl
index 79e4859b8c..767f37d875 100644
--- a/lib/ssl/test/ssl_dist_SUITE.erl
+++ b/lib/ssl/test/ssl_dist_SUITE.erl
@@ -37,6 +37,8 @@
%% Test cases
-export([basic/0,
basic/1,
+ ktls_basic/0,
+ ktls_basic/1,
monitor_nodes/1,
payload/0,
payload/1,
@@ -68,6 +70,7 @@
%% Apply export
-export([basic_test/3,
+ ktls_basic_test/3,
monitor_nodes_test/3,
payload_test/3,
plain_options_test/3,
@@ -105,6 +108,7 @@ start_ssl_node_name(Name, Args) ->
%%--------------------------------------------------------------------
all() ->
[basic,
+ ktls_basic,
monitor_nodes,
payload,
dist_port_overload,
@@ -153,6 +157,44 @@ init_per_testcase(plain_verify_options = Case, Config) when is_list(Config) ->
end,
common_init(Case, [{old_flags, Flags} | Config]);
+init_per_testcase(ktls_basic = Case, Config) when is_list(Config) ->
+ try
+ {ok, Listen} = gen_tcp:listen(0, [{active, false}]),
+ {ok, Port} = inet:port(Listen),
+ {ok, Client} = gen_tcp:connect("localhost", Port, [{active, false}]),
+ {ok, Server} = gen_tcp:accept(Listen),
+ ServerTx = <<4,3,52,0,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
+ 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,0,0,0,0,0,0,0,0>>,
+ ServerRx = <<4,3,52,0,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,
+ 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,6,6,6,6,0,0,0,0,0,0,0,0>>,
+ ClientTx = ServerRx,
+ ClientRx = ServerTx,
+ inet:setopts(Server, [{raw, 6, 31, <<"tls">>}]),
+ inet:setopts(Server, [{raw, 282, 1, ServerTx}]),
+ inet:setopts(Server, [{raw, 282, 2, ServerRx}]),
+ inet:setopts(Client, [{raw, 6, 31, <<"tls">>}]),
+ inet:setopts(Client, [{raw, 282, 1, ClientTx}]),
+ inet:setopts(Client, [{raw, 282, 2, ClientRx}]),
+ {ok, [{raw, 6, 31, <<"tls">>}]} = inet:getopts(Server, [{raw, 6, 31, 3}]),
+ {ok, [{raw, 282, 1, ServerTx}]} = inet:getopts(Server, [{raw, 282, 1, 56}]),
+ {ok, [{raw, 6, 31, <<"tls">>}]} = inet:getopts(Client, [{raw, 6, 31, 3}]),
+ {ok, [{raw, 282, 1, ClientTx}]} = inet:getopts(Client, [{raw, 282, 1, 56}]),
+ ok = gen_tcp:send(Client, "client"),
+ {ok, "client"} = gen_tcp:recv(Server, 6, 1000),
+ ok = gen_tcp:send(Server, "server"),
+ {ok, "server"} = gen_tcp:recv(Client, 6, 1000),
+ gen_tcp:close(Server),
+ gen_tcp:close(Client),
+ gen_tcp:close(Listen),
+ common_init(Case, Config)
+ catch
+ Class:Reason:Stacktrace ->
+ {skip, lists:flatten(io_lib:format(
+ "ktls not supported, ~p:~p:~0p",
+ [Class, Reason, Stacktrace]
+ ))}
+ end;
+
init_per_testcase(Case, Config) when is_list(Config) ->
common_init(Case, Config).
@@ -177,6 +219,12 @@ basic() ->
basic(Config) when is_list(Config) ->
gen_dist_test(basic_test, Config).
+%%--------------------------------------------------------------------
+ktls_basic() ->
+ [{doc,"Test that two nodes can connect via ssl distribution"}].
+ktls_basic(Config) when is_list(Config) ->
+ gen_dist_test(ktls_basic_test, Config).
+
%%--------------------------------------------------------------------
%% Test net_kernel:monitor_nodes with nodedown_reason (OTP-17838)
monitor_nodes(Config) when is_list(Config) ->
@@ -558,6 +606,47 @@ basic_test(NH1, NH2, _) ->
end)
end.
+ktls_basic_test(NH1, NH2, Config) ->
+ PrivDir = proplists:get_value(priv_dir, Config),
+ SslOpts = [
+ {
+ server,
+ [
+ {certfile, filename:join([PrivDir, "rsa_server_cert.pem"])},
+ {keyfile, filename:join([PrivDir, "rsa_server_key.pem"])},
+ {cacertfile, filename:join([PrivDir, "rsa_server_cacerts.pem"])},
+ {verify, verify_peer},
+ {fail_if_no_peer_cert, true},
+ {versions, ['tlsv1.3']},
+ {ciphers, [#{cipher => aes_256_gcm, key_exchange => any, mac => aead, prf => sha384}]},
+ {ktls, true}
+ ]
+ },
+ {
+ client,
+ [
+ {certfile, filename:join([PrivDir, "rsa_client_cert.pem"])},
+ {keyfile, filename:join([PrivDir, "rsa_client_key.pem"])},
+ {cacertfile, filename:join([PrivDir, "rsa_client_cacerts.pem"])},
+ {verify, verify_peer},
+ {customize_hostname_check, [{match_fun, fun(_, _) -> true end}]},
+ {versions, ['tlsv1.3']},
+ {ciphers, [#{cipher => aes_256_gcm, key_exchange => any, mac => aead, prf => sha384}]},
+ {ktls, true}
+ ]
+ }
+ ],
+ SetEtsOpts = fun () ->
+ spawn(fun () ->
+ ets:new(ssl_dist_opts, [named_table, public]),
+ ets:insert(ssl_dist_opts, SslOpts),
+ timer:sleep(infinity)
+ end)
+ end,
+ apply_on_ssl_node(NH1, SetEtsOpts),
+ apply_on_ssl_node(NH2, SetEtsOpts),
+ basic_test(NH1, NH2, Config).
+
monitor_nodes_test(NH1, NH2, _) ->
Node2 = NH2#node_handle.nodename,
--
2.35.3