diff options
author | Jack Lloyd <[email protected]> | 2018-11-28 10:35:17 -0500 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2018-11-28 10:35:17 -0500 |
commit | 007314c530eb12d414ced07515f8cbc25a0f64f5 (patch) | |
tree | dc887f97efa0248aa5e7b8468c94145f6a1305f8 | |
parent | b03f38f57d4f50ace1ed8b57d83ba70eb5bc1dfb (diff) |
Add CT::Mask type
-rw-r--r-- | src/lib/block/idea/idea.cpp | 4 | ||||
-rw-r--r-- | src/lib/mac/poly1305/poly1305.cpp | 8 | ||||
-rw-r--r-- | src/lib/math/bigint/big_ops2.cpp | 2 | ||||
-rw-r--r-- | src/lib/math/bigint/bigint.cpp | 24 | ||||
-rw-r--r-- | src/lib/math/mp/mp_core.h | 105 | ||||
-rw-r--r-- | src/lib/math/mp/mp_karat.cpp | 6 | ||||
-rw-r--r-- | src/lib/math/numbertheory/monty_exp.cpp | 6 | ||||
-rw-r--r-- | src/lib/modes/mode_pad/mode_pad.cpp | 63 | ||||
-rw-r--r-- | src/lib/pk_pad/eme_oaep/oaep.cpp | 22 | ||||
-rw-r--r-- | src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp | 20 | ||||
-rw-r--r-- | src/lib/pk_pad/iso9796/iso9796.cpp | 34 | ||||
-rw-r--r-- | src/lib/pubkey/dlies/dlies.cpp | 3 | ||||
-rw-r--r-- | src/lib/pubkey/ec_group/point_mul.cpp | 18 | ||||
-rw-r--r-- | src/lib/pubkey/ecies/ecies.cpp | 4 | ||||
-rw-r--r-- | src/lib/pubkey/pubkey.cpp | 18 | ||||
-rw-r--r-- | src/lib/tls/tls_cbc/tls_cbc.cpp | 20 | ||||
-rw-r--r-- | src/lib/utils/ct_utils.h | 377 | ||||
-rw-r--r-- | src/lib/utils/mem_ops.cpp | 9 | ||||
-rw-r--r-- | src/lib/utils/mem_ops.h | 20 | ||||
-rw-r--r-- | src/tests/test_utils.cpp | 73 |
20 files changed, 519 insertions, 317 deletions
diff --git a/src/lib/block/idea/idea.cpp b/src/lib/block/idea/idea.cpp index d6380368d..0bdc36a68 100644 --- a/src/lib/block/idea/idea.cpp +++ b/src/lib/block/idea/idea.cpp @@ -20,7 +20,7 @@ namespace { inline uint16_t mul(uint16_t x, uint16_t y) { const uint32_t P = static_cast<uint32_t>(x) * y; - const uint16_t P_mask = static_cast<uint16_t>(CT::is_zero(P) & 0xFFFF); + const auto P_mask = CT::Mask<uint16_t>(CT::Mask<uint32_t>::is_zero(P)); const uint32_t P_hi = P >> 16; const uint32_t P_lo = P & 0xFFFF; @@ -29,7 +29,7 @@ inline uint16_t mul(uint16_t x, uint16_t y) const uint16_t r_1 = static_cast<uint16_t>((P_lo - P_hi) + carry); const uint16_t r_2 = 1 - x - y; - return CT::select(P_mask, r_2, r_1); + return P_mask.select(r_2, r_1);; } /* diff --git a/src/lib/mac/poly1305/poly1305.cpp b/src/lib/mac/poly1305/poly1305.cpp index 32027b8d0..333a21a1a 100644 --- a/src/lib/mac/poly1305/poly1305.cpp +++ b/src/lib/mac/poly1305/poly1305.cpp @@ -119,10 +119,10 @@ void poly1305_finish(secure_vector<uint64_t>& X, uint8_t mac[16]) uint64_t g2 = h2 + c - (static_cast<uint64_t>(1) << 42); /* select h if h < p, or h + -p if h >= p */ - c = CT::expand_mask<uint64_t>(c); - h0 = CT::select(c, g0, h0); - h1 = CT::select(c, g1, h1); - h2 = CT::select(c, g2, h2); + const auto c_mask = CT::Mask<uint64_t>::expand(c); + h0 = c_mask.select(g0, h0); + h1 = c_mask.select(g1, h1); + h2 = c_mask.select(g2, h2); /* h = (h + pad) */ const uint64_t t0 = X[6]; diff --git a/src/lib/math/bigint/big_ops2.cpp b/src/lib/math/bigint/big_ops2.cpp index 2fd775ccf..68baf3200 100644 --- a/src/lib/math/bigint/big_ops2.cpp +++ b/src/lib/math/bigint/big_ops2.cpp @@ -171,7 +171,7 @@ BigInt& BigInt::mod_sub(const BigInt& s, const BigInt& mod, secure_vector<word>& ws.resize(mod_sw); // is t < s or not? - const word is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw); + const auto is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw); // ws = p - s word borrow = bigint_sub3(ws.data(), mod.data(), mod_sw, s.data(), mod_sw); diff --git a/src/lib/math/bigint/bigint.cpp b/src/lib/math/bigint/bigint.cpp index 2cb9394ce..667035686 100644 --- a/src/lib/math/bigint/bigint.cpp +++ b/src/lib/math/bigint/bigint.cpp @@ -146,7 +146,7 @@ bool BigInt::is_equal(const BigInt& other) const return false; return bigint_ct_is_eq(this->data(), this->sig_words(), - other.data(), other.sig_words()); + other.data(), other.sig_words()).is_set(); } bool BigInt::is_less_than(const BigInt& other) const @@ -160,11 +160,11 @@ bool BigInt::is_less_than(const BigInt& other) const if(other.is_negative() && this->is_negative()) { return !bigint_ct_is_lt(other.data(), other.sig_words(), - this->data(), this->sig_words(), true); + this->data(), this->sig_words(), true).is_set(); } return bigint_ct_is_lt(this->data(), this->sig_words(), - other.data(), other.sig_words()); + other.data(), other.sig_words()).is_set(); } void BigInt::encode_words(word out[], size_t size) const @@ -187,7 +187,7 @@ size_t BigInt::Data::calc_sig_words() const for(size_t i = 0; i != m_reg.size(); ++i) { const word w = m_reg[m_reg.size() - i - 1]; - sub &= CT::is_zero(w); + sub &= CT::Mask<word>::is_zero(w).value(); sig -= sub; } @@ -393,13 +393,18 @@ void BigInt::ct_cond_assign(bool predicate, BigInt& other) const size_t t_words = size(); const size_t o_words = other.size(); + if(o_words < t_words) + grow_to(o_words); + const size_t r_words = std::max(t_words, o_words); - const word mask = CT::expand_mask<word>(predicate); + const auto mask = CT::Mask<word>::expand(predicate); for(size_t i = 0; i != r_words; ++i) { - this->set_word_at(i, CT::select<word>(mask, other.word_at(i), this->word_at(i))); + const word o_word = other.word_at(i); + const word t_word = this->word_at(i); + this->set_word_at(i, mask.select(o_word, t_word)); } } @@ -430,10 +435,13 @@ void BigInt::const_time_lookup(secure_vector<word>& output, BOTAN_ASSERT(vec[i].size() >= words, "Word size as expected in const_time_lookup"); - const word mask = CT::is_equal(i, idx); + const auto mask = CT::Mask<word>::is_equal(i, idx); for(size_t w = 0; w != words; ++w) - output[w] |= CT::select<word>(mask, vec[i].word_at(w), 0); + { + const word viw = vec[i].word_at(w); + output[w] = mask.if_set_return(viw); + } } CT::unpoison(idx); diff --git a/src/lib/math/mp/mp_core.h b/src/lib/math/mp/mp_core.h index 9a19a46be..4829ef6fc 100644 --- a/src/lib/math/mp/mp_core.h +++ b/src/lib/math/mp/mp_core.h @@ -30,14 +30,14 @@ const word MP_WORD_MAX = MP_WORD_MASK; */ inline void bigint_cnd_swap(word cnd, word x[], word y[], size_t size) { - const word mask = CT::expand_mask(cnd); + const auto mask = CT::Mask<word>::expand(cnd); for(size_t i = 0; i != size; ++i) { const word a = x[i]; const word b = y[i]; - x[i] = CT::select(mask, b, a); - y[i] = CT::select(mask, a, b); + x[i] = mask.select(b, a); + y[i] = mask.select(a, b); } } @@ -46,7 +46,7 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size, { BOTAN_ASSERT(x_size >= y_size, "Expected sizes"); - const word mask = CT::expand_mask(cnd); + const auto mask = CT::Mask<word>::expand(cnd); word carry = 0; @@ -56,24 +56,22 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size, for(size_t i = 0; i != blocks; i += 8) { carry = word8_add3(z, x + i, y + i, carry); - - for(size_t j = 0; j != 8; ++j) - x[i+j] = CT::select(mask, z[j], x[i+j]); + mask.select_n(x + i, z, x + i, 8); } for(size_t i = blocks; i != y_size; ++i) { z[0] = word_add(x[i], y[i], &carry); - x[i] = CT::select(mask, z[0], x[i]); + x[i] = mask.select(z[0], x[i]); } for(size_t i = y_size; i != x_size; ++i) { z[0] = word_add(x[i], 0, &carry); - x[i] = CT::select(mask, z[0], x[i]); + x[i] = mask.select(z[0], x[i]); } - return carry & mask; + return mask.if_set_return(carry); } /* @@ -95,7 +93,7 @@ inline word bigint_cnd_sub(word cnd, { BOTAN_ASSERT(x_size >= y_size, "Expected sizes"); - const word mask = CT::expand_mask(cnd); + const auto mask = CT::Mask<word>::expand(cnd); word carry = 0; @@ -105,24 +103,22 @@ inline word bigint_cnd_sub(word cnd, for(size_t i = 0; i != blocks; i += 8) { carry = word8_sub3(z, x + i, y + i, carry); - - for(size_t j = 0; j != 8; ++j) - x[i+j] = CT::select(mask, z[j], x[i+j]); + mask.select_n(x + i, z, x + i, 8); } for(size_t i = blocks; i != y_size; ++i) { z[0] = word_sub(x[i], y[i], &carry); - x[i] = CT::select(mask, z[0], x[i]); + x[i] = mask.select(z[0], x[i]); } for(size_t i = y_size; i != x_size; ++i) { z[0] = word_sub(x[i], 0, &carry); - x[i] = CT::select(mask, z[0], x[i]); + x[i] = mask.select(z[0], x[i]); } - return carry & mask; + return mask.if_set_return(carry); } /* @@ -142,7 +138,7 @@ inline word bigint_cnd_sub(word cnd, word x[], const word y[], size_t size) * * Mask must be either 0 or all 1 bits */ -inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size) +inline void bigint_cnd_addsub(CT::Mask<word> mask, word x[], const word y[], size_t size) { const size_t blocks = size - (size % 8); @@ -158,7 +154,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size) borrow = word8_sub3(t1, x + i, y + i, borrow); for(size_t j = 0; j != 8; ++j) - x[i+j] = CT::select(mask, t0[j], t1[j]); + x[i+j] = mask.select(t0[j], t1[j]); } for(size_t i = blocks; i != size; ++i) @@ -166,7 +162,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size) const word a = word_add(x[i], y[i], &carry); const word s = word_sub(x[i], y[i], &borrow); - x[i] = CT::select(mask, a, s); + x[i] = mask.select(a, s); } } @@ -179,7 +175,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size) * * Returns the carry or borrow resp */ -inline word bigint_cnd_addsub(word mask, word x[], +inline word bigint_cnd_addsub(CT::Mask<word> mask, word x[], const word y[], const word z[], size_t size) { @@ -197,17 +193,17 @@ inline word bigint_cnd_addsub(word mask, word x[], borrow = word8_sub3(t1, x + i, z + i, borrow); for(size_t j = 0; j != 8; ++j) - x[i+j] = CT::select(mask, t0[j], t1[j]); + x[i+j] = mask.select(t0[j], t1[j]); } for(size_t i = blocks; i != size; ++i) { t0[0] = word_add(x[i], y[i], &carry); t1[0] = word_sub(x[i], z[i], &borrow); - x[i] = CT::select(mask, t0[0], t1[0]); + x[i] = mask.select(t0[0], t1[0]); } - return CT::select(mask, carry, borrow); + return mask.select(carry, borrow); } /* @@ -217,13 +213,13 @@ inline word bigint_cnd_addsub(word mask, word x[], */ inline void bigint_cnd_abs(word cnd, word x[], size_t size) { - const word mask = CT::expand_mask(cnd); + const auto mask = CT::Mask<word>::expand(cnd); - word carry = mask & 1; + word carry = mask.if_set_return(1); for(size_t i = 0; i != size; ++i) { const word z = word_add(~x[i], 0, &carry); - x[i] = CT::select(mask, z, x[i]); + x[i] = mask.select(z, x[i]); } } @@ -379,9 +375,10 @@ inline word bigint_sub3(word z[], * @param N length of x and y * @param ws array of at least 2*N words */ -inline word bigint_sub_abs(word z[], - const word x[], const word y[], size_t N, - word ws[]) +inline CT::Mask<word> +bigint_sub_abs(word z[], + const word x[], const word y[], size_t N, + word ws[]) { // Subtract in both direction then conditional copy out the result @@ -544,19 +541,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size, { static_assert(sizeof(word) >= sizeof(uint32_t), "Size assumption"); - const uint32_t LT = static_cast<uint32_t>(-1); - const uint32_t EQ = 0; - const uint32_t GT = 1; + const word LT = static_cast<word>(-1); + const word EQ = 0; + const word GT = 1; const size_t common_elems = std::min(x_size, y_size); - uint32_t result = EQ; // until found otherwise + word result = EQ; // until found otherwise for(size_t i = 0; i != common_elems; i++) { - const word is_eq = CT::is_equal(x[i], y[i]); - const word is_lt = CT::is_less(x[i], y[i]); - result = CT::select<uint32_t>(is_eq, result, CT::select<uint32_t>(is_lt, LT, GT)); + const auto is_eq = CT::Mask<word>::is_equal(x[i], y[i]); + const auto is_lt = CT::Mask<word>::is_lt(x[i], y[i]); + + result = is_eq.select(result, is_lt.select(LT, GT)); } if(x_size < y_size) @@ -566,7 +564,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size, mask |= y[i]; // If any bits were set in high part of y, then x < y - result = CT::select<uint32_t>(CT::is_zero(mask), result, LT); + result = CT::Mask<word>::is_zero(mask).select(result, LT); } else if(y_size < x_size) { @@ -575,7 +573,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size, mask |= x[i]; // If any bits were set in high part of x, then x > y - result = CT::select<uint32_t>(CT::is_zero(mask), result, GT); + result = CT::Mask<word>::is_zero(mask).select(result, GT); } CT::unpoison(result); @@ -588,19 +586,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size, * Return ~0 if x[0:x_size] < y[0:y_size] or 0 otherwise * If lt_or_equal is true, returns ~0 also for x == y */ -inline word bigint_ct_is_lt(const word x[], size_t x_size, - const word y[], size_t y_size, - bool lt_or_equal = false) +inline CT::Mask<word> +bigint_ct_is_lt(const word x[], size_t x_size, + const word y[], size_t y_size, + bool lt_or_equal = false) { const size_t common_elems = std::min(x_size, y_size); - word is_lt = CT::expand_mask<word>(lt_or_equal); + auto is_lt = CT::Mask<word>::expand(lt_or_equal); for(size_t i = 0; i != common_elems; i++) { - const word eq = CT::is_equal(x[i], y[i]); - const word lt = CT::is_less(x[i], y[i]); - is_lt = CT::select(eq, is_lt, lt); + const auto eq = CT::Mask<word>::is_equal(x[i], y[i]); + const auto lt = CT::Mask<word>::is_lt(x[i], y[i]); + is_lt = eq.select_mask(is_lt, lt); } if(x_size < y_size) @@ -609,7 +608,7 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size, for(size_t i = x_size; i != y_size; i++) mask |= y[i]; // If any bits were set in high part of y, then is_lt should be forced true - is_lt |= ~CT::is_zero(mask); + is_lt |= CT::Mask<word>::expand(mask); } else if(y_size < x_size) { @@ -618,15 +617,15 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size, mask |= x[i]; // If any bits were set in high part of x, then is_lt should be false - is_lt &= CT::is_zero(mask); + is_lt &= CT::Mask<word>::is_zero(mask); } - CT::unpoison(is_lt); return is_lt; } -inline word bigint_ct_is_eq(const word x[], size_t x_size, - const word y[], size_t y_size) +inline CT::Mask<word> +bigint_ct_is_eq(const word x[], size_t x_size, + const word y[], size_t y_size) { const size_t common_elems = std::min(x_size, y_size); @@ -649,9 +648,7 @@ inline word bigint_ct_is_eq(const word x[], size_t x_size, diff |= x[i]; } - const word is_equal = CT::is_zero(diff); - CT::unpoison(is_equal); - return is_equal; + return CT::Mask<word>::is_zero(diff); } /** diff --git a/src/lib/math/mp/mp_karat.cpp b/src/lib/math/mp/mp_karat.cpp index c0fc5304b..8bd7cf58d 100644 --- a/src/lib/math/mp/mp_karat.cpp +++ b/src/lib/math/mp/mp_karat.cpp @@ -124,9 +124,9 @@ void karatsuba_mul(word z[], const word x[], const word y[], size_t N, */ // First compute (X_lo - X_hi)*(Y_hi - Y_lo) - const word cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace); - const word cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace); - const word neg_mask = ~(cmp0 ^ cmp1); + const auto cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace); + const auto cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace); + const auto neg_mask = ~(cmp0 ^ cmp1); karatsuba_mul(ws0, z0, z1, N2, ws1); diff --git a/src/lib/math/numbertheory/monty_exp.cpp b/src/lib/math/numbertheory/monty_exp.cpp index 2b5bbd81d..7590005a0 100644 --- a/src/lib/math/numbertheory/monty_exp.cpp +++ b/src/lib/math/numbertheory/monty_exp.cpp @@ -85,10 +85,12 @@ void const_time_lookup(secure_vector<word>& output, BOTAN_ASSERT(vec.size() >= words, "Word size as expected in const_time_lookup"); - const word mask = CT::is_equal<word>(i, nibble); + const auto mask = CT::Mask<word>::is_equal(i, nibble); for(size_t w = 0; w != words; ++w) - output[w] |= (mask & vec[w]); + { + output[w] |= mask.if_set_return(vec[w]); + } } } diff --git a/src/lib/modes/mode_pad/mode_pad.cpp b/src/lib/modes/mode_pad/mode_pad.cpp index e65114c88..5c949e9cf 100644 --- a/src/lib/modes/mode_pad/mode_pad.cpp +++ b/src/lib/modes/mode_pad/mode_pad.cpp @@ -57,21 +57,30 @@ size_t PKCS7_Padding::unpad(const uint8_t input[], size_t input_length) const return input_length; CT::poison(input, input_length); - size_t bad_input = 0; + const uint8_t last_byte = input[input_length-1]; - bad_input |= CT::expand_mask<size_t>(last_byte > input_length); + /* + The input should == the block size so if the last byte exceeds + that then the padding is certainly invalid + */ + auto bad_input = CT::Mask<size_t>::is_gt(last_byte, input_length); const size_t pad_pos = input_length - last_byte; for(size_t i = 0; i != input_length - 1; ++i) { - const uint8_t in_range = CT::expand_mask<uint8_t>(i >= pad_pos); - bad_input |= in_range & (~CT::is_equal(input[i], last_byte)); + // Does this byte equal the expected pad byte? + const auto pad_eq = CT::Mask<size_t>::is_equal(input[i], last_byte); + + // Ignore values that are not part of the padding + const auto in_range = CT::Mask<size_t>::is_gte(i, pad_pos); + bad_input |= in_range & (~pad_eq); } CT::unpoison(input, input_length); - return CT::conditional_return(bad_input, input_length, pad_pos); + + return bad_input.select_and_unpoison(input_length, pad_pos); } /* @@ -99,21 +108,24 @@ size_t ANSI_X923_Padding::unpad(const uint8_t input[], size_t input_length) cons return input_length; CT::poison(input, input_length); + const size_t last_byte = input[input_length-1]; - uint8_t bad_input = 0; - bad_input |= CT::expand_mask<uint8_t>(last_byte > input_length); + auto bad_input = CT::Mask<size_t>::is_gt(last_byte, input_length); const size_t pad_pos = input_length - last_byte; for(size_t i = 0; i != input_length - 1; ++i) { - const uint8_t in_range = CT::expand_mask<uint8_t>(i >= pad_pos); - bad_input |= CT::expand_mask(input[i]) & in_range; + // Ignore values that are not part of the padding + const auto in_range = CT::Mask<size_t>::is_gte(i, pad_pos); + const auto pad_is_nonzero = CT::Mask<size_t>::expand(input[i]); + bad_input |= pad_is_nonzero & in_range; } CT::unpoison(input, input_length); - return CT::conditional_return(bad_input, input_length, pad_pos); + + return bad_input.select_and_unpoison(input_length, pad_pos); } /* @@ -139,22 +151,26 @@ size_t OneAndZeros_Padding::unpad(const uint8_t input[], size_t input_length) co CT::poison(input, input_length); - uint8_t bad_input = 0; - uint8_t seen_one = 0; + auto bad_input = CT::Mask<uint8_t>::cleared(); + auto seen_0x80 = CT::Mask<uint8_t>::cleared(); + size_t pad_pos = input_length - 1; size_t i = input_length; while(i) { - seen_one |= CT::is_equal<uint8_t>(input[i-1], 0x80); - pad_pos -= CT::select<uint8_t>(~seen_one, 1, 0); - bad_input |= ~CT::is_zero<uint8_t>(input[i-1]) & ~seen_one; + const auto is_0x80 = CT::Mask<uint8_t>::is_equal(input[i-1], 0x80); + const auto is_zero = CT::Mask<uint8_t>::is_zero(input[i-1]); + + seen_0x80 |= is_0x80; + pad_pos -= seen_0x80.if_not_set_return(1); + bad_input |= ~seen_0x80 & ~is_zero; i--; } - bad_input |= ~seen_one; + bad_input |= ~seen_0x80; CT::unpoison(input, input_length); - return CT::conditional_return(bad_input, input_length, pad_pos); + return bad_input.select_and_unpoison(input_length, pad_pos); } /* @@ -183,20 +199,23 @@ size_t ESP_Padding::unpad(const uint8_t input[], size_t input_length) const CT::poison(input, input_length); const size_t last_byte = input[input_length-1]; - uint8_t bad_input = 0; - bad_input |= CT::is_zero(last_byte) | CT::expand_mask<uint8_t>(last_byte > input_length); + + auto bad_input = CT::Mask<uint8_t>::is_zero(last_byte) | + CT::Mask<uint8_t>::is_gt(last_byte, input_length); const size_t pad_pos = input_length - last_byte; size_t i = input_length - 1; while(i) { - const uint8_t in_range = CT::expand_mask<uint8_t>(i > pad_pos); - bad_input |= (~CT::is_equal<uint8_t>(input[i-1], input[i]-1)) & in_range; + const auto in_range = CT::Mask<uint8_t>::is_gt(i, pad_pos); + const auto incrementing = CT::Mask<uint8_t>::is_equal(input[i-1], input[i]-1); + + bad_input |= in_range & ~incrementing; --i; } CT::unpoison(input, input_length); - return CT::conditional_return(bad_input, input_length, pad_pos); + return bad_input.select_and_unpoison(input_length, pad_pos); } diff --git a/src/lib/pk_pad/eme_oaep/oaep.cpp b/src/lib/pk_pad/eme_oaep/oaep.cpp index 398202bd1..9a8676ab9 100644 --- a/src/lib/pk_pad/eme_oaep/oaep.cpp +++ b/src/lib/pk_pad/eme_oaep/oaep.cpp @@ -72,7 +72,7 @@ secure_vector<uint8_t> OAEP::unpad(uint8_t& valid_mask, Therefore, the first byte can always be skipped safely. */ - uint8_t skip_first = CT::is_zero<uint8_t>(in[0]) & 0x01; + const uint8_t skip_first = CT::Mask<uint8_t>::is_zero(in[0]).if_set_return(1); secure_vector<uint8_t> input(in + skip_first, in + in_length); @@ -105,37 +105,37 @@ oaep_find_delim(uint8_t& valid_mask, CT::poison(input, input_len); size_t delim_idx = 2 * hlen; - uint8_t waiting_for_delim = 0xFF; - uint8_t bad_input = 0; + CT::Mask<uint8_t> waiting_for_delim = CT::Mask<uint8_t>::set(); + CT::Mask<uint8_t> bad_input = CT::Mask<uint8_t>::cleared(); for(size_t i = delim_idx; i < input_len; ++i) { - const uint8_t zero_m = CT::is_zero<uint8_t>(input[i]); - const uint8_t one_m = CT::is_equal<uint8_t>(input[i], 1); + const auto zero_m = CT::Mask<uint8_t>::is_zero(input[i]); + const auto one_m = CT::Mask<uint8_t>::is_equal(input[i], 1); - const uint8_t add_m = waiting_for_delim & zero_m; + const auto add_m = waiting_for_delim & zero_m; bad_input |= waiting_for_delim & ~(zero_m | one_m); - delim_idx += CT::select<uint8_t>(add_m, 1, 0); + delim_idx += add_m.if_set_return(1); waiting_for_delim &= zero_m; } // If we never saw any non-zero byte, then it's not valid input bad_input |= waiting_for_delim; - bad_input |= CT::is_equal<uint8_t>(constant_time_compare(&input[hlen], Phash.data(), hlen), false); + bad_input |= CT::Mask<uint8_t>::is_zero(ct_compare_u8(&input[hlen], Phash.data(), hlen)); - delim_idx &= ~CT::expand_mask<size_t>(bad_input); + delim_idx = CT::Mask<size_t>::expand(bad_input.value()).if_not_set_return(delim_idx); CT::unpoison(input, input_len); CT::unpoison(&bad_input, 1); CT::unpoison(&delim_idx, 1); - valid_mask = ~bad_input; + valid_mask = (~bad_input).value(); secure_vector<uint8_t> output(input + delim_idx + 1, input + input_len); - CT::cond_zero_mem(bad_input, output.data(), output.size()); + bad_input.if_set_zero_out(output.data(), output.size()); return output; } diff --git a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp index 58aadbc38..597b7c26a 100644 --- a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp +++ b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp @@ -58,33 +58,33 @@ secure_vector<uint8_t> EME_PKCS1v15::unpad(uint8_t& valid_mask, CT::poison(in, inlen); - uint8_t bad_input_m = 0; - uint8_t seen_zero_m = 0; + CT::Mask<uint8_t> bad_input_m = CT::Mask<uint8_t>::cleared(); + CT::Mask<uint8_t> seen_zero_m = CT::Mask<uint8_t>::cleared(); size_t delim_idx = 0; - bad_input_m |= ~CT::is_equal<uint8_t>(in[0], 0); - bad_input_m |= ~CT::is_equal<uint8_t>(in[1], 2); + bad_input_m |= ~CT::Mask<uint8_t>::is_equal(in[0], 0); + bad_input_m |= ~CT::Mask<uint8_t>::is_equal(in[1], 2); for(size_t i = 2; i < inlen; ++i) { - const uint8_t is_zero_m = CT::is_zero<uint8_t>(in[i]); + const auto is_zero_m = CT::Mask<uint8_t>::is_zero(in[i]); - delim_idx += CT::select<uint8_t>(~seen_zero_m, 1, 0); + delim_idx += seen_zero_m.if_not_set_return(1); - bad_input_m |= is_zero_m & CT::expand_mask<uint8_t>(i < 10); + bad_input_m |= is_zero_m & CT::Mask<uint8_t>(CT::Mask<size_t>::is_lt(i, 10)); seen_zero_m |= is_zero_m; } bad_input_m |= ~seen_zero_m; - bad_input_m |= CT::is_less<size_t>(delim_idx, 8); + bad_input_m |= CT::Mask<uint8_t>(CT::Mask<size_t>::is_lt(delim_idx, 8)); CT::unpoison(in, inlen); CT::unpoison(bad_input_m); CT::unpoison(delim_idx); secure_vector<uint8_t> output(&in[delim_idx + 2], &in[inlen]); - CT::cond_zero_mem(bad_input_m, output.data(), output.size()); - valid_mask = ~bad_input_m; + bad_input_m.if_set_zero_out(output.data(), output.size()); + valid_mask = ~bad_input_m.value(); return output; } diff --git a/src/lib/pk_pad/iso9796/iso9796.cpp b/src/lib/pk_pad/iso9796/iso9796.cpp index 99a1dfd29..28a603bf9 100644 --- a/src/lib/pk_pad/iso9796/iso9796.cpp +++ b/src/lib/pk_pad/iso9796/iso9796.cpp @@ -142,26 +142,29 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded, //recover msg1 and salt size_t msg1_offset = 1; - uint8_t waiting_for_delim = 0xFF; - uint8_t bad_input = 0; + + auto waiting_for_delim = CT::Mask<uint8_t>::set(); + auto bad_input = CT::Mask<uint8_t>::cleared(); + for(size_t j = 0; j < DB_size; ++j) { - const uint8_t one_m = CT::is_equal<uint8_t>(DB[j], 0x01); - const uint8_t zero_m = CT::is_zero(DB[j]); - const uint8_t add_m = waiting_for_delim & zero_m; + const auto is_zero = CT::Mask<uint8_t>::is_zero(DB[j]); + const auto is_one = CT::Mask<uint8_t>::is_equal(DB[j], 0x01); - bad_input |= waiting_for_delim & ~(zero_m | one_m); - msg1_offset += CT::select<uint8_t>(add_m, 1, 0); + const auto add_m = waiting_for_delim & is_zero; - waiting_for_delim &= zero_m; + bad_input |= waiting_for_delim & ~(is_zero | is_one); + msg1_offset += add_m.if_set_return(1); + + waiting_for_delim &= is_zero; } //invalid, if delimiter 0x01 was not found or msg1_offset is too big bad_input |= waiting_for_delim; - bad_input |= CT::is_less(coded.size(), tLength + HASH_SIZE + msg1_offset + SALT_SIZE); + bad_input |= CT::Mask<size_t>::is_lt(coded.size(), tLength + HASH_SIZE + msg1_offset + SALT_SIZE); //in case that msg1_offset is too big, just continue with offset = 0. - msg1_offset = CT::select<size_t>(bad_input, 0, msg1_offset); + msg1_offset = CT::Mask<size_t>::expand(bad_input.value()).if_not_set_return(msg1_offset); CT::unpoison(coded.data(), coded.size()); CT::unpoison(msg1_offset); @@ -172,8 +175,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded, coded.end() - tLength - HASH_SIZE); //compute H2(C||msg1||H(msg2)||S*). * indicates a recovered value - const size_t capacity = (key_bits - 2 + 7) / 8 - HASH_SIZE - - SALT_SIZE - tLength - 1; + const size_t capacity = (key_bits - 2 + 7) / 8 - HASH_SIZE - SALT_SIZE - tLength - 1; secure_vector<uint8_t> msg1raw; secure_vector<uint8_t> msg2; if(raw.size() > capacity) @@ -188,7 +190,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded, } msg2 = hash->final(); - uint64_t msg1rawLength = msg1raw.size(); + const uint64_t msg1rawLength = msg1raw.size(); hash->update_be(msg1rawLength * 8); hash->update(msg1raw); hash->update(msg2); @@ -196,7 +198,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded, secure_vector<uint8_t> H3 = hash->final(); //compute H3(C*||msg1*||H(msg2)||S*) * indicates a recovered value - uint64_t msgLength = msg1.size(); + const uint64_t msgLength = msg1.size(); hash->update_be(msgLength * 8); hash->update(msg1); hash->update(msg2); @@ -204,10 +206,10 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded, secure_vector<uint8_t> H2 = hash->final(); //check if H3 == H2 - bad_input |= CT::is_equal<uint8_t>(constant_time_compare(H3.data(), H2.data(), HASH_SIZE), false); + bad_input |= CT::Mask<uint8_t>::is_zero(ct_compare_u8(H3.data(), H2.data(), HASH_SIZE)); CT::unpoison(bad_input); - return (bad_input == 0); + return (bad_input.is_set() == false); } } diff --git a/src/lib/pubkey/dlies/dlies.cpp b/src/lib/pubkey/dlies/dlies.cpp index 5465401d1..aa214fd8b 100644 --- a/src/lib/pubkey/dlies/dlies.cpp +++ b/src/lib/pubkey/dlies/dlies.cpp @@ -7,7 +7,6 @@ */ #include <botan/dlies.h> -#include <botan/internal/ct_utils.h> #include <limits> namespace Botan { @@ -180,7 +179,7 @@ secure_vector<uint8_t> DLIES_Decryptor::do_decrypt(uint8_t& valid_mask, secure_vector<uint8_t> tag(msg + m_pub_key_size + ciphertext_len, msg + m_pub_key_size + ciphertext_len + m_mac->output_length()); - valid_mask = CT::expand_mask<uint8_t>(constant_time_compare(tag.data(), calculated_tag.data(), tag.size())); + valid_mask = ct_compare_u8(tag.data(), calculated_tag.data(), tag.size()); // decrypt if(m_cipher) diff --git a/src/lib/pubkey/ec_group/point_mul.cpp b/src/lib/pubkey/ec_group/point_mul.cpp index da3abaacc..2707a98f3 100644 --- a/src/lib/pubkey/ec_group/point_mul.cpp +++ b/src/lib/pubkey/ec_group/point_mul.cpp @@ -128,9 +128,9 @@ PointGFp PointGFp_Base_Point_Precompute::mul(const BigInt& k, const word w = scalar.get_substring(2*window, 2); - const word w_is_1 = CT::is_equal<word>(w, 1); - const word w_is_2 = CT::is_equal<word>(w, 2); - const word w_is_3 = CT::is_equal<word>(w, 3); + const auto w_is_1 = CT::Mask<word>::is_equal(w, 1); + const auto w_is_2 = CT::Mask<word>::is_equal(w, 2); + const auto w_is_3 = CT::Mask<word>::is_equal(w, 3); for(size_t j = 0; j != elem_size; ++j) { @@ -138,7 +138,7 @@ PointGFp PointGFp_Base_Point_Precompute::mul(const BigInt& k, const word w2 = m_W[base_addr + 1*elem_size + j]; const word w3 = m_W[base_addr + 2*elem_size + j]; - Wt[j] = CT::select3<word>(w_is_1, w1, w_is_2, w2, w_is_3, w3, 0); + Wt[j] = w_is_1.select(w1, w_is_2.select(w2, w_is_3.select(w3, 0))); } R.add_affine(&Wt[0], m_p_words, &Wt[m_p_words], m_p_words, ws); @@ -255,11 +255,11 @@ PointGFp PointGFp_Var_Point_Precompute::mul(const BigInt& k, clear_mem(e.data(), e.size()); for(size_t i = 1; i != window_elems; ++i) { - const word wmask = CT::is_equal<word>(w, i); + const auto wmask = CT::Mask<word>::is_equal(w, i); for(size_t j = 0; j != elem_size; ++j) { - e[j] |= wmask & m_T[i * elem_size + j]; + e[j] |= wmask.if_set_return(m_T[i * elem_size + j]); } } @@ -282,10 +282,12 @@ PointGFp PointGFp_Var_Point_Precompute::mul(const BigInt& k, clear_mem(e.data(), e.size()); for(size_t i = 1; i != window_elems; ++i) { - const word wmask = CT::is_equal<word>(w, i); + const auto wmask = CT::Mask<word>::is_equal(w, i); for(size_t j = 0; j != elem_size; ++j) - e[j] |= wmask & m_T[i * elem_size + j]; + { + e[j] |= wmask.if_set_return(m_T[i * elem_size + j]); + } } R.add(&e[0], m_p_words, &e[m_p_words], m_p_words, &e[2*m_p_words], m_p_words, ws); diff --git a/src/lib/pubkey/ecies/ecies.cpp b/src/lib/pubkey/ecies/ecies.cpp index b35ecc107..864e0b72a 100644 --- a/src/lib/pubkey/ecies/ecies.cpp +++ b/src/lib/pubkey/ecies/ecies.cpp @@ -10,8 +10,6 @@ #include <botan/numthry.h> #include <botan/cipher_mode.h> #include <botan/mac.h> - -#include <botan/internal/ct_utils.h> #include <botan/internal/pk_ops_impl.h> namespace Botan { @@ -386,7 +384,7 @@ secure_vector<uint8_t> ECIES_Decryptor::do_decrypt(uint8_t& valid_mask, const ui m_mac->update(m_label); } const secure_vector<uint8_t> calculated_mac = m_mac->final(); - valid_mask = CT::expand_mask<uint8_t>(constant_time_compare(mac_data.data(), calculated_mac.data(), mac_data.size())); + valid_mask = ct_compare_u8(mac_data.data(), calculated_mac.data(), mac_data.size()); if(valid_mask) { diff --git a/src/lib/pubkey/pubkey.cpp b/src/lib/pubkey/pubkey.cpp index bb0170548..d98b5dc9e 100644 --- a/src/lib/pubkey/pubkey.cpp +++ b/src/lib/pubkey/pubkey.cpp @@ -37,10 +37,11 @@ PK_Decryptor::decrypt_or_random(const uint8_t in[], { const secure_vector<uint8_t> fake_pms = rng.random_vec(expected_pt_len); - uint8_t valid_mask = 0; - secure_vector<uint8_t> decoded = do_decrypt(valid_mask, in, length); + uint8_t decrypt_valid = 0; + secure_vector<uint8_t> decoded = do_decrypt(decrypt_valid, in, length); - valid_mask &= CT::is_equal(decoded.size(), expected_pt_len); + auto valid_mask = CT::Mask<uint8_t>::is_equal(decrypt_valid, 0xFF); + valid_mask &= CT::Mask<uint8_t>(CT::Mask<size_t>::is_zero(decoded.size() ^ expected_pt_len)); decoded.resize(expected_pt_len); @@ -62,14 +63,13 @@ PK_Decryptor::decrypt_or_random(const uint8_t in[], BOTAN_ASSERT(off < expected_pt_len, "Offset in range of plaintext"); - valid_mask &= CT::is_equal(decoded[off], exp); + auto eq = CT::Mask<uint8_t>::is_equal(decoded[off], exp); + + valid_mask &= eq; } - CT::conditional_copy_mem(valid_mask, - /*output*/decoded.data(), - /*from0*/decoded.data(), - /*from1*/fake_pms.data(), - expected_pt_len); + // If valid_mask is false, assign fake pre master instead + valid_mask.select_n(decoded.data(), decoded.data(), fake_pms.data(), expected_pt_len); return decoded; } diff --git a/src/lib/tls/tls_cbc/tls_cbc.cpp b/src/lib/tls/tls_cbc/tls_cbc.cpp index 7376e655b..f3ea17d42 100644 --- a/src/lib/tls/tls_cbc/tls_cbc.cpp +++ b/src/lib/tls/tls_cbc/tls_cbc.cpp @@ -235,17 +235,17 @@ uint16_t check_tls_cbc_padding(const uint8_t record[], size_t record_len) const uint8_t pad_byte = record[record_len-1]; const uint16_t pad_bytes = 1 + pad_byte; - uint16_t pad_invalid = CT::is_less<uint16_t>(rec16, pad_bytes); + auto pad_invalid = CT::Mask<uint16_t>::is_lt(rec16, pad_byte); for(uint16_t i = rec16 - to_check; i != rec16; ++i) { const uint16_t offset = rec16 - i; - const uint16_t in_pad_range = CT::is_lte<uint16_t>(offset, pad_bytes); - pad_invalid |= (in_pad_range & (record[i] ^ pad_byte)); + const auto in_pad_range = CT::Mask<uint16_t>::is_lte(offset, pad_bytes); + const auto pad_correct = CT::Mask<uint16_t>::is_equal(record[i], pad_byte); + pad_invalid |= in_pad_range & ~pad_correct; } - const uint16_t pad_invalid_mask = CT::expand_mask<uint16_t>(pad_invalid); - return CT::select<uint16_t>(pad_invalid_mask, 0, pad_byte + 1); + return pad_invalid.if_not_set_return(pad_bytes); } void TLS_CBC_HMAC_AEAD_Decryption::cbc_decrypt_record(uint8_t record_contents[], size_t record_len) @@ -337,7 +337,7 @@ void TLS_CBC_HMAC_AEAD_Decryption::perform_additional_compressions(size_t plen, const uint16_t current_compressions = ((L2 + block_size - 1 - max_bytes_in_first_block) / block_size); // number of additional compressions we have to perform const uint16_t add_compressions = max_compresssions - current_compressions; - const uint8_t equal = CT::is_equal(max_compresssions, current_compressions) & 0x01; + const uint8_t equal = CT::Mask<uint16_t>::is_equal(max_compresssions, current_compressions).if_set_return(1); // We compute the data length we need to achieve the number of compressions. // If there are no compressions, we just add 55/111 dummy bytes so that no // compression is performed. @@ -418,8 +418,8 @@ void TLS_CBC_HMAC_AEAD_Decryption::finish(secure_vector<uint8_t>& buffer, size_t (sending empty records, instead of 1/(n-1) splitting) */ - const uint16_t size_ok_mask = CT::is_lte<uint16_t>(static_cast<uint16_t>(tag_size() + pad_size), static_cast<uint16_t>(record_len)); - pad_size &= size_ok_mask; + const auto size_ok_mask = CT::Mask<uint16_t>::is_lte(tag_size() + pad_size, record_len); + pad_size = size_ok_mask.if_set_return(pad_size); CT::unpoison(record_contents, record_len); @@ -442,11 +442,11 @@ void TLS_CBC_HMAC_AEAD_Decryption::finish(secure_vector<uint8_t>& buffer, size_t const bool mac_ok = constant_time_compare(&record_contents[mac_offset], mac_buf.data(), tag_size()); - const uint16_t ok_mask = size_ok_mask & CT::expand_mask<uint16_t>(mac_ok) & CT::expand_mask<uint16_t>(pad_size); + const auto ok_mask = size_ok_mask & CT::Mask<uint16_t>::expand(mac_ok) & CT::Mask<uint16_t>::expand(pad_size); CT::unpoison(ok_mask); - if(ok_mask) + if(ok_mask.is_set()) { buffer.insert(buffer.end(), plaintext_block, plaintext_block + plaintext_length); } diff --git a/src/lib/utils/ct_utils.h b/src/lib/utils/ct_utils.h index 63b8f4640..eb510baa2 100644 --- a/src/lib/utils/ct_utils.h +++ b/src/lib/utils/ct_utils.h @@ -6,13 +6,13 @@ * Wagner, Molnar, et al "The Program Counter Security Model" * * (C) 2010 Falko Strenzke -* (C) 2015,2016 Jack Lloyd +* (C) 2015,2016,2018 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ -#ifndef BOTAN_TIMING_ATTACK_CM_H_ -#define BOTAN_TIMING_ATTACK_CM_H_ +#ifndef BOTAN_CT_UTILS_H_ +#define BOTAN_CT_UTILS_H_ #include <botan/secmem.h> #include <type_traits> @@ -75,130 +75,285 @@ inline void unpoison(T& p) #endif } -/* Mask generation */ - -template<typename T> -inline constexpr T expand_top_bit(T a) - { - static_assert(std::is_unsigned<T>::value, "unsigned integer type required"); - return static_cast<T>(0) - (a >> (sizeof(T)*8-1)); - } - -template<typename T> -inline constexpr T is_zero(T x) - { - static_assert(std::is_unsigned<T>::value, "unsigned integer type required"); - return expand_top_bit<T>(~x & (x - 1)); - } - -/* -* T should be an unsigned machine integer type -* Expand to a mask used for other operations -* @param in an integer -* @return If n is zero, returns zero. Otherwise -* returns a T with all bits set for use as a mask with -* select. +/** +* A Mask type used for constant-time operations. A Mask<T> always has value +* either 0 (all bits cleared) or ~0 (all bits set). All operations in a Mask<T> +* are intended to compile to code which does not contain conditional jumps. +* This must be verified with tooling (eg binary disassembly or using valgrind) +* since you never know what a compiler might do. */ template<typename T> -inline constexpr T expand_mask(T x) +class Mask { - static_assert(std::is_unsigned<T>::value, "unsigned integer type required"); - return ~is_zero(x); - } - -template<typename T> -inline constexpr T select(T mask, T from0, T from1) - { - static_assert(std::is_unsigned<T>::value, "unsigned integer type required"); - //return static_cast<T>((from0 & mask) | (from1 & ~mask)); - return static_cast<T>(from1 ^ (mask & (from0 ^ from1))); - } - -template<typename T> -inline constexpr T select2(T mask0, T val0, T mask1, T val1, T val2) - { - return select<T>(mask0, val0, select<T>(mask1, val1, val2)); - } - -template<typename T> -inline constexpr T select3(T mask0, T val0, T mask1, T val1, T mask2, T val2, T val3) - { - return select2<T>(mask0, val0, mask1, val1, select<T>(mask2, val2, val3)); - } - -template<typename PredT, typename ValT> -inline constexpr ValT val_or_zero(PredT pred_val, ValT val) - { - return select(CT::expand_mask<ValT>(pred_val), val, static_cast<ValT>(0)); - } + public: + static_assert(std::is_unsigned<T>::value, "CT::Mask only defined for unsigned integer types"); + + Mask(const Mask<T>& other) = default; + Mask<T>& operator=(const Mask<T>& other) = default; + + /** + * Derive a Mask from a Mask of a larger type + */ + template<typename U> + Mask(Mask<U> o) : m_mask(o.value()) + { + static_assert(sizeof(U) > sizeof(T), "sizes ok"); + } + + /** + * Return a Mask<T> with all bits set + */ + static Mask<T> set() + { + return Mask<T>(~0); + } + + /** + * Return a Mask<T> with all bits cleared + */ + static Mask<T> cleared() + { + return Mask<T>(0); + } + + /** + * Return a Mask<T> which is set if v is != 0 + */ + static Mask<T> expand(T v) + { + return ~Mask<T>::is_zero(v); + } + + /** + * Return a Mask<T> which is set if v is == 0 or cleared otherwise + */ + static Mask<T> is_zero(T x) + { + return Mask<T>(expand_top_bit(~x & (x - 1))); + } + + /** + * Return a Mask<T> which is set if x == y + */ + static Mask<T> is_equal(T x, T y) + { + return Mask<T>::is_zero(static_cast<T>(x ^ y)); + } + + /** + * Return a Mask<T> which is set if x < y + */ + static Mask<T> is_lt(T x, T y) + { + return Mask<T>(expand_top_bit(x^((x^y) | ((x-y)^x)))); + } + + /** + * Return a Mask<T> which is set if x > y + */ + static Mask<T> is_gt(T x, T y) + { + return Mask<T>::is_lt(y, x); + } + + /** + * Return a Mask<T> which is set if x <= y + */ + static Mask<T> is_lte(T x, T y) + { + return ~Mask<T>::is_gt(x, y); + } + + /** + * Return a Mask<T> which is set if x >= y + */ + static Mask<T> is_gte(T x, T y) + { + return ~Mask<T>::is_lt(x, y); + } + + /** + * AND-combine two masks + */ + Mask<T>& operator&=(Mask<T> o) + { + m_mask &= o.value(); + return (*this); + } + + /** + * XOR-combine two masks + */ + Mask<T>& operator^=(Mask<T> o) + { + m_mask ^= o.value(); + return (*this); + } + + /** + * OR-combine two masks + */ + Mask<T>& operator|=(Mask<T> o) + { + m_mask |= o.value(); + return (*this); + } + + /** + * AND-combine two masks + */ + friend Mask<T> operator&(Mask<T> x, Mask<T> y) + { + return Mask<T>(x.value() & y.value()); + } + + /** + * XOR-combine two masks + */ + friend Mask<T> operator^(Mask<T> x, Mask<T> y) + { + return Mask<T>(x.value() ^ y.value()); + } + + /** + * OR-combine two masks + */ + friend Mask<T> operator|(Mask<T> x, Mask<T> y) + { + return Mask<T>(x.value() | y.value()); + } + + /** + * Negate this mask + */ + Mask<T> operator~() const + { + return Mask<T>(~value()); + } + + /** + * Return x if the mask is set, or otherwise zero + */ + T if_set_return(T x) const + { + return m_mask & x; + } + + /** + * Return x if the mask is cleared, or otherwise zero + */ + T if_not_set_return(T x) const + { + return ~m_mask & x; + } + + /** + * If this mask is set, return x, otherwise return y + */ + T select(T x, T y) const + { + // (x & value()) | (y & ~value()) + return static_cast<T>(y ^ (value() & (x ^ y))); + } + + T select_and_unpoison(T x, T y) const + { + T r = this->select(x, y); + CT::unpoison(r); + return r; + } + + /** + * If this mask is set, return x, otherwise return y + */ + Mask<T> select_mask(Mask<T> x, Mask<T> y) const + { + return Mask<T>(select(x.value(), y.value())); + } + + /** + * Conditionally set output to x or y, depending on if mask is set or + * cleared (resp) + */ + void select_n(T output[], const T x[], const T y[], size_t len) const + { + for(size_t i = 0; i != len; ++i) + output[i] = this->select(x[i], y[i]); + } + + /** + * If this mask is set, zero out buf, otherwise do nothing + */ + void if_set_zero_out(T buf[], size_t elems) + { + for(size_t i = 0; i != elems; ++i) + { + buf[i] = this->if_not_set_return(buf[i]); + } + } + + /** + * Return the value of the mask, unpoisoned + */ + T unpoisoned_value() const + { + T r = value(); + CT::unpoison(r); + return r; + } + + /** + * Return true iff this mask is set + */ + bool is_set() const + { + return unpoisoned_value() != 0; + } + + /** + * Return the underlying value of the mask + */ + T value() const + { + return m_mask; + } + + private: + /** + * If top bit of arg is set, return ~0. Otherwise return 0. + */ + static T expand_top_bit(T a) + { + return static_cast<T>(0) - (a >> (sizeof(T)*8-1)); + } + + Mask(T m) : m_mask(m) {} + + T m_mask; + }; template<typename T> -inline constexpr T is_equal(T x, T y) +inline Mask<T> conditional_copy_mem(T cnd, + T* to, + const T* from0, + const T* from1, + size_t elems) { - return is_zero<T>(x ^ y); - } - -template<typename T> -inline constexpr T is_less(T a, T b) - { - return expand_top_bit<T>(a ^ ((a^b) | ((a-b)^a))); - } - -template<typename T> -inline constexpr T is_lte(T a, T b) - { - return CT::is_less(a, b) | CT::is_equal(a, b); - } - -template<typename C, typename T> -inline T conditional_return(C condvar, T left, T right) - { - const T val = CT::select(CT::expand_mask<T>(condvar), left, right); - CT::unpoison(val); - return val; - } - -template<typename T> -inline T conditional_copy_mem(T value, - T* to, - const T* from0, - const T* from1, - size_t elems) - { - const T mask = CT::expand_mask(value); - - for(size_t i = 0; i != elems; ++i) - { - to[i] = CT::select(mask, from0[i], from1[i]); - } - + const auto mask = CT::Mask<T>::expand(cnd); + mask.select_n(to, from0, from1, elems); return mask; } -template<typename T> -inline void cond_zero_mem(T cond, - T* array, - size_t elems) - { - const T mask = CT::expand_mask(cond); - const T zero(0); - - for(size_t i = 0; i != elems; ++i) - { - array[i] = CT::select(mask, zero, array[i]); - } - } - inline secure_vector<uint8_t> strip_leading_zeros(const uint8_t in[], size_t length) { size_t leading_zeros = 0; - uint8_t only_zeros = 0xFF; + auto only_zeros = Mask<uint8_t>::set(); for(size_t i = 0; i != length; ++i) { - only_zeros = only_zeros & CT::is_zero<uint8_t>(in[i]); - leading_zeros += CT::select<uint8_t>(only_zeros, 1, 0); + only_zeros &= CT::Mask<uint8_t>::is_zero(in[i]); + leading_zeros += only_zeros.if_set_return(1); } return secure_vector<uint8_t>(in + leading_zeros, in + length); diff --git a/src/lib/utils/mem_ops.cpp b/src/lib/utils/mem_ops.cpp index 668437e9f..460fc4b69 100644 --- a/src/lib/utils/mem_ops.cpp +++ b/src/lib/utils/mem_ops.cpp @@ -5,6 +5,7 @@ */ #include <botan/mem_ops.h> +#include <botan/internal/ct_utils.h> #include <cstdlib> #include <new> @@ -49,16 +50,16 @@ void initialize_allocator() #endif } -bool constant_time_compare(const uint8_t x[], - const uint8_t y[], - size_t len) +uint8_t ct_compare_u8(const uint8_t x[], + const uint8_t y[], + size_t len) { volatile uint8_t difference = 0; for(size_t i = 0; i != len; ++i) difference |= (x[i] ^ y[i]); - return difference == 0; + return CT::Mask<uint8_t>::is_zero(difference).value(); } } diff --git a/src/lib/utils/mem_ops.h b/src/lib/utils/mem_ops.h index f0599c8b2..bff15e98a 100644 --- a/src/lib/utils/mem_ops.h +++ b/src/lib/utils/mem_ops.h @@ -65,11 +65,25 @@ BOTAN_PUBLIC_API(2,0) void secure_scrub_memory(void* ptr, size_t n); * @param x a pointer to an array * @param y a pointer to another array * @param len the number of Ts in x and y +* @return 0xFF iff x[i] == y[i] forall i in [0...n) or 0x00 otherwise +*/ +BOTAN_PUBLIC_API(2,9) uint8_t ct_compare_u8(const uint8_t x[], + const uint8_t y[], + size_t len); + +/** +* Memory comparison, input insensitive +* @param x a pointer to an array +* @param y a pointer to another array +* @param len the number of Ts in x and y * @return true iff x[i] == y[i] forall i in [0...n) */ -BOTAN_PUBLIC_API(2,3) bool constant_time_compare(const uint8_t x[], - const uint8_t y[], - size_t len); +inline bool constant_time_compare(const uint8_t x[], + const uint8_t y[], + size_t len) + { + return ct_compare_u8(x, y, len) == 0xFF; + } /** * Zero out some bytes diff --git a/src/tests/test_utils.cpp b/src/tests/test_utils.cpp index 7bbe3745c..f1c6bef43 100644 --- a/src/tests/test_utils.cpp +++ b/src/tests/test_utils.cpp @@ -79,44 +79,10 @@ class Utility_Function_Tests final : public Text_Based_Test std::vector<Test::Result> results; results.push_back(test_loadstore()); - results.push_back(test_ct_utils()); return results; } - Test::Result test_ct_utils() - { - Test::Result result("CT utils"); - - result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(0), 0xFF); - result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(1), 0x00); - result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(0xFF), 0x00); - - result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(0), 0xFFFF); - result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(1), 0x0000); - result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(0xFF), 0x0000); - - result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(0), 0xFFFFFFFF); - result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(1), 0x00000000); - result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(0xFF), 0x00000000); - - result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(0, 1), 0xFF); - result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(1, 0), 0x00); - result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(0xFF, 5), 0x00); - - result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(0, 1), 0xFFFF); - result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(1, 0), 0x0000); - result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(0xFFFF, 5), 0x0000); - - result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0, 1), 0xFFFFFFFF); - result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(1, 0), 0x00000000); - result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0xFFFF5, 5), 0x00000000); - result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0xFFFFFFFF, 5), 0x00000000); - result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(5, 0xFFFFFFFF), 0xFFFFFFFF); - - return result; - } - Test::Result test_loadstore() { Test::Result result("Util load/store"); @@ -227,6 +193,45 @@ class Utility_Function_Tests final : public Text_Based_Test BOTAN_REGISTER_TEST("util", Utility_Function_Tests); +class CT_Mask_Tests final : public Test + { + public: + std::vector<Test::Result> run() override + { + Test::Result result("CT utils"); + + result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(0).value(), 0xFF); + result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(1).value(), 0x00); + result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(0xFF).value(), 0x00); + + result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(0).value(), 0xFFFF); + result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(1).value(), 0x0000); + result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(0xFF).value(), 0x0000); + + result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(0).value(), 0xFFFFFFFF); + result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(1).value(), 0x00000000); + result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(0xFF).value(), 0x00000000); + + result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(0, 1).value(), 0xFF); + result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(1, 0).value(), 0x00); + result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(0xFF, 5).value(), 0x00); + + result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(0, 1).value(), 0xFFFF); + result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(1, 0).value(), 0x0000); + result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(0xFFFF, 5).value(), 0x0000); + + result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0, 1).value(), 0xFFFFFFFF); + result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(1, 0).value(), 0x00000000); + result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0xFFFF5, 5).value(), 0x00000000); + result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0xFFFFFFFF, 5).value(), 0x00000000); + result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(5, 0xFFFFFFFF).value(), 0xFFFFFFFF); + + return {result}; + } + }; + +BOTAN_REGISTER_TEST("ct_utils", CT_Mask_Tests); + #if defined(BOTAN_HAS_POLY_DBL) class Poly_Double_Tests final : public Text_Based_Test |