aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/pubkey/sm2/sm2_enc.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/pubkey/sm2/sm2_enc.cpp')
-rw-r--r--src/lib/pubkey/sm2/sm2_enc.cpp55
1 files changed, 30 insertions, 25 deletions
diff --git a/src/lib/pubkey/sm2/sm2_enc.cpp b/src/lib/pubkey/sm2/sm2_enc.cpp
index aca31941d..2d44faacb 100644
--- a/src/lib/pubkey/sm2/sm2_enc.cpp
+++ b/src/lib/pubkey/sm2/sm2_enc.cpp
@@ -8,13 +8,15 @@
#include <botan/sm2_enc.h>
#include <botan/pk_ops.h>
#include <botan/keypair.h>
+#include <botan/der_enc.h>
+#include <botan/ber_dec.h>
#include <botan/kdf.h>
#include <botan/hash.h>
namespace Botan {
bool SM2_Encryption_PrivateKey::check_key(RandomNumberGenerator& rng,
- bool strong) const
+ bool strong) const
{
if(!public_point().on_the_curve())
return false;
@@ -64,9 +66,6 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption
std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
- secure_vector<uint8_t> ciphertext;
- ciphertext.reserve(1 + m_p_bytes*2 + msg_len + hash->output_length());
-
const BigInt k = BigInt::random_integer(rng, 1, m_order);
const PointGFp C1 = m_base_point.blinded_multiply(k, rng);
@@ -102,13 +101,14 @@ class SM2_Encryption_Operation : public PK_Ops::Encryption
std::vector<uint8_t> C3(hash->output_length());
hash->final(C3.data());
- ciphertext.push_back(0x04);
- ciphertext += x1_bytes;
- ciphertext += y1_bytes;
- ciphertext += C3;
- ciphertext += masked_msg;
-
- return ciphertext;
+ return DER_Encoder()
+ .start_cons(SEQUENCE)
+ .encode(x1)
+ .encode(y1)
+ .encode(C3, OCTET_STRING)
+ .encode(masked_msg, OCTET_STRING)
+ .end_cons()
+ .get_contents();
}
private:
@@ -137,7 +137,7 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
const BigInt& cofactor = m_key.domain().get_cofactor();
const size_t p_bytes = m_key.domain().get_curve().get_p().bytes();
- valid_mask = 0;
+ valid_mask = 0x00;
std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
@@ -148,15 +148,21 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
return secure_vector<uint8_t>();
}
- if(ciphertext[0] != 0x04)
- {
- return secure_vector<uint8_t>();
- }
+ BigInt x1, y1;
+ secure_vector<uint8_t> C3, masked_msg;
- const size_t msg_len = ciphertext_len - (1 + p_bytes*2 + hash->output_length());
+ BER_Decoder(ciphertext, ciphertext_len)
+ .start_cons(SEQUENCE)
+ .decode(x1)
+ .decode(y1)
+ .decode(C3, OCTET_STRING)
+ .decode(masked_msg, OCTET_STRING)
+ .end_cons()
+ .verify_end();
- const PointGFp C1 = OS2ECP(ciphertext, 1 + p_bytes*2, m_key.domain().get_curve());
- // OS2ECP verifies C1 is on the curve
+ const PointGFp C1(m_key.domain().get_curve(), x1, y1);
+ if(!C1.on_the_curve())
+ return secure_vector<uint8_t>();
Blinded_Point_Multiply C1_mul(C1, m_key.domain().get_order());
@@ -180,21 +186,20 @@ class SM2_Decryption_Operation : public PK_Ops::Decryption
kdf_input += y2_bytes;
const secure_vector<uint8_t> kdf_output =
- kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());
+ kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size());
- secure_vector<uint8_t> msg(msg_len);
- xor_buf(msg.data(), ciphertext + (1+p_bytes*2+hash->output_length()), kdf_output.data(), msg_len);
+ xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size());
hash->update(x2_bytes);
- hash->update(msg);
+ hash->update(masked_msg);
hash->update(y2_bytes);
secure_vector<uint8_t> u = hash->final();
- if(constant_time_compare(u.data(), ciphertext + (1+p_bytes*2), hash->output_length()) == false)
+ if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false)
return secure_vector<uint8_t>();
valid_mask = 0xFF;
- return msg;
+ return masked_msg;
}
private:
const SM2_Encryption_PrivateKey& m_key;