aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-08-03 16:46:53 +0000
committerlloyd <[email protected]>2012-08-03 16:46:53 +0000
commit950103dd7bbddec16330788c2ce11bcb545aaf25 (patch)
treebdfab618c883d6cc361bda21293a0e94c9e42ca8
parent419935515357bc1b7b39825ed3a0def12362746d (diff)
Take the initial record version from the Handshake_IO instance instead
of hardcoding it to SSLv3.
-rw-r--r--src/tls/tls_client.cpp6
-rw-r--r--src/tls/tls_handshake_io.cpp10
-rw-r--r--src/tls/tls_handshake_io.h7
-rw-r--r--src/tls/tls_handshake_state.cpp2
-rw-r--r--src/tls/tls_record.h2
5 files changed, 24 insertions, 3 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 77ff010f3..46e5296e2 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -35,8 +35,6 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
m_hostname(hostname),
m_port(port)
{
- m_writer.set_version(Protocol_Version::SSL_V3);
-
const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname);
const Protocol_Version version = m_policy.pref_version();
@@ -69,6 +67,10 @@ void Client::initiate_handshake(bool force_full_renegotiation,
std::function<std::string (std::vector<std::string>)> next_protocol)
{
m_state.reset(new_handshake_state());
+
+ if(!m_writer.record_version_set())
+ m_writer.set_version(m_state->handshake_io().initial_record_version());
+
m_state->set_expected_next(SERVER_HELLO);
m_state->client_npn_cb = next_protocol;
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index 8ab0bea9c..452fa8f15 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -33,6 +33,11 @@ void store_be24(byte out[3], size_t val)
}
+Protocol_Version Stream_Handshake_IO::initial_record_version() const
+ {
+ return Protocol_Version::SSL_V3;
+ }
+
void Stream_Handshake_IO::add_input(const byte rec_type,
const byte record[],
size_t record_size)
@@ -119,6 +124,11 @@ std::vector<byte> Stream_Handshake_IO::send(Handshake_Message& msg)
return buf;
}
+Protocol_Version Datagram_Handshake_IO::initial_record_version() const
+ {
+ return Protocol_Version::DTLS_V10;
+ }
+
void Datagram_Handshake_IO::add_input(const byte rec_type,
const byte record[],
size_t record_size)
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index 039f92121..da3bfd5c8 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -9,6 +9,7 @@
#define BOTAN_TLS_HANDSHAKE_IO_H__
#include <botan/tls_magic.h>
+#include <botan/tls_version.h>
#include <botan/loadstor.h>
#include <vector>
#include <deque>
@@ -28,6 +29,8 @@ class Handshake_Message;
class Handshake_IO
{
public:
+ virtual Protocol_Version initial_record_version() const = 0;
+
virtual std::vector<byte> send(Handshake_Message& msg) = 0;
virtual std::vector<byte> format(
@@ -61,6 +64,8 @@ class Stream_Handshake_IO : public Handshake_IO
public:
Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {}
+ Protocol_Version initial_record_version() const override;
+
std::vector<byte> send(Handshake_Message& msg) override;
std::vector<byte> format(
@@ -89,6 +94,8 @@ class Datagram_Handshake_IO : public Handshake_IO
public:
Datagram_Handshake_IO(Record_Writer& writer) : m_writer(writer) {}
+ Protocol_Version initial_record_version() const override;
+
std::vector<byte> send(Handshake_Message& msg) override;
std::vector<byte> format(
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 77a1b52fc..2046652ec 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -87,7 +87,7 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
*/
Handshake_State::Handshake_State(Handshake_IO* io) :
m_handshake_io(io),
- m_version(Protocol_Version::SSL_V3)
+ m_version(m_handshake_io->initial_record_version())
{
}
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index fa5da52b4..9f63b5dec 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -45,6 +45,8 @@ class BOTAN_DLL Record_Writer
void set_version(Protocol_Version version);
+ bool record_version_set() const { return m_version.valid(); }
+
void reset();
void set_maximum_fragment_size(size_t max_fragment);