diff options
Diffstat (limited to 'src/lib/pubkey/sm2/sm2_enc.cpp')
-rw-r--r-- | src/lib/pubkey/sm2/sm2_enc.cpp | 55 |
1 files changed, 30 insertions, 25 deletions
diff --git a/src/lib/pubkey/sm2/sm2_enc.cpp b/src/lib/pubkey/sm2/sm2_enc.cpp index aca31941d..2d44faacb 100644 --- a/src/lib/pubkey/sm2/sm2_enc.cpp +++ b/src/lib/pubkey/sm2/sm2_enc.cpp @@ -8,13 +8,15 @@ #include <botan/sm2_enc.h> #include <botan/pk_ops.h> #include <botan/keypair.h> +#include <botan/der_enc.h> +#include <botan/ber_dec.h> #include <botan/kdf.h> #include <botan/hash.h> namespace Botan { bool SM2_Encryption_PrivateKey::check_key(RandomNumberGenerator& rng, - bool strong) const + bool strong) const { if(!public_point().on_the_curve()) return false; @@ -64,9 +66,6 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); - secure_vector<uint8_t> ciphertext; - ciphertext.reserve(1 + m_p_bytes*2 + msg_len + hash->output_length()); - const BigInt k = BigInt::random_integer(rng, 1, m_order); const PointGFp C1 = m_base_point.blinded_multiply(k, rng); @@ -102,13 +101,14 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption std::vector<uint8_t> C3(hash->output_length()); hash->final(C3.data()); - ciphertext.push_back(0x04); - ciphertext += x1_bytes; - ciphertext += y1_bytes; - ciphertext += C3; - ciphertext += masked_msg; - - return ciphertext; + return DER_Encoder() + .start_cons(SEQUENCE) + .encode(x1) + .encode(y1) + .encode(C3, OCTET_STRING) + .encode(masked_msg, OCTET_STRING) + .end_cons() + .get_contents(); } private: @@ -137,7 +137,7 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption const BigInt& cofactor = m_key.domain().get_cofactor(); const size_t p_bytes = m_key.domain().get_curve().get_p().bytes(); - valid_mask = 0; + valid_mask = 0x00; std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); @@ -148,15 +148,21 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption return secure_vector<uint8_t>(); } - if(ciphertext[0] != 0x04) - { - return secure_vector<uint8_t>(); - } + BigInt x1, y1; + secure_vector<uint8_t> C3, masked_msg; - const size_t msg_len = ciphertext_len - (1 + p_bytes*2 + hash->output_length()); + BER_Decoder(ciphertext, ciphertext_len) + .start_cons(SEQUENCE) + .decode(x1) + .decode(y1) + .decode(C3, OCTET_STRING) + .decode(masked_msg, OCTET_STRING) + .end_cons() + .verify_end(); - const PointGFp C1 = OS2ECP(ciphertext, 1 + p_bytes*2, m_key.domain().get_curve()); - // OS2ECP verifies C1 is on the curve + const PointGFp C1(m_key.domain().get_curve(), x1, y1); + if(!C1.on_the_curve()) + return secure_vector<uint8_t>(); Blinded_Point_Multiply C1_mul(C1, m_key.domain().get_order()); @@ -180,21 +186,20 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption kdf_input += y2_bytes; const secure_vector<uint8_t> kdf_output = - kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); + kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size()); - secure_vector<uint8_t> msg(msg_len); - xor_buf(msg.data(), ciphertext + (1+p_bytes*2+hash->output_length()), kdf_output.data(), msg_len); + xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size()); hash->update(x2_bytes); - hash->update(msg); + hash->update(masked_msg); hash->update(y2_bytes); secure_vector<uint8_t> u = hash->final(); - if(constant_time_compare(u.data(), ciphertext + (1+p_bytes*2), hash->output_length()) == false) + if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false) return secure_vector<uint8_t>(); valid_mask = 0xFF; - return msg; + return masked_msg; } private: const SM2_Encryption_PrivateKey& m_key; |