Skip to main content

osom_lib_prng/streams/
chacha.rs

1#![allow(
2    clippy::cast_possible_truncation,
3)]
4
5use core::{convert::Infallible, mem::MaybeUninit};
6
7use osom_lib_reprc::macros::reprc;
8use osom_lib_try_clone::TryClone;
9
10use crate::{
11    aligned_array::AlignArray_4, prngs::SplitMix64, traits::{PRStream, PRNGenerator, Seedable}
12};
13
14/// The Osom imlementation of `ChaCha` algorithm, as defined
15/// in [RFC 8439](https://www.rfc-editor.org/rfc/rfc8439).
16/// 
17/// [`ChaChaStream`] is cryptographically secure stream.
18/// 
19/// Note that the recommended `ROUNDS` value is 20.
20#[derive(Debug, PartialEq, Eq, Clone, Copy)]
21#[reprc]
22#[must_use]
23pub struct ChaChaStream<const ROUNDS: u32> {
24    // The data necessary to generate blocks
25    key: [u32; 8],
26    nonce: [u32; 3],
27    counter: u32,
28
29    // Current buffer to serve data from
30    buffer: [u8; 64],
31    buffer_len: u32,
32}
33
34impl<const ROUNDS: u32> TryClone for ChaChaStream<ROUNDS> {
35    type Error = Infallible;
36
37    fn try_clone(&self) -> Result<Self, Self::Error> {
38        Ok(*self)
39    }
40}
41
42impl<const ROUNDS: u32> ChaChaStream<ROUNDS> {
43    /// Creates a new [`ChaChaStream`] from a given seed. This method
44    /// puts the seed into the key, and passes 0 as the nonce.
45    /// 
46    /// For better control use [`Self::from_arrays`] or [`Self::from_slices`].
47    pub fn from_seed(seed: u128) -> Self {
48        let mut key = [0u8; 32];
49        let seed_bytes = seed.to_le_bytes();
50        let mut index = 0usize;
51        while index < size_of::<u128>() {
52            key[index] = seed_bytes[index];
53            index += 1;
54        }
55
56        #[allow(clippy::cast_possible_truncation)]
57        let mut mixer = SplitMix64::with_seed(seed as u64);
58
59        let nonce = mixer.generate::<[u8; 12]>();
60        Self::from_arrays(key, nonce)
61    }
62
63    /// Creates a new [`ChaChaStream`] from a given key and nonce.
64    #[inline(always)]
65    pub fn from_arrays(key: [u8; 32], nonce: [u8; 12]) -> Self {
66        Self::from_slices(&key, &nonce)
67    }
68
69    /// Creates a new [`ChaChaStream`] from a given key and nonce slices.
70    #[inline(always)]
71    pub fn from_slices(key: &[u8], nonce: &[u8]) -> Self {
72        Self::from_slices_and_counter(key, nonce, 0)
73    }
74
75    /// Creates a new [`ChaChaStream`] from a given key, nonce slices
76    /// and internal counter.
77    /// 
78    /// # Panics
79    /// 
80    /// Panics if the key or nonce is not of size 32 or 12 respectively.
81    /// Also when counter is `u32::MAX`.
82    pub fn from_slices_and_counter(key: &[u8], nonce: &[u8], counter: u32) -> Self {
83        const {
84            assert!(ROUNDS >= 8, "ChaCha rounds must be at least 8, otherwise it won't be secure. The recommended value is 20.");
85            assert!(ROUNDS <= 1000, "ChaCha rounds must be at most 1000. That number is definitely too much for any purpose. The recommended value is 20.");
86        }
87        assert!(key.len() == 32, "ChaCha key must be of size 32");
88        assert!(nonce.len() == 12, "ChaCha nonce must be of size 12");
89        assert!(counter < u32::MAX, "ChaCha counter must be smaller than u32::MAX. It is recommended to make it 0 or 1 or something low. Otherwise overflow will occur fast.");
90
91        // We are going to reinterpre u8 slices as u32 slices. And thus we need to
92        // ensure they are properly aligned to 4.
93        let aligned_key = AlignArray_4::<32>::from_slice(key);
94        let aligned_nonce = AlignArray_4::<12>::from_slice(nonce);
95
96        let mut real_key = MaybeUninit::<[u32; 8]>::uninit();
97        let mut real_nonce = MaybeUninit::<[u32; 3]>::uninit();
98
99        let mut stream;
100        #[allow(clippy::needless_range_loop)]
101        unsafe {
102            let aligned_key = aligned_key.as_slice();
103            let aligned_nonce = aligned_nonce.as_slice();
104            let real_key_ptr = real_key.as_mut_ptr().cast::<u32>();
105            let real_nonce_ptr = real_nonce.as_mut_ptr().cast::<u32>();
106            for idx in 0..8 {
107                real_key_ptr.add(idx).write(from_le_u32(aligned_key, idx * 4));
108            }
109            for idx in 0..3 {
110                real_nonce_ptr.add(idx).write(from_le_u32(aligned_nonce, idx * 4));
111            }
112
113            stream = Self { key: real_key.assume_init(), nonce: real_nonce.assume_init(), counter, buffer: [0u8; 64], buffer_len: 0 };
114        }
115
116        stream.buffer = Self::serialize_block(&stream.next_u32_block());
117        stream
118    }
119
120    /// Generates the next chacha block.
121    /// 
122    /// # Panics
123    /// 
124    /// When the internal counter overflows, i.e. when this function
125    /// is called more than `u32::MAX` times.
126    #[must_use]
127    pub fn next_u32_block(&mut self) -> [u32; 16] {
128        assert!(self.counter != u32::MAX, "ChaCha stream counter overflow.");
129        let counter = self.counter;
130        self.counter += 1;
131        calculate_chacha_block(self.key, self.nonce, counter, ROUNDS)
132    }
133
134    /// Serializes the given chacha block, by applying little endian encoding to each element.
135    #[must_use]
136    pub fn serialize_block(block: &[u32; 16]) -> [u8; 64] {
137        let mut result = MaybeUninit::<[u8; 64]>::uninit();
138        let mut ptr = result.as_mut_ptr().cast::<u8>();
139        for item in block {
140            unsafe {
141                let bytes = item.to_le_bytes();
142                let bytes_ptr = (&raw const bytes).cast();
143                ptr.copy_from_nonoverlapping(bytes_ptr, size_of::<u32>());
144                ptr = ptr.add(size_of::<u32>());
145            }
146        }
147
148        unsafe { result.assume_init() }
149    }
150}
151
152impl<const ROUNDS: u32> PRStream for ChaChaStream<ROUNDS> {
153    unsafe fn fill_raw(&mut self, dst_ptr: *mut u8, dst_len: usize) {
154        let mut dst = dst_ptr;
155        let mut len = dst_len;
156
157        // Fill from whatever it is missing in the current block
158        let diff = 64 - self.buffer_len;
159        let to_write = core::cmp::min(len, diff as usize);
160        unsafe {
161            dst.copy_from_nonoverlapping(self.buffer.as_ptr().add(self.buffer_len as usize), to_write);
162            dst = dst.add(to_write);
163            len -= to_write;
164        };
165
166        if len == 0 {
167            self.buffer_len += to_write as u32;
168            return;
169        }
170
171        // Fill the buffer block by block
172        self.buffer = Self::serialize_block(&self.next_u32_block());
173        self.buffer_len = 0;
174
175        while len > 64 {
176            unsafe {
177                dst.copy_from_nonoverlapping(self.buffer.as_ptr(), 64);
178                dst = dst.add(64);
179            };
180            self.buffer = Self::serialize_block(&self.next_u32_block());
181            self.buffer_len = 0;
182            len -= 64;
183        }
184
185        // If anything remains, fill it from the current block.
186        if len > 0 {
187            unsafe {
188                dst.copy_from_nonoverlapping(self.buffer.as_ptr(), len);
189            };
190            self.buffer_len = len as u32;
191        }
192    }
193    
194
195    // #[inline(always)]
196    // fn next_block(&mut self) -> Self::BlockType {
197    //     Self::serialize_block(&self.next_u32_block())
198    // }
199}
200
201#[inline(always)]
202fn from_le_u32(arr: &[u8], start: usize) -> u32 {
203    #[allow(clippy::cast_ptr_alignment)]
204    unsafe {
205        let ptr = arr.as_ptr().add(start).cast::<u32>();
206        debug_assert!(ptr.is_aligned(), "ChaCha: misaligned initial u32 data");
207        (*ptr).to_le()
208    }
209}
210
211impl<const ROUNDS: u32> Seedable<u128> for ChaChaStream<ROUNDS> {
212    fn with_seed(seed: u128) -> Self {
213        Self::from_seed(seed)
214    }
215}
216
217impl<const ROUNDS: u32> Seedable<u64> for ChaChaStream<ROUNDS> {
218    fn with_seed(seed: u64) -> Self {
219        Self::from_seed(u128::from(seed))
220    }
221}
222
223/// The recommended [`ChaChaStream`] with 20 rounds.
224pub type DefaultChaChaStream = ChaChaStream<20>;
225
226macro_rules! qr {
227    ($arr: expr, $a:literal, $b:literal, $c:literal, $d:literal) => {{
228        $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
229        $arr[$d] ^= $arr[$a];
230        $arr[$d] = $arr[$d].rotate_left(16);
231
232        $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
233        $arr[$b] ^= $arr[$c];
234        $arr[$b] = $arr[$b].rotate_left(12);
235
236        $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
237        $arr[$d] ^= $arr[$a];
238        $arr[$d] = $arr[$d].rotate_left(8);
239
240        $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
241        $arr[$b] ^= $arr[$c];
242        $arr[$b] = $arr[$b].rotate_left(7);
243    }};
244}
245
246#[inline(always)]
247const fn mutate_state(state: &mut [u32; 16], rounds: u32) {
248    let mut index = 0;
249    while index < rounds {
250        // Odd round
251        qr!(state, 0, 4, 8, 12);
252        qr!(state, 1, 5, 9, 13);
253        qr!(state, 2, 6, 10, 14);
254        qr!(state, 3, 7, 11, 15);
255
256        // Even round
257        qr!(state, 0, 5, 10, 15);
258        qr!(state, 1, 6, 11, 12);
259        qr!(state, 2, 7, 8, 13);
260        qr!(state, 3, 4, 9, 14);
261
262        index += 2;
263    }
264}
265
266#[inline(always)]
267fn initialize_block(key: [u32; 8], nonce: [u32; 3], counter: u32) -> [u32; 16] {
268    let mut block = [0u32; 16];
269    block[0] = 0x61707865;
270    block[1] = 0x3320646e;
271    block[2] = 0x79622d32;
272    block[3] = 0x6b206574;
273    block[4..12].copy_from_slice(&key);
274    block[12] = counter;
275    block[13..16].copy_from_slice(&nonce);
276    block
277}
278
279fn calculate_chacha_block(key: [u32; 8], nonce: [u32; 3], counter: u32, rounds: u32) -> [u32; 16] {
280    let block = initialize_block(key, nonce, counter);
281
282    #[allow(clippy::clone_on_copy)]
283    let mut working_block = block.clone();
284
285    mutate_state(&mut working_block, rounds);
286
287    for index in 0..block.len() {
288        working_block[index] = working_block[index].wrapping_add(block[index]);
289    }
290
291    working_block
292}