diff options
Diffstat (limited to 'src/lib/pubkey/sm2/sm2_enc.cpp')
-rw-r--r-- | src/lib/pubkey/sm2/sm2_enc.cpp | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/src/lib/pubkey/sm2/sm2_enc.cpp b/src/lib/pubkey/sm2/sm2_enc.cpp new file mode 100644 index 000000000..a832dd1ac --- /dev/null +++ b/src/lib/pubkey/sm2/sm2_enc.cpp @@ -0,0 +1,231 @@ +/* +* SM2 Encryption +* (C) 2017 Ribose Inc +* +* Botan is released under the Simplified BSD License (see license.txt) +*/ + +#include <botan/sm2_enc.h> +#include <botan/pk_ops.h> +#include <botan/keypair.h> +#include <botan/kdf.h> +#include <botan/hash.h> + +namespace Botan { + +bool SM2_Encryption_PrivateKey::check_key(RandomNumberGenerator& rng, + bool strong) const + { + if(!public_point().on_the_curve()) + return false; + + if(!strong) + return true; + + return KeyPair::encryption_consistency_check(rng, *this, "SM3"); + } + +SM2_Encryption_PrivateKey::SM2_Encryption_PrivateKey(const AlgorithmIdentifier& alg_id, + const secure_vector<uint8_t>& key_bits) : + EC_PrivateKey(alg_id, key_bits) + { + } + +SM2_Encryption_PrivateKey::SM2_Encryption_PrivateKey(RandomNumberGenerator& rng, + const EC_Group& domain, + const BigInt& x) : + EC_PrivateKey(rng, domain, x) + { + } + +namespace { + +class SM2_Encryption_Operation : public PK_Ops::Encryption + { + public: + SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key) : + m_p_bytes(key.domain().get_curve().get_p().bytes()), + m_order(key.domain().get_order()), + m_base_point(key.domain().get_base_point(), m_order), + m_public_point(key.public_point(), m_order) + {} + + size_t max_input_bits() const override + { + // This is arbitrary, but assumes SM2 is used for key encapsulation + return 512; + } + + secure_vector<uint8_t> encrypt(const uint8_t msg[], + size_t msg_len, + RandomNumberGenerator& rng) override + { + std::unique_ptr<HashFunction> hash = HashFunction::create("SM3"); + std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)"); + + secure_vector<uint8_t> ciphertext; + ciphertext.reserve(1 + m_p_bytes*2 + msg_len + hash->output_length()); + + const BigInt k = BigInt::random_integer(rng, 1, m_order); + + const PointGFp C1 = m_base_point.blinded_multiply(k, rng); + const BigInt x1 = C1.get_affine_x(); + const BigInt y1 = C1.get_affine_y(); + std::vector<uint8_t> x1_bytes(m_p_bytes); + std::vector<uint8_t> y1_bytes(m_p_bytes); + BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1); + BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1); + + const PointGFp kPB = m_public_point.blinded_multiply(k, rng); + + const BigInt x2 = kPB.get_affine_x(); + const BigInt y2 = kPB.get_affine_y(); + std::vector<uint8_t> x2_bytes(m_p_bytes); + std::vector<uint8_t> y2_bytes(m_p_bytes); + BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); + BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); + + secure_vector<uint8_t> kdf_input; + kdf_input += x2_bytes; + kdf_input += y2_bytes; + + const secure_vector<uint8_t> kdf_output = + kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); + + secure_vector<uint8_t> masked_msg(msg_len); + xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len); + + hash->update(x2_bytes); + hash->update(msg, msg_len); + hash->update(y2_bytes); + std::vector<uint8_t> C3(hash->output_length()); + hash->final(C3.data()); + + ciphertext.push_back(0x04); + ciphertext += x1_bytes; + ciphertext += y1_bytes; + ciphertext += masked_msg; + ciphertext += C3; + + return ciphertext; + } + + private: + size_t m_p_bytes; + const BigInt& m_order; + Blinded_Point_Multiply m_base_point; + Blinded_Point_Multiply m_public_point; + }; + +class SM2_Decryption_Operation : public PK_Ops::Decryption + { + public: + SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key, + RandomNumberGenerator& rng) : + m_key(key), + m_rng(rng) + {} + + secure_vector<uint8_t> decrypt(uint8_t& valid_mask, + const uint8_t ciphertext[], + size_t ciphertext_len) override + { + const BigInt& cofactor = m_key.domain().get_cofactor(); + const size_t p_bytes = m_key.domain().get_curve().get_p().bytes(); + + valid_mask = 0; + + std::unique_ptr<HashFunction> hash = HashFunction::create("SM3"); + + // Too short to be valid - no timing problem from early return + if(ciphertext_len < 1 + p_bytes*2 + hash->output_length()) + { + return secure_vector<uint8_t>(); + } + + if(ciphertext[0] != 0x04) + { + return secure_vector<uint8_t>(); + } + + const PointGFp C1 = OS2ECP(ciphertext, 1 + p_bytes*2, m_key.domain().get_curve()); + // OS2ECP verifies C1 is on the curve + + Blinded_Point_Multiply C1_mul(C1, m_key.domain().get_order()); + + if(cofactor > 1 && C1_mul.blinded_multiply(cofactor, m_rng).is_zero()) + { + return secure_vector<uint8_t>(); + } + + const PointGFp dbC1 = C1_mul.blinded_multiply(m_key.private_value(), m_rng); + + const BigInt x2 = dbC1.get_affine_x(); + const BigInt y2 = dbC1.get_affine_y(); + + std::vector<uint8_t> x2_bytes(p_bytes); + std::vector<uint8_t> y2_bytes(p_bytes); + BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); + BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); + + secure_vector<uint8_t> kdf_input; + kdf_input += x2_bytes; + kdf_input += y2_bytes; + + const size_t msg_len = ciphertext_len - (1 + p_bytes*2 + hash->output_length()); + + std::unique_ptr<KDF> kdf = KDF::create("KDF2(SM3)"); + const secure_vector<uint8_t> kdf_output = + kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); + + secure_vector<uint8_t> msg(msg_len); + xor_buf(msg.data(), ciphertext + (1+p_bytes*2), kdf_output.data(), msg_len); + + hash->update(x2_bytes); + hash->update(msg); + hash->update(y2_bytes); + secure_vector<uint8_t> u = hash->final(); + + if(same_mem(u.data(), ciphertext + (1+p_bytes*2+msg_len), hash->output_length()) == false) + return secure_vector<uint8_t>(); + + valid_mask = 0xFF; + return msg; + } + private: + const SM2_Encryption_PrivateKey& m_key; + RandomNumberGenerator& m_rng; + const std::string m_ident; + }; + +} + +std::unique_ptr<PK_Ops::Encryption> +SM2_Encryption_PublicKey::create_encryption_op(RandomNumberGenerator& /*rng*/, + const std::string& params, + const std::string& provider) const + { + if(provider == "base" || provider.empty()) + { + if(params == "") + return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this)); + } + + throw Provider_Not_Found(algo_name(), provider); + } + +std::unique_ptr<PK_Ops::Decryption> +SM2_Encryption_PrivateKey::create_decryption_op(RandomNumberGenerator& rng, + const std::string& params, + const std::string& provider) const + { + if(provider == "base" || provider.empty()) + { + if(params == "") + return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng)); + } + + throw Provider_Not_Found(algo_name(), provider); + } + +} |