diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/info.txt | 1 | ||||
-rw-r--r-- | src/tls/msg_client_hello.cpp | 7 | ||||
-rw-r--r-- | src/tls/msg_server_hello.cpp | 4 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_msg.h | 36 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 17 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 14 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 93 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 2 |
9 files changed, 117 insertions, 66 deletions
diff --git a/src/tls/info.txt b/src/tls/info.txt index 5294d3026..ffd0f5ad1 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -13,6 +13,7 @@ tls_channel.h tls_ciphersuite.h tls_client.h tls_exceptn.h +tls_handshake_msg.h tls_magic.h tls_policy.h tls_record.h diff --git a/src/tls/msg_client_hello.cpp b/src/tls/msg_client_hello.cpp index 52536e79c..30d34ef78 100644 --- a/src/tls/msg_client_hello.cpp +++ b/src/tls/msg_client_hello.cpp @@ -143,13 +143,6 @@ Client_Hello::Client_Hello(Handshake_IO& io, */ Client_Hello::Client_Hello(const std::vector<byte>& buf, Handshake_Type type) { - m_next_protocol = false; - m_secure_renegotiation = false; - m_supports_session_ticket = false; - m_supports_heartbeats = false; - m_peer_can_send_heartbeats = false; - m_fragment_size = 0; - if(type == CLIENT_HELLO) deserialize(buf); else diff --git a/src/tls/msg_server_hello.cpp b/src/tls/msg_server_hello.cpp index 8d151b2b0..941fee8e4 100644 --- a/src/tls/msg_server_hello.cpp +++ b/src/tls/msg_server_hello.cpp @@ -55,10 +55,6 @@ Server_Hello::Server_Hello(Handshake_IO& io, */ Server_Hello::Server_Hello(const std::vector<byte>& buf) { - m_secure_renegotiation = false; - m_supports_session_ticket = false; - m_next_protocol = false; - if(buf.size() < 38) throw Decoding_Error("Server_Hello: Packet corrupted"); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 61156d71b..aed524dbe 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -20,7 +20,12 @@ namespace { class Client_Handshake_State : public Handshake_State { public: - Client_Handshake_State(Handshake_IO* io) : Handshake_State(io) {} + // using Handshake_State::Handshake_State; + + Client_Handshake_State(Handshake_IO* io, + std::function<void (const Handshake_Message&)> msg_callback = + std::function<void (const Handshake_Message&)>()) : + Handshake_State(io, msg_callback) {} // Used during session resumption secure_vector<byte> resume_master_secret; @@ -186,6 +191,8 @@ void Client::process_handshake_msg(Handshake_Type type, Hello_Verify_Request hello_verify_request(contents); + m_state->note_message(hello_verify_request); + std::unique_ptr<Client_Hello> client_hello_w_cookie( new Client_Hello(m_state->handshake_io(), m_state->hash(), diff --git a/src/tls/tls_handshake_msg.h b/src/tls/tls_handshake_msg.h new file mode 100644 index 000000000..1c44554d3 --- /dev/null +++ b/src/tls/tls_handshake_msg.h @@ -0,0 +1,36 @@ +/* +* TLS Handshake Message +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_HANDSHAKE_MSG_H__ +#define BOTAN_TLS_HANDSHAKE_MSG_H__ + +#include <botan/tls_magic.h> +#include <vector> +#include <string> + +namespace Botan { + +namespace TLS { + +/** +* TLS Handshake Message Base Class +*/ +class BOTAN_DLL Handshake_Message + { + public: + virtual Handshake_Type type() const = 0; + + virtual std::vector<byte> serialize() const = 0; + + virtual ~Handshake_Message() {} + }; + +} + +} + +#endif diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 4be0c58e7..082461dc9 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -85,8 +85,10 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io) : +Handshake_State::Handshake_State(Handshake_IO* io, + std::function<void (const Handshake_Message&)> msg_callback) : m_handshake_io(io), + m_msg_callback(msg_callback), m_version(m_handshake_io->initial_record_version()) { } @@ -96,67 +98,80 @@ Handshake_State::~Handshake_State() {} void Handshake_State::client_hello(Client_Hello* client_hello) { m_client_hello.reset(client_hello); + note_message(*m_client_hello); } void Handshake_State::server_hello(Server_Hello* server_hello) { m_server_hello.reset(server_hello); m_ciphersuite = Ciphersuite::by_id(m_server_hello->ciphersuite()); + note_message(*m_server_hello); } void Handshake_State::server_certs(Certificate* server_certs) { m_server_certs.reset(server_certs); + note_message(*m_server_certs); } void Handshake_State::server_kex(Server_Key_Exchange* server_kex) { m_server_kex.reset(server_kex); + note_message(*m_server_kex); } void Handshake_State::cert_req(Certificate_Req* cert_req) { m_cert_req.reset(cert_req); + note_message(*m_cert_req); } void Handshake_State::server_hello_done(Server_Hello_Done* server_hello_done) { m_server_hello_done.reset(server_hello_done); + note_message(*m_server_hello_done); } void Handshake_State::client_certs(Certificate* client_certs) { m_client_certs.reset(client_certs); + note_message(*m_client_certs); } void Handshake_State::client_kex(Client_Key_Exchange* client_kex) { m_client_kex.reset(client_kex); + note_message(*m_client_kex); } void Handshake_State::client_verify(Certificate_Verify* client_verify) { m_client_verify.reset(client_verify); + note_message(*m_client_verify); } void Handshake_State::next_protocol(Next_Protocol* next_protocol) { m_next_protocol.reset(next_protocol); + note_message(*m_next_protocol); } void Handshake_State::new_session_ticket(New_Session_Ticket* new_session_ticket) { m_new_session_ticket.reset(new_session_ticket); + note_message(*m_new_session_ticket); } void Handshake_State::server_finished(Finished* server_finished) { m_server_finished.reset(server_finished); + note_message(*m_server_finished); } void Handshake_State::client_finished(Finished* client_finished) { m_client_finished.reset(client_finished); + note_message(*m_client_finished); } void Handshake_State::set_version(const Protocol_Version& version) diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index d0a03e2d9..81a603c6f 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -11,6 +11,7 @@ #include <botan/internal/tls_handshake_hash.h> #include <botan/internal/tls_handshake_io.h> #include <botan/internal/tls_session_key.h> +#include <botan/tls_handshake_msg.h> #include <botan/pk_keys.h> #include <botan/pubkey.h> #include <functional> @@ -44,7 +45,9 @@ class Finished; class Handshake_State { public: - Handshake_State(Handshake_IO* io); + Handshake_State(Handshake_IO* io, + std::function<void (const Handshake_Message&)> msg_callback = + std::function<void (const Handshake_Message&)>()); virtual ~Handshake_State(); @@ -146,9 +149,18 @@ class Handshake_State const Handshake_Hash& hash() const { return m_handshake_hash; } + void note_message(const Handshake_Message& msg) + { + if(m_msg_callback) + m_msg_callback(msg); + } + private: + std::unique_ptr<Handshake_IO> m_handshake_io; + std::function<void (const Handshake_Message&)> m_msg_callback; + u32bit m_hand_expecting_mask = 0; u32bit m_hand_received_mask = 0; Protocol_Version m_version; diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 29f75c58e..f162a8cce 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -9,15 +9,16 @@ #define BOTAN_TLS_MESSAGES_H__ #include <botan/internal/tls_handshake_state.h> +#include <botan/tls_handshake_msg.h> #include <botan/tls_session.h> #include <botan/tls_policy.h> -#include <botan/tls_magic.h> #include <botan/tls_ciphersuite.h> #include <botan/bigint.h> #include <botan/pkcs8.h> #include <botan/x509cert.h> #include <vector> #include <memory> +#include <string> namespace Botan { @@ -28,19 +29,6 @@ namespace TLS { class Handshake_IO; -/** -* TLS Handshake Message Base Class -*/ -class Handshake_Message - { - public: - virtual std::vector<byte> serialize() const = 0; - virtual Handshake_Type type() const = 0; - - Handshake_Message() {} - virtual ~Handshake_Message() {} - }; - std::vector<byte> make_hello_random(RandomNumberGenerator& rng); /** @@ -49,8 +37,8 @@ std::vector<byte> make_hello_random(RandomNumberGenerator& rng); class Hello_Verify_Request : public Handshake_Message { public: - std::vector<byte> serialize() const; - Handshake_Type type() const { return HELLO_VERIFY_REQUEST; } + std::vector<byte> serialize() const override; + Handshake_Type type() const override { return HELLO_VERIFY_REQUEST; } std::vector<byte> cookie() const { return m_cookie; } @@ -69,7 +57,7 @@ class Hello_Verify_Request : public Handshake_Message class Client_Hello : public Handshake_Message { public: - Handshake_Type type() const { return CLIENT_HELLO; } + Handshake_Type type() const override { return CLIENT_HELLO; } Protocol_Version version() const { return m_version; } @@ -137,7 +125,7 @@ class Client_Hello : public Handshake_Message Handshake_Type type); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; void deserialize(const std::vector<byte>& buf); void deserialize_sslv2(const std::vector<byte>& buf); @@ -147,22 +135,22 @@ class Client_Hello : public Handshake_Message std::vector<byte> m_comp_methods; std::string m_hostname; std::string m_srp_identifier; - bool m_next_protocol; + bool m_next_protocol = false; - size_t m_fragment_size; - bool m_secure_renegotiation; + size_t m_fragment_size = 0; + bool m_secure_renegotiation = false; std::vector<byte> m_renegotiation_info; std::vector<std::pair<std::string, std::string> > m_supported_algos; std::vector<std::string> m_supported_curves; - bool m_supports_session_ticket; + bool m_supports_session_ticket = false; std::vector<byte> m_session_ticket; std::vector<byte> m_hello_cookie; - bool m_supports_heartbeats; - bool m_peer_can_send_heartbeats; + bool m_supports_heartbeats = false; + bool m_peer_can_send_heartbeats = false; }; /** @@ -171,7 +159,7 @@ class Client_Hello : public Handshake_Message class Server_Hello : public Handshake_Message { public: - Handshake_Type type() const { return SERVER_HELLO; } + Handshake_Type type() const override { return SERVER_HELLO; } Protocol_Version version() const { return m_version; } @@ -218,23 +206,23 @@ class Server_Hello : public Handshake_Message Server_Hello(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; Protocol_Version m_version; std::vector<byte> m_session_id, m_random; u16bit m_ciphersuite; byte m_comp_method; - size_t m_fragment_size; - bool m_secure_renegotiation; + size_t m_fragment_size = 0; + bool m_secure_renegotiation = false; std::vector<byte> m_renegotiation_info; - bool m_next_protocol; + bool m_next_protocol = false; std::vector<std::string> m_next_protocols; - bool m_supports_session_ticket; + bool m_supports_session_ticket = false; - bool m_supports_heartbeats; - bool m_peer_can_send_heartbeats; + bool m_supports_heartbeats = false; + bool m_peer_can_send_heartbeats = false; }; /** @@ -243,7 +231,7 @@ class Server_Hello : public Handshake_Message class Client_Key_Exchange : public Handshake_Message { public: - Handshake_Type type() const { return CLIENT_KEX; } + Handshake_Type type() const override { return CLIENT_KEX; } const secure_vector<byte>& pre_master_secret() const { return m_pre_master; } @@ -264,7 +252,8 @@ class Client_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng); private: - std::vector<byte> serialize() const { return m_key_material; } + std::vector<byte> serialize() const override + { return m_key_material; } std::vector<byte> m_key_material; secure_vector<byte> m_pre_master; @@ -276,7 +265,7 @@ class Client_Key_Exchange : public Handshake_Message class Certificate : public Handshake_Message { public: - Handshake_Type type() const { return CERTIFICATE; } + Handshake_Type type() const override { return CERTIFICATE; } const std::vector<X509_Certificate>& cert_chain() const { return m_certs; } size_t count() const { return m_certs.size(); } @@ -288,7 +277,7 @@ class Certificate : public Handshake_Message Certificate(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; std::vector<X509_Certificate> m_certs; }; @@ -299,7 +288,7 @@ class Certificate : public Handshake_Message class Certificate_Req : public Handshake_Message { public: - Handshake_Type type() const { return CERTIFICATE_REQUEST; } + Handshake_Type type() const override { return CERTIFICATE_REQUEST; } const std::vector<std::string>& acceptable_cert_types() const { return m_cert_key_types; } @@ -318,7 +307,7 @@ class Certificate_Req : public Handshake_Message Certificate_Req(const std::vector<byte>& buf, Protocol_Version version); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; std::vector<X509_DN> m_names; std::vector<std::string> m_cert_key_types; @@ -332,7 +321,7 @@ class Certificate_Req : public Handshake_Message class Certificate_Verify : public Handshake_Message { public: - Handshake_Type type() const { return CERTIFICATE_VERIFY; } + Handshake_Type type() const override { return CERTIFICATE_VERIFY; } /** * Check the signature on a certificate verify message @@ -351,7 +340,7 @@ class Certificate_Verify : public Handshake_Message Certificate_Verify(const std::vector<byte>& buf, Protocol_Version version); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; std::string m_sig_algo; // sig algo used to create signature std::string m_hash_algo; // hash used to create signature @@ -364,7 +353,7 @@ class Certificate_Verify : public Handshake_Message class Finished : public Handshake_Message { public: - Handshake_Type type() const { return FINISHED; } + Handshake_Type type() const override { return FINISHED; } std::vector<byte> verify_data() const { return m_verification_data; } @@ -378,7 +367,7 @@ class Finished : public Handshake_Message Finished(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; Connection_Side m_side; std::vector<byte> m_verification_data; @@ -390,12 +379,12 @@ class Finished : public Handshake_Message class Hello_Request : public Handshake_Message { public: - Handshake_Type type() const { return HELLO_REQUEST; } + Handshake_Type type() const override { return HELLO_REQUEST; } Hello_Request(Handshake_IO& io); Hello_Request(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; }; /** @@ -404,7 +393,7 @@ class Hello_Request : public Handshake_Message class Server_Key_Exchange : public Handshake_Message { public: - Handshake_Type type() const { return SERVER_KEX; } + Handshake_Type type() const override { return SERVER_KEX; } const std::vector<byte>& params() const { return m_params; } @@ -431,7 +420,7 @@ class Server_Key_Exchange : public Handshake_Message ~Server_Key_Exchange(); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; std::unique_ptr<Private_Key> m_kex_key; std::unique_ptr<SRP6_Server_Session> m_srp_params; @@ -449,12 +438,12 @@ class Server_Key_Exchange : public Handshake_Message class Server_Hello_Done : public Handshake_Message { public: - Handshake_Type type() const { return SERVER_HELLO_DONE; } + Handshake_Type type() const override { return SERVER_HELLO_DONE; } Server_Hello_Done(Handshake_IO& io, Handshake_Hash& hash); Server_Hello_Done(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; }; /** @@ -463,7 +452,7 @@ class Server_Hello_Done : public Handshake_Message class Next_Protocol : public Handshake_Message { public: - Handshake_Type type() const { return NEXT_PROTOCOL; } + Handshake_Type type() const override { return NEXT_PROTOCOL; } std::string protocol() const { return m_protocol; } @@ -473,7 +462,7 @@ class Next_Protocol : public Handshake_Message Next_Protocol(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; std::string m_protocol; }; @@ -481,7 +470,7 @@ class Next_Protocol : public Handshake_Message class New_Session_Ticket : public Handshake_Message { public: - Handshake_Type type() const { return NEW_SESSION_TICKET; } + Handshake_Type type() const override { return NEW_SESSION_TICKET; } u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; } const std::vector<byte>& ticket() const { return m_ticket; } @@ -496,7 +485,7 @@ class New_Session_Ticket : public Handshake_Message New_Session_Ticket(const std::vector<byte>& buf); private: - std::vector<byte> serialize() const; + std::vector<byte> serialize() const override; u32bit m_ticket_lifetime_hint; std::vector<byte> m_ticket; diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 61a7642df..aabafaaaa 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -21,6 +21,8 @@ namespace { class Server_Handshake_State : public Handshake_State { public: + // using Handshake_State::Handshake_State; + Server_Handshake_State(Handshake_IO* io) : Handshake_State(io) {} // Used by the server only, in case of RSA key exchange. Not owned |