File 3832-Rewrite-set_ktls-for-better-error-checking.patch of Package erlang

From c46bb39a02f6e4032199d00867022488cf919903 Mon Sep 17 00:00:00 2001
From: Raimo Niskanen <raimo@erlang.org>
Date: Thu, 11 Aug 2022 20:58:26 +0200
Subject: [PATCH 02/11] Rewrite set_ktls for better error checking

---
 lib/ssl/src/inet_tls_dist.erl   | 249 ++++++++++++++++++++++----------
 lib/ssl/test/ssl_dist_SUITE.erl |  63 ++++----
 2 files changed, 199 insertions(+), 113 deletions(-)

diff --git a/lib/ssl/src/inet_tls_dist.erl b/lib/ssl/src/inet_tls_dist.erl
index 89db0bab11..638eca2339 100644
--- a/lib/ssl/src/inet_tls_dist.erl
+++ b/lib/ssl/src/inet_tls_dist.erl
@@ -330,37 +330,23 @@ accept_one(Driver, Kernel, Socket) ->
           net_kernel:connecttime())
     of
         {ok, #sslsocket{pid = [Receiver, Sender| _]} = SslSocket} ->
-            DistCtrl = case KTLS of
+            case KTLS of
                 true ->
                     {ok, KtlsInfo} = ssl_gen_statem:ktls_handover(Receiver),
-                    set_ktls(KtlsInfo),
-                    Socket;
-                false ->
-                    Sender
-            end,
-            trace(
-              Kernel !
-                  {accept, self(), DistCtrl,
-                   Driver:family(), tls}),
-            receive
-                {Kernel, controller, Pid} ->
-                    ChangeOwner = case KTLS of
-                        true ->
-                            inet_tcp:controlling_process(Socket, Pid);
-                        false ->
-                            ssl:controlling_process(SslSocket, Pid)
-                    end,
-                    case ChangeOwner of
+                    case set_ktls(KtlsInfo) of
                         ok ->
-                            trace(Pid ! {self(), controller});
-                        Error ->
-                            trace(Pid ! {self(), exit}),
+                            accept_one(
+                              Driver, Kernel, Socket, KTLS, Socket);
+                        {error, KtlsReason} ->
                             ?LOG_ERROR(
-                                "Cannot control TLS distribution connection: ~p~n",
-                                [Error])
+                               [{slogan, set_ktls_failed},
+                                {reason, KtlsReason},
+                                {pid, self()}]),
+                            gen_tcp:close(Socket),
+                            trace({ktls_error, KtlsReason})
                     end;
-                {Kernel, unsupported_protocol} ->
-                    trace(unsupported_protocol)
+                false ->
+                    accept_one(Driver, Kernel, SslSocket, KTLS, Sender)
             end;
         {error, {options, _}} = Error ->
             %% Bad options: that's probably our fault.
@@ -374,6 +360,31 @@ accept_one(Driver, Kernel, Socket) ->
             gen_tcp:close(Socket),
             trace(Other)
     end.
+%%
+accept_one(Driver, Kernel, DistSocket, KTLS, DistCtrl) ->
+    trace(Kernel ! {accept, self(), DistCtrl, Driver:family(), tls}),
+    receive
+        {Kernel, controller, Pid} ->
+            case
+                case KTLS of
+                    true ->
+                        inet_tcp:controlling_process(DistSocket, Pid);
+                    false ->
+                        ssl:controlling_process(DistSocket, Pid)
+                end
+            of
+                ok ->
+                    trace(Pid ! {self(), controller});
+                {error, Reason} ->
+                    trace(Pid ! {self(), exit}),
+                    ?LOG_ERROR(
+                       [{slogan, controlling_process_failed},
+                        {reason, Reason},
+                        {pid, self()}])
+            end;
+        {Kernel, unsupported_protocol} ->
+            trace(unsupported_protocol)
+    end.
 
 
 %% {verify_fun,{fun ?MODULE:verify_client/3,_}} is used
@@ -655,27 +666,36 @@ do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNo
         net_kernel:connecttime()
     ) of
         {ok, #sslsocket{pid = [Receiver, Sender| _]} = SslSocket} ->
-            HSData0 = case KTLS of
-                true ->
-                    {ok, KtlsInfo} = ssl_gen_statem:ktls_handover(Receiver),
-                    set_ktls(KtlsInfo),
-                    #{socket := Socket} = KtlsInfo,
-                    hs_data_common(Socket);
-                false ->
-                    _ = monitor_pid(Sender),
-                    ok = ssl:controlling_process(SslSocket, self()),
-                    link(Sender),
-                    hs_data_common(SslSocket)
-            end,
+            DistSocket =
+                case KTLS of
+                    true ->
+                        {ok, KtlsInfo} =
+                            ssl_gen_statem:ktls_handover(Receiver),
+                        case set_ktls(KtlsInfo) of
+                            ok ->
+                                #{socket := Socket} = KtlsInfo,
+                                Socket;
+                            {error, KtlsReason} ->
+                                ?shutdown2(
+                                   Node,
+                                   trace({set_ktls_falied, KtlsReason}))
+                        end;
+                    false ->
+                        _ = monitor_pid(Sender),
+                        ok = ssl:controlling_process(SslSocket, self()),
+                        link(Sender),
+                        SslSocket
+                end,
             HSData =
-                HSData0#hs_data{
-                kernel_pid = Kernel,
-                other_node = Node,
-                this_node = MyNode,
-                timer = Timer,
-                this_flags = 0,
-                other_version = Version,
-                request_type = Type},
+                (hs_data_common(DistSocket))
+                #hs_data{
+                  kernel_pid = Kernel,
+                  other_node = Node,
+                  this_node = MyNode,
+                  timer = Timer,
+                  this_flags = 0,
+                  other_version = Version,
+                  request_type = Type},
             dist_util:handshake_we_started(trace(HSData));
         Other ->
         %% Other Node may have closed since
@@ -972,38 +992,111 @@ verify_fun(Value) ->
 	    error(malformed_ssl_dist_opt, [Value])
     end.
 
-set_ktls(#{
-    socket := Socket,
-    tls_version := {3, 4},
-    cipher_suite := ?TLS_AES_256_GCM_SHA384,
-    socket_options := #socket_options{
-        mode = _Mode,
-        packet = Packet,
-        packet_size = PacketSize,
-        header = Header,
-        active = Active
-    },
-    write_state := #cipher_state{
-        iv = <<WriteSalt:4/bytes, WriteIV:8/bytes>>, key = WriteKey
-    },
-    write_seq := WriteSeq,
-    read_state := #cipher_state{
-        iv = <<ReadSalt:4/bytes, ReadIV:8/bytes>>, key = ReadKey
-    },
-    read_seq := ReadSeq
-}) ->
-    % SOL_TCP = 6, TCP_ULP = 31
-    inet:setopts(Socket, [{raw, 6, 31, <<"tls">>}]),
-    % SOL_TLS = 282, TLS_TX = 1, TLS_RX = 2, TLS_1_3_VERSION = <<4, 3>>, TLS_CIPHER_AES_GCM_256 = <<52, 0>>
-    inet:setopts(Socket, [
-        {raw, 282, 1, <<4, 3, 52, 0, WriteIV/binary, WriteKey/binary, WriteSalt/binary, WriteSeq:64>>}
-    ]),
-    inet:setopts(Socket, [
-        {raw, 282, 2, <<4, 3, 52, 0, ReadIV/binary, ReadKey/binary, ReadSalt/binary, ReadSeq:64>>}
-    ]),
-    inet:setopts(Socket, [
-        list, {packet, Packet}, {packet_size, PacketSize}, {header, Header}, {active, Active}
-    ]).
+set_ktls(#{socket := Socket} = KtlsInfo) ->
+    %% Check OS support
+    case {os:type(), os:version()} of
+        {{unix,linux}, {Major,Minor,_}}
+          when 5 == Major, 2 =< Minor;
+               5 < Major ->
+            set_ktls(KtlsInfo, Socket);
+        OsTypeVersion ->
+            {error, {ktls_invalid_os, OsTypeVersion}}
+    end.
+%%
+%% Check TLS version and cipher suite
+set_ktls(
+  #{tls_version := {3,4}, % 'tlsv1.3'
+    cipher_suite := CipherSuite} = KtlsInfo,
+  Socket)
+  when CipherSuite =:= ?TLS_AES_256_GCM_SHA384 ->
+    %%
+    %% See include/netinet/tcp.h
+    %%
+    SOL_TCP = 6,
+    TCP_ULP = 31,
+    _ = inet:setopts(Socket, [{raw, SOL_TCP, TCP_ULP, <<"tls">>}]),
+    set_ktls(
+      KtlsInfo, Socket,
+      inet:getopts(Socket, [{raw, SOL_TCP, TCP_ULP, 4}]),
+      {raw, SOL_TCP, TCP_ULP, <<"tls",0>>});
+set_ktls(
+  #{tls_version := TLSVersion, cipher_suite := CipherSuite},
+ _Socket) ->
+    {error, {ktls_invalid_cipher, TLSVersion, CipherSuite}}.
+%%
+%% Check if kernel module loaded,
+%% i.e if getopts SOL_TCP,TCP_ULP returned "tls"
+set_ktls(
+  KtlsInfo, Socket,
+  {ok, [ULP]},
+  ULP) ->
+    #{write_state :=
+          #cipher_state{
+             key = <<WriteKey:32/bytes>>,
+             iv = <<WriteSalt:4/bytes, WriteIV:8/bytes>>
+            },
+      write_seq := WriteSeq,
+      read_state :=
+          #cipher_state{
+             key = <<ReadKey:32/bytes>>,
+             iv = <<ReadSalt:4/bytes, ReadIV:8/bytes>>
+            },
+      read_seq := ReadSeq,
+      socket_options := SocketOptions} = KtlsInfo,
+    %%
+    %% See include/linux/tls.h
+    %%
+    TLS_1_3_VERSION_MAJOR = 3,
+    TLS_1_3_VERSION_MINOR = 4,
+    TLS_1_3_VERSION =
+        (TLS_1_3_VERSION_MAJOR bsl 8) bor TLS_1_3_VERSION_MINOR,
+    TLS_CIPHER_AES_GCM_256 = 52,
+    TLS_crypto_info_TX =
+        <<TLS_1_3_VERSION:16/native,
+          TLS_CIPHER_AES_GCM_256:16/native,
+          WriteIV/bytes, WriteKey/bytes,
+          WriteSalt/bytes, WriteSeq:64/native>>,
+    TLS_crypto_info_RX =
+        <<TLS_1_3_VERSION:16/native,
+          TLS_CIPHER_AES_GCM_256:16/native,
+          ReadIV/bytes, ReadKey/bytes,
+          ReadSalt/bytes, ReadSeq:64/native>>,
+    SOL_TLS = 282,
+    TLS_TX = 1,
+    TLS_RX = 2,
+    RawOptTX = {raw, SOL_TLS, TLS_TX, TLS_crypto_info_TX},
+    RawOptRX = {raw, SOL_TLS, TLS_RX, TLS_crypto_info_RX},
+    _ = inet:setopts(Socket, [RawOptTX]),
+    _ = inet:setopts(Socket, [RawOptRX]),
+    %%
+    %% Check if cipher could be set
+    case
+        inet:getopts(
+          Socket, [{raw, SOL_TLS, TLS_TX, byte_size(TLS_crypto_info_TX)}])
+    of
+        {ok, [RawOptTX]} ->
+            #socket_options{
+               mode = _Mode,
+               packet = Packet,
+               packet_size = PacketSize,
+               header = Header,
+               active = Active
+              } = SocketOptions,
+            case
+                inet:setopts(
+                  Socket,
+                  [list, {packet, Packet}, {packet_size, PacketSize},
+                   {header, Header}, {active, Active}])
+            of
+                ok -> ok;
+                {error, SetoptError} ->
+                    {error, {ktls_setopt_failed, SetoptError}}
+            end;
+        Other ->
+            {error, {ktls_set_cipher_failed, Other}}
+    end;
+set_ktls(_KtlsInfo, _Socket, BadGetoptULP, _ULP) ->
+    {error, {ktls_not_supported, BadGetoptULP}}.
 
 %% -------------------------------------------------------------------------
 
diff --git a/lib/ssl/test/ssl_dist_SUITE.erl b/lib/ssl/test/ssl_dist_SUITE.erl
index 767f37d875..b7abc3afad 100644
--- a/lib/ssl/test/ssl_dist_SUITE.erl
+++ b/lib/ssl/test/ssl_dist_SUITE.erl
@@ -158,41 +158,34 @@ init_per_testcase(plain_verify_options = Case, Config) when is_list(Config) ->
     common_init(Case, [{old_flags, Flags} | Config]);
 
 init_per_testcase(ktls_basic = Case, Config) when is_list(Config) ->
-    try
-        {ok, Listen} = gen_tcp:listen(0, [{active, false}]),
-        {ok, Port} = inet:port(Listen),
-        {ok, Client} = gen_tcp:connect("localhost", Port, [{active, false}]),
-        {ok, Server} = gen_tcp:accept(Listen),
-        ServerTx = <<4,3,52,0,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
-                     2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,0,0,0,0,0,0,0,0>>,
-        ServerRx = <<4,3,52,0,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,
-                     5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,6,6,6,6,0,0,0,0,0,0,0,0>>,
-        ClientTx = ServerRx,
-        ClientRx = ServerTx,
-        inet:setopts(Server, [{raw, 6, 31, <<"tls">>}]),
-        inet:setopts(Server, [{raw, 282, 1, ServerTx}]),
-        inet:setopts(Server, [{raw, 282, 2, ServerRx}]),
-        inet:setopts(Client, [{raw, 6, 31, <<"tls">>}]),
-        inet:setopts(Client, [{raw, 282, 1, ClientTx}]),
-        inet:setopts(Client, [{raw, 282, 2, ClientRx}]),
-        {ok, [{raw, 6, 31, <<"tls">>}]} = inet:getopts(Server, [{raw, 6, 31, 3}]),
-        {ok, [{raw, 282, 1, ServerTx}]} = inet:getopts(Server, [{raw, 282, 1, 56}]),
-        {ok, [{raw, 6, 31, <<"tls">>}]} = inet:getopts(Client, [{raw, 6, 31, 3}]),
-        {ok, [{raw, 282, 1, ClientTx}]} = inet:getopts(Client, [{raw, 282, 1, 56}]),
-        ok = gen_tcp:send(Client, "client"),
-        {ok, "client"} = gen_tcp:recv(Server, 6, 1000),
-        ok = gen_tcp:send(Server, "server"),
-        {ok, "server"} = gen_tcp:recv(Client, 6, 1000),
-        gen_tcp:close(Server),
-        gen_tcp:close(Client),
-        gen_tcp:close(Listen),
-        common_init(Case, Config)
-    catch
-        Class:Reason:Stacktrace ->
-            {skip, lists:flatten(io_lib:format(
-                "ktls not supported, ~p:~p:~0p",
-                [Class, Reason, Stacktrace]
-            ))}
+    case {os:type(), os:version()} of
+        {{unix,linux}, {Major,Minor,_}}
+          when 5 == Major, 2 =< Minor;
+               5 < Major ->
+            {ok, Listen} = gen_tcp:listen(0, [{active, false}]),
+            {ok, Port} = inet:port(Listen),
+            {ok, Client} =
+                gen_tcp:connect({127,0,0,1}, Port, [{active, false}]),
+            {ok, Server} = gen_tcp:accept(Listen),
+            %%
+            _ = inet:setopts(Server, [{raw, 6, 31, <<"tls">>}]),
+            Result = inet:getopts(Server, [{raw, 6, 31, 4}]),
+            %%
+            _ = gen_tcp:close(Server),
+            _ = gen_tcp:close(Client),
+            _ = gen_tcp:close(Listen),
+            case Result of
+                {ok, [{raw, 6, 31, <<"tls",0>>}]} ->
+                    common_init(Case, Config);
+                Other ->
+                    {skip,
+                     lists:flatten(
+                       io_lib:format("kTLS not supported, ~p", [Other]))}
+            end;
+        OS ->
+            {skip,
+             lists:flatten(
+               io_lib:format("kTLS not supported by OS: ~p", [OS]))}
     end;
 
 init_per_testcase(Case, Config) when is_list(Config) ->
-- 
2.35.3

openSUSE Build Service is sponsored by