diff options
-rw-r--r-- | src/tls/c_hello.cpp | 16 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 23 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 10 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 8 | ||||
-rw-r--r-- | src/tls/finished.cpp | 10 | ||||
-rw-r--r-- | src/tls/info.txt | 2 | ||||
-rw-r--r-- | src/tls/rec_read.cpp | 34 | ||||
-rw-r--r-- | src/tls/rec_wri.cpp | 26 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 21 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 4 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 6 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 4 | ||||
-rw-r--r-- | src/tls/tls_handshake_hash.cpp | 6 | ||||
-rw-r--r-- | src/tls/tls_handshake_hash.h | 3 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 30 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 2 | ||||
-rw-r--r-- | src/tls/tls_magic.h | 8 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 24 | ||||
-rw-r--r-- | src/tls/tls_policy.h | 10 | ||||
-rw-r--r-- | src/tls/tls_record.h | 9 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 22 | ||||
-rw-r--r-- | src/tls/tls_session.cpp | 11 | ||||
-rw-r--r-- | src/tls/tls_session.h | 38 | ||||
-rw-r--r-- | src/tls/tls_session_key.cpp | 12 | ||||
-rw-r--r-- | src/tls/tls_version.cpp | 33 | ||||
-rw-r--r-- | src/tls/tls_version.h | 100 |
26 files changed, 304 insertions, 168 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 4fdadd455..00728ff16 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -150,8 +150,8 @@ MemoryVector<byte> Client_Hello::serialize() const { MemoryVector<byte> buf; - buf.push_back(static_cast<byte>(m_version >> 8)); - buf.push_back(static_cast<byte>(m_version )); + buf.push_back(m_version.major_version()); + buf.push_back(m_version.minor_version()); buf += m_random; append_tls_length_value(buf, m_session_id, 1); @@ -174,7 +174,7 @@ MemoryVector<byte> Client_Hello::serialize() const extensions.add(new Server_Name_Indicator(m_hostname)); extensions.add(new SRP_Identifier(m_srp_identifier)); - if(m_version >= TLS_V12) + if(m_version >= Protocol_Version::TLS_V12) extensions.add(new Signature_Algorithms(m_supported_algos)); if(m_next_protocol) @@ -220,7 +220,7 @@ void Client_Hello::deserialize_sslv2(const MemoryRegion<byte>& buf) m_suites.push_back(make_u16bit(buf[i+1], buf[i+2])); } - m_version = static_cast<Version_Code>(make_u16bit(buf[1], buf[2])); + m_version = Protocol_Version(buf[1], buf[2]); m_random.resize(challenge_len); copy_mem(&m_random[0], &buf[9+cipher_spec_len+m_session_id_len], challenge_len); @@ -242,7 +242,11 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) TLS_Data_Reader reader(buf); - m_version = static_cast<Version_Code>(reader.get_u16bit()); + const byte major_version = reader.get_byte(); + const byte minor_version = reader.get_byte(); + + m_version = Protocol_Version(major_version, minor_version); + m_random = reader.get_fixed<byte>(32); m_session_id = reader.get_range<byte>(1, 0, 32); @@ -289,7 +293,7 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) } else { - if(m_version >= TLS_V12) + if(m_version >= Protocol_Version::TLS_V12) { /* The rule for when a TLS 1.2 client not sending the extension diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index de8f54fbe..8bf923041 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -89,17 +89,17 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key.get())) { - const Version_Code pref_version = state->client_hello->version(); + const Protocol_Version pref_version = state->client_hello->version(); pre_master = rng.random_vec(48); - pre_master[0] = (pref_version >> 8) & 0xFF; - pre_master[1] = (pref_version ) & 0xFF; + pre_master[0] = pref_version.major_version(); + pre_master[1] = pref_version.minor_version(); PK_Encryptor_EME encryptor(*rsa_pub, "PKCS1v15"); key_material = encryptor.encrypt(pre_master, rng); - if(state->version == SSL_V3) + if(state->version == Protocol_Version::SSL_V3) include_length = false; } else @@ -116,11 +116,11 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, */ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, const Ciphersuite& suite, - Version_Code using_version) + Protocol_Version using_version) { include_length = true; - if(using_version == SSL_V3 && (suite.kex_algo() == "")) + if(using_version == Protocol_Version::SSL_V3 && (suite.kex_algo() == "")) include_length = false; if(include_length) @@ -153,7 +153,7 @@ MemoryVector<byte> Client_Key_Exchange::serialize() const SecureVector<byte> Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, const Private_Key* priv_key, - Version_Code version) + Protocol_Version client_version) { if(const DH_PrivateKey* dh_priv = dynamic_cast<const DH_PrivateKey*>(priv_key)) @@ -184,14 +184,17 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, pre_master = decryptor.decrypt(key_material); if(pre_master.size() != 48 || - make_u16bit(pre_master[0], pre_master[1]) != version) + client_version.major_version() != pre_master[0] || + client_version.minor_version() != pre_master[1]) + { throw Decoding_Error("Client_Key_Exchange: Secret corrupted"); + } } catch(...) { pre_master = rng.random_vec(48); - pre_master[0] = (version >> 8) & 0xFF; - pre_master[1] = (version ) & 0xFF; + pre_master[0] = client_version.major_version(); + pre_master[1] = client_version.minor_version(); } return pre_master; diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index d5a73f64e..3f70c306b 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -13,8 +13,6 @@ #include <botan/loadstor.h> #include <botan/secqueue.h> -#include <stdio.h> - namespace Botan { namespace TLS { @@ -26,7 +24,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& ca_certs, - Version_Code version) + Protocol_Version version) { for(size_t i = 0; i != ca_certs.size(); ++i) names.push_back(ca_certs[i].subject_dn()); @@ -34,7 +32,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, cert_types.push_back(RSA_CERT); cert_types.push_back(DSS_CERT); - if(version >= TLS_V12) + if(version >= Protocol_Version::TLS_V12) { std::vector<std::string> hashes = policy.allowed_hashes(); std::vector<std::string> sigs = policy.allowed_signature_methods(); @@ -51,7 +49,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, * Deserialize a Certificate Request message */ Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf, - Version_Code version) + Protocol_Version version) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); @@ -60,7 +58,7 @@ Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf, cert_types = reader.get_range_vector<byte>(1, 1, 255); - if(version >= TLS_V12) + if(version >= Protocol_Version::TLS_V12) { std::vector<byte> sig_hash_algs = reader.get_range_vector<byte>(2, 2, 65534); diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 923cdbb42..791635b17 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -30,7 +30,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, PK_Signer signer(*priv_key, format.first, format.second); - if(state->version == SSL_V3) + if(state->version == Protocol_Version::SSL_V3) { SecureVector<byte> md5_sha = state->hash.final_ssl3( state->keys.master_secret()); @@ -52,11 +52,11 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, * Deserialize a Certificate Verify message */ Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf, - Version_Code version) + Protocol_Version version) { TLS_Data_Reader reader(buf); - if(version >= TLS_V12) + if(version >= Protocol_Version::TLS_V12) { hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte()); sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte()); @@ -99,7 +99,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert, PK_Verifier verifier(*key, format.first, format.second); - if(state->version == SSL_V3) + if(state->version == Protocol_Version::SSL_V3) { SecureVector<byte> md5_sha = state->hash.final_ssl3( state->keys.master_secret()); diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index f7f8a7eb8..80385bd5e 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -11,19 +11,17 @@ #include <botan/sha2_32.h> #include <memory> -#include <stdio.h> - namespace Botan { namespace TLS { namespace { -KDF* choose_tls_prf(Version_Code version) +KDF* choose_tls_prf(Protocol_Version version) { - if(version == TLS_V10 || version == TLS_V11) + if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11) return new TLS_PRF; - else if(version == TLS_V12) + else if(version == Protocol_Version::TLS_V12) return new TLS_12_PRF(new HMAC(new SHA_256)); // might depend on ciphersuite else throw TLS_Exception(PROTOCOL_VERSION, @@ -36,7 +34,7 @@ KDF* choose_tls_prf(Version_Code version) MemoryVector<byte> finished_compute_verify(Handshake_State* state, Connection_Side side) { - if(state->version == SSL_V3) + if(state->version == Protocol_Version::SSL_V3) { const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 }; const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 }; diff --git a/src/tls/info.txt b/src/tls/info.txt index 16d112df2..2774e9be8 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -18,6 +18,7 @@ tls_server.h tls_session.h tls_session_manager.h tls_suites.h +tls_version.h </header:public> <header:internal> @@ -52,6 +53,7 @@ tls_session.cpp tls_session_key.cpp tls_session_manager.cpp tls_suites.cpp +tls_version.cpp </source> <requires> diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index 4db50262d..3fd2df33f 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -41,7 +41,7 @@ void Record_Reader::reset() m_block_size = 0; m_iv_size = 0; - m_major = m_minor = 0; + m_version = Protocol_Version(); m_seq_no = 0; set_maximum_fragment_size(0); } @@ -57,10 +57,9 @@ void Record_Reader::set_maximum_fragment_size(size_t max_fragment) /* * Set the version to use */ -void Record_Reader::set_version(Version_Code version) +void Record_Reader::set_version(Protocol_Version version) { - m_major = (version >> 8) & 0xFF; - m_minor = (version & 0xFF); + m_version = version; } /* @@ -102,7 +101,7 @@ void Record_Reader::activate(const Ciphersuite& suite, ); m_block_size = block_size_of(cipher_algo); - if(m_major > 3 || (m_major == 3 && m_minor >= 2)) + if(m_version >= Protocol_Version::TLS_V11) m_iv_size = m_block_size; else m_iv_size = 0; @@ -120,7 +119,7 @@ void Record_Reader::activate(const Ciphersuite& suite, { Algorithm_Factory& af = global_state().algorithm_factory(); - if(m_major == 3 && m_minor == 0) + if(m_version == Protocol_Version::SSL_V3) m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")"); else m_mac = af.make_mac("HMAC(" + mac_algo + ")"); @@ -220,12 +219,17 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, " from counterparty"); } - const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]); const size_t record_len = make_u16bit(m_readbuf[3], m_readbuf[4]); - if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor)) - throw TLS_Exception(PROTOCOL_VERSION, - "Got unexpected version from counterparty"); + if(m_version.major_version()) + { + if(m_readbuf[1] != m_version.major_version() || + m_readbuf[2] != m_version.minor_version()) + { + throw TLS_Exception(PROTOCOL_VERSION, + "Got unexpected version from counterparty"); + } + } if(record_len > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(RECORD_OVERFLOW, @@ -282,7 +286,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, * This particular countermeasure is recommended in the TLS 1.2 * spec (RFC 5246) in section 6.2.3.2 */ - if(version == SSL_V3) + if(m_version == Protocol_Version::SSL_V3) { if(pad_value > m_block_size) pad_size = 0; @@ -313,9 +317,11 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, m_mac->update_be(m_seq_no); m_mac->update(m_readbuf[0]); // msg_type - if(version != SSL_V3) - for(size_t i = 0; i != 2; ++i) - m_mac->update(get_byte(i, version)); + if(m_version != Protocol_Version::SSL_V3) + { + m_mac->update(m_version.major_version()); + m_mac->update(m_version.minor_version()); + } m_mac->update_be(plain_length); m_mac->update(&m_readbuf[TLS_HEADER_SIZE + m_iv_size], plain_length); diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp index 139d84c50..9e1e4637c 100644 --- a/src/tls/rec_wri.cpp +++ b/src/tls/rec_wri.cpp @@ -48,8 +48,7 @@ void Record_Writer::reset() delete m_mac; m_mac = 0; - m_major = 0; - m_minor = 0; + m_version = Protocol_Version(); m_block_size = 0; m_mac_size = 0; m_iv_size = 0; @@ -60,10 +59,9 @@ void Record_Writer::reset() /* * Set the version to use */ -void Record_Writer::set_version(Version_Code version) +void Record_Writer::set_version(Protocol_Version version) { - m_major = (version >> 8) & 0xFF; - m_minor = (version & 0xFF); + m_version = version; } /* @@ -112,7 +110,7 @@ void Record_Writer::activate(const Ciphersuite& suite, ); m_block_size = block_size_of(cipher_algo); - if(m_major > 3 || (m_major == 3 && m_minor >= 2)) + if(m_version >= Protocol_Version::TLS_V11) m_iv_size = m_block_size; else m_iv_size = 0; @@ -130,7 +128,7 @@ void Record_Writer::activate(const Ciphersuite& suite, { Algorithm_Factory& af = global_state().algorithm_factory(); - if(m_major == 3 && m_minor == 0) + if(m_version == Protocol_Version::SSL_V3) m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")"); else m_mac = af.make_mac("HMAC(" + mac_algo + ")"); @@ -191,8 +189,8 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length) { const byte header[TLS_HEADER_SIZE] = { type, - m_major, - m_minor, + m_version.major_version(), + m_version.minor_version(), get_byte<u16bit>(0, length), get_byte<u16bit>(1, length) }; @@ -205,10 +203,10 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length) m_mac->update_be(m_seq_no); m_mac->update(type); - if(m_major > 3 || (m_major == 3 && m_minor != 0)) + if(m_version != Protocol_Version::SSL_V3) { - m_mac->update(m_major); - m_mac->update(m_minor); + m_mac->update(m_version.major_version()); + m_mac->update(m_version.minor_version()); } m_mac->update(get_byte<u16bit>(0, length)); @@ -229,8 +227,8 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length) // TLS record header m_writebuf[0] = type; - m_writebuf[1] = m_major; - m_writebuf[2] = m_minor; + m_writebuf[1] = m_version.major_version(); + m_writebuf[2] = m_version.minor_version(); m_writebuf[3] = get_byte<u16bit>(0, buf_size); m_writebuf[4] = get_byte<u16bit>(1, buf_size); diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 9e61f62af..b027c6cc6 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -21,7 +21,7 @@ namespace TLS { */ Server_Hello::Server_Hello(Record_Writer& writer, Handshake_Hash& hash, - Version_Code version, + Protocol_Version version, const Client_Hello& c_hello, const std::vector<X509_Certificate>& certs, const Policy& policy, @@ -68,7 +68,7 @@ Server_Hello::Server_Hello(Record_Writer& writer, Server_Hello::Server_Hello(Record_Writer& writer, Handshake_Hash& hash, const MemoryRegion<byte>& session_id, - Version_Code ver, + Protocol_Version ver, u16bit ciphersuite, byte compression, size_t max_fragment_size, @@ -104,12 +104,15 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) TLS_Data_Reader reader(buf); - s_version = static_cast<Version_Code>(reader.get_u16bit()); + const byte major_version = reader.get_byte(); + const byte minor_version = reader.get_byte(); - if(s_version != SSL_V3 && - s_version != TLS_V10 && - s_version != TLS_V11 && - s_version != TLS_V12) + s_version = Protocol_Version(major_version, minor_version); + + if(s_version != Protocol_Version::SSL_V3 && + s_version != Protocol_Version::TLS_V10 && + s_version != Protocol_Version::TLS_V11 && + s_version != Protocol_Version::TLS_V12) { throw TLS_Exception(PROTOCOL_VERSION, "Server_Hello: Unsupported server version"); @@ -146,8 +149,8 @@ MemoryVector<byte> Server_Hello::serialize() const { MemoryVector<byte> buf; - buf.push_back(static_cast<byte>(s_version >> 8)); - buf.push_back(static_cast<byte>(s_version )); + buf.push_back(s_version.major_version()); + buf.push_back(s_version.minor_version()); buf += s_random; append_tls_length_value(buf, m_session_id, 1); diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 359ef6f4a..0a7ae9b14 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -87,7 +87,7 @@ MemoryVector<byte> Server_Key_Exchange::serialize_params() const Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf, const std::string& kex_algo, const std::string& sig_algo, - Version_Code version) + Protocol_Version version) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); @@ -109,7 +109,7 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf, if(sig_algo != "") { - if(version >= TLS_V12) + if(version >= Protocol_Version::TLS_V12) { m_hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte()); m_sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte()); diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index a3ff69d87..76a5424ad 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -17,8 +17,8 @@ namespace Botan { namespace TLS { Channel::Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const Session&)> handshake_complete) : + std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_complete) : proc_fn(proc_fn), handshake_fn(handshake_complete), writer(socket_output_fn), @@ -133,7 +133,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) * Split up and process handshake messages */ void Channel::read_handshake(byte rec_type, - const MemoryRegion<byte>& rec_buf) + const MemoryRegion<byte>& rec_buf) { if(rec_type == HANDSHAKE) { diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index d1b31f137..835e8d4bd 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -33,7 +33,7 @@ Client::Client(std::tr1::function<void (const byte[], size_t)> output_fn, session_manager(session_manager), creds(creds) { - writer.set_version(SSL_V3); + writer.set_version(Protocol_Version::SSL_V3); state = new Handshake_State; state->set_expected_next(SERVER_HELLO); @@ -296,7 +296,7 @@ void Client::process_handshake_msg(Handshake_Type type, std::vector<byte> types = state->cert_req->acceptable_types(); std::vector<X509_Certificate> client_certs = - creds.cert_chain("", // use types here + creds.cert_chain("", // FIXME use types here "tls-client", state->client_hello->sni_hostname()); diff --git a/src/tls/tls_handshake_hash.cpp b/src/tls/tls_handshake_hash.cpp index e521ea342..491b4f6c0 100644 --- a/src/tls/tls_handshake_hash.cpp +++ b/src/tls/tls_handshake_hash.cpp @@ -31,11 +31,11 @@ void Handshake_Hash::update(Handshake_Type handshake_type, /** * Return a TLS Handshake Hash */ -SecureVector<byte> Handshake_Hash::final(Version_Code version) +SecureVector<byte> Handshake_Hash::final(Protocol_Version version) { SecureVector<byte> output; - if(version == TLS_V10 || version == TLS_V11) + if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11) { MD5 md5; SHA_160 sha1; @@ -46,7 +46,7 @@ SecureVector<byte> Handshake_Hash::final(Version_Code version) output += md5.final(); output += sha1.final(); } - else if(version == TLS_V12) + else if(version == Protocol_Version::TLS_V12) { // This might depend on the ciphersuite SHA_256 sha256; diff --git a/src/tls/tls_handshake_hash.h b/src/tls/tls_handshake_hash.h index a6c2b44e1..20f3c51fc 100644 --- a/src/tls/tls_handshake_hash.h +++ b/src/tls/tls_handshake_hash.h @@ -9,6 +9,7 @@ #define BOTAN_TLS_HANDSHAKE_HASH_H__ #include <botan/secmem.h> +#include <botan/tls_version.h> #include <botan/tls_magic.h> namespace Botan { @@ -35,7 +36,7 @@ class Handshake_Hash void update(Handshake_Type handshake_type, const MemoryRegion<byte>& handshake_msg); - SecureVector<byte> final(Version_Code version); + SecureVector<byte> final(Protocol_Version version); SecureVector<byte> final_ssl3(const MemoryRegion<byte>& master_secret); const SecureVector<byte>& get_contents() const diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 5eb44414e..15017648c 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -93,7 +93,7 @@ Handshake_State::Handshake_State() kex_priv = 0; - version = SSL_V3; + version = Protocol_Version::SSL_V3; hand_expecting_mask = 0; hand_received_mask = 0; @@ -133,9 +133,9 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const std::pair<std::string, Signature_Format> Handshake_State::choose_sig_format(const Private_Key* key, - std::string& hash_algo_out, - std::string& sig_algo_out, - bool for_client_auth) + std::string& hash_algo_out, + std::string& sig_algo_out, + bool for_client_auth) { const std::string sig_algo = key->algo_name(); @@ -153,15 +153,15 @@ Handshake_State::choose_sig_format(const Private_Key* key, } } - if(for_client_auth && this->version == SSL_V3) + if(for_client_auth && this->version == Protocol_Version::SSL_V3) hash_algo = "Raw"; - if(hash_algo == "" && this->version == TLS_V12) + if(hash_algo == "" && this->version == Protocol_Version::TLS_V12) hash_algo = "SHA-1"; // TLS 1.2 but no compatible hashes set (?) BOTAN_ASSERT(hash_algo != "", "Couldn't figure out hash to use"); - if(this->version >= TLS_V12) + if(this->version >= Protocol_Version::TLS_V12) { hash_algo_out = hash_algo; sig_algo_out = sig_algo; @@ -185,9 +185,9 @@ Handshake_State::choose_sig_format(const Private_Key* key, std::pair<std::string, Signature_Format> Handshake_State::understand_sig_format(const Public_Key* key, - std::string hash_algo, - std::string sig_algo, - bool for_client_auth) + std::string hash_algo, + std::string sig_algo, + bool for_client_auth) { const std::string algo_name = key->algo_name(); @@ -199,7 +199,7 @@ Handshake_State::understand_sig_format(const Public_Key* key, Or not? */ - if(this->version < TLS_V12) + if(this->version < Protocol_Version::TLS_V12) { if(hash_algo != "" || sig_algo != "") throw Decoding_Error("Counterparty sent hash/sig IDs with old version"); @@ -215,11 +215,11 @@ Handshake_State::understand_sig_format(const Public_Key* key, if(algo_name == "RSA") { - if(for_client_auth && this->version == SSL_V3) + if(for_client_auth && this->version == Protocol_Version::SSL_V3) { hash_algo = "Raw"; } - else if(this->version < TLS_V12) + else if(this->version < Protocol_Version::TLS_V12) { hash_algo = "TLS.Digest.0"; } @@ -229,11 +229,11 @@ Handshake_State::understand_sig_format(const Public_Key* key, } else if(algo_name == "DSA") { - if(for_client_auth && this->version == SSL_V3) + if(for_client_auth && this->version == Protocol_Version::SSL_V3) { hash_algo = "Raw"; } - else if(this->version < TLS_V12) + else if(this->version < Protocol_Version::TLS_V12) { hash_algo = "SHA-1"; } diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 54e0da892..7339033c4 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -60,7 +60,7 @@ class Handshake_State std::string& sig_algo, bool for_client_auth); - Version_Code version; + Protocol_Version version; class Client_Hello* client_hello; class Server_Hello* server_hello; diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h index 09919c26f..ebca860de 100644 --- a/src/tls/tls_magic.h +++ b/src/tls/tls_magic.h @@ -24,14 +24,6 @@ enum Size_Limits { MAX_TLS_RECORD_SIZE = MAX_CIPHERTEXT_SIZE + TLS_HEADER_SIZE, }; -enum Version_Code { - NO_VERSION_SET = 0x0000, - SSL_V3 = 0x0300, - TLS_V10 = 0x0301, - TLS_V11 = 0x0302, - TLS_V12 = 0x0303 -}; - enum Connection_Side { CLIENT = 1, SERVER = 2 }; enum Record_Type { diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 89eb4af16..72d9a1c60 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -50,7 +50,7 @@ class Client_Hello : public Handshake_Message { public: Handshake_Type type() const { return CLIENT_HELLO; } - Version_Code version() const { return m_version; } + Protocol_Version version() const { return m_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } std::vector<byte> session_id_vector() const @@ -106,7 +106,7 @@ class Client_Hello : public Handshake_Message void deserialize(const MemoryRegion<byte>& buf); void deserialize_sslv2(const MemoryRegion<byte>& buf); - Version_Code m_version; + Protocol_Version m_version; MemoryVector<byte> m_session_id, m_random; std::vector<u16bit> m_suites; std::vector<byte> m_comp_methods; @@ -128,7 +128,7 @@ class Server_Hello : public Handshake_Message { public: Handshake_Type type() const { return SERVER_HELLO; } - Version_Code version() { return s_version; } + Protocol_Version version() { return s_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } u16bit ciphersuite() const { return suite; } byte compression_method() const { return comp_method; } @@ -156,7 +156,7 @@ class Server_Hello : public Handshake_Message Server_Hello(Record_Writer& writer, Handshake_Hash& hash, - Version_Code version, + Protocol_Version version, const Client_Hello& other, const std::vector<X509_Certificate>& certs, const Policy& policies, @@ -169,7 +169,7 @@ class Server_Hello : public Handshake_Message Server_Hello(Record_Writer& writer, Handshake_Hash& hash, const MemoryRegion<byte>& session_id, - Version_Code ver, + Protocol_Version ver, u16bit ciphersuite, byte compression, size_t max_fragment_size, @@ -183,7 +183,7 @@ class Server_Hello : public Handshake_Message private: MemoryVector<byte> serialize() const; - Version_Code s_version; + Protocol_Version s_version; MemoryVector<byte> m_session_id, s_random; u16bit suite; byte comp_method; @@ -209,7 +209,7 @@ class Client_Key_Exchange : public Handshake_Message SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng, const Private_Key* key, - Version_Code version); + Protocol_Version version); Client_Key_Exchange(Record_Writer& output, Handshake_State* state, @@ -218,7 +218,7 @@ class Client_Key_Exchange : public Handshake_Message Client_Key_Exchange(const MemoryRegion<byte>& buf, const Ciphersuite& suite, - Version_Code using_version); + Protocol_Version using_version); private: MemoryVector<byte> serialize() const; @@ -267,10 +267,10 @@ class Certificate_Req : public Handshake_Message Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& allowed_cas, - Version_Code version); + Protocol_Version version); Certificate_Req(const MemoryRegion<byte>& buf, - Version_Code version); + Protocol_Version version); private: MemoryVector<byte> serialize() const; @@ -302,7 +302,7 @@ class Certificate_Verify : public Handshake_Message const Private_Key* key); Certificate_Verify(const MemoryRegion<byte>& buf, - Version_Code version); + Protocol_Version version); private: MemoryVector<byte> serialize() const; @@ -372,7 +372,7 @@ class Server_Key_Exchange : public Handshake_Message Server_Key_Exchange(const MemoryRegion<byte>& buf, const std::string& kex_alg, const std::string& sig_alg, - Version_Code version); + Protocol_Version version); private: MemoryVector<byte> serialize() const; MemoryVector<byte> serialize_params() const; diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h index f8e608cdb..61de53dcd 100644 --- a/src/tls/tls_policy.h +++ b/src/tls/tls_policy.h @@ -8,7 +8,7 @@ #ifndef BOTAN_TLS_POLICY_H__ #define BOTAN_TLS_POLICY_H__ -#include <botan/tls_magic.h> +#include <botan/tls_version.h> #include <botan/x509cert.h> #include <botan/dl_group.h> #include <vector> @@ -60,7 +60,7 @@ class BOTAN_DLL Policy * renegotiation. * * @warning Changing this to false exposes you to injected - * plaintext attacks. + * plaintext attacks. Read the RFC for background. */ virtual bool require_secure_renegotiation() const { return true; } @@ -72,12 +72,14 @@ class BOTAN_DLL Policy /* * @return the minimum version that we will negotiate */ - virtual Version_Code min_version() const { return SSL_V3; } + virtual Protocol_Version min_version() const + { return Protocol_Version::SSL_V3; } /* * @return the version we would prefer to negotiate */ - virtual Version_Code pref_version() const { return TLS_V12; } + virtual Protocol_Version pref_version() const + { return Protocol_Version::TLS_V12; } virtual ~Policy() {} }; diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 979154001..991243af5 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -9,6 +9,7 @@ #define BOTAN_TLS_RECORDS_H__ #include <botan/tls_suites.h> +#include <botan/tls_version.h> #include <botan/pipe.h> #include <botan/mac.h> #include <botan/secqueue.h> @@ -49,7 +50,7 @@ class BOTAN_DLL Record_Writer const Session_Keys& keys, Connection_Side side); - void set_version(Version_Code version); + void set_version(Protocol_Version version); void reset(); @@ -74,7 +75,7 @@ class BOTAN_DLL Record_Writer size_t m_block_size, m_mac_size, m_iv_size, m_max_fragment; u64bit m_seq_no; - byte m_major, m_minor; + Protocol_Version m_version; }; /** @@ -103,7 +104,7 @@ class BOTAN_DLL Record_Reader const Session_Keys& keys, Connection_Side side); - void set_version(Version_Code version); + void set_version(Protocol_Version version); void reset(); @@ -129,7 +130,7 @@ class BOTAN_DLL Record_Reader MessageAuthenticationCode* m_mac; size_t m_block_size, m_iv_size, m_max_fragment; u64bit m_seq_no; - byte m_major, m_minor; + Protocol_Version m_version; }; } diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 6c6977b91..54873e682 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -67,13 +67,13 @@ bool check_for_resume(Session& session_info, * TLS Server Constructor */ Server::Server(std::tr1::function<void (const byte[], size_t)> output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const Session&)> handshake_fn, - Session_Manager& session_manager, - Credentials_Manager& creds, - const Policy& policy, - RandomNumberGenerator& rng, - const std::vector<std::string>& next_protocols) : + std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_fn, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::vector<std::string>& next_protocols) : Channel(output_fn, proc_fn, handshake_fn), policy(policy), rng(rng), @@ -112,7 +112,7 @@ void Server::alert_notify(bool, Alert_Type type) * Split up and process handshake messages */ void Server::read_handshake(byte rec_type, - const MemoryRegion<byte>& rec_buf) + const MemoryRegion<byte>& rec_buf) { if(rec_type == HANDSHAKE && !state) { @@ -127,7 +127,7 @@ void Server::read_handshake(byte rec_type, * Process a handshake message */ void Server::process_handshake_msg(Handshake_Type type, - const MemoryRegion<byte>& contents) + const MemoryRegion<byte>& contents) { if(state == 0) throw Unexpected_Message("Unexpected handshake message from client"); @@ -155,7 +155,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_hostname = state->client_hello->sni_hostname(); - Version_Code client_version = state->client_hello->version(); + Protocol_Version client_version = state->client_hello->version(); if(client_version < policy.min_version()) throw TLS_Exception(PROTOCOL_VERSION, @@ -184,7 +184,7 @@ void Server::process_handshake_msg(Handshake_Type type, writer, state->hash, session_info.session_id(), - Version_Code(session_info.version()), + Protocol_Version(session_info.version()), session_info.ciphersuite(), session_info.compression_method(), session_info.fragment_size(), diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp index 3716878e1..d9ccd6df4 100644 --- a/src/tls/tls_session.cpp +++ b/src/tls/tls_session.cpp @@ -17,7 +17,7 @@ namespace TLS { Session::Session(const MemoryRegion<byte>& session_identifier, const MemoryRegion<byte>& master_secret, - Version_Code version, + Protocol_Version version, u16bit ciphersuite, byte compression_method, Connection_Side side, @@ -51,12 +51,15 @@ Session::Session(const byte ber[], size_t ber_len) ASN1_String sni_hostname_str; ASN1_String srp_identifier_str; + byte major_version = 0, minor_version = 0; + BER_Decoder(ber, ber_len) .decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION), "Unknown version in session structure") .decode(m_identifier, OCTET_STRING) .decode_integer_type(m_start_time) - .decode_integer_type(m_version) + .decode_integer_type(major_version) + .decode_integer_type(minor_version) .decode_integer_type(m_ciphersuite) .decode_integer_type(m_compression_method) .decode_integer_type(side_code) @@ -67,6 +70,7 @@ Session::Session(const byte ber[], size_t ber_len) .decode(sni_hostname_str) .decode(srp_identifier_str); + m_version = Protocol_Version(major_version, minor_version); m_sni_hostname = sni_hostname_str.value(); m_srp_identifier = srp_identifier_str.value(); m_connection_side = static_cast<Connection_Side>(side_code); @@ -79,7 +83,8 @@ SecureVector<byte> Session::BER_encode() const .encode(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION)) .encode(m_identifier, OCTET_STRING) .encode(static_cast<size_t>(m_start_time)) - .encode(static_cast<size_t>(m_version)) + .encode(static_cast<size_t>(m_version.major_version())) + .encode(static_cast<size_t>(m_version.minor_version())) .encode(static_cast<size_t>(m_ciphersuite)) .encode(static_cast<size_t>(m_compression_method)) .encode(static_cast<size_t>(m_connection_side)) diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h index 9b3f5b194..e44967c00 100644 --- a/src/tls/tls_session.h +++ b/src/tls/tls_session.h @@ -9,6 +9,7 @@ #define TLS_SESSION_STATE_H__ #include <botan/x509cert.h> +#include <botan/tls_version.h> #include <botan/tls_magic.h> #include <botan/secmem.h> @@ -28,7 +29,7 @@ class BOTAN_DLL Session */ Session() : m_start_time(0), - m_version(0), + m_version(), m_ciphersuite(0), m_compression_method(0), m_connection_side(static_cast<Connection_Side>(0)), @@ -40,16 +41,16 @@ class BOTAN_DLL Session * New session (sets session start time) */ Session(const MemoryRegion<byte>& session_id, - const MemoryRegion<byte>& master_secret, - Version_Code version, - u16bit ciphersuite, - byte compression_method, - Connection_Side side, - bool secure_renegotiation_supported, - size_t fragment_size, - const std::vector<X509_Certificate>& peer_certs, - const std::string& sni_hostname = "", - const std::string& srp_identifier = ""); + const MemoryRegion<byte>& master_secret, + Protocol_Version version, + u16bit ciphersuite, + byte compression_method, + Connection_Side side, + bool secure_renegotiation_supported, + size_t fragment_size, + const std::vector<X509_Certificate>& peer_certs, + const std::string& sni_hostname = "", + const std::string& srp_identifier = ""); /** * Load a session from BER (created by BER_encode) @@ -66,18 +67,7 @@ class BOTAN_DLL Session /** * Get the version of the saved session */ - Version_Code version() const - { return static_cast<Version_Code>(m_version); } - - /** - * Get the major version of the saved session - */ - byte major_version() const { return get_byte(0, m_version); } - - /** - * Get the minor version of the saved session - */ - byte minor_version() const { return get_byte(1, m_version); } + Protocol_Version version() const { return m_version; } /** * Get the ciphersuite of the saved session @@ -141,7 +131,7 @@ class BOTAN_DLL Session MemoryVector<byte> m_identifier; SecureVector<byte> m_master_secret; - u16bit m_version; + Protocol_Version m_version; u16bit m_ciphersuite; byte m_compression_method; Connection_Side m_connection_side; diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp index 42727273a..541f0b2d9 100644 --- a/src/tls/tls_session_key.cpp +++ b/src/tls/tls_session_key.cpp @@ -17,13 +17,13 @@ namespace TLS { namespace { -std::string lookup_prf_name(Version_Code version) +std::string lookup_prf_name(Protocol_Version version) { - if(version == SSL_V3) + if(version == Protocol_Version::SSL_V3) return "SSL3-PRF"; - else if(version == TLS_V10 || version == TLS_V11) + else if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11) return "TLS-PRF"; - else if(version == TLS_V12) + else if(version == Protocol_Version::TLS_V12) return "TLS-12-PRF(SHA-256)"; else throw Invalid_Argument("Session_Keys: Unknown version code"); @@ -65,7 +65,7 @@ Session_Keys::Session_Keys(Handshake_State* state, { SecureVector<byte> salt; - if(state->version != SSL_V3) + if(state->version != Protocol_Version::SSL_V3) salt += std::make_pair(MASTER_SECRET_MAGIC, sizeof(MASTER_SECRET_MAGIC)); salt += state->client_hello->random(); @@ -75,7 +75,7 @@ Session_Keys::Session_Keys(Handshake_State* state, } SecureVector<byte> salt; - if(state->version != SSL_V3) + if(state->version != Protocol_Version::SSL_V3) salt += std::make_pair(KEY_GEN_MAGIC, sizeof(KEY_GEN_MAGIC)); salt += state->server_hello->random(); salt += state->client_hello->random(); diff --git a/src/tls/tls_version.cpp b/src/tls/tls_version.cpp new file mode 100644 index 000000000..4445998eb --- /dev/null +++ b/src/tls/tls_version.cpp @@ -0,0 +1,33 @@ +/* +* TLS Protocol Version Management +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_version.h> +#include <botan/parsing.h> + +namespace Botan { + +namespace TLS { + +std::string Protocol_Version::to_string() const + { + const byte maj = major_version(); + const byte min = minor_version(); + + // Some very new or very old protocol? + if(maj != 3) + return "Protocol " + Botan::to_string(maj) + "." + Botan::to_string(min); + + if(maj == 3 && min == 0) + return "SSL v3"; + + // The TLS v1.[0123...] case + return "TLS v1." + Botan::to_string(min-1); + } + +} + +} diff --git a/src/tls/tls_version.h b/src/tls/tls_version.h new file mode 100644 index 000000000..e4e6b49a2 --- /dev/null +++ b/src/tls/tls_version.h @@ -0,0 +1,100 @@ +/* +* TLS Protocol Version Management +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_PROTOCOL_VERSION_H__ +#define BOTAN_TLS_PROTOCOL_VERSION_H__ + +#include <botan/get_byte.h> +#include <botan/parsing.h> + +namespace Botan { + +namespace TLS { + +class BOTAN_DLL Protocol_Version + { + public: + enum Version_Code { + SSL_V3 = 0x0300, + TLS_V10 = 0x0301, + TLS_V11 = 0x0302, + TLS_V12 = 0x0303 + }; + + Protocol_Version() : m_major(0), m_minor(0) {} + + Protocol_Version(Version_Code named_version) : + m_major(get_byte<u16bit>(0, named_version)), + m_minor(get_byte<u16bit>(1, named_version)) {} + + Protocol_Version(byte major, byte minor) : m_major(major), m_minor(minor) {} + + /** + * Get the major version of the protocol version + */ + byte major_version() const { return m_major; } + + /** + * Get the minor version of the protocol version + */ + byte minor_version() const { return m_minor; } + + bool operator==(const Protocol_Version& other) const + { + return (cmp(other) == 0); + } + + bool operator!=(const Protocol_Version& other) const + { + return (cmp(other) != 0); + } + + bool operator>=(const Protocol_Version& other) const + { + return (cmp(other) >= 0); + } + + bool operator>(const Protocol_Version& other) const + { + return (cmp(other) > 0); + } + + bool operator<=(const Protocol_Version& other) const + { + return (cmp(other) <= 0); + } + + bool operator<(const Protocol_Version& other) const + { + return (cmp(other) < 0); + } + + std::string to_string() const; + + private: + s32bit cmp(const Protocol_Version& other) const + { + if(major_version() < other.major_version()) + return -1; + if(major_version() > other.major_version()) + return 1; + if(minor_version() < other.minor_version()) + return -1; + if(minor_version() > other.minor_version()) + return 1; + return 0; + } + + byte m_major, m_minor; + }; + +} + +} + +#endif + |