aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp3
-rw-r--r--src/lib/rng/rng.h8
-rw-r--r--src/lib/tls/msg_client_hello.cpp6
-rw-r--r--src/lib/tls/tls_channel.cpp38
-rw-r--r--src/lib/tls/tls_channel.h17
-rw-r--r--src/lib/tls/tls_client.cpp53
-rw-r--r--src/lib/tls/tls_client.h15
-rw-r--r--src/lib/tls/tls_handshake_io.cpp97
-rw-r--r--src/lib/tls/tls_handshake_io.h16
-rw-r--r--src/lib/tls/tls_handshake_msg.h2
-rw-r--r--src/lib/tls/tls_handshake_state.cpp110
-rw-r--r--src/lib/tls/tls_handshake_state.h6
-rw-r--r--src/lib/tls/tls_magic.h2
-rw-r--r--src/lib/tls/tls_policy.cpp13
-rw-r--r--src/lib/tls/tls_policy.h15
-rw-r--r--src/lib/tls/tls_record.cpp85
-rw-r--r--src/lib/tls/tls_server.cpp58
-rw-r--r--src/lib/tls/tls_server.h14
-rw-r--r--src/lib/utils/ct_utils.h36
-rw-r--r--src/tests/unit_tls.cpp551
20 files changed, 855 insertions, 290 deletions
diff --git a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
index 6b3bce0aa..5ff288db2 100644
--- a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
+++ b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
@@ -28,8 +28,7 @@ secure_vector<byte> EME_PKCS1v15::pad(const byte in[], size_t inlen,
out[0] = 0x02;
for(size_t j = 1; j != olen - inlen - 1; ++j)
- while(out[j] == 0)
- out[j] = rng.next_byte();
+ out[j] = rng.next_nonzero_byte();
buffer_insert(out, olen - inlen, in, inlen);
return out;
diff --git a/src/lib/rng/rng.h b/src/lib/rng/rng.h
index 6ee67f66f..261880d5d 100644
--- a/src/lib/rng/rng.h
+++ b/src/lib/rng/rng.h
@@ -75,6 +75,14 @@ class BOTAN_DLL RandomNumberGenerator
*/
byte next_byte() { return get_random<byte>(); }
+ byte next_nonzero_byte()
+ {
+ byte b = next_byte();
+ while(b == 0)
+ b = next_byte();
+ return b;
+ }
+
/**
* Check whether this RNG is seeded.
* @return true if this RNG was already seeded, false otherwise.
diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp
index 82ba6f4f6..77bdc5cf5 100644
--- a/src/lib/tls/msg_client_hello.cpp
+++ b/src/lib/tls/msg_client_hello.cpp
@@ -1,6 +1,6 @@
/*
* TLS Hello Request and Client Hello Messages
-* (C) 2004-2011 Jack Lloyd
+* (C) 2004-2011,2015 Jack Lloyd
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
@@ -210,11 +210,11 @@ Client_Hello::Client_Hello(const std::vector<byte>& buf)
m_random = reader.get_fixed<byte>(32);
+ m_session_id = reader.get_range<byte>(1, 0, 32);
+
if(m_version.is_datagram_protocol())
m_hello_cookie = reader.get_range<byte>(1, 0, 255);
- m_session_id = reader.get_range<byte>(1, 0, 32);
-
m_suites = reader.get_range_vector<u16bit>(2, 1, 32767);
m_comp_methods = reader.get_range_vector<byte>(1, 1, 255);
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index e2b1aad9d..5dfcec34e 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -23,17 +23,21 @@ Channel::Channel(output_fn output_fn,
data_cb data_cb,
alert_cb alert_cb,
handshake_cb handshake_cb,
+ handshake_msg_cb handshake_msg_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ const Policy& policy,
bool is_datagram,
size_t reserved_io_buffer_size) :
m_is_datagram(is_datagram),
- m_handshake_cb(handshake_cb),
m_data_cb(data_cb),
m_alert_cb(alert_cb),
m_output_fn(output_fn),
- m_rng(rng),
- m_session_manager(session_manager)
+ m_handshake_cb(handshake_cb),
+ m_handshake_msg_cb(handshake_msg_cb),
+ m_session_manager(session_manager),
+ m_policy(policy),
+ m_rng(rng)
{
/* epoch 0 is plaintext, thus null cipher state */
m_write_cipher_states[0] = nullptr;
@@ -66,20 +70,16 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const
std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_epoch(u16bit epoch) const
{
auto i = m_read_cipher_states.find(epoch);
-
- BOTAN_ASSERT(i != m_read_cipher_states.end(),
- "Have a cipher state for the specified epoch");
-
+ if(i == m_read_cipher_states.end())
+ throw Internal_Error("TLS::Channel No read cipherstate for epoch " + std::to_string(epoch));
return i->second;
}
std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_epoch(u16bit epoch) const
{
auto i = m_write_cipher_states.find(epoch);
-
- BOTAN_ASSERT(i != m_write_cipher_states.end(),
- "Have a cipher state for the specified epoch");
-
+ if(i == m_write_cipher_states.end())
+ throw Internal_Error("TLS::Channel No write cipherstate for epoch " + std::to_string(epoch));
return i->second;
}
@@ -120,17 +120,17 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
std::unique_ptr<Handshake_IO> io;
if(version.is_datagram_protocol())
{
- // default MTU is IPv6 min MTU minus UDP/IP headers (TODO: make configurable)
- const u16bit mtu = 1280 - 40 - 8;
-
io.reset(new Datagram_Handshake_IO(
std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3),
sequence_numbers(),
- mtu));
+ m_policy.dtls_default_mtu(),
+ m_policy.dtls_initial_timeout(),
+ m_policy.dtls_maximum_timeout()));
}
else
- io.reset(new Stream_Handshake_IO(
- std::bind(&Channel::send_record, this, _1, _2)));
+ {
+ io.reset(new Stream_Handshake_IO(std::bind(&Channel::send_record, this, _1, _2)));
+ }
m_pending_state.reset(new_handshake_state(io.release()));
@@ -333,12 +333,13 @@ size_t Channel::received_data(const byte input[], size_t input_size)
if(record.size() > max_fragment_size)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
- "Plaintext record is too large");
+ "TLS input record is larger than allowed maximum");
if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC)
{
if(!m_pending_state)
{
+ // No pending handshake, possibly new:
if(record_version.is_datagram_protocol())
{
if(m_sequence_numbers)
@@ -374,6 +375,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
}
}
+ // May have been created in above conditional
if(m_pending_state)
{
m_pending_state->handshake_io().add_record(unlock(record),
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 4e6874a16..9ef2d17c4 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -24,6 +24,7 @@ namespace TLS {
class Connection_Cipher_State;
class Connection_Sequence_Numbers;
class Handshake_State;
+class Handshake_Message;
/**
* Generic interface for TLS endpoint
@@ -35,15 +36,18 @@ class BOTAN_DLL Channel
typedef std::function<void (const byte[], size_t)> data_cb;
typedef std::function<void (Alert, const byte[], size_t)> alert_cb;
typedef std::function<bool (const Session&)> handshake_cb;
+ typedef std::function<void (const Handshake_Message&)> handshake_msg_cb;
Channel(output_fn out,
data_cb app_data_cb,
alert_cb alert_cb,
handshake_cb hs_cb,
+ handshake_msg_cb hs_msg_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
+ const Policy& policy,
bool is_datagram,
- size_t reserved_io_buffer_size);
+ size_t io_buf_sz = 16*1024);
Channel(const Channel&) = delete;
@@ -196,6 +200,8 @@ class BOTAN_DLL Channel
Handshake_State& create_handshake_state(Protocol_Version version);
+ void inspect_handshake_message(const Handshake_Message& msg);
+
void activate_session();
void change_cipher_spec_reader(Connection_Side side);
@@ -214,8 +220,11 @@ class BOTAN_DLL Channel
Session_Manager& session_manager() { return m_session_manager; }
+ const Policy& policy() const { return m_policy; }
+
bool save_session(const Session& session) const { return m_handshake_cb(session); }
+ handshake_msg_cb get_handshake_msg_cb() const { return m_handshake_msg_cb; }
private:
size_t maximum_fragment_size() const;
@@ -245,14 +254,16 @@ class BOTAN_DLL Channel
bool m_is_datagram;
/* callbacks */
- handshake_cb m_handshake_cb;
data_cb m_data_cb;
alert_cb m_alert_cb;
output_fn m_output_fn;
+ handshake_cb m_handshake_cb;
+ handshake_msg_cb m_handshake_msg_cb;
/* external state */
- RandomNumberGenerator& m_rng;
Session_Manager& m_session_manager;
+ const Policy& m_policy;
+ RandomNumberGenerator& m_rng;
/* sequence number state */
std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index 9306092ce..82630b7fa 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -1,6 +1,6 @@
/*
* TLS Client
-* (C) 2004-2011,2012 Jack Lloyd
+* (C) 2004-2011,2012,2015 Jack Lloyd
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
@@ -23,8 +23,7 @@ class Client_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
- Handshake_State(io, cb) {}
+ Client_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {}
const Public_Key& get_server_public_Key() const
{
@@ -55,9 +54,31 @@ Client::Client(output_fn output_fn,
const Protocol_Version offer_version,
const std::vector<std::string>& next_protos,
size_t io_buf_sz) :
- Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng,
- offer_version.is_datagram_protocol(), io_buf_sz),
- m_policy(policy),
+ Channel(output_fn, proc_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(),
+ session_manager, rng, policy, offer_version.is_datagram_protocol(), io_buf_sz),
+ m_creds(creds),
+ m_info(info)
+ {
+ const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname());
+
+ Handshake_State& state = create_handshake_state(offer_version);
+ send_client_hello(state, false, offer_version, srp_identifier, next_protos);
+ }
+
+Client::Client(output_fn output_fn,
+ data_cb proc_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
+ handshake_msg_cb hs_msg_cb,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const Server_Information& info,
+ const Protocol_Version offer_version,
+ const std::vector<std::string>& next_protos) :
+ Channel(output_fn, proc_cb, alert_cb, handshake_cb, hs_msg_cb,
+ session_manager, rng, policy, offer_version.is_datagram_protocol()),
m_creds(creds),
m_info(info)
{
@@ -69,7 +90,7 @@ Client::Client(output_fn output_fn,
Handshake_State* Client::new_handshake_state(Handshake_IO* io)
{
- return new Client_Handshake_State(io); // , m_hs_msg_cb);
+ return new Client_Handshake_State(io, get_handshake_msg_cb());
}
std::vector<X509_Certificate>
@@ -111,7 +132,7 @@ void Client::send_client_hello(Handshake_State& state_base,
state.client_hello(new Client_Hello(
state.handshake_io(),
state.hash(),
- m_policy,
+ policy(),
rng(),
secure_renegotiation_data_for_client_hello(),
session_info,
@@ -128,7 +149,7 @@ void Client::send_client_hello(Handshake_State& state_base,
state.handshake_io(),
state.hash(),
version,
- m_policy,
+ policy(),
rng(),
secure_renegotiation_data_for_client_hello(),
next_protocols,
@@ -157,9 +178,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
if(state.client_hello())
return;
- if(m_policy.allow_server_initiated_renegotiation())
+ if(policy().allow_server_initiated_renegotiation())
{
- if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false)
+ if(!secure_renegotiation_supported() && policy().allow_insecure_renegotiation() == false)
send_warning_alert(Alert::NO_RENEGOTIATION);
else
this->initiate_handshake(state, false);
@@ -263,7 +284,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
if(state.server_hello()->supports_session_ticket())
state.set_expected_next(NEW_SESSION_TICKET);
else
+ {
state.set_expected_next(HANDSHAKE_CCS);
+ }
}
else
{
@@ -282,7 +305,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
"Server replied with later version than in hello");
}
- if(!m_policy.acceptable_protocol_version(state.version()))
+ if(!policy().acceptable_protocol_version(state.version()))
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Server version " + state.version().to_string() +
@@ -406,7 +429,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
state.client_kex(
new Client_Key_Exchange(state.handshake_io(),
state,
- m_policy,
+ policy(),
m_creds,
state.server_public_key.get(),
m_info.hostname(),
@@ -426,7 +449,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
state.client_verify(
new Certificate_Verify(state.handshake_io(),
state,
- m_policy,
+ policy(),
rng(),
private_key)
);
@@ -477,7 +500,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
const std::vector<byte>& session_ticket = state.session_ticket();
if(session_id.empty() && !session_ticket.empty())
- session_id = make_hello_random(rng(), m_policy);
+ session_id = make_hello_random(rng(), policy());
Session session_info(
session_id,
diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h
index e4e0dc363..b835c013e 100644
--- a/src/lib/tls/tls_client.h
+++ b/src/lib/tls/tls_client.h
@@ -67,6 +67,20 @@ class BOTAN_DLL Client : public Channel
size_t reserved_io_buffer_size = 16*1024
);
+ Client(output_fn out,
+ data_cb app_data_cb,
+ alert_cb alert_cb,
+ handshake_cb hs_cb,
+ handshake_msg_cb hs_msg_cb,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const Server_Information& server_info = Server_Information(),
+ const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
+ const std::vector<std::string>& next_protocols = {}
+ );
+
const std::string& application_protocol() const { return m_application_protocol; }
private:
std::vector<X509_Certificate>
@@ -88,7 +102,6 @@ class BOTAN_DLL Client : public Channel
Handshake_State* new_handshake_state(Handshake_IO* io) override;
- const Policy& m_policy;
Credentials_Manager& m_creds;
const Server_Information m_info;
std::string m_application_protocol;
diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp
index 6286eab08..f39c9f84e 100644
--- a/src/lib/tls/tls_handshake_io.cpp
+++ b/src/lib/tls/tls_handshake_io.cpp
@@ -1,6 +1,6 @@
/*
* TLS Handshake IO
-* (C) 2012,2014 Jack Lloyd
+* (C) 2012,2014,2015 Jack Lloyd
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
@@ -33,6 +33,24 @@ void store_be24(byte out[3], size_t val)
out[2] = get_byte<u32bit>(3, val);
}
+u64bit steady_clock_ms()
+ {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::steady_clock::now().time_since_epoch()).count();
+ }
+
+size_t split_for_mtu(size_t mtu, size_t msg_size)
+ {
+ const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
+
+ const size_t parts = (msg_size + mtu) / mtu;
+
+ if(parts + DTLS_HEADERS_SIZE > mtu)
+ return parts + 1;
+
+ return parts;
+ }
+
}
Protocol_Version Stream_Handshake_IO::initial_record_version() const
@@ -123,41 +141,15 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const
return Protocol_Version::DTLS_V10;
}
-namespace {
-
-// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1
-const u64bit INITIAL_TIMEOUT = 1*1000;
-const u64bit MAXIMUM_TIMEOUT = 60*1000;
-
-u64bit steady_clock_ms()
+void Datagram_Handshake_IO::retransmit_last_flight()
{
- return std::chrono::duration_cast<std::chrono::milliseconds>(
- std::chrono::steady_clock::now().time_since_epoch()).count();
+ const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
+ retransmit_flight(flight_idx);
}
-}
-
-bool Datagram_Handshake_IO::timeout_check()
+void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx)
{
- if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
- {
- /*
- If we haven't written anything yet obviously no timeout.
- Also no timeout possible if we are mid-flight,
- */
- return false;
- }
-
- const u64bit ms_since_write = steady_clock_ms() - m_last_write;
-
- if(ms_since_write < m_next_timeout)
- return false;
-
- std::vector<u16bit> flight;
- if(m_flights.size() == 1)
- flight = m_flights.at(0); // lost initial client hello
- else
- flight = m_flights.at(m_flights.size() - 2);
+ const std::vector<u16bit>& flight = m_flights.at(flight_idx);
BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit");
@@ -177,8 +169,27 @@ bool Datagram_Handshake_IO::timeout_check()
send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
epoch = msg.epoch;
}
+ }
+
+bool Datagram_Handshake_IO::timeout_check()
+ {
+ if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
+ {
+ /*
+ If we haven't written anything yet obviously no timeout.
+ Also no timeout possible if we are mid-flight,
+ */
+ return false;
+ }
+
+ const u64bit ms_since_write = steady_clock_ms() - m_last_write;
+
+ if(ms_since_write < m_next_timeout)
+ return false;
- m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT);
+ retransmit_last_flight();
+
+ m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout);
return true;
}
@@ -251,7 +262,6 @@ Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
if(m_ccs_epochs.count(current_epoch))
return std::make_pair(HANDSHAKE_CCS, std::vector<byte>());
}
-
return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
}
@@ -376,21 +386,6 @@ Datagram_Handshake_IO::format(const std::vector<byte>& msg,
return format_w_seq(msg, type, m_in_message_seq - 1);
}
-namespace {
-
-size_t split_for_mtu(size_t mtu, size_t msg_size)
- {
- const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers
-
- const size_t parts = (msg_size + mtu) / mtu;
-
- if(parts + DTLS_HEADERS_SIZE > mtu)
- return parts + 1;
-
- return parts;
- }
-
-}
std::vector<byte>
Datagram_Handshake_IO::send(const Handshake_Message& msg)
@@ -411,7 +406,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg)
m_out_message_seq += 1;
m_last_write = steady_clock_ms();
- m_next_timeout = INITIAL_TIMEOUT;
+ m_next_timeout = m_initial_timeout;
return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
}
@@ -425,7 +420,9 @@ std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq,
format_w_seq(msg_bits, msg_type, msg_seq);
if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
+ {
m_send_hs(epoch, HANDSHAKE, no_fragment);
+ }
else
{
const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h
index 00074a744..a1c1c5ce3 100644
--- a/src/lib/tls/tls_handshake_io.h
+++ b/src/lib/tls/tls_handshake_io.h
@@ -100,8 +100,14 @@ class Datagram_Handshake_IO : public Handshake_IO
Datagram_Handshake_IO(writer_fn writer,
class Connection_Sequence_Numbers& seq,
- u16bit mtu) :
- m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {}
+ u16bit mtu, u64bit initial_timeout_ms, u64bit max_timeout_ms) :
+ m_seqs(seq),
+ m_flights(1),
+ m_initial_timeout(initial_timeout_ms),
+ m_max_timeout(max_timeout_ms),
+ m_send_hs(writer),
+ m_mtu(mtu)
+ {}
Protocol_Version initial_record_version() const override;
@@ -120,6 +126,9 @@ class Datagram_Handshake_IO : public Handshake_IO
std::pair<Handshake_Type, std::vector<byte>>
get_next_record(bool expecting_ccs) override;
private:
+ void retransmit_flight(size_t flight);
+ void retransmit_last_flight();
+
std::vector<byte> format_fragment(
const byte fragment[],
size_t fragment_len,
@@ -183,6 +192,9 @@ class Datagram_Handshake_IO : public Handshake_IO
std::vector<std::vector<u16bit>> m_flights;
std::map<u16bit, Message_Info> m_flight_data;
+ u64bit m_initial_timeout = 0;
+ u64bit m_max_timeout = 0;
+
u64bit m_last_write = 0;
u64bit m_next_timeout = 0;
diff --git a/src/lib/tls/tls_handshake_msg.h b/src/lib/tls/tls_handshake_msg.h
index 6937d4f2c..7e527abf4 100644
--- a/src/lib/tls/tls_handshake_msg.h
+++ b/src/lib/tls/tls_handshake_msg.h
@@ -22,6 +22,8 @@ namespace TLS {
class BOTAN_DLL Handshake_Message
{
public:
+ std::string type_string() const;
+
virtual Handshake_Type type() const = 0;
virtual std::vector<byte> serialize() const = 0;
diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp
index cbbca3a0d..f885d3b08 100644
--- a/src/lib/tls/tls_handshake_state.cpp
+++ b/src/lib/tls/tls_handshake_state.cpp
@@ -1,6 +1,6 @@
/*
* TLS Handshaking
-* (C) 2004-2006,2011,2012 Jack Lloyd
+* (C) 2004-2006,2011,2012,2015 Jack Lloyd
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
@@ -13,6 +13,67 @@ namespace Botan {
namespace TLS {
+std::string Handshake_Message::type_string() const
+ {
+ return handshake_type_to_string(type());
+ }
+
+const char* handshake_type_to_string(Handshake_Type type)
+ {
+ switch(type)
+ {
+ case HELLO_VERIFY_REQUEST:
+ return "hello_verify_request";
+
+ case HELLO_REQUEST:
+ return "hello_request";
+
+ case CLIENT_HELLO:
+ return "client_hello";
+
+ case SERVER_HELLO:
+ return "server_hello";
+
+ case CERTIFICATE:
+ return "certificate";
+
+ case CERTIFICATE_URL:
+ return "certificate_url";
+
+ case CERTIFICATE_STATUS:
+ return "certificate_status";
+
+ case SERVER_KEX:
+ return "server_key_exchange";
+
+ case CERTIFICATE_REQUEST:
+ return "certificate_request";
+
+ case SERVER_HELLO_DONE:
+ return "server_hello_done";
+
+ case CERTIFICATE_VERIFY:
+ return "certificate_verify";
+
+ case CLIENT_KEX:
+ return "client_key_exchange";
+
+ case NEW_SESSION_TICKET:
+ return "new_session_ticket";
+
+ case HANDSHAKE_CCS:
+ return "change_cipher_spec";
+
+ case FINISHED:
+ return "finished";
+
+ case HANDSHAKE_NONE:
+ return "invalid";
+ }
+
+ throw Internal_Error("Unknown TLS handshake message type " + std::to_string(type));
+ }
+
namespace {
u32bit bitmask_for_handshake_type(Handshake_Type type)
@@ -25,9 +86,6 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
case HELLO_REQUEST:
return (1 << 1);
- /*
- * Same code point for both client hello styles
- */
case CLIENT_HELLO:
return (1 << 2);
@@ -75,12 +133,48 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
throw Internal_Error("Unknown handshake type " + std::to_string(type));
}
+std::string handshake_mask_to_string(u32bit mask)
+ {
+ const Handshake_Type types[] = {
+ HELLO_VERIFY_REQUEST,
+ HELLO_REQUEST,
+ CLIENT_HELLO,
+ CERTIFICATE,
+ CERTIFICATE_URL,
+ CERTIFICATE_STATUS,
+ SERVER_KEX,
+ CERTIFICATE_REQUEST,
+ SERVER_HELLO_DONE,
+ CERTIFICATE_VERIFY,
+ CLIENT_KEX,
+ NEW_SESSION_TICKET,
+ HANDSHAKE_CCS,
+ FINISHED
+ };
+
+ std::ostringstream o;
+ bool empty = true;
+
+ for(auto&& t : types)
+ {
+ if(mask & bitmask_for_handshake_type(t))
+ {
+ if(!empty)
+ o << ",";
+ o << handshake_type_to_string(t);
+ empty = false;
+ }
+ }
+
+ return o.str();
+ }
+
}
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) :
+Handshake_State::Handshake_State(Handshake_IO* io, handshake_msg_cb cb) :
m_msg_callback(cb),
m_handshake_io(io),
m_version(m_handshake_io->initial_record_version())
@@ -196,10 +290,10 @@ void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg)
const bool ok = (m_hand_expecting_mask & mask); // overlap?
if(!ok)
- throw Unexpected_Message("Unexpected state transition in handshake, got " +
+ throw Unexpected_Message("Unexpected state transition in handshake, got type " +
std::to_string(handshake_msg) +
- " expected " + std::to_string(m_hand_expecting_mask) +
- " received " + std::to_string(m_hand_received_mask));
+ " expected " + handshake_mask_to_string(m_hand_expecting_mask) +
+ " received " + handshake_mask_to_string(m_hand_received_mask));
/* We don't know what to expect next, so force a call to
set_expected_next; if it doesn't happen, the next transition
diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h
index 3b60178b4..6260b090f 100644
--- a/src/lib/tls/tls_handshake_state.h
+++ b/src/lib/tls/tls_handshake_state.h
@@ -45,9 +45,9 @@ class Finished;
class Handshake_State
{
public:
- typedef std::function<void (const Handshake_Message&)> hs_msg_cb;
+ typedef std::function<void (const Handshake_Message&)> handshake_msg_cb;
- Handshake_State(Handshake_IO* io, hs_msg_cb cb);
+ Handshake_State(Handshake_IO* io, handshake_msg_cb cb);
virtual ~Handshake_State();
@@ -170,7 +170,7 @@ class Handshake_State
private:
- hs_msg_cb m_msg_callback;
+ handshake_msg_cb m_msg_callback;
std::unique_ptr<Handshake_IO> m_handshake_io;
diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h
index 882e59158..6db908b08 100644
--- a/src/lib/tls/tls_magic.h
+++ b/src/lib/tls/tls_magic.h
@@ -57,6 +57,8 @@ enum Handshake_Type {
HANDSHAKE_NONE = 255 // Null value
};
+const char* handshake_type_to_string(Handshake_Type t);
+
enum Compression_Method {
NO_COMPRESSION = 0x00,
DEFLATE_COMPRESSION = 0x01
diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp
index f50cf1f3e..d8dd2c828 100644
--- a/src/lib/tls/tls_policy.cpp
+++ b/src/lib/tls/tls_policy.cpp
@@ -20,9 +20,9 @@ std::vector<std::string> Policy::allowed_ciphers() const
return {
//"AES-256/OCB(12)",
//"AES-128/OCB(12)",
- "ChaCha20Poly1305",
"AES-256/GCM",
"AES-128/GCM",
+ "ChaCha20Poly1305",
"AES-256/CCM",
"AES-128/CCM",
"AES-256/CCM(8)",
@@ -35,7 +35,6 @@ std::vector<std::string> Policy::allowed_ciphers() const
//"Camellia-128",
//"SEED"
//"3DES",
- //"RC4",
};
}
@@ -175,6 +174,16 @@ bool Policy::include_time_in_hello_random() const { return true; }
bool Policy::hide_unknown_users() const { return false; }
bool Policy::server_uses_own_ciphersuite_preferences() const { return true; }
+// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1
+size_t Policy::dtls_initial_timeout() const { return 1*1000; }
+size_t Policy::dtls_maximum_timeout() const { return 60*1000; }
+
+size_t Policy::dtls_default_mtu() const
+ {
+ // default MTU is IPv6 min MTU minus UDP/IP headers
+ return 1280 - 40 - 8;
+ }
+
std::vector<u16bit> Policy::srtp_profiles() const
{
return std::vector<u16bit>();
diff --git a/src/lib/tls/tls_policy.h b/src/lib/tls/tls_policy.h
index 581d04bcd..c3f8f1ee2 100644
--- a/src/lib/tls/tls_policy.h
+++ b/src/lib/tls/tls_policy.h
@@ -13,6 +13,7 @@
#include <botan/x509cert.h>
#include <botan/dl_group.h>
#include <vector>
+#include <sstream>
namespace Botan {
@@ -173,6 +174,12 @@ class BOTAN_DLL Policy
virtual std::vector<u16bit> ciphersuite_list(Protocol_Version version,
bool have_srp) const;
+ virtual size_t dtls_default_mtu() const;
+
+ virtual size_t dtls_initial_timeout() const;
+
+ virtual size_t dtls_maximum_timeout() const;
+
virtual void print(std::ostream& o) const;
virtual ~Policy() {}
@@ -299,6 +306,14 @@ class BOTAN_DLL Text_Policy : public Policy
return r;
}
+ void set(const std::string& k, const std::string& v) { m_kv[k] = v; }
+
+ Text_Policy(const std::string& s)
+ {
+ std::istringstream iss(s);
+ m_kv = read_cfg(iss);
+ }
+
Text_Policy(std::istream& in)
{
m_kv = read_cfg(in);
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 71542de16..e38b26547 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -12,6 +12,7 @@
#include <botan/internal/tls_seq_numbers.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/rounding.h>
+#include <botan/internal/ct_utils.h>
#include <botan/rng.h>
namespace Botan {
@@ -284,31 +285,26 @@ size_t fill_buffer_to(secure_vector<byte>& readbuf,
*
* Returning 0 in the error case should ensure the MAC check will fail.
* This approach is suggested in section 6.2.3.2 of RFC 5246.
-*
-* Also returns 0 if block_size == 0, so can be safely called with a
-* stream cipher in use.
-*
-* @fixme This should run in constant time
*/
-size_t tls_padding_check(const byte record[], size_t record_len)
+u16bit tls_padding_check(const byte record[], size_t record_len)
{
- const size_t padding_length = record[(record_len-1)];
-
- if(padding_length >= record_len)
- return 0;
-
/*
* TLS v1.0 and up require all the padding bytes be the same value
* and allows up to 255 bytes.
*/
- const size_t pad_start = record_len - padding_length - 1;
- volatile size_t cmp = 0;
+ const byte pad_byte = record[(record_len-1)];
- for(size_t i = 0; i != padding_length; ++i)
- cmp += record[pad_start + i] ^ padding_length;
+ byte pad_invalid = 0;
+ for(size_t i = 0; i != record_len; ++i)
+ {
+ const size_t left = record_len - i - 2;
+ const byte delim_mask = CT::is_less<u16bit>(left, pad_byte) & 0xFF;
+ pad_invalid |= (delim_mask & (record[i] ^ pad_byte));
+ }
- return cmp ? 0 : padding_length + 1;
+ u16bit pad_invalid_mask = CT::expand_mask<u16bit>(pad_invalid);
+ return CT::select<u16bit>(pad_invalid_mask, 0, pad_byte + 1);
}
void cbc_decrypt_record(byte record_contents[], size_t record_len,
@@ -375,38 +371,39 @@ void decrypt_record(secure_vector<byte>& output,
else
{
// GenericBlockCipher case
+ BlockCipher* bc = cs.block_cipher();
+ BOTAN_ASSERT(bc != nullptr, "No cipher state set but needed to decrypt");
- volatile bool padding_bad = false;
- size_t pad_size = 0;
+ const size_t mac_size = cs.mac_size();
+ const size_t iv_size = cs.iv_size();
- if(BlockCipher* bc = cs.block_cipher())
- {
- cbc_decrypt_record(record_contents, record_len, cs, *bc);
+ // This early exit does not leak info because all the values are public
+ if((record_len < mac_size + iv_size) || (record_len % cs.block_size() != 0))
+ throw Decoding_Error("Record sent with invalid length");
- pad_size = tls_padding_check(record_contents, record_len);
+ CT::poison(record_contents, record_len);
- padding_bad = (pad_size == 0);
- }
- else
- {
- throw Internal_Error("No cipher state set but needed to decrypt");
- }
+ cbc_decrypt_record(record_contents, record_len, cs, *bc);
- const size_t mac_size = cs.mac_size();
- const size_t iv_size = cs.iv_size();
+ // 0 if padding was invalid, otherwise 1 + padding_bytes
+ u16bit pad_size = tls_padding_check(record_contents, record_len);
- const size_t mac_pad_iv_size = mac_size + pad_size + iv_size;
+ // This mask is zero if there is not enough room in the packet
+ const u16bit size_ok_mask = CT::is_less<u16bit>(mac_size + pad_size + iv_size, record_len);
+ pad_size &= size_ok_mask;
- if(record_len < mac_pad_iv_size)
- throw Decoding_Error("Record sent with invalid length");
+ CT::unpoison(record_contents, record_len);
- const byte* plaintext_block = &record_contents[iv_size];
- const u16bit plaintext_length = record_len - mac_pad_iv_size;
+ /*
+ This is unpoisoned sooner than it should. The pad_size leaks to plaintext_length and
+ then to the timing channel in the MAC computation described in the Lucky 13 paper.
+ */
+ CT::unpoison(pad_size);
- cs.mac()->update(
- cs.format_ad(record_sequence, record_type, record_version, plaintext_length)
- );
+ const byte* plaintext_block = &record_contents[iv_size];
+ const u16bit plaintext_length = record_len - mac_size - iv_size - pad_size;
+ cs.mac()->update(cs.format_ad(record_sequence, record_type, record_version, plaintext_length));
cs.mac()->update(plaintext_block, plaintext_length);
std::vector<byte> mac_buf(mac_size);
@@ -414,12 +411,16 @@ void decrypt_record(secure_vector<byte>& output,
const size_t mac_offset = record_len - (mac_size + pad_size);
- const bool mac_bad = !same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size);
+ const bool mac_ok = same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size);
- if(mac_bad || padding_bad)
- throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure");
+ const u16bit ok_mask = size_ok_mask & CT::expand_mask<u16bit>(mac_ok) & CT::expand_mask<u16bit>(pad_size);
- output.assign(plaintext_block, plaintext_block + plaintext_length);
+ CT::unpoison(ok_mask);
+
+ if(ok_mask)
+ output.assign(plaintext_block, plaintext_block + plaintext_length);
+ else
+ throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure");
}
}
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 330135e63..774827346 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -21,8 +21,7 @@ class Server_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
- Handshake_State(io, cb) {}
+ Server_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {}
// Used by the server only, in case of RSA key exchange. Not owned
Private_Key* server_rsa_kex_key = nullptr;
@@ -215,9 +214,26 @@ Server::Server(output_fn output,
next_protocol_fn next_proto,
bool is_datagram,
size_t io_buf_sz) :
- Channel(output, data_cb, alert_cb, handshake_cb,
- session_manager, rng, is_datagram, io_buf_sz),
- m_policy(policy),
+ Channel(output, data_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(),
+ session_manager, rng, policy, is_datagram, io_buf_sz),
+ m_creds(creds),
+ m_choose_next_protocol(next_proto)
+ {
+ }
+
+Server::Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
+ handshake_msg_cb hs_msg_cb,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ next_protocol_fn next_proto,
+ bool is_datagram) :
+ Channel(output, data_cb, alert_cb, handshake_cb, hs_msg_cb,
+ session_manager, rng, policy, is_datagram),
m_creds(creds),
m_choose_next_protocol(next_proto)
{
@@ -225,7 +241,9 @@ Server::Server(output_fn output,
Handshake_State* Server::new_handshake_state(Handshake_IO* io)
{
- std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io));
+ std::unique_ptr<Handshake_State> state(
+ new Server_Handshake_State(io, get_handshake_msg_cb()));
+
state->set_expected_next(CLIENT_HELLO);
return state.release();
}
@@ -278,7 +296,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
{
const bool initial_handshake = !active_state;
- if(!m_policy.allow_insecure_renegotiation() &&
+ if(!policy().allow_insecure_renegotiation() &&
!(initial_handshake || secure_renegotiation_supported()))
{
send_warning_alert(Alert::NO_RENEGOTIATION);
@@ -292,7 +310,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
Protocol_Version negotiated_version;
const Protocol_Version latest_supported =
- m_policy.latest_supported_version(client_version.is_datagram_protocol());
+ policy().latest_supported_version(client_version.is_datagram_protocol());
if((initial_handshake && client_version.known_version()) ||
(!initial_handshake && client_version == active_state->version()))
@@ -334,7 +352,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
negotiated_version = latest_supported;
}
- if(!m_policy.acceptable_protocol_version(negotiated_version))
+ if(!policy().acceptable_protocol_version(negotiated_version))
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Client version " + negotiated_version.to_string() +
@@ -359,7 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
session_manager(),
m_creds,
state.client_hello(),
- std::chrono::seconds(m_policy.session_ticket_lifetime()));
+ std::chrono::seconds(policy().session_ticket_lifetime()));
bool have_session_ticket_key = false;
@@ -387,7 +405,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
state.server_hello(new Server_Hello(
state.handshake_io(),
state.hash(),
- m_policy,
+ policy(),
rng(),
secure_renegotiation_data_for_server_hello(),
*state.client_hello(),
@@ -423,7 +441,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
new New_Session_Ticket(state.handshake_io(),
state.hash(),
session_info.encrypt(ticket_key, rng()),
- m_policy.session_ticket_lifetime())
+ policy().session_ticket_lifetime())
);
}
catch(...) {}
@@ -469,14 +487,14 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
state.server_hello(new Server_Hello(
state.handshake_io(),
state.hash(),
- m_policy,
+ policy(),
rng(),
secure_renegotiation_data_for_server_hello(),
*state.client_hello(),
- make_hello_random(rng(), m_policy), // new session ID
+ make_hello_random(rng(), policy()), // new session ID
state.version(),
- choose_ciphersuite(m_policy, state.version(), m_creds, cert_chains, state.client_hello()),
- choose_compression(m_policy, state.client_hello()->compression_methods()),
+ choose_ciphersuite(policy(), state.version(), m_creds, cert_chains, state.client_hello()),
+ choose_compression(policy(), state.client_hello()->compression_methods()),
have_session_ticket_key,
m_next_protocol)
);
@@ -516,7 +534,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
else
{
state.server_kex(new Server_Key_Exchange(state.handshake_io(),
- state, m_policy,
+ state, policy(),
m_creds, rng(), private_key));
}
@@ -534,7 +552,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
{
state.cert_req(
new Certificate_Req(state.handshake_io(), state.hash(),
- m_policy, client_auth_CAs, state.version()));
+ policy(), client_auth_CAs, state.version()));
state.set_expected_next(CERTIFICATE);
}
@@ -565,7 +583,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
state.client_kex(
new Client_Key_Exchange(contents, state,
state.server_rsa_kex_key,
- m_creds, m_policy, rng())
+ m_creds, policy(), rng())
);
state.compute_session_keys();
@@ -649,7 +667,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
new New_Session_Ticket(state.handshake_io(),
state.hash(),
session_info.encrypt(ticket_key, rng()),
- m_policy.session_ticket_lifetime())
+ policy().session_ticket_lifetime())
);
}
catch(...) {}
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index 4f2a11ba4..ffe1111bc 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -40,6 +40,19 @@ class BOTAN_DLL Server : public Channel
size_t reserved_io_buffer_size = 16*1024
);
+ Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
+ handshake_msg_cb hs_msg_cb,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ next_protocol_fn next_proto = next_protocol_fn(),
+ bool is_datagram = false
+ );
+
/**
* Return the protocol notification set by the client (using the
* NPN extension) for this connection, if any. This value is not
@@ -62,7 +75,6 @@ class BOTAN_DLL Server : public Channel
Handshake_State* new_handshake_state(Handshake_IO* io) override;
- const Policy& m_policy;
Credentials_Manager& m_creds;
next_protocol_fn m_choose_next_protocol;
diff --git a/src/lib/utils/ct_utils.h b/src/lib/utils/ct_utils.h
index 52a3bc388..2ea07b382 100644
--- a/src/lib/utils/ct_utils.h
+++ b/src/lib/utils/ct_utils.h
@@ -51,6 +51,12 @@ inline void unpoison(T* p, size_t n)
#endif
}
+template<typename T>
+inline void unpoison(T& p)
+ {
+ unpoison(&p, 1);
+ }
+
/*
* T should be an unsigned machine integer type
* Expand to a mask used for other operations
@@ -90,6 +96,16 @@ inline T is_equal(T x, T y)
}
template<typename T>
+inline T is_less(T x, T y)
+ {
+ /*
+ This expands to a constant time sequence with GCC 5.2.0 on x86-64
+ but something more complicated may be needed for portable const time.
+ */
+ return expand_mask<T>(x < y);
+ }
+
+template<typename T>
inline void conditional_copy_mem(T value,
T* to,
const T* from0,
@@ -102,6 +118,26 @@ inline void conditional_copy_mem(T value,
to[i] = CT::select(mask, from0[i], from1[i]);
}
+template<typename T>
+inline T expand_top_bit(T a)
+ {
+ return expand_mask<T>(a >> (sizeof(T)*8-1));
+ }
+
+template<typename T>
+inline T max(T a, T b)
+ {
+ const T a_larger = b - a; // negative if a is larger
+ return select(expand_top_bit(a), a, b);
+ }
+
+template<typename T>
+inline T min(T a, T b)
+ {
+ const T a_larger = b - a; // negative if a is larger
+ return select(expand_top_bit(b), b, a);
+ }
+
}
}
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp
index 116eb2cdf..d4abef119 100644
--- a/src/tests/unit_tls.cpp
+++ b/src/tests/unit_tls.cpp
@@ -11,6 +11,7 @@
#include <botan/tls_server.h>
#include <botan/tls_client.h>
+#include <botan/tls_handshake_msg.h>
#include <botan/pkcs10.h>
#include <botan/x509self.h>
#include <botan/rsa.h>
@@ -21,6 +22,7 @@
#include <iostream>
#include <vector>
#include <memory>
+#include <thread>
using namespace Botan;
@@ -155,158 +157,441 @@ Credentials_Manager* create_creds()
return new Credentials_Manager_Test(server_cert, ca_cert, server_key);
}
-size_t basic_test_handshake(RandomNumberGenerator& rng,
- TLS::Protocol_Version offer_version,
- Credentials_Manager& creds,
- TLS::Policy& policy)
+std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q)
{
- TLS::Session_Manager_In_Memory server_sessions(rng);
- TLS::Session_Manager_In_Memory client_sessions(rng);
-
- std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data;
+ return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); };
+ }
- auto handshake_complete = [&](const TLS::Session& session) -> bool
+void print_alert(TLS::Alert alert, const byte[], size_t)
{
- if(session.version() != offer_version)
- std::cout << "Offered " << offer_version.to_string()
- << " got " << session.version().to_string() << std::endl;
- return true;
+ //std::cout << "Alert " << alert.type_string() << std::endl;
};
- auto print_alert = [&](TLS::Alert alert, const byte[], size_t)
+void mutate(std::vector<byte>& v, RandomNumberGenerator& rng)
{
- if(alert.is_valid())
- std::cout << "Server recvd alert " << alert.type_string() << std::endl;
- };
+ if(v.empty())
+ return;
- auto save_server_data = [&](const byte buf[], size_t sz)
- {
- c2s_data.insert(c2s_data.end(), buf, buf+sz);
- };
+ size_t voff = rng.get_random<size_t>() % v.size();
+ v[voff] ^= rng.next_nonzero_byte();
+ }
- auto save_client_data = [&](const byte buf[], size_t sz)
+size_t test_tls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
{
- s2c_data.insert(s2c_data.end(), buf, buf+sz);
- };
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
- auto next_protocol_chooser = [&](std::vector<std::string> protos) {
- if(protos.size() != 2)
- std::cout << "Bad protocol size" << std::endl;
- if(protos[0] != "test/1" || protos[1] != "test/2")
- std::cout << "Bad protocol values" << std::endl;
- return "test/3";
- };
- const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
-
- TLS::Server server([&](const byte buf[], size_t sz)
- { s2c_q.insert(s2c_q.end(), buf, buf+sz); },
- save_server_data,
- print_alert,
- handshake_complete,
- server_sessions,
- creds,
- policy,
- rng,
- next_protocol_chooser,
- offer_version.is_datagram_protocol());
-
- TLS::Client client([&](const byte buf[], size_t sz)
- { c2s_q.insert(c2s_q.end(), buf, buf+sz); },
- save_client_data,
- print_alert,
- handshake_complete,
- client_sessions,
- creds,
- policy,
- rng,
- TLS::Server_Information("server.example.com"),
- offer_version,
- protocols_offered);
-
- while(true)
+ for(size_t r = 1; r <= 4; ++r)
{
- if(client.is_closed() && server.is_closed())
- break;
+ //std::cout << offer_version.to_string() << " r " << r << "\n";
- if(client.is_active())
- client.send("1");
- if(server.is_active())
- {
- if(server.next_protocol() != "test/3")
- std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
- server.send("2");
- }
+ bool handshake_done = false;
- /*
- * Use this as a temp value to hold the queues as otherwise they
- * might end up appending more in response to messages during the
- * handshake.
- */
- std::vector<byte> input;
- std::swap(c2s_q, input);
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- try
- {
- server.received_data(input.data(), input.size());
- }
- catch(std::exception& e)
- {
- std::cout << "Server error - " << e.what() << std::endl;
- return 1;
- }
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
- input.clear();
- std::swap(s2c_q, input);
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
try
{
- client.received_data(input.data(), input.size());
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ false);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
+ {
+ ++rounds;
+
+ if(rounds > 25)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send in several records
+ client.send(client_sent);
+ }
+
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server recv " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client recv " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
}
catch(std::exception& e)
{
- std::cout << "Client error - " << e.what() << std::endl;
+ std::cout << e.what() << "\n";
return 1;
}
+ }
- if(c2s_data.size())
- {
- if(c2s_data[0] != '1')
- {
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ return 0;
+ }
+
+size_t test_dtls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
+ {
+ BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version");
+
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
+
+ for(size_t r = 1; r <= 2; ++r)
+ {
+ //std::cout << offer_version.to_string() << " round " << r << "\n";
+
+ bool handshake_done = false;
+
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- if(s2c_data.size())
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
+
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
+
+ try
{
- if(s2c_data[0] != '2')
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ true);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
{
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ // TODO: client and server should be in different threads
+ std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2));
+ ++rounds;
- if(s2c_data.size() && c2s_data.size())
- {
- SymmetricKey client_key = client.key_material_export("label", "context", 32);
- SymmetricKey server_key = server.key_material_export("label", "context", 32);
+ if(rounds > 100)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511 and send random chunks:
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send multiple parts
+ //std::cout << "Sending " << client_sent.size() << " bytes to server\n";
+ client.send(client_sent);
+ }
- if(client_key != server_key)
- return 1;
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ //std::cout << "Sending " << server_sent.size() << " bytes to client\n";
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server got " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client got " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
- server.close();
- client.close();
+ // If we corrupted a DTLS application message, resend it:
+ if(client.is_active() && corrupt_client_data && server_recv.empty())
+ client.send(client_sent);
+ if(server.is_active() && corrupt_server_data && client_recv.empty())
+ server.send(server_sent);
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
+ }
+ catch(std::exception& e)
+ {
+ std::cout << e.what() << "\n";
+ return 1;
}
}
return 0;
}
-class Test_Policy : public TLS::Policy
+class Test_Policy : public TLS::Text_Policy
{
public:
+ Test_Policy() : Text_Policy("") {}
bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; }
bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; }
+
+ size_t dtls_initial_timeout() const override { return 1; }
+ size_t dtls_maximum_timeout() const override { return 8; }
};
}
@@ -315,17 +600,43 @@ size_t test_tls()
{
size_t errors = 0;
- Test_Policy default_policy;
auto& rng = test_rng();
std::unique_ptr<Credentials_Manager> basic_creds(create_creds());
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy);
-
- test_report("TLS", 5, errors);
+ Test_Policy policy;
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "RSA");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "DH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ policy.set("key_exchange_methods", "ECDH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "AES-128");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "ChaCha20Poly1305");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ test_report("TLS", 22, errors);
return errors;
}