aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/c_hello.cpp8
-rw-r--r--src/tls/c_kex.cpp4
-rw-r--r--src/tls/cert_req.cpp6
-rw-r--r--src/tls/cert_ver.cpp4
-rw-r--r--src/tls/finished.cpp6
-rw-r--r--src/tls/info.txt2
-rw-r--r--src/tls/next_protocol.cpp4
-rw-r--r--src/tls/rec_wri.cpp19
-rw-r--r--src/tls/s_hello.cpp6
-rw-r--r--src/tls/s_kex.cpp4
-rw-r--r--src/tls/session_ticket.cpp6
-rw-r--r--src/tls/tls_channel.cpp12
-rw-r--r--src/tls/tls_channel.h2
-rw-r--r--src/tls/tls_client.cpp30
-rw-r--r--src/tls/tls_client.h2
-rw-r--r--src/tls/tls_handshake_state.cpp46
-rw-r--r--src/tls/tls_handshake_state.h49
-rw-r--r--src/tls/tls_handshake_writer.cpp38
-rw-r--r--src/tls/tls_handshake_writer.h52
-rw-r--r--src/tls/tls_messages.h31
-rw-r--r--src/tls/tls_policy.cpp10
-rw-r--r--src/tls/tls_record.h2
-rw-r--r--src/tls/tls_server.cpp69
-rw-r--r--src/tls/tls_server.h2
24 files changed, 248 insertions, 166 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index c9249ab9a..465e6714a 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -9,7 +9,7 @@
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/stl_util.h>
#include <chrono>
@@ -36,7 +36,7 @@ std::vector<byte> make_hello_random(RandomNumberGenerator& rng)
/*
* Create a new Hello Request message
*/
-Hello_Request::Hello_Request(Record_Writer& writer)
+Hello_Request::Hello_Request(Handshake_Writer& writer)
{
writer.send(*this);
}
@@ -61,7 +61,7 @@ std::vector<byte> Hello_Request::serialize() const
/*
* Create a new Client Hello message
*/
-Client_Hello::Client_Hello(Record_Writer& writer,
+Client_Hello::Client_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
Protocol_Version version,
const Policy& policy,
@@ -98,7 +98,7 @@ Client_Hello::Client_Hello(Record_Writer& writer,
/*
* Create a new Client Hello message (session resumption case)
*/
-Client_Hello::Client_Hello(Record_Writer& writer,
+Client_Hello::Client_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
RandomNumberGenerator& rng,
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index 28449e614..f1b2306f1 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/pubkey.h>
@@ -47,7 +47,7 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input)
/*
* Create a new Client Key Exchange message
*/
-Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
+Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index 4578148f5..0806f5f66 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/der_enc.h>
#include <botan/ber_dec.h>
#include <botan/loadstor.h>
@@ -51,7 +51,7 @@ byte cert_type_name_to_code(const std::string& name)
/**
* Create a new Certificate Request message
*/
-Certificate_Req::Certificate_Req(Record_Writer& writer,
+Certificate_Req::Certificate_Req(Handshake_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& ca_certs,
@@ -166,7 +166,7 @@ std::vector<byte> Certificate_Req::serialize() const
/**
* Create a new Certificate message
*/
-Certificate::Certificate(Record_Writer& writer,
+Certificate::Certificate(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& cert_list) :
m_certs(cert_list)
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index 870d70951..4dbae9da3 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/assert.h>
#include <memory>
@@ -19,7 +19,7 @@ namespace TLS {
/*
* Create a new Certificate Verify message
*/
-Certificate_Verify::Certificate_Verify(Record_Writer& writer,
+Certificate_Verify::Certificate_Verify(Handshake_Writer& writer,
Handshake_State* state,
const Policy& policy,
RandomNumberGenerator& rng,
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index c8ae4a343..4dcc9e1ae 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -6,7 +6,7 @@
*/
#include <botan/internal/tls_messages.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <memory>
namespace Botan {
@@ -19,7 +19,7 @@ namespace {
* Compute the verify_data
*/
std::vector<byte> finished_compute_verify(Handshake_State* state,
- Connection_Side side)
+ Connection_Side side)
{
if(state->version() == Protocol_Version::SSL_V3)
{
@@ -66,7 +66,7 @@ std::vector<byte> finished_compute_verify(Handshake_State* state,
/*
* Create a new Finished message
*/
-Finished::Finished(Record_Writer& writer,
+Finished::Finished(Handshake_Writer& writer,
Handshake_State* state,
Connection_Side side)
{
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 1863be577..212562373 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -27,6 +27,7 @@ tls_extensions.h
tls_handshake_hash.h
tls_handshake_reader.h
tls_handshake_state.h
+tls_handshake_writer.h
tls_heartbeats.h
tls_messages.h
tls_reader.h
@@ -54,6 +55,7 @@ tls_extensions.cpp
tls_handshake_hash.cpp
tls_handshake_reader.cpp
tls_handshake_state.cpp
+tls_handshake_writer.cpp
tls_heartbeats.cpp
tls_policy.cpp
tls_server.cpp
diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp
index adf9acbe9..a8989c5a9 100644
--- a/src/tls/next_protocol.cpp
+++ b/src/tls/next_protocol.cpp
@@ -8,13 +8,13 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_extensions.h>
#include <botan/internal/tls_reader.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
namespace Botan {
namespace TLS {
-Next_Protocol::Next_Protocol(Record_Writer& writer,
+Next_Protocol::Next_Protocol(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::string& protocol) :
m_protocol(protocol)
diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp
index b5b9e826c..2523f8229 100644
--- a/src/tls/rec_wri.cpp
+++ b/src/tls/rec_wri.cpp
@@ -148,25 +148,6 @@ void Record_Writer::activate(Connection_Side side,
throw Invalid_Argument("Record_Writer: Unknown hash " + mac_algo);
}
-std::vector<byte> Record_Writer::send(Handshake_Message& msg)
- {
- const std::vector<byte> buf = msg.serialize();
- std::vector<byte> send_buf(4);
-
- const size_t buf_size = buf.size();
-
- send_buf[0] = msg.type();
-
- for(size_t i = 1; i != 4; ++i)
- send_buf[i] = get_byte<u32bit>(i, buf_size);
-
- send_buf += buf;
-
- send(HANDSHAKE, &send_buf[0], send_buf.size());
-
- return send_buf;
- }
-
/*
* Send one or more records to the other side
*/
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index 3b65b39f1..d34fa5e70 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -9,7 +9,7 @@
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/stl_util.h>
namespace Botan {
@@ -19,7 +19,7 @@ namespace TLS {
/*
* Create a new Server Hello message
*/
-Server_Hello::Server_Hello(Record_Writer& writer,
+Server_Hello::Server_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<byte>& session_id,
Protocol_Version ver,
@@ -149,7 +149,7 @@ std::vector<byte> Server_Hello::serialize() const
/*
* Create a new Server Hello Done message
*/
-Server_Hello_Done::Server_Hello_Done(Record_Writer& writer,
+Server_Hello_Done::Server_Hello_Done(Handshake_Writer& writer,
Handshake_Hash& hash)
{
hash.update(writer.send(*this));
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 834dff979..423497976 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/loadstor.h>
@@ -27,7 +27,7 @@ namespace TLS {
/**
* Create a new Server Key Exchange message
*/
-Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
+Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp
index 8cee2a454..3affe8fcf 100644
--- a/src/tls/session_ticket.cpp
+++ b/src/tls/session_ticket.cpp
@@ -8,14 +8,14 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_extensions.h>
#include <botan/internal/tls_reader.h>
-#include <botan/tls_record.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/loadstor.h>
namespace Botan {
namespace TLS {
-New_Session_Ticket::New_Session_Ticket(Record_Writer& writer,
+New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<byte>& ticket,
u32bit lifetime) :
@@ -25,7 +25,7 @@ New_Session_Ticket::New_Session_Ticket(Record_Writer& writer,
hash.update(writer.send(*this));
}
-New_Session_Ticket::New_Session_Ticket(Record_Writer& writer,
+New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer,
Handshake_Hash& hash) :
m_ticket_lifetime_hint(0)
{
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 84ee69e04..d77f6dbcf 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -177,8 +177,8 @@ void Channel::read_handshake(byte rec_type,
if(rec_type == HANDSHAKE)
{
if(!m_state)
- m_state = new Handshake_State(this->new_handshake_reader());
- m_state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size());
+ m_state = new_handshake_state();
+ m_state->handshake_reader().add_input(&rec_buf[0], rec_buf.size());
}
BOTAN_ASSERT_NONNULL(m_state);
@@ -189,10 +189,10 @@ void Channel::read_handshake(byte rec_type,
if(rec_type == HANDSHAKE)
{
- if(m_state->handshake_reader()->have_full_record())
+ if(m_state->handshake_reader().have_full_record())
{
std::pair<Handshake_Type, std::vector<byte> > msg =
- m_state->handshake_reader()->get_next_record();
+ m_state->handshake_reader().get_next_record();
process_handshake_msg(msg.first, msg.second);
}
else
@@ -200,7 +200,7 @@ void Channel::read_handshake(byte rec_type,
}
else if(rec_type == CHANGE_CIPHER_SPEC)
{
- if(m_state->handshake_reader()->empty() && rec_buf.size() == 1 && rec_buf[0] == 1)
+ if(m_state->handshake_reader().empty() && rec_buf.size() == 1 && rec_buf[0] == 1)
process_handshake_msg(HANDSHAKE_CCS, std::vector<byte>());
else
throw Decoding_Error("Malformed ChangeCipherSpec message");
@@ -208,7 +208,7 @@ void Channel::read_handshake(byte rec_type,
else
throw Decoding_Error("Unknown message type in handshake processing");
- if(type == HANDSHAKE_CCS || !m_state || !m_state->handshake_reader()->have_full_record())
+ if(type == HANDSHAKE_CCS || !m_state || !m_state->handshake_reader().have_full_record())
break;
}
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index ae4108e84..bd81a1745 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -111,7 +111,7 @@ class BOTAN_DLL Channel
virtual void alert_notify(const Alert& alert) = 0;
- virtual class Handshake_Reader* new_handshake_reader() const = 0;
+ virtual class Handshake_State* new_handshake_state() = 0;
class Secure_Renegotiation_State
{
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 471cbefed..a62bcbba5 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -37,7 +37,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
{
m_writer.set_version(Protocol_Version::SSL_V3);
- m_state = new Handshake_State(this->new_handshake_reader());
+ m_state = new_handshake_state();
m_state->set_expected_next(SERVER_HELLO);
m_state->client_npn_cb = next_protocol;
@@ -54,7 +54,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
if(session_info.srp_identifier() == srp_identifier)
{
m_state->client_hello = new Client_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
m_policy,
m_rng,
@@ -70,7 +70,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
if(!m_state->client_hello) // not resuming
{
m_state->client_hello = new Client_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
m_policy.pref_version(),
m_policy,
@@ -84,9 +84,10 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
m_secure_renegotiation.update(m_state->client_hello);
}
-Handshake_Reader* Client::new_handshake_reader() const
+Handshake_State* Client::new_handshake_state()
{
- return new Stream_Handshake_Reader;
+ return new Handshake_State(new Stream_Handshake_Reader,
+ new Stream_Handshake_Writer(m_writer));
}
/*
@@ -98,7 +99,7 @@ void Client::renegotiate(bool force_full_renegotiation)
return; // currently in active handshake
delete m_state;
- m_state = new Handshake_State(this->new_handshake_reader());
+ m_state = new_handshake_state();
m_state->set_expected_next(SERVER_HELLO);
@@ -108,7 +109,7 @@ void Client::renegotiate(bool force_full_renegotiation)
if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info))
{
m_state->client_hello = new Client_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
m_policy,
m_rng,
@@ -122,7 +123,7 @@ void Client::renegotiate(bool force_full_renegotiation)
if(!m_state->client_hello)
{
m_state->client_hello = new Client_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
m_reader.get_version(),
m_policy,
@@ -367,13 +368,13 @@ void Client::process_handshake_msg(Handshake_Type type,
"tls-client",
m_hostname);
- m_state->client_certs = new Certificate(m_writer,
+ m_state->client_certs = new Certificate(m_state->handshake_writer(),
m_state->hash,
client_certs);
}
m_state->client_kex =
- new Client_Key_Exchange(m_writer,
+ new Client_Key_Exchange(m_state->handshake_writer(),
m_state,
m_policy,
m_creds,
@@ -393,7 +394,7 @@ void Client::process_handshake_msg(Handshake_Type type,
"tls-client",
m_hostname);
- m_state->client_verify = new Certificate_Verify(m_writer,
+ m_state->client_verify = new Certificate_Verify(m_state->handshake_writer(),
m_state,
m_policy,
m_rng,
@@ -410,10 +411,10 @@ void Client::process_handshake_msg(Handshake_Type type,
const std::string protocol =
m_state->client_npn_cb(m_state->server_hello->next_protocols());
- m_state->next_protocol = new Next_Protocol(m_writer, m_state->hash, protocol);
+ m_state->next_protocol = new Next_Protocol(m_state->handshake_writer(), m_state->hash, protocol);
}
- m_state->client_finished = new Finished(m_writer, m_state, CLIENT);
+ m_state->client_finished = new Finished(m_state->handshake_writer(), m_state, CLIENT);
if(m_state->server_hello->supports_session_ticket())
m_state->set_expected_next(NEW_SESSION_TICKET);
@@ -452,7 +453,8 @@ void Client::process_handshake_msg(Handshake_Type type,
m_writer.activate(CLIENT, m_state->suite, m_state->keys,
m_state->server_hello->compression_method());
- m_state->client_finished = new Finished(m_writer, m_state, CLIENT);
+ m_state->client_finished = new Finished(m_state->handshake_writer(),
+ m_state, CLIENT);
}
m_secure_renegotiation.update(m_state->client_finished, m_state->server_finished);
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index b62e4aadf..ad13a94dc 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -73,7 +73,7 @@ class BOTAN_DLL Client : public Channel
void alert_notify(const Alert& alert) override;
- class Handshake_Reader* new_handshake_reader() const override;
+ class Handshake_State* new_handshake_state() override;
const Policy& m_policy;
RandomNumberGenerator& m_rng;
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 8bb251b73..304366719 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -85,33 +85,12 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State(Handshake_Reader* reader)
+Handshake_State::Handshake_State(Handshake_Reader* reader,
+ Handshake_Writer* writer) :
+ m_handshake_reader(reader),
+ m_handshake_writer(writer),
+ m_version(Protocol_Version::SSL_V3)
{
- client_hello = nullptr;
- server_hello = nullptr;
- server_certs = nullptr;
- server_kex = nullptr;
- cert_req = nullptr;
- server_hello_done = nullptr;
- next_protocol = nullptr;
- new_session_ticket = nullptr;
-
- client_certs = nullptr;
- client_kex = nullptr;
- client_verify = nullptr;
- client_finished = nullptr;
- server_finished = nullptr;
-
- m_handshake_reader = reader;
-
- server_rsa_kex_key = nullptr;
-
- m_version = Protocol_Version::SSL_V3;
-
- hand_expecting_mask = 0;
- hand_received_mask = 0;
-
- allow_session_resumption = true;
}
void Handshake_State::set_version(const Protocol_Version& version)
@@ -123,33 +102,33 @@ void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg)
{
const u32bit mask = bitmask_for_handshake_type(handshake_msg);
- hand_received_mask |= mask;
+ m_hand_received_mask |= mask;
- const bool ok = (hand_expecting_mask & mask); // overlap?
+ const bool ok = (m_hand_expecting_mask & mask); // overlap?
if(!ok)
throw Unexpected_Message("Unexpected state transition in handshake, got " +
std::to_string(handshake_msg) +
- " expected " + std::to_string(hand_expecting_mask) +
- " received " + std::to_string(hand_received_mask));
+ " expected " + std::to_string(m_hand_expecting_mask) +
+ " received " + std::to_string(m_hand_received_mask));
/* We don't know what to expect next, so force a call to
set_expected_next; if it doesn't happen, the next transition
check will always fail which is what we want.
*/
- hand_expecting_mask = 0;
+ m_hand_expecting_mask = 0;
}
void Handshake_State::set_expected_next(Handshake_Type handshake_msg)
{
- hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg);
+ m_hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg);
}
bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const
{
const u32bit mask = bitmask_for_handshake_type(handshake_msg);
- return (hand_received_mask & mask);
+ return (m_hand_received_mask & mask);
}
std::string Handshake_State::srp_identifier() const
@@ -370,6 +349,7 @@ Handshake_State::~Handshake_State()
delete server_finished;
delete m_handshake_reader;
+ delete m_handshake_writer;
}
}
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 521da0205..0f48c976b 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -10,6 +10,7 @@
#include <botan/internal/tls_handshake_hash.h>
#include <botan/internal/tls_handshake_reader.h>
+#include <botan/internal/tls_handshake_writer.h>
#include <botan/internal/tls_session_key.h>
#include <botan/pk_keys.h>
#include <botan/pubkey.h>
@@ -31,7 +32,9 @@ class Policy;
class Handshake_State
{
public:
- Handshake_State(Handshake_Reader* reader);
+ Handshake_State(Handshake_Reader* reader,
+ Handshake_Writer* writer);
+
~Handshake_State();
Handshake_State(const Handshake_State&) = delete;
@@ -65,25 +68,25 @@ class Handshake_State
void set_version(const Protocol_Version& version);
- class Client_Hello* client_hello;
- class Server_Hello* server_hello;
- class Certificate* server_certs;
- class Server_Key_Exchange* server_kex;
- class Certificate_Req* cert_req;
- class Server_Hello_Done* server_hello_done;
+ class Client_Hello* client_hello = nullptr;
+ class Server_Hello* server_hello = nullptr;
+ class Certificate* server_certs = nullptr;
+ class Server_Key_Exchange* server_kex = nullptr;
+ class Certificate_Req* cert_req = nullptr;
+ class Server_Hello_Done* server_hello_done = nullptr;
- class Certificate* client_certs;
- class Client_Key_Exchange* client_kex;
- class Certificate_Verify* client_verify;
+ class Certificate* client_certs = nullptr;
+ class Client_Key_Exchange* client_kex = nullptr;
+ class Certificate_Verify* client_verify = nullptr;
- class Next_Protocol* next_protocol;
- class New_Session_Ticket* new_session_ticket;
+ class Next_Protocol* next_protocol = nullptr;
+ class New_Session_Ticket* new_session_ticket = nullptr;
- class Finished* client_finished;
- class Finished* server_finished;
+ class Finished* client_finished = nullptr;
+ class Finished* server_finished = nullptr;
// Used by the server only, in case of RSA key exchange
- Private_Key* server_rsa_kex_key;
+ Private_Key* server_rsa_kex_key = nullptr;
Ciphersuite suite;
Session_Keys keys;
@@ -95,19 +98,25 @@ class Handshake_State
secure_vector<byte> resume_master_secret;
/*
- *
+ * Used by the server to know if resumption should be allowed on
+ * a server-initiated renegotiation
*/
- bool allow_session_resumption;
+ bool allow_session_resumption = true;
/**
* Used by client using NPN
*/
std::function<std::string (std::vector<std::string>)> client_npn_cb;
- Handshake_Reader* handshake_reader() { return m_handshake_reader; }
+ Handshake_Reader& handshake_reader() { return *m_handshake_reader; }
+
+ Handshake_Writer& handshake_writer() { return *m_handshake_writer; }
private:
- Handshake_Reader* m_handshake_reader;
- u32bit hand_expecting_mask, hand_received_mask;
+ Handshake_Reader* m_handshake_reader = nullptr;
+ Handshake_Writer* m_handshake_writer = nullptr;
+
+ u32bit m_hand_expecting_mask = 0;
+ u32bit m_hand_received_mask = 0;
Protocol_Version m_version;
};
diff --git a/src/tls/tls_handshake_writer.cpp b/src/tls/tls_handshake_writer.cpp
new file mode 100644
index 000000000..b237e8f3a
--- /dev/null
+++ b/src/tls/tls_handshake_writer.cpp
@@ -0,0 +1,38 @@
+/*
+* Handshake Message Writer
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_messages.h>
+#include <botan/tls_record.h>
+#include <botan/exceptn.h>
+
+namespace Botan {
+
+namespace TLS {
+
+std::vector<byte> Stream_Handshake_Writer::send(Handshake_Message& msg)
+ {
+ const std::vector<byte> buf = msg.serialize();
+ std::vector<byte> send_buf(4);
+
+ const size_t buf_size = buf.size();
+
+ send_buf[0] = msg.type();
+
+ for(size_t i = 1; i != 4; ++i)
+ send_buf[i] = get_byte<u32bit>(i, buf_size);
+
+ send_buf += buf;
+
+ m_writer.send(HANDSHAKE, &send_buf[0], send_buf.size());
+
+ return send_buf;
+ }
+
+}
+
+}
diff --git a/src/tls/tls_handshake_writer.h b/src/tls/tls_handshake_writer.h
new file mode 100644
index 000000000..0d6ddb0a0
--- /dev/null
+++ b/src/tls/tls_handshake_writer.h
@@ -0,0 +1,52 @@
+/*
+* TLS Handshake Writer
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_HANDSHAKE_WRITER_H__
+#define BOTAN_TLS_HANDSHAKE_WRITER_H__
+
+#include <botan/tls_magic.h>
+#include <botan/loadstor.h>
+#include <vector>
+#include <deque>
+#include <utility>
+
+namespace Botan {
+
+namespace TLS {
+
+class Record_Writer;
+class Handshake_Message;
+
+/**
+* Handshake Writer
+*/
+class Handshake_Writer
+ {
+ public:
+ virtual std::vector<byte> send(Handshake_Message& msg) = 0;
+
+ virtual ~Handshake_Writer() {}
+ };
+
+/**
+* Stream Handshake Writer
+*/
+class Stream_Handshake_Writer : public Handshake_Writer
+ {
+ public:
+ Stream_Handshake_Writer(Record_Writer& writer) : m_writer(writer) {}
+
+ std::vector<byte> send(Handshake_Message& msg) override;
+ private:
+ Record_Writer& m_writer;
+ };
+
+}
+
+}
+
+#endif
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 2e8bf9ba3..a0e7d8630 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -25,8 +25,7 @@ class SRP6_Server_Session;
namespace TLS {
-class Record_Writer;
-class Record_Reader;
+class Handshake_Writer;
/**
* TLS Handshake Message Base Class
@@ -113,7 +112,7 @@ class Client_Hello : public Handshake_Message
bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; }
- Client_Hello(Record_Writer& writer,
+ Client_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
Protocol_Version version,
const Policy& policy,
@@ -123,7 +122,7 @@ class Client_Hello : public Handshake_Message
const std::string& hostname = "",
const std::string& srp_identifier = "");
- Client_Hello(Record_Writer& writer,
+ Client_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -197,7 +196,7 @@ class Server_Hello : public Handshake_Message
bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; }
- Server_Hello(Record_Writer& writer,
+ Server_Hello(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<byte>& session_id,
Protocol_Version ver,
@@ -244,7 +243,7 @@ class Client_Key_Exchange : public Handshake_Message
const secure_vector<byte>& pre_master_secret() const
{ return pre_master; }
- Client_Key_Exchange(Record_Writer& output,
+ Client_Key_Exchange(Handshake_Writer& output,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -277,7 +276,7 @@ class Certificate : public Handshake_Message
size_t count() const { return m_certs.size(); }
bool empty() const { return m_certs.empty(); }
- Certificate(Record_Writer& writer,
+ Certificate(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& certs);
@@ -304,7 +303,7 @@ class Certificate_Req : public Handshake_Message
std::vector<std::pair<std::string, std::string> > supported_algos() const
{ return m_supported_algos; }
- Certificate_Req(Record_Writer& writer,
+ Certificate_Req(Handshake_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& allowed_cas,
@@ -337,7 +336,7 @@ class Certificate_Verify : public Handshake_Message
bool verify(const X509_Certificate& cert,
Handshake_State* state);
- Certificate_Verify(Record_Writer& writer,
+ Certificate_Verify(Handshake_Writer& writer,
Handshake_State* state,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -367,7 +366,7 @@ class Finished : public Handshake_Message
bool verify(Handshake_State* state,
Connection_Side side);
- Finished(Record_Writer& writer,
+ Finished(Handshake_Writer& writer,
Handshake_State* state,
Connection_Side side);
@@ -387,7 +386,7 @@ class Hello_Request : public Handshake_Message
public:
Handshake_Type type() const { return HELLO_REQUEST; }
- Hello_Request(Record_Writer& writer);
+ Hello_Request(Handshake_Writer& writer);
Hello_Request(const std::vector<byte>& buf);
private:
std::vector<byte> serialize() const;
@@ -412,7 +411,7 @@ class Server_Key_Exchange : public Handshake_Message
// Only valid for SRP negotiation
SRP6_Server_Session& server_srp_params();
- Server_Key_Exchange(Record_Writer& writer,
+ Server_Key_Exchange(Handshake_Writer& writer,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -446,7 +445,7 @@ class Server_Hello_Done : public Handshake_Message
public:
Handshake_Type type() const { return SERVER_HELLO_DONE; }
- Server_Hello_Done(Record_Writer& writer, Handshake_Hash& hash);
+ Server_Hello_Done(Handshake_Writer& writer, Handshake_Hash& hash);
Server_Hello_Done(const std::vector<byte>& buf);
private:
std::vector<byte> serialize() const;
@@ -462,7 +461,7 @@ class Next_Protocol : public Handshake_Message
std::string protocol() const { return m_protocol; }
- Next_Protocol(Record_Writer& writer,
+ Next_Protocol(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::string& protocol);
@@ -481,12 +480,12 @@ class New_Session_Ticket : public Handshake_Message
u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; }
const std::vector<byte>& ticket() const { return m_ticket; }
- New_Session_Ticket(Record_Writer& writer,
+ New_Session_Ticket(Handshake_Writer& writer,
Handshake_Hash& hash,
const std::vector<byte>& ticket,
u32bit lifetime);
- New_Session_Ticket(Record_Writer& writer,
+ New_Session_Ticket(Handshake_Writer& writer,
Handshake_Hash& hash);
New_Session_Ticket(const std::vector<byte>& buf);
diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp
index 99ac66369..76492a668 100644
--- a/src/tls/tls_policy.cpp
+++ b/src/tls/tls_policy.cpp
@@ -83,11 +83,11 @@ std::vector<std::string> Policy::allowed_ecc_curves() const
"secp256k1",
"secp224r1",
"secp224k1",
- //"secp192r1",
- //"secp192k1",
- //"secp160r2",
- //"secp160r1",
- //"secp160k1",
+ "secp192r1",
+ "secp192k1",
+ "secp160r2",
+ "secp160r1",
+ "secp160k1",
});
}
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 0b67f9a63..924d25f80 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -35,8 +35,6 @@ class BOTAN_DLL Record_Writer
void send(byte type, const std::vector<byte>& input)
{ send(type, &input[0], input.size()); }
- std::vector<byte> send(class Handshake_Message& msg);
-
void send_alert(const Alert& alert);
void activate(Connection_Side side,
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 97db6934e..d6d408db5 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -200,9 +200,10 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
{
}
-Handshake_Reader* Server::new_handshake_reader() const
+Handshake_State* Server::new_handshake_state()
{
- return new Stream_Handshake_Reader;
+ return new Handshake_State(new Stream_Handshake_Reader,
+ new Stream_Handshake_Writer(m_writer));
}
/*
@@ -213,11 +214,11 @@ void Server::renegotiate(bool force_full_renegotiation)
if(m_state)
return; // currently in handshake
- m_state = new Handshake_State(this->new_handshake_reader());
+ m_state = new_handshake_state();
m_state->allow_session_resumption = !force_full_renegotiation;
m_state->set_expected_next(CLIENT_HELLO);
- Hello_Request hello_req(m_writer);
+ Hello_Request hello_req(m_state->handshake_writer());
}
void Server::alert_notify(const Alert& alert)
@@ -240,7 +241,7 @@ void Server::read_handshake(byte rec_type,
{
if(rec_type == HANDSHAKE && !m_state)
{
- m_state = new Handshake_State(this->new_handshake_reader());
+ m_state = new_handshake_state();
m_state->set_expected_next(CLIENT_HELLO);
}
@@ -368,7 +369,7 @@ void Server::process_handshake_msg(Handshake_Type type,
// resume session
m_state->server_hello = new Server_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
m_state->client_hello->session_id(),
Protocol_Version(session_info.version()),
@@ -402,7 +403,11 @@ void Server::process_handshake_msg(Handshake_Type type,
m_session_manager.remove_entry(session_info.session_id());
if(m_state->server_hello->supports_session_ticket()) // send an empty ticket
- m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash);
+ {
+ m_state->new_session_ticket =
+ new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->hash);
+ }
}
if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket)
@@ -412,14 +417,19 @@ void Server::process_handshake_msg(Handshake_Type type,
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
- new New_Session_Ticket(m_writer, m_state->hash,
+ new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->hash,
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime());
}
catch(...) {}
if(!m_state->new_session_ticket)
- m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash);
+ {
+ m_state->new_session_ticket =
+ new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->hash);
+ }
}
m_writer.send(CHANGE_CIPHER_SPEC, 1);
@@ -427,7 +437,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_writer.activate(SERVER, m_state->suite, m_state->keys,
m_state->server_hello->compression_method());
- m_state->server_finished = new Finished(m_writer, m_state, SERVER);
+ m_state->server_finished = new Finished(m_state->handshake_writer(), m_state, SERVER);
m_state->set_expected_next(HANDSHAKE_CCS);
}
@@ -453,7 +463,7 @@ void Server::process_handshake_msg(Handshake_Type type,
}
m_state->server_hello = new Server_Hello(
- m_writer,
+ m_state->handshake_writer(),
m_state->hash,
make_hello_random(m_rng), // new session ID
m_state->version(),
@@ -486,9 +496,9 @@ void Server::process_handshake_msg(Handshake_Type type,
BOTAN_ASSERT(!cert_chains[sig_algo].empty(),
"Attempting to send empty certificate chain");
- m_state->server_certs = new Certificate(m_writer,
- m_state->hash,
- cert_chains[sig_algo]);
+ m_state->server_certs = new Certificate(m_state->handshake_writer(),
+ m_state->hash,
+ cert_chains[sig_algo]);
}
Private_Key* private_key = nullptr;
@@ -511,7 +521,12 @@ void Server::process_handshake_msg(Handshake_Type type,
else
{
m_state->server_kex =
- new Server_Key_Exchange(m_writer, m_state, m_policy, m_creds, m_rng, private_key);
+ new Server_Key_Exchange(m_state->handshake_writer(),
+ m_state,
+ m_policy,
+ m_creds,
+ m_rng,
+ private_key);
}
std::vector<X509_Certificate> client_auth_CAs =
@@ -519,11 +534,11 @@ void Server::process_handshake_msg(Handshake_Type type,
if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "")
{
- m_state->cert_req = new Certificate_Req(m_writer,
- m_state->hash,
- m_policy,
- client_auth_CAs,
- m_state->version());
+ m_state->cert_req = new Certificate_Req(m_state->handshake_writer(),
+ m_state->hash,
+ m_policy,
+ client_auth_CAs,
+ m_state->version());
m_state->set_expected_next(CERTIFICATE);
}
@@ -535,7 +550,8 @@ void Server::process_handshake_msg(Handshake_Type type,
*/
m_state->set_expected_next(CLIENT_KEX);
- m_state->server_hello_done = new Server_Hello_Done(m_writer, m_state->hash);
+ m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_writer(),
+ m_state->hash);
}
}
else if(type == CERTIFICATE)
@@ -643,7 +659,8 @@ void Server::process_handshake_msg(Handshake_Type type,
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
- new New_Session_Ticket(m_writer, m_state->hash,
+ new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->hash,
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime());
}
@@ -654,14 +671,18 @@ void Server::process_handshake_msg(Handshake_Type type,
}
if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket)
- m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash);
+ {
+ m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->hash);
+
+ }
m_writer.send(CHANGE_CIPHER_SPEC, 1);
m_writer.activate(SERVER, m_state->suite, m_state->keys,
m_state->server_hello->compression_method());
- m_state->server_finished = new Finished(m_writer, m_state, SERVER);
+ m_state->server_finished = new Finished(m_state->handshake_writer(), m_state, SERVER);
}
m_secure_renegotiation.update(m_state->client_finished,
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 2d1502e1f..c0e687604 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -56,7 +56,7 @@ class BOTAN_DLL Server : public Channel
void alert_notify(const Alert& alert) override;
- class Handshake_Reader* new_handshake_reader() const override;
+ class Handshake_State* new_handshake_state() override;
const Policy& m_policy;
RandomNumberGenerator& m_rng;