aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cmd/tls_server.cpp3
-rw-r--r--src/lib/tls/tls_channel.cpp50
-rw-r--r--src/lib/tls/tls_channel.h3
-rw-r--r--src/lib/tls/tls_client.cpp3
-rw-r--r--src/lib/tls/tls_handshake_io.cpp1
-rw-r--r--src/lib/tls/tls_magic.h4
-rw-r--r--src/lib/tls/tls_policy.cpp4
-rw-r--r--src/lib/tls/tls_record.cpp186
-rw-r--r--src/lib/tls/tls_record.h1
-rw-r--r--src/lib/tls/tls_server.cpp3
-rw-r--r--src/lib/tls/tls_server.h1
11 files changed, 183 insertions, 76 deletions
diff --git a/src/cmd/tls_server.cpp b/src/cmd/tls_server.cpp
index fd9f7ef5d..a892835dc 100644
--- a/src/cmd/tls_server.cpp
+++ b/src/cmd/tls_server.cpp
@@ -205,7 +205,8 @@ int tls_server(int argc, char* argv[])
creds,
policy,
rng,
- protocols);
+ protocols,
+ (transport != "tcp"));
while(!server.is_closed())
{
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index 25307166b..76332f7d2 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -25,7 +25,9 @@ Channel::Channel(std::function<void (const byte[], size_t)> output_fn,
std::function<bool (const Session&)> handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ bool is_datagram,
size_t reserved_io_buffer_size) :
+ m_is_datagram(is_datagram),
m_handshake_cb(handshake_cb),
m_data_cb(data_cb),
m_alert_cb(alert_cb),
@@ -142,6 +144,8 @@ bool Channel::timeout_check()
{
if(m_pending_state)
return m_pending_state->handshake_io().timeout_check();
+
+ //FIXME: scan cipher suites and remove epochs older than 2*MSL
return false;
}
@@ -252,11 +256,7 @@ void Channel::activate_session()
std::swap(m_active_state, m_pending_state);
m_pending_state.reset();
- if(m_active_state->version().is_datagram_protocol())
- {
- // FIXME, remove old states when we are sure not needed anymore
- }
- else
+ if(!m_active_state->version().is_datagram_protocol())
{
// TLS is easy just remove all but the current state
auto current_epoch = sequence_numbers().current_write_epoch();
@@ -307,6 +307,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
read_record(m_readbuf,
input,
input_size,
+ m_is_datagram,
consumed,
record,
&record_sequence,
@@ -340,24 +341,31 @@ size_t Channel::received_data(const byte input[], size_t input_size)
{
if(record_version.is_datagram_protocol())
{
- sequence_numbers().read_accept(record_sequence);
-
- /*
- * Might be a peer retransmit under epoch - 1 in which
- * case we must retransmit last flight
- */
-
- const u16bit epoch = record_sequence >> 48;
-
- if(epoch == sequence_numbers().current_read_epoch())
+ if(m_sequence_numbers)
{
- create_handshake_state(record_version);
+ /*
+ * Might be a peer retransmit under epoch - 1 in which
+ * case we must retransmit last flight
+ */
+ sequence_numbers().read_accept(record_sequence);
+
+ const u16bit epoch = record_sequence >> 48;
+
+ if(epoch == sequence_numbers().current_read_epoch())
+ {
+ create_handshake_state(record_version);
+ }
+ else if(epoch == sequence_numbers().current_read_epoch() - 1)
+ {
+ BOTAN_ASSERT(m_active_state, "Have active state here");
+ m_active_state->handshake_io().add_record(unlock(record),
+ record_type,
+ record_sequence);
+ }
}
- else if(epoch == sequence_numbers().current_read_epoch() - 1)
+ else if(record_sequence == 0)
{
- m_active_state->handshake_io().add_record(unlock(record),
- record_type,
- record_sequence);
+ create_handshake_state(record_version);
}
}
else
@@ -445,7 +453,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
return 0;
}
}
- else
+ else if(record_type != NO_RECORD)
throw Unexpected_Message("Unexpected record type " +
std::to_string(record_type) +
" from counterparty");
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 3cdfe3d5e..8aea2dab0 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -164,6 +164,7 @@ class BOTAN_DLL Channel
std::function<bool (const Session&)> handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ bool is_datagram,
size_t reserved_io_buffer_size);
Channel(const Channel&) = delete;
@@ -234,6 +235,8 @@ class BOTAN_DLL Channel
const Handshake_State* pending_state() const { return m_pending_state.get(); }
+ bool m_is_datagram;
+
/* callbacks */
std::function<bool (const Session&)> m_handshake_cb;
std::function<void (const byte[], size_t)> m_data_cb;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index 86d1998e1..7c3e48ca6 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -59,7 +59,8 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
const Protocol_Version offer_version,
std::function<std::string (std::vector<std::string>)> next_protocol,
size_t io_buf_sz) :
- Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz),
+ Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng,
+ offer_version.is_datagram_protocol(), io_buf_sz),
m_policy(policy),
m_creds(creds),
m_info(info)
diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp
index da27cc4ce..659818139 100644
--- a/src/lib/tls/tls_handshake_io.cpp
+++ b/src/lib/tls/tls_handshake_io.cpp
@@ -70,6 +70,7 @@ Stream_Handshake_IO::get_next_record(bool)
if(m_queue.size() >= length + 4)
{
Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
+ BOTAN_ASSERT(type < 250, "Not in reserved range");
std::vector<byte> contents(m_queue.begin() + 4,
m_queue.begin() + 4 + length);
diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h
index 51f1fce47..e22ab7248 100644
--- a/src/lib/tls/tls_magic.h
+++ b/src/lib/tls/tls_magic.h
@@ -27,13 +27,13 @@ enum Size_Limits {
enum Connection_Side { CLIENT = 1, SERVER = 2 };
enum Record_Type {
- NO_RECORD = 0,
-
CHANGE_CIPHER_SPEC = 20,
ALERT = 21,
HANDSHAKE = 22,
APPLICATION_DATA = 23,
HEARTBEAT = 24,
+
+ NO_RECORD = 256
};
enum Handshake_Type {
diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp
index c4867d81a..0f2190562 100644
--- a/src/lib/tls/tls_policy.cpp
+++ b/src/lib/tls/tls_policy.cpp
@@ -146,10 +146,8 @@ bool Policy::send_fallback_scsv(Protocol_Version version) const
bool Policy::acceptable_protocol_version(Protocol_Version version) const
{
- // By default require TLS to minimize surprise
if(version.is_datagram_protocol())
- return false;
-
+ return (version >= Protocol_Version::DTLS_V12);
return (version >= Protocol_Version::TLS_V10);
}
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 925961764..0b356fad3 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output,
}
}
-}
-
-size_t read_record(secure_vector<byte>& readbuf,
- const byte input[],
- size_t input_sz,
- size_t& consumed,
- secure_vector<byte>& record,
- u64bit* record_sequence,
- Protocol_Version* record_version,
- Record_Type* record_type,
- Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+size_t read_tls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
{
consumed = 0;
@@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf,
*record_version = Protocol_Version(readbuf[1], readbuf[2]);
- const bool is_dtls = record_version->is_datagram_protocol();
+ BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS");
- if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE)
- {
- if(size_t needed = fill_buffer_to(readbuf,
- input, input_sz, consumed,
- DTLS_HEADER_SIZE))
- return needed;
-
- BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE,
- "Have an entire header");
- }
-
- const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
-
- const size_t record_len = make_u16bit(readbuf[header_size-2],
- readbuf[header_size-1]);
+ const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2],
+ readbuf[TLS_HEADER_SIZE-1]);
if(record_len > MAX_CIPHERTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
@@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf,
if(size_t needed = fill_buffer_to(readbuf,
input, input_sz, consumed,
- header_size + record_len))
- return needed; // wrong for DTLS?
+ TLS_HEADER_SIZE + record_len))
+ return needed;
- BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len,
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len,
readbuf.size(),
"Have the full record");
@@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf,
u16bit epoch = 0;
- if(is_dtls)
- {
- *record_sequence = load_be<u64bit>(&readbuf[3], 0);
- epoch = (*record_sequence >> 48);
- }
- else if(sequence_numbers)
+ if(sequence_numbers)
{
*record_sequence = sequence_numbers->next_read_sequence();
epoch = sequence_numbers->current_read_epoch();
@@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf,
epoch = 0;
}
- if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
- {
- readbuf.clear();
- return 0;
- }
-
- byte* record_contents = &readbuf[header_size];
+ byte* record_contents = &readbuf[TLS_HEADER_SIZE];
if(epoch == 0) // Unencrypted initial handshake
{
- record.assign(&readbuf[header_size], &readbuf[header_size + record_len]);
+ record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]);
readbuf.clear();
return 0; // got a full record
}
@@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf,
// Otherwise, decrypt, check MAC, return plaintext
auto cipherstate = get_cipherstate(epoch);
- // FIXME: DTLS reordering might cause us not to have the cipher state
-
BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
decrypt_record(record,
@@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf,
return 0;
}
+size_t read_dtls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ consumed = 0;
+
+ if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete?
+ {
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header");
+ }
+
+ *record_version = Protocol_Version(readbuf[1], readbuf[2]);
+
+ BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS");
+
+ const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2],
+ readbuf[DTLS_HEADER_SIZE-1]);
+
+ if(record_len > MAX_CIPHERTEXT_SIZE)
+ throw TLS_Exception(Alert::RECORD_OVERFLOW,
+ "Got message that exceeds maximum size");
+
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len))
+ {
+ // Truncated packet?
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(),
+ "Have the full record");
+
+ *record_type = static_cast<Record_Type>(readbuf[0]);
+
+ u16bit epoch = 0;
+
+ *record_sequence = load_be<u64bit>(&readbuf[3], 0);
+ epoch = (*record_sequence >> 48);
+
+ if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ byte* record_contents = &readbuf[DTLS_HEADER_SIZE];
+
+ if(epoch == 0) // Unencrypted initial handshake
+ {
+ record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]);
+ readbuf.clear();
+ return 0; // got a full record
+ }
+
+ try
+ {
+ // Otherwise, decrypt, check MAC, return plaintext
+ auto cipherstate = get_cipherstate(epoch);
+
+ BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
+
+ decrypt_record(record,
+ record_contents,
+ record_len,
+ *record_sequence,
+ *record_version,
+ *record_type,
+ *cipherstate);
+ }
+ catch(std::exception)
+ {
+ readbuf.clear();
+ *record_type = NO_RECORD;
+ return 0;
+ }
+
+ if(sequence_numbers)
+ sequence_numbers->read_accept(*record_sequence);
+
+ readbuf.clear();
+ return 0;
+ }
+
+}
+
+size_t read_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ bool is_datagram,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ if(is_datagram)
+ return read_dtls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ else
+ return read_tls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ }
+
}
}
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index 8431e68c0..2dae96164 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -122,6 +122,7 @@ void write_record(secure_vector<byte>& write_buffer,
size_t read_record(secure_vector<byte>& read_buffer,
const byte input[],
size_t input_length,
+ bool is_datagram,
size_t& input_consumed,
secure_vector<byte>& record,
u64bit* record_sequence,
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index ff285881a..9b8a0d811 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -216,8 +216,9 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
const Policy& policy,
RandomNumberGenerator& rng,
const std::vector<std::string>& next_protocols,
+ bool is_datagram,
size_t io_buf_sz) :
- Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz),
+ Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz),
m_policy(policy),
m_creds(creds),
m_possible_protocols(next_protocols)
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index a514607ba..c0646bdbc 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -34,6 +34,7 @@ class BOTAN_DLL Server : public Channel
const Policy& policy,
RandomNumberGenerator& rng,
const std::vector<std::string>& protocols = std::vector<std::string>(),
+ bool is_datagram = false,
size_t reserved_io_buffer_size = 16*1024
);