@@ -20,7 +20,14 @@ struct AlignedU32x8([u32; SIMD_LANES]);
20
20
// [ block1 (32-bits) || block2 (32-bits) || block3 (32-bits) || block4 (32-bits) || block5 (32-bits) ... ]
21
21
// then we perform the normal ChaCha operations on these vectors, meaning that we compute
22
22
// 8 ChaCha blocks in parallel for every operation on these vectors.
23
- pub fn chacha_avx2 < const ROUNDS : usize > ( state : [ u32 ; 16 ] , mut counter : u64 , input : & mut [ u8 ] ) -> u64 {
23
+ pub fn chacha_avx2 < const ROUNDS : usize > (
24
+ state : [ u32 ; 16 ] ,
25
+ mut counter : u64 ,
26
+ input : & mut [ u8 ] ,
27
+ last_keystream_block : & mut [ u8 ; 64 ] ,
28
+ ) -> u64 {
29
+ let mut keystream = [ 0u8 ; SIMD_LANES * 64 ] ;
30
+
24
31
let mut initial_state: [ __m256i ; 16 ] = unsafe {
25
32
[
26
33
// constant
@@ -64,7 +71,6 @@ pub fn chacha_avx2<const ROUNDS: usize>(state: [u32; 16], mut counter: u64, inpu
64
71
}
65
72
66
73
// compute 8 64-byte ChaCha blocks in parallel
67
- let mut keystream = [ 0u8 ; SIMD_LANES * 64 ] ;
68
74
chacha20_avx2_8blocks :: < ROUNDS > ( initial_state, & mut keystream) ;
69
75
70
76
// XOR plaintext with keystream
@@ -76,6 +82,10 @@ pub fn chacha_avx2<const ROUNDS: usize>(state: [u32; 16], mut counter: u64, inpu
76
82
counter = counter. wrapping_add ( ( input_blocks. len ( ) as u64 ) . div_ceil ( 64 ) ) ;
77
83
}
78
84
85
+ let last_keystream_block_index = ( ( input. len ( ) - 1 ) / 64 ) % SIMD_LANES ;
86
+ let last_keystream_block_offset = last_keystream_block_index * 64 ;
87
+ last_keystream_block. copy_from_slice ( & keystream[ last_keystream_block_offset..last_keystream_block_offset + 64 ] ) ;
88
+
79
89
return counter;
80
90
}
81
91
@@ -84,6 +94,8 @@ pub fn chacha_avx2<const ROUNDS: usize>(state: [u32; 16], mut counter: u64, inpu
84
94
/// [ block1 (64 bytes) || block2 (64 bytes) || block3 (64 bytes) || block4 (64 bytes) ... ]
85
95
#[ inline( always) ]
86
96
fn chacha20_avx2_8blocks < const ROUNDS : usize > ( initial_state : [ __m256i ; 16 ] , keystream : & mut [ u8 ; SIMD_LANES * 64 ] ) {
97
+ let keystream_ptr = keystream. as_mut_ptr ( ) ;
98
+
87
99
unsafe {
88
100
let mut working_state = initial_state;
89
101
@@ -130,7 +142,6 @@ fn chacha20_avx2_8blocks<const ROUNDS: usize>(initial_state: [__m256i; 16], keys
130
142
// the second iteration writes block1[4..8], block2[4..8], block3[4..8], block4[4..8], block5[4..8] ...
131
143
// the third iteration writes block1[4..8], block2[8..12], block3[8..12], block4[8..12], block5[8..12] ...
132
144
// and so on, for the 16 32-bit words of the ChaCha state
133
- let keystream_ptr = keystream. as_mut_ptr ( ) ;
134
145
for word_index in 0 ..STATE_WORDS {
135
146
// first we add the working state to the initial state to get the keystream
136
147
working_state[ word_index] = _mm256_add_epi32 ( working_state[ word_index] , initial_state[ word_index] ) ;
0 commit comments