aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorlloyd <[email protected]>2014-11-15 23:39:24 +0000
committerlloyd <[email protected]>2014-11-15 23:39:24 +0000
commit060df7809a64d1b589554169443c48bc428ca726 (patch)
tree74ca96453ddb4bd3a8abca43fb81d67859c9f6f8 /src
parent9751f1a9084aadbfebbc7f7e67fcd5806ead6492 (diff)
A TLS Server can now process either TLS or DTLS but not either,
with the setting set in the constructor. This prevents various surprising things from happening to applications and simplifies record processing.
Diffstat (limited to 'src')
-rw-r--r--src/cmd/tls_server.cpp3
-rw-r--r--src/lib/tls/tls_channel.cpp50
-rw-r--r--src/lib/tls/tls_channel.h3
-rw-r--r--src/lib/tls/tls_client.cpp3
-rw-r--r--src/lib/tls/tls_handshake_io.cpp1
-rw-r--r--src/lib/tls/tls_magic.h4
-rw-r--r--src/lib/tls/tls_policy.cpp4
-rw-r--r--src/lib/tls/tls_record.cpp186
-rw-r--r--src/lib/tls/tls_record.h1
-rw-r--r--src/lib/tls/tls_server.cpp3
-rw-r--r--src/lib/tls/tls_server.h1
11 files changed, 183 insertions, 76 deletions
diff --git a/src/cmd/tls_server.cpp b/src/cmd/tls_server.cpp
index fd9f7ef5d..a892835dc 100644
--- a/src/cmd/tls_server.cpp
+++ b/src/cmd/tls_server.cpp
@@ -205,7 +205,8 @@ int tls_server(int argc, char* argv[])
creds,
policy,
rng,
- protocols);
+ protocols,
+ (transport != "tcp"));
while(!server.is_closed())
{
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index 25307166b..76332f7d2 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -25,7 +25,9 @@ Channel::Channel(std::function<void (const byte[], size_t)> output_fn,
std::function<bool (const Session&)> handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ bool is_datagram,
size_t reserved_io_buffer_size) :
+ m_is_datagram(is_datagram),
m_handshake_cb(handshake_cb),
m_data_cb(data_cb),
m_alert_cb(alert_cb),
@@ -142,6 +144,8 @@ bool Channel::timeout_check()
{
if(m_pending_state)
return m_pending_state->handshake_io().timeout_check();
+
+ //FIXME: scan cipher suites and remove epochs older than 2*MSL
return false;
}
@@ -252,11 +256,7 @@ void Channel::activate_session()
std::swap(m_active_state, m_pending_state);
m_pending_state.reset();
- if(m_active_state->version().is_datagram_protocol())
- {
- // FIXME, remove old states when we are sure not needed anymore
- }
- else
+ if(!m_active_state->version().is_datagram_protocol())
{
// TLS is easy just remove all but the current state
auto current_epoch = sequence_numbers().current_write_epoch();
@@ -307,6 +307,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
read_record(m_readbuf,
input,
input_size,
+ m_is_datagram,
consumed,
record,
&record_sequence,
@@ -340,24 +341,31 @@ size_t Channel::received_data(const byte input[], size_t input_size)
{
if(record_version.is_datagram_protocol())
{
- sequence_numbers().read_accept(record_sequence);
-
- /*
- * Might be a peer retransmit under epoch - 1 in which
- * case we must retransmit last flight
- */
-
- const u16bit epoch = record_sequence >> 48;
-
- if(epoch == sequence_numbers().current_read_epoch())
+ if(m_sequence_numbers)
{
- create_handshake_state(record_version);
+ /*
+ * Might be a peer retransmit under epoch - 1 in which
+ * case we must retransmit last flight
+ */
+ sequence_numbers().read_accept(record_sequence);
+
+ const u16bit epoch = record_sequence >> 48;
+
+ if(epoch == sequence_numbers().current_read_epoch())
+ {
+ create_handshake_state(record_version);
+ }
+ else if(epoch == sequence_numbers().current_read_epoch() - 1)
+ {
+ BOTAN_ASSERT(m_active_state, "Have active state here");
+ m_active_state->handshake_io().add_record(unlock(record),
+ record_type,
+ record_sequence);
+ }
}
- else if(epoch == sequence_numbers().current_read_epoch() - 1)
+ else if(record_sequence == 0)
{
- m_active_state->handshake_io().add_record(unlock(record),
- record_type,
- record_sequence);
+ create_handshake_state(record_version);
}
}
else
@@ -445,7 +453,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
return 0;
}
}
- else
+ else if(record_type != NO_RECORD)
throw Unexpected_Message("Unexpected record type " +
std::to_string(record_type) +
" from counterparty");
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 3cdfe3d5e..8aea2dab0 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -164,6 +164,7 @@ class BOTAN_DLL Channel
std::function<bool (const Session&)> handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ bool is_datagram,
size_t reserved_io_buffer_size);
Channel(const Channel&) = delete;
@@ -234,6 +235,8 @@ class BOTAN_DLL Channel
const Handshake_State* pending_state() const { return m_pending_state.get(); }
+ bool m_is_datagram;
+
/* callbacks */
std::function<bool (const Session&)> m_handshake_cb;
std::function<void (const byte[], size_t)> m_data_cb;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index 86d1998e1..7c3e48ca6 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -59,7 +59,8 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
const Protocol_Version offer_version,
std::function<std::string (std::vector<std::string>)> next_protocol,
size_t io_buf_sz) :
- Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz),
+ Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng,
+ offer_version.is_datagram_protocol(), io_buf_sz),
m_policy(policy),
m_creds(creds),
m_info(info)
diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp
index da27cc4ce..659818139 100644
--- a/src/lib/tls/tls_handshake_io.cpp
+++ b/src/lib/tls/tls_handshake_io.cpp
@@ -70,6 +70,7 @@ Stream_Handshake_IO::get_next_record(bool)
if(m_queue.size() >= length + 4)
{
Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]);
+ BOTAN_ASSERT(type < 250, "Not in reserved range");
std::vector<byte> contents(m_queue.begin() + 4,
m_queue.begin() + 4 + length);
diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h
index 51f1fce47..e22ab7248 100644
--- a/src/lib/tls/tls_magic.h
+++ b/src/lib/tls/tls_magic.h
@@ -27,13 +27,13 @@ enum Size_Limits {
enum Connection_Side { CLIENT = 1, SERVER = 2 };
enum Record_Type {
- NO_RECORD = 0,
-
CHANGE_CIPHER_SPEC = 20,
ALERT = 21,
HANDSHAKE = 22,
APPLICATION_DATA = 23,
HEARTBEAT = 24,
+
+ NO_RECORD = 256
};
enum Handshake_Type {
diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp
index c4867d81a..0f2190562 100644
--- a/src/lib/tls/tls_policy.cpp
+++ b/src/lib/tls/tls_policy.cpp
@@ -146,10 +146,8 @@ bool Policy::send_fallback_scsv(Protocol_Version version) const
bool Policy::acceptable_protocol_version(Protocol_Version version) const
{
- // By default require TLS to minimize surprise
if(version.is_datagram_protocol())
- return false;
-
+ return (version >= Protocol_Version::DTLS_V12);
return (version >= Protocol_Version::TLS_V10);
}
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 925961764..0b356fad3 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output,
}
}
-}
-
-size_t read_record(secure_vector<byte>& readbuf,
- const byte input[],
- size_t input_sz,
- size_t& consumed,
- secure_vector<byte>& record,
- u64bit* record_sequence,
- Protocol_Version* record_version,
- Record_Type* record_type,
- Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+size_t read_tls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
{
consumed = 0;
@@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf,
*record_version = Protocol_Version(readbuf[1], readbuf[2]);
- const bool is_dtls = record_version->is_datagram_protocol();
+ BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS");
- if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE)
- {
- if(size_t needed = fill_buffer_to(readbuf,
- input, input_sz, consumed,
- DTLS_HEADER_SIZE))
- return needed;
-
- BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE,
- "Have an entire header");
- }
-
- const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
-
- const size_t record_len = make_u16bit(readbuf[header_size-2],
- readbuf[header_size-1]);
+ const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2],
+ readbuf[TLS_HEADER_SIZE-1]);
if(record_len > MAX_CIPHERTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
@@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf,
if(size_t needed = fill_buffer_to(readbuf,
input, input_sz, consumed,
- header_size + record_len))
- return needed; // wrong for DTLS?
+ TLS_HEADER_SIZE + record_len))
+ return needed;
- BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len,
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len,
readbuf.size(),
"Have the full record");
@@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf,
u16bit epoch = 0;
- if(is_dtls)
- {
- *record_sequence = load_be<u64bit>(&readbuf[3], 0);
- epoch = (*record_sequence >> 48);
- }
- else if(sequence_numbers)
+ if(sequence_numbers)
{
*record_sequence = sequence_numbers->next_read_sequence();
epoch = sequence_numbers->current_read_epoch();
@@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf,
epoch = 0;
}
- if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
- {
- readbuf.clear();
- return 0;
- }
-
- byte* record_contents = &readbuf[header_size];
+ byte* record_contents = &readbuf[TLS_HEADER_SIZE];
if(epoch == 0) // Unencrypted initial handshake
{
- record.assign(&readbuf[header_size], &readbuf[header_size + record_len]);
+ record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]);
readbuf.clear();
return 0; // got a full record
}
@@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf,
// Otherwise, decrypt, check MAC, return plaintext
auto cipherstate = get_cipherstate(epoch);
- // FIXME: DTLS reordering might cause us not to have the cipher state
-
BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
decrypt_record(record,
@@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf,
return 0;
}
+size_t read_dtls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ consumed = 0;
+
+ if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete?
+ {
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header");
+ }
+
+ *record_version = Protocol_Version(readbuf[1], readbuf[2]);
+
+ BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS");
+
+ const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2],
+ readbuf[DTLS_HEADER_SIZE-1]);
+
+ if(record_len > MAX_CIPHERTEXT_SIZE)
+ throw TLS_Exception(Alert::RECORD_OVERFLOW,
+ "Got message that exceeds maximum size");
+
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len))
+ {
+ // Truncated packet?
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(),
+ "Have the full record");
+
+ *record_type = static_cast<Record_Type>(readbuf[0]);
+
+ u16bit epoch = 0;
+
+ *record_sequence = load_be<u64bit>(&readbuf[3], 0);
+ epoch = (*record_sequence >> 48);
+
+ if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ byte* record_contents = &readbuf[DTLS_HEADER_SIZE];
+
+ if(epoch == 0) // Unencrypted initial handshake
+ {
+ record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]);
+ readbuf.clear();
+ return 0; // got a full record
+ }
+
+ try
+ {
+ // Otherwise, decrypt, check MAC, return plaintext
+ auto cipherstate = get_cipherstate(epoch);
+
+ BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
+
+ decrypt_record(record,
+ record_contents,
+ record_len,
+ *record_sequence,
+ *record_version,
+ *record_type,
+ *cipherstate);
+ }
+ catch(std::exception)
+ {
+ readbuf.clear();
+ *record_type = NO_RECORD;
+ return 0;
+ }
+
+ if(sequence_numbers)
+ sequence_numbers->read_accept(*record_sequence);
+
+ readbuf.clear();
+ return 0;
+ }
+
+}
+
+size_t read_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ bool is_datagram,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ if(is_datagram)
+ return read_dtls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ else
+ return read_tls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ }
+
}
}
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index 8431e68c0..2dae96164 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -122,6 +122,7 @@ void write_record(secure_vector<byte>& write_buffer,
size_t read_record(secure_vector<byte>& read_buffer,
const byte input[],
size_t input_length,
+ bool is_datagram,
size_t& input_consumed,
secure_vector<byte>& record,
u64bit* record_sequence,
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index ff285881a..9b8a0d811 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -216,8 +216,9 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
const Policy& policy,
RandomNumberGenerator& rng,
const std::vector<std::string>& next_protocols,
+ bool is_datagram,
size_t io_buf_sz) :
- Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz),
+ Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz),
m_policy(policy),
m_creds(creds),
m_possible_protocols(next_protocols)
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index a514607ba..c0646bdbc 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -34,6 +34,7 @@ class BOTAN_DLL Server : public Channel
const Policy& policy,
RandomNumberGenerator& rng,
const std::vector<std::string>& protocols = std::vector<std::string>(),
+ bool is_datagram = false,
size_t reserved_io_buffer_size = 16*1024
);