diff options
Diffstat (limited to 'src/tls/tls_extensions.h')
-rw-r--r-- | src/tls/tls_extensions.h | 162 |
1 files changed, 137 insertions, 25 deletions
diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h index 6d4e40434..a9e85221e 100644 --- a/src/tls/tls_extensions.h +++ b/src/tls/tls_extensions.h @@ -12,35 +12,60 @@ #include <botan/tls_magic.h> #include <vector> #include <string> +#include <map> namespace Botan { -class TLS_Session; +namespace TLS { + class TLS_Data_Reader; +enum Handshake_Extension_Type { + TLSEXT_SERVER_NAME_INDICATION = 0, + TLSEXT_MAX_FRAGMENT_LENGTH = 1, + TLSEXT_CLIENT_CERT_URL = 2, + TLSEXT_TRUSTED_CA_KEYS = 3, + TLSEXT_TRUNCATED_HMAC = 4, + + TLSEXT_CERTIFICATE_TYPES = 9, + TLSEXT_USABLE_ELLIPTIC_CURVES = 10, + TLSEXT_EC_POINT_FORMATS = 11, + TLSEXT_SRP_IDENTIFIER = 12, + TLSEXT_SIGNATURE_ALGORITHMS = 13, + + TLSEXT_SESSION_TICKET = 35, + + TLSEXT_NEXT_PROTOCOL = 13172, + + TLSEXT_SAFE_RENEGOTIATION = 65281, +}; + /** * Base class representing a TLS extension of some kind */ -class TLS_Extension +class Extension { public: - virtual TLS_Handshake_Extension_Type type() const = 0; + virtual Handshake_Extension_Type type() const = 0; + virtual MemoryVector<byte> serialize() const = 0; virtual bool empty() const = 0; - virtual ~TLS_Extension() {} + virtual ~Extension() {} }; /** * Server Name Indicator extension (RFC 3546) */ -class Server_Name_Indicator : public TLS_Extension +class Server_Name_Indicator : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SERVER_NAME_INDICATION; } + Handshake_Extension_Type type() const { return static_type(); } + Server_Name_Indicator(const std::string& host_name) : sni_host_name(host_name) {} @@ -59,12 +84,14 @@ class Server_Name_Indicator : public TLS_Extension /** * SRP identifier extension (RFC 5054) */ -class SRP_Identifier : public TLS_Extension +class SRP_Identifier : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SRP_IDENTIFIER; } + Handshake_Extension_Type type() const { return static_type(); } + SRP_Identifier(const std::string& identifier) : srp_identifier(identifier) {} @@ -83,12 +110,14 @@ class SRP_Identifier : public TLS_Extension /** * Renegotiation Indication Extension (RFC 5746) */ -class Renegotation_Extension : public TLS_Extension +class Renegotation_Extension : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SAFE_RENEGOTIATION; } + Handshake_Extension_Type type() const { return static_type(); } + Renegotation_Extension() {} Renegotation_Extension(const MemoryRegion<byte>& bits) : @@ -110,12 +139,14 @@ class Renegotation_Extension : public TLS_Extension /** * Maximum Fragment Length Negotiation Extension (RFC 4366 sec 3.2) */ -class Maximum_Fragment_Length : public TLS_Extension +class Maximum_Fragment_Length : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_MAX_FRAGMENT_LENGTH; } + Handshake_Extension_Type type() const { return static_type(); } + bool empty() const { return val != 0; } size_t fragment_size() const; @@ -147,12 +178,14 @@ class Maximum_Fragment_Length : public TLS_Extension * spec (implemented in Chromium); the internet draft leaves the format * unspecified. */ -class Next_Protocol_Notification : public TLS_Extension +class Next_Protocol_Notification : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_NEXT_PROTOCOL; } + Handshake_Extension_Type type() const { return static_type(); } + const std::vector<std::string>& protocols() const { return m_protocols; } @@ -209,32 +242,111 @@ class Session_Ticket : public TLS_Extension }; /** +* Supported Elliptic Curves Extension (RFC 4492) +*/ +class Supported_Elliptic_Curves : public Extension + { + public: + static Handshake_Extension_Type static_type() + { return TLSEXT_USABLE_ELLIPTIC_CURVES; } + + Handshake_Extension_Type type() const { return static_type(); } + + static std::string curve_id_to_name(u16bit id); + static u16bit name_to_curve_id(const std::string& name); + + const std::vector<std::string>& curves() const { return m_curves; } + + MemoryVector<byte> serialize() const; + + Supported_Elliptic_Curves(const std::vector<std::string>& curves) : + m_curves(curves) {} + + Supported_Elliptic_Curves(TLS_Data_Reader& reader, + u16bit extension_size); + + bool empty() const { return m_curves.empty(); } + private: + std::vector<std::string> m_curves; + }; + +/** +* Signature Algorithms Extension for TLS 1.2 (RFC 5246) +*/ +class Signature_Algorithms : public Extension + { + public: + static Handshake_Extension_Type static_type() + { return TLSEXT_SIGNATURE_ALGORITHMS; } + + Handshake_Extension_Type type() const { return static_type(); } + + static std::string hash_algo_name(byte code); + static byte hash_algo_code(const std::string& name); + + static std::string sig_algo_name(byte code); + static byte sig_algo_code(const std::string& name); + + std::vector<std::pair<std::string, std::string> > + supported_signature_algorthms() const + { + return m_supported_algos; + } + + MemoryVector<byte> serialize() const; + + bool empty() const { return false; } + + Signature_Algorithms(const std::vector<std::pair<std::string, std::string> >& algos) : + m_supported_algos(algos) {} + + Signature_Algorithms(TLS_Data_Reader& reader, + u16bit extension_size); + private: + std::vector<std::pair<std::string, std::string> > m_supported_algos; + }; + +/** * Represents a block of extensions in a hello message */ -class TLS_Extensions +class Extensions { public: - size_t count() const { return extensions.size(); } + template<typename T> + T* get() const + { + Handshake_Extension_Type type = T::static_type(); - TLS_Extension* at(size_t idx) { return extensions.at(idx); } + std::map<Handshake_Extension_Type, Extension*>::const_iterator i = + extensions.find(type); - void push_back(TLS_Extension* extn) - { extensions.push_back(extn); } + if(i != extensions.end()) + return dynamic_cast<T*>(i->second); + return 0; + } + + void add(Extension* extn) + { + delete extensions[extn->type()]; // or hard error if already exists? + extensions[extn->type()] = extn; + } MemoryVector<byte> serialize() const; - TLS_Extensions() {} + Extensions() {} - TLS_Extensions(TLS_Data_Reader& reader); // deserialize + Extensions(TLS_Data_Reader& reader); // deserialize - ~TLS_Extensions(); + ~Extensions(); private: - TLS_Extensions(const TLS_Extensions&) {} - TLS_Extensions& operator=(const TLS_Extensions&) { return (*this); } + Extensions(const Extensions&) {} + Extensions& operator=(const Extensions&) { return (*this); } - std::vector<TLS_Extension*> extensions; + std::map<Handshake_Extension_Type, Extension*> extensions; }; } +} + #endif |