diff options
author | Jack Lloyd <[email protected]> | 2018-04-18 18:47:11 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2018-08-26 01:13:15 -0400 |
commit | 10e2aebeb3c21ed30e77167a49adca2b5ac4755d (patch) | |
tree | a82b95bf521a7358ee88f65f0abde87b9b2196a7 /src/lib/stream/chacha | |
parent | a955c8a777550535bc3b6922395529e6a11b4c9d (diff) |
Add AVX2 version of ChaCha
Diffstat (limited to 'src/lib/stream/chacha')
-rw-r--r-- | src/lib/stream/chacha/chacha.cpp | 39 | ||||
-rw-r--r-- | src/lib/stream/chacha/chacha.h | 8 | ||||
-rw-r--r-- | src/lib/stream/chacha/chacha_avx2/chacha_avx2.cpp | 264 | ||||
-rw-r--r-- | src/lib/stream/chacha/chacha_avx2/info.txt | 5 |
4 files changed, 304 insertions, 12 deletions
diff --git a/src/lib/stream/chacha/chacha.cpp b/src/lib/stream/chacha/chacha.cpp index 8edb685da..0670faa5e 100644 --- a/src/lib/stream/chacha/chacha.cpp +++ b/src/lib/stream/chacha/chacha.cpp @@ -67,6 +67,13 @@ ChaCha::ChaCha(size_t rounds) : m_rounds(rounds) std::string ChaCha::provider() const { +#if defined(BOTAN_HAS_CHACHA_AVX2) + if(CPUID::has_avx2()) + { + return "avx2"; + } +#endif + #if defined(BOTAN_HAS_CHACHA_SSE2) if(CPUID::has_sse2()) { @@ -78,19 +85,28 @@ std::string ChaCha::provider() const } //static -void ChaCha::chacha_x4(uint8_t output[64*4], uint32_t input[16], size_t rounds) +void ChaCha::chacha_x8(uint8_t output[64*8], uint32_t input[16], size_t rounds) { BOTAN_ASSERT(rounds % 2 == 0, "Valid rounds"); +#if defined(BOTAN_HAS_CHACHA_AVX2) + if(CPUID::has_avx2()) + { + return ChaCha::chacha_avx2_x8(output, input, rounds); + } +#endif + #if defined(BOTAN_HAS_CHACHA_SSE2) if(CPUID::has_sse2()) { - return ChaCha::chacha_sse2_x4(output, input, rounds); + ChaCha::chacha_sse2_x4(output, input, rounds); + ChaCha::chacha_sse2_x4(output + 4*64, input, rounds); + return; } #endif // TODO interleave rounds - for(size_t i = 0; i != 4; ++i) + for(size_t i = 0; i != 8; ++i) { uint32_t x00 = input[ 0], x01 = input[ 1], x02 = input[ 2], x03 = input[ 3], x04 = input[ 4], x05 = input[ 5], x06 = input[ 6], x07 = input[ 7], @@ -110,8 +126,6 @@ void ChaCha::chacha_x4(uint8_t output[64*4], uint32_t input[16], size_t rounds) CHACHA_QUARTER_ROUND(x03, x04, x09, x14); } -#undef CHACHA_QUARTER_ROUND - x00 += input[0]; x01 += input[1]; x02 += input[2]; @@ -151,6 +165,8 @@ void ChaCha::chacha_x4(uint8_t output[64*4], uint32_t input[16], size_t rounds) } } +#undef CHACHA_QUARTER_ROUND + /* * Combine cipher stream with message */ @@ -164,7 +180,7 @@ void ChaCha::cipher(const uint8_t in[], uint8_t out[], size_t length) length -= (m_buffer.size() - m_position); in += (m_buffer.size() - m_position); out += (m_buffer.size() - m_position); - chacha_x4(m_buffer.data(), m_state.data(), m_rounds); + chacha_x8(m_buffer.data(), m_state.data(), m_rounds); m_position = 0; } @@ -182,7 +198,7 @@ void ChaCha::write_keystream(uint8_t out[], size_t length) copy_mem(out, &m_buffer[m_position], m_buffer.size() - m_position); length -= (m_buffer.size() - m_position); out += (m_buffer.size() - m_position); - chacha_x4(m_buffer.data(), m_state.data(), m_rounds); + chacha_x8(m_buffer.data(), m_state.data(), m_rounds); m_position = 0; } @@ -246,7 +262,10 @@ void ChaCha::key_schedule(const uint8_t key[], size_t length) load_le<uint32_t>(m_key.data(), key, m_key.size()); m_state.resize(16); - m_buffer.resize(4*64); + + const size_t chacha_parallelism = 8; // chacha_x8 + const size_t chacha_block = 64; + m_buffer.resize(chacha_parallelism * chacha_block); set_iv(nullptr, 0); } @@ -321,7 +340,7 @@ void ChaCha::set_iv(const uint8_t iv[], size_t length) m_state[15] = load_le<uint32_t>(iv, 5); } - chacha_x4(m_buffer.data(), m_state.data(), m_rounds); + chacha_x8(m_buffer.data(), m_state.data(), m_rounds); m_position = 0; } @@ -352,7 +371,7 @@ void ChaCha::seek(uint64_t offset) m_state[12] = load_le<uint32_t>(out, 0); m_state[13] += load_le<uint32_t>(out, 1); - chacha_x4(m_buffer.data(), m_state.data(), m_rounds); + chacha_x8(m_buffer.data(), m_state.data(), m_rounds); m_position = offset % 64; } } diff --git a/src/lib/stream/chacha/chacha.h b/src/lib/stream/chacha/chacha.h index 346e25c28..390c3b788 100644 --- a/src/lib/stream/chacha/chacha.h +++ b/src/lib/stream/chacha/chacha.h @@ -1,6 +1,6 @@ /* * ChaCha20 -* (C) 2014 Jack Lloyd +* (C) 2014,2018 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -58,12 +58,16 @@ class BOTAN_PUBLIC_API(2,0) ChaCha final : public StreamCipher void initialize_state(); - void chacha_x4(uint8_t output[64*4], uint32_t state[16], size_t rounds); + void chacha_x8(uint8_t output[64*8], uint32_t state[16], size_t rounds); #if defined(BOTAN_HAS_CHACHA_SSE2) void chacha_sse2_x4(uint8_t output[64*4], uint32_t state[16], size_t rounds); #endif +#if defined(BOTAN_HAS_CHACHA_AVX2) + void chacha_avx2_x8(uint8_t output[64*8], uint32_t state[16], size_t rounds); +#endif + size_t m_rounds; secure_vector<uint32_t> m_key; secure_vector<uint32_t> m_state; diff --git a/src/lib/stream/chacha/chacha_avx2/chacha_avx2.cpp b/src/lib/stream/chacha/chacha_avx2/chacha_avx2.cpp new file mode 100644 index 000000000..5f5ad820f --- /dev/null +++ b/src/lib/stream/chacha/chacha_avx2/chacha_avx2.cpp @@ -0,0 +1,264 @@ +/* +* (C) 2018 Jack Lloyd +* +* Botan is released under the Simplified BSD License (see license.txt) +*/ + +#include <botan/chacha.h> +#include <immintrin.h> + +namespace Botan { + +//static +BOTAN_FUNC_ISA("avx2") +void ChaCha::chacha_avx2_x8(uint8_t output[64*8], uint32_t input[16], size_t rounds) + { + BOTAN_ASSERT(rounds % 2 == 0, "Valid rounds"); + + const __m128i* input_mm = reinterpret_cast<const __m128i*>(input); + __m256i* output_mm = reinterpret_cast<__m256i*>(output); + __m128i* output_mm128 = reinterpret_cast<__m128i*>(output); + + const __m256i input0 = _mm256_broadcastsi128_si256(_mm_loadu_si128(input_mm)); + const __m256i input1 = _mm256_broadcastsi128_si256(_mm_loadu_si128(input_mm + 1)); + const __m256i input2 = _mm256_broadcastsi128_si256(_mm_loadu_si128(input_mm + 2)); + const __m256i input3 = _mm256_broadcastsi128_si256(_mm_loadu_si128(input_mm + 3)); + + const __m256i CTR0 = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 4); + const __m256i CTR1 = _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 5); + const __m256i CTR2 = _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 6); + const __m256i CTR3 = _mm256_set_epi32(0, 0, 0, 3, 0, 0, 0, 7); + + const __m256i shuf_rotl_16 = _mm256_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2); + const __m256i shuf_rotl_8 = _mm256_set_epi8(14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, + 14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3); + +#define mm_rotl(r, n) \ + _mm256_or_si256(_mm256_slli_epi32(r, n), _mm256_srli_epi32(r, 32-n)) + + __m256i X0_0 = input0; + __m256i X0_1 = input1; + __m256i X0_2 = input2; + __m256i X0_3 = _mm256_add_epi64(input3, CTR0); + + __m256i X1_0 = input0; + __m256i X1_1 = input1; + __m256i X1_2 = input2; + __m256i X1_3 = _mm256_add_epi64(input3, CTR1); + + __m256i X2_0 = input0; + __m256i X2_1 = input1; + __m256i X2_2 = input2; + __m256i X2_3 = _mm256_add_epi64(input3, CTR2); + + __m256i X3_0 = input0; + __m256i X3_1 = input1; + __m256i X3_2 = input2; + __m256i X3_3 = _mm256_add_epi64(input3, CTR3); + + for(size_t r = 0; r != rounds / 2; ++r) + { + X0_0 = _mm256_add_epi32(X0_0, X0_1); + X1_0 = _mm256_add_epi32(X1_0, X1_1); + X2_0 = _mm256_add_epi32(X2_0, X2_1); + X3_0 = _mm256_add_epi32(X3_0, X3_1); + + X0_3 = _mm256_xor_si256(X0_3, X0_0); + X1_3 = _mm256_xor_si256(X1_3, X1_0); + X2_3 = _mm256_xor_si256(X2_3, X2_0); + X3_3 = _mm256_xor_si256(X3_3, X3_0); + + X0_3 = _mm256_shuffle_epi8(X0_3, shuf_rotl_16); + X1_3 = _mm256_shuffle_epi8(X1_3, shuf_rotl_16); + X2_3 = _mm256_shuffle_epi8(X2_3, shuf_rotl_16); + X3_3 = _mm256_shuffle_epi8(X3_3, shuf_rotl_16); + + X0_2 = _mm256_add_epi32(X0_2, X0_3); + X1_2 = _mm256_add_epi32(X1_2, X1_3); + X2_2 = _mm256_add_epi32(X2_2, X2_3); + X3_2 = _mm256_add_epi32(X3_2, X3_3); + + X0_1 = _mm256_xor_si256(X0_1, X0_2); + X1_1 = _mm256_xor_si256(X1_1, X1_2); + X2_1 = _mm256_xor_si256(X2_1, X2_2); + X3_1 = _mm256_xor_si256(X3_1, X3_2); + + X0_1 = mm_rotl(X0_1, 12); + X1_1 = mm_rotl(X1_1, 12); + X2_1 = mm_rotl(X2_1, 12); + X3_1 = mm_rotl(X3_1, 12); + + X0_0 = _mm256_add_epi32(X0_0, X0_1); + X1_0 = _mm256_add_epi32(X1_0, X1_1); + X2_0 = _mm256_add_epi32(X2_0, X2_1); + X3_0 = _mm256_add_epi32(X3_0, X3_1); + + X0_3 = _mm256_xor_si256(X0_3, X0_0); + X1_3 = _mm256_xor_si256(X1_3, X1_0); + X2_3 = _mm256_xor_si256(X2_3, X2_0); + X3_3 = _mm256_xor_si256(X3_3, X3_0); + + X0_3 = _mm256_shuffle_epi8(X0_3, shuf_rotl_8); + X1_3 = _mm256_shuffle_epi8(X1_3, shuf_rotl_8); + X2_3 = _mm256_shuffle_epi8(X2_3, shuf_rotl_8); + X3_3 = _mm256_shuffle_epi8(X3_3, shuf_rotl_8); + + X0_2 = _mm256_add_epi32(X0_2, X0_3); + X1_2 = _mm256_add_epi32(X1_2, X1_3); + X2_2 = _mm256_add_epi32(X2_2, X2_3); + X3_2 = _mm256_add_epi32(X3_2, X3_3); + + X0_1 = _mm256_xor_si256(X0_1, X0_2); + X1_1 = _mm256_xor_si256(X1_1, X1_2); + X2_1 = _mm256_xor_si256(X2_1, X2_2); + X3_1 = _mm256_xor_si256(X3_1, X3_2); + + X0_1 = mm_rotl(X0_1, 7); + X1_1 = mm_rotl(X1_1, 7); + X2_1 = mm_rotl(X2_1, 7); + X3_1 = mm_rotl(X3_1, 7); + + X0_1 = _mm256_shuffle_epi32(X0_1, _MM_SHUFFLE(0, 3, 2, 1)); + X0_2 = _mm256_shuffle_epi32(X0_2, _MM_SHUFFLE(1, 0, 3, 2)); + X0_3 = _mm256_shuffle_epi32(X0_3, _MM_SHUFFLE(2, 1, 0, 3)); + + X1_1 = _mm256_shuffle_epi32(X1_1, _MM_SHUFFLE(0, 3, 2, 1)); + X1_2 = _mm256_shuffle_epi32(X1_2, _MM_SHUFFLE(1, 0, 3, 2)); + X1_3 = _mm256_shuffle_epi32(X1_3, _MM_SHUFFLE(2, 1, 0, 3)); + + X2_1 = _mm256_shuffle_epi32(X2_1, _MM_SHUFFLE(0, 3, 2, 1)); + X2_2 = _mm256_shuffle_epi32(X2_2, _MM_SHUFFLE(1, 0, 3, 2)); + X2_3 = _mm256_shuffle_epi32(X2_3, _MM_SHUFFLE(2, 1, 0, 3)); + + X3_1 = _mm256_shuffle_epi32(X3_1, _MM_SHUFFLE(0, 3, 2, 1)); + X3_2 = _mm256_shuffle_epi32(X3_2, _MM_SHUFFLE(1, 0, 3, 2)); + X3_3 = _mm256_shuffle_epi32(X3_3, _MM_SHUFFLE(2, 1, 0, 3)); + + X0_0 = _mm256_add_epi32(X0_0, X0_1); + X1_0 = _mm256_add_epi32(X1_0, X1_1); + X2_0 = _mm256_add_epi32(X2_0, X2_1); + X3_0 = _mm256_add_epi32(X3_0, X3_1); + + X0_3 = _mm256_xor_si256(X0_3, X0_0); + X1_3 = _mm256_xor_si256(X1_3, X1_0); + X2_3 = _mm256_xor_si256(X2_3, X2_0); + X3_3 = _mm256_xor_si256(X3_3, X3_0); + + X0_3 = _mm256_shuffle_epi8(X0_3, shuf_rotl_16); + X1_3 = _mm256_shuffle_epi8(X1_3, shuf_rotl_16); + X2_3 = _mm256_shuffle_epi8(X2_3, shuf_rotl_16); + X3_3 = _mm256_shuffle_epi8(X3_3, shuf_rotl_16); + + X0_2 = _mm256_add_epi32(X0_2, X0_3); + X1_2 = _mm256_add_epi32(X1_2, X1_3); + X2_2 = _mm256_add_epi32(X2_2, X2_3); + X3_2 = _mm256_add_epi32(X3_2, X3_3); + + X0_1 = _mm256_xor_si256(X0_1, X0_2); + X1_1 = _mm256_xor_si256(X1_1, X1_2); + X2_1 = _mm256_xor_si256(X2_1, X2_2); + X3_1 = _mm256_xor_si256(X3_1, X3_2); + + X0_1 = mm_rotl(X0_1, 12); + X1_1 = mm_rotl(X1_1, 12); + X2_1 = mm_rotl(X2_1, 12); + X3_1 = mm_rotl(X3_1, 12); + + X0_0 = _mm256_add_epi32(X0_0, X0_1); + X1_0 = _mm256_add_epi32(X1_0, X1_1); + X2_0 = _mm256_add_epi32(X2_0, X2_1); + X3_0 = _mm256_add_epi32(X3_0, X3_1); + + X0_3 = _mm256_xor_si256(X0_3, X0_0); + X1_3 = _mm256_xor_si256(X1_3, X1_0); + X2_3 = _mm256_xor_si256(X2_3, X2_0); + X3_3 = _mm256_xor_si256(X3_3, X3_0); + + X0_3 = _mm256_shuffle_epi8(X0_3, shuf_rotl_8); + X1_3 = _mm256_shuffle_epi8(X1_3, shuf_rotl_8); + X2_3 = _mm256_shuffle_epi8(X2_3, shuf_rotl_8); + X3_3 = _mm256_shuffle_epi8(X3_3, shuf_rotl_8); + + X0_2 = _mm256_add_epi32(X0_2, X0_3); + X1_2 = _mm256_add_epi32(X1_2, X1_3); + X2_2 = _mm256_add_epi32(X2_2, X2_3); + X3_2 = _mm256_add_epi32(X3_2, X3_3); + + X0_1 = _mm256_xor_si256(X0_1, X0_2); + X1_1 = _mm256_xor_si256(X1_1, X1_2); + X2_1 = _mm256_xor_si256(X2_1, X2_2); + X3_1 = _mm256_xor_si256(X3_1, X3_2); + + X0_1 = mm_rotl(X0_1, 7); + X1_1 = mm_rotl(X1_1, 7); + X2_1 = mm_rotl(X2_1, 7); + X3_1 = mm_rotl(X3_1, 7); + + X0_1 = _mm256_shuffle_epi32(X0_1, _MM_SHUFFLE(2, 1, 0, 3)); + X0_2 = _mm256_shuffle_epi32(X0_2, _MM_SHUFFLE(1, 0, 3, 2)); + X0_3 = _mm256_shuffle_epi32(X0_3, _MM_SHUFFLE(0, 3, 2, 1)); + + X1_1 = _mm256_shuffle_epi32(X1_1, _MM_SHUFFLE(2, 1, 0, 3)); + X1_2 = _mm256_shuffle_epi32(X1_2, _MM_SHUFFLE(1, 0, 3, 2)); + X1_3 = _mm256_shuffle_epi32(X1_3, _MM_SHUFFLE(0, 3, 2, 1)); + + X2_1 = _mm256_shuffle_epi32(X2_1, _MM_SHUFFLE(2, 1, 0, 3)); + X2_2 = _mm256_shuffle_epi32(X2_2, _MM_SHUFFLE(1, 0, 3, 2)); + X2_3 = _mm256_shuffle_epi32(X2_3, _MM_SHUFFLE(0, 3, 2, 1)); + + X3_1 = _mm256_shuffle_epi32(X3_1, _MM_SHUFFLE(2, 1, 0, 3)); + X3_2 = _mm256_shuffle_epi32(X3_2, _MM_SHUFFLE(1, 0, 3, 2)); + X3_3 = _mm256_shuffle_epi32(X3_3, _MM_SHUFFLE(0, 3, 2, 1)); + } + + X0_0 = _mm256_add_epi32(X0_0, input0); + X0_1 = _mm256_add_epi32(X0_1, input1); + X0_2 = _mm256_add_epi32(X0_2, input2); + X0_3 = _mm256_add_epi32(X0_3, input3); + X0_3 = _mm256_add_epi64(X0_3, CTR0); + + X1_0 = _mm256_add_epi32(X1_0, input0); + X1_1 = _mm256_add_epi32(X1_1, input1); + X1_2 = _mm256_add_epi32(X1_2, input2); + X1_3 = _mm256_add_epi32(X1_3, input3); + X1_3 = _mm256_add_epi64(X1_3, CTR1); + + X2_0 = _mm256_add_epi32(X2_0, input0); + X2_1 = _mm256_add_epi32(X2_1, input1); + X2_2 = _mm256_add_epi32(X2_2, input2); + X2_3 = _mm256_add_epi32(X2_3, input3); + X2_3 = _mm256_add_epi64(X2_3, CTR2); + + X3_0 = _mm256_add_epi32(X3_0, input0); + X3_1 = _mm256_add_epi32(X3_1, input1); + X3_2 = _mm256_add_epi32(X3_2, input2); + X3_3 = _mm256_add_epi32(X3_3, input3); + X3_3 = _mm256_add_epi64(X3_3, CTR3); + + _mm256_storeu_si256(output_mm , _mm256_permute2x128_si256(X0_0, X0_1, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 1, _mm256_permute2x128_si256(X0_2, X0_3, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 2, _mm256_permute2x128_si256(X1_0, X1_1, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 3, _mm256_permute2x128_si256(X1_2, X1_3, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 4, _mm256_permute2x128_si256(X2_0, X2_1, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 5, _mm256_permute2x128_si256(X2_2, X2_3, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 6, _mm256_permute2x128_si256(X3_0, X3_1, 1 + (3 << 4))); + _mm256_storeu_si256(output_mm + 7, _mm256_permute2x128_si256(X3_2, X3_3, 1 + (3 << 4))); + + _mm256_storeu_si256(output_mm + 8, _mm256_permute2x128_si256(X0_0, X0_1, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 9, _mm256_permute2x128_si256(X0_2, X0_3, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 10, _mm256_permute2x128_si256(X1_0, X1_1, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 11, _mm256_permute2x128_si256(X1_2, X1_3, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 12, _mm256_permute2x128_si256(X2_0, X2_1, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 13, _mm256_permute2x128_si256(X2_2, X2_3, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 14, _mm256_permute2x128_si256(X3_0, X3_1, 0 + (2 << 4))); + _mm256_storeu_si256(output_mm + 15, _mm256_permute2x128_si256(X3_2, X3_3, 0 + (2 << 4))); + +#undef mm_rotl + + input[12] += 8; + if(input[12] < 8) + input[13]++; + + } +} diff --git a/src/lib/stream/chacha/chacha_avx2/info.txt b/src/lib/stream/chacha/chacha_avx2/info.txt new file mode 100644 index 000000000..3ec1e39d5 --- /dev/null +++ b/src/lib/stream/chacha/chacha_avx2/info.txt @@ -0,0 +1,5 @@ +<defines> +CHACHA_AVX2 -> 20180418 +</defines> + +need_isa avx2 |