diff options
Diffstat (limited to 'src/lib/tls')
-rw-r--r-- | src/lib/tls/tls_blocking.cpp | 16 | ||||
-rw-r--r-- | src/lib/tls/tls_blocking.h | 38 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 10 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 21 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 37 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 21 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.cpp | 5 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.h | 8 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 5 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 14 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 12 |
13 files changed, 116 insertions, 92 deletions
diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp index dc5769e2c..b02c9ede1 100644 --- a/src/lib/tls/tls_blocking.cpp +++ b/src/lib/tls/tls_blocking.cpp @@ -13,17 +13,17 @@ namespace TLS { using namespace std::placeholders; -Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, - std::function<void (const byte[], size_t)> write_fn, +Blocking_Client::Blocking_Client(read_fn reader, + write_fn writer, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& server_info, const Protocol_Version offer_version, - std::function<std::string (std::vector<std::string>)> next_protocol) : - m_read_fn(read_fn), - m_channel(write_fn, + next_protocol_fn npn) : + m_read(reader), + m_channel(writer, std::bind(&Blocking_Client::data_cb, this, _1, _2), std::bind(&Blocking_Client::alert_cb, this, _1, _2, _3), std::bind(&Blocking_Client::handshake_cb, this, _1), @@ -33,7 +33,7 @@ Blocking_Client::Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, rng, server_info, offer_version, - next_protocol) + npn) { } @@ -58,7 +58,7 @@ void Blocking_Client::do_handshake() while(!m_channel.is_closed() && !m_channel.is_active()) { - const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size()); + const size_t from_socket = m_read(&readbuf[0], readbuf.size()); m_channel.received_data(&readbuf[0], from_socket); } } @@ -69,7 +69,7 @@ size_t Blocking_Client::read(byte buf[], size_t buf_len) while(m_plaintext.empty() && !m_channel.is_closed()) { - const size_t from_socket = m_read_fn(&readbuf[0], readbuf.size()); + const size_t from_socket = m_read(&readbuf[0], readbuf.size()); m_channel.received_data(&readbuf[0], from_socket); } diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h index 1226d9364..ca6906545 100644 --- a/src/lib/tls/tls_blocking.h +++ b/src/lib/tls/tls_blocking.h @@ -14,27 +14,35 @@ namespace Botan { -template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>; +//template<typename T> using secure_deque = std::vector<T, secure_allocator<T>>; namespace TLS { /** * Blocking TLS Client +* Can be used directly, or subclass to get handshake and alert notifications */ class BOTAN_DLL Blocking_Client { public: + /* + * These functions are expected to block until completing entirely, or + * fail by throwing an exception. + */ + typedef std::function<size_t (byte[], size_t)> read_fn; + typedef std::function<void (const byte[], size_t)> write_fn; + + typedef Client::next_protocol_fn next_protocol_fn; - Blocking_Client(std::function<size_t (byte[], size_t)> read_fn, - std::function<void (const byte[], size_t)> write_fn, - Session_Manager& session_manager, - Credentials_Manager& creds, - const Policy& policy, - RandomNumberGenerator& rng, - const Server_Information& server_info = Server_Information(), - const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>()); + Blocking_Client(read_fn reader, + write_fn writer, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), + next_protocol_fn npn = next_protocol_fn()); /** * Completes full handshake then returns @@ -68,12 +76,12 @@ class BOTAN_DLL Blocking_Client protected: /** - * Can override to get the handshake complete notification + * Application can override to get the handshake complete notification */ virtual bool handshake_complete(const Session&) { return true; } /** - * Can override to get notification of alerts + * Application can override to get notification of alerts */ virtual void alert_notification(const Alert&) {} @@ -85,9 +93,9 @@ class BOTAN_DLL Blocking_Client void alert_cb(const Alert alert, const byte data[], size_t data_len); - std::function<size_t (byte[], size_t)> m_read_fn; + read_fn m_read; TLS::Client m_channel; - secure_deque<byte> m_plaintext; + secure_vector<byte> m_plaintext; }; } diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index a5e504f8b..e784566cd 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -19,10 +19,10 @@ namespace Botan { namespace TLS { -Channel::Channel(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Channel::Channel(output_fn output_fn, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, bool is_datagram, @@ -124,8 +124,8 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) const u16bit mtu = 1280 - 40 - 8; io.reset(new Datagram_Handshake_IO( - sequence_numbers(), std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3), + sequence_numbers(), mtu)); } else diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 5b5a5d530..713d4c1b9 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -31,10 +31,15 @@ class Handshake_State; class BOTAN_DLL Channel { public: - Channel(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + typedef std::function<void (const byte[], size_t)> output_fn; + typedef std::function<void (const byte[], size_t)> data_cb; + typedef std::function<void (Alert, const byte[], size_t)> alert_cb; + typedef std::function<bool (const Session&)> handshake_cb; + + Channel(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, bool is_datagram, @@ -240,10 +245,10 @@ class BOTAN_DLL Channel bool m_is_datagram; /* callbacks */ - std::function<bool (const Session&)> m_handshake_cb; - std::function<void (const byte[], size_t)> m_data_cb; - std::function<void (Alert, const byte[], size_t)> m_alert_cb; - std::function<void (const byte[], size_t)> m_output_fn; + handshake_cb m_handshake_cb; + data_cb m_data_cb; + alert_cb m_alert_cb; + output_fn m_output_fn; /* external state */ RandomNumberGenerator& m_rng; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index f68c9c614..75df6332a 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -22,10 +22,8 @@ class Client_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Client_Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback = - std::function<void (const Handshake_Message&)>()) : - Handshake_State(io, msg_callback) {} + Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : + Handshake_State(io, cb) {} const Public_Key& get_server_public_Key() const { @@ -39,7 +37,7 @@ class Client_Handshake_State : public Handshake_State std::unique_ptr<Public_Key> server_public_key; // Used by client using NPN - std::function<std::string (std::vector<std::string>)> client_npn_cb; + Client::next_protocol_fn client_npn_cb; }; } @@ -47,17 +45,17 @@ class Client_Handshake_State : public Handshake_State /* * TLS Client Constructor */ -Client::Client(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> proc_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Client::Client(output_fn output_fn, + data_cb proc_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& info, const Protocol_Version offer_version, - std::function<std::string (std::vector<std::string>)> next_protocol, + next_protocol_fn npn, size_t io_buf_sz) : Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, offer_version.is_datagram_protocol(), io_buf_sz), @@ -68,12 +66,12 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); Handshake_State& state = create_handshake_state(offer_version); - send_client_hello(state, false, offer_version, srp_identifier, next_protocol); + send_client_hello(state, false, offer_version, srp_identifier, npn); } Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - return new Client_Handshake_State(io); + return new Client_Handshake_State(io); // , m_hs_msg_cb); } std::vector<X509_Certificate> @@ -99,7 +97,7 @@ void Client::send_client_hello(Handshake_State& state_base, bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier, - std::function<std::string (std::vector<std::string>)> next_protocol) + next_protocol_fn next_protocol) { Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); @@ -167,16 +165,19 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.client_hello()) return; - if(!m_policy.allow_server_initiated_renegotiation() || - (!m_policy.allow_insecure_renegotiation() && !secure_renegotiation_supported())) + if(m_policy.allow_server_initiated_renegotiation()) + { + if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false) + send_warning_alert(Alert::NO_RENEGOTIATION); + else + this->initiate_handshake(state, false); + } + else { // RFC 5746 section 4.2 send_warning_alert(Alert::NO_RENEGOTIATION); - return; } - this->initiate_handshake(state, false); - return; } diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index 126dbb935..a548a32e0 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -25,9 +25,9 @@ class BOTAN_DLL Client : public Channel /** * Set up a new TLS client session * - * @param socket_output_fn is called with data for the outbound socket + * @param output_fn is called with data for the outbound socket * - * @param proc_cb is called when new application data is received + * @param app_data_cb is called when new application data is received * * @param alert_cb is called when a TLS alert is received * @@ -59,18 +59,20 @@ class BOTAN_DLL Client : public Channel * be preallocated for the read and write buffers. Smaller * values just mean reallocations and copies are more likely. */ - Client(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + + typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn; + + Client(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, const Server_Information& server_info = Server_Information(), const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>(), + next_protocol_fn next_protocol = next_protocol_fn(), size_t reserved_io_buffer_size = 16*1024 ); private: @@ -84,8 +86,7 @@ class BOTAN_DLL Client : public Channel bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier = "", - std::function<std::string (std::vector<std::string>)> next_protocol = - std::function<std::string (std::vector<std::string>)>()); + next_protocol_fn next_protocol = next_protocol_fn()); void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 34873c3a6..00074a744 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -65,8 +65,9 @@ class Handshake_IO class Stream_Handshake_IO : public Handshake_IO { public: - Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) : - m_send_hs(writer) {} + typedef std::function<void (byte, const std::vector<byte>&)> writer_fn; + + Stream_Handshake_IO(writer_fn writer) : m_send_hs(writer) {} Protocol_Version initial_record_version() const override; @@ -86,7 +87,7 @@ class Stream_Handshake_IO : public Handshake_IO get_next_record(bool expecting_ccs) override; private: std::deque<byte> m_queue; - std::function<void (byte, const std::vector<byte>&)> m_send_hs; + writer_fn m_send_hs; }; /** @@ -95,8 +96,10 @@ class Stream_Handshake_IO : public Handshake_IO class Datagram_Handshake_IO : public Handshake_IO { public: - Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq, - std::function<void (u16bit, byte, const std::vector<byte>&)> writer, + typedef std::function<void (u16bit, byte, const std::vector<byte>&)> writer_fn; + + Datagram_Handshake_IO(writer_fn writer, + class Connection_Sequence_Numbers& seq, u16bit mtu) : m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {} @@ -186,7 +189,7 @@ class Datagram_Handshake_IO : public Handshake_IO u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; - std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs; + writer_fn m_send_hs; u16bit m_mtu; }; diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index 111087041..64d35fd6b 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -83,9 +83,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback) : - m_msg_callback(msg_callback), +Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) : + m_msg_callback(cb), m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) { diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 3c8d856d7..bb2abc209 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -46,9 +46,9 @@ class Finished; class Handshake_State { public: - Handshake_State(Handshake_IO* io, - std::function<void (const Handshake_Message&)> msg_callback = - std::function<void (const Handshake_Message&)>()); + typedef std::function<void (const Handshake_Message&)> hs_msg_cb; + + Handshake_State(Handshake_IO* io, hs_msg_cb cb); virtual ~Handshake_State(); @@ -176,7 +176,7 @@ class Handshake_State private: - std::function<void (const Handshake_Message&)> m_msg_callback; + hs_msg_cb m_msg_callback; std::unique_ptr<Handshake_IO> m_handshake_io; diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 3edeab7e3..d5e3126f1 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -455,7 +455,7 @@ size_t read_tls_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { consumed = 0; @@ -543,7 +543,7 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { consumed = 0; @@ -642,7 +642,7 @@ size_t read_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + get_cipherstate_fn get_cipherstate) { if(is_datagram) return read_dtls_record(readbuf, input, input_sz, consumed, diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index c9bf8aade..46f87a9af 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -113,6 +113,9 @@ void write_record(secure_vector<byte>& write_buffer, Connection_Cipher_State* cipherstate, RandomNumberGenerator& rng); +// epoch -> cipher state +typedef std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate_fn; + /** * Decode a TLS record * @return zero if full message, else number of bytes still needed @@ -127,7 +130,7 @@ size_t read_record(secure_vector<byte>& read_buffer, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate); + get_cipherstate_fn get_cipherstate); } diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 1490fc2a4..515bd9e17 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -21,7 +21,8 @@ class Server_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Server_Handshake_State(Handshake_IO* io) : Handshake_State(io) {} + Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : + Handshake_State(io, cb) {} // Used by the server only, in case of RSA key exchange. Not owned Private_Key* server_rsa_kex_key = nullptr; @@ -203,10 +204,10 @@ get_server_certs(const std::string& hostname, /* * TLS Server Constructor */ -Server::Server(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, +Server::Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, @@ -214,7 +215,8 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, const std::vector<std::string>& next_protocols, bool is_datagram, size_t io_buf_sz) : - Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz), + Channel(output, data_cb, alert_cb, handshake_cb, + session_manager, rng, is_datagram, io_buf_sz), m_policy(policy), m_creds(creds), m_possible_protocols(next_protocols) diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index ce82e001d..4b15e837b 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -25,10 +25,10 @@ class BOTAN_DLL Server : public Channel /** * Server initialization */ - Server(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t)> data_cb, - std::function<void (Alert, const byte[], size_t)> alert_cb, - std::function<bool (const Session&)> handshake_cb, + Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, Session_Manager& session_manager, Credentials_Manager& creds, const Policy& policy, @@ -40,7 +40,9 @@ class BOTAN_DLL Server : public Channel /** * Return the protocol notification set by the client (using the - * NPN extension) for this connection, if any + * NPN extension) for this connection, if any. This value is not + * tied to the session and a later renegotiation of the same + * session can choose a new protocol. */ std::string next_protocol() const { return m_next_protocol; } |