/* * SM2 Encryption * (C) 2017 Ribose Inc * * Botan is released under the Simplified BSD License (see license.txt) */ #include #include #include #include #include #include #include #include namespace Botan { bool SM2_Encryption_PrivateKey::check_key(RandomNumberGenerator& rng, bool strong) const { if(!public_point().on_the_curve()) return false; if(!strong) return true; return KeyPair::encryption_consistency_check(rng, *this, "SM3"); } SM2_Encryption_PrivateKey::SM2_Encryption_PrivateKey(const AlgorithmIdentifier& alg_id, const secure_vector& key_bits) : EC_PrivateKey(alg_id, key_bits) { } SM2_Encryption_PrivateKey::SM2_Encryption_PrivateKey(RandomNumberGenerator& rng, const EC_Group& domain, const BigInt& x) : EC_PrivateKey(rng, domain, x) { } namespace { class SM2_Encryption_Operation final : public PK_Ops::Encryption { public: SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key, const std::string& kdf_hash) : m_group(key.domain()), m_mul_public_point(key.public_point()), m_kdf_hash(kdf_hash) {} size_t max_input_bits() const override { // This is arbitrary, but assumes SM2 is used for key encapsulation return 512; } secure_vector encrypt(const uint8_t msg[], size_t msg_len, RandomNumberGenerator& rng) override { std::unique_ptr hash = HashFunction::create_or_throw(m_kdf_hash); std::unique_ptr kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); const size_t p_bytes = m_group.get_p_bytes(); const BigInt k = BigInt::random_integer(rng, 1, m_group.get_order()); const PointGFp C1 = m_group.blinded_base_point_multiply(k, rng, m_ws); const BigInt x1 = C1.get_affine_x(); const BigInt y1 = C1.get_affine_y(); std::vector x1_bytes(p_bytes); std::vector y1_bytes(p_bytes); BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1); BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1); const PointGFp kPB = m_mul_public_point.mul(k, rng, m_group.get_order(), m_ws); const BigInt x2 = kPB.get_affine_x(); const BigInt y2 = kPB.get_affine_y(); std::vector x2_bytes(p_bytes); std::vector y2_bytes(p_bytes); BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); secure_vector kdf_input; kdf_input += x2_bytes; kdf_input += y2_bytes; const secure_vector kdf_output = kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); secure_vector masked_msg(msg_len); xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len); hash->update(x2_bytes); hash->update(msg, msg_len); hash->update(y2_bytes); std::vector C3(hash->output_length()); hash->final(C3.data()); return DER_Encoder() .start_cons(SEQUENCE) .encode(x1) .encode(y1) .encode(C3, OCTET_STRING) .encode(masked_msg, OCTET_STRING) .end_cons() .get_contents(); } private: const EC_Group m_group; PointGFp_Var_Point_Precompute m_mul_public_point; const std::string m_kdf_hash; std::vector m_ws; }; class SM2_Decryption_Operation final : public PK_Ops::Decryption { public: SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key, RandomNumberGenerator& rng, const std::string& kdf_hash) : m_key(key), m_rng(rng), m_kdf_hash(kdf_hash) {} secure_vector decrypt(uint8_t& valid_mask, const uint8_t ciphertext[], size_t ciphertext_len) override { const EC_Group& group = m_key.domain(); const BigInt& cofactor = group.get_cofactor(); const size_t p_bytes = group.get_p_bytes(); valid_mask = 0x00; std::unique_ptr hash = HashFunction::create_or_throw(m_kdf_hash); std::unique_ptr kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); // Too short to be valid - no timing problem from early return if(ciphertext_len < 1 + p_bytes*2 + hash->output_length()) { return secure_vector(); } BigInt x1, y1; secure_vector C3, masked_msg; BER_Decoder(ciphertext, ciphertext_len) .start_cons(SEQUENCE) .decode(x1) .decode(y1) .decode(C3, OCTET_STRING) .decode(masked_msg, OCTET_STRING) .end_cons() .verify_end(); PointGFp C1 = group.point(x1, y1); C1.randomize_repr(m_rng); if(!C1.on_the_curve()) return secure_vector(); if(cofactor > 1 && (C1 * cofactor).is_zero()) { return secure_vector(); } const PointGFp dbC1 = group.blinded_var_point_multiply( C1, m_key.private_value(), m_rng, m_ws); const BigInt x2 = dbC1.get_affine_x(); const BigInt y2 = dbC1.get_affine_y(); std::vector x2_bytes(p_bytes); std::vector y2_bytes(p_bytes); BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); secure_vector kdf_input; kdf_input += x2_bytes; kdf_input += y2_bytes; const secure_vector kdf_output = kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size()); xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size()); hash->update(x2_bytes); hash->update(masked_msg); hash->update(y2_bytes); secure_vector u = hash->final(); if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false) return secure_vector(); valid_mask = 0xFF; return masked_msg; } private: const SM2_Encryption_PrivateKey& m_key; RandomNumberGenerator& m_rng; const std::string m_kdf_hash; std::vector m_ws; }; } std::unique_ptr SM2_Encryption_PublicKey::create_encryption_op(RandomNumberGenerator& /*rng*/, const std::string& params, const std::string& provider) const { if(provider == "base" || provider.empty()) { const std::string kdf_hash = (params.empty() ? "SM3" : params); return std::unique_ptr(new SM2_Encryption_Operation(*this, kdf_hash)); } throw Provider_Not_Found(algo_name(), provider); } std::unique_ptr SM2_Encryption_PrivateKey::create_decryption_op(RandomNumberGenerator& rng, const std::string& params, const std::string& provider) const { if(provider == "base" || provider.empty()) { const std::string kdf_hash = (params.empty() ? "SM3" : params); return std::unique_ptr(new SM2_Decryption_Operation(*this, rng, kdf_hash)); } throw Provider_Not_Found(algo_name(), provider); } }