diff --git a/EqSat/Cargo.lock b/EqSat/Cargo.lock index 93fb5af..73b2441 100644 --- a/EqSat/Cargo.lock +++ b/EqSat/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -40,6 +40,7 @@ dependencies = [ "ahash", "cranelift-isle", "foldhash", + "iced-x86", "libc", "mimalloc", "rand", @@ -62,6 +63,21 @@ dependencies = [ "wasi", ] +[[package]] +name = "iced-x86" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c447cff8c7f384a7d4f741cfcff32f75f3ad02b406432e8d6c878d56b1edf6b" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.153" diff --git a/EqSat/Cargo.toml b/EqSat/Cargo.toml index e1c50d3..6e6cd1f 100644 --- a/EqSat/Cargo.toml +++ b/EqSat/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [lib] crate-type = ["cdylib"] -path = "src/main.rs" +path = "src/lib.rs" [dependencies] @@ -20,9 +20,17 @@ mimalloc = { version = "*", default-features = false } # egraph = { path = "./egraph" } foldhash = "=0.1.0" +[dependencies.iced-x86] +version = "1.21.0" +features = ["code_asm"] + [profile.release] debug = true debuginfo-level = 2 -panic = "abort" lto = true codegen-units = 1 +panic = "abort" +opt-level = 3 + +[build] +rustflags = ["-C", "target-cpu=native"] \ No newline at end of file diff --git a/EqSat/src/assembler/amd64_assembler.rs b/EqSat/src/assembler/amd64_assembler.rs new file mode 100644 index 0000000..8ef3246 --- /dev/null +++ b/EqSat/src/assembler/amd64_assembler.rs @@ -0,0 +1,53 @@ +use iced_x86::{Instruction, Register}; + +pub trait IAmd64Assembler { + fn push_reg(&mut self, reg: Register); + + fn push_mem64(&mut self, base_reg: Register, offset: i32); + + fn pop_reg(&mut self, reg: Register); + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32); + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register); + + fn movabs_reg_imm64(&mut self, reg: Register, imm: u64); + + fn add_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32); + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32); + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn and_reg_imm32(&mut self, reg: Register, imm: u32); + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register); + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn not_reg(&mut self, reg: Register); + + fn shl_reg_cl(&mut self, reg: Register); + + fn shr_reg_cl(&mut self, reg: Register); + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8); + + fn call_reg(&mut self, reg: Register); + + fn ret(&mut self); + + fn get_instructions(&mut self) -> Vec; + + fn get_bytes(&mut self) -> Vec; + + fn reset(&mut self); +} diff --git a/EqSat/src/assembler/differential_tester.rs b/EqSat/src/assembler/differential_tester.rs new file mode 100644 index 0000000..86d8da6 --- /dev/null +++ b/EqSat/src/assembler/differential_tester.rs @@ -0,0 +1,206 @@ +use iced_x86::{IcedError, Register}; +use rand::Rng; + +use crate::assembler::{ + amd64_assembler::IAmd64Assembler, fast_amd64_assembler::FastAmd64Assembler, + iced_amd64_assembler::IcedAmd64Assembler, +}; + +/// Differential tester that compares FastAmd64Assembler against IcedAmd64Assembler +pub struct Amd64AssemblerDifferentialTester { + rand: rand::rngs::ThreadRng, + registers: Vec, + iced_assembler: IcedAmd64Assembler, + fast_assembler: FastAmd64Assembler, +} + +impl Amd64AssemblerDifferentialTester { + /// Creates a new differential tester with the given buffer + pub unsafe fn new(buffer: *mut u8) -> Result { + let registers = vec![ + Register::RAX, + Register::RCX, + Register::RDX, + Register::RBX, + Register::RSI, + Register::RDI, + Register::RSP, + Register::RBP, + Register::R8, + Register::R9, + Register::R10, + Register::R11, + Register::R12, + Register::R13, + Register::R14, + Register::R15, + ]; + + Ok(Self { + rand: rand::thread_rng(), + registers, + iced_assembler: IcedAmd64Assembler::new()?, + fast_assembler: FastAmd64Assembler { + p: buffer, + offset: 0, + }, + }) + } + + pub fn test() -> Result<(), Box> { + let mut buffer = vec![0u8; 64 * 4096]; + let ptr = buffer.as_mut_ptr(); + + unsafe { + let mut tester = Self::new(ptr)?; + tester.run()?; + } + + Ok(()) + } + + pub fn run(&mut self) -> Result<(), Box> { + for i in 0..self.registers.len() { + let reg1 = self.registers[i]; + self.diff_reg_insts(reg1)?; + + for j in (i + 1)..self.registers.len() { + let reg2 = self.registers[j]; + self.diff_reg_reg_insts(reg1, reg2)?; + } + } + + println!("All differential tests passed!"); + Ok(()) + } + + fn diff_reg_insts(&mut self, reg: Register) -> Result<(), Box> { + self.diff("PushReg", |asm| asm.push_reg(reg))?; + self.diff("PopReg", |asm| asm.pop_reg(reg))?; + self.diff("NotReg", |asm| asm.not_reg(reg))?; + self.diff("ShlRegCl", |asm| asm.shl_reg_cl(reg))?; + self.diff("ShrRegCl", |asm| asm.shr_reg_cl(reg))?; + self.diff("CallReg", |asm| asm.call_reg(reg))?; + + // Test reg, constant instructions + for _ in 0..100 { + let c = self.rand.gen::() as u64; + + self.diff("MovabsRegImm64", |asm| asm.movabs_reg_imm64(reg, c))?; + self.diff("AddRegImm32", |asm| asm.add_reg_imm32(reg, c as u32))?; + self.diff("SubRegImm32", |asm| asm.sub_reg_imm32(reg, c as u32))?; + self.diff("AndRegImm32", |asm| asm.and_reg_imm32(reg, c as u32))?; + self.diff("ShrRegImm8", |asm| asm.shr_reg_imm8(reg, c as u8))?; + + if reg != Register::RSP { + self.diff("PushMem64", |asm| asm.push_mem64(reg, c as i32))?; + } + } + + Ok(()) + } + + fn diff_reg_reg_insts( + &mut self, + reg1: Register, + reg2: Register, + ) -> Result<(), Box> { + // Test reg, reg instructions + self.diff("MovRegReg", |asm| asm.mov_reg_reg(reg1, reg2))?; + self.diff("MovRegReg", |asm| asm.mov_reg_reg(reg2, reg1))?; + self.diff("AddRegReg", |asm| asm.add_reg_reg(reg1, reg2))?; + self.diff("AddRegReg", |asm| asm.add_reg_reg(reg2, reg1))?; + self.diff("AndRegReg", |asm| asm.and_reg_reg(reg1, reg2))?; + self.diff("AndRegReg", |asm| asm.and_reg_reg(reg2, reg1))?; + self.diff("OrRegReg", |asm| asm.or_reg_reg(reg1, reg2))?; + self.diff("OrRegReg", |asm| asm.or_reg_reg(reg2, reg1))?; + self.diff("XorRegReg", |asm| asm.xor_reg_reg(reg1, reg2))?; + self.diff("XorRegReg", |asm| asm.xor_reg_reg(reg2, reg1))?; + self.diff("ImulRegReg", |asm| asm.imul_reg_reg(reg1, reg2))?; + self.diff("ImulRegReg", |asm| asm.imul_reg_reg(reg2, reg1))?; + + // Test reg, reg, constant instructions + for _ in 0..100 { + let c = self.rand.gen::(); + + self.diff("MovMem64Reg", |asm| asm.mov_mem64_reg(reg1, c, reg2))?; + self.diff("MovMem64Reg", |asm| asm.mov_mem64_reg(reg2, c, reg1))?; + self.diff("MovRegMem64", |asm| asm.mov_reg_mem64(reg1, reg2, c))?; + self.diff("MovRegMem64", |asm| asm.mov_reg_mem64(reg2, reg1, c))?; + self.diff("AndMem64Reg", |asm| asm.and_mem64_reg(reg1, c, reg2))?; + self.diff("AndMem64Reg", |asm| asm.and_mem64_reg(reg2, c, reg1))?; + } + + Ok(()) + } + + /// Executes a test function on both assemblers and compares results + fn diff(&mut self, method_name: &str, func: F) -> Result<(), Box> + where + F: Fn(&mut dyn IAmd64Assembler), + { + // Assemble the instruction using both assemblers + func(&mut self.iced_assembler); + func(&mut self.fast_assembler); + + // Compare the results + self.compare(method_name)?; + + // Reset both assemblers + self.iced_assembler.reset(); + self.fast_assembler.reset(); + + Ok(()) + } + + /// Compares the output of both assemblers + fn compare(&mut self, method_name: &str) -> Result<(), Box> { + let iced_insts = self.iced_assembler.get_instructions(); + let iced_bytes = self.iced_assembler.get_bytes(); + let our_insts = self.fast_assembler.get_instructions(); + let our_bytes = self.fast_assembler.get_bytes(); + + if iced_insts.is_empty() || iced_bytes.is_empty() || iced_insts.len() != our_insts.len() { + return Err(format!( + "Method {} failed: instruction count mismatch (iced: {}, ours: {})", + method_name, + iced_insts.len(), + our_insts.len() + ) + .into()); + } + + // Check if instructions are equivalent + if iced_insts.len() == 1 && our_insts.len() == 1 { + let iced_str = format!("{}", iced_insts[0]); + let our_str = format!("{}", our_insts[0]); + + if iced_str != our_str { + return Err(format!( + "Method {} failed: Instruction '{}' and '{}' not equivalent!\nIced bytes: {:?}\nOur bytes: {:?}", + method_name, + iced_str, + our_str, + iced_bytes, + our_bytes + ).into()); + } + } else { + // Compare all instructions + for (i, (iced_inst, our_inst)) in iced_insts.iter().zip(our_insts.iter()).enumerate() { + let iced_str = format!("{}", iced_inst); + let our_str = format!("{}", our_inst); + + if iced_str != our_str { + return Err(format!( + "Method {} failed at instruction {}: '{}' != '{}'", + method_name, i, iced_str, our_str + ) + .into()); + } + } + } + + Ok(()) + } +} diff --git a/EqSat/src/assembler/fast_amd64_assembler.rs b/EqSat/src/assembler/fast_amd64_assembler.rs new file mode 100644 index 0000000..2d0b59d --- /dev/null +++ b/EqSat/src/assembler/fast_amd64_assembler.rs @@ -0,0 +1,478 @@ +use iced_x86::code_asm::*; +use iced_x86::{Instruction, Register}; +use rand::Rng; +use std::fmt::Write; +use std::hint::unreachable_unchecked; +use std::time::Instant; + +use crate::assembler::amd64_assembler::IAmd64Assembler; + +// Wrapper around a stack-allocated byte buffer for building instruction byte sequences +pub struct StackBuffer<'a> { + pub arr: &'a mut [u8], + pub offset: usize, +} + +impl StackBuffer<'_> { + fn push(&mut self, v: T) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut T; + *ptr = v; + } + + self.offset += std::mem::size_of::(); + } + + fn pop(&mut self) -> T { + self.offset -= std::mem::size_of::(); + + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut T; + *ptr + } + } + + fn push_u8(&mut self, byte: u8) { + unsafe { + *self.arr.get_unchecked_mut(self.offset) = byte; + } + + self.offset += 1; + } + + fn push_i32(&mut self, byte: i32) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut i32; + *ptr = byte; + } + + self.offset += 4; + } + + fn push_u32(&mut self, byte: u32) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut u32; + *ptr = byte; + } + + self.offset += 4; + } + + fn push_u64(&mut self, byte: u64) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut u64; + *ptr = byte; + } + + self.offset += 8; + } +} + +pub struct FastAmd64Assembler { + pub p: *mut u8, + pub offset: usize, +} + +impl FastAmd64Assembler { + pub fn new(buffer: *mut u8) -> Self { + FastAmd64Assembler { + p: buffer, + offset: 0, + } + } + + fn emit_bytes(&mut self, data: &[u8]) { + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), self.p.add(self.offset), data.len()); + } + + self.offset += data.len(); + } + + fn emit_buffer(&mut self, buffer: &StackBuffer) { + unsafe { + std::ptr::copy_nonoverlapping( + buffer.arr.as_ptr(), + self.p.add(self.offset), + buffer.offset, + ); + } + + self.offset += buffer.offset; + } + + pub fn opcode_reg_reg(&mut self, opcode: u8, reg1: Register, reg2: Register) { + let mut rex = 0x48; + if is_extended(reg1) { + rex |= 0x01; + } + if is_extended(reg2) { + rex |= 0x04; + } + + let modrm = 0xC0 + | ((get_register_code(reg2) as u8 & 0x07) << 3) + | (get_register_code(reg1) as u8 & 0x07); + self.emit_bytes(&[rex, opcode, modrm]); + } + + pub fn opc_reg_imm(&mut self, mask: u8, reg: Register, imm32: u32) { + let p = &mut [0u8; 7]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0x81; + let modrm = 0xC0 | (mask << 3) | (self.get_register_code(reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + arr.push_u32(imm32); + + self.emit_buffer(&arr); + } + + pub fn shift_reg_cl(&mut self, shl: bool, reg: Register) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xD3; + let m1 = if shl { 0x04 } else { 0x05 }; + let modrm = 0xC0 | (m1 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm]); + } + + #[inline(always)] + fn mov_reg_mem64_template(&mut self, dst_reg: Register, base_reg: Register, offset: i32) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if is_extended(dst_reg) { + rex |= 0x04; + } + if is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x8B; + let modrm = 0x80 + | ((get_register_code(dst_reg) as u8 & 0x07) << 3) + | (get_register_code(base_reg) as u8 & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (get_register_code(base_reg) as u8 & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + pub fn is_extended(&mut self, reg: Register) -> bool { + return reg >= Register::R8 && reg <= Register::R15; + } + + pub fn get_register_code(&mut self, reg: Register) -> u8 { + return (reg as u8) - (Register::RAX as u8); + } +} + +impl IAmd64Assembler for FastAmd64Assembler { + fn push_reg(&mut self, reg: Register) { + if reg >= Register::RAX && reg <= Register::RDI { + let opcode = (0x50 + get_register_code(reg)) as u8; + self.emit_bytes(&[opcode]); + return; + } + + let rex = 0x41; + let opcode = (0x50 + reg as u8 - Register::R8 as u8); + self.emit_bytes(&[rex, opcode]); + } + + fn push_mem64(&mut self, base_reg: Register, disp: i32) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + if is_extended(base_reg) { + let rex = 0x49; + arr.push_u8(rex); + } + + let opcode: u8 = 0xFF; + let modrm = (0x80 | (0x06 << 3) | (get_register_code(base_reg) & 0x07)) as u8; + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib: u8 = (0x00 | (0x04 << 3) | (get_register_code(base_reg) & 0x07)) as u8; + arr.push_u8(sib); + } + + arr.push_i32(disp); + self.emit_buffer(&arr); + } + + fn pop_reg(&mut self, reg: Register) { + if reg >= Register::RAX && reg <= Register::RDI { + let opcode = 0x58 + get_register_code(reg) as u8; + self.emit_bytes(&[opcode]); + return; + } + + let rex = 0x41; + let opcode = 0x58 + get_register_code(reg) as u8 - 8; + self.emit_bytes(&[rex, opcode]); + return; + } + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x89, reg1, reg2); + } + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32) { + // This function is intentionally monomorphized for performance. + // mov_reg_mem64 has a variable length encoding depending on the base register, + // which prevents SROA from promoting our stack buffer to an SSA variable. + if base_reg == Register::R12 { + self.mov_reg_mem64_template(dst_reg, Register::R12, offset); + } else if base_reg == Register::RSP { + self.mov_reg_mem64_template(dst_reg, Register::RSP, offset); + } else { + self.mov_reg_mem64_template(dst_reg, base_reg, offset); + } + } + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(src_reg) { + rex |= 0x04; + } + if self.is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x89; + let modrm = 0x80 + | ((self.get_register_code(src_reg) & 0x07) << 3) + | (self.get_register_code(base_reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (self.get_register_code(base_reg) & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + fn movabs_reg_imm64(&mut self, reg: Register, imm64: u64) { + let p = &mut [0u8; 10]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let cond = reg >= Register::RAX && reg <= Register::RDI; + let opcode = 0xB8 + + if cond { + self.get_register_code(reg) + } else { + self.get_register_code(reg) - 8 + }; + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u64(imm64); + + self.emit_buffer(&arr); + } + + fn add_reg_reg(&mut self, dest: Register, src: Register) { + self.opcode_reg_reg(0x01, dest, src); + } + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.opc_reg_imm(0x00, reg, imm32); + } + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.opc_reg_imm(0x05, reg, imm32); + } + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register) { + let mut rex = 0x48; + if self.is_extended(reg1) { + rex |= 0x04; + } + if self.is_extended(reg2) { + rex |= 0x01; + } + + let opcode1 = 0x0F; + let opcode2 = 0xAF; + let modrm = 0xC0 + | ((self.get_register_code(reg1) & 0x07) << 3) + | (self.get_register_code(reg2) & 0x07); + + self.emit_bytes(&[rex, opcode1, opcode2, modrm]); + } + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x21, reg1, reg2); + } + + fn and_reg_imm32(&mut self, reg1: Register, imm32: u32) { + self.opc_reg_imm(0x04, reg1, imm32); + } + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(src_reg) { + rex |= 0x04; + } + if self.is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x21; + let modrm = 0x80 + | ((self.get_register_code(src_reg) & 0x07) << 3) + | (self.get_register_code(base_reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (self.get_register_code(base_reg) & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x09, reg1, reg2); + } + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x31, reg1, reg2); + } + + fn not_reg(&mut self, reg: Register) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xF7; + let modrm = 0xC0 | (0x02 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm]); + } + + fn shl_reg_cl(&mut self, reg: Register) { + self.shift_reg_cl(true, reg); + } + + fn shr_reg_cl(&mut self, reg: Register) { + self.shift_reg_cl(false, reg); + } + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xC1; + let modrm = 0xC0 | (0x05 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm, imm8]); + } + + fn call_reg(&mut self, reg: Register) { + let mut rex = 0x00; + if self.is_extended(reg) { + rex = 0x41; + } + + let opcode = 0xFF; + let modrm = 0xC0 | (0x02 << 3) | (self.get_register_code(reg) & 0x07); + + if rex != 0 { + self.emit_bytes(&[rex, opcode, modrm]); + } else { + self.emit_bytes(&[opcode, modrm]); + } + } + + fn ret(&mut self) { + self.emit_bytes(&[0xC3]); + } + + fn get_instructions(&mut self) -> Vec { + let bytes = self.get_bytes(); + let mut decoder = iced_x86::Decoder::new(64, &bytes, iced_x86::DecoderOptions::NONE); + decoder.set_ip(0); + + let mut instructions = Vec::new(); + while decoder.position() < bytes.len() { + let instruction = decoder.decode(); + instructions.push(instruction); + } + + instructions + } + + fn get_bytes(&mut self) -> Vec { + let mut bytes = Vec::new(); + for i in 0..self.offset { + unsafe { + bytes.push(*self.p.add(i)); + } + } + return bytes; + } + + fn reset(&mut self) { + self.offset = 0; + } +} + +fn is_extended(reg: Register) -> bool { + return reg >= Register::R8 && reg <= Register::R15; +} + +fn get_register_code(reg: Register) -> u8 { + return (reg as u8) - (Register::RAX as u8); +} diff --git a/EqSat/src/assembler/iced_amd64_assembler.rs b/EqSat/src/assembler/iced_amd64_assembler.rs new file mode 100644 index 0000000..1f8b6ae --- /dev/null +++ b/EqSat/src/assembler/iced_amd64_assembler.rs @@ -0,0 +1,170 @@ +use crate::assembler::amd64_assembler::IAmd64Assembler; +use iced_x86::code_asm::*; +use iced_x86::Instruction; +use iced_x86::Register; + +/// x86-64 assembler implementation using the iced-x86 library +pub struct IcedAmd64Assembler { + assembler: CodeAssembler, +} + +impl IcedAmd64Assembler { + pub fn new() -> Result { + Ok(Self { + assembler: CodeAssembler::new(64)?, + }) + } + + fn conv(reg: Register) -> AsmRegister64 { + match reg { + Register::RAX => rax, + Register::RCX => rcx, + Register::RDX => rdx, + Register::RBX => rbx, + Register::RSP => rsp, + Register::RBP => rbp, + Register::RSI => rsi, + Register::RDI => rdi, + Register::R8 => r8, + Register::R9 => r9, + Register::R10 => r10, + Register::R11 => r11, + Register::R12 => r12, + Register::R13 => r13, + Register::R14 => r14, + Register::R15 => r15, + _ => panic!("Unsupported register"), + } + } +} + +impl IAmd64Assembler for IcedAmd64Assembler { + fn push_reg(&mut self, reg: Register) { + self.assembler.push(Self::conv(reg)).unwrap(); + } + + fn push_mem64(&mut self, base_reg: Register, offset: i32) { + self.assembler + .push(qword_ptr(Self::conv(base_reg) + offset)) + .unwrap(); + } + + fn pop_reg(&mut self, reg: Register) { + self.assembler.pop(Self::conv(reg)).unwrap(); + } + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .mov(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32) { + self.assembler + .mov( + Self::conv(dst_reg), + qword_ptr(Self::conv(base_reg) + offset), + ) + .unwrap(); + } + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + self.assembler + .mov( + qword_ptr(Self::conv(base_reg) + offset), + Self::conv(src_reg), + ) + .unwrap(); + } + + fn movabs_reg_imm64(&mut self, reg: Register, imm: u64) { + self.assembler.mov(Self::conv(reg), imm).unwrap(); + } + + fn add_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .add(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.assembler.add(Self::conv(reg), imm32 as i32).unwrap(); + } + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.assembler.sub(Self::conv(reg), imm32 as i32).unwrap(); + } + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .imul_2(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .and(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn and_reg_imm32(&mut self, reg: Register, imm: u32) { + self.assembler.and(Self::conv(reg), imm as i32).unwrap(); + } + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + self.assembler + .and( + qword_ptr(Self::conv(base_reg) + offset), + Self::conv(src_reg), + ) + .unwrap(); + } + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .or(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .xor(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn not_reg(&mut self, reg: Register) { + self.assembler.not(Self::conv(reg)).unwrap(); + } + + fn shl_reg_cl(&mut self, reg: Register) { + self.assembler.shl(Self::conv(reg), cl).unwrap(); + } + + fn shr_reg_cl(&mut self, reg: Register) { + self.assembler.shr(Self::conv(reg), cl).unwrap(); + } + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8) { + self.assembler.shr(Self::conv(reg), imm8 as u32).unwrap(); + } + + fn call_reg(&mut self, reg: Register) { + self.assembler.call(Self::conv(reg)).unwrap(); + } + + fn ret(&mut self) { + self.assembler.ret().unwrap(); + } + + fn get_instructions(&mut self) -> Vec { + self.assembler.instructions().to_vec() + } + + fn get_bytes(&mut self) -> Vec { + self.assembler.assemble(0).unwrap() + } + + fn reset(&mut self) { + self.assembler.reset(); + } +} diff --git a/EqSat/src/assembler/mod.rs b/EqSat/src/assembler/mod.rs new file mode 100644 index 0000000..1cba12d --- /dev/null +++ b/EqSat/src/assembler/mod.rs @@ -0,0 +1,4 @@ +pub mod amd64_assembler; +pub mod differential_tester; +pub mod fast_amd64_assembler; +pub mod iced_amd64_assembler; diff --git a/EqSat/src/main.rs b/EqSat/src/lib.rs similarity index 98% rename from EqSat/src/main.rs rename to EqSat/src/lib.rs index 357ebe0..f3b11e0 100644 --- a/EqSat/src/main.rs +++ b/EqSat/src/lib.rs @@ -27,6 +27,7 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; +mod assembler; mod known_bits; mod mba; mod simple_ast; diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 78b8333..1bbc10a 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -1,16 +1,26 @@ type Unit = (); +use core::num; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, f32::consts::PI, ffi::{CStr, CString}, - u64, vec, + ops::Add, + u16, u64, vec, }; use ahash::AHashMap; +use iced_x86::{ + code_asm::{st, CodeAssembler}, + Code, Instruction, Register, +}; use libc::{c_char, c_void}; +use std::marker::PhantomData; use crate::{ + assembler::{ + self, amd64_assembler::IAmd64Assembler, fast_amd64_assembler::FastAmd64Assembler, *, + }, known_bits::{self, *}, mba::{self, Context as MbaContext}, truth_table_database::{TruthTable, TruthTableDatabase}, @@ -27,6 +37,7 @@ pub struct AstIdx(pub u32); pub struct Arena { pub elements: Vec<(SimpleAst, AstData)>, ast_to_idx: AHashMap, + isle_cache: AHashMap, // Map a name to it's corresponds symbol index. symbol_ids: Vec<(String, AstIdx)>, @@ -37,6 +48,7 @@ impl Arena { pub fn new() -> Self { let elements = Vec::with_capacity(65536); let ast_to_idx = AHashMap::with_capacity(65536); + let isle_cache = AHashMap::with_capacity(65536); let symbol_ids = Vec::with_capacity(255); let name_to_symbol = AHashMap::with_capacity(255); @@ -44,6 +56,7 @@ impl Arena { Arena { elements: elements, ast_to_idx: ast_to_idx, + isle_cache: isle_cache, symbol_ids: symbol_ids, name_to_symbol: name_to_symbol, @@ -73,6 +86,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Add { a, b }, data); @@ -128,6 +142,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Mul { a, b }, data); @@ -154,6 +169,7 @@ impl Arena { has_poly: true, class: AstClass::Nonlinear, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Pow { a, b }, data); @@ -201,6 +217,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Neg { a }, data); } @@ -219,6 +236,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Lshr { a, b }, data); } @@ -242,6 +260,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Zext { a, to: width }, data); @@ -266,6 +285,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Trunc { a, to: width }, data); @@ -278,6 +298,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::constant(c, width), + imut_data: 0, }; // Reduce the constant modulo 2**width @@ -293,6 +314,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::empty(width), + imut_data: 0, }; return self.insert_ast_node( @@ -319,6 +341,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::empty(width), + imut_data: 0, }; let symbol_ast_idx = self.insert_ast_node( @@ -380,6 +403,14 @@ impl Arena { unsafe { self.elements.get_unchecked(idx.0 as usize).1 } } + pub fn get_data_mut(&mut self, idx: AstIdx) -> &mut AstData { + unsafe { &mut self.elements.get_unchecked_mut(idx.0 as usize).1 } + } + + pub fn set_data(&mut self, idx: AstIdx, data: AstData) { + unsafe { self.elements.get_unchecked_mut(idx.0 as usize).1 = data } + } + pub fn get_bin_width(&self, a: AstIdx, b: AstIdx) -> u8 { let a_width = self.get_width(a); let b_width = self.get_width(b); @@ -414,6 +445,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return data; @@ -487,7 +519,13 @@ pub struct AstData { // Classification of the ast class: AstClass, + // Known zero or one bits known_bits: KnownBits, + + // Internal mutable data for use in different algorithms. + // Specifically we use this field to avoid unnecessarily storing data in hashmaps. + // e.g "how many users does this node have?" can be stored here temporarily. + imut_data: u64, } #[derive(Clone, Hash, PartialEq, Eq)] @@ -813,6 +851,9 @@ pub fn eval_ast(ctx: &Context, idx: AstIdx, value_mapping: &HashMap // Recursively apply ISLE over an AST. pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { + if ctx.arena.isle_cache.get(&idx).is_some() { + return *ctx.arena.isle_cache.get(&idx).unwrap(); + } let mut ast = ctx.arena.get_node(idx).clone(); match ast { @@ -862,7 +903,9 @@ pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { ast = result.unwrap(); } - return ctx.arena.ast_to_idx[&ast]; + let result = ctx.arena.ast_to_idx[&ast]; + ctx.arena.isle_cache.insert(idx, result); + result } // Evaluate the current AST for all possible combinations of zeroes and ones as inputs. @@ -1210,6 +1253,24 @@ pub extern "C" fn ContextGetKnownBits(ctx: *mut Context, id: AstIdx) -> KnownBit } } +#[no_mangle] +pub extern "C" fn ContextGetImutData(ctx: *mut Context, id: AstIdx) -> u64 { + unsafe { + let kb = (*ctx).arena.get_data(id).imut_data; + + return kb; + } +} + +#[no_mangle] +pub extern "C" fn ContextSetImutData(ctx: *mut Context, id: AstIdx, imut: u64) { + unsafe { + let mut data = (*ctx).arena.get_data(id).clone(); + data.imut_data = imut; + (*ctx).arena.set_data(id, data); + } +} + #[no_mangle] pub extern "C" fn ContextGetOp0(ctx: *const Context, id: AstIdx) -> AstIdx { unsafe { @@ -1606,7 +1667,7 @@ unsafe fn jit_constant(c: u64, page: *mut u8, offset: &mut usize) { } #[no_mangle] -pub extern "C" fn GetPowPtr(mut base: u64, mut exp: u64) -> u64 { +pub extern "C" fn GetPowPtr() -> u64 { return Pow as *const () as u64; } @@ -1625,7 +1686,7 @@ pub extern "C" fn Pow(mut base: u64, mut exp: u64) -> u64 { } #[no_mangle] -pub unsafe extern "C" fn ContextCompile( +pub unsafe extern "C" fn ContextCompileLegacy( ctx_p: *mut Context, node: AstIdx, mask: u64, @@ -1674,6 +1735,28 @@ pub unsafe extern "C" fn ContextCompile( emit_u8(page, &mut offset, RET); } +#[no_mangle] +pub unsafe extern "C" fn ContextCompile( + ctx_p: *mut Context, + node: AstIdx, + mask: u64, + variables: *const AstIdx, + var_count: u64, + page: *mut u8, +) { + let mut ctx: &mut Context = &mut (*ctx_p); + + let mut vars: Vec = Vec::new(); + // JIT code + for i in 0..var_count { + vars.push(*variables.add(i as usize)); + } + + let mut assembler = FastAmd64Assembler::new(page); + let mut compiler = Amd64OptimizingJit::::new(); + compiler.compile(ctx, &mut assembler, node, &vars, page, false); +} + #[no_mangle] pub unsafe extern "C" fn ContextExecute( multi_bit_u: u32, @@ -2465,3 +2548,692 @@ pub fn get_group_size_index(mask: u64) -> u32 { pub fn get_group_size(idx: u32) -> u32 { return 1 << idx; } + +#[derive(Copy, Clone)] +struct Location { + pub register: Register, +} + +impl Location { + pub fn is_register(&self) -> bool { + return self.register != Register::None; + } + + pub fn reg(r: Register) -> Location { + return Location { register: r }; + } + + pub fn stack() -> Location { + return Location { + register: Register::None, + }; + } +} + +trait Exists { + fn exists(&self) -> bool; +} + +// Assert that `NodeInfo` is 8 bytes in size +const _: () = [(); 1][(core::mem::size_of::() == 8) as usize ^ 1]; + +#[derive(Copy, Clone)] +struct NodeInfo { + pub num_uses: u16, + pub var_idx: u16, + pub slot_idx: u16, + pub exists: u16, +} + +impl NodeInfo { + pub fn new(num_instances: u16) -> Self { + return NodeInfo { + num_uses: num_instances, + var_idx: 0, + slot_idx: u16::MAX, + exists: 1, + }; + } +} + +impl From for NodeInfo { + fn from(value: u64) -> Self { + unsafe { + let ptr = (&value) as *const u64 as *const NodeInfo; + *ptr + } + } +} + +impl Into for NodeInfo { + fn into(self) -> u64 { + unsafe { + let ptr = (&self) as *const NodeInfo as *const u64; + *ptr + } + } +} + +impl Exists for NodeInfo { + fn exists(&self) -> bool { + return self.exists != 0; + } +} + +struct AuxInfoStorage + Into + Exists> { + _marker: PhantomData, +} + +impl + Into + Exists> AuxInfoStorage { + pub fn contains(ctx: &mut Context, idx: AstIdx) -> bool { + let value = Self::get(ctx, idx); + return value.exists(); + } + + pub fn get(ctx: &mut Context, idx: AstIdx) -> T { + let value = ctx.arena.get_data(idx).imut_data; + return T::from(value); + } + + pub fn get_unsafe(ptr: *mut (SimpleAst, AstData), idx: AstIdx) -> T { + unsafe { + let value = (*ptr.add(idx.0 as usize)).1.imut_data; + return T::from(value); + } + } + + pub fn get_ptr_unsafe(ptr: *mut (SimpleAst, AstData), idx: AstIdx) -> *mut NodeInfo { + unsafe { + let data = &mut (*ptr.add(idx.0 as usize)).1; + return (&mut (*ptr).1.imut_data) as *mut u64 as *mut NodeInfo; + } + } + + pub fn set(ctx: &mut Context, idx: AstIdx, value: T) { + ctx.arena.get_data_mut(idx).imut_data = value.into(); + } + + pub fn set_unsafe(ptr: *mut (SimpleAst, AstData), idx: AstIdx, value: T) { + unsafe { + let data = &mut (*ptr.add(idx.0 as usize)).1; + data.imut_data = value.into(); + } + } + + pub fn try_get(ctx: &mut Context, idx: AstIdx) -> Option { + let value = Self::get(ctx, idx); + if value.exists() { + return Some(value); + } + + return None; + } + + pub fn try_get_unsafe(ptr: *mut (SimpleAst, AstData), idx: AstIdx) -> Option { + let value = Self::get_unsafe(ptr, idx); + if value.exists() { + return Some(value); + } + + return None; + } +} + +const ARGS_REGISTER: Register = Register::RCX; +const LOCALS_REGISTER: Register = Register::RBP; +const SCRATCH1: Register = Register::RSI; +const SCRATCH2: Register = Register::RDI; + +static VOLATILE_REGS: &'static [Register] = &[ + Register::RAX, + Register::RCX, + Register::RDX, + Register::R8, + Register::R9, + Register::R10, + Register::R11, +]; +static NONVOLATILE_REGS: &'static [Register] = &[ + Register::RBP, + Register::RBX, + Register::RDI, + Register::RSI, + Register::R12, + Register::R13, + Register::R14, + Register::R15, +]; + +struct Amd64OptimizingJit { + // Available registers for allocation. + free_registers: Vec, + // Post order traversal of the DAG. + dfs: Vec, + // Number of allocated stack slots + slot_count: u16, + // Stack of in-use locations. + stack: Vec, + _marker: PhantomData, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +struct StTuple { + owner: AstIdx, + value: AstIdx, +} + +impl StTuple { + pub fn new(owner: AstIdx, value: AstIdx) -> Self { + return StTuple { owner, value }; + } +} + +impl Amd64OptimizingJit { + fn new() -> Self { + return Amd64OptimizingJit { + free_registers: vec![ + Register::RAX, + Register::RDX, + Register::RBX, + Register::R8, + Register::R9, + Register::R10, + Register::R11, + Register::R12, + Register::R13, + Register::R14, + Register::R15, + ], + dfs: Vec::with_capacity(64), + slot_count: 0, + stack: Vec::with_capacity(16), + _marker: PhantomData, + }; + } + + #[inline(never)] + fn compile( + &mut self, + ctx: &mut Context, + assembler: &mut T, + idx: AstIdx, + variables: &Vec, + page_ptr: *mut u8, + use_iced_backend: bool, + ) { + // Collect necessary information about nodes for JITing (dfs order, how many users a node has). + Self::collect_info(ctx, idx, &mut self.dfs); + + // Store each variables argument index + for i in 0..variables.len() { + let var_idx = variables[i]; + let mut info = AuxInfoStorage::::get(ctx, var_idx); + info.var_idx = i as u16; + AuxInfoStorage::::set(ctx, var_idx, info); + } + + // Compile the instructions to x86. + self.lower_to_x86(ctx, assembler); + + // Clear each node's mutable data. + for id in self.dfs.iter() { + let mut info = AuxInfoStorage::::get(ctx, *id); + AuxInfoStorage::::set(ctx, *id, NodeInfo::from(0)); + } + + // If using the fast assembler backend, we've already emitted x86. + // However the stack pointer adjustment needs to fixed up, because it wasn't known during prologue emission. + if !use_iced_backend { + Self::fixup_frame_ptr(page_ptr, self.slot_count.into()); + return; + } + + // Otherwise adjust the rsp in iced. + let mut instructions = assembler.get_instructions(); + Self::fixup_iced_frame_ptr(&mut instructions, self.slot_count.into()); + + // Write the instructions to memory. + // ICED internally emits a list of assembled instructions rather than raw x86 bytes + // so this must be done after the fact. + Self::write_instructions(page_ptr, &instructions); + } + + fn collect_info(ctx: &mut Context, idx: AstIdx, dfs: &mut Vec) { + let existing = AuxInfoStorage::::try_get(ctx, idx); + if existing.is_some() { + dfs.push(idx); + return; + } + + let node = ctx.arena.get_node(idx).clone(); + match node { + SimpleAst::Add { a, b } + | SimpleAst::Mul { a, b } + | SimpleAst::Pow { a, b } + | SimpleAst::And { a, b } + | SimpleAst::Or { a, b } + | SimpleAst::Xor { a, b } + | SimpleAst::Lshr { a, b } => { + Self::collect_info(ctx, a, dfs); + Self::collect_info(ctx, b, dfs); + + Self::inc_users(ctx, a); + Self::inc_users(ctx, b); + } + SimpleAst::Neg { a } | SimpleAst::Zext { a, .. } | SimpleAst::Trunc { a, .. } => { + Self::collect_info(ctx, a, dfs); + Self::inc_users(ctx, a); + } + SimpleAst::Constant { .. } | SimpleAst::Symbol { .. } => (), + } + + dfs.push(idx); + AuxInfoStorage::::set(ctx, idx, NodeInfo::new(0)); + } + + fn inc_users(ctx: &mut Context, idx: AstIdx) { + let mut info = AuxInfoStorage::::get(ctx, idx); + info.num_uses = info.num_uses.add(1); + AuxInfoStorage::::set(ctx, idx, info); + } + + fn inc_users_unsafe(ptr: *mut (SimpleAst, AstData), idx: AstIdx) { + let mut info = AuxInfoStorage::::get_unsafe(ptr, idx); + info.num_uses = info.num_uses.add(1); + AuxInfoStorage::::set_unsafe(ptr, idx, info); + } + + #[inline(never)] + fn lower_to_x86(&mut self, ctx: &mut Context, assembler: &mut T) { + // rcx reserved for local variables ptr (or all vars in the case of a semi-linear result vector) + // RSI, RDI reserved for temporary use + + // Emit the prologue. Initially we reserve space for u32::MAX slots, which we will adjust later. + Self::emit_prologue(assembler, u32::MAX); + + for i in 0..self.dfs.len() { + let idx = unsafe { *self.dfs.get_unchecked(i) }; + let node_info = AuxInfoStorage::::get(ctx, idx); + if node_info.num_uses > 1 && node_info.slot_idx != u16::MAX { + self.load_slot_value(assembler, node_info.slot_idx as u32); + continue; + } + + let width = ctx.arena.get_width(idx) as u32; + let node = ctx.arena.get_node(idx).clone(); + match node { + SimpleAst::Add { a, b } + | SimpleAst::Mul { a, b } + | SimpleAst::Pow { a, b } + | SimpleAst::And { a, b } + | SimpleAst::Or { a, b } + | SimpleAst::Xor { a, b } + | SimpleAst::Lshr { a, b } => { + self.lower_binop(ctx, assembler, idx, node, width, node_info) + } + SimpleAst::Constant { c, width } => self.lower_constant(assembler, c), + SimpleAst::Symbol { .. } => { + self.lower_variable(assembler, node_info.var_idx.into(), width) + } + SimpleAst::Neg { .. } | SimpleAst::Zext { .. } => self.lower_unary_op( + ctx, + assembler, + idx, + width, + node_info, + matches!(node, SimpleAst::Neg { .. }), + ), + SimpleAst::Trunc { a, to } => { + let w = ctx.get_width(a); + self.lower_zext(ctx, assembler, idx, w.into(), node_info) + } + } + } + + if self.stack.len() != 1 { + panic!("Unbalanced stack after lowering!"); + } + + let result = self.stack.pop().unwrap(); + if result.is_register() { + assembler.mov_reg_reg(Register::RAX, result.register); + } else { + assembler.pop_reg(Register::RAX); + } + + // Reduce the result modulo 2**w + let w = ctx.get_width(*self.dfs.last().unwrap()); + assembler.movabs_reg_imm64(SCRATCH1, get_modulo_mask(w)); + assembler.and_reg_reg(Register::RAX, SCRATCH1); + + Self::emit_epilogue(assembler, self.slot_count as u32); + } + + fn load_slot_value(&mut self, assembler: &mut T, slot_idx: u32) { + if !self.free_registers.is_empty() { + let t = self.free_registers.pop().unwrap(); + assembler.mov_reg_mem64(t, LOCALS_REGISTER, (slot_idx * 8) as i32); + self.stack.push(Location::reg(t)); + return; + } + + assembler.push_mem64(LOCALS_REGISTER, 8 * (slot_idx as i32)); + self.stack.push(Location::stack()); + } + + fn lower_binop( + &mut self, + ctx: &mut Context, + assembler: &mut T, + idx: AstIdx, + node: SimpleAst, + width: u32, + node_info: NodeInfo, + ) { + let rhs_loc = self.stack.pop().unwrap(); + + // If the rhs is stored in a register, we use it. + let mut rhs_dest = SCRATCH1; + if rhs_loc.is_register() { + rhs_dest = rhs_loc.register; + } + // If stored on the stack, pop into scratch register + else { + assembler.pop_reg(rhs_dest); + } + + // Regardless we have the rhs in a register now. + let lhs_loc = self.stack.pop().unwrap(); + let mut lhs_dest = SCRATCH2; + if lhs_loc.is_register() { + lhs_dest = lhs_loc.register; + } else { + assembler.pop_reg(lhs_dest); + } + + match node { + SimpleAst::Add { a, b } => assembler.add_reg_reg(lhs_dest, rhs_dest), + SimpleAst::Mul { a, b } => assembler.imul_reg_reg(lhs_dest, rhs_dest), + SimpleAst::And { a, b } => assembler.and_reg_reg(lhs_dest, rhs_dest), + SimpleAst::Or { a, b } => assembler.or_reg_reg(lhs_dest, rhs_dest), + SimpleAst::Xor { a, b } => assembler.xor_reg_reg(lhs_dest, rhs_dest), + SimpleAst::Lshr { a, b } => { + if width % 8 != 0 { + panic!("Cannot jit lshr with non power of 2 width!"); + } + + // Reduce shift count modulo the bit width of the operation + // TODO: (a) Handle non power of two bit widths, + // (b) shift beyond bounds should yield zero + assembler.and_reg_imm32(rhs_dest, width - 1); + + assembler.push_reg(Register::RCX); + assembler.mov_reg_reg(Register::RCX, rhs_dest); + assembler.shr_reg_cl(lhs_dest); + assembler.pop_reg(Register::RCX); + } + SimpleAst::Pow { a, b } => { + for r in VOLATILE_REGS.iter() { + assembler.push_reg(*r); + } + + assembler.mov_reg_reg(Register::RCX, lhs_dest); + assembler.mov_reg_reg(Register::RDX, rhs_dest); + + // TODO: Inline 'pow' function + assembler.movabs_reg_imm64(Register::R11, Pow as *const () as u64); + assembler.sub_reg_imm32(Register::RSP, 32); + assembler.call_reg(Register::R11); + assembler.add_reg_imm32(Register::RSP, 32); + assembler.mov_reg_reg(SCRATCH1, Register::RAX); + + // Restore volatile registers + for ® in VOLATILE_REGS.iter().rev() { + assembler.pop_reg(reg); + } + + assembler.mov_reg_reg(lhs_dest, SCRATCH1); + } + _ => unreachable!("Node is not a binary operator"), + } + + Self::reduce_register_modulo(assembler, width, lhs_dest, SCRATCH1); + + if rhs_loc.is_register() { + self.free_registers.push(rhs_loc.register); + } + + // If there are multiple users of this value, throw it in a stack slot. + let multiple_users = node_info.num_uses > 1; + if multiple_users { + assembler.mov_mem64_reg(LOCALS_REGISTER, 8 * (self.slot_count as i32), lhs_dest); + self.assign_value_slot(ctx, idx, node_info); + } + + // If the lhs is already in a register, don't move it! + if lhs_dest != SCRATCH2 { + self.stack.push(Location::reg(lhs_dest)); + return; + } + + // Try to allocate a reg for this value + if self.free_registers.len() > 0 { + let dest = self.free_registers.pop().unwrap(); + assembler.mov_reg_reg(dest, lhs_dest); + self.stack.push(Location::reg(dest)); + } + // Otherwise this goes on the stack + else { + assembler.push_reg(lhs_dest); + self.stack.push(Location::stack()); + } + + if lhs_loc.is_register() { + self.free_registers.push(lhs_loc.register); + } + } + + fn lower_constant(&mut self, assembler: &mut T, c: u64) { + if !self.free_registers.is_empty() { + let dest = self.free_registers.pop().unwrap(); + assembler.movabs_reg_imm64(dest, c); + self.stack.push(Location::reg(dest)); + return; + } + + assembler.movabs_reg_imm64(SCRATCH1, c); + assembler.push_reg(SCRATCH1); + self.stack.push(Location::stack()); + } + + fn lower_variable(&mut self, assembler: &mut T, var_arr_idx: i32, width: u32) { + if !self.free_registers.is_empty() { + let dest = self.free_registers.pop().unwrap(); + assembler.mov_reg_mem64(dest, ARGS_REGISTER, var_arr_idx * 8); + Self::reduce_register_modulo(assembler, width, dest, SCRATCH1); + self.stack.push(Location::reg(dest)); + return; + } + + assembler.push_mem64(ARGS_REGISTER, var_arr_idx * 8); + self.stack.push(Location::stack()); + Self::reduce_location_modulo(assembler, Location::stack(), width); + } + + fn lower_unary_op( + &mut self, + ctx: &mut Context, + assembler: &mut T, + idx: AstIdx, + width: u32, + node_info: NodeInfo, + is_neg: bool, + ) { + let curr = self.stack.pop().unwrap(); + let mut dest_reg = SCRATCH1; + if curr.is_register() { + dest_reg = curr.register; + } else { + assembler.pop_reg(dest_reg); + } + + if is_neg { + assembler.not_reg(dest_reg); + Self::reduce_register_modulo(assembler, width, dest_reg, SCRATCH2); + } else { + assembler.movabs_reg_imm64(SCRATCH2, get_modulo_mask(width as u8)); + assembler.and_reg_reg(dest_reg, SCRATCH2); + } + + // If there are multiple users, store the value in a slot. + let multiple_users = node_info.num_uses > 1; + if multiple_users { + assembler.mov_mem64_reg(LOCALS_REGISTER, 8 * (self.slot_count as i32), dest_reg); + self.assign_value_slot(ctx, idx, node_info); + } + + if dest_reg != SCRATCH1 { + self.stack.push(Location::reg(dest_reg)); + return; + } + + if !self.free_registers.is_empty() { + let dest = self.free_registers.pop().unwrap(); + assembler.mov_reg_reg(dest, dest_reg); + self.stack.push(Location::reg(dest)); + return; + } + + // Otherwise this goes on the stack + assembler.push_reg(dest_reg); + self.stack.push(Location::stack()); + } + + fn lower_zext( + &mut self, + ctx: &mut Context, + assembler: &mut T, + idx: AstIdx, + from_width: u32, + node_info: NodeInfo, + ) { + // If we only have one user, this is a no-op. The result we care about is already on the location stack, + // and the zero-extension is implicit. + let peek = self.stack.pop().unwrap(); + self.stack.push(peek); + + // Because we are zero extending, we need to reduce the value modulo 2**w + // In other places we can get away with omitting this step. + Self::reduce_location_modulo(assembler, peek, from_width); + + if node_info.num_uses <= 1 { + return; + } + + if peek.is_register() { + assembler.mov_mem64_reg(LOCALS_REGISTER, 8 * (self.slot_count as i32), peek.register); + } else { + assembler.mov_reg_mem64(SCRATCH1, Register::RSP, 0); + assembler.mov_mem64_reg(LOCALS_REGISTER, 8 * (self.slot_count as i32), SCRATCH1); + } + + self.assign_value_slot(ctx, idx, node_info); + } + + fn reduce_register_modulo( + assembler: &mut T, + width: u32, + dst_reg: Register, + free_reg: Register, + ) { + debug_assert!(dst_reg != free_reg); + if width == 64 { + return; + } + + let mask = get_modulo_mask(width as u8); + assembler.movabs_reg_imm64(free_reg, mask); + assembler.and_reg_reg(dst_reg, free_reg); + } + + fn reduce_location_modulo(assembler: &mut T, loc: Location, width: u32) { + if width == 64 { + return; + } + + let mask = get_modulo_mask(width as u8); + assembler.movabs_reg_imm64(SCRATCH1, mask); + if loc.is_register() { + assembler.and_reg_reg(loc.register, SCRATCH1); + } else { + assembler.and_mem64_reg(Register::RSP, 0, SCRATCH1); + } + } + + fn assign_value_slot(&mut self, ctx: &mut Context, idx: AstIdx, mut node_info: NodeInfo) { + node_info.slot_idx = self.slot_count; + AuxInfoStorage::::set(ctx, idx, node_info); + self.slot_count = self.slot_count.checked_add(1).unwrap(); + } + + fn emit_prologue(assembler: &mut T, num_stack_slots: u32) { + // Push all nonvolatile registers + for reg in NONVOLATILE_REGS.iter() { + assembler.push_reg(*reg); + } + + // Allocate stack space for local variables + assembler.sub_reg_imm32(Register::RSP, (num_stack_slots * 8)); + // Point rbp to the local var array + assembler.mov_reg_reg(LOCALS_REGISTER, Register::RSP); + // mov rbp, rsp + assembler.mov_reg_reg(Register::RBP, Register::RSP); + } + + fn emit_epilogue(assembler: &mut T, num_stack_slots: u32) { + // Reset rsp + assembler.add_reg_imm32(Register::RSP, 8 * num_stack_slots); + // Restore nonvolatile registers (including rbp) + for i in NONVOLATILE_REGS.iter().rev() { + assembler.pop_reg(*i); + } + + assembler.ret(); + } + + fn fixup_frame_ptr(ptr: *mut u8, slot_count: u32) { + unsafe { + let sub_rsp_start = ptr.add(12); + let encoding = (*sub_rsp_start.cast::()) & 0xFF00FFFFFFFFFFFF; + if encoding != 0x4800fffff8ec8148 { + panic!("Rsp fixup position changed!"); + } + + let conv = slot_count * 8; + *(sub_rsp_start.add(3).cast::()) = conv; + } + } + + fn fixup_iced_frame_ptr(instructions: &mut Vec, slot_count: u32) { + let sub = instructions[8]; + if sub.code() != Code::Sub_rm64_imm8 && sub.code() != Code::Sub_rm64_imm32 { + panic!("Rsp fixup position changed!"); + } + + instructions[8] = + Instruction::with2(Code::Sub_rm64_imm32, Register::RSP, (slot_count * 8) as i32) + .unwrap(); + } + + fn write_instructions(ptr: *mut u8, instructions: &Vec) { + let mut assembler = CodeAssembler::new(64).unwrap(); + for inst in instructions.iter() { + assembler.add_instruction(*inst); + } + + let bytes = assembler.assemble(ptr as u64).unwrap(); + unsafe { + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len()); + } + } +} diff --git a/MSiMBA b/MSiMBA index f113c3b..7a0fa6d 160000 --- a/MSiMBA +++ b/MSiMBA @@ -1 +1 @@ -Subproject commit f113c3be7305715359e076f008a4aae2b11e6ada +Subproject commit 7a0fa6d24f2171158c726a9bd89554c530738a35 diff --git a/Mba.Simplifier/Bindings/AstCtx.cs b/Mba.Simplifier/Bindings/AstCtx.cs index 3958d40..7f9bcb9 100644 --- a/Mba.Simplifier/Bindings/AstCtx.cs +++ b/Mba.Simplifier/Bindings/AstCtx.cs @@ -49,6 +49,7 @@ public unsafe AstCtx() // Constructors public unsafe AstIdx Add(AstIdx a, AstIdx b) => Api.ContextAdd(this, a, b); + public unsafe AstIdx Sub(AstIdx a, AstIdx b) => Add(a, Mul(Constant(ulong.MaxValue, GetWidth(b)), b)); public unsafe AstIdx Mul(AstIdx a, AstIdx b) => Api.ContextMul(this, a, b); public unsafe AstIdx Pow(AstIdx a, AstIdx b) => Api.ContextPow(this, a, b); public unsafe AstIdx And(AstIdx a, AstIdx b) => Api.ContextAnd(this, a, b); @@ -161,6 +162,8 @@ public AstIdx Xor(IEnumerable nodes) public unsafe bool GetHasPoly(AstIdx id) => Api.ContextGetHasPoly(this, id); public unsafe AstClassification GetClass(AstIdx id) => Api.ContextGetClass(this, id); public unsafe KnownBits GetKnownBits(AstIdx id) => Api.ContextGetKnownBits(this, id); + public unsafe ulong GetImutData(AstIdx id) => Api.ContextGetImutData(this, id); + public unsafe void SetImutData(AstIdx id, ulong imut) => Api.ContextSetImutData(this, id, imut); public unsafe AstIdx GetOp0(AstIdx id) => Api.ContextGetOp0(this, id); public unsafe AstIdx GetOp1(AstIdx id) { @@ -273,11 +276,19 @@ public unsafe void JitEvaluate(AstIdx id, ulong mask, bool isMultibit, uint bitW } } - public unsafe nint Compile(AstIdx id, ulong mask, AstIdx[] variables, nint rwxPagePtr) + public unsafe nint CompileLegacy(AstIdx id, ulong mask, AstIdx[] variables, nint rwxPagePtr) { fixed (AstIdx* arrPtr = &variables[0]) { - return (nint)Api.ContextCompile(this, id, mask, arrPtr, (ulong)variables.Length, (ulong*)rwxPagePtr); + return (nint)Api.ContextCompileLegacy(this, id, mask, arrPtr, (ulong)variables.Length, (ulong*)rwxPagePtr); + } + } + + public unsafe void Compile(AstIdx id, ulong mask, AstIdx[] variables, nint rwxPagePtr) + { + fixed (AstIdx* arrPtr = variables) + { + Api.ContextCompile(this, id, mask, arrPtr, (ulong)variables.Length, (ulong*)rwxPagePtr); } } @@ -352,7 +363,7 @@ public static class Api public unsafe static extern uint ContextGetCost(OpaqueAstCtx* ctx, AstIdx id); [DllImport("eq_sat")] - [SuppressGCTransition] + [SuppressGCTransition] [return: MarshalAs(UnmanagedType.U1)] public unsafe static extern bool ContextGetHasPoly(OpaqueAstCtx* ctx, AstIdx id); @@ -364,6 +375,14 @@ public static class Api [SuppressGCTransition] public unsafe static extern KnownBits ContextGetKnownBits(OpaqueAstCtx* ctx, AstIdx id); + [DllImport("eq_sat")] + [SuppressGCTransition] + public unsafe static extern ulong ContextGetImutData(OpaqueAstCtx* ctx, AstIdx id); + + [DllImport("eq_sat")] + [SuppressGCTransition] + public unsafe static extern void ContextSetImutData(OpaqueAstCtx* ctx, AstIdx id, ulong data); + [DllImport("eq_sat")] [SuppressGCTransition] public unsafe static extern AstIdx ContextGetOp0(OpaqueAstCtx* ctx, AstIdx id); @@ -404,7 +423,10 @@ public static class Api public unsafe static extern ulong* ContextJit(OpaqueAstCtx* ctx, AstIdx id, ulong mask, uint isMultiBit, uint bitWidth, AstIdx* variableArray, ulong varCount, ulong numCombinations, ulong* rwxJitPage, ulong* outputArray); [DllImport("eq_sat")] - public unsafe static extern ulong* ContextCompile(OpaqueAstCtx* ctx, AstIdx id, ulong mask, AstIdx* variableArray, ulong varCount, ulong* rwxJitPage); + public unsafe static extern ulong* ContextCompileLegacy(OpaqueAstCtx* ctx, AstIdx id, ulong mask, AstIdx* variableArray, ulong varCount, ulong* rwxJitPage); + + [DllImport("eq_sat")] + public unsafe static extern void ContextCompile(OpaqueAstCtx* ctx, AstIdx id, ulong mask, AstIdx* variableArray, ulong varCount, ulong* rwxJitPage); [DllImport("eq_sat")] public unsafe static extern ulong* ContextExecute(uint isMultiBit, uint bitWidth, ulong varCount, ulong numCombinations, ulong* rwxJitPage, ulong* outputArray, uint isOneBitVars); diff --git a/Mba.Simplifier/Bindings/AstIdx.cs b/Mba.Simplifier/Bindings/AstIdx.cs index 1a5849f..93d1720 100644 --- a/Mba.Simplifier/Bindings/AstIdx.cs +++ b/Mba.Simplifier/Bindings/AstIdx.cs @@ -22,6 +22,11 @@ public override string ToString() return ctx.GetAstString(Idx); } + public override int GetHashCode() + { + return Idx.GetHashCode(); + } + public unsafe static implicit operator uint(AstIdx reg) => reg.Idx; public unsafe static implicit operator AstIdx(uint reg) => new AstIdx(reg); diff --git a/Mba.Simplifier/Fuzzing/MSiMBAFuzzer.cs b/Mba.Simplifier/Fuzzing/MSiMBAFuzzer.cs index 49c2467..936d477 100644 --- a/Mba.Simplifier/Fuzzing/MSiMBAFuzzer.cs +++ b/Mba.Simplifier/Fuzzing/MSiMBAFuzzer.cs @@ -36,6 +36,7 @@ public static void Run() // Skip if the expression simplifies to a constant var variables = ctx.CollectVariables(fCase); + if (variables.Count == 0) continue; @@ -47,13 +48,6 @@ public static void Run() var vec2 = LinearSimplifier.JitResultVectorOld(ctx, w, (ulong)ModuloReducer.GetMask(w), variables, result, multiBit, numCombinations); if(!vec1.SequenceEqual(vec2)) throw new InvalidOperationException("Mismatch"); - - // Fuzz the other JIT - var jit = new Amd64OptimizingJit(ctx); - jit.Compile(fCase, variables, sharpPtr, true); - var vec3 = LinearSimplifier.Execute(ctx, w, ulong.MaxValue, variables, multiBit, numCombinations, sharpPtr, false); - if (!vec3.SequenceEqual(vec2)) - throw new InvalidOperationException("Mismatch"); } } diff --git a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs index 46f90df..cafc414 100644 --- a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs +++ b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs @@ -33,6 +33,15 @@ public unsafe Amd64AssemblerDifferentialTester(byte* buffer) icedAssembler = new IcedAmd64Assembler(new Assembler(64)); } + public static void Test() + { + var buffer = new byte[64 * 4096]; + fixed(byte* p = buffer) + { + new Amd64AssemblerDifferentialTester(p).Run(); + } + } + public void Run() { for (int i = 0; i < registers.Length; i++) @@ -56,13 +65,15 @@ private void DiffRegInsts(Register reg1) for (int _ = 0; _ < 100; _++) { var c = (ulong)rand.NextInt64(); + c |= rand.Next(0, 2) == 0 ? 0 : (1ul << 63); + Diff(nameof(IAmd64Assembler.MovabsRegImm64), reg1, c); - Diff(nameof(IAmd64Assembler.AddRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.SubRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.AndRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.ShrRegImm8), reg1, c); + Diff(nameof(IAmd64Assembler.AddRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.SubRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.AndRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.ShrRegImm8), reg1, (byte)c); if (reg1 != rsp) - Diff(nameof(IAmd64Assembler.PushMem64), reg1, c); + Diff(nameof(IAmd64Assembler.PushMem64), reg1, (int)c); } } @@ -90,6 +101,8 @@ private void DiffRegRegInsts(Register reg1, Register reg2) Diff(nameof(IAmd64Assembler.MovMem64Reg), reg2, (int)c, reg1); Diff(nameof(IAmd64Assembler.MovRegMem64), reg1, reg2, (int)c); Diff(nameof(IAmd64Assembler.MovRegMem64), reg2, reg1, (int)c); + Diff(nameof(IAmd64Assembler.AndMem64Reg), reg1, (int)c, reg2); + Diff(nameof(IAmd64Assembler.AndMem64Reg), reg2, (int)c, reg1); } } diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index fe79ed6..42f19f8 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -12,6 +12,7 @@ using System.Diagnostics; using System.Globalization; using System.Linq; +using System.Net.Mail; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; @@ -41,22 +42,109 @@ public static Location Stack() public bool IsRegister => Register != Register.None; } + [StructLayout(LayoutKind.Explicit)] public struct NodeInfo { - public uint numUses; + [FieldOffset(0)] + public ushort numUses; + [FieldOffset(2)] + public ushort varIdx; + [FieldOffset(4)] + public ushort slotIdx = ushort.MaxValue; + [FieldOffset(6)] + public ushort exists = 1; + + public NodeInfo(ushort numInstances) + { + this.numUses = numInstances; + } - // Allocate stack slot for the node if numUses > 1 - public uint slotIdx = uint.MaxValue; + public unsafe ulong ToUlong() + { + return Unsafe.As(ref this); + } - public NodeInfo(uint numInstances) + public unsafe static NodeInfo FromUlong(ulong x) { - this.numUses = numInstances; + return *((NodeInfo*)&x); } public override string ToString() { return $"numInstances:{numUses}, slotIdx: {slotIdx}"; } + + public bool Equivalent(NodeInfo other) + { + return this.ToUlong() == other.ToUlong(); + } + } + + public interface IInfoStorage + { + public bool Contains(AstIdx idx); + + public NodeInfo Get(AstIdx idx); + + public void Set(AstIdx idx, NodeInfo info); + + public bool TryGet(AstIdx idx, out NodeInfo info); + } + + public class MapInfoStorage : IInfoStorage + { + public readonly Dictionary map = new(); + + public bool Contains(AstIdx idx) + { + return map.ContainsKey(idx); + } + + public NodeInfo Get(AstIdx idx) + { + return map[idx]; + } + + public void Set(AstIdx idx, NodeInfo info) + { + map[idx] = info; + } + + public bool TryGet(AstIdx idx, out NodeInfo info) + { + return map.TryGetValue(idx, out info); + } + } + + public class AuxInfoStorage : IInfoStorage + { + private readonly AstCtx ctx; + + public AuxInfoStorage(AstCtx ctx) + { + this.ctx = ctx; + } + + public bool Contains(AstIdx idx) + { + return Get(idx).exists != 0; + } + + public NodeInfo Get(AstIdx idx) + { + return NodeInfo.FromUlong(ctx.GetImutData(idx)); + } + + public void Set(AstIdx idx, NodeInfo info) + { + ctx.SetImutData(idx, info.ToUlong()); + } + + public bool TryGet(AstIdx idx, out NodeInfo info) + { + info = Get(idx); + return info.exists != 0; + } } // This class implements a JIT compiler to x86 with register allocation and node reuse. @@ -71,9 +159,9 @@ public class Amd64OptimizingJit private readonly List dfs = new(16); - private readonly Dictionary seen = new(); + IInfoStorage seen; - private uint slotCount = 0; + private ushort slotCount = 0; Stack stack = new(16); @@ -92,6 +180,7 @@ public class Amd64OptimizingJit public Amd64OptimizingJit(AstCtx ctx) { this.ctx = ctx; + seen = new AuxInfoStorage(ctx); } public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, bool useIcedBackend = false) @@ -99,11 +188,27 @@ public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, boo Assembler icedAssembler = useIcedBackend ? new Assembler(64) : null; assembler = useIcedBackend ? new IcedAmd64Assembler(icedAssembler) : new FastAmd64Assembler((byte*)pagePtr); + // Collect information about the nodes necessary for JITing (dfs order, how many users a value has) - CollectInfo(idx); + CollectInfo(ctx, idx, dfs, seen); + + // Store each variables argument index + for (int i = 0; i < variables.Count; i++) + { + var vIdx = variables[i]; + var data = seen.Get(vIdx); + data.varIdx = (byte)i; + seen.Set(vIdx, data); + } // Compile the instructions to x86. - LowerToX86(variables); + LowerToX86(); + + // Clear each node's mutable data + foreach (var id in dfs) + { + seen.Set(id, NodeInfo.FromUlong(0)); + } // If using the fast assembler backend, we've already emitted x86. // However the stack pointer adjustment needs to fixed up, because it wasn't known during prologue emission. @@ -123,11 +228,17 @@ public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, boo WriteInstructions(pagePtr, instructions); } - private void CollectInfo(AstIdx idx) + static ushort Inc(ushort cl) { - if (seen.TryGetValue(idx, out var existing)) + cl += 1; + return cl == 0 ? ushort.MaxValue : cl; + } + + + private static void CollectInfo(AstCtx ctx, AstIdx idx, List dfs, IInfoStorage seen) + { + if (seen.TryGet(idx, out var existing)) { - seen[idx] = new NodeInfo(existing.numUses); dfs.Add(idx); return; } @@ -143,31 +254,32 @@ private void CollectInfo(AstIdx idx) case AstOp.Xor: case AstOp.Lshr: var op0 = ctx.GetOp0(idx); - CollectInfo(op0); + CollectInfo(ctx,op0, dfs, seen); var op1 = ctx.GetOp1(idx); - CollectInfo(op1); + CollectInfo(ctx, op1, dfs, seen); - seen[op0] = new NodeInfo(seen[op0].numUses + 1); - seen[op1] = new NodeInfo(seen[op1].numUses + 1); + seen.Set(op0, new NodeInfo(Inc(seen.Get(op0).numUses))); + seen.Set(op1, new NodeInfo(Inc(seen.Get(op1).numUses))); break; case AstOp.Neg: case AstOp.Zext: case AstOp.Trunc: var single = ctx.GetOp0(idx); - CollectInfo(single); - seen[single] = new NodeInfo(seen[single].numUses + 1); + CollectInfo(ctx, single, dfs, seen); + seen.Set(single, new NodeInfo(Inc(seen.Get(single).numUses))); break; + case AstOp.Constant: default: break; } dfs.Add(idx); - seen[idx] = new NodeInfo(0); + seen.Set(idx, new NodeInfo(0)); } // Compile the provided DAG to x86 // TODO: Optionally disabled hashmap lookup/use tracking stuff for faster codegen. Pretend DAG is an AST if the duplicated cost is not too high - private unsafe void LowerToX86(List vars) + private unsafe void LowerToX86() { // rcx reserved for local variables ptr (or all vars in the case of a semi-linear result vector) // RSI, RDI reserved for temporary use @@ -187,9 +299,9 @@ private unsafe void LowerToX86(List vars) for(int i = 0; i < dfs.Count; i++) { var idx = dfs[i]; - var nodeInfo = seen[idx]; + var nodeInfo = seen.Get(idx); // If we've seen this value, load it's value from a local variable slot - if (nodeInfo.numUses > 1 && nodeInfo.slotIdx != uint.MaxValue) + if (nodeInfo.numUses > 1 && nodeInfo.slotIdx != ushort.MaxValue) { LoadSlotValue(nodeInfo.slotIdx); continue; @@ -215,7 +327,7 @@ private unsafe void LowerToX86(List vars) break; case AstOp.Symbol: - LowerVariable(idx, width, vars); + LowerVariable(nodeInfo.varIdx, width); break; case AstOp.Neg: @@ -302,12 +414,12 @@ private unsafe void LowerBinop(AstIdx idx, AstOp opc, uint width, NodeInfo nodeI case AstOp.Lshr: // TODO: For logical shifts, we need to reduce the other side modulo some constant! // Actually maybe not, we should have already reduced modulo? - var w = (uint)ctx.GetWidth(idx); - if (w % 8 != 0) + if (width % 8 != 0) throw new InvalidOperationException($"Cannot jit shr of non power of 2 width"); // Reduce shift count modulo bit width of operation // TODO: Hand non power of two bit widths - assembler.AndRegImm32(rhsDest, w - 1); + // TODO: Shift beyond bounds should yield zero + assembler.AndRegImm32(rhsDest, width - 1); // Execute lshr assembler.PushReg(Register.RCX); @@ -338,7 +450,7 @@ private unsafe void LowerBinop(AstIdx idx, AstOp opc, uint width, NodeInfo nodeI break; default: - throw new InvalidOperationException($"{opc} is not a valid binop"); + throw new InvalidOperationException($"{opc} is not a valid binop "); } ReduceRegisterModulo(width, lhsDest, scratch1); @@ -350,7 +462,6 @@ private unsafe void LowerBinop(AstIdx idx, AstOp opc, uint width, NodeInfo nodeI bool multipleUsers = nodeInfo.numUses > 1; if (multipleUsers) { - // Otherwise there are multiple users. This needs to go on a stack slot assembler.MovMem64Reg(localsRegister, 8 * (int)slotCount, lhsDest); AssignValueSlot(idx, nodeInfo); } @@ -398,19 +509,18 @@ private void LowerConstant(AstIdx idx) stack.Push(Location.Stack()); } - private void LowerVariable(AstIdx idx, uint width, IReadOnlyList vars) + private void LowerVariable(int varArrIdx, uint width) { - uint offset = (uint)vars.IndexOf(idx); if (freeRegisters.Count != 0) { var dest = freeRegisters.Pop(); - assembler.MovRegMem64(dest, argsRegister, 8 * (int)offset); + assembler.MovRegMem64(dest, argsRegister, 8 * varArrIdx); ReduceRegisterModulo(width, dest, scratch1); stack.Push(Location.Reg(dest)); return; } - assembler.PushMem64(argsRegister, 8 * (int)offset); + assembler.PushMem64(argsRegister, 8 * varArrIdx); stack.Push(Location.Stack()); ReduceLocationModulo(stack.Peek(), width); } @@ -433,7 +543,7 @@ private void LowerUnaryOp(AstIdx idx, AstOp opc, uint width, NodeInfo nodeInfo) else { - assembler.MovabsRegImm64(scratch2, (ulong)ModuloReducer.GetMask(ctx.GetWidth(idx))); + assembler.MovabsRegImm64(scratch2, (ulong)ModuloReducer.GetMask(width)); assembler.AndRegReg(destReg, scratch2); } @@ -519,9 +629,12 @@ private void ReduceLocationModulo(Location loc, uint width) private void AssignValueSlot(AstIdx idx, NodeInfo nodeInfo) { nodeInfo.slotIdx = slotCount; - seen[idx] = nodeInfo; - // Bump slot count up - slotCount += 1; + seen.Set(idx, nodeInfo); + // Bump slot count up. Throw if we hit the max slot limit + checked + { + slotCount += 1; + } } private static void EmitPrologue(IAmd64Assembler assembler, Register localsRegister, uint numStackSlots) diff --git a/Mba.Simplifier/Jit/FastAmd64Assembler.cs b/Mba.Simplifier/Jit/FastAmd64Assembler.cs index bb6cd0e..ce5e695 100644 --- a/Mba.Simplifier/Jit/FastAmd64Assembler.cs +++ b/Mba.Simplifier/Jit/FastAmd64Assembler.cs @@ -6,17 +6,53 @@ using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics.Arm; +using System.Security.Cryptography; using System.Text; using System.Threading.Tasks; namespace Mba.Simplifier.Jit { + unsafe ref struct StackBuffer + { + public byte* Ptr; + + public uint Offset; + + public StackBuffer(byte* ptr) + { + this.Ptr = ptr; + } + + public void PushU8(byte value) + { + Ptr[Offset++] = value; + } + + public void PushI32(int value) + => PushU32((uint)value); + + public void PushU32(uint value) + { + *(uint*)&Ptr[Offset] = value; + Offset += 4; + } + + public void PushU64(ulong value) + { + *(ulong*)&Ptr[Offset] = value; + Offset += 8; + } + } + public unsafe class FastAmd64Assembler : IAmd64Assembler { private byte* start; private byte* ptr; + private int offset = 0; + public List Instructions => GetInstructions(); public FastAmd64Assembler(byte* ptr) @@ -25,62 +61,73 @@ public FastAmd64Assembler(byte* ptr) this.ptr = ptr; } + private unsafe void EmitBytes(params byte[] bytes) + { + fixed(byte* p = &bytes[0]) + { + Memcpy(ptr, p, (uint)bytes.Length); + } + + ptr += bytes.Length; + } + + private void EmitBuffer(StackBuffer buffer) + { + Memcpy(ptr, buffer.Ptr, buffer.Offset); + ptr += buffer.Offset; + } + + private void Memcpy(void* destination, void* source, uint byteCount) + { + Unsafe.CopyBlockUnaligned(destination, source, byteCount); + } + public void PushReg(Register reg) { if (reg >= Register.RAX && reg <= Register.RDI) { byte opcode = (byte)(0x50 + GetRegisterCode(reg)); - *ptr++ = opcode; + EmitBytes(opcode); + return; } - else if (reg >= Register.R8 && reg <= Register.R15) + if (reg >= Register.R8 && reg <= Register.R15) { - byte rex = 0x41; - *ptr++ = rex; - + byte rex = 0x41; byte opcode = (byte)(0x50 + (int)reg - (int)Register.R8); - *ptr++ = opcode; + EmitBytes(rex, opcode); + return; } - else - { - throw new ArgumentException("Invalid register for PUSH instruction."); - } + throw new ArgumentException("Invalid register for PUSH instruction."); + } - // push qword ptr [baseReg+offset] public void PushMem64(Register baseReg, int offset) { - byte rex = 0x48; - if (IsExtended(baseReg)) rex |= 0x01; + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + + if (IsExtended(baseReg)) + { + byte rex = 0x49; + arr.PushU8(rex); + } byte opcode = 0xFF; byte modrm = (byte)(0x80 | (0x06 << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(opcode); + arr.PushU8(modrm); if (baseReg == Register.RSP || baseReg == Register.R12) { - if (IsExtended(baseReg)) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - if (IsExtended(baseReg)) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } + arr.PushI32(offset); - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + EmitBuffer(arr); } public void PopReg(Register reg) @@ -88,22 +135,20 @@ public void PopReg(Register reg) if (reg >= Register.RAX && reg <= Register.RDI) { byte opcode = (byte)(0x58 + GetRegisterCode(reg)); - *ptr++ = opcode; + EmitBytes(opcode); + return; } - else if (reg >= Register.R8 && reg <= Register.R15) + if (reg >= Register.R8 && reg <= Register.R15) { byte rex = 0x41; - *ptr++ = rex; - byte opcode = (byte)(0x58 + GetRegisterCode(reg) - 8); - *ptr++ = opcode; + EmitBytes(rex, opcode); + return; } - else - { - throw new ArgumentException($"Cannot pop {reg}"); - } + throw new ArgumentException($"Cannot pop {reg}"); + } public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) @@ -113,9 +158,7 @@ public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) if (IsExtended(reg2)) rex |= 0x04; byte modrm = (byte)(0xC0 | ((GetRegisterCode(reg2) & 0x07) << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void MovRegReg(Register reg1, Register reg2) @@ -123,6 +166,9 @@ public void MovRegReg(Register reg1, Register reg2) public void MovRegMem64(Register dstReg, Register baseReg, int offset) { + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + byte rex = 0x48; if (IsExtended(dstReg)) rex |= 0x04; if (IsExtended(baseReg)) rex |= 0x01; @@ -130,32 +176,27 @@ public void MovRegMem64(Register dstReg, Register baseReg, int offset) byte opcode = 0x8B; byte modrm = (byte)(0x80 | ((GetRegisterCode(dstReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } + arr.PushI32(offset); - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + EmitBuffer(arr); } // mov qword ptr [baseReg + offset], srcReg public void MovMem64Reg(Register baseReg, int offset, Register srcReg) { + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + byte rex = 0x48; if (IsExtended(srcReg)) rex |= 0x04; if (IsExtended(baseReg)) rex |= 0x01; @@ -163,27 +204,18 @@ public void MovMem64Reg(Register baseReg, int offset, Register srcReg) byte opcode = 0x89; byte modrm = (byte)(0x80 | ((GetRegisterCode(srcReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } - - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + arr.PushI32(offset); + EmitBuffer(arr); } public void MovabsRegImm64(Register reg1, ulong imm64) @@ -193,14 +225,14 @@ public void MovabsRegImm64(Register reg1, ulong imm64) var cond = (reg1 >= Register.RAX && reg1 <= Register.RDI); byte opcode = (byte)(0xB8 + (cond ? GetRegisterCode(reg1) : GetRegisterCode(reg1) - 8)); - *ptr++ = rex; - *ptr++ = opcode; - for (int i = 0; i < 8; i++) - { - *ptr++ = (byte)(imm64 & 0xFF); - imm64 >>= 8; - } + byte* p = stackalloc byte[10]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU64(imm64); + + EmitBuffer(arr); } public void AddRegReg(Register reg1, Register reg2) @@ -221,15 +253,15 @@ public void OpcRegImm(byte mask, Register reg1, uint imm32) byte opcode = 0x81; byte modrm = (byte)(0xC0 | (mask << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(imm32 & 0xFF); - imm32 >>= 8; - } + byte* p = stackalloc byte[7]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + arr.PushU32(imm32); + + EmitBuffer(arr); } public void ImulRegReg(Register reg1, Register reg2) @@ -241,10 +273,7 @@ public void ImulRegReg(Register reg1, Register reg2) byte opcode1 = 0x0F; byte opcode2 = 0xAF; byte modrm = (byte)(0xC0 | ((GetRegisterCode(reg1) & 0x07) << 3) | (GetRegisterCode(reg2) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode1; - *ptr++ = opcode2; - *ptr++ = modrm; + EmitBytes(rex, opcode1, opcode2, modrm); } public void AndRegReg(Register reg1, Register reg2) @@ -262,27 +291,22 @@ public void AndMem64Reg(Register baseReg, int offset, Register srcReg) byte opcode = 0x21; byte modrm = (byte)(0x80 | ((GetRegisterCode(srcReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; - } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + arr.PushU8(sib); } - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + arr.PushI32(offset); + + EmitBuffer(arr); } public void OrRegReg(Register reg1, Register reg2) @@ -298,9 +322,7 @@ public void NotReg(Register reg1) byte opcode = 0xF7; byte modrm = (byte)(0xC0 | (0x02 << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void ShlRegCl(Register reg) @@ -317,9 +339,7 @@ public void ShiftRegCl(bool shl, Register reg) byte opcode = 0xD3; var m1 = shl ? 0x04 : 0x05; byte modrm = (byte)(0xC0 | (m1 << 3) | (GetRegisterCode(reg) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void ShrRegImm8(Register reg, byte imm8) @@ -329,10 +349,7 @@ public void ShrRegImm8(Register reg, byte imm8) byte opcode = 0xC1; byte modrm = (byte)(0xC0 | (0x05 << 3) | (GetRegisterCode(reg) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - *ptr++ = imm8; + EmitBytes(rex, opcode, modrm, imm8); } public void CallReg(Register reg1) @@ -343,20 +360,20 @@ public void CallReg(Register reg1) byte opcode = 0xFF; byte modrm = (byte)(0xC0 | (0x02 << 3) | (GetRegisterCode(reg1) & 0x07)); - if (rex != 0x00) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + if (rex != 0) + EmitBytes(rex, opcode, modrm); + else + EmitBytes(opcode, modrm); } public void Ret() - => *ptr++ = 0xC3; + => EmitBytes(0xC3); private bool IsExtended(Register reg) => reg >= Register.R8 && reg <= Register.R15; - private int GetRegisterCode(Register reg) - => (int)reg - (int)Register.RAX; + private uint GetRegisterCode(Register reg) + => (uint)reg - (uint)Register.RAX; public List GetInstructions() { diff --git a/Mba.Simplifier/Minimization/BooleanMinimizer.cs b/Mba.Simplifier/Minimization/BooleanMinimizer.cs index e88f3fa..a7e37c4 100644 --- a/Mba.Simplifier/Minimization/BooleanMinimizer.cs +++ b/Mba.Simplifier/Minimization/BooleanMinimizer.cs @@ -191,7 +191,7 @@ private static AstIdx MinimizeAnf(AstCtx ctx, IReadOnlyList variables, T } var r = ctx.MinimizeAnf(TableDatabase.Instance.db, truthTable, tempVars, MultibitSiMBA.JitPage.Value); - var backSubst = GeneralSimplifier.ApplyBackSubstitution(ctx, r, invSubstMapping); + var backSubst = GeneralSimplifier.BackSubstitute(ctx, r, invSubstMapping); return backSubst; } } diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 97c9ec0..9bc2363 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -22,6 +22,8 @@ namespace Mba.Simplifier.Pipeline { public class GeneralSimplifier { + private const bool REDUCE_POLYS = true; + private readonly AstCtx ctx; // For any given node, we store the best possible ISLE result. @@ -42,7 +44,7 @@ public GeneralSimplifier(AstCtx ctx) public AstIdx SimplifyGeneral(AstIdx id, bool useIsle = true) { // Simplify the AST via efficient, recursive term rewriting(ISLE). - if (useIsle) + if(useIsle) id = SimplifyViaTermRewriting(id); // Simplify via recursive SiMBA. @@ -53,7 +55,7 @@ public AstIdx SimplifyGeneral(AstIdx id, bool useIsle = true) // Simplify the AST via efficient, recursive term rewriting(ISLE). private AstIdx SimplifyViaTermRewriting(AstIdx id) { - if (isleCache.TryGetValue(id, out var existingIdx)) + if(isleCache.TryGetValue(id, out var existingIdx)) return existingIdx; var initialId = id; @@ -67,7 +69,7 @@ private AstIdx SimplifyViaTermRewriting(AstIdx id) // TODO: Add to isle cache bool cacheIsle = true; - if (cacheIsle) + if(cacheIsle) { isleCache.TryAdd(initialId, id); isleCache.TryAdd(id, id); @@ -78,14 +80,14 @@ private AstIdx SimplifyViaTermRewriting(AstIdx id) // Simplify the AST via recursive SiMBA application. private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) { - if (simbaCache.TryGetValue(id, out var existing)) + if(simbaCache.TryGetValue(id, out var existing)) return existing; id = SimplifyViaTermRewriting(id); // TODO: We should probably apply ISLE before attempting any other steps. // For linear and semi-linear MBAs, we can skip the substitution / polynomial simplification steps. var linClass = ctx.GetClass(id); - if (ctx.IsConstant(id)) + if(ctx.IsConstant(id)) return id; if(linClass != AstClassification.Nonlinear) @@ -95,7 +97,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if(vars.Count > 11 || vars.Count == 0) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -118,13 +120,13 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) var usedVars = ctx.CollectVariables(withSubstitutions).ToHashSet(); foreach(var (substValue, substVar) in substMapping.ToList()) { - if (!usedVars.Contains(substVar)) + if(!usedVars.Contains(substVar)) substMapping.Remove(substValue); } // Try to take a guess (MSiMBA) and prove it's equivalence var guess = SimplifyViaGuessAndProve(withSubstitutions, substMapping, ref isSemiLinear); - if (guess != null) + if(guess != null) { // Apply constant folding / term rewriting. var simplGuess = SimplifyViaTermRewriting(guess.Value); @@ -136,14 +138,15 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // If there are multiple substitutions, try to minimize the number of substitutions. - if (substMapping.Count > 1) + if(substMapping.Count > 1) withSubstitutions = TryUnmergeLinCombs(withSubstitutions, substMapping, ref isSemiLinear); withSubstitutions = SimplifyViaTermRewriting(withSubstitutions); + // If polynomial parts are present, try to simplify them. var inverseMapping = substMapping.ToDictionary(x => x.Value, x => x.Key); AstIdx? reducedPoly = null; - if (polySimplify && ctx.GetHasPoly(id)) + if(polySimplify && ctx.GetHasPoly(id)) { // Try to reduce the polynomial parts using "pure" polynomial reduction algorithms. reducedPoly = ReducePolynomials(GetRootTerms(ctx, withSubstitutions), substMapping, inverseMapping); @@ -152,7 +155,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if(reducedPoly != null) { // Back substitute the original substitutions. - reducedPoly = ApplyBackSubstitution(ctx, reducedPoly.Value, inverseMapping); + reducedPoly = BackSubstitute(ctx, reducedPoly.Value, inverseMapping); // Reset internal state. substMapping.Clear(); @@ -164,10 +167,10 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // If there are any substitutions, we want to try simplifying the polynomial parts. var variables = ctx.CollectVariables(withSubstitutions); - if (polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) + if(REDUCE_POLYS && polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) { var maybeSimplified = TrySimplifyMixedPolynomialParts(withSubstitutions, substMapping, inverseMapping, variables); - if (maybeSimplified != null && maybeSimplified.Value != id) + if(maybeSimplified != null && maybeSimplified.Value != id) { // Reset internal state. substMapping.Clear(); @@ -179,10 +182,10 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) } // If there are still more too many variables remaining, bail out. - if (variables.Count > 11) + if(variables.Count > 11) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -191,9 +194,9 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) withSubstitutions = SimplifyViaTermRewriting(withSubstitutions); var result = withSubstitutions; - if (!ctx.IsConstant(withSubstitutions)) + if(!ctx.IsConstant(withSubstitutions)) result = LinearSimplifier.Run(ctx.GetWidth(withSubstitutions), ctx, withSubstitutions, false, isSemiLinear, false, variables); - var backSub = ApplyBackSubstitution(ctx, result, inverseMapping); + var backSub = BackSubstitute(ctx, result, inverseMapping); // Apply constant folding / term rewriting. var propagated = SimplifyViaTermRewriting(backSub); @@ -210,7 +213,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) var newSubstMapping = new Dictionary(); var newWithSubst = GetAstWithSubstitutions(propagated, newSubstMapping, ref isSemiLinear, false); var newVars = ctx.CollectVariables(newWithSubst); - if (newVars.Count < variables.Count) + if(newVars.Count < variables.Count) { propagated = SimplifyViaRecursiveSiMBA(propagated, true); } @@ -226,7 +229,7 @@ private static ulong Pow(ulong bbase, ulong exponent) for (ulong term = bbase; exponent != 0; term = term * term) { - if (exponent % 2 != 0) { result *= term; } + if(exponent % 2 != 0) { result *= term; } exponent /= 2; } @@ -249,10 +252,10 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub case AstOp.Mul: // If we encounter an arithmetic subtree inside of a bitwise operator, it is not linear. // In this case we try to recursively simplify the subtree and check if it was made linear. - if (inBitwise) + if(inBitwise) { var simplified = SimplifyViaRecursiveSiMBA(id); - if (simplified != id) + if(simplified != id) { id = simplified; goto start; @@ -260,7 +263,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub } // If the above check still yielded something that is not linear, we apply substitution. - if (inBitwise) + if(inBitwise) { return GetSubstitution(id, substitutionMapping); } @@ -273,30 +276,30 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub // If the first operand is not a constant, then we want to recursively simplify both children and // check if it yields a constant. - if (opcode == AstOp.Mul && ctx.GetOpcode(v0) != AstOp.Constant) + if(opcode == AstOp.Mul && ctx.GetOpcode(v0) != AstOp.Constant) { v0 = SimplifyViaRecursiveSiMBA(v0); v1 = SimplifyViaRecursiveSiMBA(v1); - if (ctx.GetOpcode(v1) == AstOp.Constant) + if(ctx.GetOpcode(v1) == AstOp.Constant) (v0, v1) = (v1, v0); } // If both children are still not constant after applying recursive simplification, // then we need to perform substitution. - if (opcode == AstOp.Mul && ctx.GetOpcode(v0) != AstOp.Constant) + if(opcode == AstOp.Mul && ctx.GetOpcode(v0) != AstOp.Constant) { var mul = ctx.Mul(v0, v1); return GetSubstitution(mul, substitutionMapping); } // Otherwise we have a multiplication where one term is a constant(linear). - if (opcode == AstOp.Mul) + if(opcode == AstOp.Mul) { var constTerm = v0; // In the case of coeff*(x+y), where coeff is a constant, we want to distribute it, yielding coeff*x + coeff*y. - if (ctx.GetOpcode(v1) == AstOp.Add) + if(ctx.GetOpcode(v1) == AstOp.Add) { var left = ctx.Mul(constTerm, ctx.GetOp0(v1)); left = ctx.SingleSimplify(left); @@ -308,8 +311,8 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var oldSum = sum; var newSum = ctx.SingleSimplify(sum); sum = newSum; - // In this case, we apply constant folding(but we do not search recursively). + // In this case, we apply constant folding(but we do not search recursively). return GetAstWithSubstitutions(sum, substitutionMapping, ref isSemiLinear, inBitwise); } @@ -326,7 +329,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub case AstOp.Pow: var basis = SimplifyViaRecursiveSiMBA(ctx.GetOp0(id)); var degree = SimplifyViaRecursiveSiMBA(ctx.GetOp1(id)); - if (ctx.IsConstant(basis) && ctx.IsConstant(degree)) + if(ctx.IsConstant(basis) && ctx.IsConstant(degree)) { var folded = Pow(ctx.GetConstantValue(basis), ctx.GetConstantValue(degree)); return visitReplacement(ctx.Constant(folded, ctx.GetWidth(basis)), inBitwise, ref isSemiLinear); @@ -338,7 +341,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub case AstOp.And: case AstOp.Or: case AstOp.Xor: - if (opcode == AstOp.And) + if(opcode == AstOp.And) { // Simplify both children. var and0 = op0(true, ref isSemiLinear); @@ -347,23 +350,23 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var id0 = ctx.GetOp0(id); var id1 = ctx.GetOp1(id); // Move constants to the left - if (ctx.IsConstant(and1)) + if(ctx.IsConstant(and1)) { (and0, and1) = (and1, and0); (id0, id1) = (id1, id0); } - + // Rewrite (a&mask) as `Trunc(a)`, or `Trunc(a & mask)` if mask is not completely a bit mask. // This is a form of adhoc demanded bits based simplification - if (ctx.IsConstant(and0) && !ctx.IsConstant(and1)) + if(ctx.IsConstant(and0) && !ctx.IsConstant(and1)) { var andMask = ctx.GetConstantValue(and0); - if (CanEliminateMask(andMask, id1)) + if(CanEliminateMask(andMask, id1)) return visitReplacement(id1, inBitwise, ref isSemiLinear); var truncWidth = ConstantToTruncWidth(andMask); - if (truncWidth != 0 && truncWidth < ctx.GetWidth(id)) + if(truncWidth != 0 && truncWidth < ctx.GetWidth(id)) { isSemiLinear = true; var moduloWidth = 64 - (uint)BitOperations.LeadingZeroCount(andMask); @@ -373,29 +376,29 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var before = trunc; trunc = visitReplacement(before, true, ref isSemiLinear); var ext = ctx.Zext(trunc, ctx.GetWidth(id)); - if (ModuloReducer.GetMask((uint)truncWidth) != andMask) + if(ModuloReducer.GetMask((uint)truncWidth) != andMask) ext = ctx.And(ctx.Constant(andMask, ctx.GetWidth(id)), ext); return ext; } } - + return ctx.And(and0, and1); } - if (opcode == AstOp.Or) + if(opcode == AstOp.Or) return ctx.Or(op0(true, ref isSemiLinear), op1(true, ref isSemiLinear)); - if (opcode == AstOp.Xor) + if(opcode == AstOp.Xor) return ctx.Xor(op0(true, ref isSemiLinear), op1(true, ref isSemiLinear)); throw new InvalidOperationException("Unrecognized opcode!"); case AstOp.Neg: - if (inBitwise) + if(inBitwise) { - // Deleting because it causes stackoverflow! + // Temporary disabled because it can cause stack overflows. /* // If we encounter a negation inside of a bitwise operator, try to simplify the subtree. var simplified = SimplifyViaRecursiveSiMBA(id); - if (simplified != id) + if(simplified != id) { id = simplified; goto start; @@ -416,7 +419,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var src = SimplifyViaRecursiveSiMBA(ctx.GetOp0(id)); var by = SimplifyViaRecursiveSiMBA(ctx.GetOp1(id)); // Apply constant propagation if both nodes fold. - if (ctx.IsConstant(src) && ctx.IsConstant(by)) + if(ctx.IsConstant(src) && ctx.IsConstant(by)) { var value = ctx.GetConstantValue(src) >> (ushort)(ctx.GetConstantValue(by) % (ulong)ctx.GetWidth(id)); var constant = ctx.Constant(value, ctx.GetWidth(id)); @@ -438,7 +441,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub case AstOp.Constant: // If a bitwise constant is present, we want to mark it as semi-linear - if (inBitwise) + if(inBitwise) isSemiLinear = true; return id; @@ -456,7 +459,7 @@ private bool CanEliminateMask(ulong andMask, AstIdx idx) var knownBits = ctx.GetKnownBits(idx); var zeroes = ~andMask & ModuloReducer.GetMask(ctx.GetWidth(idx)); - if ((zeroes & knownBits.Zeroes) == zeroes) + if((zeroes & knownBits.Zeroes) == zeroes) return true; return false; @@ -467,11 +470,11 @@ private ulong ConstantToTruncWidth(ulong c) { var lz = BitOperations.LeadingZeroCount(c); var minWidth = 64 - lz; - if (minWidth <= 7) + if(minWidth <= 7) return 8; - if (minWidth <= 16) + if(minWidth <= 16) return 16; - if (minWidth <= 32) + if(minWidth <= 32) return 32; return 0; } @@ -508,7 +511,7 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) // Skip if this is not a multiplication. var opcode = ctx.GetOpcode(id); - var roots = GetRootMultiplications(ctx,id); + var roots = GetRootMultiplications(ctx, id); ulong coeffSum = 0; Dictionary constantPowers = new(); List others = new(); @@ -516,11 +519,11 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) foreach(var root in roots) { var code = ctx.GetOpcode(root); - if (code == AstOp.Constant) + if(code == AstOp.Constant) { coeffSum += ctx.GetConstantValue(root); } - else if (code == AstOp.Symbol) + else if(code == AstOp.Symbol) { constantPowers.TryAdd(root, 0); constantPowers[root]++; @@ -564,9 +567,9 @@ public static int VarsFirst(AstCtx ctx, AstIdx a, AstIdx b) var op0 = ctx.IsSymbol(a); var op1 = ctx.IsSymbol(b); - if (op0 && !op1) + if(op0 && !op1) return comeFirst; - if (op1 && !op0) + if(op1 && !op0) return comeLast; if(op0 && op1) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); @@ -584,31 +587,31 @@ private int CompareTo(AstIdx a, AstIdx b) // Push constants to the left var op0 = ctx.GetOpcode(a); var op1 = ctx.GetOpcode(b); - if (op0 == AstOp.Constant) + if(op0 == AstOp.Constant) return comeFirst; - if (op1 == AstOp.Constant) + if(op1 == AstOp.Constant) return comeLast; // Sort symbols alphabetically if(op0 == AstOp.Symbol && op1 == AstOp.Symbol) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); - if (op0 == AstOp.Pow) + if(op0 == AstOp.Pow) return comeLast; - if (op1 == AstOp.Pow) + if(op1 == AstOp.Pow) return comeFirst; return -1; } private AstIdx GetSubstitution(AstIdx id, Dictionary substitutionMapping) { - if (substitutionMapping.TryGetValue(id, out var existing)) + if(substitutionMapping.TryGetValue(id, out var existing)) return existing; - while(true) + while (true) { var subst = ctx.Symbol($"subst{substCount}", ctx.GetWidth(id)); substCount++; - if (substitutionMapping.Values.Contains(subst)) + if(substitutionMapping.Values.Contains(subst)) { continue; } @@ -621,11 +624,11 @@ private AstIdx GetSubstitution(AstIdx id, Dictionary substitutio private AstIdx TryUnmergeLinCombs(AstIdx withSubstitutions, Dictionary substitutionMapping, ref bool isSemiLinear) { // We cannot rewrite substitutions as negations of one another if there is only one substitution. - if (substitutionMapping.Count == 1) + if(substitutionMapping.Count == 1) return withSubstitutions; var result = UnmergeDisjointParts(withSubstitutions, substitutionMapping, ref isSemiLinear); - if (result != null) + if(result != null) withSubstitutions = result.Value; var rewriteMapping = UnmergeNegatedParts(substitutionMapping); @@ -634,7 +637,7 @@ private AstIdx TryUnmergeLinCombs(AstIdx withSubstitutions, Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary(); var results = new List(); - for(int i = 0 ; i < inputExpressions.Count; i++) + for (int i = 0; i < inputExpressions.Count; i++) { // Substitute all of the nonlinear parts for this expression // Here we share the list of substitutions @@ -696,12 +699,12 @@ private Dictionary UnmergeNegatedParts(Dictionary x.Value, x => x.Key); var vars = results.SelectMany(x => ctx.CollectVariables(x)).Distinct().OrderBy(x => ctx.GetSymbolName(x)).ToList(); - if (vars.Count > 11) + if(vars.Count > 11) return null; // Compute a result vector for each expression Dictionary> vecToExpr = new(); - for(int i = 0; i < results.Count; i++) + for (int i = 0; i < results.Count; i++) { var expr = results[i]; var w = ctx.GetWidth(expr); @@ -731,9 +734,9 @@ private Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary !CanFitConstantInUndemandedBits(kb, x.constantOffset, moduloMask))) + if(members.Any(x => !CanFitConstantInUndemandedBits(kb, x.constantOffset, moduloMask))) continue; // Replace each substituted var with a new substituted var | some constant offset. var newSubstVar = ctx.Symbol($"subst{substCount}", (byte)w); substCount++; - for(int i = 0; i < members.Count; i++) + for (int i = 0; i < members.Count; i++) { var member = members[i]; var expr = ctx.Or(ctx.Constant(member.constantOffset, w), newSubstVar); var inVar = inputSubstVars[member.index]; - if (varToNewSubstValue.ContainsKey(inVar)) + if(varToNewSubstValue.ContainsKey(inVar)) throw new InvalidOperationException($"Cannot share substituted parts!"); - + varToNewSubstValue[inVar] = expr; } var vecExpr = LinearSimplifier.Run(w, ctx, null, false, true, false, vars, null, resultVector); // TODO: ToArray vecExpr = ctx.Add(ctx.Constant(union, w), vecExpr); // Back substitute in the variables we temporarily substituted. - vecExpr = ApplyBackSubstitution(ctx, vecExpr, tempSubstMapping); + vecExpr = BackSubstitute(ctx, vecExpr, tempSubstMapping); substitutionMapping.Remove(vecExpr); substitutionMapping.TryAdd(vecExpr, newSubstVar); isSemiLinear = true; } - withSubstitutions = ApplyBackSubstitution(ctx, withSubstitutions, varToNewSubstValue); + withSubstitutions = BackSubstitute(ctx, withSubstitutions, varToNewSubstValue); var newVars = ctx.CollectVariables(withSubstitutions).ToHashSet(); foreach(var (expr, substVar) in substitutionMapping.ToList()) { - if (!newVars.Contains(substVar)) + if(!newVars.Contains(substVar)) substitutionMapping.Remove(expr); } @@ -805,9 +808,9 @@ private Dictionary UnmergeNegatedParts(Dictionary substitutionMapping, ref bool isSemiLinear) { // If there are no substituted parts, we have a semi-linear MBA. - if (substitutionMapping.Count == 0) + if(substitutionMapping.Count == 0) return null; // Compute demanded bits for each variable // TODO: Keep track of which bits are demanded by the parent(withSubstitutions) Dictionary varToDemandedBits = new(); + var cache = new HashSet<(AstIdx idx, ulong currDemanded)>(); + int totalDemanded = 0; foreach(var (expr, substVar) in substitutionMapping) - ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits); + { + ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, cache, ref totalDemanded); + if(totalDemanded > 12) + break; + } - // Compute the total number of demanded variable bits in the substituted parts. - ulong totalDemanded = 0; - foreach (var demandedBits in varToDemandedBits.Values) - totalDemanded += (ulong)BitOperations.PopCount(demandedBits); - if (totalDemanded > 12) + + // Bail if there are too many demanded bits! + if(totalDemanded > 12) return null; // Partition the MBA into semi-linear, unconstrained, and constrained parts. var (semilinearIdx, unconstrainedIdx, constrainedIdx) = PartitionConstrainedParts(withSubstitutions, substitutionMapping); // If there are no constrained or unconstrained parts then this is a semi-linear MBA. - if (unconstrainedIdx == null && constrainedIdx == null) + if(unconstrainedIdx == null && constrainedIdx == null) throw new InvalidOperationException($"Expected nonlinear expression!"); // If we have no unconstrained parts, we can prove the equivalence of the entire expression. - if (unconstrainedIdx == null) + if(unconstrainedIdx == null) return SimplifyConstrained(withSubstitutions, substitutionMapping, varToDemandedBits); // If we have have no constrained parts, we can simplify the entire expression individually. - if (constrainedIdx == null) + if(constrainedIdx == null) { // Simplify the constrained parts. - var withoutSubstitutions = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var r = SimplifyUnconstrained(withoutSubstitutions, varToDemandedBits); - if (r == null) + if(r == null) return null; - if (semilinearIdx == null) + if(semilinearIdx == null) return r.Value; // If there are some semi-linear parts, combine and simplify. @@ -863,18 +870,18 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong } // Otherwise we have both constrained and unconstrained parts, which need to be simplified individually and composed back together. - if (semilinearIdx != null) + if(semilinearIdx != null) constrainedIdx = ctx.Add(semilinearIdx.Value, constrainedIdx.Value); // Simplify constrained parts. var constrainedSimpl = SimplifyConstrained(constrainedIdx.Value, substitutionMapping, varToDemandedBits); - if (constrainedSimpl == null) + if(constrainedSimpl == null) return null; // Simplify unconstrained parts. - var unconstrainedBackSub = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var unconstrainedBackSub = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var unconstrainedSimpl = SimplifyUnconstrained(unconstrainedBackSub, varToDemandedBits); - if (unconstrainedSimpl == null) + if(unconstrainedSimpl == null) return null; // Compose and simplify @@ -888,7 +895,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong { // Construct a result vector for the linear part. var substVars = substitutionMapping.Values.ToList(); - var allVars = ctx.CollectVariables(withSubstitutions); + IReadOnlyList allVars = ctx.CollectVariables(withSubstitutions); var bitSize = ctx.GetWidth(withSubstitutions); var numCombinations = (ulong)Math.Pow(2, allVars.Count); var groupSizes = LinearSimplifier.GetGroupSizes(allVars.Count); @@ -914,37 +921,39 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong ulong substVarMask = 0; for (int i = 0; i < allVars.Count; i++) { - if (!substVars.Contains(allVars[i])) + if(!substVars.Contains(allVars[i])) continue; substVarMask |= (1ul << (ushort)i); } List semiLinearParts = new(); - if (constantOffset != 0) + if(constantOffset != 0) semiLinearParts.Add(ctx.Constant(constantOffset, bitSize)); List unconstrainedParts = new(); List constrainedParts = new(); // Decompose result vector into semi-linear, unconstrained, and constrained parts. + // Upcast variables as necessary! + allVars = LinearSimplifier.CastVariables(ctx, allVars, bitSize); int resultVecIdx = 0; - for(int i = 0; i < linearCombinations.Count; i++) + for (int i = 0; i < linearCombinations.Count; i++) { foreach(var (coeff, bitMask) in linearCombinations[i]) { - if (coeff == 0) + if(coeff == 0) goto skip; // If the term only contains normal variables, its semi-linear. var varComb = variableCombinations[i]; - if ((varComb & substVarMask) == 0) + if((varComb & substVarMask) == 0) { semiLinearParts.Add(LinearSimplifier.ConjunctionFromVarMask(ctx, allVars, coeff, varComb, bitMask)); goto skip; } // If the term only contains substituted variables, it's unconstrained. - if ((varComb & ~substVarMask) == 0) + if((varComb & ~substVarMask) == 0) { unconstrainedParts.Add(LinearSimplifier.ConjunctionFromVarMask(ctx, allVars, coeff, varComb, bitMask)); goto skip; @@ -966,7 +975,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong } // TODO: Refactor out! - private static (ulong[], List>) GetAnf(uint width, List variables, List groupSizes, ulong[] resultVector, bool multiBit) + private static (ulong[], List>) GetAnf(uint width, IReadOnlyList variables, List groupSizes, ulong[] resultVector, bool multiBit) { // Get all combinations of variables. var moduloMask = ModuloReducer.GetMask(width); @@ -1013,7 +1022,7 @@ private static (ulong[], List>) GetAnf(uint w var comb = variableCombinations[i]; var (trueMask, index) = combToMaskAndIdx[i]; var coeff = ptr[(int)offset + index]; - if (coeff == 0) + if(coeff == 0) continue; // Subtract the coefficient from the result vector. @@ -1033,7 +1042,7 @@ private static (ulong[], List>) GetAnf(uint w private unsafe AstIdx? SimplifyConstrained(AstIdx withSubstitutions, Dictionary substitutionMapping, Dictionary varToDemandedBits) { // Compute a result vector for the original expression - var withoutSubstitutions = ApplyBackSubstitution(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var w = ctx.GetWidth(withoutSubstitutions); var inputVars = ctx.CollectVariables(withoutSubstitutions); var originalResultVec = LinearSimplifier.JitResultVector(ctx, w, ModuloReducer.GetMask(w), inputVars, withoutSubstitutions, true, (ulong)Math.Pow(2, inputVars.Count)); @@ -1042,12 +1051,12 @@ private static (ulong[], List>) GetAnf(uint w var exprToSubstVar = substitutionMapping.OrderBy(x => ctx.GetAstString(x.Value)).ToList(); var allVars = inputVars.Concat(exprToSubstVar.Select(x => x.Value)).ToList(); // Sort them.... var pagePtr = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, true); + new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, false); var jittedWithSubstitutions = (delegate* unmanaged[SuppressGCTransition])pagePtr; // Return null if the expressions are not provably equivalent var demandedVars = varToDemandedBits.OrderBy(x => ctx.GetSymbolName(x.Key)).Select(x => (x.Key, x.Value)).ToList(); - if (!IsConstrainedExpressionEquivalent(w, inputVars, demandedVars, exprToSubstVar, jittedWithSubstitutions, originalResultVec)) + if(!IsConstrainedExpressionEquivalent(w, inputVars, demandedVars, exprToSubstVar, jittedWithSubstitutions, originalResultVec)) { JitUtils.FreeExecutablePage(pagePtr); return null; @@ -1060,7 +1069,7 @@ private static (ulong[], List>) GetAnf(uint w } // Returns true if two expressions are guaranteed to be equivalent - private unsafe bool IsConstrainedExpressionEquivalent(uint width,List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) + private unsafe bool IsConstrainedExpressionEquivalent(uint width, List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) { int totalDemanded = demandedVars.Sum(x => BitOperations.PopCount(x.demandedMask)); @@ -1097,7 +1106,7 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in for (int vIdx = 0; vIdx < (int)totalDemanded; vIdx++) { // If we've chosen values for all bits in this variable, move onto the next one. - if (currDemandedVarMask == 0) + if(currDemandedVarMask == 0) { currVarIdx += 1; currDemandedVarMask = demandedVars[currVarIdx].demandedMask; @@ -1129,7 +1138,7 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in fixed (ulong* vPtr = &vArray[0]) { var curr = jittedWithSubstitutions(vPtr) >> bitIndex; - if (curr != originalResultVec[vecIdx]) + if(curr != originalResultVec[vecIdx]) return false; } } @@ -1153,12 +1162,12 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in // Jit the input expression var pagePtr1 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, true); + new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, false); var jittedBefore = (delegate* unmanaged[SuppressGCTransition])pagePtr1; // Jit the output expression var pagePtr2 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, true); + new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, false); var jittedAfter = (delegate* unmanaged[SuppressGCTransition])pagePtr2; // Prove that they are equivalent for all possible input combinations @@ -1176,7 +1185,7 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in for (int vIdx = 0; vIdx < (int)totalDemanded; vIdx++) { // If we've chosen values for all bits in this variable, move onto the next one. - if (currDemandedVarMask == 0) + if(currDemandedVarMask == 0) { currVarIdx += 1; currDemandedVarMask = demandedVars[currVarIdx].demandedMask; @@ -1199,7 +1208,7 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in { var op1 = jittedBefore(vPtr); var op2 = jittedAfter(vPtr); - if (op1 != op2) + if(op1 != op2) { JitUtils.FreeExecutablePage(pagePtr1); JitUtils.FreeExecutablePage(pagePtr2); @@ -1208,22 +1217,35 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in } } + JitUtils.FreeExecutablePage(pagePtr1); + JitUtils.FreeExecutablePage(pagePtr2); return expectedExpr; } // TODO: Cache results to avoid exponentially visiting shared nodes - private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits) + private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet<(AstIdx idx, ulong currDemanded)> seen, ref int totalDemanded) { - var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits); - var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits); + if(totalDemanded > 12) + return; + if(!seen.Add((idx, currDemanded))) + return; + + totalDemanded += 1; + + var op0 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); + var op1 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); var opc = ctx.GetOpcode(idx); - switch(opc) + switch (opc) { // If we have a symbol, union the set of demanded bits case AstOp.Symbol: - symbolDemandedBits.TryAdd(idx, 0); - symbolDemandedBits[idx] |= currDemanded; + //symbolDemandedBits.TryAdd(idx, 0); + symbolDemandedBits.TryGetValue(idx, out var oldDemanded); + var newDemanded = oldDemanded | currDemanded; + symbolDemandedBits[idx] = newDemanded; + totalDemanded += BitOperations.PopCount(newDemanded & ~oldDemanded); + break; // If we have a constant, there is nothing to do. case AstOp.Constant: @@ -1233,26 +1255,27 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // For addition by a constant we can also get more precision case AstOp.Add: case AstOp.Mul: + case AstOp.Pow: // If we have addition/multiplication, we only care about bits at and below the highest set bit. var demandedWidth = 64 - (uint)BitOperations.LeadingZeroCount(currDemanded); currDemanded = ModuloReducer.GetMask(demandedWidth); - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.Lshr: var shiftBy = ctx.GetOp1(idx); var shiftByConstant = ctx.TryGetConstantValue(shiftBy); - if (shiftByConstant == null) + if(shiftByConstant == null) { - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; } // If we know the value we are shifting by, we can truncate the demanded bits. - op0(currDemanded >> (ushort)shiftByConstant.Value); - op1(currDemanded); + op0(currDemanded >> (ushort)shiftByConstant.Value, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.And: @@ -1260,8 +1283,8 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a&b, demandedbits(a) does not include any known zero bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Zeroes & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Zeroes & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } @@ -1270,25 +1293,25 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a|b, demandedbits(a) does not include any known one bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Ones & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Ones & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } // TODO: We can gain some precision by exploiting XOR known bits. case AstOp.Xor: - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; // TODO: Treat negation as x^-1, then use XOR transfer function case AstOp.Neg: - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Trunc: currDemanded &= ModuloReducer.GetMask(ctx.GetWidth(idx)); - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Zext: - op0(currDemanded & ctx.GetWidth(ctx.GetOp0(idx))); + op0(currDemanded & ModuloReducer.GetMask(ctx.GetWidth(ctx.GetOp0(idx))), ref totalDemanded); break; default: throw new InvalidOperationException($"Cannot compute demanded bits for {opc}"); @@ -1298,7 +1321,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar private AstIdx? TrySimplifyMixedPolynomialParts(AstIdx id, Dictionary substMapping, Dictionary inverseSubstMapping, List varList) { // Back substitute in the (possibly) polynomial parts - var newId = ApplyBackSubstitution(ctx, id, inverseSubstMapping); + var newId = BackSubstitute(ctx, id, inverseSubstMapping); // Decompose each term into structured polynomial parts var terms = GetRootTerms(ctx, newId); @@ -1309,7 +1332,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar var result = SimplifyParts(ctx.GetWidth(id), polyParts); - if (result == null) + if(result == null) return null; // Add back any banned parts. @@ -1320,12 +1343,12 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar } // Do a full back substitution again. - result = ApplyBackSubstitution(ctx, result.Value, inverseSubstMapping); + result = BackSubstitute(ctx, result.Value, inverseSubstMapping); // Bail out if this resulted in a worse result. var cost1 = ctx.GetCost(result.Value); var cost2 = ctx.GetCost(newId); - if (cost1 > cost2) + if(cost1 > cost2) { return null; } @@ -1341,16 +1364,16 @@ private List UnmergePolynomialParts(Dictionary // Try to rewrite substituted parts as negations of one another. Exit early if this fails. var rewriteMapping = UnmergeNegatedParts(substitutionMapping); - if (rewriteMapping.Count == 0) + if(rewriteMapping.Count == 0) return parts; var output = new List(); - foreach (var part in parts) + foreach(var part in parts) { var outPowers = new Dictionary(); - foreach (var (factor, degree) in part.ConstantPowers) + foreach(var (factor, degree) in part.ConstantPowers) { - var unmerged = ApplyBackSubstitution(ctx, factor, rewriteMapping); + var unmerged = BackSubstitute(ctx, factor, rewriteMapping); outPowers.TryAdd(unmerged, 0); outPowers[unmerged] += degree; } @@ -1396,7 +1419,7 @@ private List UnmergePolynomialParts(Dictionary // Collect all variables. var withVars = partsWithSubstitutions.SelectMany(x => x.ConstantPowers.Keys); var varSet = new List(); - if (withVars.Any()) + if(withVars.Any()) varSet = ctx.CollectVariables(ctx.Add(withVars)); IReadOnlyList allVars = varSet.OrderBy(x => ctx.GetSymbolName(x)).ToList(); @@ -1418,11 +1441,11 @@ private List UnmergePolynomialParts(Dictionary var moduloMask = (ulong)ModuloReducer.GetMask(bitSize); Dictionary basisSubstitutions = new(); List polys = new(); - foreach (var polyPart in partsWithSubstitutions) + foreach(var polyPart in partsWithSubstitutions) { List factors = new(); ulong size = 1; - foreach (var (factor, degree) in polyPart.ConstantPowers) + foreach(var (factor, degree) in polyPart.ConstantPowers) { // Construct a result vector for the linear part. var resultVec = LinearSimplifier.JitResultVector(ctx, bitSize, moduloMask, allVars, factor, multiBit: false, numCombinations); @@ -1437,13 +1460,13 @@ private List UnmergePolynomialParts(Dictionary { // Skip zero elements. var coeff = anfVector[i]; - if (coeff == 0) + if(coeff == 0) continue; numNonZeroes++; } // Calculate the max possible size of the resulting expression when multiplied out. - for(ulong i = 0; i < degree; i++) + for (ulong i = 0; i < degree; i++) { size = SaturatingMul(size, numNonZeroes); } @@ -1453,7 +1476,7 @@ private List UnmergePolynomialParts(Dictionary { // Skip zero elements. var coeff = anfVector[i]; - if (coeff == 0) + if(coeff == 0) continue; // When the basis element corresponds to the constant offset, we want to make the base bitwise expression be `1`. @@ -1461,7 +1484,7 @@ private List UnmergePolynomialParts(Dictionary AstIdx basis = ctx.Constant(1, (byte)bitSize); if(i != 0) { - if (!basisSubstitutions.TryGetValue((ulong)i, out basis)) + if(!basisSubstitutions.TryGetValue((ulong)i, out basis)) { basis = ctx.Symbol($"basis{i}", (byte)bitSize); basisSubstitutions.Add((ulong)i, basis); @@ -1473,7 +1496,7 @@ private List UnmergePolynomialParts(Dictionary } // Add this as a factor. - if (terms.Count == 0) + if(terms.Count == 0) terms.Add(ctx.Constant(0, (byte)bitSize)); var sum = ctx.Add(terms); factors.Add(ctx.Pow(sum, ctx.Constant(degree, (byte)bitSize))); @@ -1490,7 +1513,7 @@ private List UnmergePolynomialParts(Dictionary // Add in the coefficient. AstIdx? poly = null; var constOffset = ctx.Constant(polyPart.coeffSum, (byte)bitSize); - if (factors.Any()) + if(factors.Any()) { poly = ctx.Mul(factors); poly = ctx.Mul(constOffset, poly.Value); @@ -1502,12 +1525,12 @@ private List UnmergePolynomialParts(Dictionary poly = constOffset; } - + polys.Add(poly.Value); } // If there were no polynomial parts we could expand, return null. - if (!polys.Any()) + if(!polys.Any()) return null; // Reduce the polynomial parts. @@ -1521,15 +1544,15 @@ private List UnmergePolynomialParts(Dictionary } var invBases = basisSubstitutions.ToDictionary(x => x.Value, x => LinearSimplifier.ConjunctionFromVarMask(ctx, allVars, 1, x.Key)); - var backSub = ApplyBackSubstitution(ctx, reduced, invBases); - backSub = ApplyBackSubstitution(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); + var backSub = BackSubstitute(ctx, reduced, invBases); + backSub = BackSubstitute(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); return backSub; } private ulong SaturatingMul(ulong a, ulong b) { var value = (UInt128)a * (UInt128)b; - if (value > ulong.MaxValue) + if(value > ulong.MaxValue) return ulong.MaxValue; return (ulong)value; } @@ -1543,7 +1566,7 @@ public static IReadOnlyList GetRootTerms(AstCtx ctx, AstIdx id) { var term = toVisit.Pop(); var opcode = ctx.GetOpcode(term); - if (opcode == AstOp.Add) + if(opcode == AstOp.Add) { toVisit.Push(ctx.GetOp0(term)); toVisit.Push(ctx.GetOp1(term)); @@ -1552,11 +1575,11 @@ public static IReadOnlyList GetRootTerms(AstCtx ctx, AstIdx id) // If we have coeff*(x+y) and coeff is a constant, rewrite as coeff*x + coeff*y. // If coeff is not a constant then we do not apply it - it would yield exponential growth in the worst case. // TODO: Handle polynomial expansion more uniformly. - else if (opcode == AstOp.Mul && ctx.IsConstant(ctx.GetOp0(term))) + else if(opcode == AstOp.Mul && ctx.IsConstant(ctx.GetOp0(term))) { var coeff = ctx.GetOp0(term); var other = ctx.GetOp1(term); - if (ctx.IsAdd(other)) + if(ctx.IsAdd(other)) { var sum1 = ctx.Mul(coeff, ctx.GetOp0(other)); var sum2 = ctx.Mul(coeff, ctx.GetOp1(other)); @@ -1588,7 +1611,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) { var term = toVisit.Pop(); var opcode = ctx.GetOpcode(term); - if (opcode == AstOp.Mul) + if(opcode == AstOp.Mul) { toVisit.Push(ctx.GetOp0(term)); toVisit.Push(ctx.GetOp1(term)); @@ -1602,10 +1625,10 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) List newTerms = new(); ulong coeff = 1; - foreach (var term in terms) + foreach(var term in terms) { var asConstant = ctx.TryGetConstantValue(term); - if (asConstant != null) + if(asConstant != null) { coeff *= asConstant.Value; } @@ -1616,7 +1639,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) } } - if (coeff != null) + if(coeff != null) newTerms.Insert(0, ctx.Constant(coeff, ctx.GetWidth(id))); return newTerms; @@ -1635,18 +1658,18 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) var opcode = ctx.GetOpcode(term); if(opcode != AstOp.Mul && opcode != AstOp.Symbol) goto skip; - + // Search for coeff*subst if(opcode == AstOp.Mul) { // If multiplication, we are looking for coeff*(subst), where coeff is a constant. var coeff = ctx.GetOp0(term); - if (!ctx.IsConstant(coeff)) + if(!ctx.IsConstant(coeff)) goto skip; // Look for a variable on the rhs of the multiplication. var rhs = ctx.GetOp1(term); - if (!IsSubstitutedPolynomialSymbol(rhs, inverseSubstMapping)) + if(!IsSubstitutedPolynomialSymbol(rhs, inverseSubstMapping)) goto skip; // We found a polynomial part, add it to the list. @@ -1663,7 +1686,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) continue; } - skip: + skip: other.Add(term); continue; } @@ -1673,12 +1696,12 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) polyTerms.RemoveAll(x => x.Others.Any()); // Bail out if we found no polynomial terms. - if (!polyTerms.Any()) + if(!polyTerms.Any()) return null; // Now we have a list of polynomial parts, we want to try to simplify them. var uniqueBases = new Dictionary(); - foreach (var poly in polyTerms) + foreach(var poly in polyTerms) { foreach(var (_base, degree) in poly.ConstantPowers) { @@ -1696,23 +1719,23 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Bail out if there are more than three variables. // Nothing prevents more variables, but in general more than 3 variables indicates that there are bitwise parts // that we are not handling. - if (uniqueBases.Count > 3) + if(uniqueBases.Count > 3) return null; // Compute the dense vector size as a heuristic. ulong vecSize = 1; - foreach (var degree in uniqueBases.Values) + foreach(var degree in uniqueBases.Values) vecSize = SaturatingMul(vecSize, degree); // If the dense vector size would be greater than 64**3, we bail out. // In those cases, we may consider implementing variable partitioning and simplifying each partition separately. - if (vecSize > 64*64*64) + if(vecSize > 64 * 64 * 64) return null; // For now we only support polynomials up to degree 255, although this is a somewhat arbitrary limit. ulong limit = 254; var maxDeg = uniqueBases.MaxBy(x => x.Value).Value; - if (maxDeg > limit) + if(maxDeg > limit) throw new InvalidOperationException($"Polynomial has degree {maxDeg} which is greater than the limit {limit}"); // Otherwise we can carry on and simplify. @@ -1724,13 +1747,13 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) orderedVars.Sort((x, y) => { return VarsFirst(ctx, x, y); }); // Fill in the sparse polynomial data structure. - foreach (var poly in polyTerms) + foreach(var poly in polyTerms) { var coeff = poly.coeffSum; - + var constPowers = poly.ConstantPowers; var degrees = new byte[orderedVars.Count]; - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -1749,7 +1772,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // The polynomial reduction algorithm guarantees a minimal degree result, but it's often not the most simple result. // E.g. "x**10" becomes "96*x0 + 40*x0**2 + 84*x0**3 + 210*x0**4 + 161*x0**5 + 171*x0**6 + 42*x0**7 + 220*x0**8 + 1*x0**9" on 8 bits. // In the case of a single term solution, we reject the result if it is more complex. - if (polyTerms.Count == 1 && simplified.coeffs.Count(x => x.Value != 0) > 1) + if(polyTerms.Count == 1 && simplified.coeffs.Count(x => x.Value != 0) > 1) return null; List newTerms = new(); @@ -1765,12 +1788,12 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Then finally convert the sparse polynomial back to an AST. foreach(var (monom, coeff) in simplified.coeffs) { - if (coeff == 0) + if(coeff == 0) continue; List factors = new(); factors.Add(ctx.Constant(coeff, width)); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); if(deg == 0) @@ -1780,7 +1803,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) } var variable = orderedVars[i]; - if (deg == 1) + if(deg == 1) { factors.Add(variable); continue; @@ -1795,7 +1818,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // If the whole polynomial was folded to zero, discard it. - if (newTerms.Count == 0) + if(newTerms.Count == 0) newTerms.Add(ctx.Constant(0, width)); var result = ctx.Add(newTerms); @@ -1806,11 +1829,11 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) private bool IsSubstitutedPolynomialSymbol(AstIdx id, IReadOnlyDictionary inverseSubstMapping) { - if (!ctx.IsSymbol(id)) + if(!ctx.IsSymbol(id)) return false; // Make sure the variable is a substituted part - if (!inverseSubstMapping.TryGetValue(id, out var substituted)) + if(!inverseSubstMapping.TryGetValue(id, out var substituted)) return false; // Ensure that the substituted part atleast contains a polynomial somewhere. if(!ctx.GetHasPoly(substituted)) @@ -1827,7 +1850,7 @@ private bool IsSubstitutedPolynomialSymbol(AstIdx id, IReadOnlyDictionary x.Value, x => x.Key); - sum = ApplyBackSubstitution(ctx, sum, inverseMapping); + sum = BackSubstitute(ctx, sum, inverseMapping); // Try to simplify using the general simplifier. sum = ctx.RecursiveSimplify(sum); sum = SimplifyViaRecursiveSiMBA(sum, polySimplify); // Reject solution if it is more complex. - if (ctx.GetCost(sum) > ctx.GetCost(id)) + if(ctx.GetCost(sum) > ctx.GetCost(id)) return id; return sum; @@ -1909,12 +1932,12 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa return poly; }; - switch(opcode) + switch (opcode) { case AstOp.Mul: var factors = GetRootMultiplications(ctx, id); var facPolys = factors.Select(x => TryExpand(x, substMapping, false)).ToList(); - var product = IntermediatePoly.Mul(ctx,facPolys); + var product = IntermediatePoly.Mul(ctx, facPolys); resultPoly = product; // In this case we should probably distribute the coefficient down always. @@ -1929,7 +1952,7 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa case AstOp.Pow: var raisedTo = ctx.TryGetConstantValue(ctx.GetOp1(id)); - if (raisedTo == null) + if(raisedTo == null) throw new InvalidOperationException($"TODO: Handle powers of nonconstant degree"); // Unroll the power into repeated multiplications, then recurse down. @@ -1993,40 +2016,40 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa private IntermediatePoly TryReduce(IntermediatePoly poly) { var uniqueBases = new Dictionary(); - foreach (var monom in poly.coeffs.Keys) + foreach(var monom in poly.coeffs.Keys) { - foreach (var (basis, degree) in monom.varDegrees) + foreach(var (basis, degree) in monom.varDegrees) { uniqueBases.TryAdd(basis, 0); var oldDegree = uniqueBases[basis]; - if (degree > oldDegree) + if(degree > oldDegree) uniqueBases[basis] = degree; } } // For now we only support up to 8 variables. Note that in practice this limit could be increased. - if (uniqueBases.Count > 8) + if(uniqueBases.Count > 8) return poly; // Place a hard limit on the max degree. ulong limit = 254; - if (uniqueBases.Any(x => x.Value > limit)) + if(uniqueBases.Any(x => x.Value > limit)) return poly; ulong matrixSize = 1; - foreach (var deg in uniqueBases.Keys) + foreach(var deg in uniqueBases.Keys) { // Bail out if the result would be too large. UInt128 result = matrixSize * deg; - if (result > (UInt128)(64*64*64)) + if(result > (UInt128)(64 * 64 * 64)) return poly; matrixSize = SaturatingMul(matrixSize, deg); matrixSize &= poly.moduloMask; } - + // Place a limit on the matrix size. - if (matrixSize > (ulong)(64*64*64)) + if(matrixSize > (ulong)(64 * 64 * 64)) return poly; var width = poly.bitWidth; @@ -2038,9 +2061,9 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) // Fill in the sparse polynomial data structure. var degrees = new byte[orderedVars.Count]; - foreach (var (monom, coeff) in poly.coeffs) + foreach(var (monom, coeff) in poly.coeffs) { - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -2066,14 +2089,14 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) // Otherwise we can convert the sparse polynomial back to an AST. foreach(var (monom, coeff) in simplified.coeffs) { - if (coeff == 0) + if(coeff == 0) continue; Dictionary varDegrees = new(); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); - if (deg == 0) + if(deg == 0) continue; varDegrees.Add(orderedVars[i], deg); } @@ -2091,18 +2114,19 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) return outPoly; } - public static AstIdx ApplyBackSubstitution(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache = null) + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions) + => BackSubstitute(ctx, id, backSubstitutions, new(16)); + + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache) { - if (cache == null) - cache = new(); - if (backSubstitutions.TryGetValue(id, out var backSub)) + if(backSubstitutions.TryGetValue(id, out var backSub)) return backSub; - if (cache.TryGetValue(id, out var existing)) + if(cache.TryGetValue(id, out var existing)) return existing; - var op0 = () => ApplyBackSubstitution(ctx, ctx.GetOp0(id), backSubstitutions, cache); - var op1 = () => ApplyBackSubstitution(ctx, ctx.GetOp1(id), backSubstitutions, cache); + var op0 = () => BackSubstitute(ctx, ctx.GetOp0(id), backSubstitutions, cache); + var op1 = () => BackSubstitute(ctx, ctx.GetOp1(id), backSubstitutions, cache); var opcode = ctx.GetOpcode(id); var width = ctx.GetWidth(id); diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index b1989bf..091994a 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -23,6 +23,7 @@ using static Antlr4.Runtime.Atn.SemanticContext; using Mba.Simplifier.Interpreter; using Mba.Simplifier.Utility; +using Microsoft.VisualBasic; namespace Mba.Simplifier.Pipeline { @@ -45,8 +46,15 @@ public class LinearSimplifier // If enabled, we try to find a simpler representation of grouping of basis expressions. private readonly bool tryDecomposeMultiBitBases; + // For internal use in private projects (do not use) private readonly Action? resultVectorHook; + private readonly int depth; + + // For internal use in private projects (do not use) + // Optionally used to track which variables bits are demanded in the expression + private readonly Dictionary anfDemandedBits; + private readonly ApInt moduloMask = 0; // Number of combinations of input variables(2^n), for a single bit index. @@ -69,14 +77,14 @@ public class LinearSimplifier private AstIdx? initialInput = null; - public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null) + public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0, Dictionary anfDemandedBits = null) { if (variables == null) variables = ctx.CollectVariables(ast.Value); - return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec).Simplify(false, alreadySplit); + return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth, anfDemandedBits).Simplify(false, alreadySplit); } - public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null) + public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0, Dictionary anfDemandedBits = null) { // If we are given an AST, verify that the correct width was passed. if (ast != null && bitSize != ctx.GetWidth(ast.Value)) @@ -90,6 +98,8 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables this.multiBit = multiBit; this.tryDecomposeMultiBitBases = tryDecomposeMultiBitBases; this.resultVectorHook = resultVectorHook; + this.depth = depth; + this.anfDemandedBits = anfDemandedBits; moduloMask = (ApInt)ModuloReducer.GetMask(bitSize); groupSizes = GetGroupSizes(variables.Count); numCombinations = (ApInt)Math.Pow(2, variables.Count); @@ -124,7 +134,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables } } - private static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) + public static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) { // If all variables are of a correct size, no casting is necessary. if (!variables.Any(x => ctx.GetWidth(x) != bitSize)) @@ -179,15 +189,14 @@ public unsafe static ApInt[] JitResultVectorOld(AstCtx ctx, uint bitWidth, ApInt public unsafe static ApInt[] JitResultVectorNew(AstCtx ctx, uint bitWidth, ApInt mask, IReadOnlyList variables, AstIdx ast, bool multiBit, ApInt numCombinations) { - var jit = new Amd64OptimizingJit(ctx); - jit.Compile(ast, variables.ToList(), MultibitSiMBA.JitPage.Value, false); + ctx.Compile(ast, ModuloReducer.GetMask(bitWidth), variables.ToArray(), MultibitSiMBA.JitPage.Value); var vec = LinearSimplifier.Execute(ctx, bitWidth, mask, variables, multiBit, numCombinations, MultibitSiMBA.JitPage.Value, false); return vec; } - public unsafe static nint Compile(AstCtx ctx, ApInt mask, IReadOnlyList variables, AstIdx ast, nint codePtr) + public unsafe static nint CompileLegacy(AstCtx ctx, ApInt mask, IReadOnlyList variables, AstIdx ast, nint codePtr) { - return ctx.Compile(ast, mask, variables.ToArray(), codePtr); + return ctx.CompileLegacy(ast, mask, variables.ToArray(), codePtr); } public unsafe static ApInt[] Execute(AstCtx ctx, uint bitWidth, ApInt mask, IReadOnlyList variables, bool multiBit, ApInt numCombinations, nint codePtr, bool isOneBitVars) @@ -222,7 +231,7 @@ private AstIdx Simplify(bool useZ3 = false, bool alreadySplit = false) // If we were given a semi-linear expression, and the ground truth of that expression is linear, // truncate the size of the result vector down to 2^t, then treat it as a linear MBA. - if (multiBit && IsLinearResultVector()) + if (multiBit && IsLinearResultVector()) { multiBit = false; Array.Resize(ref resultVector, (int)numCombinations); @@ -239,6 +248,14 @@ private AstIdx Simplify(bool useZ3 = false, bool alreadySplit = false) // If we have a multi-bit result vector, try to rewrite as a linear result vector. If possible, update state accordingly. private unsafe bool IsLinearResultVector() { + foreach(var v in variables) + { + // If the variable is zero extended or truncated, we treat this as a semi-linear signature vector. + // Truncation cannot be treated as linear, though in the future we may be able to get away with treating zero extension as linear? + if (!ctx.IsSymbol(v)) + return false; + } + fixed (ApInt* ptr = &resultVector[0]) { ushort bitIndex = 0; @@ -619,6 +636,8 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded demandedMask &= ~(1ul << xorIdx); } + + var clone = variables.ToList(); AstIdx sum = ctx.Constant(constantOffset, width); for (int i = 0; i < linearCombinations.Count; i++) { @@ -628,14 +647,16 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded continue; var combMask = variableCombinations[i]; - var vComb = ctx.GetConjunctionFromVarMask(mutVars, combMask); + var widths = variables.Select(x => ctx.GetWidth(x)).ToList(); + + var vComb = ctx.GetConjunctionFromVarMask(clone, combMask); var term = Term(vComb, curr[0].coeff); sum = ctx.Add(sum, term); } // TODO: Instead of constructing a result vector inside the recursive linear simplifier call, we could instead convert the ANF vector back to DNF. // This should be much more efficient than constructing a result vector via JITing and evaluating an AST representation of the ANF vector. - return LinearSimplifier.Run(width, ctx, sum, false, false, false, variables); + return LinearSimplifier.Run(width, ctx, sum, false, false, false, mutVars, depth: depth + 1); } private void EliminateUniqueValues(Dictionary coeffToTable) @@ -908,7 +929,7 @@ public static ApInt SubtractConstantOffset(ApInt moduloMask, ApInt[] resultVecto if (multiBit) { - var r = SimplifyOneValueMultibit(constant, resultVector.ToArray(), variableCombinations); + var r = SimplifyOneValueMultibit(constant, resultVector.ToArray()); if (r != null) { CheckSolutionComplexity(r.Value, 1, null); @@ -971,6 +992,18 @@ public static ApInt SubtractConstantOffset(ApInt moduloMask, ApInt[] resultVecto } } + if(anfDemandedBits != null) + { + for(int i = 0; i < linearCombinations.Count; i++) + { + anfDemandedBits.TryAdd((ApInt)i, 0); + foreach(var (coeff, mask) in linearCombinations[i]) + { + anfDemandedBits[(ApInt)i] |= mask; + } + } + } + // Identify variables that are not present in any conjunction. // E.g. if we have a + (b&c), then a is not present in a conjunction, while b is. var withNoConjunctions = GetVariablesWithNoConjunctions(variableCombinations, linearCombinations); @@ -1367,29 +1400,9 @@ private ulong GetVariablesWithNoConjunctions(ulong[] variableCombinations, List< return null; } - private AstIdx? SimplifyOneValueMultibit(ulong constant, ApInt[] withoutConstant, ApInt[] variableCombinations) + // Algorithm: Start at some point, check if you can change every coefficient to the target coefficient + private AstIdx? SimplifyOneValueMultibit(ulong constant, ApInt[] withoutConstant) { - // Algorithm: Start at some point, check if you can change every coefficient to the target coefficient - bool truthTableIdx = true; - if (!truthTableIdx) - variableCombinations = new List() { 0 }.Concat(variableCombinations).ToArray(); - - var getConj = (ApInt i, ApInt? mask) => - { - if (truthTableIdx) - { - var boolean = GetBooleanForIndex((int)i); - if (mask == null) - return boolean; - - return ctx.And(ctx.Constant(mask.Value, width), boolean); - } - - return ConjunctionFromVarMask(1, i, mask); - }; - - AstIdx.ctx = ctx; - // Reduce each row to a canonical form. If a row cannot be canonicalized, there is no solution. var uniqueCoeffs = TryReduceRows(constant, withoutConstant); if (uniqueCoeffs == null) diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 9d98e31..2a94557 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -5,6 +5,7 @@ using Microsoft.Z3; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -33,13 +34,13 @@ public class ProbableEquivalenceChecker private unsafe delegate* unmanaged[SuppressGCTransition] func2; - public static bool ProbablyEquivalent(AstCtx ctx, AstIdx before, AstIdx after) + public static bool ProbablyEquivalent(AstCtx ctx, AstIdx before, AstIdx after, bool slowHeuristics = true) { var pagePtr1 = JitUtils.AllocateExecutablePage(4096); var pagePtr2 = JitUtils.AllocateExecutablePage(4096); var allVars = ctx.CollectVariables(before).Concat(ctx.CollectVariables(after)).Distinct().OrderBy(x => ctx.GetSymbolName(x)).ToList(); - bool probablyEquivalent = new ProbableEquivalenceChecker(ctx, allVars, before, after, pagePtr1, pagePtr2).ProbablyEquivalent(true); + bool probablyEquivalent = new ProbableEquivalenceChecker(ctx, allVars, before, after, pagePtr1, pagePtr2).ProbablyEquivalent(false); JitUtils.FreeExecutablePage(pagePtr1); JitUtils.FreeExecutablePage(pagePtr2); @@ -59,11 +60,11 @@ public ProbableEquivalenceChecker(AstCtx ctx, List variables, AstIdx bef public unsafe bool ProbablyEquivalent(bool slowHeuristics = false) { var jit1 = new Amd64OptimizingJit(ctx); - jit1.Compile(before, variables, pagePtr1, true); + jit1.Compile(before, variables, pagePtr1, false); func1 = (delegate* unmanaged[SuppressGCTransition])pagePtr1; var jit2 = new Amd64OptimizingJit(ctx); - jit2.Compile(after, variables, pagePtr2, true); + jit2.Compile(after, variables, pagePtr2, false); func2 = (delegate* unmanaged[SuppressGCTransition])pagePtr2; var vArray = stackalloc ulong[variables.Count]; @@ -113,6 +114,7 @@ private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses) var op1 = func1(vArray); var op2 = func2(vArray); + if (op1 != op2) return false; } @@ -128,6 +130,8 @@ private unsafe bool AllCombs(ulong* vArray, ulong a, ulong b) return false; if (!SignatureVectorEquivalent(vArray, a, b)) return false; + if (!SignatureVectorEquivalent(vArray, b, a)) + return false; return true; } @@ -168,34 +172,37 @@ private ulong Next() public static void ProbablyEquivalentZ3(AstCtx ctx, AstIdx before, AstIdx after) { - var z3Ctx = new Context(); - var translator = new Z3Translator(ctx, z3Ctx); - var beforeZ3 = translator.Translate(before); - var afterZ3 = translator.Translate(after); - var solver = z3Ctx.MkSolver("QF_BV"); - - // Set the maximum timeout to 10 seconds. - var p = z3Ctx.MkParams(); - uint solverLimit = 10000; - p.Add("timeout", solverLimit); - solver.Parameters = p; - - Console.WriteLine("Proving equivalence...\n"); - solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); - var check = solver.Check(); - - var printModel = (Model model) => + using (var z3Ctx = new Context()) { - var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); - return $"[{String.Join(", ", values)}]"; - }; - - if (check == Status.UNSATISFIABLE) - Console.WriteLine("Expressions are equivalent."); - else if (check == Status.SATISFIABLE) - Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); - else - Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + var translator = new Z3Translator(ctx, z3Ctx); + var beforeZ3 = translator.Translate(before); + var afterZ3 = translator.Translate(after); + var solver = z3Ctx.MkSolver("QF_BV"); + + // Set the maximum timeout to 10 seconds. + var p = z3Ctx.MkParams(); + uint solverLimit = 5000; + p.Add("timeout", solverLimit); + solver.Parameters = p; + + Console.WriteLine("Proving equivalence...\n"); + solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); + var check = solver.Check(); + + var printModel = (Model model) => + { + var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); + return $"[{String.Join(", ", values)}]"; + }; + + if (check == Status.UNSATISFIABLE) + Console.WriteLine("Expressions are equivalent."); + else if (check == Status.SATISFIABLE) + Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); + else + Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + + } } } diff --git a/Mba.Simplifier/Utility/DagFormatter.cs b/Mba.Simplifier/Utility/DagFormatter.cs new file mode 100644 index 0000000..936d147 --- /dev/null +++ b/Mba.Simplifier/Utility/DagFormatter.cs @@ -0,0 +1,102 @@ +using Mba.Simplifier.Bindings; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Mba.Simplifier.Utility +{ + public static class DagFormatter + { + public static string Format(AstCtx ctx, AstIdx idx) + { + var sb = new StringBuilder(); + Format(sb, ctx, idx, new()); + return sb.ToString(); + } + + private static void Format(StringBuilder sb, AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + // Allocate value numbers for the operands if necessary + var opc = ctx.GetOpcode(idx); + var opcount = GetOpCount(opc); + if (opcount >= 1 && !valueNumbers.ContainsKey(ctx.GetOp0(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp0(idx))) + Format(sb, ctx, ctx.GetOp0(idx), valueNumbers); + if (opcount >= 2 && !valueNumbers.ContainsKey(ctx.GetOp1(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp1(idx))) + Format(sb, ctx, ctx.GetOp1(idx), valueNumbers); + + var op0 = () => $"{Lookup(ctx, ctx.GetOp0(idx), valueNumbers)}"; + var op1 = () => $"{Lookup(ctx, ctx.GetOp1(idx), valueNumbers)}"; + + var vNum = valueNumbers.Count; + valueNumbers.Add(idx, vNum); + var width = ctx.GetWidth(idx); + if (opc == AstOp.Symbol) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetSymbolName(idx)}"); + else if (opc == AstOp.Constant) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetConstantValue(idx)}"); + else if (opc == AstOp.Neg) + sb.AppendLine($"i{width} t{vNum} = ~{op0()}"); + else if (opc == AstOp.Zext || opc == AstOp.Trunc) + { + sb.AppendLine($"i{width} t{vNum} = {GetOperatorName(opc)} i{ctx.GetWidth(ctx.GetOp0(idx))} {op0()} to i{width}"); + } + else + { + sb.AppendLine($"i{width} t{vNum} = {op0()} {GetOperatorName(opc)} {op1()}"); + } + } + + private static bool IsConstOrSymbol(AstCtx ctx, AstIdx idx) + => ctx.GetOpcode(idx) == AstOp.Constant || ctx.GetOpcode(idx) == AstOp.Symbol; + + private static string Lookup(AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + var opc = ctx.GetOpcode(idx); + if (opc == AstOp.Constant) + return ctx.GetConstantValue(idx).ToString(); + if (opc == AstOp.Symbol) + return ctx.GetSymbolName(idx); + return $"t{valueNumbers[idx]}"; + } + + private static int GetOpCount(AstOp opc) + { + return opc switch + { + AstOp.None => 0, + AstOp.Add => 2, + AstOp.Mul => 2, + AstOp.Pow => 2, + AstOp.And => 2, + AstOp.Or => 2, + AstOp.Xor => 2, + AstOp.Neg => 1, + AstOp.Lshr => 2, + AstOp.Constant => 0, + AstOp.Symbol => 0, + AstOp.Zext => 1, + AstOp.Trunc => 1, + }; + } + + private static string GetOperatorName(AstOp opc) + { + return opc switch + { + AstOp.Add => "+", + AstOp.Mul => "*", + AstOp.Pow => "**", + AstOp.And => "&", + AstOp.Or => "|", + AstOp.Xor => "^", + AstOp.Neg => "~", + AstOp.Lshr => ">>", + AstOp.Zext => "zext", + AstOp.Trunc => "trunc", + _ => throw new InvalidOperationException(), + }; + } + } +} diff --git a/Mba.Simplifier/Utility/DagRustFormatter.cs b/Mba.Simplifier/Utility/DagRustFormatter.cs new file mode 100644 index 0000000..e6975e8 --- /dev/null +++ b/Mba.Simplifier/Utility/DagRustFormatter.cs @@ -0,0 +1,105 @@ +using Mba.Simplifier.Bindings; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Mba.Simplifier.Utility +{ + public static class DagRustFormatter + { + public static string Format(AstCtx ctx, AstIdx idx) + { + var sb = new StringBuilder(); + Format(sb, ctx, idx, new()); + return sb.ToString(); + } + + private static void Format(StringBuilder sb, AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + // Allocate value numbers for the operands if necessary + var opc = ctx.GetOpcode(idx); + var opcount = GetOpCount(opc); + if (opcount >= 1 && !valueNumbers.ContainsKey(ctx.GetOp0(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp0(idx))) + Format(sb, ctx, ctx.GetOp0(idx), valueNumbers); + if (opcount >= 2 && !valueNumbers.ContainsKey(ctx.GetOp1(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp1(idx))) + Format(sb, ctx, ctx.GetOp1(idx), valueNumbers); + + var op0 = () => $"{Lookup(ctx, ctx.GetOp0(idx), valueNumbers)}"; + var op1 = () => $"{Lookup(ctx, ctx.GetOp1(idx), valueNumbers)}"; + + var vNum = valueNumbers.Count; + valueNumbers.Add(idx, vNum); + var width = ctx.GetWidth(idx); + + var tName = $"u{width}"; + + if (opc == AstOp.Symbol) + sb.AppendLine($"let t{vNum} = ctx.arena.symbol_with_name({ctx.GetSymbolName(idx)}, {width})"); + else if (opc == AstOp.Constant) + sb.AppendLine($"let t{vNum} = {ctx.GetConstantValue(idx)}"); + else if (opc == AstOp.Neg) + sb.AppendLine($"let t{vNum} = ctx.arena.neg({op0()})"); + else if (opc == AstOp.Zext || opc == AstOp.Trunc) + { + sb.AppendLine($"i{width} t{vNum} = {GetOperatorName(opc)} i{ctx.GetWidth(ctx.GetOp0(idx))} {op0()} to i{width}"); + } + else + { + sb.AppendLine($"i{width} t{vNum} = {op0()} {GetOperatorName(opc)} {op1()}"); + } + } + + private static bool IsConstOrSymbol(AstCtx ctx, AstIdx idx) + => ctx.GetOpcode(idx) == AstOp.Constant || ctx.GetOpcode(idx) == AstOp.Symbol; + + private static string Lookup(AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + var opc = ctx.GetOpcode(idx); + if (opc == AstOp.Constant) + return ctx.GetConstantValue(idx).ToString(); + if (opc == AstOp.Symbol) + return ctx.GetSymbolName(idx); + return $"t{valueNumbers[idx]}"; + } + + private static int GetOpCount(AstOp opc) + { + return opc switch + { + AstOp.None => 0, + AstOp.Add => 2, + AstOp.Mul => 2, + AstOp.Pow => 2, + AstOp.And => 2, + AstOp.Or => 2, + AstOp.Xor => 2, + AstOp.Neg => 1, + AstOp.Lshr => 2, + AstOp.Constant => 0, + AstOp.Symbol => 0, + AstOp.Zext => 1, + AstOp.Trunc => 1, + }; + } + + private static string GetOperatorName(AstOp opc) + { + return opc switch + { + AstOp.Add => "+", + AstOp.Mul => "*", + AstOp.Pow => "**", + AstOp.And => "&", + AstOp.Or => "|", + AstOp.Xor => "^", + AstOp.Neg => "~", + AstOp.Lshr => ">>", + AstOp.Zext => "zext", + AstOp.Trunc => "trunc", + _ => throw new InvalidOperationException(), + }; + } + } +} diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 882b645..d704c56 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -15,7 +15,7 @@ bool printUsage = false; uint bitWidth = 64; bool useEqsat = false; -bool proveEquivalence = true; +bool proveEquivalence = false; string inputText = null; var printHelp = () => @@ -37,7 +37,7 @@ case "-h": printUsage = true; break; - case "-b": + case "-b": bitWidth = uint.Parse(args[i + 1]); i++; break; @@ -81,7 +81,7 @@ // Run the simplification pipeline. id = simplifier.SimplifyGeneral(id); // Try to expand and reduce the polynomial parts(if any exist). - if(ctx.GetHasPoly(id)) + if (ctx.GetHasPoly(id)) id = simplifier.ExpandReduce(id); if (!useEqsat)