diff options
author | lloyd <[email protected]> | 2012-08-03 16:46:53 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-08-03 16:46:53 +0000 |
commit | 950103dd7bbddec16330788c2ce11bcb545aaf25 (patch) | |
tree | bdfab618c883d6cc361bda21293a0e94c9e42ca8 | |
parent | 419935515357bc1b7b39825ed3a0def12362746d (diff) |
Take the initial record version from the Handshake_IO instance instead
of hardcoding it to SSLv3.
-rw-r--r-- | src/tls/tls_client.cpp | 6 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 10 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 7 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_record.h | 2 |
5 files changed, 24 insertions, 3 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 77ff010f3..46e5296e2 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -35,8 +35,6 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, m_hostname(hostname), m_port(port) { - m_writer.set_version(Protocol_Version::SSL_V3); - const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname); const Protocol_Version version = m_policy.pref_version(); @@ -69,6 +67,10 @@ void Client::initiate_handshake(bool force_full_renegotiation, std::function<std::string (std::vector<std::string>)> next_protocol) { m_state.reset(new_handshake_state()); + + if(!m_writer.record_version_set()) + m_writer.set_version(m_state->handshake_io().initial_record_version()); + m_state->set_expected_next(SERVER_HELLO); m_state->client_npn_cb = next_protocol; diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index 8ab0bea9c..452fa8f15 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -33,6 +33,11 @@ void store_be24(byte out[3], size_t val) } +Protocol_Version Stream_Handshake_IO::initial_record_version() const + { + return Protocol_Version::SSL_V3; + } + void Stream_Handshake_IO::add_input(const byte rec_type, const byte record[], size_t record_size) @@ -119,6 +124,11 @@ std::vector<byte> Stream_Handshake_IO::send(Handshake_Message& msg) return buf; } +Protocol_Version Datagram_Handshake_IO::initial_record_version() const + { + return Protocol_Version::DTLS_V10; + } + void Datagram_Handshake_IO::add_input(const byte rec_type, const byte record[], size_t record_size) diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index 039f92121..da3bfd5c8 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -9,6 +9,7 @@ #define BOTAN_TLS_HANDSHAKE_IO_H__ #include <botan/tls_magic.h> +#include <botan/tls_version.h> #include <botan/loadstor.h> #include <vector> #include <deque> @@ -28,6 +29,8 @@ class Handshake_Message; class Handshake_IO { public: + virtual Protocol_Version initial_record_version() const = 0; + virtual std::vector<byte> send(Handshake_Message& msg) = 0; virtual std::vector<byte> format( @@ -61,6 +64,8 @@ class Stream_Handshake_IO : public Handshake_IO public: Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {} + Protocol_Version initial_record_version() const override; + std::vector<byte> send(Handshake_Message& msg) override; std::vector<byte> format( @@ -89,6 +94,8 @@ class Datagram_Handshake_IO : public Handshake_IO public: Datagram_Handshake_IO(Record_Writer& writer) : m_writer(writer) {} + Protocol_Version initial_record_version() const override; + std::vector<byte> send(Handshake_Message& msg) override; std::vector<byte> format( diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 77a1b52fc..2046652ec 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -87,7 +87,7 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) */ Handshake_State::Handshake_State(Handshake_IO* io) : m_handshake_io(io), - m_version(Protocol_Version::SSL_V3) + m_version(m_handshake_io->initial_record_version()) { } diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index fa5da52b4..9f63b5dec 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -45,6 +45,8 @@ class BOTAN_DLL Record_Writer void set_version(Protocol_Version version); + bool record_version_set() const { return m_version.valid(); } + void reset(); void set_maximum_fragment_size(size_t max_fragment); |