diff options
-rw-r--r-- | src/tls/hello.cpp | 31 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 3 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 3 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 11 | ||||
-rw-r--r-- | src/tls/tls_client.h | 7 | ||||
-rw-r--r-- | src/tls/tls_extensions.cpp | 22 | ||||
-rw-r--r-- | src/tls/tls_extensions.h | 26 | ||||
-rw-r--r-- | src/tls/tls_magic.h | 4 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 116 | ||||
-rw-r--r-- | src/tls/tls_reader.h | 3 |
10 files changed, 158 insertions, 68 deletions
diff --git a/src/tls/hello.cpp b/src/tls/hello.cpp index 59e7c68d4..8136eaa6b 100644 --- a/src/tls/hello.cpp +++ b/src/tls/hello.cpp @@ -99,7 +99,13 @@ MemoryVector<byte> Client_Hello::serialize() const printf("Requesting hostname '%s'\n", requested_hostname.c_str()); + /* + * May not want to send extensions at all in some cases. + * If so, should include SCSV value + */ + TLS_Extensions extensions; + extensions.push_back(new Renegotation_Extension()); extensions.push_back(new Server_Name_Indicator(requested_hostname)); extensions.push_back(new SRP_Identifier(requested_srp_id)); buf += extensions.serialize(); @@ -147,6 +153,8 @@ void Client_Hello::deserialize_sslv2(const MemoryRegion<byte>& buf) */ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) { + has_secure_renegotiation = false; + if(buf.size() == 0) throw Decoding_Error("Client_Hello: Packet corrupted"); @@ -174,10 +182,13 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) requested_hostname = sni->host_name(); else if(SRP_Identifier* srp = dynamic_cast<SRP_Identifier*>(extn)) requested_srp_id = srp->identifier(); + else if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn)) + { + // checked by TLS_Client / TLS_Server as they know the handshake state + has_secure_renegotiation = true; + renegotiation_info_bits = reneg->renegotiation_info(); + } } - - printf("hostname %s srp id %s\n", requested_hostname.c_str(), - requested_srp_id.c_str()); } /* @@ -294,6 +305,20 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf) suite = reader.get_u16bit(); comp_method = reader.get_byte(); + + TLS_Extensions extensions(reader); + + for(size_t i = 0; i != extensions.count(); ++i) + { + TLS_Extension* extn = extensions.at(i); + + if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn)) + { + // checked by TLS_Client / TLS_Server as they know the handshake state + has_secure_renegotiation = true; + renegotiation_info_bits = reneg->renegotiation_info(); + } + } } /* diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 96d08a5f4..8cfb15ad6 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -17,7 +17,8 @@ TLS_Channel::TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_ proc_fn(proc_fn), writer(socket_output_fn), state(0), - active(false) + active(false), + secure_renegotiation(false) { } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index d74504ccd..f8b53684d 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -77,6 +77,9 @@ class BOTAN_DLL TLS_Channel class Handshake_State* state; + bool secure_renegotiation; + MemoryVector<byte> client_verify_data, server_verify_data; + bool active; }; diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 3fe93b846..2119ca44e 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -87,22 +87,25 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, { state->server_hello = new Server_Hello(contents); - if(!state->client_hello->offered_suite( - state->server_hello->ciphersuite() - ) - ) + if(!state->client_hello->offered_suite(state->server_hello->ciphersuite())) + { throw TLS_Exception(HANDSHAKE_FAILURE, "TLS_Client: Server replied with bad ciphersuite"); + } state->version = state->server_hello->version(); if(state->version > state->client_hello->version()) + { throw TLS_Exception(HANDSHAKE_FAILURE, "TLS_Client: Server replied with bad version"); + } if(state->version < policy.min_version()) + { throw TLS_Exception(PROTOCOL_VERSION, "TLS_Client: Server is too old for specified policy"); + } writer.set_version(state->version); reader.set_version(state->version); diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index e80051460..a2ad232c8 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -22,6 +22,13 @@ class BOTAN_DLL TLS_Client : public TLS_Channel public: /** * Set up a new TLS client session + * @param socket_output_fn is called with data for the outbound socket + * @param proc_fn is called when new data (application or alerts) is received + * @param session_manager manages session resumption + * @param policy specifies other connection policy information + * @param rng a random number generator + * @param servername the server's DNS name, if known + * @param srp_username an identifier to use for SRP key exchange */ TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn, std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp index b9881ea3a..28523abf2 100644 --- a/src/tls/tls_extensions.cpp +++ b/src/tls/tls_extensions.cpp @@ -41,7 +41,7 @@ TLS_Extensions::TLS_Extensions(class TLS_Data_Reader& reader) MemoryVector<byte> TLS_Extensions::serialize() const { - MemoryVector<byte> buf(2); // allocate length + MemoryVector<byte> buf(2); // 2 bytes for length field for(size_t i = 0; i != extensions.size(); ++i) { @@ -52,8 +52,6 @@ MemoryVector<byte> TLS_Extensions::serialize() const MemoryVector<byte> extn_val = extensions[i]->serialize(); - printf("serializing extn %d of %d bytes\n", extn_code, extn_val.size()); - buf.push_back(get_byte(0, extn_code)); buf.push_back(get_byte(1, extn_code)); @@ -68,9 +66,7 @@ MemoryVector<byte> TLS_Extensions::serialize() const buf[0] = get_byte(0, extn_size); buf[1] = get_byte(1, extn_size); - printf("%d bytes of extensions\n", buf.size()); - - // avoid sending an empty extensions block + // avoid sending a completely empty extensions block if(buf.size() == 2) return MemoryVector<byte>(); @@ -119,13 +115,10 @@ MemoryVector<byte> Server_Name_Indicator::serialize() const buf.push_back(get_byte<u16bit>(0, name_len)); buf.push_back(get_byte<u16bit>(1, name_len)); - buf += std::make_pair( reinterpret_cast<const byte*>(sni_host_name.data()), sni_host_name.size()); - printf("serializing %d bytes %s\n", buf.size(), - sni_host_name.c_str()); return buf; } @@ -146,5 +139,16 @@ MemoryVector<byte> SRP_Identifier::serialize() const return buf; } +Renegotation_Extension::Renegotation_Extension(TLS_Data_Reader& reader) + { + reneg_data = reader.get_range<byte>(1, 0, 255); + } + +MemoryVector<byte> Renegotation_Extension::serialize() const + { + MemoryVector<byte> buf; + append_tls_length_value(buf, reneg_data, 1); + return buf; + } } diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h index 01a4253b3..aa2349cf1 100644 --- a/src/tls/tls_extensions.h +++ b/src/tls/tls_extensions.h @@ -78,6 +78,32 @@ class SRP_Identifier : public TLS_Extension }; /** +* Renegotiation Indication Extension (RFC 5746) +*/ +class Renegotation_Extension : public TLS_Extension + { + public: + TLS_Handshake_Extension_Type type() const + { return TLSEXT_SAFE_RENEGOTIATION; } + + Renegotation_Extension() {} + + Renegotation_Extension(const MemoryRegion<byte>& bits) : + reneg_data(bits) {} + + Renegotation_Extension(TLS_Data_Reader& reader); + + const MemoryVector<byte>& renegotiation_info() const + { return reneg_data; } + + MemoryVector<byte> serialize() const; + + bool empty() const { return false; } // always send this + private: + MemoryVector<byte> reneg_data; + }; + +/** * Represents a block of extensions in a hello message */ class TLS_Extensions diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h index bddbab3ce..535319d41 100644 --- a/src/tls/tls_magic.h +++ b/src/tls/tls_magic.h @@ -201,7 +201,9 @@ enum TLS_Handshake_Extension_Type { TLSEXT_SRP_IDENTIFIER = 12, TLSEXT_CERTIFICATE_TYPES = 9, - TLSEXT_SESSION_TICKET = 35 + TLSEXT_SESSION_TICKET = 35, + + TLSEXT_SAFE_RENEGOTIATION = 65281, }; } diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index eeaf77a39..38d347d28 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -65,7 +65,12 @@ class Client_Hello : public HandshakeMessage std::string srp_identifier() const { return requested_srp_id; } - bool offered_suite(u16bit) const; + bool secure_renegotiation() const { return has_secure_renegotiation; } + + const MemoryVector<byte>& renegotiation_info() + { return renegotiation_info_bits; } + + bool offered_suite(u16bit ciphersuite) const; Client_Hello(Record_Writer& writer, TLS_Handshake_Hash& hash, @@ -94,6 +99,66 @@ class Client_Hello : public HandshakeMessage std::vector<byte> comp_methods; std::string requested_hostname; std::string requested_srp_id; + + bool has_secure_renegotiation; + MemoryVector<byte> renegotiation_info_bits; + }; + +/** +* Server Hello Message +*/ +class Server_Hello : public HandshakeMessage + { + public: + Handshake_Type type() const { return SERVER_HELLO; } + Version_Code version() { return s_version; } + const MemoryVector<byte>& session_id() const { return sess_id; } + u16bit ciphersuite() const { return suite; } + byte compression_method() const { return comp_method; } + + std::vector<byte> session_id_vector() const + { + std::vector<byte> v; + v.insert(v.begin(), &sess_id[0], &sess_id[sess_id.size()]); + return v; + } + + bool secure_renegotiation() const { return has_secure_renegotiation; } + + const MemoryVector<byte>& renegotiation_info() + { return renegotiation_info_bits; } + + const MemoryVector<byte>& random() const { return s_random; } + + Server_Hello(Record_Writer& writer, + TLS_Handshake_Hash& hash, + const TLS_Policy& policies, + RandomNumberGenerator& rng, + const std::vector<X509_Certificate>& certs, + const Client_Hello& other, + const MemoryRegion<byte>& session_id, + Version_Code version); + + Server_Hello(Record_Writer& writer, + TLS_Handshake_Hash& hash, + RandomNumberGenerator& rng, + const MemoryRegion<byte>& session_id, + u16bit ciphersuite, + byte compression, + Version_Code ver); + + Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } + private: + MemoryVector<byte> serialize() const; + void deserialize(const MemoryRegion<byte>&); + + Version_Code s_version; + MemoryVector<byte> sess_id, s_random; + u16bit suite; + byte comp_method; + + bool has_secure_renegotiation; + MemoryVector<byte> renegotiation_info_bits; }; /** @@ -260,55 +325,6 @@ class Hello_Request : public HandshakeMessage }; /** -* Server Hello Message -*/ -class Server_Hello : public HandshakeMessage - { - public: - Handshake_Type type() const { return SERVER_HELLO; } - Version_Code version() { return s_version; } - const MemoryVector<byte>& session_id() const { return sess_id; } - u16bit ciphersuite() const { return suite; } - byte compression_method() const { return comp_method; } - - std::vector<byte> session_id_vector() const - { - std::vector<byte> v; - v.insert(v.begin(), &sess_id[0], &sess_id[sess_id.size()]); - return v; - } - - const MemoryVector<byte>& random() const { return s_random; } - - Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const TLS_Policy& policies, - RandomNumberGenerator& rng, - const std::vector<X509_Certificate>& certs, - const Client_Hello& other, - const MemoryRegion<byte>& session_id, - Version_Code version); - - Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - RandomNumberGenerator& rng, - const MemoryRegion<byte>& session_id, - u16bit ciphersuite, - byte compression, - Version_Code ver); - - Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } - private: - MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - - Version_Code s_version; - MemoryVector<byte> sess_id, s_random; - u16bit suite; - byte comp_method; - }; - -/** * Server Key Exchange Message */ class Server_Key_Exchange : public HandshakeMessage diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index 9f40bb457..bc2dffb58 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -29,7 +29,10 @@ class TLS_Data_Reader ~TLS_Data_Reader() { if(has_remaining()) + { + abort(); throw Decoding_Error("Extra bytes at end of message"); + } } size_t remaining_bytes() const |