diff options
author | lloyd <[email protected]> | 2014-12-20 13:45:23 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2014-12-20 13:45:23 +0000 |
commit | 4083193089f91ec11584ae729ecc3b4cc3b4b86a (patch) | |
tree | d6edc4e416f3eb34aeb91b00d2bee6386962f316 | |
parent | 4562cd4366c81c905dc8957837c6128b193a28bd (diff) |
Add abstract database interface so applications can easily store info
in places other than sqlite3, though sqlite3 remains the only
implementation. The interface is currently limited to precisely the
functionality the TLS session manager needs and will likely expand.
-rw-r--r-- | doc/relnotes/1_11_11.rst | 12 | ||||
-rw-r--r-- | src/lib/tls/sessions_sql/info.txt | 5 | ||||
-rw-r--r-- | src/lib/tls/sessions_sql/tls_session_manager_sql.cpp | 215 | ||||
-rw-r--r-- | src/lib/tls/sessions_sql/tls_session_manager_sql.h (renamed from src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.h) | 40 | ||||
-rw-r--r-- | src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.cpp | 222 | ||||
-rw-r--r-- | src/lib/tls/sessions_sqlite3/info.txt (renamed from src/lib/tls/sessions_sqlite/info.txt) | 2 | ||||
-rw-r--r-- | src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.cpp | 29 | ||||
-rw-r--r-- | src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.h | 52 | ||||
-rw-r--r-- | src/lib/utils/database.h | 62 | ||||
-rw-r--r-- | src/lib/utils/info.txt | 1 | ||||
-rw-r--r-- | src/lib/utils/sqlite3/info.txt | 8 | ||||
-rw-r--r-- | src/lib/utils/sqlite3/sqlite3.cpp | 46 | ||||
-rw-r--r-- | src/lib/utils/sqlite3/sqlite3.h | 72 |
13 files changed, 455 insertions, 311 deletions
diff --git a/doc/relnotes/1_11_11.rst b/doc/relnotes/1_11_11.rst index da6d56a85..4eb7e948c 100644 --- a/doc/relnotes/1_11_11.rst +++ b/doc/relnotes/1_11_11.rst @@ -1,3 +1,15 @@ Version 1.11.11, Not Yet Released ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* The Sqlite3 wrapper has been abstracted to a simple interface for + SQL dbs in general, though Sqlite3 remains the only implementation. + The main logic of the TLS session manager which stored encrypted + sessions to a Sqlite3 database (`TLS::Session_Manager_SQLite`) has + been moved to the new `TLS::Session_Manager_SQL`. The Sqlite3 + manager API remains the same but now just subclasses + `TLS::Session_Manager_SQL` and has a constructor instantiate the + concrete database instance. + + Applications which would like to use a different db can now do so + without having to reimplement the session cache logic simply by + implementing a database wrapper subtype. diff --git a/src/lib/tls/sessions_sql/info.txt b/src/lib/tls/sessions_sql/info.txt new file mode 100644 index 000000000..7016a3d42 --- /dev/null +++ b/src/lib/tls/sessions_sql/info.txt @@ -0,0 +1,5 @@ +define TLS_SESSION_MANAGER_SQL_DB 20141219 + +<requires> +pbkdf2 +</requires> diff --git a/src/lib/tls/sessions_sql/tls_session_manager_sql.cpp b/src/lib/tls/sessions_sql/tls_session_manager_sql.cpp new file mode 100644 index 000000000..561939def --- /dev/null +++ b/src/lib/tls/sessions_sql/tls_session_manager_sql.cpp @@ -0,0 +1,215 @@ +/* +* SQL TLS Session Manager +* (C) 2012,2014 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_session_manager_sql.h> +#include <botan/database.h> +#include <botan/lookup.h> +#include <botan/hex.h> +#include <botan/loadstor.h> +#include <chrono> + +namespace Botan { + +namespace TLS { + +namespace { + +SymmetricKey derive_key(const std::string& passphrase, + const byte salt[], + size_t salt_len, + size_t iterations, + size_t& check_val) + { + std::unique_ptr<PBKDF> pbkdf(get_pbkdf("PBKDF2(SHA-512)")); + + secure_vector<byte> x = pbkdf->derive_key(32 + 2, + passphrase, + salt, salt_len, + iterations).bits_of(); + + check_val = make_u16bit(x[0], x[1]); + return SymmetricKey(&x[2], x.size() - 2); + } + +} + +Session_Manager_SQL::Session_Manager_SQL(std::shared_ptr<SQL_Database> db, + const std::string& passphrase, + RandomNumberGenerator& rng, + size_t max_sessions, + std::chrono::seconds session_lifetime) : + m_db(db), + m_rng(rng), + m_max_sessions(max_sessions), + m_session_lifetime(session_lifetime) + { + m_db->create_table( + "create table if not exists tls_sessions " + "(" + "session_id TEXT PRIMARY KEY, " + "session_start INTEGER, " + "hostname TEXT, " + "hostport INTEGER, " + "session BLOB" + ")"); + + m_db->create_table( + "create table if not exists tls_sessions_metadata " + "(" + "passphrase_salt BLOB, " + "passphrase_iterations INTEGER, " + "passphrase_check INTEGER " + ")"); + + const size_t salts = m_db->row_count("tls_sessions_metadata"); + + if(salts == 1) + { + // existing db + auto stmt = m_db->new_statement("select * from tls_sessions_metadata"); + + if(stmt->step()) + { + std::pair<const byte*, size_t> salt = stmt->get_blob(0); + const size_t iterations = stmt->get_size_t(1); + const size_t check_val_db = stmt->get_size_t(2); + + size_t check_val_created; + m_session_key = derive_key(passphrase, + salt.first, + salt.second, + iterations, + check_val_created); + + if(check_val_created != check_val_db) + throw std::runtime_error("Session database password not valid"); + } + } + else + { + // maybe just zap the salts + sessions tables in this case? + if(salts != 0) + throw std::runtime_error("Seemingly corrupted database, multiple salts found"); + + // new database case + + std::vector<byte> salt = unlock(rng.random_vec(16)); + const size_t iterations = 256 * 1024; + size_t check_val = 0; + + m_session_key = derive_key(passphrase, &salt[0], salt.size(), + iterations, check_val); + + auto stmt = m_db->new_statement("insert into tls_sessions_metadata values(?1, ?2, ?3)"); + + stmt->bind(1, salt); + stmt->bind(2, iterations); + stmt->bind(3, check_val); + + stmt->spin(); + } + } + +bool Session_Manager_SQL::load_from_session_id(const std::vector<byte>& session_id, + Session& session) + { + auto stmt = m_db->new_statement("select session from tls_sessions where session_id = ?1"); + + stmt->bind(1, hex_encode(session_id)); + + while(stmt->step()) + { + std::pair<const byte*, size_t> blob = stmt->get_blob(0); + + try + { + session = Session::decrypt(blob.first, blob.second, m_session_key); + return true; + } + catch(...) + { + } + } + + return false; + } + +bool Session_Manager_SQL::load_from_server_info(const Server_Information& server, + Session& session) + { + auto stmt = m_db->new_statement("select session from tls_sessions" + " where hostname = ?1 and hostport = ?2" + " order by session_start desc"); + + stmt->bind(1, server.hostname()); + stmt->bind(2, server.port()); + + while(stmt->step()) + { + std::pair<const byte*, size_t> blob = stmt->get_blob(0); + + try + { + session = Session::decrypt(blob.first, blob.second, m_session_key); + return true; + } + catch(...) + { + } + } + + return false; + } + +void Session_Manager_SQL::remove_entry(const std::vector<byte>& session_id) + { + auto stmt = m_db->new_statement("delete from tls_sessions where session_id = ?1"); + + stmt->bind(1, hex_encode(session_id)); + + stmt->spin(); + } + +void Session_Manager_SQL::save(const Session& session) + { + auto stmt = m_db->new_statement("insert or replace into tls_sessions" + " values(?1, ?2, ?3, ?4, ?5)"); + + stmt->bind(1, hex_encode(session.session_id())); + stmt->bind(2, session.start_time()); + stmt->bind(3, session.server_info().hostname()); + stmt->bind(4, session.server_info().port()); + stmt->bind(5, session.encrypt(m_session_key, m_rng)); + + stmt->spin(); + + prune_session_cache(); + } + +void Session_Manager_SQL::prune_session_cache() + { + // First expire old sessions + auto remove_expired = m_db->new_statement("delete from tls_sessions where session_start <= ?1"); + remove_expired->bind(1, std::chrono::system_clock::now() - m_session_lifetime); + remove_expired->spin(); + + const size_t sessions = m_db->row_count("tls_sessions"); + + // Then if needed expire some more sessions at random + if(sessions > m_max_sessions) + { + auto remove_some = m_db->new_statement("delete from tls_sessions where session_id in " + "(select session_id from tls_sessions limit ?1)"); + + remove_some->bind(1, sessions - m_max_sessions); + remove_some->spin(); + } + } + +} + +} diff --git a/src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.h b/src/lib/tls/sessions_sql/tls_session_manager_sql.h index 7892ccd6a..0935b73ac 100644 --- a/src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.h +++ b/src/lib/tls/sessions_sql/tls_session_manager_sql.h @@ -1,51 +1,52 @@ /* -* SQLite3 TLS Session Manager -* (C) 2012 Jack Lloyd +* TLS Session Manager storing to encrypted SQL db table +* (C) 2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ -#ifndef BOTAN_TLS_SQLITE3_SESSION_MANAGER_H__ -#define BOTAN_TLS_SQLITE3_SESSION_MANAGER_H__ +#ifndef BOTAN_TLS_SQL_SESSION_MANAGER_H__ +#define BOTAN_TLS_SQL_SESSION_MANAGER_H__ #include <botan/tls_session_manager.h> +#include <botan/database.h> #include <botan/rng.h> namespace Botan { -class sqlite3_database; - namespace TLS { /** -* An implementation of Session_Manager that saves values in a SQLite3 +* An implementation of Session_Manager that saves values in a SQL * database file, with the session data encrypted using a passphrase. * * @warning For clients, the hostnames associated with the saved * sessions are stored in the database in plaintext. This may be a * serious privacy risk in some situations. */ -class BOTAN_DLL Session_Manager_SQLite : public Session_Manager +class BOTAN_DLL Session_Manager_SQL : public Session_Manager { public: /** + * @param db A connection to the database to use + The table names botan_tls_sessions and + botan_tls_sessions_metadata will be used * @param passphrase used to encrypt the session data * @param rng a random number generator - * @param db_filename filename of the SQLite database file. - The table names tls_sessions and tls_sessions_metadata - will be used * @param max_sessions a hint on the maximum number of sessions * to keep in memory at any one time. (If zero, don't cap) * @param session_lifetime sessions are expired after this many * seconds have elapsed from initial handshake. */ - Session_Manager_SQLite(const std::string& passphrase, - RandomNumberGenerator& rng, - const std::string& db_filename, - size_t max_sessions = 1000, - std::chrono::seconds session_lifetime = std::chrono::seconds(7200)); + Session_Manager_SQL(std::shared_ptr<SQL_Database> db, + const std::string& passphrase, + RandomNumberGenerator& rng, + size_t max_sessions = 1000, + std::chrono::seconds session_lifetime = std::chrono::seconds(7200)); + + Session_Manager_SQL(const Session_Manager_SQL&) = delete; - ~Session_Manager_SQLite(); + Session_Manager_SQL& operator=(const Session_Manager_SQL&) = delete; bool load_from_session_id(const std::vector<byte>& session_id, Session& session) override; @@ -61,16 +62,13 @@ class BOTAN_DLL Session_Manager_SQLite : public Session_Manager { return m_session_lifetime; } private: - Session_Manager_SQLite(const Session_Manager_SQLite&); - Session_Manager_SQLite& operator=(const Session_Manager_SQLite&); - void prune_session_cache(); + std::shared_ptr<SQL_Database> m_db; SymmetricKey m_session_key; RandomNumberGenerator& m_rng; size_t m_max_sessions; std::chrono::seconds m_session_lifetime; - sqlite3_database* m_db; }; } diff --git a/src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.cpp b/src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.cpp deleted file mode 100644 index 21483067f..000000000 --- a/src/lib/tls/sessions_sqlite/tls_session_manager_sqlite.cpp +++ /dev/null @@ -1,222 +0,0 @@ -/* -* SQLite TLS Session Manager -* (C) 2012 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#include <botan/tls_session_manager_sqlite.h> -#include <botan/internal/sqlite3.h> -#include <botan/lookup.h> -#include <botan/hex.h> -#include <botan/loadstor.h> -#include <chrono> - -namespace Botan { - -namespace TLS { - -namespace { - -SymmetricKey derive_key(const std::string& passphrase, - const byte salt[], - size_t salt_len, - size_t iterations, - size_t& check_val) - { - std::unique_ptr<PBKDF> pbkdf(get_pbkdf("PBKDF2(SHA-512)")); - - secure_vector<byte> x = pbkdf->derive_key(32 + 2, - passphrase, - salt, salt_len, - iterations).bits_of(); - - check_val = make_u16bit(x[0], x[1]); - return SymmetricKey(&x[2], x.size() - 2); - } - -} - -Session_Manager_SQLite::Session_Manager_SQLite(const std::string& passphrase, - RandomNumberGenerator& rng, - const std::string& db_filename, - size_t max_sessions, - std::chrono::seconds session_lifetime) : - m_rng(rng), - m_max_sessions(max_sessions), - m_session_lifetime(session_lifetime) - { - m_db = new sqlite3_database(db_filename); - - m_db->create_table( - "create table if not exists tls_sessions " - "(" - "session_id TEXT PRIMARY KEY, " - "session_start INTEGER, " - "hostname TEXT, " - "hostport INTEGER, " - "session BLOB" - ")"); - - m_db->create_table( - "create table if not exists tls_sessions_metadata " - "(" - "passphrase_salt BLOB, " - "passphrase_iterations INTEGER, " - "passphrase_check INTEGER " - ")"); - - const size_t salts = m_db->row_count("tls_sessions_metadata"); - - if(salts == 1) - { - // existing db - sqlite3_statement stmt(m_db, "select * from tls_sessions_metadata"); - - if(stmt.step()) - { - std::pair<const byte*, size_t> salt = stmt.get_blob(0); - const size_t iterations = stmt.get_size_t(1); - const size_t check_val_db = stmt.get_size_t(2); - - size_t check_val_created; - m_session_key = derive_key(passphrase, - salt.first, - salt.second, - iterations, - check_val_created); - - if(check_val_created != check_val_db) - throw std::runtime_error("Session database password not valid"); - } - } - else - { - // maybe just zap the salts + sessions tables in this case? - if(salts != 0) - throw std::runtime_error("Seemingly corrupted database, multiple salts found"); - - // new database case - - std::vector<byte> salt = unlock(rng.random_vec(16)); - const size_t iterations = 256 * 1024; - size_t check_val = 0; - - m_session_key = derive_key(passphrase, &salt[0], salt.size(), - iterations, check_val); - - sqlite3_statement stmt(m_db, "insert into tls_sessions_metadata" - " values(?1, ?2, ?3)"); - - stmt.bind(1, salt); - stmt.bind(2, iterations); - stmt.bind(3, check_val); - - stmt.spin(); - } - } - -Session_Manager_SQLite::~Session_Manager_SQLite() - { - delete m_db; - } - -bool Session_Manager_SQLite::load_from_session_id(const std::vector<byte>& session_id, - Session& session) - { - sqlite3_statement stmt(m_db, "select session from tls_sessions where session_id = ?1"); - - stmt.bind(1, hex_encode(session_id)); - - while(stmt.step()) - { - std::pair<const byte*, size_t> blob = stmt.get_blob(0); - - try - { - session = Session::decrypt(blob.first, blob.second, m_session_key); - return true; - } - catch(...) - { - } - } - - return false; - } - -bool Session_Manager_SQLite::load_from_server_info(const Server_Information& server, - Session& session) - { - sqlite3_statement stmt(m_db, "select session from tls_sessions" - " where hostname = ?1 and hostport = ?2" - " order by session_start desc"); - - stmt.bind(1, server.hostname()); - stmt.bind(2, server.port()); - - while(stmt.step()) - { - std::pair<const byte*, size_t> blob = stmt.get_blob(0); - - try - { - session = Session::decrypt(blob.first, blob.second, m_session_key); - return true; - } - catch(...) - { - } - } - - return false; - } - -void Session_Manager_SQLite::remove_entry(const std::vector<byte>& session_id) - { - sqlite3_statement stmt(m_db, "delete from tls_sessions where session_id = ?1"); - - stmt.bind(1, hex_encode(session_id)); - - stmt.spin(); - } - -void Session_Manager_SQLite::save(const Session& session) - { - sqlite3_statement stmt(m_db, "insert or replace into tls_sessions" - " values(?1, ?2, ?3, ?4, ?5)"); - - stmt.bind(1, hex_encode(session.session_id())); - stmt.bind(2, session.start_time()); - stmt.bind(3, session.server_info().hostname()); - stmt.bind(4, session.server_info().port()); - stmt.bind(5, session.encrypt(m_session_key, m_rng)); - - stmt.spin(); - - prune_session_cache(); - } - -void Session_Manager_SQLite::prune_session_cache() - { - sqlite3_statement remove_expired(m_db, "delete from tls_sessions where session_start <= ?1"); - - remove_expired.bind(1, std::chrono::system_clock::now() - m_session_lifetime); - - remove_expired.spin(); - - const size_t sessions = m_db->row_count("tls_sessions"); - - if(sessions > m_max_sessions) - { - sqlite3_statement remove_some(m_db, "delete from tls_sessions where session_id in " - "(select session_id from tls_sessions limit ?1)"); - - remove_some.bind(1, sessions - m_max_sessions); - remove_some.spin(); - } - } - -} - -} diff --git a/src/lib/tls/sessions_sqlite/info.txt b/src/lib/tls/sessions_sqlite3/info.txt index 76d53f995..b04b6a9d6 100644 --- a/src/lib/tls/sessions_sqlite/info.txt +++ b/src/lib/tls/sessions_sqlite3/info.txt @@ -1,6 +1,6 @@ define TLS_SQLITE3_SESSION_MANAGER 20131128 <requires> -pbkdf2 +sessions_sql sqlite3 </requires> diff --git a/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.cpp b/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.cpp new file mode 100644 index 000000000..30af3699f --- /dev/null +++ b/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.cpp @@ -0,0 +1,29 @@ +/* +* SQLite TLS Session Manager +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_session_manager_sqlite.h> +#include <botan/sqlite3.h> + +namespace Botan { + +namespace TLS { + +Session_Manager_SQLite::Session_Manager_SQLite(const std::string& passphrase, + RandomNumberGenerator& rng, + const std::string& db_filename, + size_t max_sessions, + std::chrono::seconds session_lifetime) : + Session_Manager_SQL(std::make_shared<Sqlite3_Database>(db_filename), + passphrase, + rng, + max_sessions, + session_lifetime) + {} + +} + +} diff --git a/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.h b/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.h new file mode 100644 index 000000000..67c1c9e53 --- /dev/null +++ b/src/lib/tls/sessions_sqlite3/tls_session_manager_sqlite.h @@ -0,0 +1,52 @@ +/* +* SQLite3 TLS Session Manager +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_SQLITE3_SESSION_MANAGER_H__ +#define BOTAN_TLS_SQLITE3_SESSION_MANAGER_H__ + +#include <botan/tls_session_manager_sql.h> +#include <botan/rng.h> + +namespace Botan { + +namespace TLS { + +/** +* An implementation of Session_Manager that saves values in a SQLite3 +* database file, with the session data encrypted using a passphrase. +* +* @warning For clients, the hostnames associated with the saved +* sessions are stored in the database in plaintext. This may be a +* serious privacy risk in some situations. +*/ +class BOTAN_DLL +Session_Manager_SQLite : public Session_Manager_SQL + { + public: + /** + * @param passphrase used to encrypt the session data + * @param rng a random number generator + * @param db_filename filename of the SQLite database file. + The table names tls_sessions and tls_sessions_metadata + will be used + * @param max_sessions a hint on the maximum number of sessions + * to keep in memory at any one time. (If zero, don't cap) + * @param session_lifetime sessions are expired after this many + * seconds have elapsed from initial handshake. + */ + Session_Manager_SQLite(const std::string& passphrase, + RandomNumberGenerator& rng, + const std::string& db_filename, + size_t max_sessions = 1000, + std::chrono::seconds session_lifetime = std::chrono::seconds(7200)); +}; + +} + +} + +#endif diff --git a/src/lib/utils/database.h b/src/lib/utils/database.h new file mode 100644 index 000000000..742c52c7c --- /dev/null +++ b/src/lib/utils/database.h @@ -0,0 +1,62 @@ +/* +* SQL database interface +* (C) 2014 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_SQL_DATABASE_H__ +#define BOTAN_SQL_DATABASE_H__ + +#include <botan/types.h> +#include <string> +#include <chrono> +#include <vector> + +namespace Botan { + +class BOTAN_DLL SQL_Database + { + public: + class BOTAN_DLL Statement + { + public: + /* Bind statement parameters */ + virtual void bind(int column, const std::string& str) = 0; + + virtual void bind(int column, size_t i) = 0; + + virtual void bind(int column, std::chrono::system_clock::time_point time) = 0; + + virtual void bind(int column, const std::vector<byte>& blob) = 0; + + /* Get output */ + virtual std::pair<const byte*, size_t> get_blob(int column) = 0; + + virtual size_t get_size_t(int column) = 0; + + /* Run to completion */ + virtual void spin() = 0; + + /* Maybe update */ + virtual bool step() = 0; + + virtual ~Statement() {} + }; + + /* + * Create a new statement for execution. + * Use ?1, ?2, ?3, etc for parameters to set later with bind + */ + virtual std::shared_ptr<Statement> new_statement(const std::string& base_sql) const = 0; + + virtual size_t row_count(const std::string& table_name) = 0; + + virtual void create_table(const std::string& table_schema) = 0; + + virtual ~SQL_Database() {} +}; + +} + +#endif diff --git a/src/lib/utils/info.txt b/src/lib/utils/info.txt index 17ae249c2..9ba51f193 100644 --- a/src/lib/utils/info.txt +++ b/src/lib/utils/info.txt @@ -18,6 +18,7 @@ bswap.h calendar.h charset.h cpuid.h +database.h exceptn.h get_byte.h loadstor.h diff --git a/src/lib/utils/sqlite3/info.txt b/src/lib/utils/sqlite3/info.txt index 97d59d697..6370f4b2b 100644 --- a/src/lib/utils/sqlite3/info.txt +++ b/src/lib/utils/sqlite3/info.txt @@ -5,10 +5,6 @@ load_on request all -> sqlite3 </libs> -<header:internal> +<header:public> sqlite3.h -</header:internal> - -<source> -sqlite3.cpp -</source> +</header:public> diff --git a/src/lib/utils/sqlite3/sqlite3.cpp b/src/lib/utils/sqlite3/sqlite3.cpp index 7f6626759..0ed8df83c 100644 --- a/src/lib/utils/sqlite3/sqlite3.cpp +++ b/src/lib/utils/sqlite3/sqlite3.cpp @@ -5,13 +5,13 @@ * Released under the terms of the Botan license */ -#include <botan/internal/sqlite3.h> +#include <botan/sqlite3.h> #include <stdexcept> #include <sqlite3.h> namespace Botan { -sqlite3_database::sqlite3_database(const std::string& db_filename) +Sqlite3_Database::Sqlite3_Database(const std::string& db_filename) { int rc = ::sqlite3_open(db_filename.c_str(), &m_db); @@ -24,24 +24,29 @@ sqlite3_database::sqlite3_database(const std::string& db_filename) } } -sqlite3_database::~sqlite3_database() +Sqlite3_Database::~Sqlite3_Database() { if(m_db) ::sqlite3_close(m_db); m_db = nullptr; } -size_t sqlite3_database::row_count(const std::string& table_name) +std::shared_ptr<SQL_Database::Statement> Sqlite3_Database::new_statement(const std::string& base_sql) const { - sqlite3_statement stmt(this, "select count(*) from " + table_name); + return std::make_shared<Sqlite3_Statement>(m_db, base_sql); + } + +size_t Sqlite3_Database::row_count(const std::string& table_name) + { + auto stmt = new_statement("select count(*) from " + table_name); - if(stmt.step()) - return stmt.get_size_t(0); + if(stmt->step()) + return stmt->get_size_t(0); else throw std::runtime_error("Querying size of table " + table_name + " failed"); } -void sqlite3_database::create_table(const std::string& table_schema) +void Sqlite3_Database::create_table(const std::string& table_schema) { char* errmsg = nullptr; int rc = ::sqlite3_exec(m_db, table_schema.c_str(), nullptr, nullptr, &errmsg); @@ -56,44 +61,45 @@ void sqlite3_database::create_table(const std::string& table_schema) } } - -sqlite3_statement::sqlite3_statement(sqlite3_database* db, const std::string& base_sql) +Sqlite3_Database::Sqlite3_Statement::Sqlite3_Statement(sqlite3* db, const std::string& base_sql) { - int rc = ::sqlite3_prepare_v2(db->m_db, base_sql.c_str(), -1, &m_stmt, nullptr); + int rc = ::sqlite3_prepare_v2(db, base_sql.c_str(), -1, &m_stmt, nullptr); if(rc != SQLITE_OK) throw std::runtime_error("sqlite3_prepare failed " + base_sql + ", code " + std::to_string(rc)); } -void sqlite3_statement::bind(int column, const std::string& val) +void Sqlite3_Database::Sqlite3_Statement::bind(int column, const std::string& val) { int rc = ::sqlite3_bind_text(m_stmt, column, val.c_str(), -1, SQLITE_TRANSIENT); if(rc != SQLITE_OK) throw std::runtime_error("sqlite3_bind_text failed, code " + std::to_string(rc)); } -void sqlite3_statement::bind(int column, int val) +void Sqlite3_Database::Sqlite3_Statement::bind(int column, size_t val) { + if(val != static_cast<size_t>(static_cast<int>(val))) // is this legit? + throw std::runtime_error("sqlite3 cannot store " + std::to_string(val) + " without truncation"); int rc = ::sqlite3_bind_int(m_stmt, column, val); if(rc != SQLITE_OK) throw std::runtime_error("sqlite3_bind_int failed, code " + std::to_string(rc)); } -void sqlite3_statement::bind(int column, std::chrono::system_clock::time_point time) +void Sqlite3_Database::Sqlite3_Statement::bind(int column, std::chrono::system_clock::time_point time) { const int timeval = std::chrono::duration_cast<std::chrono::seconds>(time.time_since_epoch()).count(); bind(column, timeval); } -void sqlite3_statement::bind(int column, const std::vector<byte>& val) +void Sqlite3_Database::Sqlite3_Statement::bind(int column, const std::vector<byte>& val) { int rc = ::sqlite3_bind_blob(m_stmt, column, &val[0], val.size(), SQLITE_TRANSIENT); if(rc != SQLITE_OK) throw std::runtime_error("sqlite3_bind_text failed, code " + std::to_string(rc)); } -std::pair<const byte*, size_t> sqlite3_statement::get_blob(int column) +std::pair<const byte*, size_t> Sqlite3_Database::Sqlite3_Statement::get_blob(int column) { BOTAN_ASSERT(::sqlite3_column_type(m_stmt, 0) == SQLITE_BLOB, "Return value is a blob"); @@ -107,7 +113,7 @@ std::pair<const byte*, size_t> sqlite3_statement::get_blob(int column) static_cast<size_t>(session_blob_size)); } -size_t sqlite3_statement::get_size_t(int column) +size_t Sqlite3_Database::Sqlite3_Statement::get_size_t(int column) { BOTAN_ASSERT(::sqlite3_column_type(m_stmt, column) == SQLITE_INTEGER, "Return count is an integer"); @@ -119,17 +125,17 @@ size_t sqlite3_statement::get_size_t(int column) return static_cast<size_t>(sessions_int); } -void sqlite3_statement::spin() +void Sqlite3_Database::Sqlite3_Statement::spin() { while(step()) {} } -bool sqlite3_statement::step() +bool Sqlite3_Database::Sqlite3_Statement::step() { return (::sqlite3_step(m_stmt) == SQLITE_ROW); } -sqlite3_statement::~sqlite3_statement() +Sqlite3_Database::Sqlite3_Statement::~Sqlite3_Statement() { ::sqlite3_finalize(m_stmt); } diff --git a/src/lib/utils/sqlite3/sqlite3.h b/src/lib/utils/sqlite3/sqlite3.h index 3085ff0e3..7853a012f 100644 --- a/src/lib/utils/sqlite3/sqlite3.h +++ b/src/lib/utils/sqlite3/sqlite3.h @@ -1,66 +1,56 @@ /* -* SQLite wrapper -* (C) 2012 Jack Lloyd +* SQLite3 wrapper +* (C) 2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ -#ifndef BOTAN_UTILS_SQLITE_WRAPPER_H__ -#define BOTAN_UTILS_SQLITE_WRAPPER_H__ +#ifndef BOTAN_UTILS_SQLITE3_H__ +#define BOTAN_UTILS_SQLIT3_H__ -#include <botan/types.h> -#include <string> -#include <chrono> -#include <vector> +#include <botan/database.h> class sqlite3; class sqlite3_stmt; namespace Botan { -class sqlite3_database +class BOTAN_DLL Sqlite3_Database : public SQL_Database { public: - sqlite3_database(const std::string& file); + Sqlite3_Database(const std::string& file); - ~sqlite3_database(); + ~Sqlite3_Database(); - size_t row_count(const std::string& table_name); + size_t row_count(const std::string& table_name) override; - void create_table(const std::string& table_schema); + void create_table(const std::string& table_schema) override; + + std::shared_ptr<Statement> new_statement(const std::string& sql) const override; private: - friend class sqlite3_statement; + class Sqlite3_Statement : public Statement + { + public: + void bind(int column, const std::string& val) override; + void bind(int column, size_t val) override; + void bind(int column, std::chrono::system_clock::time_point time) override; + void bind(int column, const std::vector<byte>& val) override; + + std::pair<const byte*, size_t> get_blob(int column) override; + size_t get_size_t(int column) override; + + void spin() override; + bool step() override; + + Sqlite3_Statement(sqlite3* db, const std::string& base_sql); + ~Sqlite3_Statement(); + private: + sqlite3_stmt* m_stmt; + }; sqlite3* m_db; }; -class sqlite3_statement - { - public: - sqlite3_statement(sqlite3_database* db, - const std::string& base_sql); - - void bind(int column, const std::string& val); - - void bind(int column, int val); - - void bind(int column, std::chrono::system_clock::time_point time); - - void bind(int column, const std::vector<byte>& val); - - std::pair<const byte*, size_t> get_blob(int column); - - size_t get_size_t(int column); - - void spin(); - - bool step(); - - ~sqlite3_statement(); - private: - sqlite3_stmt* m_stmt; - }; - } #endif |