diff options
author | lloyd <[email protected]> | 2015-01-27 14:10:37 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2015-01-27 14:10:37 +0000 |
commit | b8fa304ec981d273c45d7ef31705d65ccfb00cc1 (patch) | |
tree | 86a0c03ddcf3f6b331a73170167bbf1e429e3d79 | |
parent | 5ca89c642f19b747b965a22db87e7af2d13d0f35 (diff) |
Add typedefs for function signatures/types used in TLS for easier reading
-rw-r--r-- | doc/manual/tls.rst | 68 | ||||
-rw-r--r-- | src/cmd/tls_client.cpp | 28 | ||||
-rw-r--r-- | src/lib/alloc/secmem.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_blocking.cpp | 16 | ||||
-rw-r--r-- | src/lib/tls/tls_blocking.h | 38 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 10 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 21 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 37 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 21 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.cpp | 5 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.h | 8 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 5 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 14 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 12 |
16 files changed, 171 insertions, 135 deletions
diff --git a/doc/manual/tls.rst b/doc/manual/tls.rst index b581c978c..5e1d48656 100644 --- a/doc/manual/tls.rst +++ b/doc/manual/tls.rst @@ -30,7 +30,7 @@ abstraction. This makes the library completely agnostic to how you write your network layer, be it blocking sockets, libevent, asio, a message queue, etc. -The callbacks that TLS calls have the signatures +The callbacks for TLS have the signatures .. cpp:function:: void output_fn(const byte data[], size_t data_len) @@ -81,6 +81,13 @@ available: .. cpp:class:: TLS::Channel + .. cpp:type:: std::function<void (const byte[], size_t)> output_fn + .. cpp:type:: std::function<void (const byte[], size_t)> data_cb + .. cpp:type:: std::function<void (Alert, const byte[], size_t)> alert_cb + .. cpp:type:: std::function<bool (const Session&)> handshake_cb + + Typedefs used in the code for the functions described above + .. cpp:function:: size_t received_data(const byte buf[], size_t buf_size) .. cpp:function:: size_t received_data(const std::vector<byte>& buf) @@ -185,18 +192,18 @@ TLS Clients .. cpp:class:: TLS::Client .. cpp:function:: TLS::Client( \ - std::function<void, const byte*, size_t> output_fn, \ - std::function<void, const byte*, size_t> data_cb, \ - std::function<TLS::Alert, const byte*, size_t> alert_cb, \ - std::function<bool, const TLS::Session&> handshake_cb, \ - TLS::Session_Manager& session_manager, \ - Credentials_Manager& credendials_manager, \ - const TLS::Policy& policy, \ - RandomNumberGenerator& rng, \ - const Server_Information& server_info, \ - const Protocol_Version offer_version, \ - std::function<std::string, std::vector<std::string> > next_protocol, \ - size_t reserved_io_buffer_size) + output_fn output, \ + data_cb data, \ + alert_cb alert, \ + handshake_cb handshake_complete, \ + TLS::Session_Manager& session_manager, \ + Credentials_Manager& credendials_manager, \ + const TLS::Policy& policy, \ + RandomNumberGenerator& rng, \ + const Server_Information& server_info, \ + const Protocol_Version offer_version, \ + next_protocol_fn npn, \ + size_t reserved_io_buffer_size) Initialize a new TLS client. The constructor will immediately initiate a new session. @@ -234,23 +241,20 @@ TLS Clients The *credentials_manager* is an interface that will be called to retrieve any certificates, secret keys, pre-shared keys, or SRP - intformation; see :doc:`credentials_manager` for more information. + information; see :doc:`credentials_manager` for more information. - Use *server_info* to specify the DNS name of the server you are - attempting to connect to, if you know it. This helps the server - select what certificate to use and helps the client validate the - connection. + Use the optional *server_info* to specify the DNS name of the + server you are attempting to connect to, if you know it. This helps + the server select what certificate to use and helps the client + validate the connection. - Use *offer_version* to control the version of TLS you wish the - client to offer. Normally, you'll want to offer the most recent - version of (D)TLS that is available, however some broken servers are - intolerant of certain versions being offered, and for classes of - applications that have to deal with such servers (typically web - browsers) it may be necessary to implement a version backdown - strategy if the initial attempt fails. - - Setting *offer_version* is also used to offer DTLS instead of TLS; - use :cpp:func:`TLS::Protocol_Version::latest_dtls_version`. + Use the optional *offer_version* to control the version of TLS you + wish the client to offer. Normally, you'll want to offer the most + recent version of (D)TLS that is available, however some broken + servers are intolerant of certain versions being offered, and for + classes of applications that have to deal with such servers + (typically web browsers) it may be necessary to implement a version + backdown strategy if the initial attempt fails. .. warning:: @@ -258,6 +262,9 @@ TLS Clients downgrade your connection to the weakest protocol that both you and the server support. + Setting *offer_version* is also used to offer DTLS instead of TLS; + use :cpp:func:`TLS::Protocol_Version::latest_dtls_version`. + The optional *next_protocol* callback is called if the server indicates it supports the next protocol notification extension. The callback wlil be called with a list of protocol names that the @@ -270,7 +277,7 @@ TLS Clients resized as needed to process inputs). Otherwise some reasonable default is used. -A TLS client example using BSD sockets is in `src/cmd/tls_client.cpp` +Code for a TLS client using BSD sockets is in `src/cmd/tls_client.cpp` TLS Servers ---------------------------------------- @@ -308,8 +315,7 @@ not until they actually receive a hello without this parameter. renegotiation, but might change across different connections using that session. -An example TLS server implementation using asio is available in -`src/cmd/tls_proxy.cpp`. +Code for a TLS server using asio is in `src/cmd/tls_proxy.cpp`. .. _tls_sessions: diff --git a/src/cmd/tls_client.cpp b/src/cmd/tls_client.cpp index 543119fb9..903824a78 100644 --- a/src/cmd/tls_client.cpp +++ b/src/cmd/tls_client.cpp @@ -39,7 +39,7 @@ using namespace std::placeholders; namespace { -int connect_to_host(const std::string& host, u16bit port, const std::string& transport) +int connect_to_host(const std::string& host, u16bit port, bool tcp) { hostent* host_addr = ::gethostbyname(host.c_str()); @@ -49,7 +49,7 @@ int connect_to_host(const std::string& host, u16bit port, const std::string& tra if(host_addr->h_addrtype != AF_INET) // FIXME throw std::runtime_error(host + " has IPv6 address, not supported"); - int type = (transport == "tcp") ? SOCK_STREAM : SOCK_DGRAM; + int type = tcp ? SOCK_STREAM : SOCK_DGRAM; int fd = ::socket(PF_INET, type, 0); if(fd == -1) @@ -130,13 +130,6 @@ void process_data(const byte buf[], size_t buf_size) std::cout << buf[i]; } -std::string protocol_chooser(const std::vector<std::string>& protocols) - { - for(size_t i = 0; i != protocols.size(); ++i) - std::cout << "Protocol " << i << " = " << protocols[i] << "\n"; - return "http/1.1"; - } - int tls_client(int argc, char* argv[]) { if(argc != 2 && argc != 3 && argc != 4) @@ -145,6 +138,9 @@ int tls_client(int argc, char* argv[]) return 1; } + const bool request_protocol = true; + const std::string use_protocol = "http/1.1"; + try { AutoSeeded_RNG rng; @@ -167,14 +163,22 @@ int tls_client(int argc, char* argv[]) u32bit port = argc >= 3 ? Botan::to_u32bit(argv[2]) : 443; const std::string transport = argc >= 4 ? argv[3] : "tcp"; - int sockfd = connect_to_host(host, port, transport); + const bool use_tcp = (transport == "tcp"); + + int sockfd = connect_to_host(host, port, use_tcp); + + auto protocol_chooser = [use_protocol](const std::vector<std::string>& protocols) -> std::string { + for(size_t i = 0; i != protocols.size(); ++i) + std::cout << "Server offered protocol " << i << " = " << protocols[i] << "\n"; + return use_protocol; + }; auto socket_write = - (transport == "tcp") ? + use_tcp ? std::bind(stream_socket_write, sockfd, _1, _2) : std::bind(dgram_socket_write, sockfd, _1, _2); - auto version = policy.latest_supported_version(transport != "tcp"); + auto version = policy.latest_supported_version(!use_tcp); TLS::Client client(socket_write, process_data, diff --git a/src/lib/alloc/secmem.h b/src/lib/alloc/secmem.h index 58d0734cb..82b4083ea 100644 --- a/src/lib/alloc/secmem.h +++ b/src/lib/alloc/secmem.h @@ -11,6 +11,7 @@ #include <botan/mem_ops.h> #include <algorithm> #include <vector> +#include <deque> #if defined(BOTAN_HAS_LOCKING_ALLOCATOR) #include <botan/locking_allocator.h> @@ -90,6 +91,7 @@ operator!=(const secure_allocator<T>&, const secure_allocator<T>&) { return false; } template<typename T> using secure_vector = std::vector<T, secure_allocator<T>>; +template<typename T> using secure_deque = std::deque<T, secure_allocator<T>>; template<typename T> std::vector<T> unlock(const secure_vector<T>& in) diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp index dc5769e2c..b02c9ede1 100644 --- a/src/lib/tls/tls_blocking.cpp +++ b/src/lib/tls/tls_blocking.cpp @@ -13,17 +13,17 @@ namespace TLS { using namespace std::placeholders; -Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, - std::function<void (const byte[], size_t)> write_fn, +Blocking_Client::Blocking_Client(read_fn reader, + write_fn writer, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& server_info, const Protocol_Version offer_version, - std::function<std::string (std::vector<std::string>)> next_protocol) : - m_read_fn(read_fn), - m_channel(write_fn, + next_protocol_fn npn) : + m_read(reader), + m_channel(writer, std::bind(&Blocking_Client::data_cb, this, _1, _2), std::bind(&Blocking_Client::alert_cb, this, _1, _2, _3), std::bind(&Blocking_Client::handshake_cb, this, _1), @@ -33,7 +33,7 @@ Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, rng, server_info, offer_version, - next_protocol) + npn) { } @@ -58,7 +58,7 @@ void Blocking_Client::do_handshake() while(!m_channel.is_closed() && !m_channel.is_active()) { - const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size()); + const size_t from_socket = m_read(&readbuf[0], readbuf.size()); m_channel.received_data(&readbuf[0], from_socket); } } @@ -69,7 +69,7 @@ size_t Blocking_Client::read(byte buf[], size_t buf_len) while(m_plaintext.empty() && !m_channel.is_closed()) { - const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size()); + const size_t from_socket = m_read(&readbuf[0], readbuf.size()); m_channel.received_data(&readbuf[0], from_socket); } diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h index 1226d9364..ca6906545 100644 --- a/src/lib/tls/tls_blocking.h +++ b/src/lib/tls/tls_blocking.h @@ -14,27 +14,35 @@ namespace Botan { -template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>; +//template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>; namespace TLS { /** * Blocking TLS Client +* Can be used directly, or subclass to get handshake and alert notifications */ class BOTAN_DLL Blocking_Client { public: + /* + * These functions are expected to block until completing entirely, or + * fail by throwing an exception. + */ + typedef std::function<size_t (byte[], size_t)> read_fn; + typedef std::function<void (const byte[], size_t)> write_fn; + + typedef Client::next_protocol_fn next_protocol_fn; - Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, - std::function<void (const byte[], size_t)> write_fn, - Session_Manager& session_manager, - Credentials_Manager& creds, - const Policy& policy, - RandomNumberGenerator& rng, - const Server_Information& server_info = Server_Information(), - const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>()); + Blocking_Client(read_fn reader, + write_fn writer, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), + next_protocol_fn npn = next_protocol_fn()); /** * Completes full handshake then returns @@ -68,12 +76,12 @@ class BOTAN_DLL Blocking_Client protected: /** - * Can override to get the handshake complete notification + * Application can override to get the handshake complete notification */ virtual bool handshake_complete(const Session&) { return true; } /** - * Can override to get notification of alerts + * Application can override to get notification of alerts */ virtual void alert_notification(const Alert&) {} @@ -85,9 +93,9 @@ class BOTAN_DLL Blocking_Client void alert_cb(const Alert alert, const byte data[], size_t data_len); - std::function<size_t (byte[], size_t)> m_read_fn; + read_fn m_read; TLS::Client m_channel; - secure_deque<byte> m_plaintext; + secure_vector<byte> m_plaintext; }; } diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index a5e504f8b..e784566cd 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -19,10 +19,10 @@ namespace Botan { namespace TLS { -Channel::Channel(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Channel::Channel(output_fn output_fn, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, bool is_datagram, @@ -124,8 +124,8 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) const u16bit mtu = 1280 - 40 - 8; io.reset(new Datagram_Handshake_IO( - sequence_numbers(), std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3), + sequence_numbers(), mtu)); } else diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 5b5a5d530..713d4c1b9 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -31,10 +31,15 @@ class Handshake_State; class BOTAN_DLL Channel { public: - Channel(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + typedef std::function<void (const byte[], size_t)> output_fn; + typedef std::function<void (const byte[], size_t)> data_cb; + typedef std::function<void (Alert, const byte[], size_t)> alert_cb; + typedef std::function<bool (const Session&)> handshake_cb; + + Channel(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, bool is_datagram, @@ -240,10 +245,10 @@ class BOTAN_DLL Channel bool m_is_datagram; /* callbacks */ - std::function<bool (const Session&)> m_handshake_cb; - std::function<void (const byte[], size_t)> m_data_cb; - std::function<void (Alert, const byte[], size_t)> m_alert_cb; - std::function<void (const byte[], size_t)> m_output_fn; + handshake_cb m_handshake_cb; + data_cb m_data_cb; + alert_cb m_alert_cb; + output_fn m_output_fn; /* external state */ RandomNumberGenerator& m_rng; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index f68c9c614..75df6332a 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -22,10 +22,8 @@ class Client_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Client_Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback = - std::function<void (const Handshake_Message&)>()) : - Handshake_State(io, msg_callback) {} + Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : + Handshake_State(io, cb) {} const Public_Key& get_server_public_Key() const { @@ -39,7 +37,7 @@ class Client_Handshake_State : public Handshake_State std::unique_ptr<Public_Key> server_public_key; // Used by client using NPN - std::function<std::string (std::vector<std::string>)> client_npn_cb; + Client::next_protocol_fn client_npn_cb; }; } @@ -47,17 +45,17 @@ class Client_Handshake_State : public Handshake_State /* * TLS Client Constructor */ -Client::Client(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> proc_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Client::Client(output_fn output_fn, + data_cb proc_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& info, const Protocol_Version offer_version, - std::function<std::string (std::vector<std::string>)> next_protocol, + next_protocol_fn npn, size_t io_buf_sz) : Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, offer_version.is_datagram_protocol(), io_buf_sz), @@ -68,12 +66,12 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); Handshake_State& state = create_handshake_state(offer_version); - send_client_hello(state, false, offer_version, srp_identifier, next_protocol); + send_client_hello(state, false, offer_version, srp_identifier, npn); } Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - return new Client_Handshake_State(io); + return new Client_Handshake_State(io); // , m_hs_msg_cb); } std::vector<X509_Certificate> @@ -99,7 +97,7 @@ void Client::send_client_hello(Handshake_State& state_base, bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier, - std::function<std::string (std::vector<std::string>)> next_protocol) + next_protocol_fn next_protocol) { Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); @@ -167,16 +165,19 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.client_hello()) return; - if(!m_policy.allow_server_initiated_renegotiation() || - (!m_policy.allow_insecure_renegotiation() && !secure_renegotiation_supported())) + if(m_policy.allow_server_initiated_renegotiation()) + { + if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false) + send_warning_alert(Alert::NO_RENEGOTIATION); + else + this->initiate_handshake(state, false); + } + else { // RFC 5746 section 4.2 send_warning_alert(Alert::NO_RENEGOTIATION); - return; } - this->initiate_handshake(state, false); - return; } diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index 126dbb935..a548a32e0 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -25,9 +25,9 @@ class BOTAN_DLL Client : public Channel /** * Set up a new TLS client session * - * @param socket_output_fn is called with data for the outbound socket + * @param output_fn is called with data for the outbound socket * - * @param proc_cb is called when new application data is received + * @param app_data_cb is called when new application data is received * * @param alert_cb is called when a TLS alert is received * @@ -59,18 +59,20 @@ class BOTAN_DLL Client : public Channel * be preallocated for the read and write buffers. Smaller * values just mean reallocations and copies are more likely. */ - Client(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + + typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn; + + Client(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& server_info = Server_Information(), const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>(), + next_protocol_fn next_protocol = next_protocol_fn(), size_t reserved_io_buffer_size = 16*1024 ); private: @@ -84,8 +86,7 @@ class BOTAN_DLL Client : public Channel bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier = "", - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>()); + next_protocol_fn next_protocol = next_protocol_fn()); void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 34873c3a6..00074a744 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -65,8 +65,9 @@ class Handshake_IO class Stream_Handshake_IO : public Handshake_IO { public: - Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) : - m_send_hs(writer) {} + typedef std::function<void (byte, const std::vector<byte>&)> writer_fn; + + Stream_Handshake_IO(writer_fn writer) : m_send_hs(writer) {} Protocol_Version initial_record_version() const override; @@ -86,7 +87,7 @@ class Stream_Handshake_IO : public Handshake_IO get_next_record(bool expecting_ccs) override; private: std::deque<byte> m_queue; - std::function<void (byte, const std::vector<byte>&)> m_send_hs; + writer_fn m_send_hs; }; /** @@ -95,8 +96,10 @@ class Stream_Handshake_IO : public Handshake_IO class Datagram_Handshake_IO : public Handshake_IO { public: - Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq, - std::function<void (u16bit, byte, const std::vector<byte>&)> writer, + typedef std::function<void (u16bit, byte, const std::vector<byte>&)> writer_fn; + + Datagram_Handshake_IO(writer_fn writer, + class Connection_Sequence_Numbers& seq, u16bit mtu) : m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {} @@ -186,7 +189,7 @@ class Datagram_Handshake_IO : public Handshake_IO u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; - std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs; + writer_fn m_send_hs; u16bit m_mtu; }; diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index 111087041..64d35fd6b 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -83,9 +83,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback) : - m_msg_callback(msg_callback), +Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) : + m_msg_callback(cb), m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) { diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 3c8d856d7..bb2abc209 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -46,9 +46,9 @@ class Finished; class Handshake_State { public: - Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback = - std::function<void (const Handshake_Message&)>()); + typedef std::function<void (const Handshake_Message&)> hs_msg_cb; + + Handshake_State(Handshake_IO* io, hs_msg_cb cb); virtual ~Handshake_State(); @@ -176,7 +176,7 @@ class Handshake_State private: - std::function<void (const Handshake_Message&)> m_msg_callback; + hs_msg_cb m_msg_callback; std::unique_ptr<Handshake_IO> m_handshake_io; diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 3edeab7e3..d5e3126f1 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -455,7 +455,7 @@ size_t read_tls_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { consumed = 0; @@ -543,7 +543,7 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { consumed = 0; @@ -642,7 +642,7 @@ size_t read_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { if(is_datagram) return read_dtls_record(readbuf, input, input_sz, consumed, diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index c9bf8aade..46f87a9af 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -113,6 +113,9 @@ void write_record(secure_vector<byte>& write_buffer, Connection_Cipher_State* cipherstate, RandomNumberGenerator& rng); +// epoch -> cipher state +typedef std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate_fn; + /** * Decode a TLS record * @return zero if full message, else number of bytes still needed @@ -127,7 +130,7 @@ size_t read_record(secure_vector<byte>& read_buffer, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate); + get_cipherstate_fn get_cipherstate); } diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 1490fc2a4..515bd9e17 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -21,7 +21,8 @@ class Server_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Server_Handshake_State(Handshake_IO* io) : Handshake_State(io) {} + Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : + Handshake_State(io, cb) {} // Used by the server only, in case of RSA key exchange. Not owned Private_Key* server_rsa_kex_key = nullptr; @@ -203,10 +204,10 @@ get_server_certs(const std::string& hostname, /* * TLS Server Constructor */ -Server::Server(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Server::Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, @@ -214,7 +215,8 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, const std::vector<std::string>& next_protocols, bool is_datagram, size_t io_buf_sz) : - Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz), + Channel(output, data_cb, alert_cb, handshake_cb, + session_manager, rng, is_datagram, io_buf_sz), m_policy(policy), m_creds(creds), m_possible_protocols(next_protocols) diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index ce82e001d..4b15e837b 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -25,10 +25,10 @@ class BOTAN_DLL Server : public Channel /** * Server initialization */ - Server(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, @@ -40,7 +40,9 @@ class BOTAN_DLL Server : public Channel /** * Return the protocol notification set by the client (using the - * NPN extension) for this connection, if any + * NPN extension) for this connection, if any. This value is not + * tied to the session and a later renegotiation of the same + * session can choose a new protocol. */ std::string next_protocol() const { return m_next_protocol; } |