aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2018-11-28 10:35:17 -0500
committerJack Lloyd <[email protected]>2018-11-28 10:35:17 -0500
commit007314c530eb12d414ced07515f8cbc25a0f64f5 (patch)
treedc887f97efa0248aa5e7b8468c94145f6a1305f8 /src
parentb03f38f57d4f50ace1ed8b57d83ba70eb5bc1dfb (diff)
Add CT::Mask type
Diffstat (limited to 'src')
-rw-r--r--src/lib/block/idea/idea.cpp4
-rw-r--r--src/lib/mac/poly1305/poly1305.cpp8
-rw-r--r--src/lib/math/bigint/big_ops2.cpp2
-rw-r--r--src/lib/math/bigint/bigint.cpp24
-rw-r--r--src/lib/math/mp/mp_core.h105
-rw-r--r--src/lib/math/mp/mp_karat.cpp6
-rw-r--r--src/lib/math/numbertheory/monty_exp.cpp6
-rw-r--r--src/lib/modes/mode_pad/mode_pad.cpp63
-rw-r--r--src/lib/pk_pad/eme_oaep/oaep.cpp22
-rw-r--r--src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp20
-rw-r--r--src/lib/pk_pad/iso9796/iso9796.cpp34
-rw-r--r--src/lib/pubkey/dlies/dlies.cpp3
-rw-r--r--src/lib/pubkey/ec_group/point_mul.cpp18
-rw-r--r--src/lib/pubkey/ecies/ecies.cpp4
-rw-r--r--src/lib/pubkey/pubkey.cpp18
-rw-r--r--src/lib/tls/tls_cbc/tls_cbc.cpp20
-rw-r--r--src/lib/utils/ct_utils.h377
-rw-r--r--src/lib/utils/mem_ops.cpp9
-rw-r--r--src/lib/utils/mem_ops.h20
-rw-r--r--src/tests/test_utils.cpp73
20 files changed, 519 insertions, 317 deletions
diff --git a/src/lib/block/idea/idea.cpp b/src/lib/block/idea/idea.cpp
index d6380368d..0bdc36a68 100644
--- a/src/lib/block/idea/idea.cpp
+++ b/src/lib/block/idea/idea.cpp
@@ -20,7 +20,7 @@ namespace {
inline uint16_t mul(uint16_t x, uint16_t y)
{
const uint32_t P = static_cast<uint32_t>(x) * y;
- const uint16_t P_mask = static_cast<uint16_t>(CT::is_zero(P) & 0xFFFF);
+ const auto P_mask = CT::Mask<uint16_t>(CT::Mask<uint32_t>::is_zero(P));
const uint32_t P_hi = P >> 16;
const uint32_t P_lo = P & 0xFFFF;
@@ -29,7 +29,7 @@ inline uint16_t mul(uint16_t x, uint16_t y)
const uint16_t r_1 = static_cast<uint16_t>((P_lo - P_hi) + carry);
const uint16_t r_2 = 1 - x - y;
- return CT::select(P_mask, r_2, r_1);
+ return P_mask.select(r_2, r_1);;
}
/*
diff --git a/src/lib/mac/poly1305/poly1305.cpp b/src/lib/mac/poly1305/poly1305.cpp
index 32027b8d0..333a21a1a 100644
--- a/src/lib/mac/poly1305/poly1305.cpp
+++ b/src/lib/mac/poly1305/poly1305.cpp
@@ -119,10 +119,10 @@ void poly1305_finish(secure_vector<uint64_t>& X, uint8_t mac[16])
uint64_t g2 = h2 + c - (static_cast<uint64_t>(1) << 42);
/* select h if h < p, or h + -p if h >= p */
- c = CT::expand_mask<uint64_t>(c);
- h0 = CT::select(c, g0, h0);
- h1 = CT::select(c, g1, h1);
- h2 = CT::select(c, g2, h2);
+ const auto c_mask = CT::Mask<uint64_t>::expand(c);
+ h0 = c_mask.select(g0, h0);
+ h1 = c_mask.select(g1, h1);
+ h2 = c_mask.select(g2, h2);
/* h = (h + pad) */
const uint64_t t0 = X[6];
diff --git a/src/lib/math/bigint/big_ops2.cpp b/src/lib/math/bigint/big_ops2.cpp
index 2fd775ccf..68baf3200 100644
--- a/src/lib/math/bigint/big_ops2.cpp
+++ b/src/lib/math/bigint/big_ops2.cpp
@@ -171,7 +171,7 @@ BigInt& BigInt::mod_sub(const BigInt& s, const BigInt& mod, secure_vector<word>&
ws.resize(mod_sw);
// is t < s or not?
- const word is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw);
+ const auto is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw);
// ws = p - s
word borrow = bigint_sub3(ws.data(), mod.data(), mod_sw, s.data(), mod_sw);
diff --git a/src/lib/math/bigint/bigint.cpp b/src/lib/math/bigint/bigint.cpp
index 2cb9394ce..667035686 100644
--- a/src/lib/math/bigint/bigint.cpp
+++ b/src/lib/math/bigint/bigint.cpp
@@ -146,7 +146,7 @@ bool BigInt::is_equal(const BigInt& other) const
return false;
return bigint_ct_is_eq(this->data(), this->sig_words(),
- other.data(), other.sig_words());
+ other.data(), other.sig_words()).is_set();
}
bool BigInt::is_less_than(const BigInt& other) const
@@ -160,11 +160,11 @@ bool BigInt::is_less_than(const BigInt& other) const
if(other.is_negative() && this->is_negative())
{
return !bigint_ct_is_lt(other.data(), other.sig_words(),
- this->data(), this->sig_words(), true);
+ this->data(), this->sig_words(), true).is_set();
}
return bigint_ct_is_lt(this->data(), this->sig_words(),
- other.data(), other.sig_words());
+ other.data(), other.sig_words()).is_set();
}
void BigInt::encode_words(word out[], size_t size) const
@@ -187,7 +187,7 @@ size_t BigInt::Data::calc_sig_words() const
for(size_t i = 0; i != m_reg.size(); ++i)
{
const word w = m_reg[m_reg.size() - i - 1];
- sub &= CT::is_zero(w);
+ sub &= CT::Mask<word>::is_zero(w).value();
sig -= sub;
}
@@ -393,13 +393,18 @@ void BigInt::ct_cond_assign(bool predicate, BigInt& other)
const size_t t_words = size();
const size_t o_words = other.size();
+ if(o_words < t_words)
+ grow_to(o_words);
+
const size_t r_words = std::max(t_words, o_words);
- const word mask = CT::expand_mask<word>(predicate);
+ const auto mask = CT::Mask<word>::expand(predicate);
for(size_t i = 0; i != r_words; ++i)
{
- this->set_word_at(i, CT::select<word>(mask, other.word_at(i), this->word_at(i)));
+ const word o_word = other.word_at(i);
+ const word t_word = this->word_at(i);
+ this->set_word_at(i, mask.select(o_word, t_word));
}
}
@@ -430,10 +435,13 @@ void BigInt::const_time_lookup(secure_vector<word>& output,
BOTAN_ASSERT(vec[i].size() >= words,
"Word size as expected in const_time_lookup");
- const word mask = CT::is_equal(i, idx);
+ const auto mask = CT::Mask<word>::is_equal(i, idx);
for(size_t w = 0; w != words; ++w)
- output[w] |= CT::select<word>(mask, vec[i].word_at(w), 0);
+ {
+ const word viw = vec[i].word_at(w);
+ output[w] = mask.if_set_return(viw);
+ }
}
CT::unpoison(idx);
diff --git a/src/lib/math/mp/mp_core.h b/src/lib/math/mp/mp_core.h
index 9a19a46be..4829ef6fc 100644
--- a/src/lib/math/mp/mp_core.h
+++ b/src/lib/math/mp/mp_core.h
@@ -30,14 +30,14 @@ const word MP_WORD_MAX = MP_WORD_MASK;
*/
inline void bigint_cnd_swap(word cnd, word x[], word y[], size_t size)
{
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
for(size_t i = 0; i != size; ++i)
{
const word a = x[i];
const word b = y[i];
- x[i] = CT::select(mask, b, a);
- y[i] = CT::select(mask, a, b);
+ x[i] = mask.select(b, a);
+ y[i] = mask.select(a, b);
}
}
@@ -46,7 +46,7 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size,
{
BOTAN_ASSERT(x_size >= y_size, "Expected sizes");
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
word carry = 0;
@@ -56,24 +56,22 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size,
for(size_t i = 0; i != blocks; i += 8)
{
carry = word8_add3(z, x + i, y + i, carry);
-
- for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, z[j], x[i+j]);
+ mask.select_n(x + i, z, x + i, 8);
}
for(size_t i = blocks; i != y_size; ++i)
{
z[0] = word_add(x[i], y[i], &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
for(size_t i = y_size; i != x_size; ++i)
{
z[0] = word_add(x[i], 0, &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
- return carry & mask;
+ return mask.if_set_return(carry);
}
/*
@@ -95,7 +93,7 @@ inline word bigint_cnd_sub(word cnd,
{
BOTAN_ASSERT(x_size >= y_size, "Expected sizes");
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
word carry = 0;
@@ -105,24 +103,22 @@ inline word bigint_cnd_sub(word cnd,
for(size_t i = 0; i != blocks; i += 8)
{
carry = word8_sub3(z, x + i, y + i, carry);
-
- for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, z[j], x[i+j]);
+ mask.select_n(x + i, z, x + i, 8);
}
for(size_t i = blocks; i != y_size; ++i)
{
z[0] = word_sub(x[i], y[i], &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
for(size_t i = y_size; i != x_size; ++i)
{
z[0] = word_sub(x[i], 0, &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
- return carry & mask;
+ return mask.if_set_return(carry);
}
/*
@@ -142,7 +138,7 @@ inline word bigint_cnd_sub(word cnd, word x[], const word y[], size_t size)
*
* Mask must be either 0 or all 1 bits
*/
-inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
+inline void bigint_cnd_addsub(CT::Mask<word> mask, word x[], const word y[], size_t size)
{
const size_t blocks = size - (size % 8);
@@ -158,7 +154,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
borrow = word8_sub3(t1, x + i, y + i, borrow);
for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, t0[j], t1[j]);
+ x[i+j] = mask.select(t0[j], t1[j]);
}
for(size_t i = blocks; i != size; ++i)
@@ -166,7 +162,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
const word a = word_add(x[i], y[i], &carry);
const word s = word_sub(x[i], y[i], &borrow);
- x[i] = CT::select(mask, a, s);
+ x[i] = mask.select(a, s);
}
}
@@ -179,7 +175,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
*
* Returns the carry or borrow resp
*/
-inline word bigint_cnd_addsub(word mask, word x[],
+inline word bigint_cnd_addsub(CT::Mask<word> mask, word x[],
const word y[], const word z[],
size_t size)
{
@@ -197,17 +193,17 @@ inline word bigint_cnd_addsub(word mask, word x[],
borrow = word8_sub3(t1, x + i, z + i, borrow);
for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, t0[j], t1[j]);
+ x[i+j] = mask.select(t0[j], t1[j]);
}
for(size_t i = blocks; i != size; ++i)
{
t0[0] = word_add(x[i], y[i], &carry);
t1[0] = word_sub(x[i], z[i], &borrow);
- x[i] = CT::select(mask, t0[0], t1[0]);
+ x[i] = mask.select(t0[0], t1[0]);
}
- return CT::select(mask, carry, borrow);
+ return mask.select(carry, borrow);
}
/*
@@ -217,13 +213,13 @@ inline word bigint_cnd_addsub(word mask, word x[],
*/
inline void bigint_cnd_abs(word cnd, word x[], size_t size)
{
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
- word carry = mask & 1;
+ word carry = mask.if_set_return(1);
for(size_t i = 0; i != size; ++i)
{
const word z = word_add(~x[i], 0, &carry);
- x[i] = CT::select(mask, z, x[i]);
+ x[i] = mask.select(z, x[i]);
}
}
@@ -379,9 +375,10 @@ inline word bigint_sub3(word z[],
* @param N length of x and y
* @param ws array of at least 2*N words
*/
-inline word bigint_sub_abs(word z[],
- const word x[], const word y[], size_t N,
- word ws[])
+inline CT::Mask<word>
+bigint_sub_abs(word z[],
+ const word x[], const word y[], size_t N,
+ word ws[])
{
// Subtract in both direction then conditional copy out the result
@@ -544,19 +541,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
{
static_assert(sizeof(word) >= sizeof(uint32_t), "Size assumption");
- const uint32_t LT = static_cast<uint32_t>(-1);
- const uint32_t EQ = 0;
- const uint32_t GT = 1;
+ const word LT = static_cast<word>(-1);
+ const word EQ = 0;
+ const word GT = 1;
const size_t common_elems = std::min(x_size, y_size);
- uint32_t result = EQ; // until found otherwise
+ word result = EQ; // until found otherwise
for(size_t i = 0; i != common_elems; i++)
{
- const word is_eq = CT::is_equal(x[i], y[i]);
- const word is_lt = CT::is_less(x[i], y[i]);
- result = CT::select<uint32_t>(is_eq, result, CT::select<uint32_t>(is_lt, LT, GT));
+ const auto is_eq = CT::Mask<word>::is_equal(x[i], y[i]);
+ const auto is_lt = CT::Mask<word>::is_lt(x[i], y[i]);
+
+ result = is_eq.select(result, is_lt.select(LT, GT));
}
if(x_size < y_size)
@@ -566,7 +564,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
mask |= y[i];
// If any bits were set in high part of y, then x < y
- result = CT::select<uint32_t>(CT::is_zero(mask), result, LT);
+ result = CT::Mask<word>::is_zero(mask).select(result, LT);
}
else if(y_size < x_size)
{
@@ -575,7 +573,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
mask |= x[i];
// If any bits were set in high part of x, then x > y
- result = CT::select<uint32_t>(CT::is_zero(mask), result, GT);
+ result = CT::Mask<word>::is_zero(mask).select(result, GT);
}
CT::unpoison(result);
@@ -588,19 +586,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
* Return ~0 if x[0:x_size] < y[0:y_size] or 0 otherwise
* If lt_or_equal is true, returns ~0 also for x == y
*/
-inline word bigint_ct_is_lt(const word x[], size_t x_size,
- const word y[], size_t y_size,
- bool lt_or_equal = false)
+inline CT::Mask<word>
+bigint_ct_is_lt(const word x[], size_t x_size,
+ const word y[], size_t y_size,
+ bool lt_or_equal = false)
{
const size_t common_elems = std::min(x_size, y_size);
- word is_lt = CT::expand_mask<word>(lt_or_equal);
+ auto is_lt = CT::Mask<word>::expand(lt_or_equal);
for(size_t i = 0; i != common_elems; i++)
{
- const word eq = CT::is_equal(x[i], y[i]);
- const word lt = CT::is_less(x[i], y[i]);
- is_lt = CT::select(eq, is_lt, lt);
+ const auto eq = CT::Mask<word>::is_equal(x[i], y[i]);
+ const auto lt = CT::Mask<word>::is_lt(x[i], y[i]);
+ is_lt = eq.select_mask(is_lt, lt);
}
if(x_size < y_size)
@@ -609,7 +608,7 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size,
for(size_t i = x_size; i != y_size; i++)
mask |= y[i];
// If any bits were set in high part of y, then is_lt should be forced true
- is_lt |= ~CT::is_zero(mask);
+ is_lt |= CT::Mask<word>::expand(mask);
}
else if(y_size < x_size)
{
@@ -618,15 +617,15 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size,
mask |= x[i];
// If any bits were set in high part of x, then is_lt should be false
- is_lt &= CT::is_zero(mask);
+ is_lt &= CT::Mask<word>::is_zero(mask);
}
- CT::unpoison(is_lt);
return is_lt;
}
-inline word bigint_ct_is_eq(const word x[], size_t x_size,
- const word y[], size_t y_size)
+inline CT::Mask<word>
+bigint_ct_is_eq(const word x[], size_t x_size,
+ const word y[], size_t y_size)
{
const size_t common_elems = std::min(x_size, y_size);
@@ -649,9 +648,7 @@ inline word bigint_ct_is_eq(const word x[], size_t x_size,
diff |= x[i];
}
- const word is_equal = CT::is_zero(diff);
- CT::unpoison(is_equal);
- return is_equal;
+ return CT::Mask<word>::is_zero(diff);
}
/**
diff --git a/src/lib/math/mp/mp_karat.cpp b/src/lib/math/mp/mp_karat.cpp
index c0fc5304b..8bd7cf58d 100644
--- a/src/lib/math/mp/mp_karat.cpp
+++ b/src/lib/math/mp/mp_karat.cpp
@@ -124,9 +124,9 @@ void karatsuba_mul(word z[], const word x[], const word y[], size_t N,
*/
// First compute (X_lo - X_hi)*(Y_hi - Y_lo)
- const word cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
- const word cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
- const word neg_mask = ~(cmp0 ^ cmp1);
+ const auto cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
+ const auto cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
+ const auto neg_mask = ~(cmp0 ^ cmp1);
karatsuba_mul(ws0, z0, z1, N2, ws1);
diff --git a/src/lib/math/numbertheory/monty_exp.cpp b/src/lib/math/numbertheory/monty_exp.cpp
index 2b5bbd81d..7590005a0 100644
--- a/src/lib/math/numbertheory/monty_exp.cpp
+++ b/src/lib/math/numbertheory/monty_exp.cpp
@@ -85,10 +85,12 @@ void const_time_lookup(secure_vector<word>& output,
BOTAN_ASSERT(vec.size() >= words,
"Word size as expected in const_time_lookup");
- const word mask = CT::is_equal<word>(i, nibble);
+ const auto mask = CT::Mask<word>::is_equal(i, nibble);
for(size_t w = 0; w != words; ++w)
- output[w] |= (mask & vec[w]);
+ {
+ output[w] |= mask.if_set_return(vec[w]);
+ }
}
}
diff --git a/src/lib/modes/mode_pad/mode_pad.cpp b/src/lib/modes/mode_pad/mode_pad.cpp
index e65114c88..5c949e9cf 100644
--- a/src/lib/modes/mode_pad/mode_pad.cpp
+++ b/src/lib/modes/mode_pad/mode_pad.cpp
@@ -57,21 +57,30 @@ size_t PKCS7_Padding::unpad(const uint8_t input[], size_t input_length) const
return input_length;
CT::poison(input, input_length);
- size_t bad_input = 0;
+
const uint8_t last_byte = input[input_length-1];
- bad_input |= CT::expand_mask<size_t>(last_byte > input_length);
+ /*
+ The input should == the block size so if the last byte exceeds
+ that then the padding is certainly invalid
+ */
+ auto bad_input = CT::Mask<size_t>::is_gt(last_byte, input_length);
const size_t pad_pos = input_length - last_byte;
for(size_t i = 0; i != input_length - 1; ++i)
{
- const uint8_t in_range = CT::expand_mask<uint8_t>(i >= pad_pos);
- bad_input |= in_range & (~CT::is_equal(input[i], last_byte));
+ // Does this byte equal the expected pad byte?
+ const auto pad_eq = CT::Mask<size_t>::is_equal(input[i], last_byte);
+
+ // Ignore values that are not part of the padding
+ const auto in_range = CT::Mask<size_t>::is_gte(i, pad_pos);
+ bad_input |= in_range & (~pad_eq);
}
CT::unpoison(input, input_length);
- return CT::conditional_return(bad_input, input_length, pad_pos);
+
+ return bad_input.select_and_unpoison(input_length, pad_pos);
}
/*
@@ -99,21 +108,24 @@ size_t ANSI_X923_Padding::unpad(const uint8_t input[], size_t input_length) cons
return input_length;
CT::poison(input, input_length);
+
const size_t last_byte = input[input_length-1];
- uint8_t bad_input = 0;
- bad_input |= CT::expand_mask<uint8_t>(last_byte > input_length);
+ auto bad_input = CT::Mask<size_t>::is_gt(last_byte, input_length);
const size_t pad_pos = input_length - last_byte;
for(size_t i = 0; i != input_length - 1; ++i)
{
- const uint8_t in_range = CT::expand_mask<uint8_t>(i >= pad_pos);
- bad_input |= CT::expand_mask(input[i]) & in_range;
+ // Ignore values that are not part of the padding
+ const auto in_range = CT::Mask<size_t>::is_gte(i, pad_pos);
+ const auto pad_is_nonzero = CT::Mask<size_t>::expand(input[i]);
+ bad_input |= pad_is_nonzero & in_range;
}
CT::unpoison(input, input_length);
- return CT::conditional_return(bad_input, input_length, pad_pos);
+
+ return bad_input.select_and_unpoison(input_length, pad_pos);
}
/*
@@ -139,22 +151,26 @@ size_t OneAndZeros_Padding::unpad(const uint8_t input[], size_t input_length) co
CT::poison(input, input_length);
- uint8_t bad_input = 0;
- uint8_t seen_one = 0;
+ auto bad_input = CT::Mask<uint8_t>::cleared();
+ auto seen_0x80 = CT::Mask<uint8_t>::cleared();
+
size_t pad_pos = input_length - 1;
size_t i = input_length;
while(i)
{
- seen_one |= CT::is_equal<uint8_t>(input[i-1], 0x80);
- pad_pos -= CT::select<uint8_t>(~seen_one, 1, 0);
- bad_input |= ~CT::is_zero<uint8_t>(input[i-1]) & ~seen_one;
+ const auto is_0x80 = CT::Mask<uint8_t>::is_equal(input[i-1], 0x80);
+ const auto is_zero = CT::Mask<uint8_t>::is_zero(input[i-1]);
+
+ seen_0x80 |= is_0x80;
+ pad_pos -= seen_0x80.if_not_set_return(1);
+ bad_input |= ~seen_0x80 & ~is_zero;
i--;
}
- bad_input |= ~seen_one;
+ bad_input |= ~seen_0x80;
CT::unpoison(input, input_length);
- return CT::conditional_return(bad_input, input_length, pad_pos);
+ return bad_input.select_and_unpoison(input_length, pad_pos);
}
/*
@@ -183,20 +199,23 @@ size_t ESP_Padding::unpad(const uint8_t input[], size_t input_length) const
CT::poison(input, input_length);
const size_t last_byte = input[input_length-1];
- uint8_t bad_input = 0;
- bad_input |= CT::is_zero(last_byte) | CT::expand_mask<uint8_t>(last_byte > input_length);
+
+ auto bad_input = CT::Mask<uint8_t>::is_zero(last_byte) |
+ CT::Mask<uint8_t>::is_gt(last_byte, input_length);
const size_t pad_pos = input_length - last_byte;
size_t i = input_length - 1;
while(i)
{
- const uint8_t in_range = CT::expand_mask<uint8_t>(i > pad_pos);
- bad_input |= (~CT::is_equal<uint8_t>(input[i-1], input[i]-1)) & in_range;
+ const auto in_range = CT::Mask<uint8_t>::is_gt(i, pad_pos);
+ const auto incrementing = CT::Mask<uint8_t>::is_equal(input[i-1], input[i]-1);
+
+ bad_input |= in_range & ~incrementing;
--i;
}
CT::unpoison(input, input_length);
- return CT::conditional_return(bad_input, input_length, pad_pos);
+ return bad_input.select_and_unpoison(input_length, pad_pos);
}
diff --git a/src/lib/pk_pad/eme_oaep/oaep.cpp b/src/lib/pk_pad/eme_oaep/oaep.cpp
index 398202bd1..9a8676ab9 100644
--- a/src/lib/pk_pad/eme_oaep/oaep.cpp
+++ b/src/lib/pk_pad/eme_oaep/oaep.cpp
@@ -72,7 +72,7 @@ secure_vector<uint8_t> OAEP::unpad(uint8_t& valid_mask,
Therefore, the first byte can always be skipped safely.
*/
- uint8_t skip_first = CT::is_zero<uint8_t>(in[0]) & 0x01;
+ const uint8_t skip_first = CT::Mask<uint8_t>::is_zero(in[0]).if_set_return(1);
secure_vector<uint8_t> input(in + skip_first, in + in_length);
@@ -105,37 +105,37 @@ oaep_find_delim(uint8_t& valid_mask,
CT::poison(input, input_len);
size_t delim_idx = 2 * hlen;
- uint8_t waiting_for_delim = 0xFF;
- uint8_t bad_input = 0;
+ CT::Mask<uint8_t> waiting_for_delim = CT::Mask<uint8_t>::set();
+ CT::Mask<uint8_t> bad_input = CT::Mask<uint8_t>::cleared();
for(size_t i = delim_idx; i < input_len; ++i)
{
- const uint8_t zero_m = CT::is_zero<uint8_t>(input[i]);
- const uint8_t one_m = CT::is_equal<uint8_t>(input[i], 1);
+ const auto zero_m = CT::Mask<uint8_t>::is_zero(input[i]);
+ const auto one_m = CT::Mask<uint8_t>::is_equal(input[i], 1);
- const uint8_t add_m = waiting_for_delim & zero_m;
+ const auto add_m = waiting_for_delim & zero_m;
bad_input |= waiting_for_delim & ~(zero_m | one_m);
- delim_idx += CT::select<uint8_t>(add_m, 1, 0);
+ delim_idx += add_m.if_set_return(1);
waiting_for_delim &= zero_m;
}
// If we never saw any non-zero byte, then it's not valid input
bad_input |= waiting_for_delim;
- bad_input |= CT::is_equal<uint8_t>(constant_time_compare(&input[hlen], Phash.data(), hlen), false);
+ bad_input |= CT::Mask<uint8_t>::is_zero(ct_compare_u8(&input[hlen], Phash.data(), hlen));
- delim_idx &= ~CT::expand_mask<size_t>(bad_input);
+ delim_idx = CT::Mask<size_t>::expand(bad_input.value()).if_not_set_return(delim_idx);
CT::unpoison(input, input_len);
CT::unpoison(&bad_input, 1);
CT::unpoison(&delim_idx, 1);
- valid_mask = ~bad_input;
+ valid_mask = (~bad_input).value();
secure_vector<uint8_t> output(input + delim_idx + 1, input + input_len);
- CT::cond_zero_mem(bad_input, output.data(), output.size());
+ bad_input.if_set_zero_out(output.data(), output.size());
return output;
}
diff --git a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
index 58aadbc38..597b7c26a 100644
--- a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
+++ b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp
@@ -58,33 +58,33 @@ secure_vector<uint8_t> EME_PKCS1v15::unpad(uint8_t& valid_mask,
CT::poison(in, inlen);
- uint8_t bad_input_m = 0;
- uint8_t seen_zero_m = 0;
+ CT::Mask<uint8_t> bad_input_m = CT::Mask<uint8_t>::cleared();
+ CT::Mask<uint8_t> seen_zero_m = CT::Mask<uint8_t>::cleared();
size_t delim_idx = 0;
- bad_input_m |= ~CT::is_equal<uint8_t>(in[0], 0);
- bad_input_m |= ~CT::is_equal<uint8_t>(in[1], 2);
+ bad_input_m |= ~CT::Mask<uint8_t>::is_equal(in[0], 0);
+ bad_input_m |= ~CT::Mask<uint8_t>::is_equal(in[1], 2);
for(size_t i = 2; i < inlen; ++i)
{
- const uint8_t is_zero_m = CT::is_zero<uint8_t>(in[i]);
+ const auto is_zero_m = CT::Mask<uint8_t>::is_zero(in[i]);
- delim_idx += CT::select<uint8_t>(~seen_zero_m, 1, 0);
+ delim_idx += seen_zero_m.if_not_set_return(1);
- bad_input_m |= is_zero_m & CT::expand_mask<uint8_t>(i < 10);
+ bad_input_m |= is_zero_m & CT::Mask<uint8_t>(CT::Mask<size_t>::is_lt(i, 10));
seen_zero_m |= is_zero_m;
}
bad_input_m |= ~seen_zero_m;
- bad_input_m |= CT::is_less<size_t>(delim_idx, 8);
+ bad_input_m |= CT::Mask<uint8_t>(CT::Mask<size_t>::is_lt(delim_idx, 8));
CT::unpoison(in, inlen);
CT::unpoison(bad_input_m);
CT::unpoison(delim_idx);
secure_vector<uint8_t> output(&in[delim_idx + 2], &in[inlen]);
- CT::cond_zero_mem(bad_input_m, output.data(), output.size());
- valid_mask = ~bad_input_m;
+ bad_input_m.if_set_zero_out(output.data(), output.size());
+ valid_mask = ~bad_input_m.value();
return output;
}
diff --git a/src/lib/pk_pad/iso9796/iso9796.cpp b/src/lib/pk_pad/iso9796/iso9796.cpp
index 99a1dfd29..28a603bf9 100644
--- a/src/lib/pk_pad/iso9796/iso9796.cpp
+++ b/src/lib/pk_pad/iso9796/iso9796.cpp
@@ -142,26 +142,29 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded,
//recover msg1 and salt
size_t msg1_offset = 1;
- uint8_t waiting_for_delim = 0xFF;
- uint8_t bad_input = 0;
+
+ auto waiting_for_delim = CT::Mask<uint8_t>::set();
+ auto bad_input = CT::Mask<uint8_t>::cleared();
+
for(size_t j = 0; j < DB_size; ++j)
{
- const uint8_t one_m = CT::is_equal<uint8_t>(DB[j], 0x01);
- const uint8_t zero_m = CT::is_zero(DB[j]);
- const uint8_t add_m = waiting_for_delim & zero_m;
+ const auto is_zero = CT::Mask<uint8_t>::is_zero(DB[j]);
+ const auto is_one = CT::Mask<uint8_t>::is_equal(DB[j], 0x01);
- bad_input |= waiting_for_delim & ~(zero_m | one_m);
- msg1_offset += CT::select<uint8_t>(add_m, 1, 0);
+ const auto add_m = waiting_for_delim & is_zero;
- waiting_for_delim &= zero_m;
+ bad_input |= waiting_for_delim & ~(is_zero | is_one);
+ msg1_offset += add_m.if_set_return(1);
+
+ waiting_for_delim &= is_zero;
}
//invalid, if delimiter 0x01 was not found or msg1_offset is too big
bad_input |= waiting_for_delim;
- bad_input |= CT::is_less(coded.size(), tLength + HASH_SIZE + msg1_offset + SALT_SIZE);
+ bad_input |= CT::Mask<size_t>::is_lt(coded.size(), tLength + HASH_SIZE + msg1_offset + SALT_SIZE);
//in case that msg1_offset is too big, just continue with offset = 0.
- msg1_offset = CT::select<size_t>(bad_input, 0, msg1_offset);
+ msg1_offset = CT::Mask<size_t>::expand(bad_input.value()).if_not_set_return(msg1_offset);
CT::unpoison(coded.data(), coded.size());
CT::unpoison(msg1_offset);
@@ -172,8 +175,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded,
coded.end() - tLength - HASH_SIZE);
//compute H2(C||msg1||H(msg2)||S*). * indicates a recovered value
- const size_t capacity = (key_bits - 2 + 7) / 8 - HASH_SIZE
- - SALT_SIZE - tLength - 1;
+ const size_t capacity = (key_bits - 2 + 7) / 8 - HASH_SIZE - SALT_SIZE - tLength - 1;
secure_vector<uint8_t> msg1raw;
secure_vector<uint8_t> msg2;
if(raw.size() > capacity)
@@ -188,7 +190,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded,
}
msg2 = hash->final();
- uint64_t msg1rawLength = msg1raw.size();
+ const uint64_t msg1rawLength = msg1raw.size();
hash->update_be(msg1rawLength * 8);
hash->update(msg1raw);
hash->update(msg2);
@@ -196,7 +198,7 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded,
secure_vector<uint8_t> H3 = hash->final();
//compute H3(C*||msg1*||H(msg2)||S*) * indicates a recovered value
- uint64_t msgLength = msg1.size();
+ const uint64_t msgLength = msg1.size();
hash->update_be(msgLength * 8);
hash->update(msg1);
hash->update(msg2);
@@ -204,10 +206,10 @@ bool iso9796_verification(const secure_vector<uint8_t>& const_coded,
secure_vector<uint8_t> H2 = hash->final();
//check if H3 == H2
- bad_input |= CT::is_equal<uint8_t>(constant_time_compare(H3.data(), H2.data(), HASH_SIZE), false);
+ bad_input |= CT::Mask<uint8_t>::is_zero(ct_compare_u8(H3.data(), H2.data(), HASH_SIZE));
CT::unpoison(bad_input);
- return (bad_input == 0);
+ return (bad_input.is_set() == false);
}
}
diff --git a/src/lib/pubkey/dlies/dlies.cpp b/src/lib/pubkey/dlies/dlies.cpp
index 5465401d1..aa214fd8b 100644
--- a/src/lib/pubkey/dlies/dlies.cpp
+++ b/src/lib/pubkey/dlies/dlies.cpp
@@ -7,7 +7,6 @@
*/
#include <botan/dlies.h>
-#include <botan/internal/ct_utils.h>
#include <limits>
namespace Botan {
@@ -180,7 +179,7 @@ secure_vector<uint8_t> DLIES_Decryptor::do_decrypt(uint8_t& valid_mask,
secure_vector<uint8_t> tag(msg + m_pub_key_size + ciphertext_len,
msg + m_pub_key_size + ciphertext_len + m_mac->output_length());
- valid_mask = CT::expand_mask<uint8_t>(constant_time_compare(tag.data(), calculated_tag.data(), tag.size()));
+ valid_mask = ct_compare_u8(tag.data(), calculated_tag.data(), tag.size());
// decrypt
if(m_cipher)
diff --git a/src/lib/pubkey/ec_group/point_mul.cpp b/src/lib/pubkey/ec_group/point_mul.cpp
index da3abaacc..2707a98f3 100644
--- a/src/lib/pubkey/ec_group/point_mul.cpp
+++ b/src/lib/pubkey/ec_group/point_mul.cpp
@@ -128,9 +128,9 @@ PointGFp PointGFp_Base_Point_Precompute::mul(const BigInt& k,
const word w = scalar.get_substring(2*window, 2);
- const word w_is_1 = CT::is_equal<word>(w, 1);
- const word w_is_2 = CT::is_equal<word>(w, 2);
- const word w_is_3 = CT::is_equal<word>(w, 3);
+ const auto w_is_1 = CT::Mask<word>::is_equal(w, 1);
+ const auto w_is_2 = CT::Mask<word>::is_equal(w, 2);
+ const auto w_is_3 = CT::Mask<word>::is_equal(w, 3);
for(size_t j = 0; j != elem_size; ++j)
{
@@ -138,7 +138,7 @@ PointGFp PointGFp_Base_Point_Precompute::mul(const BigInt& k,
const word w2 = m_W[base_addr + 1*elem_size + j];
const word w3 = m_W[base_addr + 2*elem_size + j];
- Wt[j] = CT::select3<word>(w_is_1, w1, w_is_2, w2, w_is_3, w3, 0);
+ Wt[j] = w_is_1.select(w1, w_is_2.select(w2, w_is_3.select(w3, 0)));
}
R.add_affine(&Wt[0], m_p_words, &Wt[m_p_words], m_p_words, ws);
@@ -255,11 +255,11 @@ PointGFp PointGFp_Var_Point_Precompute::mul(const BigInt& k,
clear_mem(e.data(), e.size());
for(size_t i = 1; i != window_elems; ++i)
{
- const word wmask = CT::is_equal<word>(w, i);
+ const auto wmask = CT::Mask<word>::is_equal(w, i);
for(size_t j = 0; j != elem_size; ++j)
{
- e[j] |= wmask & m_T[i * elem_size + j];
+ e[j] |= wmask.if_set_return(m_T[i * elem_size + j]);
}
}
@@ -282,10 +282,12 @@ PointGFp PointGFp_Var_Point_Precompute::mul(const BigInt& k,
clear_mem(e.data(), e.size());
for(size_t i = 1; i != window_elems; ++i)
{
- const word wmask = CT::is_equal<word>(w, i);
+ const auto wmask = CT::Mask<word>::is_equal(w, i);
for(size_t j = 0; j != elem_size; ++j)
- e[j] |= wmask & m_T[i * elem_size + j];
+ {
+ e[j] |= wmask.if_set_return(m_T[i * elem_size + j]);
+ }
}
R.add(&e[0], m_p_words, &e[m_p_words], m_p_words, &e[2*m_p_words], m_p_words, ws);
diff --git a/src/lib/pubkey/ecies/ecies.cpp b/src/lib/pubkey/ecies/ecies.cpp
index b35ecc107..864e0b72a 100644
--- a/src/lib/pubkey/ecies/ecies.cpp
+++ b/src/lib/pubkey/ecies/ecies.cpp
@@ -10,8 +10,6 @@
#include <botan/numthry.h>
#include <botan/cipher_mode.h>
#include <botan/mac.h>
-
-#include <botan/internal/ct_utils.h>
#include <botan/internal/pk_ops_impl.h>
namespace Botan {
@@ -386,7 +384,7 @@ secure_vector<uint8_t> ECIES_Decryptor::do_decrypt(uint8_t& valid_mask, const ui
m_mac->update(m_label);
}
const secure_vector<uint8_t> calculated_mac = m_mac->final();
- valid_mask = CT::expand_mask<uint8_t>(constant_time_compare(mac_data.data(), calculated_mac.data(), mac_data.size()));
+ valid_mask = ct_compare_u8(mac_data.data(), calculated_mac.data(), mac_data.size());
if(valid_mask)
{
diff --git a/src/lib/pubkey/pubkey.cpp b/src/lib/pubkey/pubkey.cpp
index bb0170548..d98b5dc9e 100644
--- a/src/lib/pubkey/pubkey.cpp
+++ b/src/lib/pubkey/pubkey.cpp
@@ -37,10 +37,11 @@ PK_Decryptor::decrypt_or_random(const uint8_t in[],
{
const secure_vector<uint8_t> fake_pms = rng.random_vec(expected_pt_len);
- uint8_t valid_mask = 0;
- secure_vector<uint8_t> decoded = do_decrypt(valid_mask, in, length);
+ uint8_t decrypt_valid = 0;
+ secure_vector<uint8_t> decoded = do_decrypt(decrypt_valid, in, length);
- valid_mask &= CT::is_equal(decoded.size(), expected_pt_len);
+ auto valid_mask = CT::Mask<uint8_t>::is_equal(decrypt_valid, 0xFF);
+ valid_mask &= CT::Mask<uint8_t>(CT::Mask<size_t>::is_zero(decoded.size() ^ expected_pt_len));
decoded.resize(expected_pt_len);
@@ -62,14 +63,13 @@ PK_Decryptor::decrypt_or_random(const uint8_t in[],
BOTAN_ASSERT(off < expected_pt_len, "Offset in range of plaintext");
- valid_mask &= CT::is_equal(decoded[off], exp);
+ auto eq = CT::Mask<uint8_t>::is_equal(decoded[off], exp);
+
+ valid_mask &= eq;
}
- CT::conditional_copy_mem(valid_mask,
- /*output*/decoded.data(),
- /*from0*/decoded.data(),
- /*from1*/fake_pms.data(),
- expected_pt_len);
+ // If valid_mask is false, assign fake pre master instead
+ valid_mask.select_n(decoded.data(), decoded.data(), fake_pms.data(), expected_pt_len);
return decoded;
}
diff --git a/src/lib/tls/tls_cbc/tls_cbc.cpp b/src/lib/tls/tls_cbc/tls_cbc.cpp
index 7376e655b..f3ea17d42 100644
--- a/src/lib/tls/tls_cbc/tls_cbc.cpp
+++ b/src/lib/tls/tls_cbc/tls_cbc.cpp
@@ -235,17 +235,17 @@ uint16_t check_tls_cbc_padding(const uint8_t record[], size_t record_len)
const uint8_t pad_byte = record[record_len-1];
const uint16_t pad_bytes = 1 + pad_byte;
- uint16_t pad_invalid = CT::is_less<uint16_t>(rec16, pad_bytes);
+ auto pad_invalid = CT::Mask<uint16_t>::is_lt(rec16, pad_byte);
for(uint16_t i = rec16 - to_check; i != rec16; ++i)
{
const uint16_t offset = rec16 - i;
- const uint16_t in_pad_range = CT::is_lte<uint16_t>(offset, pad_bytes);
- pad_invalid |= (in_pad_range & (record[i] ^ pad_byte));
+ const auto in_pad_range = CT::Mask<uint16_t>::is_lte(offset, pad_bytes);
+ const auto pad_correct = CT::Mask<uint16_t>::is_equal(record[i], pad_byte);
+ pad_invalid |= in_pad_range & ~pad_correct;
}
- const uint16_t pad_invalid_mask = CT::expand_mask<uint16_t>(pad_invalid);
- return CT::select<uint16_t>(pad_invalid_mask, 0, pad_byte + 1);
+ return pad_invalid.if_not_set_return(pad_bytes);
}
void TLS_CBC_HMAC_AEAD_Decryption::cbc_decrypt_record(uint8_t record_contents[], size_t record_len)
@@ -337,7 +337,7 @@ void TLS_CBC_HMAC_AEAD_Decryption::perform_additional_compressions(size_t plen,
const uint16_t current_compressions = ((L2 + block_size - 1 - max_bytes_in_first_block) / block_size);
// number of additional compressions we have to perform
const uint16_t add_compressions = max_compresssions - current_compressions;
- const uint8_t equal = CT::is_equal(max_compresssions, current_compressions) & 0x01;
+ const uint8_t equal = CT::Mask<uint16_t>::is_equal(max_compresssions, current_compressions).if_set_return(1);
// We compute the data length we need to achieve the number of compressions.
// If there are no compressions, we just add 55/111 dummy bytes so that no
// compression is performed.
@@ -418,8 +418,8 @@ void TLS_CBC_HMAC_AEAD_Decryption::finish(secure_vector<uint8_t>& buffer, size_t
(sending empty records, instead of 1/(n-1) splitting)
*/
- const uint16_t size_ok_mask = CT::is_lte<uint16_t>(static_cast<uint16_t>(tag_size() + pad_size), static_cast<uint16_t>(record_len));
- pad_size &= size_ok_mask;
+ const auto size_ok_mask = CT::Mask<uint16_t>::is_lte(tag_size() + pad_size, record_len);
+ pad_size = size_ok_mask.if_set_return(pad_size);
CT::unpoison(record_contents, record_len);
@@ -442,11 +442,11 @@ void TLS_CBC_HMAC_AEAD_Decryption::finish(secure_vector<uint8_t>& buffer, size_t
const bool mac_ok = constant_time_compare(&record_contents[mac_offset], mac_buf.data(), tag_size());
- const uint16_t ok_mask = size_ok_mask & CT::expand_mask<uint16_t>(mac_ok) & CT::expand_mask<uint16_t>(pad_size);
+ const auto ok_mask = size_ok_mask & CT::Mask<uint16_t>::expand(mac_ok) & CT::Mask<uint16_t>::expand(pad_size);
CT::unpoison(ok_mask);
- if(ok_mask)
+ if(ok_mask.is_set())
{
buffer.insert(buffer.end(), plaintext_block, plaintext_block + plaintext_length);
}
diff --git a/src/lib/utils/ct_utils.h b/src/lib/utils/ct_utils.h
index 63b8f4640..eb510baa2 100644
--- a/src/lib/utils/ct_utils.h
+++ b/src/lib/utils/ct_utils.h
@@ -6,13 +6,13 @@
* Wagner, Molnar, et al "The Program Counter Security Model"
*
* (C) 2010 Falko Strenzke
-* (C) 2015,2016 Jack Lloyd
+* (C) 2015,2016,2018 Jack Lloyd
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
-#ifndef BOTAN_TIMING_ATTACK_CM_H_
-#define BOTAN_TIMING_ATTACK_CM_H_
+#ifndef BOTAN_CT_UTILS_H_
+#define BOTAN_CT_UTILS_H_
#include <botan/secmem.h>
#include <type_traits>
@@ -75,130 +75,285 @@ inline void unpoison(T& p)
#endif
}
-/* Mask generation */
-
-template<typename T>
-inline constexpr T expand_top_bit(T a)
- {
- static_assert(std::is_unsigned<T>::value, "unsigned integer type required");
- return static_cast<T>(0) - (a >> (sizeof(T)*8-1));
- }
-
-template<typename T>
-inline constexpr T is_zero(T x)
- {
- static_assert(std::is_unsigned<T>::value, "unsigned integer type required");
- return expand_top_bit<T>(~x & (x - 1));
- }
-
-/*
-* T should be an unsigned machine integer type
-* Expand to a mask used for other operations
-* @param in an integer
-* @return If n is zero, returns zero. Otherwise
-* returns a T with all bits set for use as a mask with
-* select.
+/**
+* A Mask type used for constant-time operations. A Mask<T> always has value
+* either 0 (all bits cleared) or ~0 (all bits set). All operations in a Mask<T>
+* are intended to compile to code which does not contain conditional jumps.
+* This must be verified with tooling (eg binary disassembly or using valgrind)
+* since you never know what a compiler might do.
*/
template<typename T>
-inline constexpr T expand_mask(T x)
+class Mask
{
- static_assert(std::is_unsigned<T>::value, "unsigned integer type required");
- return ~is_zero(x);
- }
-
-template<typename T>
-inline constexpr T select(T mask, T from0, T from1)
- {
- static_assert(std::is_unsigned<T>::value, "unsigned integer type required");
- //return static_cast<T>((from0 & mask) | (from1 & ~mask));
- return static_cast<T>(from1 ^ (mask & (from0 ^ from1)));
- }
-
-template<typename T>
-inline constexpr T select2(T mask0, T val0, T mask1, T val1, T val2)
- {
- return select<T>(mask0, val0, select<T>(mask1, val1, val2));
- }
-
-template<typename T>
-inline constexpr T select3(T mask0, T val0, T mask1, T val1, T mask2, T val2, T val3)
- {
- return select2<T>(mask0, val0, mask1, val1, select<T>(mask2, val2, val3));
- }
-
-template<typename PredT, typename ValT>
-inline constexpr ValT val_or_zero(PredT pred_val, ValT val)
- {
- return select(CT::expand_mask<ValT>(pred_val), val, static_cast<ValT>(0));
- }
+ public:
+ static_assert(std::is_unsigned<T>::value, "CT::Mask only defined for unsigned integer types");
+
+ Mask(const Mask<T>& other) = default;
+ Mask<T>& operator=(const Mask<T>& other) = default;
+
+ /**
+ * Derive a Mask from a Mask of a larger type
+ */
+ template<typename U>
+ Mask(Mask<U> o) : m_mask(o.value())
+ {
+ static_assert(sizeof(U) > sizeof(T), "sizes ok");
+ }
+
+ /**
+ * Return a Mask<T> with all bits set
+ */
+ static Mask<T> set()
+ {
+ return Mask<T>(~0);
+ }
+
+ /**
+ * Return a Mask<T> with all bits cleared
+ */
+ static Mask<T> cleared()
+ {
+ return Mask<T>(0);
+ }
+
+ /**
+ * Return a Mask<T> which is set if v is != 0
+ */
+ static Mask<T> expand(T v)
+ {
+ return ~Mask<T>::is_zero(v);
+ }
+
+ /**
+ * Return a Mask<T> which is set if v is == 0 or cleared otherwise
+ */
+ static Mask<T> is_zero(T x)
+ {
+ return Mask<T>(expand_top_bit(~x & (x - 1)));
+ }
+
+ /**
+ * Return a Mask<T> which is set if x == y
+ */
+ static Mask<T> is_equal(T x, T y)
+ {
+ return Mask<T>::is_zero(static_cast<T>(x ^ y));
+ }
+
+ /**
+ * Return a Mask<T> which is set if x < y
+ */
+ static Mask<T> is_lt(T x, T y)
+ {
+ return Mask<T>(expand_top_bit(x^((x^y) | ((x-y)^x))));
+ }
+
+ /**
+ * Return a Mask<T> which is set if x > y
+ */
+ static Mask<T> is_gt(T x, T y)
+ {
+ return Mask<T>::is_lt(y, x);
+ }
+
+ /**
+ * Return a Mask<T> which is set if x <= y
+ */
+ static Mask<T> is_lte(T x, T y)
+ {
+ return ~Mask<T>::is_gt(x, y);
+ }
+
+ /**
+ * Return a Mask<T> which is set if x >= y
+ */
+ static Mask<T> is_gte(T x, T y)
+ {
+ return ~Mask<T>::is_lt(x, y);
+ }
+
+ /**
+ * AND-combine two masks
+ */
+ Mask<T>& operator&=(Mask<T> o)
+ {
+ m_mask &= o.value();
+ return (*this);
+ }
+
+ /**
+ * XOR-combine two masks
+ */
+ Mask<T>& operator^=(Mask<T> o)
+ {
+ m_mask ^= o.value();
+ return (*this);
+ }
+
+ /**
+ * OR-combine two masks
+ */
+ Mask<T>& operator|=(Mask<T> o)
+ {
+ m_mask |= o.value();
+ return (*this);
+ }
+
+ /**
+ * AND-combine two masks
+ */
+ friend Mask<T> operator&(Mask<T> x, Mask<T> y)
+ {
+ return Mask<T>(x.value() & y.value());
+ }
+
+ /**
+ * XOR-combine two masks
+ */
+ friend Mask<T> operator^(Mask<T> x, Mask<T> y)
+ {
+ return Mask<T>(x.value() ^ y.value());
+ }
+
+ /**
+ * OR-combine two masks
+ */
+ friend Mask<T> operator|(Mask<T> x, Mask<T> y)
+ {
+ return Mask<T>(x.value() | y.value());
+ }
+
+ /**
+ * Negate this mask
+ */
+ Mask<T> operator~() const
+ {
+ return Mask<T>(~value());
+ }
+
+ /**
+ * Return x if the mask is set, or otherwise zero
+ */
+ T if_set_return(T x) const
+ {
+ return m_mask & x;
+ }
+
+ /**
+ * Return x if the mask is cleared, or otherwise zero
+ */
+ T if_not_set_return(T x) const
+ {
+ return ~m_mask & x;
+ }
+
+ /**
+ * If this mask is set, return x, otherwise return y
+ */
+ T select(T x, T y) const
+ {
+ // (x & value()) | (y & ~value())
+ return static_cast<T>(y ^ (value() & (x ^ y)));
+ }
+
+ T select_and_unpoison(T x, T y) const
+ {
+ T r = this->select(x, y);
+ CT::unpoison(r);
+ return r;
+ }
+
+ /**
+ * If this mask is set, return x, otherwise return y
+ */
+ Mask<T> select_mask(Mask<T> x, Mask<T> y) const
+ {
+ return Mask<T>(select(x.value(), y.value()));
+ }
+
+ /**
+ * Conditionally set output to x or y, depending on if mask is set or
+ * cleared (resp)
+ */
+ void select_n(T output[], const T x[], const T y[], size_t len) const
+ {
+ for(size_t i = 0; i != len; ++i)
+ output[i] = this->select(x[i], y[i]);
+ }
+
+ /**
+ * If this mask is set, zero out buf, otherwise do nothing
+ */
+ void if_set_zero_out(T buf[], size_t elems)
+ {
+ for(size_t i = 0; i != elems; ++i)
+ {
+ buf[i] = this->if_not_set_return(buf[i]);
+ }
+ }
+
+ /**
+ * Return the value of the mask, unpoisoned
+ */
+ T unpoisoned_value() const
+ {
+ T r = value();
+ CT::unpoison(r);
+ return r;
+ }
+
+ /**
+ * Return true iff this mask is set
+ */
+ bool is_set() const
+ {
+ return unpoisoned_value() != 0;
+ }
+
+ /**
+ * Return the underlying value of the mask
+ */
+ T value() const
+ {
+ return m_mask;
+ }
+
+ private:
+ /**
+ * If top bit of arg is set, return ~0. Otherwise return 0.
+ */
+ static T expand_top_bit(T a)
+ {
+ return static_cast<T>(0) - (a >> (sizeof(T)*8-1));
+ }
+
+ Mask(T m) : m_mask(m) {}
+
+ T m_mask;
+ };
template<typename T>
-inline constexpr T is_equal(T x, T y)
+inline Mask<T> conditional_copy_mem(T cnd,
+ T* to,
+ const T* from0,
+ const T* from1,
+ size_t elems)
{
- return is_zero<T>(x ^ y);
- }
-
-template<typename T>
-inline constexpr T is_less(T a, T b)
- {
- return expand_top_bit<T>(a ^ ((a^b) | ((a-b)^a)));
- }
-
-template<typename T>
-inline constexpr T is_lte(T a, T b)
- {
- return CT::is_less(a, b) | CT::is_equal(a, b);
- }
-
-template<typename C, typename T>
-inline T conditional_return(C condvar, T left, T right)
- {
- const T val = CT::select(CT::expand_mask<T>(condvar), left, right);
- CT::unpoison(val);
- return val;
- }
-
-template<typename T>
-inline T conditional_copy_mem(T value,
- T* to,
- const T* from0,
- const T* from1,
- size_t elems)
- {
- const T mask = CT::expand_mask(value);
-
- for(size_t i = 0; i != elems; ++i)
- {
- to[i] = CT::select(mask, from0[i], from1[i]);
- }
-
+ const auto mask = CT::Mask<T>::expand(cnd);
+ mask.select_n(to, from0, from1, elems);
return mask;
}
-template<typename T>
-inline void cond_zero_mem(T cond,
- T* array,
- size_t elems)
- {
- const T mask = CT::expand_mask(cond);
- const T zero(0);
-
- for(size_t i = 0; i != elems; ++i)
- {
- array[i] = CT::select(mask, zero, array[i]);
- }
- }
-
inline secure_vector<uint8_t> strip_leading_zeros(const uint8_t in[], size_t length)
{
size_t leading_zeros = 0;
- uint8_t only_zeros = 0xFF;
+ auto only_zeros = Mask<uint8_t>::set();
for(size_t i = 0; i != length; ++i)
{
- only_zeros = only_zeros & CT::is_zero<uint8_t>(in[i]);
- leading_zeros += CT::select<uint8_t>(only_zeros, 1, 0);
+ only_zeros &= CT::Mask<uint8_t>::is_zero(in[i]);
+ leading_zeros += only_zeros.if_set_return(1);
}
return secure_vector<uint8_t>(in + leading_zeros, in + length);
diff --git a/src/lib/utils/mem_ops.cpp b/src/lib/utils/mem_ops.cpp
index 668437e9f..460fc4b69 100644
--- a/src/lib/utils/mem_ops.cpp
+++ b/src/lib/utils/mem_ops.cpp
@@ -5,6 +5,7 @@
*/
#include <botan/mem_ops.h>
+#include <botan/internal/ct_utils.h>
#include <cstdlib>
#include <new>
@@ -49,16 +50,16 @@ void initialize_allocator()
#endif
}
-bool constant_time_compare(const uint8_t x[],
- const uint8_t y[],
- size_t len)
+uint8_t ct_compare_u8(const uint8_t x[],
+ const uint8_t y[],
+ size_t len)
{
volatile uint8_t difference = 0;
for(size_t i = 0; i != len; ++i)
difference |= (x[i] ^ y[i]);
- return difference == 0;
+ return CT::Mask<uint8_t>::is_zero(difference).value();
}
}
diff --git a/src/lib/utils/mem_ops.h b/src/lib/utils/mem_ops.h
index f0599c8b2..bff15e98a 100644
--- a/src/lib/utils/mem_ops.h
+++ b/src/lib/utils/mem_ops.h
@@ -65,11 +65,25 @@ BOTAN_PUBLIC_API(2,0) void secure_scrub_memory(void* ptr, size_t n);
* @param x a pointer to an array
* @param y a pointer to another array
* @param len the number of Ts in x and y
+* @return 0xFF iff x[i] == y[i] forall i in [0...n) or 0x00 otherwise
+*/
+BOTAN_PUBLIC_API(2,9) uint8_t ct_compare_u8(const uint8_t x[],
+ const uint8_t y[],
+ size_t len);
+
+/**
+* Memory comparison, input insensitive
+* @param x a pointer to an array
+* @param y a pointer to another array
+* @param len the number of Ts in x and y
* @return true iff x[i] == y[i] forall i in [0...n)
*/
-BOTAN_PUBLIC_API(2,3) bool constant_time_compare(const uint8_t x[],
- const uint8_t y[],
- size_t len);
+inline bool constant_time_compare(const uint8_t x[],
+ const uint8_t y[],
+ size_t len)
+ {
+ return ct_compare_u8(x, y, len) == 0xFF;
+ }
/**
* Zero out some bytes
diff --git a/src/tests/test_utils.cpp b/src/tests/test_utils.cpp
index 7bbe3745c..f1c6bef43 100644
--- a/src/tests/test_utils.cpp
+++ b/src/tests/test_utils.cpp
@@ -79,44 +79,10 @@ class Utility_Function_Tests final : public Text_Based_Test
std::vector<Test::Result> results;
results.push_back(test_loadstore());
- results.push_back(test_ct_utils());
return results;
}
- Test::Result test_ct_utils()
- {
- Test::Result result("CT utils");
-
- result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(0), 0xFF);
- result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(1), 0x00);
- result.test_eq_sz("CT::is_zero8", Botan::CT::is_zero<uint8_t>(0xFF), 0x00);
-
- result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(0), 0xFFFF);
- result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(1), 0x0000);
- result.test_eq_sz("CT::is_zero16", Botan::CT::is_zero<uint16_t>(0xFF), 0x0000);
-
- result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(0), 0xFFFFFFFF);
- result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(1), 0x00000000);
- result.test_eq_sz("CT::is_zero32", Botan::CT::is_zero<uint32_t>(0xFF), 0x00000000);
-
- result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(0, 1), 0xFF);
- result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(1, 0), 0x00);
- result.test_eq_sz("CT::is_less8", Botan::CT::is_less<uint8_t>(0xFF, 5), 0x00);
-
- result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(0, 1), 0xFFFF);
- result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(1, 0), 0x0000);
- result.test_eq_sz("CT::is_less16", Botan::CT::is_less<uint16_t>(0xFFFF, 5), 0x0000);
-
- result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0, 1), 0xFFFFFFFF);
- result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(1, 0), 0x00000000);
- result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0xFFFF5, 5), 0x00000000);
- result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(0xFFFFFFFF, 5), 0x00000000);
- result.test_eq_sz("CT::is_less32", Botan::CT::is_less<uint32_t>(5, 0xFFFFFFFF), 0xFFFFFFFF);
-
- return result;
- }
-
Test::Result test_loadstore()
{
Test::Result result("Util load/store");
@@ -227,6 +193,45 @@ class Utility_Function_Tests final : public Text_Based_Test
BOTAN_REGISTER_TEST("util", Utility_Function_Tests);
+class CT_Mask_Tests final : public Test
+ {
+ public:
+ std::vector<Test::Result> run() override
+ {
+ Test::Result result("CT utils");
+
+ result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(0).value(), 0xFF);
+ result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(1).value(), 0x00);
+ result.test_eq_sz("CT::is_zero8", Botan::CT::Mask<uint8_t>::is_zero(0xFF).value(), 0x00);
+
+ result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(0).value(), 0xFFFF);
+ result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(1).value(), 0x0000);
+ result.test_eq_sz("CT::is_zero16", Botan::CT::Mask<uint16_t>::is_zero(0xFF).value(), 0x0000);
+
+ result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(0).value(), 0xFFFFFFFF);
+ result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(1).value(), 0x00000000);
+ result.test_eq_sz("CT::is_zero32", Botan::CT::Mask<uint32_t>::is_zero(0xFF).value(), 0x00000000);
+
+ result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(0, 1).value(), 0xFF);
+ result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(1, 0).value(), 0x00);
+ result.test_eq_sz("CT::is_less8", Botan::CT::Mask<uint8_t>::is_lt(0xFF, 5).value(), 0x00);
+
+ result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(0, 1).value(), 0xFFFF);
+ result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(1, 0).value(), 0x0000);
+ result.test_eq_sz("CT::is_less16", Botan::CT::Mask<uint16_t>::is_lt(0xFFFF, 5).value(), 0x0000);
+
+ result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0, 1).value(), 0xFFFFFFFF);
+ result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(1, 0).value(), 0x00000000);
+ result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0xFFFF5, 5).value(), 0x00000000);
+ result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(0xFFFFFFFF, 5).value(), 0x00000000);
+ result.test_eq_sz("CT::is_less32", Botan::CT::Mask<uint32_t>::is_lt(5, 0xFFFFFFFF).value(), 0xFFFFFFFF);
+
+ return {result};
+ }
+ };
+
+BOTAN_REGISTER_TEST("ct_utils", CT_Mask_Tests);
+
#if defined(BOTAN_HAS_POLY_DBL)
class Poly_Double_Tests final : public Text_Based_Test