diff options
author | Jack Lloyd <[email protected]> | 2017-09-05 17:49:25 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2017-09-05 17:49:25 -0400 |
commit | 7c91ef340a40fc163cda1e1a86db6b044d96e165 (patch) | |
tree | 906646c136f0d278c8678a5548c2bea925f133a6 /src/lib/pubkey | |
parent | 9972e0fc2b407ea831f4cf90c5196b8b343e5e3a (diff) |
Support arbitrary hashes for SM2 encryption
This is a contribution from Ribose Inc.
Diffstat (limited to 'src/lib/pubkey')
-rw-r--r-- | src/lib/pubkey/sm2/sm2_enc.cpp | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/src/lib/pubkey/sm2/sm2_enc.cpp b/src/lib/pubkey/sm2/sm2_enc.cpp index a832dd1ac..988db1ee9 100644 --- a/src/lib/pubkey/sm2/sm2_enc.cpp +++ b/src/lib/pubkey/sm2/sm2_enc.cpp @@ -43,11 +43,12 @@ namespace { class SM2_Encryption_Operation : public PK_Ops::Encryption { public: - SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key) : + SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key, const std::string& kdf_hash) : m_p_bytes(key.domain().get_curve().get_p().bytes()), m_order(key.domain().get_order()), m_base_point(key.domain().get_base_point(), m_order), - m_public_point(key.public_point(), m_order) + m_public_point(key.public_point(), m_order), + m_kdf_hash(kdf_hash) {} size_t max_input_bits() const override @@ -60,8 +61,8 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption size_t msg_len, RandomNumberGenerator& rng) override { - std::unique_ptr<HashFunction> hash = HashFunction::create("SM3"); - std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)"); + std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); + std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); secure_vector<uint8_t> ciphertext; ciphertext.reserve(1 + m_p_bytes*2 + msg_len + hash->output_length()); @@ -115,15 +116,18 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption const BigInt& m_order; Blinded_Point_Multiply m_base_point; Blinded_Point_Multiply m_public_point; + const std::string m_kdf_hash; }; class SM2_Decryption_Operation : public PK_Ops::Decryption { public: SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key, - RandomNumberGenerator& rng) : + RandomNumberGenerator& rng, + const std::string& kdf_hash) : m_key(key), - m_rng(rng) + m_rng(rng), + m_kdf_hash(kdf_hash) {} secure_vector<uint8_t> decrypt(uint8_t& valid_mask, @@ -135,7 +139,8 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption valid_mask = 0; - std::unique_ptr<HashFunction> hash = HashFunction::create("SM3"); + std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); + std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); // Too short to be valid - no timing problem from early return if(ciphertext_len < 1 + p_bytes*2 + hash->output_length()) @@ -174,7 +179,6 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption const size_t msg_len = ciphertext_len - (1 + p_bytes*2 + hash->output_length()); - std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)"); const secure_vector<uint8_t> kdf_output = kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); @@ -196,6 +200,7 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption const SM2_Encryption_PrivateKey& m_key; RandomNumberGenerator& m_rng; const std::string m_ident; + const std::string m_kdf_hash; }; } @@ -207,8 +212,8 @@ SM2_Encryption_PublicKey::create_encryption_op(RandomNumberGenerator& /*rng*/, { if(provider == "base" || provider.empty()) { - if(params == "") - return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this)); + const std::string kdf_hash = (params.empty() ? "SM3" : params); + return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this, kdf_hash)); } throw Provider_Not_Found(algo_name(), provider); @@ -221,8 +226,8 @@ SM2_Encryption_PrivateKey::create_decryption_op(RandomNumberGenerator& rng, { if(provider == "base" || provider.empty()) { - if(params == "") - return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng)); + const std::string kdf_hash = (params.empty() ? "SM3" : params); + return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng, kdf_hash)); } throw Provider_Not_Found(algo_name(), provider); |