aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/c_hello.cpp16
-rw-r--r--src/tls/c_kex.cpp23
-rw-r--r--src/tls/cert_req.cpp10
-rw-r--r--src/tls/cert_ver.cpp8
-rw-r--r--src/tls/finished.cpp10
-rw-r--r--src/tls/info.txt2
-rw-r--r--src/tls/rec_read.cpp34
-rw-r--r--src/tls/rec_wri.cpp26
-rw-r--r--src/tls/s_hello.cpp21
-rw-r--r--src/tls/s_kex.cpp4
-rw-r--r--src/tls/tls_channel.cpp6
-rw-r--r--src/tls/tls_client.cpp4
-rw-r--r--src/tls/tls_handshake_hash.cpp6
-rw-r--r--src/tls/tls_handshake_hash.h3
-rw-r--r--src/tls/tls_handshake_state.cpp30
-rw-r--r--src/tls/tls_handshake_state.h2
-rw-r--r--src/tls/tls_magic.h8
-rw-r--r--src/tls/tls_messages.h24
-rw-r--r--src/tls/tls_policy.h10
-rw-r--r--src/tls/tls_record.h9
-rw-r--r--src/tls/tls_server.cpp22
-rw-r--r--src/tls/tls_session.cpp11
-rw-r--r--src/tls/tls_session.h38
-rw-r--r--src/tls/tls_session_key.cpp12
-rw-r--r--src/tls/tls_version.cpp33
-rw-r--r--src/tls/tls_version.h100
26 files changed, 304 insertions, 168 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index 4fdadd455..00728ff16 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -150,8 +150,8 @@ MemoryVector<byte> Client_Hello::serialize() const
{
MemoryVector<byte> buf;
- buf.push_back(static_cast<byte>(m_version >> 8));
- buf.push_back(static_cast<byte>(m_version ));
+ buf.push_back(m_version.major_version());
+ buf.push_back(m_version.minor_version());
buf += m_random;
append_tls_length_value(buf, m_session_id, 1);
@@ -174,7 +174,7 @@ MemoryVector<byte> Client_Hello::serialize() const
extensions.add(new Server_Name_Indicator(m_hostname));
extensions.add(new SRP_Identifier(m_srp_identifier));
- if(m_version >= TLS_V12)
+ if(m_version >= Protocol_Version::TLS_V12)
extensions.add(new Signature_Algorithms(m_supported_algos));
if(m_next_protocol)
@@ -220,7 +220,7 @@ void Client_Hello::deserialize_sslv2(const MemoryRegion<byte>& buf)
m_suites.push_back(make_u16bit(buf[i+1], buf[i+2]));
}
- m_version = static_cast<Version_Code>(make_u16bit(buf[1], buf[2]));
+ m_version = Protocol_Version(buf[1], buf[2]);
m_random.resize(challenge_len);
copy_mem(&m_random[0], &buf[9+cipher_spec_len+m_session_id_len], challenge_len);
@@ -242,7 +242,11 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf)
TLS_Data_Reader reader(buf);
- m_version = static_cast<Version_Code>(reader.get_u16bit());
+ const byte major_version = reader.get_byte();
+ const byte minor_version = reader.get_byte();
+
+ m_version = Protocol_Version(major_version, minor_version);
+
m_random = reader.get_fixed<byte>(32);
m_session_id = reader.get_range<byte>(1, 0, 32);
@@ -289,7 +293,7 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf)
}
else
{
- if(m_version >= TLS_V12)
+ if(m_version >= Protocol_Version::TLS_V12)
{
/*
The rule for when a TLS 1.2 client not sending the extension
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index de8f54fbe..8bf923041 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -89,17 +89,17 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key.get()))
{
- const Version_Code pref_version = state->client_hello->version();
+ const Protocol_Version pref_version = state->client_hello->version();
pre_master = rng.random_vec(48);
- pre_master[0] = (pref_version >> 8) & 0xFF;
- pre_master[1] = (pref_version ) & 0xFF;
+ pre_master[0] = pref_version.major_version();
+ pre_master[1] = pref_version.minor_version();
PK_Encryptor_EME encryptor(*rsa_pub, "PKCS1v15");
key_material = encryptor.encrypt(pre_master, rng);
- if(state->version == SSL_V3)
+ if(state->version == Protocol_Version::SSL_V3)
include_length = false;
}
else
@@ -116,11 +116,11 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
*/
Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
const Ciphersuite& suite,
- Version_Code using_version)
+ Protocol_Version using_version)
{
include_length = true;
- if(using_version == SSL_V3 && (suite.kex_algo() == ""))
+ if(using_version == Protocol_Version::SSL_V3 && (suite.kex_algo() == ""))
include_length = false;
if(include_length)
@@ -153,7 +153,7 @@ MemoryVector<byte> Client_Key_Exchange::serialize() const
SecureVector<byte>
Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
const Private_Key* priv_key,
- Version_Code version)
+ Protocol_Version client_version)
{
if(const DH_PrivateKey* dh_priv = dynamic_cast<const DH_PrivateKey*>(priv_key))
@@ -184,14 +184,17 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
pre_master = decryptor.decrypt(key_material);
if(pre_master.size() != 48 ||
- make_u16bit(pre_master[0], pre_master[1]) != version)
+ client_version.major_version() != pre_master[0] ||
+ client_version.minor_version() != pre_master[1])
+ {
throw Decoding_Error("Client_Key_Exchange: Secret corrupted");
+ }
}
catch(...)
{
pre_master = rng.random_vec(48);
- pre_master[0] = (version >> 8) & 0xFF;
- pre_master[1] = (version ) & 0xFF;
+ pre_master[0] = client_version.major_version();
+ pre_master[1] = client_version.minor_version();
}
return pre_master;
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index d5a73f64e..3f70c306b 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -13,8 +13,6 @@
#include <botan/loadstor.h>
#include <botan/secqueue.h>
-#include <stdio.h>
-
namespace Botan {
namespace TLS {
@@ -26,7 +24,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& ca_certs,
- Version_Code version)
+ Protocol_Version version)
{
for(size_t i = 0; i != ca_certs.size(); ++i)
names.push_back(ca_certs[i].subject_dn());
@@ -34,7 +32,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
cert_types.push_back(RSA_CERT);
cert_types.push_back(DSS_CERT);
- if(version >= TLS_V12)
+ if(version >= Protocol_Version::TLS_V12)
{
std::vector<std::string> hashes = policy.allowed_hashes();
std::vector<std::string> sigs = policy.allowed_signature_methods();
@@ -51,7 +49,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
* Deserialize a Certificate Request message
*/
Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf,
- Version_Code version)
+ Protocol_Version version)
{
if(buf.size() < 4)
throw Decoding_Error("Certificate_Req: Bad certificate request");
@@ -60,7 +58,7 @@ Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf,
cert_types = reader.get_range_vector<byte>(1, 1, 255);
- if(version >= TLS_V12)
+ if(version >= Protocol_Version::TLS_V12)
{
std::vector<byte> sig_hash_algs = reader.get_range_vector<byte>(2, 2, 65534);
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index 923cdbb42..791635b17 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -30,7 +30,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
PK_Signer signer(*priv_key, format.first, format.second);
- if(state->version == SSL_V3)
+ if(state->version == Protocol_Version::SSL_V3)
{
SecureVector<byte> md5_sha = state->hash.final_ssl3(
state->keys.master_secret());
@@ -52,11 +52,11 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
* Deserialize a Certificate Verify message
*/
Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf,
- Version_Code version)
+ Protocol_Version version)
{
TLS_Data_Reader reader(buf);
- if(version >= TLS_V12)
+ if(version >= Protocol_Version::TLS_V12)
{
hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte());
sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte());
@@ -99,7 +99,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert,
PK_Verifier verifier(*key, format.first, format.second);
- if(state->version == SSL_V3)
+ if(state->version == Protocol_Version::SSL_V3)
{
SecureVector<byte> md5_sha = state->hash.final_ssl3(
state->keys.master_secret());
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index f7f8a7eb8..80385bd5e 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -11,19 +11,17 @@
#include <botan/sha2_32.h>
#include <memory>
-#include <stdio.h>
-
namespace Botan {
namespace TLS {
namespace {
-KDF* choose_tls_prf(Version_Code version)
+KDF* choose_tls_prf(Protocol_Version version)
{
- if(version == TLS_V10 || version == TLS_V11)
+ if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11)
return new TLS_PRF;
- else if(version == TLS_V12)
+ else if(version == Protocol_Version::TLS_V12)
return new TLS_12_PRF(new HMAC(new SHA_256)); // might depend on ciphersuite
else
throw TLS_Exception(PROTOCOL_VERSION,
@@ -36,7 +34,7 @@ KDF* choose_tls_prf(Version_Code version)
MemoryVector<byte> finished_compute_verify(Handshake_State* state,
Connection_Side side)
{
- if(state->version == SSL_V3)
+ if(state->version == Protocol_Version::SSL_V3)
{
const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 };
const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 };
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 16d112df2..2774e9be8 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -18,6 +18,7 @@ tls_server.h
tls_session.h
tls_session_manager.h
tls_suites.h
+tls_version.h
</header:public>
<header:internal>
@@ -52,6 +53,7 @@ tls_session.cpp
tls_session_key.cpp
tls_session_manager.cpp
tls_suites.cpp
+tls_version.cpp
</source>
<requires>
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp
index 4db50262d..3fd2df33f 100644
--- a/src/tls/rec_read.cpp
+++ b/src/tls/rec_read.cpp
@@ -41,7 +41,7 @@ void Record_Reader::reset()
m_block_size = 0;
m_iv_size = 0;
- m_major = m_minor = 0;
+ m_version = Protocol_Version();
m_seq_no = 0;
set_maximum_fragment_size(0);
}
@@ -57,10 +57,9 @@ void Record_Reader::set_maximum_fragment_size(size_t max_fragment)
/*
* Set the version to use
*/
-void Record_Reader::set_version(Version_Code version)
+void Record_Reader::set_version(Protocol_Version version)
{
- m_major = (version >> 8) & 0xFF;
- m_minor = (version & 0xFF);
+ m_version = version;
}
/*
@@ -102,7 +101,7 @@ void Record_Reader::activate(const Ciphersuite& suite,
);
m_block_size = block_size_of(cipher_algo);
- if(m_major > 3 || (m_major == 3 && m_minor >= 2))
+ if(m_version >= Protocol_Version::TLS_V11)
m_iv_size = m_block_size;
else
m_iv_size = 0;
@@ -120,7 +119,7 @@ void Record_Reader::activate(const Ciphersuite& suite,
{
Algorithm_Factory& af = global_state().algorithm_factory();
- if(m_major == 3 && m_minor == 0)
+ if(m_version == Protocol_Version::SSL_V3)
m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")");
else
m_mac = af.make_mac("HMAC(" + mac_algo + ")");
@@ -220,12 +219,17 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz,
" from counterparty");
}
- const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]);
const size_t record_len = make_u16bit(m_readbuf[3], m_readbuf[4]);
- if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor))
- throw TLS_Exception(PROTOCOL_VERSION,
- "Got unexpected version from counterparty");
+ if(m_version.major_version())
+ {
+ if(m_readbuf[1] != m_version.major_version() ||
+ m_readbuf[2] != m_version.minor_version())
+ {
+ throw TLS_Exception(PROTOCOL_VERSION,
+ "Got unexpected version from counterparty");
+ }
+ }
if(record_len > MAX_CIPHERTEXT_SIZE)
throw TLS_Exception(RECORD_OVERFLOW,
@@ -282,7 +286,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz,
* This particular countermeasure is recommended in the TLS 1.2
* spec (RFC 5246) in section 6.2.3.2
*/
- if(version == SSL_V3)
+ if(m_version == Protocol_Version::SSL_V3)
{
if(pad_value > m_block_size)
pad_size = 0;
@@ -313,9 +317,11 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz,
m_mac->update_be(m_seq_no);
m_mac->update(m_readbuf[0]); // msg_type
- if(version != SSL_V3)
- for(size_t i = 0; i != 2; ++i)
- m_mac->update(get_byte(i, version));
+ if(m_version != Protocol_Version::SSL_V3)
+ {
+ m_mac->update(m_version.major_version());
+ m_mac->update(m_version.minor_version());
+ }
m_mac->update_be(plain_length);
m_mac->update(&m_readbuf[TLS_HEADER_SIZE + m_iv_size], plain_length);
diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp
index 139d84c50..9e1e4637c 100644
--- a/src/tls/rec_wri.cpp
+++ b/src/tls/rec_wri.cpp
@@ -48,8 +48,7 @@ void Record_Writer::reset()
delete m_mac;
m_mac = 0;
- m_major = 0;
- m_minor = 0;
+ m_version = Protocol_Version();
m_block_size = 0;
m_mac_size = 0;
m_iv_size = 0;
@@ -60,10 +59,9 @@ void Record_Writer::reset()
/*
* Set the version to use
*/
-void Record_Writer::set_version(Version_Code version)
+void Record_Writer::set_version(Protocol_Version version)
{
- m_major = (version >> 8) & 0xFF;
- m_minor = (version & 0xFF);
+ m_version = version;
}
/*
@@ -112,7 +110,7 @@ void Record_Writer::activate(const Ciphersuite& suite,
);
m_block_size = block_size_of(cipher_algo);
- if(m_major > 3 || (m_major == 3 && m_minor >= 2))
+ if(m_version >= Protocol_Version::TLS_V11)
m_iv_size = m_block_size;
else
m_iv_size = 0;
@@ -130,7 +128,7 @@ void Record_Writer::activate(const Ciphersuite& suite,
{
Algorithm_Factory& af = global_state().algorithm_factory();
- if(m_major == 3 && m_minor == 0)
+ if(m_version == Protocol_Version::SSL_V3)
m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")");
else
m_mac = af.make_mac("HMAC(" + mac_algo + ")");
@@ -191,8 +189,8 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length)
{
const byte header[TLS_HEADER_SIZE] = {
type,
- m_major,
- m_minor,
+ m_version.major_version(),
+ m_version.minor_version(),
get_byte<u16bit>(0, length),
get_byte<u16bit>(1, length)
};
@@ -205,10 +203,10 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length)
m_mac->update_be(m_seq_no);
m_mac->update(type);
- if(m_major > 3 || (m_major == 3 && m_minor != 0))
+ if(m_version != Protocol_Version::SSL_V3)
{
- m_mac->update(m_major);
- m_mac->update(m_minor);
+ m_mac->update(m_version.major_version());
+ m_mac->update(m_version.minor_version());
}
m_mac->update(get_byte<u16bit>(0, length));
@@ -229,8 +227,8 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length)
// TLS record header
m_writebuf[0] = type;
- m_writebuf[1] = m_major;
- m_writebuf[2] = m_minor;
+ m_writebuf[1] = m_version.major_version();
+ m_writebuf[2] = m_version.minor_version();
m_writebuf[3] = get_byte<u16bit>(0, buf_size);
m_writebuf[4] = get_byte<u16bit>(1, buf_size);
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index 9e61f62af..b027c6cc6 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -21,7 +21,7 @@ namespace TLS {
*/
Server_Hello::Server_Hello(Record_Writer& writer,
Handshake_Hash& hash,
- Version_Code version,
+ Protocol_Version version,
const Client_Hello& c_hello,
const std::vector<X509_Certificate>& certs,
const Policy& policy,
@@ -68,7 +68,7 @@ Server_Hello::Server_Hello(Record_Writer& writer,
Server_Hello::Server_Hello(Record_Writer& writer,
Handshake_Hash& hash,
const MemoryRegion<byte>& session_id,
- Version_Code ver,
+ Protocol_Version ver,
u16bit ciphersuite,
byte compression,
size_t max_fragment_size,
@@ -104,12 +104,15 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf)
TLS_Data_Reader reader(buf);
- s_version = static_cast<Version_Code>(reader.get_u16bit());
+ const byte major_version = reader.get_byte();
+ const byte minor_version = reader.get_byte();
- if(s_version != SSL_V3 &&
- s_version != TLS_V10 &&
- s_version != TLS_V11 &&
- s_version != TLS_V12)
+ s_version = Protocol_Version(major_version, minor_version);
+
+ if(s_version != Protocol_Version::SSL_V3 &&
+ s_version != Protocol_Version::TLS_V10 &&
+ s_version != Protocol_Version::TLS_V11 &&
+ s_version != Protocol_Version::TLS_V12)
{
throw TLS_Exception(PROTOCOL_VERSION,
"Server_Hello: Unsupported server version");
@@ -146,8 +149,8 @@ MemoryVector<byte> Server_Hello::serialize() const
{
MemoryVector<byte> buf;
- buf.push_back(static_cast<byte>(s_version >> 8));
- buf.push_back(static_cast<byte>(s_version ));
+ buf.push_back(s_version.major_version());
+ buf.push_back(s_version.minor_version());
buf += s_random;
append_tls_length_value(buf, m_session_id, 1);
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 359ef6f4a..0a7ae9b14 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -87,7 +87,7 @@ MemoryVector<byte> Server_Key_Exchange::serialize_params() const
Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
const std::string& kex_algo,
const std::string& sig_algo,
- Version_Code version)
+ Protocol_Version version)
{
if(buf.size() < 6)
throw Decoding_Error("Server_Key_Exchange: Packet corrupted");
@@ -109,7 +109,7 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
if(sig_algo != "")
{
- if(version >= TLS_V12)
+ if(version >= Protocol_Version::TLS_V12)
{
m_hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte());
m_sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte());
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index a3ff69d87..76a5424ad 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -17,8 +17,8 @@ namespace Botan {
namespace TLS {
Channel::Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
- std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
- std::tr1::function<bool (const Session&)> handshake_complete) :
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
+ std::tr1::function<bool (const Session&)> handshake_complete) :
proc_fn(proc_fn),
handshake_fn(handshake_complete),
writer(socket_output_fn),
@@ -133,7 +133,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
* Split up and process handshake messages
*/
void Channel::read_handshake(byte rec_type,
- const MemoryRegion<byte>& rec_buf)
+ const MemoryRegion<byte>& rec_buf)
{
if(rec_type == HANDSHAKE)
{
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index d1b31f137..835e8d4bd 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -33,7 +33,7 @@ Client::Client(std::tr1::function<void (const byte[], size_t)> output_fn,
session_manager(session_manager),
creds(creds)
{
- writer.set_version(SSL_V3);
+ writer.set_version(Protocol_Version::SSL_V3);
state = new Handshake_State;
state->set_expected_next(SERVER_HELLO);
@@ -296,7 +296,7 @@ void Client::process_handshake_msg(Handshake_Type type,
std::vector<byte> types = state->cert_req->acceptable_types();
std::vector<X509_Certificate> client_certs =
- creds.cert_chain("", // use types here
+ creds.cert_chain("", // FIXME use types here
"tls-client",
state->client_hello->sni_hostname());
diff --git a/src/tls/tls_handshake_hash.cpp b/src/tls/tls_handshake_hash.cpp
index e521ea342..491b4f6c0 100644
--- a/src/tls/tls_handshake_hash.cpp
+++ b/src/tls/tls_handshake_hash.cpp
@@ -31,11 +31,11 @@ void Handshake_Hash::update(Handshake_Type handshake_type,
/**
* Return a TLS Handshake Hash
*/
-SecureVector<byte> Handshake_Hash::final(Version_Code version)
+SecureVector<byte> Handshake_Hash::final(Protocol_Version version)
{
SecureVector<byte> output;
- if(version == TLS_V10 || version == TLS_V11)
+ if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11)
{
MD5 md5;
SHA_160 sha1;
@@ -46,7 +46,7 @@ SecureVector<byte> Handshake_Hash::final(Version_Code version)
output += md5.final();
output += sha1.final();
}
- else if(version == TLS_V12)
+ else if(version == Protocol_Version::TLS_V12)
{
// This might depend on the ciphersuite
SHA_256 sha256;
diff --git a/src/tls/tls_handshake_hash.h b/src/tls/tls_handshake_hash.h
index a6c2b44e1..20f3c51fc 100644
--- a/src/tls/tls_handshake_hash.h
+++ b/src/tls/tls_handshake_hash.h
@@ -9,6 +9,7 @@
#define BOTAN_TLS_HANDSHAKE_HASH_H__
#include <botan/secmem.h>
+#include <botan/tls_version.h>
#include <botan/tls_magic.h>
namespace Botan {
@@ -35,7 +36,7 @@ class Handshake_Hash
void update(Handshake_Type handshake_type,
const MemoryRegion<byte>& handshake_msg);
- SecureVector<byte> final(Version_Code version);
+ SecureVector<byte> final(Protocol_Version version);
SecureVector<byte> final_ssl3(const MemoryRegion<byte>& master_secret);
const SecureVector<byte>& get_contents() const
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 5eb44414e..15017648c 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -93,7 +93,7 @@ Handshake_State::Handshake_State()
kex_priv = 0;
- version = SSL_V3;
+ version = Protocol_Version::SSL_V3;
hand_expecting_mask = 0;
hand_received_mask = 0;
@@ -133,9 +133,9 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const
std::pair<std::string, Signature_Format>
Handshake_State::choose_sig_format(const Private_Key* key,
- std::string& hash_algo_out,
- std::string& sig_algo_out,
- bool for_client_auth)
+ std::string& hash_algo_out,
+ std::string& sig_algo_out,
+ bool for_client_auth)
{
const std::string sig_algo = key->algo_name();
@@ -153,15 +153,15 @@ Handshake_State::choose_sig_format(const Private_Key* key,
}
}
- if(for_client_auth && this->version == SSL_V3)
+ if(for_client_auth && this->version == Protocol_Version::SSL_V3)
hash_algo = "Raw";
- if(hash_algo == "" && this->version == TLS_V12)
+ if(hash_algo == "" && this->version == Protocol_Version::TLS_V12)
hash_algo = "SHA-1"; // TLS 1.2 but no compatible hashes set (?)
BOTAN_ASSERT(hash_algo != "", "Couldn't figure out hash to use");
- if(this->version >= TLS_V12)
+ if(this->version >= Protocol_Version::TLS_V12)
{
hash_algo_out = hash_algo;
sig_algo_out = sig_algo;
@@ -185,9 +185,9 @@ Handshake_State::choose_sig_format(const Private_Key* key,
std::pair<std::string, Signature_Format>
Handshake_State::understand_sig_format(const Public_Key* key,
- std::string hash_algo,
- std::string sig_algo,
- bool for_client_auth)
+ std::string hash_algo,
+ std::string sig_algo,
+ bool for_client_auth)
{
const std::string algo_name = key->algo_name();
@@ -199,7 +199,7 @@ Handshake_State::understand_sig_format(const Public_Key* key,
Or not?
*/
- if(this->version < TLS_V12)
+ if(this->version < Protocol_Version::TLS_V12)
{
if(hash_algo != "" || sig_algo != "")
throw Decoding_Error("Counterparty sent hash/sig IDs with old version");
@@ -215,11 +215,11 @@ Handshake_State::understand_sig_format(const Public_Key* key,
if(algo_name == "RSA")
{
- if(for_client_auth && this->version == SSL_V3)
+ if(for_client_auth && this->version == Protocol_Version::SSL_V3)
{
hash_algo = "Raw";
}
- else if(this->version < TLS_V12)
+ else if(this->version < Protocol_Version::TLS_V12)
{
hash_algo = "TLS.Digest.0";
}
@@ -229,11 +229,11 @@ Handshake_State::understand_sig_format(const Public_Key* key,
}
else if(algo_name == "DSA")
{
- if(for_client_auth && this->version == SSL_V3)
+ if(for_client_auth && this->version == Protocol_Version::SSL_V3)
{
hash_algo = "Raw";
}
- else if(this->version < TLS_V12)
+ else if(this->version < Protocol_Version::TLS_V12)
{
hash_algo = "SHA-1";
}
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 54e0da892..7339033c4 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -60,7 +60,7 @@ class Handshake_State
std::string& sig_algo,
bool for_client_auth);
- Version_Code version;
+ Protocol_Version version;
class Client_Hello* client_hello;
class Server_Hello* server_hello;
diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h
index 09919c26f..ebca860de 100644
--- a/src/tls/tls_magic.h
+++ b/src/tls/tls_magic.h
@@ -24,14 +24,6 @@ enum Size_Limits {
MAX_TLS_RECORD_SIZE = MAX_CIPHERTEXT_SIZE + TLS_HEADER_SIZE,
};
-enum Version_Code {
- NO_VERSION_SET = 0x0000,
- SSL_V3 = 0x0300,
- TLS_V10 = 0x0301,
- TLS_V11 = 0x0302,
- TLS_V12 = 0x0303
-};
-
enum Connection_Side { CLIENT = 1, SERVER = 2 };
enum Record_Type {
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 89eb4af16..72d9a1c60 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -50,7 +50,7 @@ class Client_Hello : public Handshake_Message
{
public:
Handshake_Type type() const { return CLIENT_HELLO; }
- Version_Code version() const { return m_version; }
+ Protocol_Version version() const { return m_version; }
const MemoryVector<byte>& session_id() const { return m_session_id; }
std::vector<byte> session_id_vector() const
@@ -106,7 +106,7 @@ class Client_Hello : public Handshake_Message
void deserialize(const MemoryRegion<byte>& buf);
void deserialize_sslv2(const MemoryRegion<byte>& buf);
- Version_Code m_version;
+ Protocol_Version m_version;
MemoryVector<byte> m_session_id, m_random;
std::vector<u16bit> m_suites;
std::vector<byte> m_comp_methods;
@@ -128,7 +128,7 @@ class Server_Hello : public Handshake_Message
{
public:
Handshake_Type type() const { return SERVER_HELLO; }
- Version_Code version() { return s_version; }
+ Protocol_Version version() { return s_version; }
const MemoryVector<byte>& session_id() const { return m_session_id; }
u16bit ciphersuite() const { return suite; }
byte compression_method() const { return comp_method; }
@@ -156,7 +156,7 @@ class Server_Hello : public Handshake_Message
Server_Hello(Record_Writer& writer,
Handshake_Hash& hash,
- Version_Code version,
+ Protocol_Version version,
const Client_Hello& other,
const std::vector<X509_Certificate>& certs,
const Policy& policies,
@@ -169,7 +169,7 @@ class Server_Hello : public Handshake_Message
Server_Hello(Record_Writer& writer,
Handshake_Hash& hash,
const MemoryRegion<byte>& session_id,
- Version_Code ver,
+ Protocol_Version ver,
u16bit ciphersuite,
byte compression,
size_t max_fragment_size,
@@ -183,7 +183,7 @@ class Server_Hello : public Handshake_Message
private:
MemoryVector<byte> serialize() const;
- Version_Code s_version;
+ Protocol_Version s_version;
MemoryVector<byte> m_session_id, s_random;
u16bit suite;
byte comp_method;
@@ -209,7 +209,7 @@ class Client_Key_Exchange : public Handshake_Message
SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng,
const Private_Key* key,
- Version_Code version);
+ Protocol_Version version);
Client_Key_Exchange(Record_Writer& output,
Handshake_State* state,
@@ -218,7 +218,7 @@ class Client_Key_Exchange : public Handshake_Message
Client_Key_Exchange(const MemoryRegion<byte>& buf,
const Ciphersuite& suite,
- Version_Code using_version);
+ Protocol_Version using_version);
private:
MemoryVector<byte> serialize() const;
@@ -267,10 +267,10 @@ class Certificate_Req : public Handshake_Message
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_Certificate>& allowed_cas,
- Version_Code version);
+ Protocol_Version version);
Certificate_Req(const MemoryRegion<byte>& buf,
- Version_Code version);
+ Protocol_Version version);
private:
MemoryVector<byte> serialize() const;
@@ -302,7 +302,7 @@ class Certificate_Verify : public Handshake_Message
const Private_Key* key);
Certificate_Verify(const MemoryRegion<byte>& buf,
- Version_Code version);
+ Protocol_Version version);
private:
MemoryVector<byte> serialize() const;
@@ -372,7 +372,7 @@ class Server_Key_Exchange : public Handshake_Message
Server_Key_Exchange(const MemoryRegion<byte>& buf,
const std::string& kex_alg,
const std::string& sig_alg,
- Version_Code version);
+ Protocol_Version version);
private:
MemoryVector<byte> serialize() const;
MemoryVector<byte> serialize_params() const;
diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h
index f8e608cdb..61de53dcd 100644
--- a/src/tls/tls_policy.h
+++ b/src/tls/tls_policy.h
@@ -8,7 +8,7 @@
#ifndef BOTAN_TLS_POLICY_H__
#define BOTAN_TLS_POLICY_H__
-#include <botan/tls_magic.h>
+#include <botan/tls_version.h>
#include <botan/x509cert.h>
#include <botan/dl_group.h>
#include <vector>
@@ -60,7 +60,7 @@ class BOTAN_DLL Policy
* renegotiation.
*
* @warning Changing this to false exposes you to injected
- * plaintext attacks.
+ * plaintext attacks. Read the RFC for background.
*/
virtual bool require_secure_renegotiation() const { return true; }
@@ -72,12 +72,14 @@ class BOTAN_DLL Policy
/*
* @return the minimum version that we will negotiate
*/
- virtual Version_Code min_version() const { return SSL_V3; }
+ virtual Protocol_Version min_version() const
+ { return Protocol_Version::SSL_V3; }
/*
* @return the version we would prefer to negotiate
*/
- virtual Version_Code pref_version() const { return TLS_V12; }
+ virtual Protocol_Version pref_version() const
+ { return Protocol_Version::TLS_V12; }
virtual ~Policy() {}
};
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 979154001..991243af5 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -9,6 +9,7 @@
#define BOTAN_TLS_RECORDS_H__
#include <botan/tls_suites.h>
+#include <botan/tls_version.h>
#include <botan/pipe.h>
#include <botan/mac.h>
#include <botan/secqueue.h>
@@ -49,7 +50,7 @@ class BOTAN_DLL Record_Writer
const Session_Keys& keys,
Connection_Side side);
- void set_version(Version_Code version);
+ void set_version(Protocol_Version version);
void reset();
@@ -74,7 +75,7 @@ class BOTAN_DLL Record_Writer
size_t m_block_size, m_mac_size, m_iv_size, m_max_fragment;
u64bit m_seq_no;
- byte m_major, m_minor;
+ Protocol_Version m_version;
};
/**
@@ -103,7 +104,7 @@ class BOTAN_DLL Record_Reader
const Session_Keys& keys,
Connection_Side side);
- void set_version(Version_Code version);
+ void set_version(Protocol_Version version);
void reset();
@@ -129,7 +130,7 @@ class BOTAN_DLL Record_Reader
MessageAuthenticationCode* m_mac;
size_t m_block_size, m_iv_size, m_max_fragment;
u64bit m_seq_no;
- byte m_major, m_minor;
+ Protocol_Version m_version;
};
}
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 6c6977b91..54873e682 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -67,13 +67,13 @@ bool check_for_resume(Session& session_info,
* TLS Server Constructor
*/
Server::Server(std::tr1::function<void (const byte[], size_t)> output_fn,
- std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
- std::tr1::function<bool (const Session&)> handshake_fn,
- Session_Manager& session_manager,
- Credentials_Manager& creds,
- const Policy& policy,
- RandomNumberGenerator& rng,
- const std::vector<std::string>& next_protocols) :
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
+ std::tr1::function<bool (const Session&)> handshake_fn,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const std::vector<std::string>& next_protocols) :
Channel(output_fn, proc_fn, handshake_fn),
policy(policy),
rng(rng),
@@ -112,7 +112,7 @@ void Server::alert_notify(bool, Alert_Type type)
* Split up and process handshake messages
*/
void Server::read_handshake(byte rec_type,
- const MemoryRegion<byte>& rec_buf)
+ const MemoryRegion<byte>& rec_buf)
{
if(rec_type == HANDSHAKE && !state)
{
@@ -127,7 +127,7 @@ void Server::read_handshake(byte rec_type,
* Process a handshake message
*/
void Server::process_handshake_msg(Handshake_Type type,
- const MemoryRegion<byte>& contents)
+ const MemoryRegion<byte>& contents)
{
if(state == 0)
throw Unexpected_Message("Unexpected handshake message from client");
@@ -155,7 +155,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_hostname = state->client_hello->sni_hostname();
- Version_Code client_version = state->client_hello->version();
+ Protocol_Version client_version = state->client_hello->version();
if(client_version < policy.min_version())
throw TLS_Exception(PROTOCOL_VERSION,
@@ -184,7 +184,7 @@ void Server::process_handshake_msg(Handshake_Type type,
writer,
state->hash,
session_info.session_id(),
- Version_Code(session_info.version()),
+ Protocol_Version(session_info.version()),
session_info.ciphersuite(),
session_info.compression_method(),
session_info.fragment_size(),
diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp
index 3716878e1..d9ccd6df4 100644
--- a/src/tls/tls_session.cpp
+++ b/src/tls/tls_session.cpp
@@ -17,7 +17,7 @@ namespace TLS {
Session::Session(const MemoryRegion<byte>& session_identifier,
const MemoryRegion<byte>& master_secret,
- Version_Code version,
+ Protocol_Version version,
u16bit ciphersuite,
byte compression_method,
Connection_Side side,
@@ -51,12 +51,15 @@ Session::Session(const byte ber[], size_t ber_len)
ASN1_String sni_hostname_str;
ASN1_String srp_identifier_str;
+ byte major_version = 0, minor_version = 0;
+
BER_Decoder(ber, ber_len)
.decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION),
"Unknown version in session structure")
.decode(m_identifier, OCTET_STRING)
.decode_integer_type(m_start_time)
- .decode_integer_type(m_version)
+ .decode_integer_type(major_version)
+ .decode_integer_type(minor_version)
.decode_integer_type(m_ciphersuite)
.decode_integer_type(m_compression_method)
.decode_integer_type(side_code)
@@ -67,6 +70,7 @@ Session::Session(const byte ber[], size_t ber_len)
.decode(sni_hostname_str)
.decode(srp_identifier_str);
+ m_version = Protocol_Version(major_version, minor_version);
m_sni_hostname = sni_hostname_str.value();
m_srp_identifier = srp_identifier_str.value();
m_connection_side = static_cast<Connection_Side>(side_code);
@@ -79,7 +83,8 @@ SecureVector<byte> Session::BER_encode() const
.encode(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION))
.encode(m_identifier, OCTET_STRING)
.encode(static_cast<size_t>(m_start_time))
- .encode(static_cast<size_t>(m_version))
+ .encode(static_cast<size_t>(m_version.major_version()))
+ .encode(static_cast<size_t>(m_version.minor_version()))
.encode(static_cast<size_t>(m_ciphersuite))
.encode(static_cast<size_t>(m_compression_method))
.encode(static_cast<size_t>(m_connection_side))
diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h
index 9b3f5b194..e44967c00 100644
--- a/src/tls/tls_session.h
+++ b/src/tls/tls_session.h
@@ -9,6 +9,7 @@
#define TLS_SESSION_STATE_H__
#include <botan/x509cert.h>
+#include <botan/tls_version.h>
#include <botan/tls_magic.h>
#include <botan/secmem.h>
@@ -28,7 +29,7 @@ class BOTAN_DLL Session
*/
Session() :
m_start_time(0),
- m_version(0),
+ m_version(),
m_ciphersuite(0),
m_compression_method(0),
m_connection_side(static_cast<Connection_Side>(0)),
@@ -40,16 +41,16 @@ class BOTAN_DLL Session
* New session (sets session start time)
*/
Session(const MemoryRegion<byte>& session_id,
- const MemoryRegion<byte>& master_secret,
- Version_Code version,
- u16bit ciphersuite,
- byte compression_method,
- Connection_Side side,
- bool secure_renegotiation_supported,
- size_t fragment_size,
- const std::vector<X509_Certificate>& peer_certs,
- const std::string& sni_hostname = "",
- const std::string& srp_identifier = "");
+ const MemoryRegion<byte>& master_secret,
+ Protocol_Version version,
+ u16bit ciphersuite,
+ byte compression_method,
+ Connection_Side side,
+ bool secure_renegotiation_supported,
+ size_t fragment_size,
+ const std::vector<X509_Certificate>& peer_certs,
+ const std::string& sni_hostname = "",
+ const std::string& srp_identifier = "");
/**
* Load a session from BER (created by BER_encode)
@@ -66,18 +67,7 @@ class BOTAN_DLL Session
/**
* Get the version of the saved session
*/
- Version_Code version() const
- { return static_cast<Version_Code>(m_version); }
-
- /**
- * Get the major version of the saved session
- */
- byte major_version() const { return get_byte(0, m_version); }
-
- /**
- * Get the minor version of the saved session
- */
- byte minor_version() const { return get_byte(1, m_version); }
+ Protocol_Version version() const { return m_version; }
/**
* Get the ciphersuite of the saved session
@@ -141,7 +131,7 @@ class BOTAN_DLL Session
MemoryVector<byte> m_identifier;
SecureVector<byte> m_master_secret;
- u16bit m_version;
+ Protocol_Version m_version;
u16bit m_ciphersuite;
byte m_compression_method;
Connection_Side m_connection_side;
diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp
index 42727273a..541f0b2d9 100644
--- a/src/tls/tls_session_key.cpp
+++ b/src/tls/tls_session_key.cpp
@@ -17,13 +17,13 @@ namespace TLS {
namespace {
-std::string lookup_prf_name(Version_Code version)
+std::string lookup_prf_name(Protocol_Version version)
{
- if(version == SSL_V3)
+ if(version == Protocol_Version::SSL_V3)
return "SSL3-PRF";
- else if(version == TLS_V10 || version == TLS_V11)
+ else if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11)
return "TLS-PRF";
- else if(version == TLS_V12)
+ else if(version == Protocol_Version::TLS_V12)
return "TLS-12-PRF(SHA-256)";
else
throw Invalid_Argument("Session_Keys: Unknown version code");
@@ -65,7 +65,7 @@ Session_Keys::Session_Keys(Handshake_State* state,
{
SecureVector<byte> salt;
- if(state->version != SSL_V3)
+ if(state->version != Protocol_Version::SSL_V3)
salt += std::make_pair(MASTER_SECRET_MAGIC, sizeof(MASTER_SECRET_MAGIC));
salt += state->client_hello->random();
@@ -75,7 +75,7 @@ Session_Keys::Session_Keys(Handshake_State* state,
}
SecureVector<byte> salt;
- if(state->version != SSL_V3)
+ if(state->version != Protocol_Version::SSL_V3)
salt += std::make_pair(KEY_GEN_MAGIC, sizeof(KEY_GEN_MAGIC));
salt += state->server_hello->random();
salt += state->client_hello->random();
diff --git a/src/tls/tls_version.cpp b/src/tls/tls_version.cpp
new file mode 100644
index 000000000..4445998eb
--- /dev/null
+++ b/src/tls/tls_version.cpp
@@ -0,0 +1,33 @@
+/*
+* TLS Protocol Version Management
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/tls_version.h>
+#include <botan/parsing.h>
+
+namespace Botan {
+
+namespace TLS {
+
+std::string Protocol_Version::to_string() const
+ {
+ const byte maj = major_version();
+ const byte min = minor_version();
+
+ // Some very new or very old protocol?
+ if(maj != 3)
+ return "Protocol " + Botan::to_string(maj) + "." + Botan::to_string(min);
+
+ if(maj == 3 && min == 0)
+ return "SSL v3";
+
+ // The TLS v1.[0123...] case
+ return "TLS v1." + Botan::to_string(min-1);
+ }
+
+}
+
+}
diff --git a/src/tls/tls_version.h b/src/tls/tls_version.h
new file mode 100644
index 000000000..e4e6b49a2
--- /dev/null
+++ b/src/tls/tls_version.h
@@ -0,0 +1,100 @@
+/*
+* TLS Protocol Version Management
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_PROTOCOL_VERSION_H__
+#define BOTAN_TLS_PROTOCOL_VERSION_H__
+
+#include <botan/get_byte.h>
+#include <botan/parsing.h>
+
+namespace Botan {
+
+namespace TLS {
+
+class BOTAN_DLL Protocol_Version
+ {
+ public:
+ enum Version_Code {
+ SSL_V3 = 0x0300,
+ TLS_V10 = 0x0301,
+ TLS_V11 = 0x0302,
+ TLS_V12 = 0x0303
+ };
+
+ Protocol_Version() : m_major(0), m_minor(0) {}
+
+ Protocol_Version(Version_Code named_version) :
+ m_major(get_byte<u16bit>(0, named_version)),
+ m_minor(get_byte<u16bit>(1, named_version)) {}
+
+ Protocol_Version(byte major, byte minor) : m_major(major), m_minor(minor) {}
+
+ /**
+ * Get the major version of the protocol version
+ */
+ byte major_version() const { return m_major; }
+
+ /**
+ * Get the minor version of the protocol version
+ */
+ byte minor_version() const { return m_minor; }
+
+ bool operator==(const Protocol_Version& other) const
+ {
+ return (cmp(other) == 0);
+ }
+
+ bool operator!=(const Protocol_Version& other) const
+ {
+ return (cmp(other) != 0);
+ }
+
+ bool operator>=(const Protocol_Version& other) const
+ {
+ return (cmp(other) >= 0);
+ }
+
+ bool operator>(const Protocol_Version& other) const
+ {
+ return (cmp(other) > 0);
+ }
+
+ bool operator<=(const Protocol_Version& other) const
+ {
+ return (cmp(other) <= 0);
+ }
+
+ bool operator<(const Protocol_Version& other) const
+ {
+ return (cmp(other) < 0);
+ }
+
+ std::string to_string() const;
+
+ private:
+ s32bit cmp(const Protocol_Version& other) const
+ {
+ if(major_version() < other.major_version())
+ return -1;
+ if(major_version() > other.major_version())
+ return 1;
+ if(minor_version() < other.minor_version())
+ return -1;
+ if(minor_version() > other.minor_version())
+ return 1;
+ return 0;
+ }
+
+ byte m_major, m_minor;
+ };
+
+}
+
+}
+
+#endif
+