aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--doc/manual/tls.rst68
-rw-r--r--src/cmd/tls_client.cpp28
-rw-r--r--src/lib/alloc/secmem.h2
-rw-r--r--src/lib/tls/tls_blocking.cpp16
-rw-r--r--src/lib/tls/tls_blocking.h38
-rw-r--r--src/lib/tls/tls_channel.cpp10
-rw-r--r--src/lib/tls/tls_channel.h21
-rw-r--r--src/lib/tls/tls_client.cpp37
-rw-r--r--src/lib/tls/tls_client.h21
-rw-r--r--src/lib/tls/tls_handshake_io.h15
-rw-r--r--src/lib/tls/tls_handshake_state.cpp5
-rw-r--r--src/lib/tls/tls_handshake_state.h8
-rw-r--r--src/lib/tls/tls_record.cpp6
-rw-r--r--src/lib/tls/tls_record.h5
-rw-r--r--src/lib/tls/tls_server.cpp14
-rw-r--r--src/lib/tls/tls_server.h12
16 files changed, 171 insertions, 135 deletions
diff --git a/doc/manual/tls.rst b/doc/manual/tls.rst
index b581c978c..5e1d48656 100644
--- a/doc/manual/tls.rst
+++ b/doc/manual/tls.rst
@@ -30,7 +30,7 @@ abstraction. This makes the library completely agnostic to how you
write your network layer, be it blocking sockets, libevent, asio, a
message queue, etc.
-The callbacks that TLS calls have the signatures
+The callbacks for TLS have the signatures
.. cpp:function:: void output_fn(const byte data[], size_t data_len)
@@ -81,6 +81,13 @@ available:
.. cpp:class:: TLS::Channel
+ .. cpp:type:: std::function<void (const byte[], size_t)> output_fn
+ .. cpp:type:: std::function<void (const byte[], size_t)> data_cb
+ .. cpp:type:: std::function<void (Alert, const byte[], size_t)> alert_cb
+ .. cpp:type:: std::function<bool (const Session&)> handshake_cb
+
+ Typedefs used in the code for the functions described above
+
.. cpp:function:: size_t received_data(const byte buf[], size_t buf_size)
.. cpp:function:: size_t received_data(const std::vector<byte>& buf)
@@ -185,18 +192,18 @@ TLS Clients
.. cpp:class:: TLS::Client
.. cpp:function:: TLS::Client( \
- std::function<void, const byte*, size_t> output_fn, \
- std::function<void, const byte*, size_t> data_cb, \
- std::function<TLS::Alert, const byte*, size_t> alert_cb, \
- std::function<bool, const TLS::Session&> handshake_cb, \
- TLS::Session_Manager& session_manager, \
- Credentials_Manager& credendials_manager, \
- const TLS::Policy& policy, \
- RandomNumberGenerator& rng, \
- const Server_Information& server_info, \
- const Protocol_Version offer_version, \
- std::function<std::string, std::vector<std::string> > next_protocol, \
- size_t reserved_io_buffer_size)
+ output_fn output, \
+ data_cb data, \
+ alert_cb alert, \
+ handshake_cb handshake_complete, \
+ TLS::Session_Manager& session_manager, \
+ Credentials_Manager& credendials_manager, \
+ const TLS::Policy& policy, \
+ RandomNumberGenerator& rng, \
+ const Server_Information& server_info, \
+ const Protocol_Version offer_version, \
+ next_protocol_fn npn, \
+ size_t reserved_io_buffer_size)
Initialize a new TLS client. The constructor will immediately
initiate a new session.
@@ -234,23 +241,20 @@ TLS Clients
The *credentials_manager* is an interface that will be called to
retrieve any certificates, secret keys, pre-shared keys, or SRP
- intformation; see :doc:`credentials_manager` for more information.
+ information; see :doc:`credentials_manager` for more information.
- Use *server_info* to specify the DNS name of the server you are
- attempting to connect to, if you know it. This helps the server
- select what certificate to use and helps the client validate the
- connection.
+ Use the optional *server_info* to specify the DNS name of the
+ server you are attempting to connect to, if you know it. This helps
+ the server select what certificate to use and helps the client
+ validate the connection.
- Use *offer_version* to control the version of TLS you wish the
- client to offer. Normally, you'll want to offer the most recent
- version of (D)TLS that is available, however some broken servers are
- intolerant of certain versions being offered, and for classes of
- applications that have to deal with such servers (typically web
- browsers) it may be necessary to implement a version backdown
- strategy if the initial attempt fails.
-
- Setting *offer_version* is also used to offer DTLS instead of TLS;
- use :cpp:func:`TLS::Protocol_Version::latest_dtls_version`.
+ Use the optional *offer_version* to control the version of TLS you
+ wish the client to offer. Normally, you'll want to offer the most
+ recent version of (D)TLS that is available, however some broken
+ servers are intolerant of certain versions being offered, and for
+ classes of applications that have to deal with such servers
+ (typically web browsers) it may be necessary to implement a version
+ backdown strategy if the initial attempt fails.
.. warning::
@@ -258,6 +262,9 @@ TLS Clients
downgrade your connection to the weakest protocol that both you
and the server support.
+ Setting *offer_version* is also used to offer DTLS instead of TLS;
+ use :cpp:func:`TLS::Protocol_Version::latest_dtls_version`.
+
The optional *next_protocol* callback is called if the server
indicates it supports the next protocol notification extension.
The callback wlil be called with a list of protocol names that the
@@ -270,7 +277,7 @@ TLS Clients
resized as needed to process inputs). Otherwise some reasonable
default is used.
-A TLS client example using BSD sockets is in `src/cmd/tls_client.cpp`
+Code for a TLS client using BSD sockets is in `src/cmd/tls_client.cpp`
TLS Servers
----------------------------------------
@@ -308,8 +315,7 @@ not until they actually receive a hello without this parameter.
renegotiation, but might change across different connections using
that session.
-An example TLS server implementation using asio is available in
-`src/cmd/tls_proxy.cpp`.
+Code for a TLS server using asio is in `src/cmd/tls_proxy.cpp`.
.. _tls_sessions:
diff --git a/src/cmd/tls_client.cpp b/src/cmd/tls_client.cpp
index 543119fb9..903824a78 100644
--- a/src/cmd/tls_client.cpp
+++ b/src/cmd/tls_client.cpp
@@ -39,7 +39,7 @@ using namespace std::placeholders;
namespace {
-int connect_to_host(const std::string& host, u16bit port, const std::string& transport)
+int connect_to_host(const std::string& host, u16bit port, bool tcp)
{
hostent* host_addr = ::gethostbyname(host.c_str());
@@ -49,7 +49,7 @@ int connect_to_host(const std::string& host, u16bit port, const std::string& tra
if(host_addr->h_addrtype != AF_INET) // FIXME
throw std::runtime_error(host + " has IPv6 address, not supported");
- int type = (transport == "tcp") ? SOCK_STREAM : SOCK_DGRAM;
+ int type = tcp ? SOCK_STREAM : SOCK_DGRAM;
int fd = ::socket(PF_INET, type, 0);
if(fd == -1)
@@ -130,13 +130,6 @@ void process_data(const byte buf[], size_t buf_size)
std::cout << buf[i];
}
-std::string protocol_chooser(const std::vector<std::string>& protocols)
- {
- for(size_t i = 0; i != protocols.size(); ++i)
- std::cout << "Protocol " << i << " = " << protocols[i] << "\n";
- return "http/1.1";
- }
-
int tls_client(int argc, char* argv[])
{
if(argc != 2 && argc != 3 && argc != 4)
@@ -145,6 +138,9 @@ int tls_client(int argc, char* argv[])
return 1;
}
+ const bool request_protocol = true;
+ const std::string use_protocol = "http/1.1";
+
try
{
AutoSeeded_RNG rng;
@@ -167,14 +163,22 @@ int tls_client(int argc, char* argv[])
u32bit port = argc >= 3 ? Botan::to_u32bit(argv[2]) : 443;
const std::string transport = argc >= 4 ? argv[3] : "tcp";
- int sockfd = connect_to_host(host, port, transport);
+ const bool use_tcp = (transport == "tcp");
+
+ int sockfd = connect_to_host(host, port, use_tcp);
+
+ auto protocol_chooser = [use_protocol](const std::vector<std::string>& protocols) -> std::string {
+ for(size_t i = 0; i != protocols.size(); ++i)
+ std::cout << "Server offered protocol " << i << " = " << protocols[i] << "\n";
+ return use_protocol;
+ };
auto socket_write =
- (transport == "tcp") ?
+ use_tcp ?
std::bind(stream_socket_write, sockfd, _1, _2) :
std::bind(dgram_socket_write, sockfd, _1, _2);
- auto version = policy.latest_supported_version(transport != "tcp");
+ auto version = policy.latest_supported_version(!use_tcp);
TLS::Client client(socket_write,
process_data,
diff --git a/src/lib/alloc/secmem.h b/src/lib/alloc/secmem.h
index 58d0734cb..82b4083ea 100644
--- a/src/lib/alloc/secmem.h
+++ b/src/lib/alloc/secmem.h
@@ -11,6 +11,7 @@
#include <botan/mem_ops.h>
#include <algorithm>
#include <vector>
+#include <deque>
#if defined(BOTAN_HAS_LOCKING_ALLOCATOR)
#include <botan/locking_allocator.h>
@@ -90,6 +91,7 @@ operator!=(const secure_allocator<T>&, const secure_allocator<T>&)
{ return false; }
template<typename T> using secure_vector = std::vector<T, secure_allocator<T>>;
+template<typename T> using secure_deque = std::deque<T, secure_allocator<T>>;
template<typename T>
std::vector<T> unlock(const secure_vector<T>& in)
diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp
index dc5769e2c..b02c9ede1 100644
--- a/src/lib/tls/tls_blocking.cpp
+++ b/src/lib/tls/tls_blocking.cpp
@@ -13,17 +13,17 @@ namespace TLS {
using namespace std::placeholders;
-Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
- std::function<void (const byte[], size_t)> write_fn,
+Blocking_Client::Blocking_Client(read_fn reader,
+ write_fn writer,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& server_info,
const Protocol_Version offer_version,
- std::function<std::string (std::vector<std::string>)> next_protocol) :
- m_read_fn(read_fn),
- m_channel(write_fn,
+ next_protocol_fn npn) :
+ m_read(reader),
+ m_channel(writer,
std::bind(&Blocking_Client::data_cb, this, _1, _2),
std::bind(&Blocking_Client::alert_cb, this, _1, _2, _3),
std::bind(&Blocking_Client::handshake_cb, this, _1),
@@ -33,7 +33,7 @@ Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
rng,
server_info,
offer_version,
- next_protocol)
+ npn)
{
}
@@ -58,7 +58,7 @@ void Blocking_Client::do_handshake()
while(!m_channel.is_closed() && !m_channel.is_active())
{
- const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size());
+ const size_t from_socket = m_read(&readbuf[0], readbuf.size());
m_channel.received_data(&readbuf[0], from_socket);
}
}
@@ -69,7 +69,7 @@ size_t Blocking_Client::read(byte buf[], size_t buf_len)
while(m_plaintext.empty() && !m_channel.is_closed())
{
- const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size());
+ const size_t from_socket = m_read(&readbuf[0], readbuf.size());
m_channel.received_data(&readbuf[0], from_socket);
}
diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h
index 1226d9364..ca6906545 100644
--- a/src/lib/tls/tls_blocking.h
+++ b/src/lib/tls/tls_blocking.h
@@ -14,27 +14,35 @@
namespace Botan {
-template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>;
+//template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>;
namespace TLS {
/**
* Blocking TLS Client
+* Can be used directly, or subclass to get handshake and alert notifications
*/
class BOTAN_DLL Blocking_Client
{
public:
+ /*
+ * These functions are expected to block until completing entirely, or
+ * fail by throwing an exception.
+ */
+ typedef std::function<size_t (byte[], size_t)> read_fn;
+ typedef std::function<void (const byte[], size_t)> write_fn;
+
+ typedef Client::next_protocol_fn next_protocol_fn;
- Blocking_Client(std::function<size_t (byte[], size_t)> read_fn,
- std::function<void (const byte[], size_t)> write_fn,
- Session_Manager& session_manager,
- Credentials_Manager& creds,
- const Policy& policy,
- RandomNumberGenerator& rng,
- const Server_Information& server_info = Server_Information(),
- const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>());
+ Blocking_Client(read_fn reader,
+ write_fn writer,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const Server_Information& server_info = Server_Information(),
+ const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
+ next_protocol_fn npn = next_protocol_fn());
/**
* Completes full handshake then returns
@@ -68,12 +76,12 @@ class BOTAN_DLL Blocking_Client
protected:
/**
- * Can override to get the handshake complete notification
+ * Application can override to get the handshake complete notification
*/
virtual bool handshake_complete(const Session&) { return true; }
/**
- * Can override to get notification of alerts
+ * Application can override to get notification of alerts
*/
virtual void alert_notification(const Alert&) {}
@@ -85,9 +93,9 @@ class BOTAN_DLL Blocking_Client
void alert_cb(const Alert alert, const byte data[], size_t data_len);
- std::function<size_t (byte[], size_t)> m_read_fn;
+ read_fn m_read;
TLS::Client m_channel;
- secure_deque<byte> m_plaintext;
+ secure_vector<byte> m_plaintext;
};
}
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index a5e504f8b..e784566cd 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -19,10 +19,10 @@ namespace Botan {
namespace TLS {
-Channel::Channel(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Channel::Channel(output_fn output_fn,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
bool is_datagram,
@@ -124,8 +124,8 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
const u16bit mtu = 1280 - 40 - 8;
io.reset(new Datagram_Handshake_IO(
- sequence_numbers(),
std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3),
+ sequence_numbers(),
mtu));
}
else
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 5b5a5d530..713d4c1b9 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -31,10 +31,15 @@ class Handshake_State;
class BOTAN_DLL Channel
{
public:
- Channel(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+ typedef std::function<void (const byte[], size_t)> output_fn;
+ typedef std::function<void (const byte[], size_t)> data_cb;
+ typedef std::function<void (Alert, const byte[], size_t)> alert_cb;
+ typedef std::function<bool (const Session&)> handshake_cb;
+
+ Channel(output_fn out,
+ data_cb app_data_cb,
+ alert_cb alert_cb,
+ handshake_cb hs_cb,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
bool is_datagram,
@@ -240,10 +245,10 @@ class BOTAN_DLL Channel
bool m_is_datagram;
/* callbacks */
- std::function<bool (const Session&)> m_handshake_cb;
- std::function<void (const byte[], size_t)> m_data_cb;
- std::function<void (Alert, const byte[], size_t)> m_alert_cb;
- std::function<void (const byte[], size_t)> m_output_fn;
+ handshake_cb m_handshake_cb;
+ data_cb m_data_cb;
+ alert_cb m_alert_cb;
+ output_fn m_output_fn;
/* external state */
RandomNumberGenerator& m_rng;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index f68c9c614..75df6332a 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -22,10 +22,8 @@ class Client_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Client_Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback =
- std::function<void (const Handshake_Message&)>()) :
- Handshake_State(io, msg_callback) {}
+ Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
+ Handshake_State(io, cb) {}
const Public_Key& get_server_public_Key() const
{
@@ -39,7 +37,7 @@ class Client_Handshake_State : public Handshake_State
std::unique_ptr<Public_Key> server_public_key;
// Used by client using NPN
- std::function<std::string (std::vector<std::string>)> client_npn_cb;
+ Client::next_protocol_fn client_npn_cb;
};
}
@@ -47,17 +45,17 @@ class Client_Handshake_State : public Handshake_State
/*
* TLS Client Constructor
*/
-Client::Client(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> proc_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Client::Client(output_fn output_fn,
+ data_cb proc_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& info,
const Protocol_Version offer_version,
- std::function<std::string (std::vector<std::string>)> next_protocol,
+ next_protocol_fn npn,
size_t io_buf_sz) :
Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng,
offer_version.is_datagram_protocol(), io_buf_sz),
@@ -68,12 +66,12 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname());
Handshake_State& state = create_handshake_state(offer_version);
- send_client_hello(state, false, offer_version, srp_identifier, next_protocol);
+ send_client_hello(state, false, offer_version, srp_identifier, npn);
}
Handshake_State* Client::new_handshake_state(Handshake_IO* io)
{
- return new Client_Handshake_State(io);
+ return new Client_Handshake_State(io); // , m_hs_msg_cb);
}
std::vector<X509_Certificate>
@@ -99,7 +97,7 @@ void Client::send_client_hello(Handshake_State& state_base,
bool force_full_renegotiation,
Protocol_Version version,
const std::string& srp_identifier,
- std::function<std::string (std::vector<std::string>)> next_protocol)
+ next_protocol_fn next_protocol)
{
Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base);
@@ -167,16 +165,19 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
if(state.client_hello())
return;
- if(!m_policy.allow_server_initiated_renegotiation() ||
- (!m_policy.allow_insecure_renegotiation() && !secure_renegotiation_supported()))
+ if(m_policy.allow_server_initiated_renegotiation())
+ {
+ if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false)
+ send_warning_alert(Alert::NO_RENEGOTIATION);
+ else
+ this->initiate_handshake(state, false);
+ }
+ else
{
// RFC 5746 section 4.2
send_warning_alert(Alert::NO_RENEGOTIATION);
- return;
}
- this->initiate_handshake(state, false);
-
return;
}
diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h
index 126dbb935..a548a32e0 100644
--- a/src/lib/tls/tls_client.h
+++ b/src/lib/tls/tls_client.h
@@ -25,9 +25,9 @@ class BOTAN_DLL Client : public Channel
/**
* Set up a new TLS client session
*
- * @param socket_output_fn is called with data for the outbound socket
+ * @param output_fn is called with data for the outbound socket
*
- * @param proc_cb is called when new application data is received
+ * @param app_data_cb is called when new application data is received
*
* @param alert_cb is called when a TLS alert is received
*
@@ -59,18 +59,20 @@ class BOTAN_DLL Client : public Channel
* be preallocated for the read and write buffers. Smaller
* values just mean reallocations and copies are more likely.
*/
- Client(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+
+ typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn;
+
+ Client(output_fn out,
+ data_cb app_data_cb,
+ alert_cb alert_cb,
+ handshake_cb hs_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
const Server_Information& server_info = Server_Information(),
const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>(),
+ next_protocol_fn next_protocol = next_protocol_fn(),
size_t reserved_io_buffer_size = 16*1024
);
private:
@@ -84,8 +86,7 @@ class BOTAN_DLL Client : public Channel
bool force_full_renegotiation,
Protocol_Version version,
const std::string& srp_identifier = "",
- std::function<std::string (std::vector<std::string>)> next_protocol =
- std::function<std::string (std::vector<std::string>)>());
+ next_protocol_fn next_protocol = next_protocol_fn());
void process_handshake_msg(const Handshake_State* active_state,
Handshake_State& pending_state,
diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h
index 34873c3a6..00074a744 100644
--- a/src/lib/tls/tls_handshake_io.h
+++ b/src/lib/tls/tls_handshake_io.h
@@ -65,8 +65,9 @@ class Handshake_IO
class Stream_Handshake_IO : public Handshake_IO
{
public:
- Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) :
- m_send_hs(writer) {}
+ typedef std::function<void (byte, const std::vector<byte>&)> writer_fn;
+
+ Stream_Handshake_IO(writer_fn writer) : m_send_hs(writer) {}
Protocol_Version initial_record_version() const override;
@@ -86,7 +87,7 @@ class Stream_Handshake_IO : public Handshake_IO
get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
- std::function<void (byte, const std::vector<byte>&)> m_send_hs;
+ writer_fn m_send_hs;
};
/**
@@ -95,8 +96,10 @@ class Stream_Handshake_IO : public Handshake_IO
class Datagram_Handshake_IO : public Handshake_IO
{
public:
- Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq,
- std::function<void (u16bit, byte, const std::vector<byte>&)> writer,
+ typedef std::function<void (u16bit, byte, const std::vector<byte>&)> writer_fn;
+
+ Datagram_Handshake_IO(writer_fn writer,
+ class Connection_Sequence_Numbers& seq,
u16bit mtu) :
m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {}
@@ -186,7 +189,7 @@ class Datagram_Handshake_IO : public Handshake_IO
u16bit m_in_message_seq = 0;
u16bit m_out_message_seq = 0;
- std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs;
+ writer_fn m_send_hs;
u16bit m_mtu;
};
diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp
index 111087041..64d35fd6b 100644
--- a/src/lib/tls/tls_handshake_state.cpp
+++ b/src/lib/tls/tls_handshake_state.cpp
@@ -83,9 +83,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback) :
- m_msg_callback(msg_callback),
+Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) :
+ m_msg_callback(cb),
m_handshake_io(io),
m_version(m_handshake_io->initial_record_version())
{
diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h
index 3c8d856d7..bb2abc209 100644
--- a/src/lib/tls/tls_handshake_state.h
+++ b/src/lib/tls/tls_handshake_state.h
@@ -46,9 +46,9 @@ class Finished;
class Handshake_State
{
public:
- Handshake_State(Handshake_IO* io,
- std::function<void (const Handshake_Message&)> msg_callback =
- std::function<void (const Handshake_Message&)>());
+ typedef std::function<void (const Handshake_Message&)> hs_msg_cb;
+
+ Handshake_State(Handshake_IO* io, hs_msg_cb cb);
virtual ~Handshake_State();
@@ -176,7 +176,7 @@ class Handshake_State
private:
- std::function<void (const Handshake_Message&)> m_msg_callback;
+ hs_msg_cb m_msg_callback;
std::unique_ptr<Handshake_IO> m_handshake_io;
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 3edeab7e3..d5e3126f1 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -455,7 +455,7 @@ size_t read_tls_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
consumed = 0;
@@ -543,7 +543,7 @@ size_t read_dtls_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
consumed = 0;
@@ -642,7 +642,7 @@ size_t read_record(secure_vector<byte>& readbuf,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ get_cipherstate_fn get_cipherstate)
{
if(is_datagram)
return read_dtls_record(readbuf, input, input_sz, consumed,
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index c9bf8aade..46f87a9af 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -113,6 +113,9 @@ void write_record(secure_vector<byte>& write_buffer,
Connection_Cipher_State* cipherstate,
RandomNumberGenerator& rng);
+// epoch -> cipher state
+typedef std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate_fn;
+
/**
* Decode a TLS record
* @return zero if full message, else number of bytes still needed
@@ -127,7 +130,7 @@ size_t read_record(secure_vector<byte>& read_buffer,
Protocol_Version* record_version,
Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate);
+ get_cipherstate_fn get_cipherstate);
}
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 1490fc2a4..515bd9e17 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -21,7 +21,8 @@ class Server_Handshake_State : public Handshake_State
public:
// using Handshake_State::Handshake_State;
- Server_Handshake_State(Handshake_IO* io) : Handshake_State(io) {}
+ Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) :
+ Handshake_State(io, cb) {}
// Used by the server only, in case of RSA key exchange. Not owned
Private_Key* server_rsa_kex_key = nullptr;
@@ -203,10 +204,10 @@ get_server_certs(const std::string& hostname,
/*
* TLS Server Constructor
*/
-Server::Server(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+Server::Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
@@ -214,7 +215,8 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
const std::vector<std::string>& next_protocols,
bool is_datagram,
size_t io_buf_sz) :
- Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz),
+ Channel(output, data_cb, alert_cb, handshake_cb,
+ session_manager, rng, is_datagram, io_buf_sz),
m_policy(policy),
m_creds(creds),
m_possible_protocols(next_protocols)
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index ce82e001d..4b15e837b 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -25,10 +25,10 @@ class BOTAN_DLL Server : public Channel
/**
* Server initialization
*/
- Server(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t)> data_cb,
- std::function<void (Alert, const byte[], size_t)> alert_cb,
- std::function<bool (const Session&)> handshake_cb,
+ Server(output_fn output,
+ data_cb data_cb,
+ alert_cb alert_cb,
+ handshake_cb handshake_cb,
Session_Manager& session_manager,
Credentials_Manager& creds,
const Policy& policy,
@@ -40,7 +40,9 @@ class BOTAN_DLL Server : public Channel
/**
* Return the protocol notification set by the client (using the
- * NPN extension) for this connection, if any
+ * NPN extension) for this connection, if any. This value is not
+ * tied to the session and a later renegotiation of the same
+ * session can choose a new protocol.
*/
std::string next_protocol() const { return m_next_protocol; }