diff options
Diffstat (limited to 'src/block/aes_ni/aes_ni.cpp')
-rw-r--r-- | src/block/aes_ni/aes_ni.cpp | 779 |
1 files changed, 779 insertions, 0 deletions
diff --git a/src/block/aes_ni/aes_ni.cpp b/src/block/aes_ni/aes_ni.cpp new file mode 100644 index 000000000..3ee0e608c --- /dev/null +++ b/src/block/aes_ni/aes_ni.cpp @@ -0,0 +1,779 @@ +/* +* AES using AES-NI instructions +* (C) 2009 Jack Lloyd +* +* Distributed under the terms of the Botan license +*/ + +#include <botan/aes_ni.h> +#include <botan/loadstor.h> +#include <wmmintrin.h> + +namespace Botan { + +namespace { + +__m128i aes_128_key_expansion(__m128i key, __m128i key_with_rcon) + { + key_with_rcon = _mm_shuffle_epi32(key_with_rcon, _MM_SHUFFLE(3,3,3,3)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + return _mm_xor_si128(key, key_with_rcon); + } + +void aes_192_key_expansion(__m128i* K1, __m128i* K2, __m128i key2_with_rcon, + u32bit out[], bool last) + { + __m128i key1 = *K1; + __m128i key2 = *K2; + + key2_with_rcon = _mm_shuffle_epi32(key2_with_rcon, _MM_SHUFFLE(1,1,1,1)); + key1 = _mm_xor_si128(key1, _mm_slli_si128(key1, 4)); + key1 = _mm_xor_si128(key1, _mm_slli_si128(key1, 4)); + key1 = _mm_xor_si128(key1, _mm_slli_si128(key1, 4)); + key1 = _mm_xor_si128(key1, key2_with_rcon); + + *K1 = key1; + _mm_storeu_si128((__m128i*)out, key1); + + if(last) + return; + + key2 = _mm_xor_si128(key2, _mm_slli_si128(key2, 4)); + key2 = _mm_xor_si128(key2, _mm_shuffle_epi32(key1, _MM_SHUFFLE(3,3,3,3))); + + *K2 = key2; + out[4] = _mm_cvtsi128_si32(key2); + out[5] = _mm_cvtsi128_si32(_mm_srli_si128(key2, 4)); + } + +/* +* The second half of the AES-256 key expansion (other half same as AES-128) +*/ +__m128i aes_256_key_expansion(__m128i key, __m128i key2) + { + __m128i key_with_rcon = _mm_aeskeygenassist_si128(key2, 0x00); + key_with_rcon = _mm_shuffle_epi32(key_with_rcon, _MM_SHUFFLE(2,2,2,2)); + + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + return _mm_xor_si128(key, key_with_rcon); + } + +} + +#define AES_ENC_4_ROUNDS(K) \ + do \ + { \ + B0 = _mm_aesenc_si128(B0, K); \ + B1 = _mm_aesenc_si128(B1, K); \ + B2 = _mm_aesenc_si128(B2, K); \ + B3 = _mm_aesenc_si128(B3, K); \ + } while(0) + +#define AES_ENC_4_LAST_ROUNDS(K) \ + do \ + { \ + B0 = _mm_aesenclast_si128(B0, K); \ + B1 = _mm_aesenclast_si128(B1, K); \ + B2 = _mm_aesenclast_si128(B2, K); \ + B3 = _mm_aesenclast_si128(B3, K); \ + } while(0) + +#define AES_DEC_4_ROUNDS(K) \ + do \ + { \ + B0 = _mm_aesdec_si128(B0, K); \ + B1 = _mm_aesdec_si128(B1, K); \ + B2 = _mm_aesdec_si128(B2, K); \ + B3 = _mm_aesdec_si128(B3, K); \ + } while(0) + +#define AES_DEC_4_LAST_ROUNDS(K) \ + do \ + { \ + B0 = _mm_aesdeclast_si128(B0, K); \ + B1 = _mm_aesdeclast_si128(B1, K); \ + B2 = _mm_aesdeclast_si128(B2, K); \ + B3 = _mm_aesdeclast_si128(B3, K); \ + } while(0) + +/* +* AES-128 Encryption +*/ +void AES_128_NI::encrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&EK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_ENC_4_ROUNDS(K1); + AES_ENC_4_ROUNDS(K2); + AES_ENC_4_ROUNDS(K3); + AES_ENC_4_ROUNDS(K4); + AES_ENC_4_ROUNDS(K5); + AES_ENC_4_ROUNDS(K6); + AES_ENC_4_ROUNDS(K7); + AES_ENC_4_ROUNDS(K8); + AES_ENC_4_ROUNDS(K9); + AES_ENC_4_LAST_ROUNDS(K10); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesenc_si128(B, K1); + B = _mm_aesenc_si128(B, K2); + B = _mm_aesenc_si128(B, K3); + B = _mm_aesenc_si128(B, K4); + B = _mm_aesenc_si128(B, K5); + B = _mm_aesenc_si128(B, K6); + B = _mm_aesenc_si128(B, K7); + B = _mm_aesenc_si128(B, K8); + B = _mm_aesenc_si128(B, K9); + B = _mm_aesenclast_si128(B, K10); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-128 Decryption +*/ +void AES_128_NI::decrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&DK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_DEC_4_ROUNDS(K1); + AES_DEC_4_ROUNDS(K2); + AES_DEC_4_ROUNDS(K3); + AES_DEC_4_ROUNDS(K4); + AES_DEC_4_ROUNDS(K5); + AES_DEC_4_ROUNDS(K6); + AES_DEC_4_ROUNDS(K7); + AES_DEC_4_ROUNDS(K8); + AES_DEC_4_ROUNDS(K9); + AES_DEC_4_LAST_ROUNDS(K10); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesdec_si128(B, K1); + B = _mm_aesdec_si128(B, K2); + B = _mm_aesdec_si128(B, K3); + B = _mm_aesdec_si128(B, K4); + B = _mm_aesdec_si128(B, K5); + B = _mm_aesdec_si128(B, K6); + B = _mm_aesdec_si128(B, K7); + B = _mm_aesdec_si128(B, K8); + B = _mm_aesdec_si128(B, K9); + B = _mm_aesdeclast_si128(B, K10); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-128 Key Schedule +*/ +void AES_128_NI::key_schedule(const byte key[], size_t) + { + #define AES_128_key_exp(K, RCON) \ + aes_128_key_expansion(K, _mm_aeskeygenassist_si128(K, RCON)) + + __m128i K0 = _mm_loadu_si128((const __m128i*)(key)); + __m128i K1 = AES_128_key_exp(K0, 0x01); + __m128i K2 = AES_128_key_exp(K1, 0x02); + __m128i K3 = AES_128_key_exp(K2, 0x04); + __m128i K4 = AES_128_key_exp(K3, 0x08); + __m128i K5 = AES_128_key_exp(K4, 0x10); + __m128i K6 = AES_128_key_exp(K5, 0x20); + __m128i K7 = AES_128_key_exp(K6, 0x40); + __m128i K8 = AES_128_key_exp(K7, 0x80); + __m128i K9 = AES_128_key_exp(K8, 0x1B); + __m128i K10 = AES_128_key_exp(K9, 0x36); + + __m128i* EK_mm = (__m128i*)&EK[0]; + _mm_storeu_si128(EK_mm , K0); + _mm_storeu_si128(EK_mm + 1, K1); + _mm_storeu_si128(EK_mm + 2, K2); + _mm_storeu_si128(EK_mm + 3, K3); + _mm_storeu_si128(EK_mm + 4, K4); + _mm_storeu_si128(EK_mm + 5, K5); + _mm_storeu_si128(EK_mm + 6, K6); + _mm_storeu_si128(EK_mm + 7, K7); + _mm_storeu_si128(EK_mm + 8, K8); + _mm_storeu_si128(EK_mm + 9, K9); + _mm_storeu_si128(EK_mm + 10, K10); + + // Now generate decryption keys + + __m128i* DK_mm = (__m128i*)&DK[0]; + _mm_storeu_si128(DK_mm , K10); + _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(K9)); + _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(K8)); + _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(K7)); + _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(K6)); + _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(K5)); + _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(K4)); + _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(K3)); + _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(K2)); + _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(K1)); + _mm_storeu_si128(DK_mm + 10, K0); + } + +/* +* Clear memory of sensitive data +*/ +void AES_128_NI::clear() + { + zeroise(EK); + zeroise(DK); + } + +/* +* AES-192 Encryption +*/ +void AES_192_NI::encrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&EK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + __m128i K11 = _mm_loadu_si128(key_mm + 11); + __m128i K12 = _mm_loadu_si128(key_mm + 12); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_ENC_4_ROUNDS(K1); + AES_ENC_4_ROUNDS(K2); + AES_ENC_4_ROUNDS(K3); + AES_ENC_4_ROUNDS(K4); + AES_ENC_4_ROUNDS(K5); + AES_ENC_4_ROUNDS(K6); + AES_ENC_4_ROUNDS(K7); + AES_ENC_4_ROUNDS(K8); + AES_ENC_4_ROUNDS(K9); + AES_ENC_4_ROUNDS(K10); + AES_ENC_4_ROUNDS(K11); + AES_ENC_4_LAST_ROUNDS(K12); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesenc_si128(B, K1); + B = _mm_aesenc_si128(B, K2); + B = _mm_aesenc_si128(B, K3); + B = _mm_aesenc_si128(B, K4); + B = _mm_aesenc_si128(B, K5); + B = _mm_aesenc_si128(B, K6); + B = _mm_aesenc_si128(B, K7); + B = _mm_aesenc_si128(B, K8); + B = _mm_aesenc_si128(B, K9); + B = _mm_aesenc_si128(B, K10); + B = _mm_aesenc_si128(B, K11); + B = _mm_aesenclast_si128(B, K12); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-192 Decryption +*/ +void AES_192_NI::decrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&DK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + __m128i K11 = _mm_loadu_si128(key_mm + 11); + __m128i K12 = _mm_loadu_si128(key_mm + 12); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_DEC_4_ROUNDS(K1); + AES_DEC_4_ROUNDS(K2); + AES_DEC_4_ROUNDS(K3); + AES_DEC_4_ROUNDS(K4); + AES_DEC_4_ROUNDS(K5); + AES_DEC_4_ROUNDS(K6); + AES_DEC_4_ROUNDS(K7); + AES_DEC_4_ROUNDS(K8); + AES_DEC_4_ROUNDS(K9); + AES_DEC_4_ROUNDS(K10); + AES_DEC_4_ROUNDS(K11); + AES_DEC_4_LAST_ROUNDS(K12); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesdec_si128(B, K1); + B = _mm_aesdec_si128(B, K2); + B = _mm_aesdec_si128(B, K3); + B = _mm_aesdec_si128(B, K4); + B = _mm_aesdec_si128(B, K5); + B = _mm_aesdec_si128(B, K6); + B = _mm_aesdec_si128(B, K7); + B = _mm_aesdec_si128(B, K8); + B = _mm_aesdec_si128(B, K9); + B = _mm_aesdec_si128(B, K10); + B = _mm_aesdec_si128(B, K11); + B = _mm_aesdeclast_si128(B, K12); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-192 Key Schedule +*/ +void AES_192_NI::key_schedule(const byte key[], size_t) + { + __m128i K0 = _mm_loadu_si128((const __m128i*)(key)); + __m128i K1 = _mm_loadu_si128((const __m128i*)(key + 8)); + K1 = _mm_srli_si128(K1, 8); + + load_le(&EK[0], key, 6); + +#define AES_192_key_exp(RCON, EK_OFF) \ + aes_192_key_expansion(&K0, &K1, \ + _mm_aeskeygenassist_si128(K1, RCON), \ + EK + EK_OFF, EK_OFF == 48) + + AES_192_key_exp(0x01, 6); + AES_192_key_exp(0x02, 12); + AES_192_key_exp(0x04, 18); + AES_192_key_exp(0x08, 24); + AES_192_key_exp(0x10, 30); + AES_192_key_exp(0x20, 36); + AES_192_key_exp(0x40, 42); + AES_192_key_exp(0x80, 48); + + // Now generate decryption keys + const __m128i* EK_mm = (const __m128i*)&EK[0]; + __m128i* DK_mm = (__m128i*)&DK[0]; + _mm_storeu_si128(DK_mm , EK_mm[12]); + _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(EK_mm[11])); + _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(EK_mm[10])); + _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(EK_mm[9])); + _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(EK_mm[8])); + _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(EK_mm[7])); + _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(EK_mm[6])); + _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(EK_mm[5])); + _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(EK_mm[4])); + _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(EK_mm[3])); + _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(EK_mm[2])); + _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(EK_mm[1])); + _mm_storeu_si128(DK_mm + 12, EK_mm[0]); + } + +/* +* Clear memory of sensitive data +*/ +void AES_192_NI::clear() + { + zeroise(EK); + zeroise(DK); + } + +/* +* AES-256 Encryption +*/ +void AES_256_NI::encrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&EK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + __m128i K11 = _mm_loadu_si128(key_mm + 11); + __m128i K12 = _mm_loadu_si128(key_mm + 12); + __m128i K13 = _mm_loadu_si128(key_mm + 13); + __m128i K14 = _mm_loadu_si128(key_mm + 14); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_ENC_4_ROUNDS(K1); + AES_ENC_4_ROUNDS(K2); + AES_ENC_4_ROUNDS(K3); + AES_ENC_4_ROUNDS(K4); + AES_ENC_4_ROUNDS(K5); + AES_ENC_4_ROUNDS(K6); + AES_ENC_4_ROUNDS(K7); + AES_ENC_4_ROUNDS(K8); + AES_ENC_4_ROUNDS(K9); + AES_ENC_4_ROUNDS(K10); + AES_ENC_4_ROUNDS(K11); + AES_ENC_4_ROUNDS(K12); + AES_ENC_4_ROUNDS(K13); + AES_ENC_4_LAST_ROUNDS(K14); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesenc_si128(B, K1); + B = _mm_aesenc_si128(B, K2); + B = _mm_aesenc_si128(B, K3); + B = _mm_aesenc_si128(B, K4); + B = _mm_aesenc_si128(B, K5); + B = _mm_aesenc_si128(B, K6); + B = _mm_aesenc_si128(B, K7); + B = _mm_aesenc_si128(B, K8); + B = _mm_aesenc_si128(B, K9); + B = _mm_aesenc_si128(B, K10); + B = _mm_aesenc_si128(B, K11); + B = _mm_aesenc_si128(B, K12); + B = _mm_aesenc_si128(B, K13); + B = _mm_aesenclast_si128(B, K14); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-256 Decryption +*/ +void AES_256_NI::decrypt_n(const byte in[], byte out[], size_t blocks) const + { + const __m128i* in_mm = (const __m128i*)in; + __m128i* out_mm = (__m128i*)out; + + const __m128i* key_mm = (const __m128i*)&DK[0]; + + __m128i K0 = _mm_loadu_si128(key_mm); + __m128i K1 = _mm_loadu_si128(key_mm + 1); + __m128i K2 = _mm_loadu_si128(key_mm + 2); + __m128i K3 = _mm_loadu_si128(key_mm + 3); + __m128i K4 = _mm_loadu_si128(key_mm + 4); + __m128i K5 = _mm_loadu_si128(key_mm + 5); + __m128i K6 = _mm_loadu_si128(key_mm + 6); + __m128i K7 = _mm_loadu_si128(key_mm + 7); + __m128i K8 = _mm_loadu_si128(key_mm + 8); + __m128i K9 = _mm_loadu_si128(key_mm + 9); + __m128i K10 = _mm_loadu_si128(key_mm + 10); + __m128i K11 = _mm_loadu_si128(key_mm + 11); + __m128i K12 = _mm_loadu_si128(key_mm + 12); + __m128i K13 = _mm_loadu_si128(key_mm + 13); + __m128i K14 = _mm_loadu_si128(key_mm + 14); + + while(blocks >= 4) + { + __m128i B0 = _mm_loadu_si128(in_mm + 0); + __m128i B1 = _mm_loadu_si128(in_mm + 1); + __m128i B2 = _mm_loadu_si128(in_mm + 2); + __m128i B3 = _mm_loadu_si128(in_mm + 3); + + B0 = _mm_xor_si128(B0, K0); + B1 = _mm_xor_si128(B1, K0); + B2 = _mm_xor_si128(B2, K0); + B3 = _mm_xor_si128(B3, K0); + + AES_DEC_4_ROUNDS(K1); + AES_DEC_4_ROUNDS(K2); + AES_DEC_4_ROUNDS(K3); + AES_DEC_4_ROUNDS(K4); + AES_DEC_4_ROUNDS(K5); + AES_DEC_4_ROUNDS(K6); + AES_DEC_4_ROUNDS(K7); + AES_DEC_4_ROUNDS(K8); + AES_DEC_4_ROUNDS(K9); + AES_DEC_4_ROUNDS(K10); + AES_DEC_4_ROUNDS(K11); + AES_DEC_4_ROUNDS(K12); + AES_DEC_4_ROUNDS(K13); + AES_DEC_4_LAST_ROUNDS(K14); + + _mm_storeu_si128(out_mm + 0, B0); + _mm_storeu_si128(out_mm + 1, B1); + _mm_storeu_si128(out_mm + 2, B2); + _mm_storeu_si128(out_mm + 3, B3); + + blocks -= 4; + in_mm += 4; + out_mm += 4; + } + + for(size_t i = 0; i != blocks; ++i) + { + __m128i B = _mm_loadu_si128(in_mm + i); + + B = _mm_xor_si128(B, K0); + + B = _mm_aesdec_si128(B, K1); + B = _mm_aesdec_si128(B, K2); + B = _mm_aesdec_si128(B, K3); + B = _mm_aesdec_si128(B, K4); + B = _mm_aesdec_si128(B, K5); + B = _mm_aesdec_si128(B, K6); + B = _mm_aesdec_si128(B, K7); + B = _mm_aesdec_si128(B, K8); + B = _mm_aesdec_si128(B, K9); + B = _mm_aesdec_si128(B, K10); + B = _mm_aesdec_si128(B, K11); + B = _mm_aesdec_si128(B, K12); + B = _mm_aesdec_si128(B, K13); + B = _mm_aesdeclast_si128(B, K14); + + _mm_storeu_si128(out_mm + i, B); + } + } + +/* +* AES-256 Key Schedule +*/ +void AES_256_NI::key_schedule(const byte key[], size_t) + { + __m128i K0 = _mm_loadu_si128((const __m128i*)(key)); + __m128i K1 = _mm_loadu_si128((const __m128i*)(key + 16)); + + __m128i K2 = aes_128_key_expansion(K0, _mm_aeskeygenassist_si128(K1, 0x01)); + __m128i K3 = aes_256_key_expansion(K1, K2); + + __m128i K4 = aes_128_key_expansion(K2, _mm_aeskeygenassist_si128(K3, 0x02)); + __m128i K5 = aes_256_key_expansion(K3, K4); + + __m128i K6 = aes_128_key_expansion(K4, _mm_aeskeygenassist_si128(K5, 0x04)); + __m128i K7 = aes_256_key_expansion(K5, K6); + + __m128i K8 = aes_128_key_expansion(K6, _mm_aeskeygenassist_si128(K7, 0x08)); + __m128i K9 = aes_256_key_expansion(K7, K8); + + __m128i K10 = aes_128_key_expansion(K8, _mm_aeskeygenassist_si128(K9, 0x10)); + __m128i K11 = aes_256_key_expansion(K9, K10); + + __m128i K12 = aes_128_key_expansion(K10, _mm_aeskeygenassist_si128(K11, 0x20)); + __m128i K13 = aes_256_key_expansion(K11, K12); + + __m128i K14 = aes_128_key_expansion(K12, _mm_aeskeygenassist_si128(K13, 0x40)); + + __m128i* EK_mm = (__m128i*)&EK[0]; + _mm_storeu_si128(EK_mm , K0); + _mm_storeu_si128(EK_mm + 1, K1); + _mm_storeu_si128(EK_mm + 2, K2); + _mm_storeu_si128(EK_mm + 3, K3); + _mm_storeu_si128(EK_mm + 4, K4); + _mm_storeu_si128(EK_mm + 5, K5); + _mm_storeu_si128(EK_mm + 6, K6); + _mm_storeu_si128(EK_mm + 7, K7); + _mm_storeu_si128(EK_mm + 8, K8); + _mm_storeu_si128(EK_mm + 9, K9); + _mm_storeu_si128(EK_mm + 10, K10); + _mm_storeu_si128(EK_mm + 11, K11); + _mm_storeu_si128(EK_mm + 12, K12); + _mm_storeu_si128(EK_mm + 13, K13); + _mm_storeu_si128(EK_mm + 14, K14); + + // Now generate decryption keys + + __m128i* DK_mm = (__m128i*)&DK[0]; + _mm_storeu_si128(DK_mm , K14); + _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(K13)); + _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(K12)); + _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(K11)); + _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(K10)); + _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(K9)); + _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(K8)); + _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(K7)); + _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(K6)); + _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(K5)); + _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(K4)); + _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(K3)); + _mm_storeu_si128(DK_mm + 12, _mm_aesimc_si128(K2)); + _mm_storeu_si128(DK_mm + 13, _mm_aesimc_si128(K1)); + _mm_storeu_si128(DK_mm + 14, K0); + } + +/* +* Clear memory of sensitive data +*/ +void AES_256_NI::clear() + { + zeroise(EK); + zeroise(DK); + } + +} |