osom_asm_x86_64/assembler/implementation/
x86_64_assembler_assemble.rs

1#![allow(
2    unused_unsafe,
3    clippy::checked_conversions,
4    clippy::cast_possible_wrap,
5    clippy::cast_sign_loss,
6    clippy::cast_ptr_alignment,
7    clippy::unnecessary_wraps,
8    clippy::used_underscore_items
9)]
10
11use std::collections::HashMap;
12
13use osom_encoders_x86_64::encoders as enc;
14use osom_encoders_x86_64::models as enc_models;
15
16use crate::assembler::implementation::fragment::FragmentOrderId;
17use crate::assembler::implementation::fragment::RelaxationVariant;
18use crate::assembler::implementation::fragment::const_sizes;
19use crate::assembler::implementation::macros::fragment_end;
20use crate::assembler::{AssembleError, EmissionData};
21use crate::models::Condition;
22use crate::models::Label;
23
24use super::macros::{fragment_at_index, fragment_at_index_mut};
25use super::{X86_64Assembler, fragment::Fragment};
26
27impl X86_64Assembler {
28    pub fn assemble(mut self, stream: &mut dyn std::io::Write) -> Result<EmissionData, AssembleError> {
29        let mut offsets = calculate_initial_offsets(&self)?;
30        relax_instructions_and_update_offsets(&mut self, &mut offsets)?;
31        let labels_map = calculate_labels_map(&self, &offsets)?;
32        patch_addresses(&mut self, &labels_map, &offsets)?;
33        emit_fragments(&self, &labels_map, &offsets, stream)
34    }
35}
36
37fn calculate_initial_offsets(asm: &X86_64Assembler) -> Result<HashMap<FragmentOrderId, u32>, AssembleError> {
38    let mut result = HashMap::with_capacity(asm.fragments_count as usize);
39
40    let start = fragment_at_index!(asm, 0) as *const Fragment;
41    let end = fragment_end!(asm);
42    let get_id = |fragment: *const Fragment| -> FragmentOrderId {
43        let u8_fragment = fragment.cast::<u8>();
44        let u8_start = start.cast::<u8>();
45        let ptr_diff = unsafe { u8_fragment.offset_from(u8_start) };
46        FragmentOrderId::from_index(ptr_diff as u32)
47    };
48
49    let mut current_fragment = start;
50    let mut current_offset = 0;
51    let fragment_id = get_id(current_fragment);
52    result.insert(fragment_id, current_offset);
53
54    while current_fragment < end {
55        let current_fragment_ref = unsafe { &*current_fragment };
56        current_offset += current_fragment_ref.data_length();
57        current_fragment = unsafe { current_fragment_ref.next() };
58        result.insert(get_id(current_fragment), current_offset);
59    }
60
61    Ok(result)
62}
63
64/// I really don't know where this shift by 3 comes from.
65/// But it is needed, so I won't dive deep into it.
66const MAGIC_SHIFT: isize = 3;
67
68fn relax_instructions_and_update_offsets(
69    asm: &mut X86_64Assembler,
70    offsets: &mut HashMap<FragmentOrderId, u32>,
71) -> Result<(), AssembleError> {
72    let start = fragment_at_index_mut!(asm, 0) as *mut Fragment;
73    let end = fragment_end!(asm);
74
75    let get_id = |fragment: *const Fragment| -> FragmentOrderId {
76        let u8_fragment = fragment.cast::<u8>();
77        let u8_start = start.cast::<u8>();
78        let ptr_diff = unsafe { u8_fragment.offset_from(u8_start) };
79        FragmentOrderId::from_index(ptr_diff as u32)
80    };
81
82    let get_position = |label: &Label, offsets: &HashMap<FragmentOrderId, u32>| -> Result<u32, AssembleError> {
83        let Some(label_offset) = asm.label_offsets.get(label) else {
84            return Err(AssembleError::LabelNotSet(*label));
85        };
86
87        let fragment_index = label_offset.fragment_id.index();
88        let fragment = fragment_at_index!(asm, fragment_index);
89
90        let relaxation_offset = match fragment {
91            Fragment::Bytes { .. } => 0,
92            _ => fragment.data_length(),
93        };
94
95        let fragment_offset = offsets.get(&label_offset.fragment_id).unwrap();
96        Ok(*fragment_offset + relaxation_offset + label_offset.in_fragment_offset)
97    };
98
99    macro_rules! update_subsequent_offsets {
100        ($start:expr, $add:expr) => {{
101            let start: *mut Fragment = $start;
102            let add: u32 = $add;
103            let mut current = unsafe { (*start).next() };
104            while current.cast_const() < end {
105                let current_id = get_id(current);
106                *offsets.get_mut(&current_id).unwrap() += add;
107                current = unsafe { (*current).next() };
108            }
109        }};
110    }
111
112    loop {
113        let mut has_changes = false;
114
115        let mut current_fragment = start;
116        while current_fragment.cast_const() < end {
117            let current_fragment_ref = unsafe { &mut *current_fragment };
118            if let Fragment::Bytes { .. } = current_fragment_ref {
119                current_fragment = unsafe { current_fragment_ref.next() };
120                continue;
121            }
122
123            let current_fragment_id = get_id(current_fragment);
124            let current_fragment_offset = (*offsets.get(&current_fragment_id).unwrap()) as isize;
125
126            match current_fragment_ref {
127                Fragment::Relaxable_Jump { variant, label } => {
128                    if *variant == RelaxationVariant::Short {
129                        let label_position = get_position(label, offsets)? as isize;
130                        let diff = current_fragment_offset - label_position - const_sizes::SHORT_JUMP as isize;
131                        if (diff < i8::MIN as isize - MAGIC_SHIFT) || (diff > i8::MAX as isize - MAGIC_SHIFT) {
132                            *variant = RelaxationVariant::Long;
133                            has_changes = true;
134                            let add = const_sizes::LONG_JUMP - const_sizes::SHORT_JUMP;
135                            update_subsequent_offsets!(current_fragment, add);
136                        }
137                    }
138                }
139                Fragment::Relaxable_CondJump { variant, label, .. } => {
140                    if *variant == RelaxationVariant::Short {
141                        let label_position = get_position(label, offsets)? as isize;
142                        let diff = current_fragment_offset - label_position - const_sizes::SHORT_COND_JUMP as isize;
143                        if (diff < i8::MIN as isize - MAGIC_SHIFT) || (diff > i8::MAX as isize - MAGIC_SHIFT) {
144                            *variant = RelaxationVariant::Long;
145                            has_changes = true;
146                            let add = const_sizes::LONG_COND_JUMP - const_sizes::SHORT_COND_JUMP;
147                            update_subsequent_offsets!(current_fragment, add);
148                        }
149                    }
150                }
151                Fragment::Bytes { .. } => unreachable!(),
152            }
153
154            current_fragment = unsafe { current_fragment_ref.next() };
155        }
156
157        if !has_changes {
158            break;
159        }
160    }
161
162    Ok(())
163}
164
165fn calculate_labels_map(
166    asm: &X86_64Assembler,
167    offsets: &HashMap<FragmentOrderId, u32>,
168) -> Result<HashMap<Label, usize>, AssembleError> {
169    let mut result = HashMap::with_capacity(asm.label_offsets.len());
170
171    for (label, label_offset) in &asm.label_offsets {
172        let fragment_index = label_offset.fragment_id.index();
173        let fragment = fragment_at_index!(asm, fragment_index);
174        let relaxation_offset = match fragment {
175            Fragment::Bytes { .. } => 0,
176            _ => fragment.data_length(),
177        };
178
179        let fragment_offset = offsets.get(&label_offset.fragment_id).unwrap();
180        let position = *fragment_offset + relaxation_offset + label_offset.in_fragment_offset;
181        result.insert(*label, position as usize);
182    }
183
184    Ok(result)
185}
186
187fn patch_addresses(
188    asm: &mut X86_64Assembler,
189    labels_map: &HashMap<Label, usize>,
190    offsets: &HashMap<FragmentOrderId, u32>,
191) -> Result<(), AssembleError> {
192    unsafe {
193        for (label, patchable_addresses) in &asm.patchable_addresses {
194            let final_label_position = *labels_map.get(label).unwrap() as isize;
195            for patchable_address in patchable_addresses.as_slice() {
196                let patchable_fragment_id = patchable_address.instruction_position.fragment_id;
197                let patchable_fragment_index = patchable_fragment_id.index();
198                let patchable_fragment = fragment_at_index!(asm, patchable_fragment_index);
199                debug_assert!(
200                    matches!(patchable_fragment, Fragment::Bytes { .. }),
201                    "Patchable fragment is not a bytes fragment. Got: {patchable_fragment:?}"
202                );
203                let patchable_fragment_data_offset = size_of::<Fragment>();
204                let patchable_imm32_offset = patchable_fragment_index as usize
205                    + patchable_fragment_data_offset
206                    + patchable_address.instruction_position.in_fragment_offset as usize
207                    + patchable_address.imm32_offset as usize;
208
209                let patchable_imm32 = asm.fragments.as_mut_ptr().add(patchable_imm32_offset);
210
211                let final_fragment_offset = *offsets.get(&patchable_fragment_id).unwrap() as isize;
212                let final_end_of_instruction = final_fragment_offset
213                    + patchable_address.instruction_length as isize
214                    + patchable_address.instruction_position.in_fragment_offset as isize;
215                let distance = final_label_position - final_end_of_instruction;
216                debug_assert!(
217                    distance >= i32::MIN as isize && distance <= i32::MAX as isize,
218                    "Patchable distance is too far. Got: {distance}"
219                );
220                let distance = distance as i32;
221                let imm32 = enc_models::Immediate32::from_i32(distance).encode();
222                patchable_imm32.copy_from_nonoverlapping(imm32.as_ptr(), imm32.len());
223            }
224        }
225    }
226
227    Ok(())
228}
229
230fn emit_fragments(
231    asm: &X86_64Assembler,
232    labels_map: &HashMap<Label, usize>,
233    offsets: &HashMap<FragmentOrderId, u32>,
234    stream: &mut dyn std::io::Write,
235) -> Result<EmissionData, AssembleError> {
236    let start = fragment_at_index!(asm, 0) as *const Fragment;
237    let end = fragment_end!(asm);
238
239    let mut emitted_bytes = 0;
240    let mut current = start;
241    while current < end {
242        let current_fragment_ref = unsafe { &*current };
243        emitted_bytes += encode_fragment(asm, current_fragment_ref, labels_map, offsets, stream)?;
244        current = unsafe { current_fragment_ref.next() };
245    }
246
247    let mut public_labels = HashMap::with_capacity(asm.public_labels.len());
248    for item in &asm.public_labels {
249        let position = labels_map.get(item).unwrap();
250        public_labels.insert(*item, *position);
251    }
252
253    let emission_data = EmissionData::new(emitted_bytes, public_labels);
254    Ok(emission_data)
255}
256
257fn encode_fragment(
258    asm: &X86_64Assembler,
259    fragment: &Fragment,
260    labels_map: &HashMap<Label, usize>,
261    offsets: &HashMap<FragmentOrderId, u32>,
262    stream: &mut dyn std::io::Write,
263) -> Result<usize, AssembleError> {
264    let start = fragment_at_index!(asm, 0) as *const Fragment;
265    let get_id = |fragment: *const Fragment| -> FragmentOrderId {
266        let u8_fragment = fragment.cast::<u8>();
267        let u8_start = start.cast::<u8>();
268        let ptr_diff = unsafe { u8_fragment.offset_from(u8_start) };
269        FragmentOrderId::from_index(ptr_diff as u32)
270    };
271
272    let get_fragment_position = |fragment: *const Fragment| {
273        let id = get_id(fragment);
274        *offsets.get(&id).unwrap() as isize
275    };
276
277    let emitted_bytes = match fragment {
278        Fragment::Bytes { .. } => {
279            let slice = unsafe {
280                let self_ptr = std::ptr::from_ref(fragment).cast::<u8>();
281                let data_ptr = self_ptr.add(size_of::<Fragment>());
282                let len = fragment.data_length() as usize;
283                std::slice::from_raw_parts(data_ptr, len)
284            };
285            stream.write_all(slice)?;
286            slice.len()
287        }
288        Fragment::Relaxable_Jump { variant, label } => {
289            let position = get_fragment_position(fragment);
290            let label_position = (*labels_map.get(label).unwrap()) as isize;
291            let diff = label_position - position;
292            match variant {
293                RelaxationVariant::Short => {
294                    let diff = diff - const_sizes::SHORT_JUMP as isize;
295                    debug_assert!(
296                        diff >= i8::MIN as isize && diff <= i8::MAX as isize,
297                        "Short relaxable jump is too far. Got: {diff}"
298                    );
299                    let imm8 = enc_models::Immediate8::from_i8(diff as i8);
300                    let encoded = enc::jmp::encode_jmp_imm8(imm8);
301                    stream.write_all(encoded.as_slice())?;
302                    const_sizes::SHORT_JUMP as usize
303                }
304                RelaxationVariant::Long => {
305                    let diff = diff - const_sizes::LONG_JUMP as isize;
306                    debug_assert!(
307                        diff >= i32::MIN as isize && diff <= i32::MAX as isize,
308                        "Long relaxable jump is too far. Got: {diff}"
309                    );
310                    let imm32 = enc_models::Immediate32::from_i32(diff as i32);
311                    let encoded = enc::jmp::encode_jmp_imm32(imm32);
312                    stream.write_all(encoded.as_slice())?;
313                    const_sizes::LONG_JUMP as usize
314                }
315            }
316        }
317        Fragment::Relaxable_CondJump {
318            variant,
319            condition,
320            label,
321        } => {
322            let position = get_fragment_position(fragment);
323            let label_position = (*labels_map.get(label).unwrap()) as isize;
324            let diff = label_position - position;
325            match variant {
326                RelaxationVariant::Short => {
327                    let diff = diff - const_sizes::SHORT_COND_JUMP as isize;
328                    debug_assert!(
329                        diff >= i8::MIN as isize && diff <= i8::MAX as isize,
330                        "Short relaxable jump is too far."
331                    );
332                    let imm8 = enc_models::Immediate8::from_i8(diff as i8);
333                    let encoded = encode_short_cond_jump(*condition, imm8);
334                    stream.write_all(encoded.as_slice())?;
335                    const_sizes::SHORT_COND_JUMP as usize
336                }
337                RelaxationVariant::Long => {
338                    let diff = diff - const_sizes::LONG_COND_JUMP as isize;
339                    debug_assert!(
340                        diff >= i32::MIN as isize && diff <= i32::MAX as isize,
341                        "Long relaxable jump is too far."
342                    );
343                    let imm32 = enc_models::Immediate32::from_i32(diff as i32);
344                    let encoded = encode_long_cond_jump(*condition, imm32);
345                    stream.write_all(encoded.as_slice())?;
346                    const_sizes::LONG_COND_JUMP as usize
347                }
348            }
349        }
350    };
351
352    Ok(emitted_bytes)
353}
354
355fn encode_short_cond_jump(cond: Condition, imm8: enc_models::Immediate8) -> enc_models::EncodedX86_64Instruction {
356    match cond {
357        Condition::Equal => enc::jcc::encode_jcc_E_imm8(imm8),
358        Condition::NotEqual => enc::jcc::encode_jcc_NE_imm8(imm8),
359        Condition::Above => enc::jcc::encode_jcc_A_imm8(imm8),
360        Condition::AboveOrEqual => enc::jcc::encode_jcc_AE_imm8(imm8),
361        Condition::Below => enc::jcc::encode_jcc_B_imm8(imm8),
362        Condition::BelowOrEqual => enc::jcc::encode_jcc_BE_imm8(imm8),
363        Condition::Greater => enc::jcc::encode_jcc_G_imm8(imm8),
364        Condition::GreaterOrEqual => enc::jcc::encode_jcc_GE_imm8(imm8),
365        Condition::Less => enc::jcc::encode_jcc_L_imm8(imm8),
366        Condition::LessOrEqual => enc::jcc::encode_jcc_LE_imm8(imm8),
367        Condition::Overflow => enc::jcc::encode_jcc_O_imm8(imm8),
368        Condition::NotOverflow => enc::jcc::encode_jcc_NO_imm8(imm8),
369        Condition::Parity => enc::jcc::encode_jcc_P_imm8(imm8),
370        Condition::NotParity => enc::jcc::encode_jcc_NP_imm8(imm8),
371        Condition::ParityOdd => todo!(),
372        Condition::ParityEven => todo!(),
373        Condition::Sign => enc::jcc::encode_jcc_S_imm8(imm8),
374        Condition::NotSign => enc::jcc::encode_jcc_NS_imm8(imm8),
375        Condition::Zero => todo!(),
376        Condition::NotZero => todo!(),
377        Condition::Carry => todo!(),
378        Condition::NotCarry => todo!(),
379    }
380}
381
382fn encode_long_cond_jump(cond: Condition, imm32: enc_models::Immediate32) -> enc_models::EncodedX86_64Instruction {
383    match cond {
384        Condition::Equal => enc::jcc::encode_jcc_E_imm32(imm32),
385        Condition::NotEqual => enc::jcc::encode_jcc_NE_imm32(imm32),
386        Condition::Above => enc::jcc::encode_jcc_A_imm32(imm32),
387        Condition::AboveOrEqual => enc::jcc::encode_jcc_AE_imm32(imm32),
388        Condition::Below => enc::jcc::encode_jcc_B_imm32(imm32),
389        Condition::BelowOrEqual => enc::jcc::encode_jcc_BE_imm32(imm32),
390        Condition::Greater => enc::jcc::encode_jcc_G_imm32(imm32),
391        Condition::GreaterOrEqual => enc::jcc::encode_jcc_GE_imm32(imm32),
392        Condition::Less => enc::jcc::encode_jcc_L_imm32(imm32),
393        Condition::LessOrEqual => enc::jcc::encode_jcc_LE_imm32(imm32),
394        Condition::Overflow => enc::jcc::encode_jcc_O_imm32(imm32),
395        Condition::NotOverflow => enc::jcc::encode_jcc_NO_imm32(imm32),
396        Condition::Parity => enc::jcc::encode_jcc_P_imm32(imm32),
397        Condition::NotParity => enc::jcc::encode_jcc_NP_imm32(imm32),
398        Condition::ParityOdd => todo!(),
399        Condition::ParityEven => todo!(),
400        Condition::Sign => enc::jcc::encode_jcc_S_imm32(imm32),
401        Condition::NotSign => enc::jcc::encode_jcc_NS_imm32(imm32),
402        Condition::Zero => todo!(),
403        Condition::NotZero => todo!(),
404        Condition::Carry => todo!(),
405        Condition::NotCarry => todo!(),
406    }
407}