diff options
-rw-r--r-- | src/tls/c_hello.cpp | 14 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 22 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 90 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 18 | ||||
-rw-r--r-- | src/tls/finished.cpp | 2 | ||||
-rw-r--r-- | src/tls/next_protocol.cpp | 18 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 84 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 30 | ||||
-rw-r--r-- | src/tls/tls_reader.h | 5 |
10 files changed, 132 insertions, 153 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 6c4964fb1..60f6de487 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -56,20 +56,20 @@ Hello_Request::Hello_Request(Record_Writer& writer) } /* -* Serialize a Hello Request message +* Deserialize a Hello Request message */ -MemoryVector<byte> Hello_Request::serialize() const +Hello_Request::Hello_Request(const MemoryRegion<byte>& buf) { - return MemoryVector<byte>(); + if(buf.size()) + throw Decoding_Error("Hello_Request: Must be empty, and is not"); } /* -* Deserialize a Hello Request message +* Serialize a Hello Request message */ -void Hello_Request::deserialize(const MemoryRegion<byte>& buf) +MemoryVector<byte> Hello_Request::serialize() const { - if(buf.size()) - throw Decoding_Error("Hello_Request: Must be empty, and is not"); + return MemoryVector<byte>(); } /* diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index f95f74931..3d79116ca 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -90,7 +90,13 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, if(using_version == SSL_V3 && (suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX)) include_length = false; - deserialize(contents); + if(include_length) + { + TLS_Data_Reader reader(contents); + key_material = reader.get_range<byte>(2, 0, 65535); + } + else + key_material = contents; } /* @@ -109,20 +115,6 @@ MemoryVector<byte> Client_Key_Exchange::serialize() const } /* -* Deserialize a Client Key Exchange message -*/ -void Client_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) - { - if(include_length) - { - TLS_Data_Reader reader(buf); - key_material = reader.get_range<byte>(2, 0, 65535); - } - else - key_material = buf; - } - -/* * Return the pre_master_secret */ SecureVector<byte> diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 0168e4b7d..bdb25057c 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -37,29 +37,9 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, } /** -* Serialize a Certificate Request message -*/ -MemoryVector<byte> Certificate_Req::serialize() const - { - MemoryVector<byte> buf; - - append_tls_length_value(buf, types, 1); - - for(size_t i = 0; i != names.size(); ++i) - { - DER_Encoder encoder; - encoder.encode(names[i]); - - append_tls_length_value(buf, encoder.get_contents(), 2); - } - - return buf; - } - -/** * Deserialize a Certificate Request message */ -void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) +Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); @@ -96,43 +76,40 @@ void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) } /** -* Create a new Certificate message +* Serialize a Certificate Request message */ -Certificate::Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const std::vector<X509_Certificate>& cert_list) +MemoryVector<byte> Certificate_Req::serialize() const { - certs = cert_list; - send(writer, hash); - } + MemoryVector<byte> buf; -/** -* Serialize a Certificate message -*/ -MemoryVector<byte> Certificate::serialize() const - { - MemoryVector<byte> buf(3); + append_tls_length_value(buf, types, 1); - for(size_t i = 0; i != certs.size(); ++i) + for(size_t i = 0; i != names.size(); ++i) { - MemoryVector<byte> raw_cert = certs[i].BER_encode(); - const size_t cert_size = raw_cert.size(); - for(size_t i = 0; i != 3; ++i) - buf.push_back(get_byte<u32bit>(i+1, cert_size)); - buf += raw_cert; - } + DER_Encoder encoder; + encoder.encode(names[i]); - const size_t buf_size = buf.size() - 3; - for(size_t i = 0; i != 3; ++i) - buf[i] = get_byte<u32bit>(i+1, buf_size); + append_tls_length_value(buf, encoder.get_contents(), 2); + } return buf; } /** +* Create a new Certificate message +*/ +Certificate::Certificate(Record_Writer& writer, + TLS_Handshake_Hash& hash, + const std::vector<X509_Certificate>& cert_list) + { + certs = cert_list; + send(writer, hash); + } + +/** * Deserialize a Certificate message */ -void Certificate::deserialize(const MemoryRegion<byte>& buf) +Certificate::Certificate(const MemoryRegion<byte>& buf) { if(buf.size() < 3) throw Decoding_Error("Certificate: Message malformed"); @@ -163,4 +140,27 @@ void Certificate::deserialize(const MemoryRegion<byte>& buf) } } +/** +* Serialize a Certificate message +*/ +MemoryVector<byte> Certificate::serialize() const + { + MemoryVector<byte> buf(3); + + for(size_t i = 0; i != certs.size(); ++i) + { + MemoryVector<byte> raw_cert = certs[i].BER_encode(); + const size_t cert_size = raw_cert.size(); + for(size_t i = 0; i != 3; ++i) + buf.push_back(get_byte<u32bit>(i+1, cert_size)); + buf += raw_cert; + } + + const size_t buf_size = buf.size() - 3; + for(size_t i = 0; i != 3; ++i) + buf[i] = get_byte<u32bit>(i+1, buf_size); + + return buf; + } + } diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 81d529e88..77d9fe74b 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -54,6 +54,15 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, } /* +* Deserialize a Certificate Verify message +*/ +Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf) + { + TLS_Data_Reader reader(buf); + signature = reader.get_range<byte>(2, 0, 65535); + } + +/* * Serialize a Certificate Verify message */ MemoryVector<byte> Certificate_Verify::serialize() const @@ -69,15 +78,6 @@ MemoryVector<byte> Certificate_Verify::serialize() const } /* -* Deserialize a Certificate Verify message -*/ -void Certificate_Verify::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - signature = reader.get_range<byte>(2, 0, 65535); - } - -/* * Verify a Certificate Verify message */ bool Certificate_Verify::verify(const X509_Certificate& cert, diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index 836512f81..baa663798 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -81,7 +81,7 @@ MemoryVector<byte> Finished::serialize() const /* * Deserialize a Finished message */ -void Finished::deserialize(const MemoryRegion<byte>& buf) +Finished::Finished(const MemoryRegion<byte>& buf) { verification_data = buf; } diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index 2d2e2e599..a0d4278f1 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -19,6 +19,15 @@ Next_Protocol::Next_Protocol(Record_Writer& writer, send(writer, hash); } +Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf) + { + TLS_Data_Reader reader(buf); + + m_protocol = reader.get_string(1, 0, 255); + + reader.get_range_vector<byte>(1, 0, 255); // padding, ignored + } + MemoryVector<byte> Next_Protocol::serialize() const { MemoryVector<byte> buf; @@ -38,13 +47,4 @@ MemoryVector<byte> Next_Protocol::serialize() const return buf; } -void Next_Protocol::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - - m_protocol = reader.get_string(1, 0, 255); - - reader.get_range_vector<byte>(1, 0, 255); // padding, ignored - } - } diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 90e18ae90..fa185599d 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -90,43 +90,9 @@ Server_Hello::Server_Hello(Record_Writer& writer, } /* -* Serialize a Server Hello message -*/ -MemoryVector<byte> Server_Hello::serialize() const - { - MemoryVector<byte> buf; - - buf.push_back(static_cast<byte>(s_version >> 8)); - buf.push_back(static_cast<byte>(s_version )); - buf += s_random; - - append_tls_length_value(buf, m_session_id, 1); - - buf.push_back(get_byte(0, suite)); - buf.push_back(get_byte(1, suite)); - - buf.push_back(comp_method); - - TLS_Extensions extensions; - - if(m_secure_renegotiation) - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - - if(m_fragment_size != 0) - extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); - - if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); - - buf += extensions.serialize(); - - return buf; - } - -/* * Deserialize a Server Hello message */ -void Server_Hello::deserialize(const MemoryRegion<byte>& buf) +Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) { m_secure_renegotiation = false; m_next_protocol = false; @@ -173,6 +139,40 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf) } /* +* Serialize a Server Hello message +*/ +MemoryVector<byte> Server_Hello::serialize() const + { + MemoryVector<byte> buf; + + buf.push_back(static_cast<byte>(s_version >> 8)); + buf.push_back(static_cast<byte>(s_version )); + buf += s_random; + + append_tls_length_value(buf, m_session_id, 1); + + buf.push_back(get_byte(0, suite)); + buf.push_back(get_byte(1, suite)); + + buf.push_back(comp_method); + + TLS_Extensions extensions; + + if(m_secure_renegotiation) + extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + + if(m_fragment_size != 0) + extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); + + if(m_next_protocol) + extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); + + buf += extensions.serialize(); + + return buf; + } + +/* * Create a new Server Hello Done message */ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, @@ -182,20 +182,20 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, } /* -* Serialize a Server Hello Done message +* Deserialize a Server Hello Done message */ -MemoryVector<byte> Server_Hello_Done::serialize() const +Server_Hello_Done::Server_Hello_Done(const MemoryRegion<byte>& buf) { - return MemoryVector<byte>(); + if(buf.size()) + throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); } /* -* Deserialize a Server Hello Done message +* Serialize a Server Hello Done message */ -void Server_Hello_Done::deserialize(const MemoryRegion<byte>& buf) +MemoryVector<byte> Server_Hello_Done::serialize() const { - if(buf.size()) - throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); + return MemoryVector<byte>(); } } diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 69531c7c4..7008c89de 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -73,7 +73,7 @@ MemoryVector<byte> Server_Key_Exchange::serialize_params() const /** * Deserialize a Server Key Exchange message */ -void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) +Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 0b43545dc..d3735972e 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -37,7 +37,6 @@ class Handshake_Message private: Handshake_Message& operator=(const Handshake_Message&) { return (*this); } virtual MemoryVector<byte> serialize() const = 0; - virtual void deserialize(const MemoryRegion<byte>&) = 0; }; MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng); @@ -179,10 +178,9 @@ class Server_Hello : public Handshake_Message const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng); - Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); Version_Code s_version; MemoryVector<byte> m_session_id, s_random; @@ -224,7 +222,6 @@ class Client_Key_Exchange : public Handshake_Message Version_Code using_version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); SecureVector<byte> key_material, pre_master; bool include_length; @@ -246,10 +243,10 @@ class Certificate : public Handshake_Message TLS_Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); - Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + std::vector<X509_Certificate> certs; }; @@ -270,10 +267,9 @@ class Certificate_Req : public Handshake_Message const std::vector<Certificate_Type>& types = std::vector<Certificate_Type>()); - Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Req(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::vector<X509_DN> names; std::vector<Certificate_Type> types; @@ -300,10 +296,9 @@ class Certificate_Verify : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Verify(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); MemoryVector<byte> signature; }; @@ -326,10 +321,9 @@ class Finished : public Handshake_Message TLS_Handshake_State* state, Connection_Side side); - Finished(const MemoryRegion<byte>& buf) { deserialize(buf); } + Finished(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); Connection_Side side; MemoryVector<byte> verification_data; @@ -344,10 +338,9 @@ class Hello_Request : public Handshake_Message Handshake_Type type() const { return HELLO_REQUEST; } Hello_Request(Record_Writer& writer); - Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); } + Hello_Request(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -367,11 +360,10 @@ class Server_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* priv_key); - Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Key_Exchange(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; MemoryVector<byte> serialize_params() const; - void deserialize(const MemoryRegion<byte>&); std::vector<BigInt> params; MemoryVector<byte> signature; @@ -386,10 +378,9 @@ class Server_Hello_Done : public Handshake_Message Handshake_Type type() const { return SERVER_HELLO_DONE; } Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash); - Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello_Done(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -406,10 +397,9 @@ class Next_Protocol : public Handshake_Message TLS_Handshake_Hash& hash, const std::string& protocol); - Next_Protocol(const MemoryRegion<byte>& buf) { deserialize(buf); } + Next_Protocol(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::string m_protocol; }; diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index ef36912d3..6a0bcd5b1 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -26,13 +26,10 @@ class TLS_Data_Reader TLS_Data_Reader(const MemoryRegion<byte>& buf_in) : buf(buf_in), offset(0) {} - ~TLS_Data_Reader() + void assert_done() const { if(has_remaining()) - { - abort(); throw Decoding_Error("Extra bytes at end of message"); - } } size_t remaining_bytes() const |