aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/pubkey
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2017-09-05 17:49:25 -0400
committerJack Lloyd <[email protected]>2017-09-05 17:49:25 -0400
commit7c91ef340a40fc163cda1e1a86db6b044d96e165 (patch)
tree906646c136f0d278c8678a5548c2bea925f133a6 /src/lib/pubkey
parent9972e0fc2b407ea831f4cf90c5196b8b343e5e3a (diff)
Support arbitrary hashes for SM2 encryption
This is a contribution from Ribose Inc.
Diffstat (limited to 'src/lib/pubkey')
-rw-r--r--src/lib/pubkey/sm2/sm2_enc.cpp29
1 files changed, 17 insertions, 12 deletions
diff --git a/src/lib/pubkey/sm2/sm2_enc.cpp b/src/lib/pubkey/sm2/sm2_enc.cpp
index a832dd1ac..988db1ee9 100644
--- a/src/lib/pubkey/sm2/sm2_enc.cpp
+++ b/src/lib/pubkey/sm2/sm2_enc.cpp
@@ -43,11 +43,12 @@ namespace {
class SM2_Encryption_Operation : public PK_Ops::Encryption
{
public:
- SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key) :
+ SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key, const std::string& kdf_hash) :
m_p_bytes(key.domain().get_curve().get_p().bytes()),
m_order(key.domain().get_order()),
m_base_point(key.domain().get_base_point(), m_order),
- m_public_point(key.public_point(), m_order)
+ m_public_point(key.public_point(), m_order),
+ m_kdf_hash(kdf_hash)
{}
size_t max_input_bits() const override
@@ -60,8 +61,8 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption
size_t msg_len,
RandomNumberGenerator& rng) override
{
- std::unique_ptr<HashFunction> hash = HashFunction::create("SM3");
- std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)");
+ std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
+ std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
secure_vector<uint8_t> ciphertext;
ciphertext.reserve(1 + m_p_bytes*2 + msg_len + hash->output_length());
@@ -115,15 +116,18 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption
const BigInt& m_order;
Blinded_Point_Multiply m_base_point;
Blinded_Point_Multiply m_public_point;
+ const std::string m_kdf_hash;
};
class SM2_Decryption_Operation : public PK_Ops::Decryption
{
public:
SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key,
- RandomNumberGenerator& rng) :
+ RandomNumberGenerator& rng,
+ const std::string& kdf_hash) :
m_key(key),
- m_rng(rng)
+ m_rng(rng),
+ m_kdf_hash(kdf_hash)
{}
secure_vector<uint8_t> decrypt(uint8_t& valid_mask,
@@ -135,7 +139,8 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
valid_mask = 0;
- std::unique_ptr<HashFunction> hash = HashFunction::create("SM3");
+ std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
+ std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
// Too short to be valid - no timing problem from early return
if(ciphertext_len < 1 + p_bytes*2 + hash->output_length())
@@ -174,7 +179,6 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
const size_t msg_len = ciphertext_len - (1 + p_bytes*2 + hash->output_length());
- std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)");
const secure_vector<uint8_t> kdf_output =
kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());
@@ -196,6 +200,7 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
const SM2_Encryption_PrivateKey& m_key;
RandomNumberGenerator& m_rng;
const std::string m_ident;
+ const std::string m_kdf_hash;
};
}
@@ -207,8 +212,8 @@ SM2_Encryption_PublicKey::create_encryption_op(RandomNumberGenerator& /*rng*/,
{
if(provider == "base" || provider.empty())
{
- if(params == "")
- return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this));
+ const std::string kdf_hash = (params.empty() ? "SM3" : params);
+ return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this, kdf_hash));
}
throw Provider_Not_Found(algo_name(), provider);
@@ -221,8 +226,8 @@ SM2_Encryption_PrivateKey::create_decryption_op(RandomNumberGenerator& rng,
{
if(provider == "base" || provider.empty())
{
- if(params == "")
- return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng));
+ const std::string kdf_hash = (params.empty() ? "SM3" : params);
+ return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng, kdf_hash));
}
throw Provider_Not_Found(algo_name(), provider);