diff options
Diffstat (limited to 'src/tls/tls_messages.h')
-rw-r--r-- | src/tls/tls_messages.h | 227 |
1 files changed, 126 insertions, 101 deletions
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 94e17cb9b..617b03813 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -8,11 +8,11 @@ #ifndef BOTAN_TLS_MESSAGES_H__ #define BOTAN_TLS_MESSAGES_H__ -#include <botan/internal/tls_handshake_hash.h> +#include <botan/internal/tls_handshake_state.h> #include <botan/tls_session.h> #include <botan/tls_policy.h> #include <botan/tls_magic.h> -#include <botan/tls_suites.h> +#include <botan/tls_ciphersuite.h> #include <botan/bigint.h> #include <botan/pkcs8.h> #include <botan/x509cert.h> @@ -20,6 +20,10 @@ namespace Botan { +class Credentials_Manager; + +namespace TLS { + class Record_Writer; class Record_Reader; @@ -29,27 +33,46 @@ class Record_Reader; class Handshake_Message { public: - void send(Record_Writer& writer, TLS_Handshake_Hash& hash) const; - + virtual MemoryVector<byte> serialize() const = 0; virtual Handshake_Type type() const = 0; + Handshake_Message() {} virtual ~Handshake_Message() {} private: + Handshake_Message(const Handshake_Message&) {} Handshake_Message& operator=(const Handshake_Message&) { return (*this); } - virtual MemoryVector<byte> serialize() const = 0; - virtual void deserialize(const MemoryRegion<byte>&) = 0; }; MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng); /** +* DTLS Hello Verify Request +*/ +class Hello_Verify_Request : public Handshake_Message + { + public: + MemoryVector<byte> serialize() const; + Handshake_Type type() const { return HELLO_VERIFY_REQUEST; } + + MemoryVector<byte> cookie() const { return m_cookie; } + + Hello_Verify_Request(const MemoryRegion<byte>& buf); + + Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits, + const std::string& client_identity, + const SymmetricKey& secret_key); + private: + MemoryVector<byte> m_cookie; + }; + +/** * Client Hello Message */ class Client_Hello : public Handshake_Message { public: Handshake_Type type() const { return CLIENT_HELLO; } - Version_Code version() const { return m_version; } + Protocol_Version version() const { return m_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } std::vector<byte> session_id_vector() const @@ -59,6 +82,12 @@ class Client_Hello : public Handshake_Message return v; } + const std::vector<std::pair<std::string, std::string> >& supported_algos() const + { return m_supported_algos; } + + const std::vector<std::string>& supported_ecc_curves() const + { return m_supported_curves; } + std::vector<u16bit> ciphersuites() const { return m_suites; } std::vector<byte> compression_methods() const { return m_comp_methods; } @@ -85,8 +114,8 @@ class Client_Hello : public Handshake_Message { return m_session_ticket; } Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const TLS_Policy& policy, + Handshake_Hash& hash, + const Policy& policy, RandomNumberGenerator& rng, const MemoryRegion<byte>& reneg_info, bool next_protocol = false, @@ -94,26 +123,20 @@ class Client_Hello : public Handshake_Message const std::string& srp_identifier = ""); Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, RandomNumberGenerator& rng, - const TLS_Session& resumed_session, + const Session& resumed_session, bool next_protocol = false); Client_Hello(const MemoryRegion<byte>& buf, - Handshake_Type type) - { - if(type == CLIENT_HELLO) - deserialize(buf); - else - deserialize_sslv2(buf); - } + Handshake_Type type); private: MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>& buf); void deserialize_sslv2(const MemoryRegion<byte>& buf); - Version_Code m_version; + Protocol_Version m_version; MemoryVector<byte> m_session_id, m_random; std::vector<u16bit> m_suites; std::vector<byte> m_comp_methods; @@ -125,6 +148,9 @@ class Client_Hello : public Handshake_Message bool m_secure_renegotiation; MemoryVector<byte> m_renegotiation_info; + std::vector<std::pair<std::string, std::string> > m_supported_algos; + std::vector<std::string> m_supported_curves; + bool m_supports_session_ticket; MemoryVector<byte> m_session_ticket; }; @@ -136,7 +162,7 @@ class Server_Hello : public Handshake_Message { public: Handshake_Type type() const { return SERVER_HELLO; } - Version_Code version() { return s_version; } + Protocol_Version version() { return s_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } u16bit ciphersuite() const { return suite; } byte compression_method() const { return comp_method; } @@ -163,11 +189,11 @@ class Server_Hello : public Handshake_Message const MemoryVector<byte>& random() const { return s_random; } Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, + Handshake_Hash& hash, + Protocol_Version version, const Client_Hello& other, - const std::vector<X509_Certificate>& certs, - const TLS_Policy& policies, + const std::vector<std::string>& available_cert_types, + const Policy& policies, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, bool client_has_npn, @@ -175,9 +201,9 @@ class Server_Hello : public Handshake_Message RandomNumberGenerator& rng); Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const MemoryRegion<byte>& session_id, - Version_Code ver, + Protocol_Version ver, u16bit ciphersuite, byte compression, size_t max_fragment_size, @@ -187,12 +213,11 @@ class Server_Hello : public Handshake_Message const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng); - Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - Version_Code s_version; + Protocol_Version s_version; MemoryVector<byte> m_session_id, s_random; u16bit suite; byte comp_method; @@ -216,26 +241,22 @@ class Client_Key_Exchange : public Handshake_Message const SecureVector<byte>& pre_master_secret() const { return pre_master; } - SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng, - const Private_Key* key, - Version_Code version); - Client_Key_Exchange(Record_Writer& output, - TLS_Handshake_Hash& hash, - RandomNumberGenerator& rng, - const Public_Key* my_key, - Version_Code using_version, - Version_Code pref_version); + Handshake_State* state, + Credentials_Manager& creds, + const std::vector<X509_Certificate>& peer_certs, + RandomNumberGenerator& rng); Client_Key_Exchange(const MemoryRegion<byte>& buf, - const TLS_Cipher_Suite& suite, - Version_Code using_version); + const Handshake_State* state, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng); + private: - MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + MemoryVector<byte> serialize() const { return key_material; } SecureVector<byte> key_material, pre_master; - bool include_length; }; /** @@ -245,20 +266,20 @@ class Certificate : public Handshake_Message { public: Handshake_Type type() const { return CERTIFICATE; } - const std::vector<X509_Certificate>& cert_chain() const { return certs; } + const std::vector<X509_Certificate>& cert_chain() const { return m_certs; } - size_t count() const { return certs.size(); } - bool empty() const { return certs.empty(); } + size_t count() const { return m_certs.size(); } + bool empty() const { return m_certs.empty(); } Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); - Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - std::vector<X509_Certificate> certs; + + std::vector<X509_Certificate> m_certs; }; /** @@ -269,22 +290,29 @@ class Certificate_Req : public Handshake_Message public: Handshake_Type type() const { return CERTIFICATE_REQUEST; } - std::vector<Certificate_Type> acceptable_types() const { return types; } + const std::vector<std::string>& acceptable_cert_types() const + { return cert_key_types; } + std::vector<X509_DN> acceptable_CAs() const { return names; } + std::vector<std::pair<std::string, std::string> > supported_algos() const + { return m_supported_algos; } + Certificate_Req(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, + const Policy& policy, const std::vector<X509_Certificate>& allowed_cas, - const std::vector<Certificate_Type>& types = - std::vector<Certificate_Type>()); + Protocol_Version version); - Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Req(const MemoryRegion<byte>& buf, + Protocol_Version version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::vector<X509_DN> names; - std::vector<Certificate_Type> types; + std::vector<std::string> cert_key_types; + + std::vector<std::pair<std::string, std::string> > m_supported_algos; }; /** @@ -298,25 +326,23 @@ class Certificate_Verify : public Handshake_Message /** * Check the signature on a certificate verify message * @param cert the purported certificate - * @param hash the running handshake message hash - * @param version the version number we negotiated - * @param master_secret the session key (only used if version is SSL_V3) + * @param state the handshake state */ bool verify(const X509_Certificate& cert, - TLS_Handshake_Hash& hash, - Version_Code version, - const SecureVector<byte>& master_secret); + Handshake_State* state); Certificate_Verify(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Verify(const MemoryRegion<byte>& buf, + Protocol_Version version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + std::string sig_algo; // sig algo used to create signature + std::string hash_algo; // hash used to create signature MemoryVector<byte> signature; }; @@ -331,26 +357,16 @@ class Finished : public Handshake_Message MemoryVector<byte> verify_data() const { return verification_data; } - bool verify(const MemoryRegion<byte>& buf, - Version_Code version, - const TLS_Handshake_Hash& hash, + bool verify(Handshake_State* state, Connection_Side side); Finished(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, - Connection_Side side, - const MemoryRegion<byte>& master_secret); + Handshake_State* state, + Connection_Side side); - Finished(const MemoryRegion<byte>& buf) { deserialize(buf); } + Finished(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - - MemoryVector<byte> compute_verify(const MemoryRegion<byte>& master_secret, - TLS_Handshake_Hash hash, - Connection_Side side, - Version_Code version); Connection_Side side; MemoryVector<byte> verification_data; @@ -365,10 +381,9 @@ class Hello_Request : public Handshake_Message Handshake_Type type() const { return HELLO_REQUEST; } Hello_Request(Record_Writer& writer); - Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); } + Hello_Request(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -378,28 +393,38 @@ class Server_Key_Exchange : public Handshake_Message { public: Handshake_Type type() const { return SERVER_KEX; } - Public_Key* key() const; + + const MemoryVector<byte>& params() const { return m_params; } bool verify(const X509_Certificate& cert, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random) const; + Handshake_State* state) const; + + // Only valid for certain kex types + const Private_Key& server_kex_key() const; Server_Key_Exchange(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, + const Policy& policy, + Credentials_Manager& creds, RandomNumberGenerator& rng, - const Public_Key* kex_key, - const Private_Key* priv_key, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random); + const Private_Key* signing_key = 0); - Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Key_Exchange(const MemoryRegion<byte>& buf, + const std::string& kex_alg, + const std::string& sig_alg, + Protocol_Version version); + + ~Server_Key_Exchange() { delete m_kex_key; } private: MemoryVector<byte> serialize() const; - MemoryVector<byte> serialize_params() const; - void deserialize(const MemoryRegion<byte>&); - std::vector<BigInt> params; - MemoryVector<byte> signature; + Private_Key* m_kex_key; + + MemoryVector<byte> m_params; + + std::string m_sig_algo; // sig algo used to create signature + std::string m_hash_algo; // hash used to create signature + MemoryVector<byte> m_signature; }; /** @@ -410,11 +435,10 @@ class Server_Hello_Done : public Handshake_Message public: Handshake_Type type() const { return SERVER_HELLO_DONE; } - Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash); - Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello_Done(Record_Writer& writer, Handshake_Hash& hash); + Server_Hello_Done(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -428,13 +452,12 @@ class Next_Protocol : public Handshake_Message std::string protocol() const { return m_protocol; } Next_Protocol(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const std::string& protocol); - Next_Protocol(const MemoryRegion<byte>& buf) { deserialize(buf); } + Next_Protocol(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::string m_protocol; }; @@ -465,4 +488,6 @@ class New_Session_Ticket : public Handshake_Message } +} + #endif |