osom_lib_prng/block_streams/
chacha.rs

1use core::mem::MaybeUninit;
2
3use osom_lib_reprc::macros::reprc;
4
5use crate::{
6    aligned_array::AlignArray_4, prngs::SplitMix64, traits::{BlockStream, PRNGenerator, Seedable}
7};
8
9/// The Osom imlementation of `ChaCha` algorithm, as defined
10/// in [RFC 8439](https://www.rfc-editor.org/rfc/rfc8439).
11/// 
12/// [`ChaChaStream`] is cryptographically secure block stream.
13/// 
14/// Note that the recommended `ROUNDS` value is 20.
15#[derive(Debug, PartialEq, Eq, Clone, Copy)]
16#[reprc]
17#[must_use]
18pub struct ChaChaStream<const ROUNDS: u32> {
19    key: [u32; 8],
20    nonce: [u32; 3],
21    counter: u32,
22}
23
24impl<const ROUNDS: u32> ChaChaStream<ROUNDS> {
25    /// Creates a new [`ChaChaStream`] from a given seed. This method
26    /// puts the seed into the key, and passes 0 as the nonce.
27    /// 
28    /// For better control use [`Self::from_arrays`] or [`Self::from_slices`].
29    pub fn from_seed(seed: u128) -> Self {
30        let mut key = [0u8; 32];
31        let seed_bytes = seed.to_le_bytes();
32        let mut index = 0usize;
33        while index < size_of::<u128>() {
34            key[index] = seed_bytes[index];
35            index += 1;
36        }
37
38        #[allow(clippy::cast_possible_truncation)]
39        let mut mixer = SplitMix64::with_seed(seed as u64);
40
41        let nonce = mixer.generate::<[u8; 12]>();
42        Self::from_arrays(key, nonce)
43    }
44
45    /// Creates a new [`ChaChaStream`] from a given key and nonce.
46    #[inline(always)]
47    pub fn from_arrays(key: [u8; 32], nonce: [u8; 12]) -> Self {
48        Self::from_slices(&key, &nonce)
49    }
50
51    /// Creates a new [`ChaChaStream`] from a given key and nonce slices.
52    #[inline(always)]
53    pub fn from_slices(key: &[u8], nonce: &[u8]) -> Self {
54        Self::from_slices_and_counter(key, nonce, 0)
55    }
56
57    /// Creates a new [`ChaChaStream`] from a given key, nonce slices
58    /// and internal counter.
59    /// 
60    /// # Panics
61    /// 
62    /// Panics if the key or nonce is not of size 32 or 8 respectively.
63    /// Also when counter is `u32::MAX`.
64    pub fn from_slices_and_counter(key: &[u8], nonce: &[u8], counter: u32) -> Self {
65        const {
66            assert!(ROUNDS >= 8, "ChaCha rounds must be at least 8, otherwise it won't be secure. The recommended value is 20.");
67            assert!(ROUNDS <= 1000, "ChaCha rounds must be at most 1000. That number is definetely too much for any purpose. The recommended value is 20.");
68        }
69        assert!(key.len() == 32, "ChaCha key must be of size 32");
70        assert!(nonce.len() == 12, "ChaCha nonce must be of size 12");
71        assert!(counter < u32::MAX, "ChaCha counter must be smaller than u32::MAX. It is recommended to make it 0 or 1 or something low. Otherwise overflow will occure fast.");
72
73        // We are going to reinterpre u8 slices as u32 slices. And thus we need to
74        // ensure they are properly aligned to 4.
75        let aligned_key = AlignArray_4::<32>::from_slice(key);
76        let aligned_nonce = AlignArray_4::<12>::from_slice(nonce);
77
78        let mut real_key = MaybeUninit::<[u32; 8]>::uninit();
79        let mut real_nonce = MaybeUninit::<[u32; 3]>::uninit();
80
81        #[allow(clippy::needless_range_loop)]
82        unsafe {
83            let aligned_key = aligned_key.as_slice();
84            let aligned_nonce = aligned_nonce.as_slice();
85            let real_key_ptr = real_key.as_mut_ptr().cast::<u32>();
86            let real_nonce_ptr = real_nonce.as_mut_ptr().cast::<u32>();
87            for idx in 0..8 {
88                real_key_ptr.add(idx).write(from_le_u32(aligned_key, idx * 4));
89            }
90            for idx in 0..3 {
91                real_nonce_ptr.add(idx).write(from_le_u32(aligned_nonce, idx * 4));
92            }
93
94            Self { key: real_key.assume_init(), nonce: real_nonce.assume_init(), counter }
95        }
96    }
97
98    /// Generates the next chacha block.
99    /// 
100    /// # Panics
101    /// 
102    /// When the internal counter overflows, i.e. when this function
103    /// is called more than `u32::MAX` times.
104    #[must_use]
105    pub fn next_u32_block(&mut self) -> [u32; 16] {
106        assert!(self.counter != u32::MAX, "ChaCha stream counter overflow.");
107        let counter = self.counter;
108        self.counter += 1;
109        calculate_chacha_block(self.key, self.nonce, counter, ROUNDS)
110    }
111
112    /// Serializes the given chacha block, by applying little endian encoding to each element.
113    #[must_use]
114    pub fn serialize_block(block: &[u32; 16]) -> [u8; 64] {
115        let mut result = MaybeUninit::<[u8; 64]>::uninit();
116        let mut ptr = result.as_mut_ptr().cast::<u8>();
117        for item in block {
118            unsafe {
119                let bytes = item.to_le_bytes();
120                let bytes_ptr = (&raw const bytes).cast();
121                ptr.copy_from_nonoverlapping(bytes_ptr, size_of::<u32>());
122                ptr = ptr.add(size_of::<u32>());
123            }
124        }
125
126        unsafe { result.assume_init() }
127    }
128}
129
130impl<const ROUNDS: u32> BlockStream for ChaChaStream<ROUNDS> {
131    type BlockType = [u8; 64];
132    const BLOCK_SIZE: u32 = 64;
133
134    #[inline(always)]
135    fn next_block(&mut self) -> Self::BlockType {
136        Self::serialize_block(&self.next_u32_block())
137    }
138}
139
140#[inline(always)]
141fn from_le_u32(arr: &[u8], start: usize) -> u32 {
142    #[allow(clippy::cast_ptr_alignment)]
143    unsafe {
144        let ptr = arr.as_ptr().add(start).cast::<u32>();
145        debug_assert!(ptr.is_aligned(), "ChaCha: misaligned initial u32 data");
146        (*ptr).to_le()
147    }
148}
149
150impl<const ROUNDS: u32> Seedable<u128> for ChaChaStream<ROUNDS> {
151    fn with_seed(seed: u128) -> Self {
152        Self::from_seed(seed)
153    }
154}
155
156impl<const ROUNDS: u32> Seedable<u64> for ChaChaStream<ROUNDS> {
157    fn with_seed(seed: u64) -> Self {
158        Self::from_seed(u128::from(seed))
159    }
160}
161
162/// The recommended [`ChaChaStream`] with 20 rounds.
163pub type DefaultChaChaStream = ChaChaStream<20>;
164
165macro_rules! qr {
166    ($arr: expr, $a:literal, $b:literal, $c:literal, $d:literal) => {{
167        $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
168        $arr[$d] ^= $arr[$a];
169        $arr[$d] = $arr[$d].rotate_left(16);
170
171        $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
172        $arr[$b] ^= $arr[$c];
173        $arr[$b] = $arr[$b].rotate_left(12);
174
175        $arr[$a] = $arr[$a].wrapping_add($arr[$b]);
176        $arr[$d] ^= $arr[$a];
177        $arr[$d] = $arr[$d].rotate_left(8);
178
179        $arr[$c] = $arr[$c].wrapping_add($arr[$d]);
180        $arr[$b] ^= $arr[$c];
181        $arr[$b] = $arr[$b].rotate_left(7);
182    }};
183}
184
185#[inline(always)]
186const fn mutate_state(state: &mut [u32; 16], rounds: u32) {
187    let mut index = 0;
188    while index < rounds {
189        // Odd round
190        qr!(state, 0, 4, 8, 12);
191        qr!(state, 1, 5, 9, 13);
192        qr!(state, 2, 6, 10, 14);
193        qr!(state, 3, 7, 11, 15);
194
195        // Even round
196        qr!(state, 0, 5, 10, 15);
197        qr!(state, 1, 6, 11, 12);
198        qr!(state, 2, 7, 8, 13);
199        qr!(state, 3, 4, 9, 14);
200
201        index += 2;
202    }
203}
204
205#[inline(always)]
206fn initialize_block(key: [u32; 8], nonce: [u32; 3], counter: u32) -> [u32; 16] {
207    let mut block = [0u32; 16];
208    block[0] = 0x61707865;
209    block[1] = 0x3320646e;
210    block[2] = 0x79622d32;
211    block[3] = 0x6b206574;
212    block[4..12].copy_from_slice(&key);
213    block[12] = counter;
214    block[13..16].copy_from_slice(&nonce);
215    block
216}
217
218fn calculate_chacha_block(key: [u32; 8], nonce: [u32; 3], counter: u32, rounds: u32) -> [u32; 16] {
219    let block = initialize_block(key, nonce, counter);
220
221    #[allow(clippy::clone_on_copy)]
222    let mut working_block = block.clone();
223
224    mutate_state(&mut working_block, rounds);
225
226    for index in 0..block.len() {
227        working_block[index] = working_block[index].wrapping_add(block[index]);
228    }
229
230    working_block
231}