aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2011-12-23 18:22:37 +0000
committerlloyd <[email protected]>2011-12-23 18:22:37 +0000
commit61d461d0a5fb63c3aee906c76b4aefe3335a7591 (patch)
treea936e50187ba7ace33c09fcf5a9119e257987f30
parent917bf37104eb039a97ef989306954dd8bc05f400 (diff)
Centralize a lot of the handshaking and message parsing in TLS_Channel
Also delete the obsolete/never worked CMS examples
-rw-r--r--doc/examples/cms_dec.cpp120
-rw-r--r--doc/examples/cms_enc.cpp59
-rw-r--r--doc/examples/socket.h2
-rw-r--r--doc/examples/tls_client.cpp2
-rw-r--r--doc/examples/tls_server.cpp49
-rw-r--r--src/tls/info.txt2
-rw-r--r--src/tls/tls_channel.cpp188
-rw-r--r--src/tls/tls_channel.h85
-rw-r--r--src/tls/tls_client.cpp190
-rw-r--r--src/tls/tls_client.h46
-rw-r--r--src/tls/tls_server.cpp235
-rw-r--r--src/tls/tls_server.h45
12 files changed, 344 insertions, 679 deletions
diff --git a/doc/examples/cms_dec.cpp b/doc/examples/cms_dec.cpp
deleted file mode 100644
index 84355fb4a..000000000
--- a/doc/examples/cms_dec.cpp
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
-* (C) 2009 Jack Lloyd
-*
-* Distributed under the terms of the Botan license
-*/
-
-#include <botan/botan.h>
-#include <botan/pkcs8.h>
-#include <botan/cms_dec.h>
-using namespace Botan;
-
-#include <iostream>
-#include <memory>
-
-int main(int argc, char* argv[])
- {
- if(argc != 2)
- {
- std::cout << "Usage: " << argv[0] << " <filename>\n";
- return 1;
- }
-
- Botan::LibraryInitializer init;
-
- try {
- AutoSeeded_RNG rng;
-
- X509_Certificate mycert("mycert.pem");
- PKCS8_PrivateKey* mykey = PKCS8::load_key("mykey.pem", rng, "cut");
-
- X509_Certificate yourcert("yourcert.pem");
- X509_Certificate cacert("cacert.pem");
- X509_Certificate int_ca("int_ca.pem");
-
- X509_Store store;
- store.add_cert(mycert);
- store.add_cert(yourcert);
- store.add_cert(cacert, true);
- store.add_cert(int_ca);
-
- DataSource_Stream message(argv[1]);
-
- CMS_Decoder decoder(message, store, mykey);
-
- while(decoder.layer_type() != CMS_Decoder::DATA)
- {
- CMS_Decoder::Status status = decoder.layer_status();
- CMS_Decoder::Content_Type content = decoder.layer_type();
-
- if(status == CMS_Decoder::FAILURE)
- {
- std::cout << "Failure reading CMS data" << std::endl;
- break;
- }
-
- if(content == CMS_Decoder::DIGESTED)
- {
- std::cout << "Digested data, hash = " << decoder.layer_info()
- << std::endl;
- std::cout << "Hash is "
- << ((status == CMS_Decoder::GOOD) ? "good" : "bad")
- << std::endl;
- }
-
- if(content == CMS_Decoder::SIGNED)
- {
- // how to handle multiple signers? they can all exist within a
- // single level...
-
- std::cout << "Signed by " << decoder.layer_info() << std::endl;
- //std::cout << "Sign time: " << decoder.xxx() << std::endl;
- std::cout << "Signature is ";
- if(status == CMS_Decoder::GOOD)
- std::cout << "valid";
- else if(status == CMS_Decoder::BAD)
- std::cout << "bad";
- else if(status == CMS_Decoder::NO_KEY)
- std::cout << "(cannot check, no known cert)";
- std::cout << std::endl;
- }
- if(content == CMS_Decoder::ENVELOPED ||
- content == CMS_Decoder::COMPRESSED ||
- content == CMS_Decoder::AUTHENTICATED)
- {
- if(content == CMS_Decoder::ENVELOPED)
- std::cout << "Enveloped";
- if(content == CMS_Decoder::COMPRESSED)
- std::cout << "Compressed";
- if(content == CMS_Decoder::AUTHENTICATED)
- std::cout << "MACed";
-
- std::cout << ", algo = " << decoder.layer_info() << std::endl;
-
- if(content == CMS_Decoder::AUTHENTICATED)
- {
- std::cout << "MAC status is ";
- if(status == CMS_Decoder::GOOD)
- std::cout << "valid";
- else if(status == CMS_Decoder::BAD)
- std::cout << "bad";
- else if(status == CMS_Decoder::NO_KEY)
- std::cout << "(cannot check, no key)";
- std::cout << std::endl;
- }
- }
- decoder.next_layer();
- }
-
- if(decoder.layer_type() == CMS_Decoder::DATA)
- std::cout << "Message is \"" << decoder.get_data()
- << '"' << std::endl;
- else
- std::cout << "No data anywhere?" << std::endl;
- }
- catch(std::exception& e)
- {
- std::cerr << e.what() << std::endl;
- }
- return 0;
- }
diff --git a/doc/examples/cms_enc.cpp b/doc/examples/cms_enc.cpp
deleted file mode 100644
index 2cf813987..000000000
--- a/doc/examples/cms_enc.cpp
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
-* (C) 2009 Jack Lloyd
-*
-* Distributed under the terms of the Botan license
-*/
-
-#include <botan/botan.h>
-#include <botan/cms_enc.h>
-using namespace Botan;
-
-#include <iostream>
-#include <fstream>
-#include <memory>
-
-int main()
- {
- Botan::LibraryInitializer init;
-
- try {
-
- X509_Certificate mycert("mycert.pem");
- X509_Certificate mycert2("mycert2.pem");
- X509_Certificate yourcert("yourcert.pem");
- X509_Certificate cacert("cacert.pem");
- X509_Certificate int_ca("int_ca.pem");
-
- AutoSeeded_RNG rng;
-
- X509_Store store;
- store.add_cert(mycert);
- store.add_cert(mycert2);
- store.add_cert(yourcert);
- store.add_cert(int_ca);
- store.add_cert(cacert, true);
-
- const std::string msg = "prioncorp: we don't toy\n";
-
- CMS_Encoder encoder(msg);
-
- encoder.compress("Zlib");
- encoder.digest();
- encoder.encrypt(rng, mycert);
-
- /*
- PKCS8_PrivateKey* mykey = PKCS8::load_key("mykey.pem", rng, "cut");
- encoder.sign(store, *mykey);
- */
-
- SecureVector<byte> raw = encoder.get_contents();
- std::ofstream out("out.der");
-
- out.write((const char*)raw.begin(), raw.size());
- }
- catch(std::exception& e)
- {
- std::cerr << e.what() << std::endl;
- }
- return 0;
- }
diff --git a/doc/examples/socket.h b/doc/examples/socket.h
index f10ff9f26..9e16ab36a 100644
--- a/doc/examples/socket.h
+++ b/doc/examples/socket.h
@@ -180,7 +180,7 @@ size_t Socket::read(unsigned char buf[], size_t length, bool partial)
{
ssize_t this_time = ::recv(sockfd, (char*)buf + got, length, flags);
- const bool full_ret = (this_time == length);
+ const bool full_ret = (this_time == (ssize_t)length);
if(this_time == 0)
break;
diff --git a/doc/examples/tls_client.cpp b/doc/examples/tls_client.cpp
index a51febfcf..ee224e9eb 100644
--- a/doc/examples/tls_client.cpp
+++ b/doc/examples/tls_client.cpp
@@ -41,7 +41,7 @@ class HTTPS_Client
quit_reading = false;
- while(!client.handshake_complete() || desired)
+ while(!client.is_active() || desired)
{
const size_t socket_got = socket.read(&socket_buf[0], socket_buf.size());
//printf("Got %d bytes from socket\n", socket_got);
diff --git a/doc/examples/tls_server.cpp b/doc/examples/tls_server.cpp
index 153b26d04..62bc8fadc 100644
--- a/doc/examples/tls_server.cpp
+++ b/doc/examples/tls_server.cpp
@@ -30,6 +30,13 @@ class Server_TLS_Policy : public TLS_Policy
}
};
+void proc_data(const byte data[], size_t data_len, u16bit alert_info)
+ {
+ printf("Block of data %d bytes alert %04X\n", (int)data_len, alert_info);
+ for(size_t i = 0; i != data_len; ++i)
+ printf("%c", data[i]);
+ }
+
int main(int argc, char* argv[])
{
int port = 4433;
@@ -40,7 +47,7 @@ int main(int argc, char* argv[])
try
{
LibraryInitializer botan_init;
- SocketInitializer socket_init;
+ //SocketInitializer socket_init;
AutoSeeded_RNG rng;
@@ -67,28 +74,40 @@ int main(int argc, char* argv[])
printf("Got new connection\n");
TLS_Server tls(
- std::tr1::bind(&Socket::read, std::tr1::ref(sock), _1, _2),
- std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2),
- policy,
- rng,
- cert,
- key);
-
- std::string hostname = tls.requested_hostname();
+ std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2),
+ proc_data,
+ policy,
+ rng,
+ cert,
+ key);
+
+ SecureVector<byte> buf(1024);
+ size_t desired = 0;
+ while(!tls.is_active() || desired)
+ {
+ const size_t socket_got = sock->read(&buf[0], desired || 1);
+ desired = tls.received_data(&buf[0], socket_got);
+ }
+
+ const std::string hostname = tls.server_name_indicator();
if(hostname != "")
printf("Client requested host '%s'\n", hostname.c_str());
printf("Writing some text\n");
- char msg[] = "Foo\nBar\nBaz\nQuux\n";
- tls.write((const Botan::byte*)msg, strlen(msg));
+ char msg[] = "Welcome to the best echo server evar\n";
+ tls.queue_for_sending((const Botan::byte*)msg, strlen(msg));
+
+ while(true)
+ {
+ size_t got = sock->read(&buf[0], buf.size(), true);
- printf("Now trying a read...\n");
+ if(got == 0)
+ break;
- char buf[1024] = { 0 };
- u32bit got = tls.read((Botan::byte*)buf, sizeof(buf)-1);
- printf("%d: '%s'\n", got, buf);
+ tls.received_data(&buf[0], got);
+ }
tls.close();
}
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 1170fef45..f09309bd2 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -8,6 +8,7 @@ serious bugs or security issues.
uses_tr1 yes
<header:public>
+tls_channel.h
tls_client.h
tls_exceptn.h
tls_magic.h
@@ -36,6 +37,7 @@ hello.cpp
rec_read.cpp
rec_wri.cpp
s_kex.cpp
+tls_channel.cpp
tls_client.cpp
tls_policy.cpp
tls_server.cpp
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
new file mode 100644
index 000000000..580c1e5e5
--- /dev/null
+++ b/src/tls/tls_channel.cpp
@@ -0,0 +1,188 @@
+/*
+* TLS Channels
+* (C) 2011 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/tls_channel.h>
+#include <botan/internal/tls_alerts.h>
+#include <botan/internal/tls_state.h>
+#include <botan/loadstor.h>
+
+namespace Botan {
+
+TLS_Channel::TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn) :
+ proc_fn(proc_fn),
+ writer(socket_output_fn),
+ state(0),
+ active(false)
+ {
+ }
+
+TLS_Channel::~TLS_Channel()
+ {
+ close();
+ delete state;
+ state = 0;
+ }
+
+size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
+ {
+ try
+ {
+ reader.add_input(buf, buf_size);
+
+ byte rec_type = CONNECTION_CLOSED;
+ SecureVector<byte> record;
+
+ while(!reader.currently_empty())
+ {
+ const size_t bytes_needed = reader.get_record(rec_type, record);
+
+ if(bytes_needed > 0)
+ return bytes_needed;
+
+ if(rec_type == APPLICATION_DATA)
+ {
+ if(active)
+ {
+ proc_fn(&record[0], record.size(), NO_ALERT_TYPE);
+ }
+ else
+ {
+ throw Unexpected_Message("Application data before handshake done");
+ }
+ }
+ else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
+ {
+ read_handshake(rec_type, record);
+ }
+ else if(rec_type == ALERT)
+ {
+ Alert alert_msg(record);
+
+ proc_fn(0, 0, alert_msg.type());
+
+ if(alert_msg.is_fatal() || alert_msg.type() == CLOSE_NOTIFY)
+ {
+ if(alert_msg.type() == CLOSE_NOTIFY)
+ {
+ writer.alert(WARNING, CLOSE_NOTIFY);
+ }
+
+ alert(FATAL, NO_ALERT_TYPE);
+ }
+ }
+ else
+ throw Unexpected_Message("Unknown message type received");
+ }
+
+ return 0; // on a record boundary
+ }
+ catch(TLS_Exception& e)
+ {
+ alert(FATAL, e.type());
+ throw;
+ }
+ catch(std::exception& e)
+ {
+ alert(FATAL, INTERNAL_ERROR);
+ throw;
+ }
+ }
+
+/*
+* Split up and process handshake messages
+*/
+void TLS_Channel::read_handshake(byte rec_type,
+ const MemoryRegion<byte>& rec_buf)
+ {
+ if(rec_type == HANDSHAKE)
+ state->queue.write(&rec_buf[0], rec_buf.size());
+
+ while(true)
+ {
+ Handshake_Type type = HANDSHAKE_NONE;
+ SecureVector<byte> contents;
+
+ if(rec_type == HANDSHAKE)
+ {
+ if(state->queue.size() >= 4)
+ {
+ byte head[4] = { 0 };
+ state->queue.peek(head, 4);
+
+ const size_t length = make_u32bit(0, head[1], head[2], head[3]);
+
+ if(state->queue.size() >= length + 4)
+ {
+ type = static_cast<Handshake_Type>(head[0]);
+ contents.resize(length);
+ state->queue.read(head, 4);
+ state->queue.read(&contents[0], contents.size());
+ }
+ }
+ }
+ else if(rec_type == CHANGE_CIPHER_SPEC)
+ {
+ if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1)
+ type = HANDSHAKE_CCS;
+ else
+ throw Decoding_Error("Malformed ChangeCipherSpec message");
+ }
+ else
+ throw Decoding_Error("Unknown message type in handshake processing");
+
+ if(type == HANDSHAKE_NONE)
+ break;
+
+ process_handshake_msg(type, contents);
+
+ if(type == HANDSHAKE_CCS || !state)
+ break;
+ }
+ }
+
+void TLS_Channel::queue_for_sending(const byte buf[], size_t buf_size)
+ {
+ if(active)
+ {
+ while(!pre_handshake_write_queue.end_of_data())
+ {
+ SecureVector<byte> q_buf(1024);
+ const size_t got = pre_handshake_write_queue.read(&q_buf[0], q_buf.size());
+ writer.send(APPLICATION_DATA, &q_buf[0], got);
+ }
+
+ writer.send(APPLICATION_DATA, buf, buf_size);
+ writer.flush();
+ }
+ else
+ pre_handshake_write_queue.write(buf, buf_size);
+ }
+
+void TLS_Channel::alert(Alert_Level level, Alert_Type alert_code)
+ {
+ if(active && alert_code != NO_ALERT_TYPE)
+ {
+ try
+ {
+ writer.alert(level, alert_code);
+ writer.flush();
+ }
+ catch(...) { /* swallow it */ }
+ }
+
+ if(active && level == FATAL)
+ {
+ reader.reset();
+ writer.reset();
+ delete state;
+ state = 0;
+ active = false;
+ }
+ }
+
+}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
new file mode 100644
index 000000000..d74504ccd
--- /dev/null
+++ b/src/tls/tls_channel.h
@@ -0,0 +1,85 @@
+/*
+* TLS Channel
+* (C) 2011 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_CHANNEL_H__
+#define BOTAN_TLS_CHANNEL_H__
+
+#include <botan/tls_policy.h>
+#include <botan/tls_record.h>
+#include <botan/x509cert.h>
+#include <vector>
+
+namespace Botan {
+
+/**
+* Generic interface for TLS endpoint
+*/
+class BOTAN_DLL TLS_Channel
+ {
+ public:
+ /**
+ * Inject TLS traffic received from counterparty
+
+ * @return a hint as the how many more bytes we need to process the
+ current record (this may be 0 if on a record boundary)
+ */
+ virtual size_t received_data(const byte buf[], size_t buf_size);
+
+ /**
+ * Inject plaintext intended for counterparty
+ */
+ virtual void queue_for_sending(const byte buf[], size_t buf_size);
+
+ /**
+ * Send a close notification alert
+ */
+ void close() { alert(WARNING, CLOSE_NOTIFY); }
+
+ /**
+ * Send a TLS alert message. If the alert is fatal, the
+ * internal state (keys, etc) will be reset
+ */
+ void alert(Alert_Level level, Alert_Type type);
+
+ /**
+ * Is the connection active?
+ */
+ bool is_active() const { return active; }
+
+ /**
+ * Return the certificates of the peer
+ */
+ std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; }
+
+ TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn);
+
+ virtual ~TLS_Channel();
+ protected:
+ virtual void read_handshake(byte rec_type,
+ const MemoryRegion<byte>& rec_buf);
+
+ virtual void process_handshake_msg(Handshake_Type type,
+ const MemoryRegion<byte>& contents) = 0;
+
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn;
+
+ Record_Writer writer;
+ Record_Reader reader;
+
+ SecureQueue pre_handshake_write_queue;
+
+ std::vector<X509_Certificate> peer_certs;
+
+ class Handshake_State* state;
+
+ bool active;
+ };
+
+}
+
+#endif
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index cfa86881c..30c440d29 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -6,9 +6,7 @@
*/
#include <botan/tls_client.h>
-#include <botan/internal/tls_alerts.h>
#include <botan/internal/tls_state.h>
-#include <botan/loadstor.h>
#include <botan/rsa.h>
#include <botan/dsa.h>
#include <botan/dh.h>
@@ -17,7 +15,7 @@ namespace Botan {
namespace {
-/**
+/*
* Verify the state transition is allowed
* FIXME: checks are wrong for session reuse (add a flag for that)
*/
@@ -78,19 +76,16 @@ void client_check_state(Handshake_Type new_msg, Handshake_State* state)
}
-/**
+/*
* TLS Client Constructor
*/
-TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
- std::tr1::function<void (const byte[], size_t, u16bit)> process_fn,
+TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
const TLS_Policy& policy,
RandomNumberGenerator& rng) :
+ TLS_Channel(output_fn, proc_fn),
policy(policy),
- rng(rng),
- proc_fn(process_fn),
- writer(socket_output_fn),
- state(0),
- active(false)
+ rng(rng)
{
writer.set_version(policy.pref_version());
@@ -104,185 +99,16 @@ void TLS_Client::add_client_cert(const X509_Certificate& cert,
certs.push_back(std::make_pair(cert, cert_key));
}
-/**
+/*
* TLS Client Destructor
*/
TLS_Client::~TLS_Client()
{
- close();
for(size_t i = 0; i != certs.size(); i++)
delete certs[i].second;
- delete state;
- }
-
-size_t TLS_Client::received_data(const byte buf[], size_t buf_size)
- {
- try
- {
- reader.add_input(buf, buf_size);
-
- byte rec_type = CONNECTION_CLOSED;
- SecureVector<byte> record;
-
- while(!reader.currently_empty())
- {
- const size_t bytes_needed = reader.get_record(rec_type, record);
-
- if(bytes_needed > 0)
- return bytes_needed;
-
- if(rec_type == APPLICATION_DATA)
- {
- if(active)
- {
- proc_fn(&record[0], record.size(), NO_ALERT_TYPE);
- }
- else
- {
- throw Unexpected_Message("Application data before handshake done");
- }
- }
- else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
- {
- read_handshake(rec_type, record);
- }
- else if(rec_type == ALERT)
- {
- Alert alert(record);
-
- proc_fn(0, 0, alert.type());
-
- if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY)
- {
- if(alert.type() == CLOSE_NOTIFY)
- {
- writer.alert(WARNING, CLOSE_NOTIFY);
- }
-
- close(FATAL, NO_ALERT_TYPE);
- }
- }
- else
- throw Unexpected_Message("Unknown message type received");
- }
-
- return 0; // on a record boundary
- }
- catch(TLS_Exception& e)
- {
- close(FATAL, e.type());
- throw;
- }
- catch(std::exception& e)
- {
- close(FATAL, INTERNAL_ERROR);
- throw;
- }
- }
-
-void TLS_Client::queue_for_sending(const byte buf[], size_t buf_size)
- {
- if(active)
- {
- while(!pre_handshake_write_queue.end_of_data())
- {
- SecureVector<byte> q_buf(1024);
- const size_t got = pre_handshake_write_queue.read(&q_buf[0], q_buf.size());
- writer.send(APPLICATION_DATA, &q_buf[0], got);
- }
-
- writer.send(APPLICATION_DATA, buf, buf_size);
- writer.flush();
- }
- else
- pre_handshake_write_queue.write(buf, buf_size);
- }
-
-/**
-* Close a TLS connection
-*/
-void TLS_Client::close()
- {
- close(WARNING, CLOSE_NOTIFY);
}
-/**
-* Close a TLS connection
-*/
-void TLS_Client::close(Alert_Level level, Alert_Type alert_code)
- {
- if(active)
- {
- active = false;
-
- if(alert_code != NO_ALERT_TYPE)
- {
- try
- {
- writer.alert(level, alert_code);
- writer.flush();
- }
- catch(...) { /* swallow it */ }
- }
-
- reader.reset();
- writer.reset();
- }
- }
-
-/**
-* Split up and process handshake messages
-*/
-void TLS_Client::read_handshake(byte rec_type,
- const MemoryRegion<byte>& rec_buf)
- {
- if(rec_type == HANDSHAKE)
- state->queue.write(&rec_buf[0], rec_buf.size());
-
- while(true)
- {
- Handshake_Type type = HANDSHAKE_NONE;
- SecureVector<byte> contents;
-
- if(rec_type == HANDSHAKE)
- {
- if(state->queue.size() >= 4)
- {
- byte head[4] = { 0 };
- state->queue.peek(head, 4);
-
- const size_t length = make_u32bit(0, head[1], head[2], head[3]);
-
- if(state->queue.size() >= length + 4)
- {
- type = static_cast<Handshake_Type>(head[0]);
- contents.resize(length);
- state->queue.read(head, 4);
- state->queue.read(&contents[0], contents.size());
- }
- }
- }
- else if(rec_type == CHANGE_CIPHER_SPEC)
- {
- if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1)
- type = HANDSHAKE_CCS;
- else
- throw Decoding_Error("Malformed ChangeCipherSpec message");
- }
- else
- throw Decoding_Error("Unknown message type in handshake processing");
-
- if(type == HANDSHAKE_NONE)
- break;
-
- process_handshake_msg(type, contents);
-
- if(type == HANDSHAKE_CCS || !state)
- break;
- }
- }
-
-/**
+/*
* Process a handshake message
*/
void TLS_Client::process_handshake_msg(Handshake_Type type,
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index 6d613be33..063323c8b 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -8,17 +8,15 @@
#ifndef BOTAN_TLS_CLIENT_H__
#define BOTAN_TLS_CLIENT_H__
-#include <botan/tls_policy.h>
-#include <botan/tls_record.h>
+#include <botan/tls_channel.h>
#include <vector>
-#include <string>
namespace Botan {
/**
* SSL/TLS Client
*/
-class BOTAN_DLL TLS_Client
+class BOTAN_DLL TLS_Client : public TLS_Channel
{
public:
/**
@@ -29,57 +27,17 @@ class BOTAN_DLL TLS_Client
const TLS_Policy& policy,
RandomNumberGenerator& rng);
- /**
- * Inject TLS traffic received from counterparty
-
- * @return a hint as the how many more bytes we need to process the
- current record (this may be 0 if on a record boundary)
- */
- size_t received_data(const byte buf[], size_t buf_size);
-
- /**
- * Inject plaintext intended for counterparty
- */
- void queue_for_sending(const byte buf[], size_t buf_size);
-
- void close();
-
- bool handshake_complete() const { return active; }
-
- std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; }
-
void add_client_cert(const X509_Certificate& cert,
Private_Key* cert_key);
~TLS_Client();
private:
- void close(Alert_Level, Alert_Type);
-
- size_t get_pending_socket_input(byte output[], size_t length);
-
- void initialize();
- void do_handshake();
-
- void state_machine();
- void read_handshake(byte, const MemoryRegion<byte>&);
void process_handshake_msg(Handshake_Type, const MemoryRegion<byte>&);
const TLS_Policy& policy;
RandomNumberGenerator& rng;
- std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn;
-
- Record_Writer writer;
- Record_Reader reader;
-
- SecureQueue pre_handshake_write_queue;
-
- std::vector<X509_Certificate> peer_certs;
std::vector<std::pair<X509_Certificate, Private_Key*> > certs;
-
- class Handshake_State* state;
- //SecureVector<byte> session_id;
- bool active;
};
}
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 8964be3d7..81ed2c48e 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -1,6 +1,6 @@
/*
* TLS Server
-* (C) 2004-2010 Jack Lloyd
+* (C) 2004-2011 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -85,40 +85,20 @@ void server_check_state(Handshake_Type new_msg, Handshake_State* state)
/*
* TLS Server Constructor
*/
-TLS_Server::TLS_Server(std::tr1::function<size_t (byte[], size_t)> input_fn,
- std::tr1::function<void (const byte[], size_t)> output_fn,
+TLS_Server::TLS_Server(std::tr1::function<void (const byte[], size_t)> output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
const TLS_Policy& policy,
RandomNumberGenerator& rng,
const X509_Certificate& cert,
const Private_Key& cert_key) :
- input_fn(input_fn),
+ TLS_Channel(output_fn, proc_fn),
policy(policy),
- rng(rng),
- writer(output_fn)
+ rng(rng)
{
- state = 0;
+ writer.set_version(TLS_V10);
cert_chain.push_back(cert);
private_key = PKCS8::copy_key(cert_key, rng);
-
- try {
- active = false;
- writer.set_version(TLS_V10);
- do_handshake();
- active = true;
- }
- catch(std::exception& e)
- {
- if(state)
- {
- delete state;
- state = 0;
- }
-
- writer.alert(FATAL, HANDSHAKE_FAILURE);
- throw Stream_IO_Error(std::string("TLS_Server: Handshake failed: ") +
- e.what());
- }
}
/*
@@ -126,143 +106,7 @@ TLS_Server::TLS_Server(std::tr1::function<size_t (byte[], size_t)> input_fn,
*/
TLS_Server::~TLS_Server()
{
- close();
delete private_key;
- delete state;
- }
-
-/*
-* Return the peer's certificate chain
-*/
-std::vector<X509_Certificate> TLS_Server::peer_cert_chain() const
- {
- return peer_certs;
- }
-
-/*
-* Write to a TLS connection
-*/
-void TLS_Server::write(const byte buf[], size_t length)
- {
- if(!active)
- throw Internal_Error("TLS_Server::write called while closed");
-
- writer.send(APPLICATION_DATA, buf, length);
- }
-
-/*
-* Read from a TLS connection
-*/
-size_t TLS_Server::read(byte out[], size_t length)
- {
- if(!active)
- throw Internal_Error("TLS_Server::read called while closed");
-
- writer.flush();
-
- while(read_buf.size() == 0)
- {
- state_machine();
- if(active == false)
- break;
- }
-
- size_t got = std::min<size_t>(read_buf.size(), length);
- read_buf.read(out, got);
- return got;
- }
-
-/*
-* Check connection status
-*/
-bool TLS_Server::is_closed() const
- {
- if(!active)
- return true;
- return false;
- }
-
-/*
-* Close a TLS connection
-*/
-void TLS_Server::close()
- {
- close(WARNING, CLOSE_NOTIFY);
- }
-
-/*
-* Close a TLS connection
-*/
-void TLS_Server::close(Alert_Level level, Alert_Type alert_code)
- {
- if(active)
- {
- try {
- active = false;
- writer.alert(level, alert_code);
- writer.flush();
- }
- catch(...) {}
- }
- }
-
-/*
-* Iterate the TLS state machine
-*/
-void TLS_Server::state_machine()
- {
- byte rec_type = CONNECTION_CLOSED;
- SecureVector<byte> record(1024);
-
- size_t bytes_needed = reader.get_record(rec_type, record);
-
- while(bytes_needed)
- {
- size_t to_get = std::min<size_t>(record.size(), bytes_needed);
- size_t got = input_fn(&record[0], to_get);
-
- if(got == 0)
- {
- rec_type = CONNECTION_CLOSED;
- break;
- }
-
- reader.add_input(&record[0], got);
-
- bytes_needed = reader.get_record(rec_type, record);
- }
-
- if(rec_type == CONNECTION_CLOSED)
- {
- active = false;
- reader.reset();
- writer.reset();
- }
- else if(rec_type == APPLICATION_DATA)
- {
- if(active)
- read_buf.write(&record[0], record.size());
- else
- throw Unexpected_Message("Application data before handshake done");
- }
- else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
- read_handshake(rec_type, record);
- else if(rec_type == ALERT)
- {
- Alert alert(record);
-
- if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY)
- {
- if(alert.type() == CLOSE_NOTIFY)
- writer.alert(WARNING, CLOSE_NOTIFY);
-
- reader.reset();
- writer.reset();
- active = false;
- }
- }
- else
- throw Unexpected_Message("Unknown message type received");
}
/*
@@ -271,54 +115,10 @@ void TLS_Server::state_machine()
void TLS_Server::read_handshake(byte rec_type,
const MemoryRegion<byte>& rec_buf)
{
- if(rec_type == HANDSHAKE)
- {
- if(!state)
- state = new Handshake_State;
- state->queue.write(&rec_buf[0], rec_buf.size());
- }
-
- while(true)
- {
- Handshake_Type type = HANDSHAKE_NONE;
- SecureVector<byte> contents;
-
- if(rec_type == HANDSHAKE)
- {
- if(state->queue.size() >= 4)
- {
- byte head[4] = { 0 };
- state->queue.peek(head, 4);
-
- const size_t length = make_u32bit(0, head[1], head[2], head[3]);
-
- if(state->queue.size() >= length + 4)
- {
- type = static_cast<Handshake_Type>(head[0]);
- contents.resize(length);
- state->queue.read(head, 4);
- state->queue.read(&contents[0], contents.size());
- }
- }
- }
- else if(rec_type == CHANGE_CIPHER_SPEC)
- {
- if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1)
- type = HANDSHAKE_CCS;
- else
- throw Decoding_Error("Malformed ChangeCipherSpec message");
- }
- else
- throw Decoding_Error("Unknown message type in handshake processing");
+ if(rec_type == HANDSHAKE && !state)
+ state = new Handshake_State;
- if(type == HANDSHAKE_NONE)
- break;
-
- process_handshake_msg(type, contents);
-
- if(type == HANDSHAKE_CCS || !state)
- break;
- }
+ TLS_Channel::read_handshake(rec_type, rec_buf);
}
/*
@@ -474,21 +274,4 @@ void TLS_Server::process_handshake_msg(Handshake_Type type,
throw Unexpected_Message("Unknown handshake message received");
}
-/*
-* Perform a server-side TLS handshake
-*/
-void TLS_Server::do_handshake()
- {
- while(true)
- {
- if(active && !state)
- break;
-
- state_machine();
-
- if(!active && !state)
- throw TLS_Exception(HANDSHAKE_FAILURE, "TLS_Server: Handshake failed");
- }
- }
-
}
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 510ad15a7..e975071d2 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -8,8 +8,7 @@
#ifndef BOTAN_TLS_SERVER_H__
#define BOTAN_TLS_SERVER_H__
-#include <botan/tls_record.h>
-#include <botan/tls_policy.h>
+#include <botan/tls_channel.h>
#include <vector>
namespace Botan {
@@ -17,58 +16,42 @@ namespace Botan {
/**
* TLS Server
*/
-class BOTAN_DLL TLS_Server
+class BOTAN_DLL TLS_Server : public TLS_Channel
{
public:
- size_t read(byte buf[], size_t buf_len);
- void write(const byte buf[], size_t buf_len);
- std::vector<X509_Certificate> peer_cert_chain() const;
-
- std::string requested_hostname() const
- { return client_requested_hostname; }
-
- void close();
- bool is_closed() const;
-
- /*
+ /**
+ * TLS_Server initialization
+ *
* FIXME: support cert chains (!)
* FIXME: support anonymous servers
*/
- TLS_Server(std::tr1::function<size_t (byte[], size_t)> input_fn,
- std::tr1::function<void (const byte[], size_t)> output_fn,
+ TLS_Server(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
const TLS_Policy& policy,
RandomNumberGenerator& rng,
const X509_Certificate& cert,
const Private_Key& cert_key);
~TLS_Server();
- private:
- void close(Alert_Level, Alert_Type);
- void do_handshake();
- void state_machine();
+ /**
+ * Return the server name indicator, if set by the client
+ */
+ std::string server_name_indicator() const
+ { return client_requested_hostname; }
+ private:
void read_handshake(byte, const MemoryRegion<byte>&);
void process_handshake_msg(Handshake_Type, const MemoryRegion<byte>&);
- std::tr1::function<size_t (byte[], size_t)> input_fn;
-
const TLS_Policy& policy;
RandomNumberGenerator& rng;
- Record_Writer writer;
- Record_Reader reader;
-
- // FIXME: rename to match TLS_Client
- std::vector<X509_Certificate> cert_chain, peer_certs;
+ std::vector<X509_Certificate> cert_chain;
Private_Key* private_key;
- class Handshake_State* state;
- SecureVector<byte> session_id;
- SecureQueue read_buf;
std::string client_requested_hostname;
- bool active;
};
}