diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/msg_hello_verify.cpp | 3 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 7 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 3 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 18 |
4 files changed, 22 insertions, 9 deletions
diff --git a/src/tls/msg_hello_verify.cpp b/src/tls/msg_hello_verify.cpp index 19597e9df..f8a117c03 100644 --- a/src/tls/msg_hello_verify.cpp +++ b/src/tls/msg_hello_verify.cpp @@ -29,8 +29,7 @@ Hello_Verify_Request::Hello_Verify_Request(const std::vector<byte>& buf) if(static_cast<size_t>(buf[2]) + 3 != buf.size()) throw Decoding_Error("Bad length in hello verify request"); - m_cookie.resize(buf.size() - 3); - copy_mem(&m_cookie[0], &buf[3], buf.size() - 3); + m_cookie.assign(&buf[3], &buf[buf.size()]); } Hello_Verify_Request::Hello_Verify_Request(const std::vector<byte>& client_hello_bits, diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 5e1e546a4..91aaae206 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -100,8 +100,13 @@ void Channel::change_cipher_spec_writer(Connection_Side side) A sequence number is incremented after each record: specifically, the first record transmitted under a particular connection state MUST use sequence number 0 + + For DTLS, increment the epoch */ - m_write_seq_no = 0; + if(current_protocol_version().is_datagram_protocol()) + m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48; + else + m_write_seq_no = 0; m_write_cipherstate.reset( new Connection_Cipher_State(current_protocol_version(), diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index fd2e7ea98..96329c55c 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -314,7 +314,6 @@ size_t split_for_mtu(size_t mtu, size_t msg_size) } - std::vector<byte> Datagram_Handshake_IO::send(const Handshake_Message& msg) { @@ -329,8 +328,6 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) const std::vector<byte> no_fragment = format_w_seq(msg_bits, msg.type(), m_out_message_seq); - m_mtu = 64; - if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) m_send_hs(HANDSHAKE, no_fragment); else diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index 6725310ab..4a031f626 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -87,6 +87,12 @@ void write_record(std::vector<byte>& output, output.push_back(version.major_version()); output.push_back(version.minor_version()); + if(version.is_datagram_protocol()) + { + for(size_t i = 0; i != 8; ++i) + output.push_back(get_byte(i, msg_sequence_number)); + } + if(!cipherstate) // initial unencrypted handshake records { output.push_back(get_byte<u16bit>(0, msg_length)); @@ -271,6 +277,9 @@ size_t read_record(std::vector<byte>& readbuf, { consumed = 0; + BOTAN_ASSERT(version.valid(), + "We know what version we are using"); + const size_t header_size = (version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; @@ -335,7 +344,11 @@ size_t read_record(std::vector<byte>& readbuf, " from counterparty"); } - const size_t record_len = make_u16bit(readbuf[3], readbuf[4]); + if(version.is_datagram_protocol()) + msg_sequence = load_be<u64bit>(&readbuf[3], 0); + + const size_t record_len = make_u16bit(readbuf[header_size-2], + readbuf[header_size-1]); if(version.major_version()) { @@ -372,8 +385,7 @@ size_t read_record(std::vector<byte>& readbuf, } msg_type = readbuf[0]; - msg.resize(record_len); - copy_mem(&msg[0], record_contents, record_len); + msg.assign(&record_contents[0], &record_contents[record_len]); readbuf_pos = 0; return 0; // got a full record |