osom_lib_prng/
block_prng.rs

1use core::ops::RangeBounds;
2
3use osom_lib_reprc::{macros::reprc, traits::ReprC};
4
5use crate::{
6    prngs::helpers::{
7        generate_f32_in_range, generate_f64_in_range, generate_i32_in_range, generate_i64_in_range,
8        generate_u32_in_range, generate_u64_in_range,
9    },
10    traits::{BlockStream, PRNConcreteBoundedGenerator, PRNConcreteGenerator, PRNGenerator, Seedable},
11};
12
13/// The standard block-based PRNG. It accepts a stream and implements
14/// the [`PRNGenerator`] trait on top of it.
15#[derive(Debug, PartialEq, Eq, Clone)]
16#[reprc]
17pub struct BlockPRNG<T: BlockStream> {
18    stream: T,
19    block: T::BlockType,
20    position: u32,
21}
22
23impl<T: BlockStream> BlockPRNG<T> {
24    /// Creates a new [`BlockPRNG`] from a given stream.
25    #[inline(always)]
26    pub fn new(stream: T) -> Self {
27        let mut stream = stream;
28        let initial_block = Self::new_block(&mut stream);
29        Self {
30            stream,
31            block: initial_block,
32            position: 0,
33        }
34    }
35
36    #[inline(always)]
37    fn new_block(stream: &mut T) -> T::BlockType {
38        let block = stream.next_block();
39        assert!(block.as_ref().len() == T::BLOCK_SIZE as usize, "Block size mismatch.");
40        block
41    }
42}
43
44#[allow(clippy::cast_possible_truncation)]
45impl<T: BlockStream> PRNGenerator for BlockPRNG<T> {
46    unsafe fn fill_raw(&mut self, dst_ptr: *mut u8, dst_len: usize) {
47        let mut dst = dst_ptr;
48        let mut len = dst_len;
49
50        // Fill from whatever it is missing in the current block
51        let diff = T::BLOCK_SIZE - self.position;
52        let to_write = core::cmp::min(len, diff as usize);
53        unsafe {
54            dst.copy_from_nonoverlapping(self.block.as_ref().as_ptr().add(self.position as usize), to_write);
55            dst = dst.add(to_write);
56            len -= to_write;
57        };
58
59        if len == 0 {
60            self.position += to_write as u32;
61            return;
62        }
63
64        // Fill the buffer block by block
65        self.block = Self::new_block(&mut self.stream);
66        self.position = 0;
67
68        while len > T::BLOCK_SIZE as usize {
69            unsafe {
70                dst.copy_from_nonoverlapping(self.block.as_ref().as_ptr(), T::BLOCK_SIZE as usize);
71                dst = dst.add(T::BLOCK_SIZE as usize);
72            };
73            self.block = Self::new_block(&mut self.stream);
74            len -= T::BLOCK_SIZE as usize;
75        }
76
77        // If anything remains, fill it from the current block.
78        if len > 0 {
79            unsafe {
80                dst.copy_from_nonoverlapping(self.block.as_ref().as_ptr(), len);
81            };
82            self.position = len as u32;
83        }
84    }
85}
86
87impl<const N: usize, T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for [u8; N] {
88    #[inline(always)]
89    fn generate(generator: &mut BlockPRNG<T>) -> Self {
90        const {
91            assert!(size_of::<Self>() == N);
92        }
93        let mut item = core::mem::MaybeUninit::<Self>::uninit();
94        unsafe {
95            generator.fill_raw(item.as_mut_ptr().cast(), N);
96            item.assume_init()
97        }
98    }
99}
100
101impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for bool {
102    fn generate(generator: &mut BlockPRNG<T>) -> Self {
103        (generator.generate::<u8>() & 1) == 1
104    }
105}
106
107impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for u8 {
108    fn generate(generator: &mut BlockPRNG<T>) -> Self {
109        u8::from_le_bytes(generator.generate::<[u8; 1]>())
110    }
111}
112
113impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for i8 {
114    fn generate(generator: &mut BlockPRNG<T>) -> Self {
115        i8::from_le_bytes(generator.generate::<[u8; 1]>())
116    }
117}
118
119impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for u16 {
120    fn generate(generator: &mut BlockPRNG<T>) -> Self {
121        u16::from_le_bytes(generator.generate::<[u8; 2]>())
122    }
123}
124
125impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for i16 {
126    fn generate(generator: &mut BlockPRNG<T>) -> Self {
127        i16::from_le_bytes(generator.generate::<[u8; 2]>())
128    }
129}
130
131impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for u32 {
132    fn generate(generator: &mut BlockPRNG<T>) -> Self {
133        u32::from_le_bytes(generator.generate::<[u8; 4]>())
134    }
135}
136
137impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for i32 {
138    fn generate(generator: &mut BlockPRNG<T>) -> Self {
139        i32::from_le_bytes(generator.generate::<[u8; 4]>())
140    }
141}
142
143impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for u64 {
144    fn generate(generator: &mut BlockPRNG<T>) -> Self {
145        u64::from_le_bytes(generator.generate::<[u8; 8]>())
146    }
147}
148
149impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for i64 {
150    fn generate(generator: &mut BlockPRNG<T>) -> Self {
151        i64::from_le_bytes(generator.generate::<[u8; 8]>())
152    }
153}
154
155impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for u128 {
156    fn generate(generator: &mut BlockPRNG<T>) -> Self {
157        u128::from_le_bytes(generator.generate::<[u8; 16]>())
158    }
159}
160
161impl<T: BlockStream> PRNConcreteGenerator<BlockPRNG<T>> for i128 {
162    fn generate(generator: &mut BlockPRNG<T>) -> Self {
163        i128::from_le_bytes(generator.generate::<[u8; 16]>())
164    }
165}
166
167impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for u32 {
168    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
169        generate_u32_in_range(generator, range)
170    }
171}
172
173impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for u64 {
174    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
175        generate_u64_in_range(generator, range)
176    }
177}
178
179impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for i32 {
180    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
181        generate_i32_in_range(generator, range)
182    }
183}
184
185impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for i64 {
186    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
187        generate_i64_in_range(generator, range)
188    }
189}
190
191impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for f32 {
192    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
193        generate_f32_in_range(generator, range)
194    }
195}
196
197impl<T: BlockStream> PRNConcreteBoundedGenerator<BlockPRNG<T>> for f64 {
198    fn generate<TBounds: RangeBounds<Self>>(generator: &mut BlockPRNG<T>, range: TBounds) -> Self {
199        generate_f64_in_range(generator, range)
200    }
201}
202
203impl<TSeed: ReprC + Copy, T: BlockStream + Seedable<TSeed>> Seedable<TSeed> for BlockPRNG<T> {
204    fn with_seed(seed: TSeed) -> Self {
205        Self::new(T::with_seed(seed))
206    }
207}