aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-01-19 17:07:03 +0000
committerlloyd <[email protected]>2012-01-19 17:07:03 +0000
commit239241568d4d3ff14d2d1994e5829f3d548f2078 (patch)
treea21fe21d7c229f00ae06859dbe1768ead68e6d13
parent30104a60568b392886c1d717a7ca006378552e4d (diff)
Remove Handshake_Message::deserialize which was an unnecessary hook.
Instead deserialize directly in the constructors that are passed the raw message data. This makes it easier to pass contextual information needed for decoding (eg, version numbers) where necessary.
-rw-r--r--src/tls/c_hello.cpp14
-rw-r--r--src/tls/c_kex.cpp22
-rw-r--r--src/tls/cert_req.cpp90
-rw-r--r--src/tls/cert_ver.cpp18
-rw-r--r--src/tls/finished.cpp2
-rw-r--r--src/tls/next_protocol.cpp18
-rw-r--r--src/tls/s_hello.cpp84
-rw-r--r--src/tls/s_kex.cpp2
-rw-r--r--src/tls/tls_messages.h30
-rw-r--r--src/tls/tls_reader.h5
10 files changed, 132 insertions, 153 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index 6c4964fb1..60f6de487 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -56,20 +56,20 @@ Hello_Request::Hello_Request(Record_Writer& writer)
}
/*
-* Serialize a Hello Request message
+* Deserialize a Hello Request message
*/
-MemoryVector<byte> Hello_Request::serialize() const
+Hello_Request::Hello_Request(const MemoryRegion<byte>& buf)
{
- return MemoryVector<byte>();
+ if(buf.size())
+ throw Decoding_Error("Hello_Request: Must be empty, and is not");
}
/*
-* Deserialize a Hello Request message
+* Serialize a Hello Request message
*/
-void Hello_Request::deserialize(const MemoryRegion<byte>& buf)
+MemoryVector<byte> Hello_Request::serialize() const
{
- if(buf.size())
- throw Decoding_Error("Hello_Request: Must be empty, and is not");
+ return MemoryVector<byte>();
}
/*
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index f95f74931..3d79116ca 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -90,7 +90,13 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
if(using_version == SSL_V3 && (suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX))
include_length = false;
- deserialize(contents);
+ if(include_length)
+ {
+ TLS_Data_Reader reader(contents);
+ key_material = reader.get_range<byte>(2, 0, 65535);
+ }
+ else
+ key_material = contents;
}
/*
@@ -109,20 +115,6 @@ MemoryVector<byte> Client_Key_Exchange::serialize() const
}
/*
-* Deserialize a Client Key Exchange message
-*/
-void Client_Key_Exchange::deserialize(const MemoryRegion<byte>& buf)
- {
- if(include_length)
- {
- TLS_Data_Reader reader(buf);
- key_material = reader.get_range<byte>(2, 0, 65535);
- }
- else
- key_material = buf;
- }
-
-/*
* Return the pre_master_secret
*/
SecureVector<byte>
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index 0168e4b7d..bdb25057c 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -37,29 +37,9 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
}
/**
-* Serialize a Certificate Request message
-*/
-MemoryVector<byte> Certificate_Req::serialize() const
- {
- MemoryVector<byte> buf;
-
- append_tls_length_value(buf, types, 1);
-
- for(size_t i = 0; i != names.size(); ++i)
- {
- DER_Encoder encoder;
- encoder.encode(names[i]);
-
- append_tls_length_value(buf, encoder.get_contents(), 2);
- }
-
- return buf;
- }
-
-/**
* Deserialize a Certificate Request message
*/
-void Certificate_Req::deserialize(const MemoryRegion<byte>& buf)
+Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf)
{
if(buf.size() < 4)
throw Decoding_Error("Certificate_Req: Bad certificate request");
@@ -96,43 +76,40 @@ void Certificate_Req::deserialize(const MemoryRegion<byte>& buf)
}
/**
-* Create a new Certificate message
+* Serialize a Certificate Request message
*/
-Certificate::Certificate(Record_Writer& writer,
- TLS_Handshake_Hash& hash,
- const std::vector<X509_Certificate>& cert_list)
+MemoryVector<byte> Certificate_Req::serialize() const
{
- certs = cert_list;
- send(writer, hash);
- }
+ MemoryVector<byte> buf;
-/**
-* Serialize a Certificate message
-*/
-MemoryVector<byte> Certificate::serialize() const
- {
- MemoryVector<byte> buf(3);
+ append_tls_length_value(buf, types, 1);
- for(size_t i = 0; i != certs.size(); ++i)
+ for(size_t i = 0; i != names.size(); ++i)
{
- MemoryVector<byte> raw_cert = certs[i].BER_encode();
- const size_t cert_size = raw_cert.size();
- for(size_t i = 0; i != 3; ++i)
- buf.push_back(get_byte<u32bit>(i+1, cert_size));
- buf += raw_cert;
- }
+ DER_Encoder encoder;
+ encoder.encode(names[i]);
- const size_t buf_size = buf.size() - 3;
- for(size_t i = 0; i != 3; ++i)
- buf[i] = get_byte<u32bit>(i+1, buf_size);
+ append_tls_length_value(buf, encoder.get_contents(), 2);
+ }
return buf;
}
/**
+* Create a new Certificate message
+*/
+Certificate::Certificate(Record_Writer& writer,
+ TLS_Handshake_Hash& hash,
+ const std::vector<X509_Certificate>& cert_list)
+ {
+ certs = cert_list;
+ send(writer, hash);
+ }
+
+/**
* Deserialize a Certificate message
*/
-void Certificate::deserialize(const MemoryRegion<byte>& buf)
+Certificate::Certificate(const MemoryRegion<byte>& buf)
{
if(buf.size() < 3)
throw Decoding_Error("Certificate: Message malformed");
@@ -163,4 +140,27 @@ void Certificate::deserialize(const MemoryRegion<byte>& buf)
}
}
+/**
+* Serialize a Certificate message
+*/
+MemoryVector<byte> Certificate::serialize() const
+ {
+ MemoryVector<byte> buf(3);
+
+ for(size_t i = 0; i != certs.size(); ++i)
+ {
+ MemoryVector<byte> raw_cert = certs[i].BER_encode();
+ const size_t cert_size = raw_cert.size();
+ for(size_t i = 0; i != 3; ++i)
+ buf.push_back(get_byte<u32bit>(i+1, cert_size));
+ buf += raw_cert;
+ }
+
+ const size_t buf_size = buf.size() - 3;
+ for(size_t i = 0; i != 3; ++i)
+ buf[i] = get_byte<u32bit>(i+1, buf_size);
+
+ return buf;
+ }
+
}
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index 81d529e88..77d9fe74b 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -54,6 +54,15 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
}
/*
+* Deserialize a Certificate Verify message
+*/
+Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf)
+ {
+ TLS_Data_Reader reader(buf);
+ signature = reader.get_range<byte>(2, 0, 65535);
+ }
+
+/*
* Serialize a Certificate Verify message
*/
MemoryVector<byte> Certificate_Verify::serialize() const
@@ -69,15 +78,6 @@ MemoryVector<byte> Certificate_Verify::serialize() const
}
/*
-* Deserialize a Certificate Verify message
-*/
-void Certificate_Verify::deserialize(const MemoryRegion<byte>& buf)
- {
- TLS_Data_Reader reader(buf);
- signature = reader.get_range<byte>(2, 0, 65535);
- }
-
-/*
* Verify a Certificate Verify message
*/
bool Certificate_Verify::verify(const X509_Certificate& cert,
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index 836512f81..baa663798 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -81,7 +81,7 @@ MemoryVector<byte> Finished::serialize() const
/*
* Deserialize a Finished message
*/
-void Finished::deserialize(const MemoryRegion<byte>& buf)
+Finished::Finished(const MemoryRegion<byte>& buf)
{
verification_data = buf;
}
diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp
index 2d2e2e599..a0d4278f1 100644
--- a/src/tls/next_protocol.cpp
+++ b/src/tls/next_protocol.cpp
@@ -19,6 +19,15 @@ Next_Protocol::Next_Protocol(Record_Writer& writer,
send(writer, hash);
}
+Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf)
+ {
+ TLS_Data_Reader reader(buf);
+
+ m_protocol = reader.get_string(1, 0, 255);
+
+ reader.get_range_vector<byte>(1, 0, 255); // padding, ignored
+ }
+
MemoryVector<byte> Next_Protocol::serialize() const
{
MemoryVector<byte> buf;
@@ -38,13 +47,4 @@ MemoryVector<byte> Next_Protocol::serialize() const
return buf;
}
-void Next_Protocol::deserialize(const MemoryRegion<byte>& buf)
- {
- TLS_Data_Reader reader(buf);
-
- m_protocol = reader.get_string(1, 0, 255);
-
- reader.get_range_vector<byte>(1, 0, 255); // padding, ignored
- }
-
}
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index 90e18ae90..fa185599d 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -90,43 +90,9 @@ Server_Hello::Server_Hello(Record_Writer& writer,
}
/*
-* Serialize a Server Hello message
-*/
-MemoryVector<byte> Server_Hello::serialize() const
- {
- MemoryVector<byte> buf;
-
- buf.push_back(static_cast<byte>(s_version >> 8));
- buf.push_back(static_cast<byte>(s_version ));
- buf += s_random;
-
- append_tls_length_value(buf, m_session_id, 1);
-
- buf.push_back(get_byte(0, suite));
- buf.push_back(get_byte(1, suite));
-
- buf.push_back(comp_method);
-
- TLS_Extensions extensions;
-
- if(m_secure_renegotiation)
- extensions.push_back(new Renegotation_Extension(m_renegotiation_info));
-
- if(m_fragment_size != 0)
- extensions.push_back(new Maximum_Fragment_Length(m_fragment_size));
-
- if(m_next_protocol)
- extensions.push_back(new Next_Protocol_Notification(m_next_protocols));
-
- buf += extensions.serialize();
-
- return buf;
- }
-
-/*
* Deserialize a Server Hello message
*/
-void Server_Hello::deserialize(const MemoryRegion<byte>& buf)
+Server_Hello::Server_Hello(const MemoryRegion<byte>& buf)
{
m_secure_renegotiation = false;
m_next_protocol = false;
@@ -173,6 +139,40 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf)
}
/*
+* Serialize a Server Hello message
+*/
+MemoryVector<byte> Server_Hello::serialize() const
+ {
+ MemoryVector<byte> buf;
+
+ buf.push_back(static_cast<byte>(s_version >> 8));
+ buf.push_back(static_cast<byte>(s_version ));
+ buf += s_random;
+
+ append_tls_length_value(buf, m_session_id, 1);
+
+ buf.push_back(get_byte(0, suite));
+ buf.push_back(get_byte(1, suite));
+
+ buf.push_back(comp_method);
+
+ TLS_Extensions extensions;
+
+ if(m_secure_renegotiation)
+ extensions.push_back(new Renegotation_Extension(m_renegotiation_info));
+
+ if(m_fragment_size != 0)
+ extensions.push_back(new Maximum_Fragment_Length(m_fragment_size));
+
+ if(m_next_protocol)
+ extensions.push_back(new Next_Protocol_Notification(m_next_protocols));
+
+ buf += extensions.serialize();
+
+ return buf;
+ }
+
+/*
* Create a new Server Hello Done message
*/
Server_Hello_Done::Server_Hello_Done(Record_Writer& writer,
@@ -182,20 +182,20 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer,
}
/*
-* Serialize a Server Hello Done message
+* Deserialize a Server Hello Done message
*/
-MemoryVector<byte> Server_Hello_Done::serialize() const
+Server_Hello_Done::Server_Hello_Done(const MemoryRegion<byte>& buf)
{
- return MemoryVector<byte>();
+ if(buf.size())
+ throw Decoding_Error("Server_Hello_Done: Must be empty, and is not");
}
/*
-* Deserialize a Server Hello Done message
+* Serialize a Server Hello Done message
*/
-void Server_Hello_Done::deserialize(const MemoryRegion<byte>& buf)
+MemoryVector<byte> Server_Hello_Done::serialize() const
{
- if(buf.size())
- throw Decoding_Error("Server_Hello_Done: Must be empty, and is not");
+ return MemoryVector<byte>();
}
}
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 69531c7c4..7008c89de 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -73,7 +73,7 @@ MemoryVector<byte> Server_Key_Exchange::serialize_params() const
/**
* Deserialize a Server Key Exchange message
*/
-void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf)
+Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf)
{
if(buf.size() < 6)
throw Decoding_Error("Server_Key_Exchange: Packet corrupted");
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 0b43545dc..d3735972e 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -37,7 +37,6 @@ class Handshake_Message
private:
Handshake_Message& operator=(const Handshake_Message&) { return (*this); }
virtual MemoryVector<byte> serialize() const = 0;
- virtual void deserialize(const MemoryRegion<byte>&) = 0;
};
MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng);
@@ -179,10 +178,9 @@ class Server_Hello : public Handshake_Message
const std::vector<std::string>& next_protocols,
RandomNumberGenerator& rng);
- Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Server_Hello(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
Version_Code s_version;
MemoryVector<byte> m_session_id, s_random;
@@ -224,7 +222,6 @@ class Client_Key_Exchange : public Handshake_Message
Version_Code using_version);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
SecureVector<byte> key_material, pre_master;
bool include_length;
@@ -246,10 +243,10 @@ class Certificate : public Handshake_Message
TLS_Handshake_Hash& hash,
const std::vector<X509_Certificate>& certs);
- Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Certificate(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
+
std::vector<X509_Certificate> certs;
};
@@ -270,10 +267,9 @@ class Certificate_Req : public Handshake_Message
const std::vector<Certificate_Type>& types =
std::vector<Certificate_Type>());
- Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Certificate_Req(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
std::vector<X509_DN> names;
std::vector<Certificate_Type> types;
@@ -300,10 +296,9 @@ class Certificate_Verify : public Handshake_Message
RandomNumberGenerator& rng,
const Private_Key* key);
- Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Certificate_Verify(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
MemoryVector<byte> signature;
};
@@ -326,10 +321,9 @@ class Finished : public Handshake_Message
TLS_Handshake_State* state,
Connection_Side side);
- Finished(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Finished(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
Connection_Side side;
MemoryVector<byte> verification_data;
@@ -344,10 +338,9 @@ class Hello_Request : public Handshake_Message
Handshake_Type type() const { return HELLO_REQUEST; }
Hello_Request(Record_Writer& writer);
- Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Hello_Request(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
};
/**
@@ -367,11 +360,10 @@ class Server_Key_Exchange : public Handshake_Message
RandomNumberGenerator& rng,
const Private_Key* priv_key);
- Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Server_Key_Exchange(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
MemoryVector<byte> serialize_params() const;
- void deserialize(const MemoryRegion<byte>&);
std::vector<BigInt> params;
MemoryVector<byte> signature;
@@ -386,10 +378,9 @@ class Server_Hello_Done : public Handshake_Message
Handshake_Type type() const { return SERVER_HELLO_DONE; }
Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash);
- Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Server_Hello_Done(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
};
/**
@@ -406,10 +397,9 @@ class Next_Protocol : public Handshake_Message
TLS_Handshake_Hash& hash,
const std::string& protocol);
- Next_Protocol(const MemoryRegion<byte>& buf) { deserialize(buf); }
+ Next_Protocol(const MemoryRegion<byte>& buf);
private:
MemoryVector<byte> serialize() const;
- void deserialize(const MemoryRegion<byte>&);
std::string m_protocol;
};
diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h
index ef36912d3..6a0bcd5b1 100644
--- a/src/tls/tls_reader.h
+++ b/src/tls/tls_reader.h
@@ -26,13 +26,10 @@ class TLS_Data_Reader
TLS_Data_Reader(const MemoryRegion<byte>& buf_in) :
buf(buf_in), offset(0) {}
- ~TLS_Data_Reader()
+ void assert_done() const
{
if(has_remaining())
- {
- abort();
throw Decoding_Error("Extra bytes at end of message");
- }
}
size_t remaining_bytes() const