1#![allow(
2 clippy::cast_possible_truncation,
3)]
4
5use core::mem::MaybeUninit;
6
7use osom_lib_reprc::macros::reprc;
8
9use crate::{
10 aligned_array::AlignArray_4, prngs::SplitMix64, traits::{PRStream, PRNGenerator, Seedable}
11};
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy)]
20#[reprc]
21#[must_use]
22pub struct ChaChaStream<const ROUNDS: u32> {
23 key: [u32; 8],
25 nonce: [u32; 3],
26 counter: u32,
27
28 buffer: [u8; 64],
30 buffer_len: u32,
31}
32
33impl<const ROUNDS: u32> ChaChaStream<ROUNDS> {
34 pub fn from_seed(seed: u128) -> Self {
39 let mut key = [0u8; 32];
40 let seed_bytes = seed.to_le_bytes();
41 let mut index = 0usize;
42 while index < size_of::<u128>() {
43 key[index] = seed_bytes[index];
44 index += 1;
45 }
46
47 #[allow(clippy::cast_possible_truncation)]
48 let mut mixer = SplitMix64::with_seed(seed as u64);
49
50 let nonce = mixer.generate::<[u8; 12]>();
51 Self::from_arrays(key, nonce)
52 }
53
54 #[inline(always)]
56 pub fn from_arrays(key: [u8; 32], nonce: [u8; 12]) -> Self {
57 Self::from_slices(&key, &nonce)
58 }
59
60 #[inline(always)]
62 pub fn from_slices(key: &[u8], nonce: &[u8]) -> Self {
63 Self::from_slices_and_counter(key, nonce, 0)
64 }
65
66 pub fn from_slices_and_counter(key: &[u8], nonce: &[u8], counter: u32) -> Self {
74 const {
75 assert!(ROUNDS >= 8, "ChaCha rounds must be at least 8, otherwise it won't be secure. The recommended value is 20.");
76 assert!(ROUNDS <= 1000, "ChaCha rounds must be at most 1000. That number is definetely too much for any purpose. The recommended value is 20.");
77 }
78 assert!(key.len() == 32, "ChaCha key must be of size 32");
79 assert!(nonce.len() == 12, "ChaCha nonce must be of size 12");
80 assert!(counter < u32::MAX, "ChaCha counter must be smaller than u32::MAX. It is recommended to make it 0 or 1 or something low. Otherwise overflow will occure fast.");
81
82 let aligned_key = AlignArray_4::<32>::from_slice(key);
85 let aligned_nonce = AlignArray_4::<12>::from_slice(nonce);
86
87 let mut real_key = MaybeUninit::<[u32; 8]>::uninit();
88 let mut real_nonce = MaybeUninit::<[u32; 3]>::uninit();
89
90 let mut stream;
91 #[allow(clippy::needless_range_loop)]
92 unsafe {
93 let aligned_key = aligned_key.as_slice();
94 let aligned_nonce = aligned_nonce.as_slice();
95 let real_key_ptr = real_key.as_mut_ptr().cast::<u32>();
96 let real_nonce_ptr = real_nonce.as_mut_ptr().cast::<u32>();
97 for idx in 0..8 {
98 real_key_ptr.add(idx).write(from_le_u32(aligned_key, idx * 4));
99 }
100 for idx in 0..3 {
101 real_nonce_ptr.add(idx).write(from_le_u32(aligned_nonce, idx * 4));
102 }
103
104 stream = Self { key: real_key.assume_init(), nonce: real_nonce.assume_init(), counter, buffer: [0u8; 64], buffer_len: 0 };
105 }
106
107 stream.buffer = Self::serialize_block(&stream.next_u32_block());
108 stream
109 }
110
111 #[must_use]
118 pub fn next_u32_block(&mut self) -> [u32; 16] {
119 assert!(self.counter != u32::MAX, "ChaCha stream counter overflow.");
120 let counter = self.counter;
121 self.counter += 1;
122 calculate_chacha_block(self.key, self.nonce, counter, ROUNDS)
123 }
124
125 #[must_use]
127 pub fn serialize_block(block: &[u32; 16]) -> [u8; 64] {
128 let mut result = MaybeUninit::<[u8; 64]>::uninit();
129 let mut ptr = result.as_mut_ptr().cast::<u8>();
130 for item in block {
131 unsafe {
132 let bytes = item.to_le_bytes();
133 let bytes_ptr = (&raw const bytes).cast();
134 ptr.copy_from_nonoverlapping(bytes_ptr, size_of::<u32>());
135 ptr = ptr.add(size_of::<u32>());
136 }
137 }
138
139 unsafe { result.assume_init() }
140 }
141}
142
143impl<const ROUNDS: u32> PRStream for ChaChaStream<ROUNDS> {
144 unsafe fn fill_raw(&mut self, dst_ptr: *mut u8, dst_len: usize) {
145 let mut dst = dst_ptr;
146 let mut len = dst_len;
147
148 let diff = 64 - self.buffer_len;
150 let to_write = core::cmp::min(len, diff as usize);
151 unsafe {
152 dst.copy_from_nonoverlapping(self.buffer.as_ptr().add(self.buffer_len as usize), to_write);
153 dst = dst.add(to_write);
154 len -= to_write;
155 };
156
157 if len == 0 {
158 self.buffer_len += to_write as u32;
159 return;
160 }
161
162 self.buffer = Self::serialize_block(&self.next_u32_block());
164 self.buffer_len = 0;
165
166 while len > 64 {
167 unsafe {
168 dst.copy_from_nonoverlapping(self.buffer.as_ptr(), 64);
169 dst = dst.add(64);
170 };
171 self.buffer = Self::serialize_block(&self.next_u32_block());
172 self.buffer_len = 0;
173 len -= 64;
174 }
175
176 if len > 0 {
178 unsafe {
179 dst.copy_from_nonoverlapping(self.buffer.as_ptr(), len);
180 };
181 self.buffer_len = len as u32;
182 }
183 }
184
185
186 }
191
192#[inline(always)]
193fn from_le_u32(arr: &[u8], start: usize) -> u32 {
194 #[allow(clippy::cast_ptr_alignment)]
195 unsafe {
196 let ptr = arr.as_ptr().add(start).cast::<u32>();
197 debug_assert!(ptr.is_aligned(), "ChaCha: misaligned initial u32 data");
198 (*ptr).to_le()
199 }
200}
201
202impl<const ROUNDS: u32> Seedable<u128> for ChaChaStream<ROUNDS> {
203 fn with_seed(seed: u128) -> Self {
204 Self::from_seed(seed)
205 }
206}
207
208impl<const ROUNDS: u32> Seedable<u64> for ChaChaStream<ROUNDS> {
209 fn with_seed(seed: u64) -> Self {
210 Self::from_seed(u128::from(seed))
211 }
212}
213
214pub type DefaultChaChaStream = ChaChaStream<20>;
216
217macro_rules! qr {
218 ($arr: expr, $a:literal, $b:literal, $c:literal, $d:literal) => {{
219 $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
220 $arr[$d] ^= $arr[$a];
221 $arr[$d] = $arr[$d].rotate_left(16);
222
223 $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
224 $arr[$b] ^= $arr[$c];
225 $arr[$b] = $arr[$b].rotate_left(12);
226
227 $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
228 $arr[$d] ^= $arr[$a];
229 $arr[$d] = $arr[$d].rotate_left(8);
230
231 $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
232 $arr[$b] ^= $arr[$c];
233 $arr[$b] = $arr[$b].rotate_left(7);
234 }};
235}
236
237#[inline(always)]
238const fn mutate_state(state: &mut [u32; 16], rounds: u32) {
239 let mut index = 0;
240 while index < rounds {
241 qr!(state, 0, 4, 8, 12);
243 qr!(state, 1, 5, 9, 13);
244 qr!(state, 2, 6, 10, 14);
245 qr!(state, 3, 7, 11, 15);
246
247 qr!(state, 0, 5, 10, 15);
249 qr!(state, 1, 6, 11, 12);
250 qr!(state, 2, 7, 8, 13);
251 qr!(state, 3, 4, 9, 14);
252
253 index += 2;
254 }
255}
256
257#[inline(always)]
258fn initialize_block(key: [u32; 8], nonce: [u32; 3], counter: u32) -> [u32; 16] {
259 let mut block = [0u32; 16];
260 block[0] = 0x61707865;
261 block[1] = 0x3320646e;
262 block[2] = 0x79622d32;
263 block[3] = 0x6b206574;
264 block[4..12].copy_from_slice(&key);
265 block[12] = counter;
266 block[13..16].copy_from_slice(&nonce);
267 block
268}
269
270fn calculate_chacha_block(key: [u32; 8], nonce: [u32; 3], counter: u32, rounds: u32) -> [u32; 16] {
271 let block = initialize_block(key, nonce, counter);
272
273 #[allow(clippy::clone_on_copy)]
274 let mut working_block = block.clone();
275
276 mutate_state(&mut working_block, rounds);
277
278 for index in 0..block.len() {
279 working_block[index] = working_block[index].wrapping_add(block[index]);
280 }
281
282 working_block
283}