aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-08-03 14:40:08 +0000
committerlloyd <[email protected]>2012-08-03 14:40:08 +0000
commitdb2a5f10716f69a58f8c554c8e65d21e198ffbc5 (patch)
tree1c7a302a3f34fb46f201bf6b658884421609a559
parentba0e7cc86e7fa6606a04c3ae34be354d8ed801b3 (diff)
Combine Handshake_Writer and Handshake_Reader into Handshake_IO.
This is mostly just a minor code savings for TLS, but it actually seems important for DTLS because getting a handshake message can be a trigger for retransmitting previously sent handshake messages in some circumstances. Having the reading and writing all in one layer makes it a bit easier to accomplish that.
-rw-r--r--src/tls/c_hello.cpp14
-rw-r--r--src/tls/c_kex.cpp6
-rw-r--r--src/tls/cert_req.cpp10
-rw-r--r--src/tls/cert_ver.cpp6
-rw-r--r--src/tls/finished.cpp6
-rw-r--r--src/tls/info.txt6
-rw-r--r--src/tls/next_protocol.cpp6
-rw-r--r--src/tls/s_hello.cpp10
-rw-r--r--src/tls/s_kex.cpp6
-rw-r--r--src/tls/session_ticket.cpp10
-rw-r--r--src/tls/tls_channel.cpp6
-rw-r--r--src/tls/tls_client.cpp23
-rw-r--r--src/tls/tls_handshake_io.cpp (renamed from src/tls/tls_handshake_reader.cpp)55
-rw-r--r--src/tls/tls_handshake_io.h (renamed from src/tls/tls_handshake_reader.h)40
-rw-r--r--src/tls/tls_handshake_state.cpp9
-rw-r--r--src/tls/tls_handshake_state.h13
-rw-r--r--src/tls/tls_handshake_writer.cpp56
-rw-r--r--src/tls/tls_handshake_writer.h66
-rw-r--r--src/tls/tls_messages.h30
-rw-r--r--src/tls/tls_server.cpp37
20 files changed, 170 insertions, 245 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index eacbecc6c..2d2e03752 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -9,7 +9,7 @@
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/stl_util.h>
#include <chrono>
@@ -36,9 +36,9 @@ std::vector<byte> make_hello_random(RandomNumberGenerator& rng)
/*
* Create a new Hello Request message
*/
-Hello_Request::Hello_Request(Handshake_Writer& writer)
+Hello_Request::Hello_Request(Handshake_IO& io)
{
- writer.send(*this);
+ io.send(*this);
}
/*
@@ -61,7 +61,7 @@ std::vector<byte> Hello_Request::serialize() const
/*
* Create a new Client Hello message
*/
-Client_Hello::Client_Hello(Handshake_Writer& writer,
+Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
Protocol_Version version,
const Policy& policy,
@@ -92,13 +92,13 @@ Client_Hello::Client_Hello(Handshake_Writer& writer,
for(size_t j = 0; j != sigs.size(); ++j)
m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/*
* Create a new Client Hello message (session resumption case)
*/
-Client_Hello::Client_Hello(Handshake_Writer& writer,
+Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -135,7 +135,7 @@ Client_Hello::Client_Hello(Handshake_Writer& writer,
for(size_t j = 0; j != sigs.size(); ++j)
m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/*
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index f1b2306f1..1bf41cff2 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/pubkey.h>
@@ -47,7 +47,7 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input)
/*
* Create a new Client Key Exchange message
*/
-Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer,
+Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -259,7 +259,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer,
pub_key->algo_name());
}
- state->hash.update(writer.send(*this));
+ state->hash.update(io.send(*this));
}
/*
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index 0806f5f66..5087865d4 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/der_enc.h>
#include <botan/ber_dec.h>
#include <botan/loadstor.h>
@@ -51,7 +51,7 @@ byte cert_type_name_to_code(const std::string& name)
/**
* Create a new Certificate Request message
*/
-Certificate_Req::Certificate_Req(Handshake_Writer& writer,
+Certificate_Req::Certificate_Req(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& ca_certs,
@@ -74,7 +74,7 @@ Certificate_Req::Certificate_Req(Handshake_Writer& writer,
m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
}
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/**
@@ -166,12 +166,12 @@ std::vector<byte> Certificate_Req::serialize() const
/**
* Create a new Certificate message
*/
-Certificate::Certificate(Handshake_Writer& writer,
+Certificate::Certificate(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& cert_list) :
m_certs(cert_list)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/**
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index 4dbae9da3..7a58ea28a 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/assert.h>
#include <memory>
@@ -19,7 +19,7 @@ namespace TLS {
/*
* Create a new Certificate Verify message
*/
-Certificate_Verify::Certificate_Verify(Handshake_Writer& writer,
+Certificate_Verify::Certificate_Verify(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -47,7 +47,7 @@ Certificate_Verify::Certificate_Verify(Handshake_Writer& writer,
signature = signer.sign_message(state->hash.get_contents(), rng);
}
- state->hash.update(writer.send(*this));
+ state->hash.update(io.send(*this));
}
/*
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index 4dcc9e1ae..9205331de 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -6,7 +6,7 @@
*/
#include <botan/internal/tls_messages.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <memory>
namespace Botan {
@@ -66,12 +66,12 @@ std::vector<byte> finished_compute_verify(Handshake_State* state,
/*
* Create a new Finished message
*/
-Finished::Finished(Handshake_Writer& writer,
+Finished::Finished(Handshake_IO& io,
Handshake_State* state,
Connection_Side side)
{
verification_data = finished_compute_verify(state, side);
- state->hash.update(writer.send(*this));
+ state->hash.update(io.send(*this));
}
/*
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 212562373..bc2bc41c3 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -25,9 +25,8 @@ tls_version.h
<header:internal>
tls_extensions.h
tls_handshake_hash.h
-tls_handshake_reader.h
+tls_handshake_io.h
tls_handshake_state.h
-tls_handshake_writer.h
tls_heartbeats.h
tls_messages.h
tls_reader.h
@@ -53,9 +52,8 @@ tls_ciphersuite.cpp
tls_client.cpp
tls_extensions.cpp
tls_handshake_hash.cpp
-tls_handshake_reader.cpp
+tls_handshake_io.cpp
tls_handshake_state.cpp
-tls_handshake_writer.cpp
tls_heartbeats.cpp
tls_policy.cpp
tls_server.cpp
diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp
index a8989c5a9..71bb0eb9e 100644
--- a/src/tls/next_protocol.cpp
+++ b/src/tls/next_protocol.cpp
@@ -8,18 +8,18 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_extensions.h>
#include <botan/internal/tls_reader.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
namespace Botan {
namespace TLS {
-Next_Protocol::Next_Protocol(Handshake_Writer& writer,
+Next_Protocol::Next_Protocol(Handshake_IO& io,
Handshake_Hash& hash,
const std::string& protocol) :
m_protocol(protocol)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
Next_Protocol::Next_Protocol(const std::vector<byte>& buf)
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index d34fa5e70..8d151b2b0 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -9,7 +9,7 @@
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/stl_util.h>
namespace Botan {
@@ -19,7 +19,7 @@ namespace TLS {
/*
* Create a new Server Hello message
*/
-Server_Hello::Server_Hello(Handshake_Writer& writer,
+Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<byte>& session_id,
Protocol_Version ver,
@@ -47,7 +47,7 @@ Server_Hello::Server_Hello(Handshake_Writer& writer,
m_supports_heartbeats(client_has_heartbeat),
m_peer_can_send_heartbeats(true)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/*
@@ -149,10 +149,10 @@ std::vector<byte> Server_Hello::serialize() const
/*
* Create a new Server Hello Done message
*/
-Server_Hello_Done::Server_Hello_Done(Handshake_Writer& writer,
+Server_Hello_Done::Server_Hello_Done(Handshake_IO& io,
Handshake_Hash& hash)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
/*
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 423497976..2d27c6f0d 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -8,7 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/loadstor.h>
@@ -27,7 +27,7 @@ namespace TLS {
/**
* Create a new Server Key Exchange message
*/
-Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer,
+Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -136,7 +136,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer,
m_signature = signer.signature(rng);
}
- state->hash.update(writer.send(*this));
+ state->hash.update(io.send(*this));
}
/**
diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp
index 3affe8fcf..2bb9987a9 100644
--- a/src/tls/session_ticket.cpp
+++ b/src/tls/session_ticket.cpp
@@ -8,28 +8,28 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_extensions.h>
#include <botan/internal/tls_reader.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/loadstor.h>
namespace Botan {
namespace TLS {
-New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer,
+New_Session_Ticket::New_Session_Ticket(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<byte>& ticket,
u32bit lifetime) :
m_ticket_lifetime_hint(lifetime),
m_ticket(ticket)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
-New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer,
+New_Session_Ticket::New_Session_Ticket(Handshake_IO& io,
Handshake_Hash& hash) :
m_ticket_lifetime_hint(0)
{
- hash.update(writer.send(*this));
+ hash.update(io.send(*this));
}
New_Session_Ticket::New_Session_Ticket(const std::vector<byte>& buf) :
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 4c9c12d92..0c1f9fd09 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -174,12 +174,12 @@ void Channel::read_handshake(byte rec_type,
if(!m_state)
m_state.reset(new_handshake_state());
- m_state->handshake_reader().add_input(rec_type, &rec_buf[0], rec_buf.size());
+ m_state->handshake_io().add_input(rec_type, &rec_buf[0], rec_buf.size());
- while(m_state && m_state->handshake_reader().have_full_record())
+ while(m_state && m_state->handshake_io().have_full_record())
{
std::pair<Handshake_Type, std::vector<byte> > msg =
- m_state->handshake_reader().get_next_record();
+ m_state->handshake_io().get_next_record();
process_handshake_msg(msg.first, msg.second);
}
}
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 17a7879d6..77ff010f3 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -45,8 +45,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
Handshake_State* Client::new_handshake_state()
{
- return new Handshake_State(new Stream_Handshake_Reader,
- new Stream_Handshake_Writer(m_writer));
+ return new Handshake_State(new Stream_Handshake_IO(m_writer));
}
/*
@@ -84,7 +83,7 @@ void Client::initiate_handshake(bool force_full_renegotiation,
if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier)
{
m_state->client_hello = new Client_Hello(
- m_state->handshake_writer(),
+ m_state->handshake_io(),
m_state->hash,
m_policy,
m_rng,
@@ -100,7 +99,7 @@ void Client::initiate_handshake(bool force_full_renegotiation,
if(!m_state->client_hello) // not resuming
{
m_state->client_hello = new Client_Hello(
- m_state->handshake_writer(),
+ m_state->handshake_io(),
m_state->hash,
version,
m_policy,
@@ -157,7 +156,7 @@ void Client::process_handshake_msg(Handshake_Type type,
m_state->confirm_transition_to(type);
if(type != HANDSHAKE_CCS && type != FINISHED)
- m_state->hash.update(m_state->handshake_writer().format(contents, type));
+ m_state->hash.update(m_state->handshake_io().format(contents, type));
if(type == SERVER_HELLO)
{
@@ -344,13 +343,13 @@ void Client::process_handshake_msg(Handshake_Type type,
"tls-client",
m_hostname);
- m_state->client_certs = new Certificate(m_state->handshake_writer(),
+ m_state->client_certs = new Certificate(m_state->handshake_io(),
m_state->hash,
client_certs);
}
m_state->client_kex =
- new Client_Key_Exchange(m_state->handshake_writer(),
+ new Client_Key_Exchange(m_state->handshake_io(),
m_state.get(),
m_policy,
m_creds,
@@ -370,7 +369,7 @@ void Client::process_handshake_msg(Handshake_Type type,
"tls-client",
m_hostname);
- m_state->client_verify = new Certificate_Verify(m_state->handshake_writer(),
+ m_state->client_verify = new Certificate_Verify(m_state->handshake_io(),
m_state.get(),
m_policy,
m_rng,
@@ -389,10 +388,10 @@ void Client::process_handshake_msg(Handshake_Type type,
const std::string protocol =
m_state->client_npn_cb(m_state->server_hello->next_protocols());
- m_state->next_protocol = new Next_Protocol(m_state->handshake_writer(), m_state->hash, protocol);
+ m_state->next_protocol = new Next_Protocol(m_state->handshake_io(), m_state->hash, protocol);
}
- m_state->client_finished = new Finished(m_state->handshake_writer(),
+ m_state->client_finished = new Finished(m_state->handshake_io(),
m_state.get(), CLIENT);
if(m_state->server_hello->supports_session_ticket())
@@ -425,7 +424,7 @@ void Client::process_handshake_msg(Handshake_Type type,
throw TLS_Exception(Alert::DECRYPT_ERROR,
"Finished message didn't verify");
- m_state->hash.update(m_state->handshake_writer().format(contents, type));
+ m_state->hash.update(m_state->handshake_io().format(contents, type));
if(!m_state->client_finished) // session resume case
{
@@ -436,7 +435,7 @@ void Client::process_handshake_msg(Handshake_Type type,
m_state->keys,
m_state->server_hello->compression_method());
- m_state->client_finished = new Finished(m_state->handshake_writer(),
+ m_state->client_finished = new Finished(m_state->handshake_io(),
m_state.get(), CLIENT);
}
diff --git a/src/tls/tls_handshake_reader.cpp b/src/tls/tls_handshake_io.cpp
index 3721ec5b5..fe1b9c790 100644
--- a/src/tls/tls_handshake_reader.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -1,11 +1,13 @@
/*
-* TLS Handshake Reader
+* TLS Handshake IO
* (C) 2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
-#include <botan/internal/tls_handshake_reader.h>
+#include <botan/internal/tls_handshake_io.h>
+#include <botan/internal/tls_messages.h>
+#include <botan/tls_record.h>
#include <botan/exceptn.h>
namespace Botan {
@@ -22,12 +24,18 @@ inline size_t load_be24(const byte q[3])
q[2]);
}
-}
+void store_be24(byte out[3], size_t val)
+ {
+ out[0] = get_byte<u32bit>(1, val);
+ out[1] = get_byte<u32bit>(2, val);
+ out[2] = get_byte<u32bit>(3, val);
+ }
+}
-void Stream_Handshake_Reader::add_input(const byte rec_type,
- const byte record[],
- size_t record_size)
+void Stream_Handshake_IO::add_input(const byte rec_type,
+ const byte record[],
+ size_t record_size)
{
if(rec_type == HANDSHAKE)
{
@@ -45,12 +53,12 @@ void Stream_Handshake_Reader::add_input(const byte rec_type,
throw Decoding_Error("Unknown message type in handshake processing");
}
-bool Stream_Handshake_Reader::empty() const
+bool Stream_Handshake_IO::empty() const
{
return m_queue.empty();
}
-bool Stream_Handshake_Reader::have_full_record() const
+bool Stream_Handshake_IO::have_full_record() const
{
if(m_queue.size() >= 4)
{
@@ -62,7 +70,8 @@ bool Stream_Handshake_Reader::have_full_record() const
return false;
}
-std::pair<Handshake_Type, std::vector<byte> > Stream_Handshake_Reader::get_next_record()
+std::pair<Handshake_Type, std::vector<byte> >
+Stream_Handshake_IO::get_next_record()
{
if(m_queue.size() >= 4)
{
@@ -81,7 +90,33 @@ std::pair<Handshake_Type, std::vector<byte> > Stream_Handshake_Reader::get_next_
}
}
- throw Internal_Error("Stream_Handshake_Reader::get_next_record called without a full record");
+ throw Internal_Error("Stream_Handshake_IO::get_next_record called without a full record");
+ }
+
+std::vector<byte>
+Stream_Handshake_IO::format(const std::vector<byte>& msg,
+ Handshake_Type type)
+ {
+ std::vector<byte> send_buf(4 + msg.size());
+
+ const size_t buf_size = msg.size();
+
+ send_buf[0] = type;
+
+ store_be24(&send_buf[1], buf_size);
+
+ copy_mem(&send_buf[4], &msg[0], msg.size());
+
+ return send_buf;
+ }
+
+std::vector<byte> Stream_Handshake_IO::send(Handshake_Message& msg)
+ {
+ const std::vector<byte> buf = format(msg.serialize(), msg.type());
+
+ m_writer.send(HANDSHAKE, &buf[0], buf.size());
+
+ return buf;
}
}
diff --git a/src/tls/tls_handshake_reader.h b/src/tls/tls_handshake_io.h
index 791a2628a..f71b2c034 100644
--- a/src/tls/tls_handshake_reader.h
+++ b/src/tls/tls_handshake_io.h
@@ -1,12 +1,12 @@
/*
-* TLS Handshake Reader
+* TLS Handshake Serialization
* (C) 2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
-#ifndef BOTAN_TLS_HANDSHAKE_READER_H__
-#define BOTAN_TLS_HANDSHAKE_READER_H__
+#ifndef BOTAN_TLS_HANDSHAKE_IO_H__
+#define BOTAN_TLS_HANDSHAKE_IO_H__
#include <botan/tls_magic.h>
#include <botan/loadstor.h>
@@ -18,12 +18,21 @@ namespace Botan {
namespace TLS {
+class Record_Writer;
+class Handshake_Message;
+
/**
-* Handshake Reader Interface
+* Handshake IO Interface
*/
-class Handshake_Reader
+class Handshake_IO
{
public:
+ virtual std::vector<byte> send(Handshake_Message& msg) = 0;
+
+ virtual std::vector<byte> format(
+ const std::vector<byte>& handshake_msg,
+ Handshake_Type handshake_type) = 0;
+
virtual void add_input(byte record_type,
const byte record[],
size_t record_size) = 0;
@@ -34,15 +43,29 @@ class Handshake_Reader
virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0;
- virtual ~Handshake_Reader() {}
+ Handshake_IO() {}
+
+ Handshake_IO(const Handshake_IO&) = delete;
+
+ Handshake_IO& operator=(const Handshake_IO&) = delete;
+
+ virtual ~Handshake_IO() {}
};
/**
-* Reader of TLS handshake messages
+* Handshake IO for stream-based handshakes
*/
-class Stream_Handshake_Reader : public Handshake_Reader
+class Stream_Handshake_IO : public Handshake_IO
{
public:
+ Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {}
+
+ std::vector<byte> send(Handshake_Message& msg) override;
+
+ std::vector<byte> format(
+ const std::vector<byte>& handshake_msg,
+ Handshake_Type handshake_type) override;
+
void add_input(byte record_type,
const byte record[],
size_t record_size) override;
@@ -54,6 +77,7 @@ class Stream_Handshake_Reader : public Handshake_Reader
std::pair<Handshake_Type, std::vector<byte> > get_next_record() override;
private:
std::deque<byte> m_queue;
+ Record_Writer& m_writer;
};
}
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 023b1816a..77a1b52fc 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -85,10 +85,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State(Handshake_Reader* reader,
- Handshake_Writer* writer) :
- m_handshake_reader(reader),
- m_handshake_writer(writer),
+Handshake_State::Handshake_State(Handshake_IO* io) :
+ m_handshake_io(io),
m_version(Protocol_Version::SSL_V3)
{
}
@@ -345,8 +343,7 @@ Handshake_State::~Handshake_State()
delete client_finished;
delete server_finished;
- delete m_handshake_reader;
- delete m_handshake_writer;
+ delete m_handshake_io;
}
}
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 0f48c976b..49470fecb 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -9,8 +9,7 @@
#define BOTAN_TLS_HANDSHAKE_STATE_H__
#include <botan/internal/tls_handshake_hash.h>
-#include <botan/internal/tls_handshake_reader.h>
-#include <botan/internal/tls_handshake_writer.h>
+#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/tls_session_key.h>
#include <botan/pk_keys.h>
#include <botan/pubkey.h>
@@ -32,8 +31,7 @@ class Policy;
class Handshake_State
{
public:
- Handshake_State(Handshake_Reader* reader,
- Handshake_Writer* writer);
+ Handshake_State(Handshake_IO* io);
~Handshake_State();
@@ -108,12 +106,9 @@ class Handshake_State
*/
std::function<std::string (std::vector<std::string>)> client_npn_cb;
- Handshake_Reader& handshake_reader() { return *m_handshake_reader; }
-
- Handshake_Writer& handshake_writer() { return *m_handshake_writer; }
+ Handshake_IO& handshake_io() { return *m_handshake_io; }
private:
- Handshake_Reader* m_handshake_reader = nullptr;
- Handshake_Writer* m_handshake_writer = nullptr;
+ Handshake_IO* m_handshake_io = nullptr;
u32bit m_hand_expecting_mask = 0;
u32bit m_hand_received_mask = 0;
diff --git a/src/tls/tls_handshake_writer.cpp b/src/tls/tls_handshake_writer.cpp
deleted file mode 100644
index 7af9a3f52..000000000
--- a/src/tls/tls_handshake_writer.cpp
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
-* Handshake Message Writer
-* (C) 2012 Jack Lloyd
-*
-* Released under the terms of the Botan license
-*/
-
-#include <botan/internal/tls_handshake_writer.h>
-#include <botan/internal/tls_messages.h>
-#include <botan/tls_record.h>
-#include <botan/exceptn.h>
-
-namespace Botan {
-
-namespace TLS {
-
-namespace {
-
-void store_be24(byte* out, size_t val)
- {
- out[0] = get_byte<u32bit>(1, val);
- out[1] = get_byte<u32bit>(2, val);
- out[2] = get_byte<u32bit>(3, val);
- }
-
-}
-
-std::vector<byte>
-Stream_Handshake_Writer::format(const std::vector<byte>& msg,
- Handshake_Type type)
- {
- std::vector<byte> send_buf(4 + msg.size());
-
- const size_t buf_size = msg.size();
-
- send_buf[0] = type;
-
- store_be24(&send_buf[1], buf_size);
-
- copy_mem(&send_buf[4], &msg[0], msg.size());
-
- return send_buf;
- }
-
-std::vector<byte> Stream_Handshake_Writer::send(Handshake_Message& msg)
- {
- const std::vector<byte> buf = format(msg.serialize(), msg.type());
-
- m_writer.send(HANDSHAKE, &buf[0], buf.size());
-
- return buf;
- }
-
-}
-
-}
diff --git a/src/tls/tls_handshake_writer.h b/src/tls/tls_handshake_writer.h
deleted file mode 100644
index 3bbb1c93e..000000000
--- a/src/tls/tls_handshake_writer.h
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
-* TLS Handshake Writer
-* (C) 2012 Jack Lloyd
-*
-* Released under the terms of the Botan license
-*/
-
-#ifndef BOTAN_TLS_HANDSHAKE_WRITER_H__
-#define BOTAN_TLS_HANDSHAKE_WRITER_H__
-
-#include <botan/tls_magic.h>
-#include <botan/loadstor.h>
-#include <vector>
-#include <deque>
-#include <utility>
-
-namespace Botan {
-
-namespace TLS {
-
-class Record_Writer;
-class Handshake_Message;
-
-/**
-* Handshake Writer
-*/
-class Handshake_Writer
- {
- public:
- virtual std::vector<byte> send(Handshake_Message& msg) = 0;
-
- virtual std::vector<byte> format(
- const std::vector<byte>& handshake_msg,
- Handshake_Type handshake_type) = 0;
-
- Handshake_Writer() {}
-
- Handshake_Writer(const Handshake_Writer&) = delete;
-
- Handshake_Writer& operator=(const Handshake_Writer&) = delete;
-
- virtual ~Handshake_Writer() {}
- };
-
-/**
-* Stream Handshake Writer
-*/
-class Stream_Handshake_Writer : public Handshake_Writer
- {
- public:
- Stream_Handshake_Writer(Record_Writer& writer) : m_writer(writer) {}
-
- std::vector<byte> send(Handshake_Message& msg) override;
-
- std::vector<byte> format(
- const std::vector<byte>& handshake_msg,
- Handshake_Type handshake_type) override;
- private:
- Record_Writer& m_writer;
- };
-
-}
-
-}
-
-#endif
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index a0e7d8630..0969aea06 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -25,7 +25,7 @@ class SRP6_Server_Session;
namespace TLS {
-class Handshake_Writer;
+class Handshake_IO;
/**
* TLS Handshake Message Base Class
@@ -112,7 +112,7 @@ class Client_Hello : public Handshake_Message
bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; }
- Client_Hello(Handshake_Writer& writer,
+ Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
Protocol_Version version,
const Policy& policy,
@@ -122,7 +122,7 @@ class Client_Hello : public Handshake_Message
const std::string& hostname = "",
const std::string& srp_identifier = "");
- Client_Hello(Handshake_Writer& writer,
+ Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -196,7 +196,7 @@ class Server_Hello : public Handshake_Message
bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; }
- Server_Hello(Handshake_Writer& writer,
+ Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<byte>& session_id,
Protocol_Version ver,
@@ -243,7 +243,7 @@ class Client_Key_Exchange : public Handshake_Message
const secure_vector<byte>& pre_master_secret() const
{ return pre_master; }
- Client_Key_Exchange(Handshake_Writer& output,
+ Client_Key_Exchange(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -276,7 +276,7 @@ class Certificate : public Handshake_Message
size_t count() const { return m_certs.size(); }
bool empty() const { return m_certs.empty(); }
- Certificate(Handshake_Writer& writer,
+ Certificate(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& certs);
@@ -303,7 +303,7 @@ class Certificate_Req : public Handshake_Message
std::vector<std::pair<std::string, std::string> > supported_algos() const
{ return m_supported_algos; }
- Certificate_Req(Handshake_Writer& writer,
+ Certificate_Req(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& allowed_cas,
@@ -336,7 +336,7 @@ class Certificate_Verify : public Handshake_Message
bool verify(const X509_Certificate& cert,
Handshake_State* state);
- Certificate_Verify(Handshake_Writer& writer,
+ Certificate_Verify(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
RandomNumberGenerator& rng,
@@ -366,7 +366,7 @@ class Finished : public Handshake_Message
bool verify(Handshake_State* state,
Connection_Side side);
- Finished(Handshake_Writer& writer,
+ Finished(Handshake_IO& io,
Handshake_State* state,
Connection_Side side);
@@ -386,7 +386,7 @@ class Hello_Request : public Handshake_Message
public:
Handshake_Type type() const { return HELLO_REQUEST; }
- Hello_Request(Handshake_Writer& writer);
+ Hello_Request(Handshake_IO& io);
Hello_Request(const std::vector<byte>& buf);
private:
std::vector<byte> serialize() const;
@@ -411,7 +411,7 @@ class Server_Key_Exchange : public Handshake_Message
// Only valid for SRP negotiation
SRP6_Server_Session& server_srp_params();
- Server_Key_Exchange(Handshake_Writer& writer,
+ Server_Key_Exchange(Handshake_IO& io,
Handshake_State* state,
const Policy& policy,
Credentials_Manager& creds,
@@ -445,7 +445,7 @@ class Server_Hello_Done : public Handshake_Message
public:
Handshake_Type type() const { return SERVER_HELLO_DONE; }
- Server_Hello_Done(Handshake_Writer& writer, Handshake_Hash& hash);
+ Server_Hello_Done(Handshake_IO& io, Handshake_Hash& hash);
Server_Hello_Done(const std::vector<byte>& buf);
private:
std::vector<byte> serialize() const;
@@ -461,7 +461,7 @@ class Next_Protocol : public Handshake_Message
std::string protocol() const { return m_protocol; }
- Next_Protocol(Handshake_Writer& writer,
+ Next_Protocol(Handshake_IO& io,
Handshake_Hash& hash,
const std::string& protocol);
@@ -480,12 +480,12 @@ class New_Session_Ticket : public Handshake_Message
u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; }
const std::vector<byte>& ticket() const { return m_ticket; }
- New_Session_Ticket(Handshake_Writer& writer,
+ New_Session_Ticket(Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<byte>& ticket,
u32bit lifetime);
- New_Session_Ticket(Handshake_Writer& writer,
+ New_Session_Ticket(Handshake_IO& io,
Handshake_Hash& hash);
New_Session_Ticket(const std::vector<byte>& buf);
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 0f1b24045..9c6250273 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -207,8 +207,7 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
Handshake_State* Server::new_handshake_state()
{
- return new Handshake_State(new Stream_Handshake_Reader,
- new Stream_Handshake_Writer(m_writer));
+ return new Handshake_State(new Stream_Handshake_IO(m_writer));
}
/*
@@ -223,7 +222,7 @@ void Server::renegotiate(bool force_full_renegotiation)
m_state->allow_session_resumption = !force_full_renegotiation;
m_state->set_expected_next(CLIENT_HELLO);
- Hello_Request hello_req(m_state->handshake_writer());
+ Hello_Request hello_req(m_state->handshake_io());
}
void Server::alert_notify(const Alert& alert)
@@ -273,7 +272,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(type == CLIENT_HELLO_SSLV2)
m_state->hash.update(contents);
else
- m_state->hash.update(m_state->handshake_writer().format(contents, type));
+ m_state->hash.update(m_state->handshake_io().format(contents, type));
}
if(type == CLIENT_HELLO || type == CLIENT_HELLO_SSLV2)
@@ -374,7 +373,7 @@ void Server::process_handshake_msg(Handshake_Type type,
// resume session
m_state->server_hello = new Server_Hello(
- m_state->handshake_writer(),
+ m_state->handshake_io(),
m_state->hash,
m_state->client_hello->session_id(),
Protocol_Version(session_info.version()),
@@ -410,7 +409,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(m_state->server_hello->supports_session_ticket()) // send an empty ticket
{
m_state->new_session_ticket =
- new New_Session_Ticket(m_state->handshake_writer(),
+ new New_Session_Ticket(m_state->handshake_io(),
m_state->hash);
}
}
@@ -422,7 +421,7 @@ void Server::process_handshake_msg(Handshake_Type type,
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
- new New_Session_Ticket(m_state->handshake_writer(),
+ new New_Session_Ticket(m_state->handshake_io(),
m_state->hash,
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime());
@@ -432,7 +431,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(!m_state->new_session_ticket)
{
m_state->new_session_ticket =
- new New_Session_Ticket(m_state->handshake_writer(),
+ new New_Session_Ticket(m_state->handshake_io(),
m_state->hash);
}
}
@@ -444,7 +443,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_state->keys,
m_state->server_hello->compression_method());
- m_state->server_finished = new Finished(m_state->handshake_writer(),
+ m_state->server_finished = new Finished(m_state->handshake_io(),
m_state.get(), SERVER);
m_state->set_expected_next(HANDSHAKE_CCS);
@@ -471,7 +470,7 @@ void Server::process_handshake_msg(Handshake_Type type,
}
m_state->server_hello = new Server_Hello(
- m_state->handshake_writer(),
+ m_state->handshake_io(),
m_state->hash,
make_hello_random(m_rng), // new session ID
m_state->version(),
@@ -508,7 +507,7 @@ void Server::process_handshake_msg(Handshake_Type type,
BOTAN_ASSERT(!cert_chains[sig_algo].empty(),
"Attempting to send empty certificate chain");
- m_state->server_certs = new Certificate(m_state->handshake_writer(),
+ m_state->server_certs = new Certificate(m_state->handshake_io(),
m_state->hash,
cert_chains[sig_algo]);
}
@@ -533,7 +532,7 @@ void Server::process_handshake_msg(Handshake_Type type,
else
{
m_state->server_kex =
- new Server_Key_Exchange(m_state->handshake_writer(),
+ new Server_Key_Exchange(m_state->handshake_io(),
m_state.get(),
m_policy,
m_creds,
@@ -546,7 +545,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "")
{
- m_state->cert_req = new Certificate_Req(m_state->handshake_writer(),
+ m_state->cert_req = new Certificate_Req(m_state->handshake_io(),
m_state->hash,
m_policy,
client_auth_CAs,
@@ -562,7 +561,7 @@ void Server::process_handshake_msg(Handshake_Type type,
*/
m_state->set_expected_next(CLIENT_KEX);
- m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_writer(),
+ m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_io(),
m_state->hash);
}
}
@@ -599,7 +598,7 @@ void Server::process_handshake_msg(Handshake_Type type,
const bool sig_valid =
m_state->client_verify->verify(m_peer_certs[0], m_state.get());
- m_state->hash.update(m_state->handshake_writer().format(contents, type));
+ m_state->hash.update(m_state->handshake_io().format(contents, type));
/*
* Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for
@@ -654,7 +653,7 @@ void Server::process_handshake_msg(Handshake_Type type,
{
// already sent finished if resuming, so this is a new session
- m_state->hash.update(m_state->handshake_writer().format(contents, type));
+ m_state->hash.update(m_state->handshake_io().format(contents, type));
Session session_info(
m_state->server_hello->session_id(),
@@ -680,7 +679,7 @@ void Server::process_handshake_msg(Handshake_Type type,
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
- new New_Session_Ticket(m_state->handshake_writer(),
+ new New_Session_Ticket(m_state->handshake_io(),
m_state->hash,
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime());
@@ -693,7 +692,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket)
{
- m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_writer(),
+ m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_io(),
m_state->hash);
}
@@ -705,7 +704,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_state->keys,
m_state->server_hello->compression_method());
- m_state->server_finished = new Finished(m_state->handshake_writer(),
+ m_state->server_finished = new Finished(m_state->handshake_io(),
m_state.get(), SERVER);
}