aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/math
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2018-11-28 10:35:17 -0500
committerJack Lloyd <[email protected]>2018-11-28 10:35:17 -0500
commit007314c530eb12d414ced07515f8cbc25a0f64f5 (patch)
treedc887f97efa0248aa5e7b8468c94145f6a1305f8 /src/lib/math
parentb03f38f57d4f50ace1ed8b57d83ba70eb5bc1dfb (diff)
Add CT::Mask type
Diffstat (limited to 'src/lib/math')
-rw-r--r--src/lib/math/bigint/big_ops2.cpp2
-rw-r--r--src/lib/math/bigint/bigint.cpp24
-rw-r--r--src/lib/math/mp/mp_core.h105
-rw-r--r--src/lib/math/mp/mp_karat.cpp6
-rw-r--r--src/lib/math/numbertheory/monty_exp.cpp6
5 files changed, 75 insertions, 68 deletions
diff --git a/src/lib/math/bigint/big_ops2.cpp b/src/lib/math/bigint/big_ops2.cpp
index 2fd775ccf..68baf3200 100644
--- a/src/lib/math/bigint/big_ops2.cpp
+++ b/src/lib/math/bigint/big_ops2.cpp
@@ -171,7 +171,7 @@ BigInt& BigInt::mod_sub(const BigInt& s, const BigInt& mod, secure_vector<word>&
ws.resize(mod_sw);
// is t < s or not?
- const word is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw);
+ const auto is_lt = bigint_ct_is_lt(data(), mod_sw, s.data(), mod_sw);
// ws = p - s
word borrow = bigint_sub3(ws.data(), mod.data(), mod_sw, s.data(), mod_sw);
diff --git a/src/lib/math/bigint/bigint.cpp b/src/lib/math/bigint/bigint.cpp
index 2cb9394ce..667035686 100644
--- a/src/lib/math/bigint/bigint.cpp
+++ b/src/lib/math/bigint/bigint.cpp
@@ -146,7 +146,7 @@ bool BigInt::is_equal(const BigInt& other) const
return false;
return bigint_ct_is_eq(this->data(), this->sig_words(),
- other.data(), other.sig_words());
+ other.data(), other.sig_words()).is_set();
}
bool BigInt::is_less_than(const BigInt& other) const
@@ -160,11 +160,11 @@ bool BigInt::is_less_than(const BigInt& other) const
if(other.is_negative() && this->is_negative())
{
return !bigint_ct_is_lt(other.data(), other.sig_words(),
- this->data(), this->sig_words(), true);
+ this->data(), this->sig_words(), true).is_set();
}
return bigint_ct_is_lt(this->data(), this->sig_words(),
- other.data(), other.sig_words());
+ other.data(), other.sig_words()).is_set();
}
void BigInt::encode_words(word out[], size_t size) const
@@ -187,7 +187,7 @@ size_t BigInt::Data::calc_sig_words() const
for(size_t i = 0; i != m_reg.size(); ++i)
{
const word w = m_reg[m_reg.size() - i - 1];
- sub &= CT::is_zero(w);
+ sub &= CT::Mask<word>::is_zero(w).value();
sig -= sub;
}
@@ -393,13 +393,18 @@ void BigInt::ct_cond_assign(bool predicate, BigInt& other)
const size_t t_words = size();
const size_t o_words = other.size();
+ if(o_words < t_words)
+ grow_to(o_words);
+
const size_t r_words = std::max(t_words, o_words);
- const word mask = CT::expand_mask<word>(predicate);
+ const auto mask = CT::Mask<word>::expand(predicate);
for(size_t i = 0; i != r_words; ++i)
{
- this->set_word_at(i, CT::select<word>(mask, other.word_at(i), this->word_at(i)));
+ const word o_word = other.word_at(i);
+ const word t_word = this->word_at(i);
+ this->set_word_at(i, mask.select(o_word, t_word));
}
}
@@ -430,10 +435,13 @@ void BigInt::const_time_lookup(secure_vector<word>& output,
BOTAN_ASSERT(vec[i].size() >= words,
"Word size as expected in const_time_lookup");
- const word mask = CT::is_equal(i, idx);
+ const auto mask = CT::Mask<word>::is_equal(i, idx);
for(size_t w = 0; w != words; ++w)
- output[w] |= CT::select<word>(mask, vec[i].word_at(w), 0);
+ {
+ const word viw = vec[i].word_at(w);
+ output[w] = mask.if_set_return(viw);
+ }
}
CT::unpoison(idx);
diff --git a/src/lib/math/mp/mp_core.h b/src/lib/math/mp/mp_core.h
index 9a19a46be..4829ef6fc 100644
--- a/src/lib/math/mp/mp_core.h
+++ b/src/lib/math/mp/mp_core.h
@@ -30,14 +30,14 @@ const word MP_WORD_MAX = MP_WORD_MASK;
*/
inline void bigint_cnd_swap(word cnd, word x[], word y[], size_t size)
{
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
for(size_t i = 0; i != size; ++i)
{
const word a = x[i];
const word b = y[i];
- x[i] = CT::select(mask, b, a);
- y[i] = CT::select(mask, a, b);
+ x[i] = mask.select(b, a);
+ y[i] = mask.select(a, b);
}
}
@@ -46,7 +46,7 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size,
{
BOTAN_ASSERT(x_size >= y_size, "Expected sizes");
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
word carry = 0;
@@ -56,24 +56,22 @@ inline word bigint_cnd_add(word cnd, word x[], word x_size,
for(size_t i = 0; i != blocks; i += 8)
{
carry = word8_add3(z, x + i, y + i, carry);
-
- for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, z[j], x[i+j]);
+ mask.select_n(x + i, z, x + i, 8);
}
for(size_t i = blocks; i != y_size; ++i)
{
z[0] = word_add(x[i], y[i], &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
for(size_t i = y_size; i != x_size; ++i)
{
z[0] = word_add(x[i], 0, &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
- return carry & mask;
+ return mask.if_set_return(carry);
}
/*
@@ -95,7 +93,7 @@ inline word bigint_cnd_sub(word cnd,
{
BOTAN_ASSERT(x_size >= y_size, "Expected sizes");
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
word carry = 0;
@@ -105,24 +103,22 @@ inline word bigint_cnd_sub(word cnd,
for(size_t i = 0; i != blocks; i += 8)
{
carry = word8_sub3(z, x + i, y + i, carry);
-
- for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, z[j], x[i+j]);
+ mask.select_n(x + i, z, x + i, 8);
}
for(size_t i = blocks; i != y_size; ++i)
{
z[0] = word_sub(x[i], y[i], &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
for(size_t i = y_size; i != x_size; ++i)
{
z[0] = word_sub(x[i], 0, &carry);
- x[i] = CT::select(mask, z[0], x[i]);
+ x[i] = mask.select(z[0], x[i]);
}
- return carry & mask;
+ return mask.if_set_return(carry);
}
/*
@@ -142,7 +138,7 @@ inline word bigint_cnd_sub(word cnd, word x[], const word y[], size_t size)
*
* Mask must be either 0 or all 1 bits
*/
-inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
+inline void bigint_cnd_addsub(CT::Mask<word> mask, word x[], const word y[], size_t size)
{
const size_t blocks = size - (size % 8);
@@ -158,7 +154,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
borrow = word8_sub3(t1, x + i, y + i, borrow);
for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, t0[j], t1[j]);
+ x[i+j] = mask.select(t0[j], t1[j]);
}
for(size_t i = blocks; i != size; ++i)
@@ -166,7 +162,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
const word a = word_add(x[i], y[i], &carry);
const word s = word_sub(x[i], y[i], &borrow);
- x[i] = CT::select(mask, a, s);
+ x[i] = mask.select(a, s);
}
}
@@ -179,7 +175,7 @@ inline void bigint_cnd_addsub(word mask, word x[], const word y[], size_t size)
*
* Returns the carry or borrow resp
*/
-inline word bigint_cnd_addsub(word mask, word x[],
+inline word bigint_cnd_addsub(CT::Mask<word> mask, word x[],
const word y[], const word z[],
size_t size)
{
@@ -197,17 +193,17 @@ inline word bigint_cnd_addsub(word mask, word x[],
borrow = word8_sub3(t1, x + i, z + i, borrow);
for(size_t j = 0; j != 8; ++j)
- x[i+j] = CT::select(mask, t0[j], t1[j]);
+ x[i+j] = mask.select(t0[j], t1[j]);
}
for(size_t i = blocks; i != size; ++i)
{
t0[0] = word_add(x[i], y[i], &carry);
t1[0] = word_sub(x[i], z[i], &borrow);
- x[i] = CT::select(mask, t0[0], t1[0]);
+ x[i] = mask.select(t0[0], t1[0]);
}
- return CT::select(mask, carry, borrow);
+ return mask.select(carry, borrow);
}
/*
@@ -217,13 +213,13 @@ inline word bigint_cnd_addsub(word mask, word x[],
*/
inline void bigint_cnd_abs(word cnd, word x[], size_t size)
{
- const word mask = CT::expand_mask(cnd);
+ const auto mask = CT::Mask<word>::expand(cnd);
- word carry = mask & 1;
+ word carry = mask.if_set_return(1);
for(size_t i = 0; i != size; ++i)
{
const word z = word_add(~x[i], 0, &carry);
- x[i] = CT::select(mask, z, x[i]);
+ x[i] = mask.select(z, x[i]);
}
}
@@ -379,9 +375,10 @@ inline word bigint_sub3(word z[],
* @param N length of x and y
* @param ws array of at least 2*N words
*/
-inline word bigint_sub_abs(word z[],
- const word x[], const word y[], size_t N,
- word ws[])
+inline CT::Mask<word>
+bigint_sub_abs(word z[],
+ const word x[], const word y[], size_t N,
+ word ws[])
{
// Subtract in both direction then conditional copy out the result
@@ -544,19 +541,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
{
static_assert(sizeof(word) >= sizeof(uint32_t), "Size assumption");
- const uint32_t LT = static_cast<uint32_t>(-1);
- const uint32_t EQ = 0;
- const uint32_t GT = 1;
+ const word LT = static_cast<word>(-1);
+ const word EQ = 0;
+ const word GT = 1;
const size_t common_elems = std::min(x_size, y_size);
- uint32_t result = EQ; // until found otherwise
+ word result = EQ; // until found otherwise
for(size_t i = 0; i != common_elems; i++)
{
- const word is_eq = CT::is_equal(x[i], y[i]);
- const word is_lt = CT::is_less(x[i], y[i]);
- result = CT::select<uint32_t>(is_eq, result, CT::select<uint32_t>(is_lt, LT, GT));
+ const auto is_eq = CT::Mask<word>::is_equal(x[i], y[i]);
+ const auto is_lt = CT::Mask<word>::is_lt(x[i], y[i]);
+
+ result = is_eq.select(result, is_lt.select(LT, GT));
}
if(x_size < y_size)
@@ -566,7 +564,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
mask |= y[i];
// If any bits were set in high part of y, then x < y
- result = CT::select<uint32_t>(CT::is_zero(mask), result, LT);
+ result = CT::Mask<word>::is_zero(mask).select(result, LT);
}
else if(y_size < x_size)
{
@@ -575,7 +573,7 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
mask |= x[i];
// If any bits were set in high part of x, then x > y
- result = CT::select<uint32_t>(CT::is_zero(mask), result, GT);
+ result = CT::Mask<word>::is_zero(mask).select(result, GT);
}
CT::unpoison(result);
@@ -588,19 +586,20 @@ inline int32_t bigint_cmp(const word x[], size_t x_size,
* Return ~0 if x[0:x_size] < y[0:y_size] or 0 otherwise
* If lt_or_equal is true, returns ~0 also for x == y
*/
-inline word bigint_ct_is_lt(const word x[], size_t x_size,
- const word y[], size_t y_size,
- bool lt_or_equal = false)
+inline CT::Mask<word>
+bigint_ct_is_lt(const word x[], size_t x_size,
+ const word y[], size_t y_size,
+ bool lt_or_equal = false)
{
const size_t common_elems = std::min(x_size, y_size);
- word is_lt = CT::expand_mask<word>(lt_or_equal);
+ auto is_lt = CT::Mask<word>::expand(lt_or_equal);
for(size_t i = 0; i != common_elems; i++)
{
- const word eq = CT::is_equal(x[i], y[i]);
- const word lt = CT::is_less(x[i], y[i]);
- is_lt = CT::select(eq, is_lt, lt);
+ const auto eq = CT::Mask<word>::is_equal(x[i], y[i]);
+ const auto lt = CT::Mask<word>::is_lt(x[i], y[i]);
+ is_lt = eq.select_mask(is_lt, lt);
}
if(x_size < y_size)
@@ -609,7 +608,7 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size,
for(size_t i = x_size; i != y_size; i++)
mask |= y[i];
// If any bits were set in high part of y, then is_lt should be forced true
- is_lt |= ~CT::is_zero(mask);
+ is_lt |= CT::Mask<word>::expand(mask);
}
else if(y_size < x_size)
{
@@ -618,15 +617,15 @@ inline word bigint_ct_is_lt(const word x[], size_t x_size,
mask |= x[i];
// If any bits were set in high part of x, then is_lt should be false
- is_lt &= CT::is_zero(mask);
+ is_lt &= CT::Mask<word>::is_zero(mask);
}
- CT::unpoison(is_lt);
return is_lt;
}
-inline word bigint_ct_is_eq(const word x[], size_t x_size,
- const word y[], size_t y_size)
+inline CT::Mask<word>
+bigint_ct_is_eq(const word x[], size_t x_size,
+ const word y[], size_t y_size)
{
const size_t common_elems = std::min(x_size, y_size);
@@ -649,9 +648,7 @@ inline word bigint_ct_is_eq(const word x[], size_t x_size,
diff |= x[i];
}
- const word is_equal = CT::is_zero(diff);
- CT::unpoison(is_equal);
- return is_equal;
+ return CT::Mask<word>::is_zero(diff);
}
/**
diff --git a/src/lib/math/mp/mp_karat.cpp b/src/lib/math/mp/mp_karat.cpp
index c0fc5304b..8bd7cf58d 100644
--- a/src/lib/math/mp/mp_karat.cpp
+++ b/src/lib/math/mp/mp_karat.cpp
@@ -124,9 +124,9 @@ void karatsuba_mul(word z[], const word x[], const word y[], size_t N,
*/
// First compute (X_lo - X_hi)*(Y_hi - Y_lo)
- const word cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
- const word cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
- const word neg_mask = ~(cmp0 ^ cmp1);
+ const auto cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
+ const auto cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
+ const auto neg_mask = ~(cmp0 ^ cmp1);
karatsuba_mul(ws0, z0, z1, N2, ws1);
diff --git a/src/lib/math/numbertheory/monty_exp.cpp b/src/lib/math/numbertheory/monty_exp.cpp
index 2b5bbd81d..7590005a0 100644
--- a/src/lib/math/numbertheory/monty_exp.cpp
+++ b/src/lib/math/numbertheory/monty_exp.cpp
@@ -85,10 +85,12 @@ void const_time_lookup(secure_vector<word>& output,
BOTAN_ASSERT(vec.size() >= words,
"Word size as expected in const_time_lookup");
- const word mask = CT::is_equal<word>(i, nibble);
+ const auto mask = CT::Mask<word>::is_equal(i, nibble);
for(size_t w = 0; w != words; ++w)
- output[w] |= (mask & vec[w]);
+ {
+ output[w] |= mask.if_set_return(vec[w]);
+ }
}
}