File 2476-dtls-Implement-replay-protection.patch of Package erlang

From 4fd3360bb68adb2ee942b3fbaeab7d766b6d3454 Mon Sep 17 00:00:00 2001
From: Ingela Anderton Andin <ingela@erlang.org>
Date: Wed, 10 May 2017 23:36:44 +0200
Subject: [PATCH] dtls: Implement replay protection

See RFC 6347 section 3.3
---
 lib/ssl/src/dtls_connection.erl | 32 +++++++++++++++++-------
 lib/ssl/src/dtls_record.erl     | 55 ++++++++++++++++++++++++++++++++++++-----
 2 files changed, 72 insertions(+), 15 deletions(-)

diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index 9937373e6..b52896a45 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -688,16 +688,19 @@ next_record(#state{unprocessed_handshake_events = N} = State) when N > 0 ->
     {no_record, State#state{unprocessed_handshake_events = N-1}};
 					 
 next_record(#state{protocol_buffers =
-		       #protocol_buffers{dtls_cipher_texts = [CT | Rest]}
+		       #protocol_buffers{dtls_cipher_texts = [#ssl_tls{epoch = Epoch} = CT | Rest]}
 		   = Buffers,
-		   connection_states = ConnStates0} = State) ->
-    case dtls_record:decode_cipher_text(CT, ConnStates0) of
-	{Plain, ConnStates} ->		      
-	    {Plain, State#state{protocol_buffers =
-				    Buffers#protocol_buffers{dtls_cipher_texts = Rest},
-				connection_states = ConnStates}};
-	#alert{} = Alert ->
-	    {Alert, State}
+		   connection_states = ConnectionStates} = State) ->
+    CurrentRead = dtls_record:get_connection_state_by_epoch(Epoch, ConnectionStates, read),
+    case dtls_record:replay_detect(CT, CurrentRead) of
+        false ->
+            decode_cipher_text(State#state{connection_states = ConnectionStates}) ;
+        true ->
+            ct:pal("Replay detect", []),            
+            %% Ignore replayed record
+            next_record(State#state{protocol_buffers =
+                                        Buffers#protocol_buffers{dtls_cipher_texts = Rest},
+                                    connection_states = ConnectionStates})
     end;
 next_record(#state{role = server,
 		   socket = {Listener, {Client, _}},
@@ -770,6 +773,17 @@ next_event(StateName, Record,
 	    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}
     end.
 
+decode_cipher_text(#state{protocol_buffers = #protocol_buffers{dtls_cipher_texts = [ CT | Rest]} = Buffers,
+                          connection_states = ConnStates0} = State) ->
+    case dtls_record:decode_cipher_text(CT, ConnStates0) of
+	{Plain, ConnStates} ->		      
+	    {Plain, State#state{protocol_buffers =
+				    Buffers#protocol_buffers{dtls_cipher_texts = Rest},
+				connection_states = ConnStates}};
+	#alert{} = Alert ->
+	    {Alert, State}
+    end.
+
 dtls_version(hello, Version, #state{role = server} = State) ->
     State#state{negotiated_version = Version}; %%Inital version
 dtls_version(_,_, State) ->
diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl
index 6a418c6fb..8a7f8c1d0 100644
--- a/lib/ssl/src/dtls_record.erl
+++ b/lib/ssl/src/dtls_record.erl
@@ -46,7 +46,7 @@
 	 is_higher/2, supported_protocol_versions/0,
 	 is_acceptable_version/2, hello_version/2]).
 
--export([save_current_connection_state/2, next_epoch/2]).
+-export([save_current_connection_state/2, next_epoch/2, get_connection_state_by_epoch/3, replay_detect/2]).
 
 -export([init_connection_state_seq/2, current_connection_state_epoch/2]).
 
@@ -55,6 +55,8 @@
 -type dtls_version()       :: ssl_record:ssl_version().
 -type dtls_atom_version()  :: dtlsv1 | 'dtlsv1.2'.
 
+-define(REPLAY_WINDOW_SIZE, 64).
+
 -compile(inline).
 
 %%====================================================================
@@ -73,7 +75,7 @@ init_connection_states(Role, BeastMitigation) ->
     Initial = initial_connection_state(ConnectionEnd, BeastMitigation),
     Current = Initial#{epoch := 0},
     InitialPending = ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation),
-    Pending = InitialPending#{epoch => undefined},
+    Pending = InitialPending#{epoch => undefined, replay_window => init_replay_window(?REPLAY_WINDOW_SIZE)},
     #{saved_read  => Current,
       current_read  => Current,
       pending_read  => Pending,
@@ -96,11 +98,13 @@ save_current_connection_state(#{current_write := Current} = States, write) ->
 
 next_epoch(#{pending_read := Pending,
 	     current_read := #{epoch := Epoch}} = States, read) ->
-    States#{pending_read := Pending#{epoch := Epoch + 1}};
+    States#{pending_read := Pending#{epoch := Epoch + 1,
+                                     replay_window := init_replay_window(?REPLAY_WINDOW_SIZE)}};
 
 next_epoch(#{pending_write := Pending,
 	     current_write := #{epoch := Epoch}} = States, write) ->
-    States#{pending_write := Pending#{epoch := Epoch + 1}}.
+    States#{pending_write := Pending#{epoch := Epoch + 1,
+                                      replay_window := init_replay_window(?REPLAY_WINDOW_SIZE)}}.
 
 get_connection_state_by_epoch(Epoch, #{current_write := #{epoch := Epoch} = Current},
 			      write) ->
@@ -411,6 +415,7 @@ hello_version(Version, Versions) ->
             lowest_protocol_version(Versions)
     end.
 
+
 %%--------------------------------------------------------------------
 %%% Internal functions
 %%--------------------------------------------------------------------
@@ -419,6 +424,7 @@ initial_connection_state(ConnectionEnd, BeastMitigation) ->
 	  ssl_record:initial_security_params(ConnectionEnd),
       epoch => undefined,
       sequence_number => 0,
+      replay_window => init_replay_window(?REPLAY_WINDOW_SIZE),
       beast_mitigation => BeastMitigation,
       compression_state  => undefined,
       cipher_state  => undefined,
@@ -499,8 +505,9 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version,
 	{PlainFragment, CipherState} ->
 	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg,
 							   PlainFragment, CompressionS0),
-	    ReadState = ReadState0#{compression_state => CompressionS1,
+	    ReadState0 = ReadState0#{compression_state => CompressionS1,
                                     cipher_state => CipherState},
+            ReadState = update_replay_window(Seq, ReadState0),
 	    ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read),
 	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates};
 	  #alert{} = Alert ->
@@ -523,7 +530,8 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version,
 	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg,
 							   PlainFragment, CompressionS0),
 	    
-	    ReadState = ReadState1#{compression_state => CompressionS1},
+	    ReadState2 = ReadState1#{compression_state => CompressionS1},
+            ReadState = update_replay_window(Seq, ReadState2),
 	    ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read),
 	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates};
 	false ->
@@ -555,3 +563,38 @@ mac_hash({Major, Minor}, MacAlg, MacSecret, Epoch, SeqNo, Type, Length, Fragment
     
 calc_aad(Type, {MajVer, MinVer}, Epoch, SeqNo) ->
     <<?UINT16(Epoch), ?UINT48(SeqNo), ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>.
+
+init_replay_window(Size) ->
+    #{size => Size,
+      top => Size,
+      bottom => 0,
+      mask => 0 bsl 64
+     }.
+
+replay_detect(#ssl_tls{sequence_number = SequenceNumber}, #{replay_window := Window}) ->
+    is_replay(SequenceNumber, Window).
+
+
+is_replay(SequenceNumber, #{bottom := Bottom}) when SequenceNumber < Bottom ->
+    true;
+is_replay(SequenceNumber, #{size := Size,
+                            top := Top,
+                            bottom := Bottom,
+                            mask :=  Mask})  when (SequenceNumber >= Bottom) andalso (SequenceNumber =< Top) ->
+    Index = (SequenceNumber rem Size),
+    (Index band Mask) == 1;
+
+is_replay(_, _) ->
+    false.
+
+update_replay_window(SequenceNumber,  #{replay_window := #{size := Size,
+                                                           top := Top,
+                                                           bottom := Bottom,
+                                                           mask :=  Mask0} = Window0} = ConnectionStates) ->
+    NoNewBits = SequenceNumber - Top,
+    Index = SequenceNumber rem Size,
+    Mask = (Mask0 bsl NoNewBits) bor Index,
+    Window =  Window0#{top => SequenceNumber,
+                       bottom => Bottom + NoNewBits,
+                       mask => Mask},
+    ConnectionStates#{replay_window := Window}.
-- 
2.13.0

openSUSE Build Service is sponsored by