File 1291-ssh-KEX-strict-implementation-fixes.patch of Package erlang

From e4b56a9f4a511aa9990dd86c16c61439c828df83 Mon Sep 17 00:00:00 2001
From: Jakub Witczak <kuba@erlang.org>
Date: Tue, 6 May 2025 17:01:29 +0200
Subject: [PATCH 1/2] ssh: KEX strict implementation fixes

- fixed KEX strict implementation
- draft-miller-sshm-strict-kex-01.txt
- ssh_dbg added to ssh_fsm_kexinit module
- CVE-2025-46712
---
 lib/ssh/src/ssh_connection_handler.erl |  24 ++--
 lib/ssh/src/ssh_fsm_kexinit.erl        | 129 ++++++++++++++++++--
 lib/ssh/src/ssh_transport.erl          |  13 +-
 lib/ssh/test/ssh_protocol_SUITE.erl    | 158 ++++++++++++++++++++++---
 lib/ssh/test/ssh_trpt_test_lib.erl     |  39 +++++-
 5 files changed, 313 insertions(+), 50 deletions(-)

diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl
index 5ddafa9975..15f98dfb5c 100644
--- a/lib/ssh/src/ssh_connection_handler.erl
+++ b/lib/ssh/src/ssh_connection_handler.erl
@@ -34,7 +34,6 @@
 -include("ssh_transport.hrl").
 -include("ssh_auth.hrl").
 -include("ssh_connect.hrl").
-
 -include("ssh_fsm.hrl").
 
 %%====================================================================
@@ -728,16 +727,6 @@ handle_event(internal, #ssh_msg_disconnect{description=Desc} = Msg, StateName, D
     disconnect_fun("Received disconnect: "++Desc, D),
     {stop_and_reply, {shutdown,Desc}, Actions, D};
 
-handle_event(internal, #ssh_msg_ignore{}, {_StateName, _Role, init},
-             #data{ssh_params = #ssh{kex_strict_negotiated = true,
-                                     send_sequence = SendSeq,
-                                     recv_sequence = RecvSeq}}) ->
-    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
-                io_lib:format("strict KEX violation: unexpected SSH_MSG_IGNORE "
-                              "send_sequence = ~p  recv_sequence = ~p",
-                              [SendSeq, RecvSeq])
-               );
-
 handle_event(internal, #ssh_msg_ignore{}, _StateName, _) ->
     keep_state_and_data;
 
@@ -1141,11 +1130,14 @@ handle_event(info, {Proto, Sock, NewData}, StateName,
     of
 	{packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} ->
 	    D1 = D0#data{ssh_params =
-			    Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)},
-			decrypted_data_buffer = <<>>,
-                        undecrypted_packet_length = undefined,
-                        aead_data = <<>>,
-			encrypted_data_buffer = EncryptedDataRest},
+                             Ssh1#ssh{recv_sequence =
+                                          ssh_transport:next_seqnum(StateName,
+                                                                    Ssh1#ssh.recv_sequence,
+                                                                    SshParams)},
+                         decrypted_data_buffer = <<>>,
+                         undecrypted_packet_length = undefined,
+                         aead_data = <<>>,
+                         encrypted_data_buffer = EncryptedDataRest},
 	    try
 		ssh_message:decode(set_kex_overload_prefix(DecryptedBytes,D1))
 	    of
diff --git a/lib/ssh/src/ssh_fsm_kexinit.erl b/lib/ssh/src/ssh_fsm_kexinit.erl
index 05f7bdf22f..b8fdc29079 100644
--- a/lib/ssh/src/ssh_fsm_kexinit.erl
+++ b/lib/ssh/src/ssh_fsm_kexinit.erl
@@ -43,6 +43,11 @@
 -export([callback_mode/0, handle_event/4, terminate/3,
 	 format_status/2, code_change/4]).
 
+-behaviour(ssh_dbg).
+-export([ssh_dbg_trace_points/0, ssh_dbg_flags/1,
+         ssh_dbg_on/1, ssh_dbg_off/1,
+         ssh_dbg_format/2]).
+
 %%====================================================================
 %% gen_statem callbacks
 %%====================================================================
@@ -53,8 +58,13 @@ callback_mode() ->
 
 %%--------------------------------------------------------------------
 
-%%% ######## {kexinit, client|server, init|renegotiate} ####
 
+handle_event(Type, Event = prepare_next_packet, StateName, D) ->
+    ssh_connection_handler:handle_event(Type, Event, StateName, D);
+handle_event(Type, Event = {send_disconnect, _, _, _, _}, StateName, D) ->
+    ssh_connection_handler:handle_event(Type, Event, StateName, D);
+
+%%% ######## {kexinit, client|server, init|renegotiate} ####
 handle_event(internal, {#ssh_msg_kexinit{}=Kex, Payload}, {kexinit,Role,ReNeg},
 	     D = #data{key_exchange_init_msg = OwnKex}) ->
     Ssh1 = ssh_transport:key_init(peer_role(Role), D#data.ssh_params, Payload),
@@ -67,11 +77,10 @@ handle_event(internal, {#ssh_msg_kexinit{}=Kex, Payload}, {kexinit,Role,ReNeg},
 	  end,
     {next_state, {key_exchange,Role,ReNeg}, D#data{ssh_params=Ssh}};
 
-
 %%% ######## {key_exchange, client|server, init|renegotiate} ####
-
 %%%---- diffie-hellman
 handle_event(internal, #ssh_msg_kexdh_init{} = Msg, {key_exchange,server,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, KexdhReply, Ssh1} = ssh_transport:handle_kexdh_init(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(KexdhReply, D),
     {ok, NewKeys, Ssh2} = ssh_transport:new_keys_message(Ssh1),
@@ -81,6 +90,7 @@ handle_event(internal, #ssh_msg_kexdh_init{} = Msg, {key_exchange,server,ReNeg},
     {next_state, {new_keys,server,ReNeg}, D#data{ssh_params=Ssh}};
 
 handle_event(internal, #ssh_msg_kexdh_reply{} = Msg, {key_exchange,client,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, NewKeys, Ssh1} = ssh_transport:handle_kexdh_reply(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(NewKeys, D),
     {ok, ExtInfo, Ssh} = ssh_transport:ext_info_message(Ssh1),
@@ -89,24 +99,28 @@ handle_event(internal, #ssh_msg_kexdh_reply{} = Msg, {key_exchange,client,ReNeg}
 
 %%%---- diffie-hellman group exchange
 handle_event(internal, #ssh_msg_kex_dh_gex_request{} = Msg, {key_exchange,server,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, GexGroup, Ssh1} = ssh_transport:handle_kex_dh_gex_request(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(GexGroup, D),
     Ssh = ssh_transport:parallell_gen_key(Ssh1),
     {next_state, {key_exchange_dh_gex_init,server,ReNeg}, D#data{ssh_params=Ssh}};
 
 handle_event(internal, #ssh_msg_kex_dh_gex_request_old{} = Msg, {key_exchange,server,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, GexGroup, Ssh1} = ssh_transport:handle_kex_dh_gex_request(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(GexGroup, D),
     Ssh = ssh_transport:parallell_gen_key(Ssh1),
     {next_state, {key_exchange_dh_gex_init,server,ReNeg}, D#data{ssh_params=Ssh}};
 
 handle_event(internal, #ssh_msg_kex_dh_gex_group{} = Msg, {key_exchange,client,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, KexGexInit, Ssh} = ssh_transport:handle_kex_dh_gex_group(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(KexGexInit, D),
     {next_state, {key_exchange_dh_gex_reply,client,ReNeg}, D#data{ssh_params=Ssh}};
 
 %%%---- elliptic curve diffie-hellman
 handle_event(internal, #ssh_msg_kex_ecdh_init{} = Msg, {key_exchange,server,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, KexEcdhReply, Ssh1} = ssh_transport:handle_kex_ecdh_init(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(KexEcdhReply, D),
     {ok, NewKeys, Ssh2} = ssh_transport:new_keys_message(Ssh1),
@@ -116,16 +130,25 @@ handle_event(internal, #ssh_msg_kex_ecdh_init{} = Msg, {key_exchange,server,ReNe
     {next_state, {new_keys,server,ReNeg}, D#data{ssh_params=Ssh}};
 
 handle_event(internal, #ssh_msg_kex_ecdh_reply{} = Msg, {key_exchange,client,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, NewKeys, Ssh1} = ssh_transport:handle_kex_ecdh_reply(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(NewKeys, D),
     {ok, ExtInfo, Ssh} = ssh_transport:ext_info_message(Ssh1),
     ssh_connection_handler:send_bytes(ExtInfo, D),
     {next_state, {new_keys,client,ReNeg}, D#data{ssh_params=Ssh}};
 
+%%% ######## handle KEX strict
+handle_event(internal, _Event, {key_exchange,_Role,init},
+             #data{ssh_params = #ssh{algorithms = #alg{kex_strict_negotiated = true},
+                                     send_sequence = SendSeq,
+                                     recv_sequence = RecvSeq}}) ->
+    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                io_lib:format("KEX strict violation: send_sequence = ~p  recv_sequence = ~p",
+                              [SendSeq, RecvSeq]));
 
 %%% ######## {key_exchange_dh_gex_init, server, init|renegotiate} ####
-
 handle_event(internal, #ssh_msg_kex_dh_gex_init{} = Msg, {key_exchange_dh_gex_init,server,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, KexGexReply, Ssh1} =  ssh_transport:handle_kex_dh_gex_init(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(KexGexReply, D),
     {ok, NewKeys, Ssh2} = ssh_transport:new_keys_message(Ssh1),
@@ -133,20 +156,33 @@ handle_event(internal, #ssh_msg_kex_dh_gex_init{} = Msg, {key_exchange_dh_gex_in
     {ok, ExtInfo, Ssh} = ssh_transport:ext_info_message(Ssh2),
     ssh_connection_handler:send_bytes(ExtInfo, D),
     {next_state, {new_keys,server,ReNeg}, D#data{ssh_params=Ssh}};
-
+%%% ######## handle KEX strict
+handle_event(internal, _Event, {key_exchange_dh_gex_init,_Role,init},
+             #data{ssh_params = #ssh{algorithms = #alg{kex_strict_negotiated = true},
+                                     send_sequence = SendSeq,
+                                     recv_sequence = RecvSeq}}) ->
+    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                io_lib:format("KEX strict violation: send_sequence = ~p  recv_sequence = ~p",
+                              [SendSeq, RecvSeq]));
 
 %%% ######## {key_exchange_dh_gex_reply, client, init|renegotiate} ####
-
 handle_event(internal, #ssh_msg_kex_dh_gex_reply{} = Msg, {key_exchange_dh_gex_reply,client,ReNeg}, D) ->
+    ok = check_kex_strict(Msg, D),
     {ok, NewKeys, Ssh1} = ssh_transport:handle_kex_dh_gex_reply(Msg, D#data.ssh_params),
     ssh_connection_handler:send_bytes(NewKeys, D),
     {ok, ExtInfo, Ssh} = ssh_transport:ext_info_message(Ssh1),
     ssh_connection_handler:send_bytes(ExtInfo, D),
     {next_state, {new_keys,client,ReNeg}, D#data{ssh_params=Ssh}};
-
+%%% ######## handle KEX strict
+handle_event(internal, _Event, {key_exchange_dh_gex_reply,_Role,init},
+             #data{ssh_params = #ssh{algorithms = #alg{kex_strict_negotiated = true},
+                                     send_sequence = SendSeq,
+                                     recv_sequence = RecvSeq}}) ->
+    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                io_lib:format("KEX strict violation: send_sequence = ~p  recv_sequence = ~p",
+                              [SendSeq, RecvSeq]));
 
 %%% ######## {new_keys, client|server} ####
-
 %% First key exchange round:
 handle_event(internal, #ssh_msg_newkeys{} = Msg, {new_keys,client,init}, D0) ->
     {ok, Ssh1} = ssh_transport:handle_new_keys(Msg, D0#data.ssh_params),
@@ -162,6 +198,15 @@ handle_event(internal, #ssh_msg_newkeys{} = Msg, {new_keys,server,init}, D) ->
     %% ssh_connection_handler:send_bytes(ExtInfo, D),
     {next_state, {ext_info,server,init}, D#data{ssh_params=Ssh}};
 
+%%% ######## handle KEX strict
+handle_event(internal, _Event, {new_keys,_Role,init},
+             #data{ssh_params = #ssh{algorithms = #alg{kex_strict_negotiated = true},
+                                     send_sequence = SendSeq,
+                                     recv_sequence = RecvSeq}}) ->
+    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                io_lib:format("KEX strict violation (send_sequence = ~p recv_sequence = ~p)",
+                              [SendSeq, RecvSeq]));
+
 %% Subsequent key exchange rounds (renegotiation):
 handle_event(internal, #ssh_msg_newkeys{} = Msg, {new_keys,Role,renegotiate}, D) ->
     {ok, Ssh} = ssh_transport:handle_new_keys(Msg, D#data.ssh_params),
@@ -183,7 +228,6 @@ handle_event(internal, #ssh_msg_ext_info{}=Msg, {ext_info,Role,renegotiate}, D0)
 handle_event(internal, #ssh_msg_newkeys{}=Msg, {ext_info,_Role,renegotiate}, D) ->
     {ok, Ssh} = ssh_transport:handle_new_keys(Msg, D#data.ssh_params),
     {keep_state, D#data{ssh_params = Ssh}};
-    
 
 handle_event(internal, Msg, {ext_info,Role,init}, D) when is_tuple(Msg) ->
     %% If something else arrives, goto next state and handle the event in that one
@@ -217,3 +261,70 @@ code_change(_OldVsn, StateName, State, _Extra) ->
 peer_role(client) -> server;
 peer_role(server) -> client.
 
+check_kex_strict(Msg,
+                 #data{ssh_params =
+                           #ssh{algorithms =
+                                    #alg{
+                                       kex = Kex,
+                                       kex_strict_negotiated = KexStrictNegotiated},
+                                send_sequence = SendSeq,
+                                recv_sequence = RecvSeq}}) ->
+    case check_msg_group(Msg, get_alg_group(Kex), KexStrictNegotiated) of
+        ok ->
+            ok;
+        error ->
+            ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                        io_lib:format("KEX strict violation: send_sequence = ~p  recv_sequence = ~p",
+                                      [SendSeq, RecvSeq]))
+    end.
+
+get_alg_group(Kex) when Kex == 'diffie-hellman-group16-sha512';
+                        Kex == 'diffie-hellman-group18-sha512';
+                        Kex == 'diffie-hellman-group14-sha256';
+                        Kex == 'diffie-hellman-group14-sha1';
+                        Kex == 'diffie-hellman-group1-sha1' ->
+    dh_alg;
+get_alg_group(Kex) when Kex == 'diffie-hellman-group-exchange-sha256';
+                        Kex == 'diffie-hellman-group-exchange-sha1' ->
+    dh_gex_alg;
+get_alg_group(Kex) when Kex == 'curve25519-sha256';
+                        Kex == 'curve25519-sha256@libssh.org';
+                        Kex == 'curve448-sha512';
+                        Kex == 'ecdh-sha2-nistp521';
+                        Kex == 'ecdh-sha2-nistp384';
+                        Kex == 'ecdh-sha2-nistp256' ->
+    ecdh_alg.
+
+check_msg_group(_Msg, _AlgGroup, false) -> ok;
+check_msg_group(#ssh_msg_kexdh_init{},  dh_alg, true) -> ok;
+check_msg_group(#ssh_msg_kexdh_reply{}, dh_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_dh_gex_request_old{}, dh_gex_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_dh_gex_request{},     dh_gex_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_dh_gex_group{},       dh_gex_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_dh_gex_init{},        dh_gex_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_dh_gex_reply{},       dh_gex_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_ecdh_init{},  ecdh_alg, true) -> ok;
+check_msg_group(#ssh_msg_kex_ecdh_reply{}, ecdh_alg, true) -> ok;
+check_msg_group(_Msg, _AlgGroup, _) -> error.
+
+%%%################################################################
+%%%#
+%%%# Tracing
+%%%#
+
+ssh_dbg_trace_points() -> [connection_events].
+
+ssh_dbg_flags(connection_events) -> [c].
+
+ssh_dbg_on(connection_events) -> dbg:tp(?MODULE,   handle_event, 4, x).
+
+ssh_dbg_off(connection_events) -> dbg:ctpg(?MODULE, handle_event, 4).
+
+ssh_dbg_format(connection_events, {call, {?MODULE,handle_event, [EventType, EventContent, State, _Data]}}) ->
+    ["Connection event\n",
+     io_lib:format("[~w] EventType: ~p~nEventContent: ~p~nState: ~p~n", [?MODULE, EventType, EventContent, State])
+    ];
+ssh_dbg_format(connection_events, {return_from, {?MODULE,handle_event,4}, Ret}) ->
+    ["Connection event result\n",
+     io_lib:format("[~w] ~p~n", [?MODULE, ssh_dbg:reduce_state(Ret, #data{})])
+    ].
diff --git a/lib/ssh/src/ssh_transport.erl b/lib/ssh/src/ssh_transport.erl
index 3e96ca9402..e612ffd0fe 100644
--- a/lib/ssh/src/ssh_transport.erl
+++ b/lib/ssh/src/ssh_transport.erl
@@ -26,12 +26,11 @@
 
 -include_lib("public_key/include/public_key.hrl").
 -include_lib("kernel/include/inet.hrl").
-
 -include("ssh_transport.hrl").
 -include("ssh.hrl").
 
 -export([versions/2, hello_version_msg/1]).
--export([next_seqnum/1, 
+-export([next_seqnum/3,
 	 supported_algorithms/0, supported_algorithms/1,
 	 default_algorithms/0, default_algorithms/1,
          clear_default_algorithms_env/0,
@@ -295,7 +294,12 @@ random_id(Nlo, Nup) ->
 hello_version_msg(Data) ->
     [Data,"\r\n"].
 
-next_seqnum(SeqNum) ->
+next_seqnum({State, _Role, init}, 16#ffffffff,
+            #ssh{algorithms = #alg{kex_strict_negotiated = true}})
+  when State == kexinit; State == key_exchange; State == new_keys ->
+    ?DISCONNECT(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
+                io_lib:format("KEX strict violation: recv_sequence = 16#ffffffff", []));
+next_seqnum(_State, SeqNum, _) ->
     (SeqNum + 1) band 16#ffffffff.
 
 is_valid_mac(_, _ , #ssh{recv_mac_size = 0}) ->
@@ -1080,7 +1084,7 @@ known_host_key(#ssh{opts = Opts, peer = {PeerName,{IP,Port}}} = Ssh,
 %%   algorithm.  Each string MUST contain at least one algorithm name.
 select_algorithm(Role, Client, Server,
                  #ssh{opts = Opts,
-                         kex_strict_negotiated = KexStrictNegotiated0},
+                      kex_strict_negotiated = KexStrictNegotiated0},
                  ReNeg) ->
     KexStrictNegotiated =
         case ReNeg of
@@ -1105,7 +1109,6 @@ select_algorithm(Role, Client, Server,
             _ ->
                 KexStrictNegotiated0
         end,
-
     {Encrypt0, Decrypt0} = select_encrypt_decrypt(Role, Client, Server),
     {SendMac0, RecvMac0} = select_send_recv_mac(Role, Client, Server),
 
diff --git a/lib/ssh/test/ssh_protocol_SUITE.erl b/lib/ssh/test/ssh_protocol_SUITE.erl
index 537642cff5..2e1c2a6c76 100644
--- a/lib/ssh/test/ssh_protocol_SUITE.erl
+++ b/lib/ssh/test/ssh_protocol_SUITE.erl
@@ -55,7 +55,9 @@
          ext_info_c/1,
          ext_info_s/1,
          kex_strict_negotiated/1,
-         kex_strict_msg_ignore/1,
+         kex_strict_violation_key_exchange/1,
+         kex_strict_violation_new_keys/1,
+         kex_strict_violation/1,
          kex_strict_msg_unknown/1,
          gex_client_init_option_groups/1,
          gex_client_init_option_groups_file/1,
@@ -144,7 +146,9 @@ groups() ->
 		gex_client_old_request_exact,
 		gex_client_old_request_noexact,
                 kex_strict_negotiated,
-                kex_strict_msg_ignore,
+                kex_strict_violation_key_exchange,
+                kex_strict_violation_new_keys,
+                kex_strict_violation,
                 kex_strict_msg_unknown]},
      {service_requests, [], [bad_service_name,
 			     bad_long_service_name,
@@ -1007,22 +1011,145 @@ kex_strict_negotiated(Config0) ->
     ssh_test_lib:rm_log_handler(),
     ok.
 
-%% Connect to an erlang server and inject unexpected SSH ignore
-kex_strict_msg_ignore(Config) ->
-    ct:log("START: ~p~n=================================", [?FUNCTION_NAME]),
-    ExpectedReason = "strict KEX violation: unexpected SSH_MSG_IGNORE",
-    TestMessages =
-        [{send, ssh_msg_ignore},
-         {match, #ssh_msg_kexdh_reply{_='_'}, receive_msg},
-         {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}],
-    kex_strict_helper(Config, TestMessages, ExpectedReason).
+%% Connect to an erlang server and inject unexpected SSH message
+%% ssh_fsm_kexinit in key_exchange state
+kex_strict_violation_key_exchange(Config) ->
+    ExpectedReason = "KEX strict violation",
+    Injections = [ssh_msg_ignore, ssh_msg_debug, ssh_msg_unimplemented],
+    TestProcedure =
+        fun(M) ->
+                ct:log(
+                  "=================== START: ~p Message: ~p Expected Fail =================================",
+                  [?FUNCTION_NAME, M]),
+                [receive_hello,
+                 {send, hello},
+                 {send, ssh_msg_kexinit},
+                 {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+                 {send, M},
+                 {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]
+        end,
+    [kex_strict_helper(Config, TestProcedure(Msg), ExpectedReason) ||
+        Msg <- Injections],
+    ct:log("========== END ========"),
+    ok.
+
+%% Connect to an erlang server and inject unexpected SSH message
+%% ssh_fsm_kexinit in new_keys state
+kex_strict_violation_new_keys(Config) ->
+    ExpectedReason = "KEX strict violation",
+    Injections = [ssh_msg_ignore, ssh_msg_debug, ssh_msg_unimplemented],
+    TestProcedure =
+        fun(M) ->
+                ct:log(
+                  "=================== START: ~p Message: ~p Expected Fail =================================",
+                  [?FUNCTION_NAME, M]),
+                [receive_hello,
+                 {send, hello},
+                 {send, ssh_msg_kexinit},
+                 {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+                 {send, ssh_msg_kexdh_init},
+                 {send, M},
+                 {match, #ssh_msg_kexdh_reply{_='_'}, receive_msg},
+                 {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]
+        end,
+    [kex_strict_helper(Config, TestProcedure(Msg), ExpectedReason) ||
+        Msg <- Injections],
+    ct:log("========== END ========"),
+    ok.
+
+%% Connect to an erlang server and inject unexpected SSH message
+%% duplicated KEXINIT
+kex_strict_violation(Config) ->
+    KexDhReply =
+        #ssh_msg_kexdh_reply{
+           public_host_key = {{{'ECPoint',<<73,72,235,162,96,101,154,59,217,114,123,192,96,105,250,29,214,76,60,63,167,21,221,118,246,168,152,2,7,172,137,125>>},
+                               {namedCurve,{1,3,101,112}}},
+                              'ssh-ed25519'},
+           f = 18504393053016436370762156176197081926381112956345797067569792020930728564439992620494295053804030674742529174859108487694089045521619258420515443400605141150065440678508889060925968846155921972385560196703381004650914261218463420313738628465563288022895912907728767735629532940627575655703806353550720122093175255090704443612257683903495753071530605378193139909567971489952258218767352348904221407081210633467414579377014704081235998044497191940270966762124544755076128392259615566530695493013708460088312025006678879288856957348606386230195080105197251789635675011844976120745546472873505352732719507783227210178188,
+           h_sig = <<90,247,44,240,136,196,82,215,56,165,53,33,230,101,253,
+                     34,112,201,21,131,162,169,10,129,174,14,69,25,39,174,
+                     92,210,130,249,103,2,215,245,7,213,110,235,136,134,11,
+                     124,248,139,79,17,225,77,125,182,204,84,137,167,99,186,
+                     167,42,192,10>>},
+    TestFlows =
+        [
+         {kexinit, "KEX strict violation",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           {send, ssh_msg_kexinit},
+           {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]},
+         {ssh_msg_kexdh_init, "KEX strict violation",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           {send, ssh_msg_kexdh_init_dup},
+           {match,# ssh_msg_kexdh_reply{_='_'}, receive_msg},
+           {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]},
+         {new_keys, "Message ssh_msg_newkeys in wrong state",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           {send, ssh_msg_kexdh_init},
+           {match,# ssh_msg_kexdh_reply{_='_'}, receive_msg},
+           {send, #ssh_msg_newkeys{}},
+           {match, #ssh_msg_newkeys{_='_'}, receive_msg},
+           {send, #ssh_msg_newkeys{}},
+           {match, disconnect(?SSH_DISCONNECT_PROTOCOL_ERROR), receive_msg}]},
+         {ssh_msg_unexpected_dh_gex, "KEX strict violation",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           %% dh_alg is expected but dh_gex_alg is provided
+	   {send, #ssh_msg_kex_dh_gex_request{min = 1000, n = 3000, max = 4000}},
+           {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]},
+         {wrong_role, "KEX strict violation",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           %% client should not send message below
+           {send, KexDhReply},
+           {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]},
+         {wrong_role2, "KEX strict violation",
+          [receive_hello,
+           {send, hello},
+           {send, ssh_msg_kexinit},
+           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+           {send, ssh_msg_kexdh_init},
+           {match,# ssh_msg_kexdh_reply{_='_'}, receive_msg},
+           %% client should not send message below
+           {send, KexDhReply},
+           {match, #ssh_msg_newkeys{_='_'}, receive_msg},
+           {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}]}
+        ],
+    TestProcedure =
+        fun({Msg, _, P}) ->
+                ct:log(
+                  "==== START: ~p (duplicated ~p) Expected Fail ====~n~p",
+                  [?FUNCTION_NAME, Msg, P]),
+                P
+        end,
+    [kex_strict_helper(Config, TestProcedure(Procedure), Reason) ||
+        Procedure = {_, Reason, _} <- TestFlows],
+    ct:log("==== END ====="),
+    ok.
 
 %% Connect to an erlang server and inject unexpected non-SSH binary
 kex_strict_msg_unknown(Config) ->
     ct:log("START: ~p~n=================================", [?FUNCTION_NAME]),
     ExpectedReason = "Bad packet: Size",
     TestMessages =
-        [{send, ssh_msg_unknown},
+        [receive_hello,
+         {send, hello},
+         {send, ssh_msg_kexinit},
+         {match, #ssh_msg_kexinit{_='_'}, receive_msg},
+         {send, ssh_msg_kexdh_init},
+         {send, ssh_msg_unknown},
          {match, #ssh_msg_kexdh_reply{_='_'}, receive_msg},
          {match, disconnect(?SSH_DISCONNECT_KEY_EXCHANGE_FAILED), receive_msg}],
     kex_strict_helper(Config, TestMessages, ExpectedReason).
@@ -1047,12 +1174,7 @@ kex_strict_helper(Config, TestMessages, ExpectedReason) ->
              {user_dir, user_dir(Config)},
              {user_interaction, false}
             | proplists:get_value(extra_options,Config,[])
-            ]},
-           receive_hello,
-           {send, hello},
-           {send, ssh_msg_kexinit},
-           {match, #ssh_msg_kexinit{_='_'}, receive_msg},
-           {send, ssh_msg_kexdh_init}] ++
+            ]}] ++
               TestMessages,
           InitialState),
     ct:sleep(100),
diff --git a/lib/ssh/test/ssh_trpt_test_lib.erl b/lib/ssh/test/ssh_trpt_test_lib.erl
index f03fee1662..e34db487e5 100644
--- a/lib/ssh/test/ssh_trpt_test_lib.erl
+++ b/lib/ssh/test/ssh_trpt_test_lib.erl
@@ -90,7 +90,8 @@ exec(Op, S0=#s{}) ->
 	    report_trace(throw, Term, S1),
 	    throw({Term,Op});
 
-	error:Error ->
+	error:Error:St ->
+            ct:log("Stacktrace=~n~p", [St]),
 	    report_trace(error, Error, S1),
 	    error({Error,Op});
 
@@ -335,6 +336,17 @@ send(S0, ssh_msg_ignore) ->
     Msg = #ssh_msg_ignore{data = "unexpected_ignore_message"},
     send(S0, Msg);
 
+send(S0, ssh_msg_debug) ->
+    Msg = #ssh_msg_debug{
+             always_display = true,
+             message = "some debug message",
+             language = "en"},
+    send(S0, Msg);
+
+send(S0, ssh_msg_unimplemented) ->
+    Msg = #ssh_msg_unimplemented{sequence = 123},
+    send(S0, Msg);
+
 send(S0, ssh_msg_unknown) ->
     Msg = binary:encode_hex(<<"0000000C060900000000000000000000">>),
     send(S0, Msg);
@@ -382,6 +394,26 @@ send(S0, ssh_msg_kexdh_init) when ?role(S0) == client ->
 	    end),
     send_bytes(NextKexMsgBin, S#s{ssh = C});
 
+send(S0, ssh_msg_kexdh_init_dup) when ?role(S0) == client ->
+    {OwnMsg, PeerMsg} = S0#s.alg_neg,
+    {ok, NextKexMsgBin, C} =
+	try ssh_transport:handle_kexinit_msg(PeerMsg, OwnMsg, S0#s.ssh, init)
+	catch
+	    Class:Exc ->
+		fail("Algorithm negotiation failed!",
+		     {"Algorithm negotiation failed at line ~p:~p~n~p:~s~nPeer: ~s~n Own: ~s",
+		      [?MODULE,?LINE,Class,format_msg(Exc),format_msg(PeerMsg),format_msg(OwnMsg)]},
+		     S0)
+	end,
+    S = opt(print_messages, S0,
+	    fun(X) when X==true;X==detail ->
+		    #ssh{keyex_key = {{_Private, Public}, {_G, _P}}} = C,
+		    Msg = #ssh_msg_kexdh_init{e = Public},
+		    {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
+	    end),
+    send_bytes(NextKexMsgBin, S#s{ssh = C}),
+    send_bytes(NextKexMsgBin, S#s{ssh = C});
+
 send(S0, ssh_msg_kexdh_reply) ->
     Bytes = proplists:get_value(ssh_msg_kexdh_reply, S0#s.reply),
     S = opt(print_messages, S0,
@@ -531,7 +563,10 @@ receive_binary_msg(S0=#s{}) ->
            S0#s.ssh)
      of
          {packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} ->
-             S1 = S0#s{ssh = Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)},
+             S1 = S0#s{ssh = Ssh1#ssh{recv_sequence =
+                                          ssh_transport:next_seqnum(undefined,
+                                                                    Ssh1#ssh.recv_sequence,
+                                                                    false)},
                        decrypted_data_buffer = <<>>,
                        undecrypted_packet_length = undefined,
                        aead_data = <<>>,
-- 
2.43.0

openSUSE Build Service is sponsored by