aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/tls')
-rw-r--r--src/lib/tls/tls_blocking.cpp16
-rw-r--r--src/lib/tls/tls_blocking.h38
-rw-r--r--src/lib/tls/tls_channel.cpp10
-rw-r--r--src/lib/tls/tls_channel.h21
-rw-r--r--src/lib/tls/tls_client.cpp37
-rw-r--r--src/lib/tls/tls_client.h21
-rw-r--r--src/lib/tls/tls_handshake_io.h15
-rw-r--r--src/lib/tls/tls_handshake_state.cpp5
-rw-r--r--src/lib/tls/tls_handshake_state.h8
-rw-r--r--src/lib/tls/tls_record.cpp6
-rw-r--r--src/lib/tls/tls_record.h5
-rw-r--r--src/lib/tls/tls_server.cpp14
-rw-r--r--src/lib/tls/tls_server.h12
13 files changed, 116 insertions, 92 deletions
diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp
index dc5769e2c..b02c9ede1 100644
--- a/src/lib/tls/tls_blocking.cpp
+++ b/src/lib/tls/tls_blocking.cpp
@@ -13,17 +13,17 @@ namespace TLS {
using namespace std::placeholders;
-Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
- std::function<void (const byte[], size_t)> write_fn,
+Blocking_Client::Blocking_Client(read_fn reader,
+ write_fn writer,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& server_info,
const Protocol_Version offer_version,
- std::function<std::string (std::vector<std::string>)> next_protocol) :
- m_read_fn(read_fn),
- m_channel(write_fn,
+ next_protocol_fn npn) :
+ m_read(reader),
+ m_channel(writer,
std::bind(&Blocking_Client::data_cb, this, _1, _2),
std::bind(&Blocking_Client::alert_cb, this, _1, _2, _3),
std::bind(&Blocking_Client::handshake_cb, this, _1),
@@ -33,7 +33,7 @@ Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
rng,
server_info,
offer_version,
- next_protocol)
+ npn)
{
}
@@ -58,7 +58,7 @@ void Blocking_Client::do_handshake()
while(!m_channel.is_closed() && !m_channel.is_active())
{
- const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size());
+ const size_t from_socket = m_read(&readbuf[0], readbuf.size());
m_channel.received_data(&readbuf[0], from_socket);
}
}
@@ -69,7 +69,7 @@ size_t Blocking_Client::read(byte buf[], size_t buf_len)
while(m_plaintext.empty() && !m_channel.is_closed())
{
- const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size());
+ const size_t from_socket = m_read(&readbuf[0], readbuf.size());
m_channel.received_data(&readbuf[0], from_socket);
}
diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h
index 1226d9364..ca6906545 100644
--- a/src/lib/tls/tls_blocking.h
+++ b/src/lib/tls/tls_blocking.h
@@ -14,27 +14,35 @@
namespace Botan {
-template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>;
+//template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>;
namespace TLS {
/**
* Blocking TLS Client
+* Can be used directly, or subclass to get handshake and alert notifications
*/
class BOTAN_DLL Blocking_Client
{
public:
+ /*
+ * These functions are expected to block until completing entirely, or
+ * fail by throwing an exception.
+ */
+ typedef std::function<size_t (byte[], size_t)> read_fn;
+ typedef std::function<void (const byte[], size_t)> write_fn;
+
+ typedef Client::next_protocol_fn next_protocol_fn;
- Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
- std::function<void (const byte[], size_t)> write_fn,
- Session_Manager& session_manager,
- Credentials_Manager& creds,
- const Policy& policy,
- RandomNumberGenerator& rng,
- const Server_Information& server_info = Server_Information(),
- const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>());
+ Blocking_Client(read_fn reader,
+ write_fn writer,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const Server_Information& server_info = Server_Information(),
+ const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
+ next_protocol_fn npn = next_protocol_fn());
/**
* Completes full handshake then returns
@@ -68,12 +76,12 @@ class BOTAN_DLL Blocking_Client
protected:
/**
- * Can override to get the handshake complete notification
+ * Application can override to get the handshake complete notification
*/
virtual bool handshake_complete(const Session&) { return true; }
/**
- * Can override to get notification of alerts
+ * Application can override to get notification of alerts
*/
virtual void alert_notification(const Alert&) {}
@@ -85,9 +93,9 @@ class BOTAN_DLL Blocking_Client
void alert_cb(const Alert alert, const byte data[], size_t data_len);
- std::function<size_t (byte[], size_t)> m_read_fn;
+ read_fn m_read;
TLS::Client m_channel;
- secure_deque<byte> m_plaintext;
+ secure_vector<byte> m_plaintext;
};
}
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index a5e504f8b..e784566cd 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -19,10 +19,10 @@ namespace Botan {
namespace TLS {
-Channel::Channel(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Channel::Channel(output_fn output_fn,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
bool is_datagram,
@@ -124,8 +124,8 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
const u16bit mtu = 1280 - 40 - 8;
io.reset(new Datagram_Handshake_IO(
- sequence_numbers(),
std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3),
+ sequence_numbers(),
mtu));
}
else
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 5b5a5d530..713d4c1b9 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -31,10 +31,15 @@ class Handshake_State;
class BOTAN_DLL Channel
{
public:
- Channel(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+ typedef std::function<void (const byte[], size_t)> output_fn;
+ typedef std::function<void (const byte[], size_t)> data_cb;
+ typedef std::function<void (Alert, const byte[], size_t)> alert_cb;
+ typedef std::function<bool (const Session&)> handshake_cb;
+
+ Channel(output_fn out,
+ data_cb app_data_cb,
+ alert_cb alert_cb,
+ handshake_cb hs_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
bool is_datagram,
@@ -240,10 +245,10 @@ class BOTAN_DLL Channel
bool m_is_datagram;
/* callbacks */
- std::function<bool (const Session&)> m_handshake_cb;
- std::function<void (const byte[], size_t)> m_data_cb;
- std::function<void (Alert, const byte[], size_t)> m_alert_cb;
- std::function<void (const byte[], size_t)> m_output_fn;
+ handshake_cb m_handshake_cb;
+ data_cb m_data_cb;
+ alert_cb m_alert_cb;
+ output_fn m_output_fn;
/* external state */
RandomNumberGenerator& m_rng;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index f68c9c614..75df6332a 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -22,10 +22,8 @@ class Client_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Client_Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback =
- std::function<void (const Handshake_Message&)>()) :
- Handshake_State(io, msg_callback) {}
+ Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
+ Handshake_State(io, cb) {}
const Public_Key& get_server_public_Key() const
{
@@ -39,7 +37,7 @@ class Client_Handshake_State : public Handshake_State
std::unique_ptr<Public_Key> server_public_key;
// Used by client using NPN
- std::function<std::string (std::vector<std::string>)> client_npn_cb;
+ Client::next_protocol_fn client_npn_cb;
};
}
@@ -47,17 +45,17 @@ class Client_Handshake_State : public Handshake_State
/*
* TLS Client Constructor
*/
-Client::Client(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> proc_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Client::Client(output_fn output_fn,
+ data_cb proc_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& info,
const Protocol_Version offer_version,
- std::function<std::string (std::vector<std::string>)> next_protocol,
+ next_protocol_fn npn,
size_t io_buf_sz) :
Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng,
offer_version.is_datagram_protocol(), io_buf_sz),
@@ -68,12 +66,12 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname());
Handshake_State& state = create_handshake_state(offer_version);
- send_client_hello(state, false, offer_version, srp_identifier, next_protocol);
+ send_client_hello(state, false, offer_version, srp_identifier, npn);
}
Handshake_State* Client::new_handshake_state(Handshake_IO* io)
{
- return new Client_Handshake_State(io);
+ return new Client_Handshake_State(io); // , m_hs_msg_cb);
}
std::vector<X509_Certificate>
@@ -99,7 +97,7 @@ void Client::send_client_hello(Handshake_State& state_base,
bool force_full_renegotiation,
Protocol_Version version,
const std::string& srp_identifier,
- std::function<std::string (std::vector<std::string>)> next_protocol)
+ next_protocol_fn next_protocol)
{
Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base);
@@ -167,16 +165,19 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
if(state.client_hello())
return;
- if(!m_policy.allow_server_initiated_renegotiation() ||
- (!m_policy.allow_insecure_renegotiation() && !secure_renegotiation_supported()))
+ if(m_policy.allow_server_initiated_renegotiation())
+ {
+ if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false)
+ send_warning_alert(Alert::NO_RENEGOTIATION);
+ else
+ this->initiate_handshake(state, false);
+ }
+ else
{
// RFC 5746 section 4.2
send_warning_alert(Alert::NO_RENEGOTIATION);
- return;
}
- this->initiate_handshake(state, false);
-
return;
}
diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h
index 126dbb935..a548a32e0 100644
--- a/src/lib/tls/tls_client.h
+++ b/src/lib/tls/tls_client.h
@@ -25,9 +25,9 @@ class BOTAN_DLL Client : public Channel
/**
* Set up a new TLS client session
*
- * @param socket_output_fn is called with data for the outbound socket
+ * @param output_fn is called with data for the outbound socket
*
- * @param proc_cb is called when new application data is received
+ * @param app_data_cb is called when new application data is received
*
* @param alert_cb is called when a TLS alert is received
*
@@ -59,18 +59,20 @@ class BOTAN_DLL Client : public Channel
* be preallocated for the read and write buffers. Smaller
* values just mean reallocations and copies are more likely.
*/
- Client(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+
+ typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn;
+
+ Client(output_fn out,
+ data_cb app_data_cb,
+ alert_cb alert_cb,
+ handshake_cb hs_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& server_info = Server_Information(),
const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>(),
+ next_protocol_fn next_protocol = next_protocol_fn(),
size_t reserved_io_buffer_size = 16*1024
);
private:
@@ -84,8 +86,7 @@ class BOTAN_DLL Client : public Channel
bool force_full_renegotiation,
Protocol_Version version,
const std::string& srp_identifier = "",
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>());
+ next_protocol_fn next_protocol = next_protocol_fn());
void process_handshake_msg(const Handshake_State* active_state,
Handshake_State& pending_state,
diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h
index 34873c3a6..00074a744 100644
--- a/src/lib/tls/tls_handshake_io.h
+++ b/src/lib/tls/tls_handshake_io.h
@@ -65,8 +65,9 @@ class Handshake_IO
class Stream_Handshake_IO : public Handshake_IO
{
public:
- Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) :
- m_send_hs(writer) {}
+ typedef std::function<void (byte, const std::vector<byte>&)> writer_fn;
+
+ Stream_Handshake_IO(writer_fn writer) : m_send_hs(writer) {}
Protocol_Version initial_record_version() const override;
@@ -86,7 +87,7 @@ class Stream_Handshake_IO : public Handshake_IO
get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
- std::function<void (byte, const std::vector<byte>&)> m_send_hs;
+ writer_fn m_send_hs;
};
/**
@@ -95,8 +96,10 @@ class Stream_Handshake_IO : public Handshake_IO
class Datagram_Handshake_IO : public Handshake_IO
{
public:
- Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq,
- std::function<void (u16bit, byte, const std::vector<byte>&)> writer,
+ typedef std::function<void (u16bit, byte, const std::vector<byte>&)> writer_fn;
+
+ Datagram_Handshake_IO(writer_fn writer,
+ class Connection_Sequence_Numbers& seq,
u16bit mtu) :
m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {}
@@ -186,7 +189,7 @@ class Datagram_Handshake_IO : public Handshake_IO
u16bit m_in_message_seq = 0;
u16bit m_out_message_seq = 0;
- std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs;
+ writer_fn m_send_hs;
u16bit m_mtu;
};
diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp
index 111087041..64d35fd6b 100644
--- a/src/lib/tls/tls_handshake_state.cpp
+++ b/src/lib/tls/tls_handshake_state.cpp
@@ -83,9 +83,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback) :
- m_msg_callback(msg_callback),
+Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) :
+ m_msg_callback(cb),
m_handshake_io(io),
m_version(m_handshake_io->initial_record_version())
{
diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h
index 3c8d856d7..bb2abc209 100644
--- a/src/lib/tls/tls_handshake_state.h
+++ b/src/lib/tls/tls_handshake_state.h
@@ -46,9 +46,9 @@ class Finished;
class Handshake_State
{
public:
- Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback =
- std::function<void (const Handshake_Message&)>());
+ typedef std::function<void (const Handshake_Message&)> hs_msg_cb;
+
+ Handshake_State(Handshake_IO* io, hs_msg_cb cb);
virtual ~Handshake_State();
@@ -176,7 +176,7 @@ class Handshake_State
private:
- std::function<void (const Handshake_Message&)> m_msg_callback;
+ hs_msg_cb m_msg_callback;
std::unique_ptr<Handshake_IO> m_handshake_io;
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 3edeab7e3..d5e3126f1 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -455,7 +455,7 @@ size_t read_tls_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
consumed = 0;
@@ -543,7 +543,7 @@ size_t read_dtls_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
consumed = 0;
@@ -642,7 +642,7 @@ size_t read_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
if(is_datagram)
return read_dtls_record(readbuf, input, input_sz, consumed,
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index c9bf8aade..46f87a9af 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -113,6 +113,9 @@ void write_record(secure_vector<byte>& write_buffer,
Connection_Cipher_State* cipherstate,
RandomNumberGenerator& rng);
+// epoch -> cipher state
+typedef std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate_fn;
+
/**
* Decode a TLS record
* @return zero if full message, else number of bytes still needed
@@ -127,7 +130,7 @@ size_t read_record(secure_vector<byte>& read_buffer,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate);
+ get_cipherstate_fn get_cipherstate);
}
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 1490fc2a4..515bd9e17 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -21,7 +21,8 @@ class Server_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Server_Handshake_State(Handshake_IO* io) : Handshake_State(io) {}
+ Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
+ Handshake_State(io, cb) {}
// Used by the server only, in case of RSA key exchange. Not owned
Private_Key* server_rsa_kex_key = nullptr;
@@ -203,10 +204,10 @@ get_server_certs(const std::string& hostname,
/*
* TLS Server Constructor
*/
-Server::Server(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Server::Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
@@ -214,7 +215,8 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
const std::vector<std::string>& next_protocols,
bool is_datagram,
size_t io_buf_sz) :
- Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz),
+ Channel(output, data_cb, alert_cb, handshake_cb,
+ session_manager, rng, is_datagram, io_buf_sz),
m_policy(policy),
m_creds(creds),
m_possible_protocols(next_protocols)
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index ce82e001d..4b15e837b 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -25,10 +25,10 @@ class BOTAN_DLL Server : public Channel
/**
* Server initialization
*/
- Server(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+ Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
@@ -40,7 +40,9 @@ class BOTAN_DLL Server : public Channel
/**
* Return the protocol notification set by the client (using the
- * NPN extension) for this connection, if any
+ * NPN extension) for this connection, if any. This value is not
+ * tied to the session and a later renegotiation of the same
+ * session can choose a new protocol.
*/
std::string next_protocol() const { return m_next_protocol; }