From d35609d82c90091eab69383f778e46396d415652 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Mon, 15 Sep 2025 23:57:11 -0400 Subject: [PATCH 01/21] Upcast variables as necessary --- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 150 ++++++++++--------- Mba.Simplifier/Pipeline/LinearSimplifier.cs | 2 +- 2 files changed, 77 insertions(+), 75 deletions(-) diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 97c9ec0..ae1ba95 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -88,11 +88,11 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (ctx.IsConstant(id)) return id; - if(linClass != AstClassification.Nonlinear) + if (linClass != AstClassification.Nonlinear) { // Bail out if there are too many variables. var vars = ctx.CollectVariables(id); - if(vars.Count > 11 || vars.Count == 0) + if (vars.Count > 11 || vars.Count == 0) { var simplified = SimplifyViaTermRewriting(id); simbaCache.Add(id, simplified); @@ -116,7 +116,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // Discard any vanished substitutions var usedVars = ctx.CollectVariables(withSubstitutions).ToHashSet(); - foreach(var (substValue, substVar) in substMapping.ToList()) + foreach (var (substValue, substVar) in substMapping.ToList()) { if (!usedVars.Contains(substVar)) substMapping.Remove(substValue); @@ -149,7 +149,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) reducedPoly = ReducePolynomials(GetRootTerms(ctx, withSubstitutions), substMapping, inverseMapping); // If we succeeded, reset the state. - if(reducedPoly != null) + if (reducedPoly != null) { // Back substitute the original substitutions. reducedPoly = ApplyBackSubstitution(ctx, reducedPoly.Value, inverseMapping); @@ -352,7 +352,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub (and0, and1) = (and1, and0); (id0, id1) = (id1, id0); } - + // Rewrite (a&mask) as `Trunc(a)`, or `Trunc(a & mask)` if mask is not completely a bit mask. // This is a form of adhoc demanded bits based simplification if (ctx.IsConstant(and0) && !ctx.IsConstant(and1)) @@ -379,7 +379,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub return ext; } } - + return ctx.And(and0, and1); } @@ -508,12 +508,12 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) // Skip if this is not a multiplication. var opcode = ctx.GetOpcode(id); - var roots = GetRootMultiplications(ctx,id); + var roots = GetRootMultiplications(ctx, id); ulong coeffSum = 0; Dictionary constantPowers = new(); List others = new(); - foreach(var root in roots) + foreach (var root in roots) { var code = ctx.GetOpcode(root); if (code == AstOp.Constant) @@ -525,12 +525,12 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) constantPowers.TryAdd(root, 0); constantPowers[root]++; } - else if(code == AstOp.Pow) + else if (code == AstOp.Pow) { // If we have a power by a nonconstant, we can't really do much here. var degree = ctx.GetOp1(root); var constPower = ctx.TryGetConstantValue(degree); - if(constPower == null) + if (constPower == null) { others.Add(root); continue; @@ -568,7 +568,7 @@ public static int VarsFirst(AstCtx ctx, AstIdx a, AstIdx b) return comeFirst; if (op1 && !op0) return comeLast; - if(op0 && op1) + if (op0 && op1) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); return comeLast; } @@ -590,7 +590,7 @@ private int CompareTo(AstIdx a, AstIdx b) return comeLast; // Sort symbols alphabetically - if(op0 == AstOp.Symbol && op1 == AstOp.Symbol) + if (op0 == AstOp.Symbol && op1 == AstOp.Symbol) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); if (op0 == AstOp.Pow) return comeLast; @@ -604,7 +604,7 @@ private AstIdx GetSubstitution(AstIdx id, Dictionary substitutio if (substitutionMapping.TryGetValue(id, out var existing)) return existing; - while(true) + while (true) { var subst = ctx.Symbol($"subst{substCount}", ctx.GetWidth(id)); substCount++; @@ -631,7 +631,7 @@ private AstIdx TryUnmergeLinCombs(AstIdx withSubstitutions, Dictionary UnmergeNegatedParts(Dictionary(); var results = new List(); - for(int i = 0 ; i < inputExpressions.Count; i++) + for (int i = 0; i < inputExpressions.Count; i++) { // Substitute all of the nonlinear parts for this expression // Here we share the list of substitutions @@ -701,7 +701,7 @@ private Dictionary UnmergeNegatedParts(Dictionary> vecToExpr = new(); - for(int i = 0; i < results.Count; i++) + for (int i = 0; i < results.Count; i++) { var expr = results[i]; var w = ctx.GetWidth(expr); @@ -717,7 +717,7 @@ private Dictionary UnmergeNegatedParts(Dictionary varToNewSubstValue = new Dictionary(); - foreach(var (key, members) in vecToExpr) + foreach (var (key, members) in vecToExpr) { var temp = results[members.First().index]; var w = ctx.GetWidth(temp); @@ -731,12 +731,12 @@ private Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary varToDemandedBits = new(); - foreach(var (expr, substVar) in substitutionMapping) + foreach (var (expr, substVar) in substitutionMapping) ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits); // Compute the total number of demanded variable bits in the substituted parts. @@ -888,7 +888,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong { // Construct a result vector for the linear part. var substVars = substitutionMapping.Values.ToList(); - var allVars = ctx.CollectVariables(withSubstitutions); + IReadOnlyList allVars = ctx.CollectVariables(withSubstitutions); var bitSize = ctx.GetWidth(withSubstitutions); var numCombinations = (ulong)Math.Pow(2, allVars.Count); var groupSizes = LinearSimplifier.GetGroupSizes(allVars.Count); @@ -927,10 +927,12 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong List constrainedParts = new(); // Decompose result vector into semi-linear, unconstrained, and constrained parts. + // Upcast variables as necessary! + allVars = LinearSimplifier.CastVariables(ctx, allVars, bitSize); int resultVecIdx = 0; - for(int i = 0; i < linearCombinations.Count; i++) + for (int i = 0; i < linearCombinations.Count; i++) { - foreach(var (coeff, bitMask) in linearCombinations[i]) + foreach (var (coeff, bitMask) in linearCombinations[i]) { if (coeff == 0) goto skip; @@ -966,7 +968,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong } // TODO: Refactor out! - private static (ulong[], List>) GetAnf(uint width, List variables, List groupSizes, ulong[] resultVector, bool multiBit) + private static (ulong[], List>) GetAnf(uint width, IReadOnlyList variables, List groupSizes, ulong[] resultVector, bool multiBit) { // Get all combinations of variables. var moduloMask = ModuloReducer.GetMask(width); @@ -1060,7 +1062,7 @@ private static (ulong[], List>) GetAnf(uint w } // Returns true if two expressions are guaranteed to be equivalent - private unsafe bool IsConstrainedExpressionEquivalent(uint width,List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) + private unsafe bool IsConstrainedExpressionEquivalent(uint width, List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) { int totalDemanded = demandedVars.Sum(x => BitOperations.PopCount(x.demandedMask)); @@ -1218,7 +1220,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits); var opc = ctx.GetOpcode(idx); - switch(opc) + switch (opc) { // If we have a symbol, union the set of demanded bits case AstOp.Symbol: @@ -1313,7 +1315,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar return null; // Add back any banned parts. - if(banned.Any()) + if (banned.Any()) { var sum = ctx.Add(banned.Select(x => GetAstForPolynomialParts(x))); result = ctx.Add(result.Value, sum); @@ -1336,7 +1338,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar private List UnmergePolynomialParts(Dictionary substitutionMapping, List parts) { // Skip if there is only one substituted part. - if(substitutionMapping.Count <= 1) + if (substitutionMapping.Count <= 1) return parts; // Try to rewrite substituted parts as negations of one another. Exit early if this fails. @@ -1369,11 +1371,11 @@ private List UnmergePolynomialParts(Dictionary // Rewrite as a sum of polynomial parts, where the factors are linear MBAs with substitution of nonlinear parts. var bannedParts = new List(); List partsWithSubstitutions = new(); - foreach(var polyPart in polyParts) + foreach (var polyPart in polyParts) { bool isSemiLinear = false; Dictionary powers = new(); - foreach(var factor in polyPart.ConstantPowers) + foreach (var factor in polyPart.ConstantPowers) { var withSubstitutions = GetAstWithSubstitutions(factor.Key, substMapping, ref isSemiLinear); powers.TryAdd(withSubstitutions, 0); @@ -1381,7 +1383,7 @@ private List UnmergePolynomialParts(Dictionary } // TODO: Handle the semi-linear case. - if(isSemiLinear) + if (isSemiLinear) { bannedParts.Add(polyPart); continue; @@ -1443,7 +1445,7 @@ private List UnmergePolynomialParts(Dictionary } // Calculate the max possible size of the resulting expression when multiplied out. - for(ulong i = 0; i < degree; i++) + for (ulong i = 0; i < degree; i++) { size = SaturatingMul(size, numNonZeroes); } @@ -1459,7 +1461,7 @@ private List UnmergePolynomialParts(Dictionary // When the basis element corresponds to the constant offset, we want to make the base bitwise expression be `1`. // Otherwise we just substitute it with a variable. AstIdx basis = ctx.Constant(1, (byte)bitSize); - if(i != 0) + if (i != 0) { if (!basisSubstitutions.TryGetValue((ulong)i, out basis)) { @@ -1481,7 +1483,7 @@ private List UnmergePolynomialParts(Dictionary // If the expanded form would be too large, we want to block this polynomial. // It would take too long. - if(size >= 1000) + if (size >= 1000) { bannedParts.Add(polyPart); continue; @@ -1502,7 +1504,7 @@ private List UnmergePolynomialParts(Dictionary poly = constOffset; } - + polys.Add(poly.Value); } @@ -1514,7 +1516,7 @@ private List UnmergePolynomialParts(Dictionary var linComb = ctx.Add(polys); var reduced = ExpandReduce(linComb, false); // Add back banned parts - if(bannedParts.Any()) + if (bannedParts.Any()) { var sum = ctx.Add(bannedParts.Select(x => GetAstForPolynomialParts(x))); reduced = ctx.Add(reduced, sum); @@ -1628,16 +1630,16 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Try to decompose into high degree polynomials parts. List polyTerms = new(); List other = new(); - foreach(var term in terms) + foreach (var term in terms) { // Typically this is going to be a multiplication(coefficient over substituted variable), or whole substituted variable. // TODO: Handle negation. var opcode = ctx.GetOpcode(term); - if(opcode != AstOp.Mul && opcode != AstOp.Symbol) + if (opcode != AstOp.Mul && opcode != AstOp.Symbol) goto skip; - + // Search for coeff*subst - if(opcode == AstOp.Mul) + if (opcode == AstOp.Mul) { // If multiplication, we are looking for coeff*(subst), where coeff is a constant. var coeff = ctx.GetOp0(term); @@ -1656,14 +1658,14 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) } // Search for a plain substitution(omitted coefficient of 1) - if(opcode == AstOp.Symbol && IsSubstitutedPolynomialSymbol(term, inverseSubstMapping)) + if (opcode == AstOp.Symbol && IsSubstitutedPolynomialSymbol(term, inverseSubstMapping)) { var invSubst = inverseSubstMapping[term]; polyTerms.Add(GetPolynomialParts(ctx, invSubst)); continue; } - skip: + skip: other.Add(term); continue; } @@ -1680,7 +1682,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) var uniqueBases = new Dictionary(); foreach (var poly in polyTerms) { - foreach(var (_base, degree) in poly.ConstantPowers) + foreach (var (_base, degree) in poly.ConstantPowers) { // Set the default degree to zero. uniqueBases.TryAdd(_base, 0); @@ -1688,7 +1690,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // For each unique base, we want to keep track of the highest degree. var oldDegree = uniqueBases[_base]; var newDeg = degree; - if(newDeg > oldDegree) + if (newDeg > oldDegree) uniqueBases[_base] = newDeg; } } @@ -1706,7 +1708,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // If the dense vector size would be greater than 64**3, we bail out. // In those cases, we may consider implementing variable partitioning and simplifying each partition separately. - if (vecSize > 64*64*64) + if (vecSize > 64 * 64 * 64) return null; // For now we only support polynomials up to degree 255, although this is a somewhat arbitrary limit. @@ -1727,10 +1729,10 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) foreach (var poly in polyTerms) { var coeff = poly.coeffSum; - + var constPowers = poly.ConstantPowers; var degrees = new byte[orderedVars.Count]; - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -1756,24 +1758,24 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Add back all of the ignored parts. newTerms.AddRange(other); // Add back all of the discarded polynomial parts - foreach(var part in discarded) + foreach (var part in discarded) { var ast = GetAstForPolynomialParts(part); newTerms.Add(ast); } // Then finally convert the sparse polynomial back to an AST. - foreach(var (monom, coeff) in simplified.coeffs) + foreach (var (monom, coeff) in simplified.coeffs) { if (coeff == 0) continue; List factors = new(); factors.Add(ctx.Constant(coeff, width)); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); - if(deg == 0) + if (deg == 0) { factors.Add(ctx.Constant(1, width)); continue; @@ -1813,17 +1815,17 @@ private bool IsSubstitutedPolynomialSymbol(AstIdx id, IReadOnlyDictionary terms = new(); var width = ctx.GetWidth(id); - foreach(var (monom, coeff) in result.coeffs) + foreach (var (monom, coeff) in result.coeffs) { List factors = new(); factors.Add(ctx.Constant(coeff, width)); - foreach(var (varIdx, degree) in monom.varDegrees) + foreach (var (varIdx, degree) in monom.varDegrees) { // Skip a constant factor of 1 if (degree == 0) continue; - if(degree == 1) + if (degree == 1) { factors.Add(varIdx); continue; @@ -1909,12 +1911,12 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa return poly; }; - switch(opcode) + switch (opcode) { case AstOp.Mul: var factors = GetRootMultiplications(ctx, id); var facPolys = factors.Select(x => TryExpand(x, substMapping, false)).ToList(); - var product = IntermediatePoly.Mul(ctx,facPolys); + var product = IntermediatePoly.Mul(ctx, facPolys); resultPoly = product; // In this case we should probably distribute the coefficient down always. @@ -1981,7 +1983,7 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa // If this is the root of a polynomial part, we want to try and reduce it. // Alternatively we may apply a reduction if there are too many terms. bool shouldReduce = isRoot || resultPoly?.coeffs?.Count > 10; - if(shouldReduce && resultPoly != null) + if (shouldReduce && resultPoly != null) { resultPoly = TryReduce(resultPoly); } @@ -2018,15 +2020,15 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) { // Bail out if the result would be too large. UInt128 result = matrixSize * deg; - if (result > (UInt128)(64*64*64)) + if (result > (UInt128)(64 * 64 * 64)) return poly; matrixSize = SaturatingMul(matrixSize, deg); matrixSize &= poly.moduloMask; } - + // Place a limit on the matrix size. - if (matrixSize > (ulong)(64*64*64)) + if (matrixSize > (ulong)(64 * 64 * 64)) return poly; var width = poly.bitWidth; @@ -2040,7 +2042,7 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) var degrees = new byte[orderedVars.Count]; foreach (var (monom, coeff) in poly.coeffs) { - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -2059,18 +2061,18 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) var newCount = simplified.coeffs.Count(x => x.Value != 0); // If we got a result with more terms, skip it. // This is required when doing expansion, since expansion is exponential in the number of terms. - if(newCount > oldCount) + if (newCount > oldCount) return poly; var outPoly = new IntermediatePoly(width); // Otherwise we can convert the sparse polynomial back to an AST. - foreach(var (monom, coeff) in simplified.coeffs) + foreach (var (monom, coeff) in simplified.coeffs) { if (coeff == 0) continue; Dictionary varDegrees = new(); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); if (deg == 0) @@ -2079,7 +2081,7 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) } // Handle the case of a constant offset. - if(varDegrees.Count == 0) + if (varDegrees.Count == 0) { varDegrees.Add(ctx.Constant(1, width), 1); } diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index b1989bf..0e8d074 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -124,7 +124,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables } } - private static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) + public static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) { // If all variables are of a correct size, no casting is necessary. if (!variables.Any(x => ctx.GetWidth(x) != bitSize)) From a2963b9b28c14e74cf523079e9cebd3820e60b77 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 16 Sep 2025 00:06:29 -0400 Subject: [PATCH 02/21] demanded bits fix --- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 2 +- Simplifier/Program.cs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index ae1ba95..13fbe72 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -1290,7 +1290,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar op0(currDemanded); break; case AstOp.Zext: - op0(currDemanded & ctx.GetWidth(ctx.GetOp0(idx))); + op0(currDemanded & ModuloReducer.GetMask(ctx.GetWidth(ctx.GetOp0(idx)))); break; default: throw new InvalidOperationException($"Cannot compute demanded bits for {opc}"); diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 882b645..476bbb8 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -18,6 +18,8 @@ bool proveEquivalence = true; string inputText = null; +inputText = "((((1:i32&((uns17:i8 zx i32)&(~uns18:i32)))|(4294964010:i32&(~((uns17:i8 zx i32)|(~uns18:i32)))))|(4294964011:i32&((uns17:i8 zx i32)&uns18:i32)))|(4:i32*(1:i32&(uns19:i8 zx i32))))"; + var printHelp = () => { Console.WriteLine("Usage: Simplifier.exe"); From dad7cfa9dc231a9203118eb784d7efe5d02ff1e2 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 16 Sep 2025 00:07:50 -0400 Subject: [PATCH 03/21] add pow handler for demanded bits --- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 13fbe72..f8ba5dd 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -1235,6 +1235,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // For addition by a constant we can also get more precision case AstOp.Add: case AstOp.Mul: + case AstOp.Pow: // If we have addition/multiplication, we only care about bits at and below the highest set bit. var demandedWidth = 64 - (uint)BitOperations.LeadingZeroCount(currDemanded); currDemanded = ModuloReducer.GetMask(demandedWidth); From 0ffa2946cbfa5f860663b01f5f9c1620ec292faa Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 16 Sep 2025 01:23:53 -0400 Subject: [PATCH 04/21] Bug fix variable elimination logic; Stop treeifying inside recursive isle simplifier --- EqSat/src/simple_ast.rs | 10 +++++- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 14 +++++--- Mba.Simplifier/Pipeline/LinearSimplifier.cs | 38 ++++++++++++++++---- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 78b8333..87f5bcb 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -27,6 +27,7 @@ pub struct AstIdx(pub u32); pub struct Arena { pub elements: Vec<(SimpleAst, AstData)>, ast_to_idx: AHashMap, + isle_cache: AHashMap, // Map a name to it's corresponds symbol index. symbol_ids: Vec<(String, AstIdx)>, @@ -37,6 +38,7 @@ impl Arena { pub fn new() -> Self { let elements = Vec::with_capacity(65536); let ast_to_idx = AHashMap::with_capacity(65536); + let isle_cache = AHashMap::with_capacity(65536); let symbol_ids = Vec::with_capacity(255); let name_to_symbol = AHashMap::with_capacity(255); @@ -44,6 +46,7 @@ impl Arena { Arena { elements: elements, ast_to_idx: ast_to_idx, + isle_cache: isle_cache, symbol_ids: symbol_ids, name_to_symbol: name_to_symbol, @@ -813,6 +816,9 @@ pub fn eval_ast(ctx: &Context, idx: AstIdx, value_mapping: &HashMap // Recursively apply ISLE over an AST. pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { + if ctx.arena.isle_cache.get(&idx).is_some() { + return *ctx.arena.isle_cache.get(&idx).unwrap(); + } let mut ast = ctx.arena.get_node(idx).clone(); match ast { @@ -862,7 +868,9 @@ pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { ast = result.unwrap(); } - return ctx.arena.ast_to_idx[&ast]; + let result = ctx.arena.ast_to_idx[&ast]; + ctx.arena.isle_cache.insert(idx, result); + result } // Evaluate the current AST for all possible combinations of zeroes and ones as inputs. diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index f8ba5dd..eb5aec2 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -825,7 +825,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong // TODO: Keep track of which bits are demanded by the parent(withSubstitutions) Dictionary varToDemandedBits = new(); foreach (var (expr, substVar) in substitutionMapping) - ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits); + ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, new()); // Compute the total number of demanded variable bits in the substituted parts. ulong totalDemanded = 0; @@ -1210,14 +1210,20 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width, List i } } + JitUtils.FreeExecutablePage(pagePtr1); + JitUtils.FreeExecutablePage(pagePtr2); return expectedExpr; } // TODO: Cache results to avoid exponentially visiting shared nodes - private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits) + private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet<(AstIdx idx, ulong currDemanded)> seen) { - var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits); - var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits); + if (seen.Contains((idx, currDemanded))) + return; + seen.Add((idx, currDemanded)); + + var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen); + var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen); var opc = ctx.GetOpcode(idx); switch (opc) diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index 0e8d074..3c911c5 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -46,7 +46,7 @@ public class LinearSimplifier private readonly bool tryDecomposeMultiBitBases; private readonly Action? resultVectorHook; - + private readonly int depth; private readonly ApInt moduloMask = 0; // Number of combinations of input variables(2^n), for a single bit index. @@ -69,14 +69,14 @@ public class LinearSimplifier private AstIdx? initialInput = null; - public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null) + public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) { if (variables == null) variables = ctx.CollectVariables(ast.Value); - return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec).Simplify(false, alreadySplit); + return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth).Simplify(false, alreadySplit); } - public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null) + public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) { // If we are given an AST, verify that the correct width was passed. if (ast != null && bitSize != ctx.GetWidth(ast.Value)) @@ -90,6 +90,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables this.multiBit = multiBit; this.tryDecomposeMultiBitBases = tryDecomposeMultiBitBases; this.resultVectorHook = resultVectorHook; + this.depth = depth; moduloMask = (ApInt)ModuloReducer.GetMask(bitSize); groupSizes = GetGroupSizes(variables.Count); numCombinations = (ApInt)Math.Pow(2, variables.Count); @@ -610,6 +611,20 @@ private AstIdx FindTwoTermsUnnegated(ApInt constant, ApInt a, ApInt b) private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demandedMask, ApInt[] variableCombinations, List> linearCombinations) { + + var vNames = this.variables.Select(x => ctx.GetAstString(x)); + //var expected = new List() { "subst594:i32", "subst595:i32", "subst596:i32", "subst597:i32", "(uns45:i64 tr i32)", "(uns48:i64 tr i32)"}; + + /* + var expected = new List() { "subst27:i32", "(uns41:i64 tr i32)", "(uns74:i64 tr i32)", "(uns75:i64 tr i32)" }; + if (expected.All(x => vNames.Any(y => y.Contains(x)))) + Debugger.Break(); + + if (depth > 10) + Debugger.Break(); + */ + + // Collect all variables used in the output expression. List mutVars = new(variables.Count); while (demandedMask != 0) @@ -619,6 +634,8 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded demandedMask &= ~(1ul << xorIdx); } + + var clone = variables.ToList(); AstIdx sum = ctx.Constant(constantOffset, width); for (int i = 0; i < linearCombinations.Count; i++) { @@ -628,14 +645,23 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded continue; var combMask = variableCombinations[i]; - var vComb = ctx.GetConjunctionFromVarMask(mutVars, combMask); + var widths = variables.Select(x => ctx.GetWidth(x)).ToList(); + Console.WriteLine(widths.Distinct().Count()); + foreach (var vIdx in variables) + Console.WriteLine($"{ctx.GetAstString(vIdx)} => {ctx.GetWidth(vIdx)}"); + Console.WriteLine("\n\n"); + if (widths.Distinct().Count() != 1) + { + Debugger.Break(); + } + var vComb = ctx.GetConjunctionFromVarMask(clone, combMask); var term = Term(vComb, curr[0].coeff); sum = ctx.Add(sum, term); } // TODO: Instead of constructing a result vector inside the recursive linear simplifier call, we could instead convert the ANF vector back to DNF. // This should be much more efficient than constructing a result vector via JITing and evaluating an AST representation of the ANF vector. - return LinearSimplifier.Run(width, ctx, sum, false, false, false, variables); + return LinearSimplifier.Run(width, ctx, sum, false, false, false, mutVars, depth: depth + 1); } private void EliminateUniqueValues(Dictionary coeffToTable) From d7b3377a1caaa35e8af3012101ea9394ac3160a4 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 16 Sep 2025 02:48:01 -0400 Subject: [PATCH 05/21] micro optimizations --- Mba.Simplifier/Bindings/AstIdx.cs | 5 ++ .../Minimization/BooleanMinimizer.cs | 2 +- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 86 ++++++++++++------- .../Pipeline/ProbableEquivalenceChecker.cs | 4 +- 4 files changed, 65 insertions(+), 32 deletions(-) diff --git a/Mba.Simplifier/Bindings/AstIdx.cs b/Mba.Simplifier/Bindings/AstIdx.cs index 1a5849f..93d1720 100644 --- a/Mba.Simplifier/Bindings/AstIdx.cs +++ b/Mba.Simplifier/Bindings/AstIdx.cs @@ -22,6 +22,11 @@ public override string ToString() return ctx.GetAstString(Idx); } + public override int GetHashCode() + { + return Idx.GetHashCode(); + } + public unsafe static implicit operator uint(AstIdx reg) => reg.Idx; public unsafe static implicit operator AstIdx(uint reg) => new AstIdx(reg); diff --git a/Mba.Simplifier/Minimization/BooleanMinimizer.cs b/Mba.Simplifier/Minimization/BooleanMinimizer.cs index e88f3fa..a7e37c4 100644 --- a/Mba.Simplifier/Minimization/BooleanMinimizer.cs +++ b/Mba.Simplifier/Minimization/BooleanMinimizer.cs @@ -191,7 +191,7 @@ private static AstIdx MinimizeAnf(AstCtx ctx, IReadOnlyList variables, T } var r = ctx.MinimizeAnf(TableDatabase.Instance.db, truthTable, tempVars, MultibitSiMBA.JitPage.Value); - var backSubst = GeneralSimplifier.ApplyBackSubstitution(ctx, r, invSubstMapping); + var backSubst = GeneralSimplifier.BackSubstitute(ctx, r, invSubstMapping); return backSubst; } } diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index eb5aec2..7174850 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -95,7 +95,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (vars.Count > 11 || vars.Count == 0) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -122,6 +122,11 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) substMapping.Remove(substValue); } + if (substMapping.Count > 8) + { + Console.WriteLine(substMapping.Count); + //Debugger.Break(); + } // Try to take a guess (MSiMBA) and prove it's equivalence var guess = SimplifyViaGuessAndProve(withSubstitutions, substMapping, ref isSemiLinear); if (guess != null) @@ -140,6 +145,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) withSubstitutions = TryUnmergeLinCombs(withSubstitutions, substMapping, ref isSemiLinear); withSubstitutions = SimplifyViaTermRewriting(withSubstitutions); + // If polynomial parts are present, try to simplify them. var inverseMapping = substMapping.ToDictionary(x => x.Value, x => x.Key); AstIdx? reducedPoly = null; @@ -152,7 +158,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (reducedPoly != null) { // Back substitute the original substitutions. - reducedPoly = ApplyBackSubstitution(ctx, reducedPoly.Value, inverseMapping); + reducedPoly = BackSubstitute(ctx, reducedPoly.Value, inverseMapping); // Reset internal state. substMapping.Clear(); @@ -182,7 +188,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (variables.Count > 11) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -193,7 +199,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) var result = withSubstitutions; if (!ctx.IsConstant(withSubstitutions)) result = LinearSimplifier.Run(ctx.GetWidth(withSubstitutions), ctx, withSubstitutions, false, isSemiLinear, false, variables); - var backSub = ApplyBackSubstitution(ctx, result, inverseMapping); + var backSub = BackSubstitute(ctx, result, inverseMapping); // Apply constant folding / term rewriting. var propagated = SimplifyViaTermRewriting(backSub); @@ -634,7 +640,7 @@ private AstIdx TryUnmergeLinCombs(AstIdx withSubstitutions, Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary varToDemandedBits = new(); + var cache = new HashSet(); foreach (var (expr, substVar) in substitutionMapping) - ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, new()); + ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, cache); // Compute the total number of demanded variable bits in the substituted parts. ulong totalDemanded = 0; @@ -849,7 +856,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong if (constrainedIdx == null) { // Simplify the constrained parts. - var withoutSubstitutions = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var r = SimplifyUnconstrained(withoutSubstitutions, varToDemandedBits); if (r == null) return null; @@ -872,7 +879,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong return null; // Simplify unconstrained parts. - var unconstrainedBackSub = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var unconstrainedBackSub = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var unconstrainedSimpl = SimplifyUnconstrained(unconstrainedBackSub, varToDemandedBits); if (unconstrainedSimpl == null) return null; @@ -1035,7 +1042,7 @@ private static (ulong[], List>) GetAnf(uint w private unsafe AstIdx? SimplifyConstrained(AstIdx withSubstitutions, Dictionary substitutionMapping, Dictionary varToDemandedBits) { // Compute a result vector for the original expression - var withoutSubstitutions = ApplyBackSubstitution(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var w = ctx.GetWidth(withoutSubstitutions); var inputVars = ctx.CollectVariables(withoutSubstitutions); var originalResultVec = LinearSimplifier.JitResultVector(ctx, w, ModuloReducer.GetMask(w), inputVars, withoutSubstitutions, true, (ulong)Math.Pow(2, inputVars.Count)); @@ -1044,7 +1051,7 @@ private static (ulong[], List>) GetAnf(uint w var exprToSubstVar = substitutionMapping.OrderBy(x => ctx.GetAstString(x.Value)).ToList(); var allVars = inputVars.Concat(exprToSubstVar.Select(x => x.Value)).ToList(); // Sort them.... var pagePtr = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, true); + new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, false); var jittedWithSubstitutions = (delegate* unmanaged[SuppressGCTransition])pagePtr; // Return null if the expressions are not provably equivalent @@ -1155,12 +1162,12 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width, List i // Jit the input expression var pagePtr1 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, true); + new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, false); var jittedBefore = (delegate* unmanaged[SuppressGCTransition])pagePtr1; // Jit the output expression var pagePtr2 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, true); + new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, false); var jittedAfter = (delegate* unmanaged[SuppressGCTransition])pagePtr2; // Prove that they are equivalent for all possible input combinations @@ -1215,12 +1222,32 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width, List i return expectedExpr; } + public struct DemandedBitsTuple + { + public AstIdx Idx; + + public ulong CurrDemanded; + + public DemandedBitsTuple(AstIdx idx, ulong currDemanded) + { + Idx = idx; + CurrDemanded = currDemanded; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 31 + Idx.GetHashCode(); + hash = hash * 31 + CurrDemanded.GetHashCode(); + return hash; + } + } + // TODO: Cache results to avoid exponentially visiting shared nodes - private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet<(AstIdx idx, ulong currDemanded)> seen) + private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet seen) { - if (seen.Contains((idx, currDemanded))) + if (!seen.Add(new DemandedBitsTuple(idx, currDemanded))) return; - seen.Add((idx, currDemanded)); var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen); var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen); @@ -1307,7 +1334,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar private AstIdx? TrySimplifyMixedPolynomialParts(AstIdx id, Dictionary substMapping, Dictionary inverseSubstMapping, List varList) { // Back substitute in the (possibly) polynomial parts - var newId = ApplyBackSubstitution(ctx, id, inverseSubstMapping); + var newId = BackSubstitute(ctx, id, inverseSubstMapping); // Decompose each term into structured polynomial parts var terms = GetRootTerms(ctx, newId); @@ -1329,7 +1356,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar } // Do a full back substitution again. - result = ApplyBackSubstitution(ctx, result.Value, inverseSubstMapping); + result = BackSubstitute(ctx, result.Value, inverseSubstMapping); // Bail out if this resulted in a worse result. var cost1 = ctx.GetCost(result.Value); @@ -1359,7 +1386,7 @@ private List UnmergePolynomialParts(Dictionary var outPowers = new Dictionary(); foreach (var (factor, degree) in part.ConstantPowers) { - var unmerged = ApplyBackSubstitution(ctx, factor, rewriteMapping); + var unmerged = BackSubstitute(ctx, factor, rewriteMapping); outPowers.TryAdd(unmerged, 0); outPowers[unmerged] += degree; } @@ -1530,8 +1557,8 @@ private List UnmergePolynomialParts(Dictionary } var invBases = basisSubstitutions.ToDictionary(x => x.Value, x => LinearSimplifier.ConjunctionFromVarMask(ctx, allVars, 1, x.Key)); - var backSub = ApplyBackSubstitution(ctx, reduced, invBases); - backSub = ApplyBackSubstitution(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); + var backSub = BackSubstitute(ctx, reduced, invBases); + backSub = BackSubstitute(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); return backSub; } @@ -1888,7 +1915,7 @@ public AstIdx ExpandReduce(AstIdx id, bool polySimplify = true) // Back substitute the substitute variables. var inverseMapping = substMapping.ToDictionary(x => x.Value, x => x.Key); - sum = ApplyBackSubstitution(ctx, sum, inverseMapping); + sum = BackSubstitute(ctx, sum, inverseMapping); // Try to simplify using the general simplifier. sum = ctx.RecursiveSimplify(sum); @@ -2100,18 +2127,19 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) return outPoly; } - public static AstIdx ApplyBackSubstitution(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache = null) + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions) + => BackSubstitute(ctx, id, backSubstitutions, new(16)); + + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache) { - if (cache == null) - cache = new(); if (backSubstitutions.TryGetValue(id, out var backSub)) return backSub; if (cache.TryGetValue(id, out var existing)) return existing; - var op0 = () => ApplyBackSubstitution(ctx, ctx.GetOp0(id), backSubstitutions, cache); - var op1 = () => ApplyBackSubstitution(ctx, ctx.GetOp1(id), backSubstitutions, cache); + var op0 = () => BackSubstitute(ctx, ctx.GetOp0(id), backSubstitutions, cache); + var op1 = () => BackSubstitute(ctx, ctx.GetOp1(id), backSubstitutions, cache); var opcode = ctx.GetOpcode(id); var width = ctx.GetWidth(id); diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 9d98e31..720c5be 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -59,11 +59,11 @@ public ProbableEquivalenceChecker(AstCtx ctx, List variables, AstIdx bef public unsafe bool ProbablyEquivalent(bool slowHeuristics = false) { var jit1 = new Amd64OptimizingJit(ctx); - jit1.Compile(before, variables, pagePtr1, true); + jit1.Compile(before, variables, pagePtr1, false); func1 = (delegate* unmanaged[SuppressGCTransition])pagePtr1; var jit2 = new Amd64OptimizingJit(ctx); - jit2.Compile(after, variables, pagePtr2, true); + jit2.Compile(after, variables, pagePtr2, false); func2 = (delegate* unmanaged[SuppressGCTransition])pagePtr2; var vArray = stackalloc ulong[variables.Count]; From 8f1b2e1b905bbaea5b7d2d7319d29e11bc72b507 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 16 Sep 2025 03:34:39 -0400 Subject: [PATCH 06/21] Optimize demanded bits; Skip walking of linear subtrees --- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 85 +++++++++++++------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 7174850..493ec2e 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -241,6 +241,26 @@ private static ulong Pow(ulong bbase, ulong exponent) private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary substitutionMapping, ref bool isSemiLinear, bool inBitwise = false) { + // This is dubious: Do we actually need to run simba here... for some reason performance degrades if not + // TODO: Maybe comment this out + var cls = ctx.GetClass(id); + if (cls == AstClassification.Bitwise) + return SimplifyViaRecursiveSiMBA(id); + if (cls == AstClassification.BitwiseWithConstants) + { + isSemiLinear = true; + return SimplifyViaRecursiveSiMBA(id); + } + + if (cls == AstClassification.Linear && !inBitwise) + return SimplifyViaRecursiveSiMBA(id); + if (cls == AstClassification.SemiLinear && !inBitwise) + { + isSemiLinear = true; + return SimplifyViaRecursiveSiMBA(id); + } + + // Sometimes we perform constant folding in this method. // To make sure that we correctly track whether the expression is semi-linear, we use this method to process replacements. var visitReplacement = (AstIdx replacementIdx, bool inBitwise, ref bool isSemiLinear) => GetAstWithSubstitutions(replacementIdx, substitutionMapping, ref isSemiLinear, inBitwise); @@ -830,14 +850,17 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong // Compute demanded bits for each variable // TODO: Keep track of which bits are demanded by the parent(withSubstitutions) Dictionary varToDemandedBits = new(); - var cache = new HashSet(); + var cache = new HashSet<(AstIdx idx, ulong currDemanded)>(); + int totalDemanded = 0; foreach (var (expr, substVar) in substitutionMapping) - ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, cache); + { + ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, cache, ref totalDemanded); + if (totalDemanded > 12) + break; + } + - // Compute the total number of demanded variable bits in the substituted parts. - ulong totalDemanded = 0; - foreach (var demandedBits in varToDemandedBits.Values) - totalDemanded += (ulong)BitOperations.PopCount(demandedBits); + // Bail if there are too many demanded bits! if (totalDemanded > 12) return null; @@ -1244,21 +1267,29 @@ public override int GetHashCode() } // TODO: Cache results to avoid exponentially visiting shared nodes - private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet seen) + private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet<(AstIdx idx, ulong currDemanded)> seen, ref int totalDemanded) { - if (!seen.Add(new DemandedBitsTuple(idx, currDemanded))) + if (totalDemanded > 12) + return; + if (!seen.Add((idx, currDemanded))) return; - var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen); - var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen); + totalDemanded += 1; + + var op0 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); + var op1 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); var opc = ctx.GetOpcode(idx); switch (opc) { // If we have a symbol, union the set of demanded bits case AstOp.Symbol: - symbolDemandedBits.TryAdd(idx, 0); - symbolDemandedBits[idx] |= currDemanded; + //symbolDemandedBits.TryAdd(idx, 0); + symbolDemandedBits.TryGetValue(idx, out var oldDemanded); + var newDemanded = oldDemanded | currDemanded; + symbolDemandedBits[idx] = newDemanded; + totalDemanded += BitOperations.PopCount(newDemanded & ~oldDemanded); + break; // If we have a constant, there is nothing to do. case AstOp.Constant: @@ -1273,22 +1304,22 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar var demandedWidth = 64 - (uint)BitOperations.LeadingZeroCount(currDemanded); currDemanded = ModuloReducer.GetMask(demandedWidth); - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.Lshr: var shiftBy = ctx.GetOp1(idx); var shiftByConstant = ctx.TryGetConstantValue(shiftBy); if (shiftByConstant == null) { - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; } // If we know the value we are shifting by, we can truncate the demanded bits. - op0(currDemanded >> (ushort)shiftByConstant.Value); - op1(currDemanded); + op0(currDemanded >> (ushort)shiftByConstant.Value, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.And: @@ -1296,8 +1327,8 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a&b, demandedbits(a) does not include any known zero bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Zeroes & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Zeroes & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } @@ -1306,25 +1337,25 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a|b, demandedbits(a) does not include any known one bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Ones & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Ones & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } // TODO: We can gain some precision by exploiting XOR known bits. case AstOp.Xor: - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; // TODO: Treat negation as x^-1, then use XOR transfer function case AstOp.Neg: - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Trunc: currDemanded &= ModuloReducer.GetMask(ctx.GetWidth(idx)); - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Zext: - op0(currDemanded & ModuloReducer.GetMask(ctx.GetWidth(ctx.GetOp0(idx)))); + op0(currDemanded & ModuloReducer.GetMask(ctx.GetWidth(ctx.GetOp0(idx))), ref totalDemanded); break; default: throw new InvalidOperationException($"Cannot compute demanded bits for {opc}"); From f17ab3f3281cb9a047076bde48b0c78959179263 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Sat, 20 Sep 2025 02:19:48 -0400 Subject: [PATCH 07/21] Debugging / fixes --- Mba.Simplifier/Pipeline/GeneralSimplifier.cs | 19 ++- Mba.Simplifier/Pipeline/LinearSimplifier.cs | 13 +- .../Pipeline/ProbableEquivalenceChecker.cs | 68 ++++++---- Mba.Simplifier/Utility/DagFormatter.cs | 102 ++++++++++++++ Simplifier/DatasetTester.cs | 125 ++++++++++++++++++ Simplifier/Program.cs | 37 ++++++ 6 files changed, 328 insertions(+), 36 deletions(-) create mode 100644 Mba.Simplifier/Utility/DagFormatter.cs create mode 100644 Simplifier/DatasetTester.cs diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 493ec2e..b7dba93 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -20,8 +20,14 @@ namespace Mba.Simplifier.Pipeline { + + public class GeneralSimplifier { + public static bool DbgLog = false; + + private const bool REDUCE_POLYS = false; + private readonly AstCtx ctx; // For any given node, we store the best possible ISLE result. @@ -170,7 +176,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // If there are any substitutions, we want to try simplifying the polynomial parts. var variables = ctx.CollectVariables(withSubstitutions); - if (polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) + if (REDUCE_POLYS && polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) { var maybeSimplified = TrySimplifyMixedPolynomialParts(withSubstitutions, substMapping, inverseMapping, variables); if (maybeSimplified != null && maybeSimplified.Value != id) @@ -241,6 +247,7 @@ private static ulong Pow(ulong bbase, ulong exponent) private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary substitutionMapping, ref bool isSemiLinear, bool inBitwise = false) { + /* // This is dubious: Do we actually need to run simba here... for some reason performance degrades if not // TODO: Maybe comment this out var cls = ctx.GetClass(id); @@ -252,6 +259,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub return SimplifyViaRecursiveSiMBA(id); } + // Note: These two checks seem to hurt performance too! if (cls == AstClassification.Linear && !inBitwise) return SimplifyViaRecursiveSiMBA(id); if (cls == AstClassification.SemiLinear && !inBitwise) @@ -259,6 +267,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub isSemiLinear = true; return SimplifyViaRecursiveSiMBA(id); } + */ // Sometimes we perform constant folding in this method. @@ -334,6 +343,14 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var oldSum = sum; var newSum = ctx.SingleSimplify(sum); sum = newSum; + + + if(GeneralSimplifier.DbgLog) + { + Console.WriteLine($"ConstTerm = {DagFormatter.Format(ctx, v0)}"); + Console.WriteLine($"\nWhole dag: {DagFormatter.Format(ctx, sum)}\n\n\n\n\n"); + } + // In this case, we apply constant folding(but we do not search recursively). return GetAstWithSubstitutions(sum, substitutionMapping, ref isSemiLinear, inBitwise); diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index 3c911c5..639c144 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -646,14 +646,11 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded var combMask = variableCombinations[i]; var widths = variables.Select(x => ctx.GetWidth(x)).ToList(); - Console.WriteLine(widths.Distinct().Count()); - foreach (var vIdx in variables) - Console.WriteLine($"{ctx.GetAstString(vIdx)} => {ctx.GetWidth(vIdx)}"); - Console.WriteLine("\n\n"); - if (widths.Distinct().Count() != 1) - { - Debugger.Break(); - } + //Console.WriteLine(widths.Distinct().Count()); + //foreach (var vIdx in variables) + // Console.WriteLine($"{ctx.GetAstString(vIdx)} => {ctx.GetWidth(vIdx)}"); + //Console.WriteLine("\n\n"); + var vComb = ctx.GetConjunctionFromVarMask(clone, combMask); var term = Term(vComb, curr[0].coeff); sum = ctx.Add(sum, term); diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 720c5be..ab4cc2b 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -5,6 +5,7 @@ using Microsoft.Z3; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -128,6 +129,8 @@ private unsafe bool AllCombs(ulong* vArray, ulong a, ulong b) return false; if (!SignatureVectorEquivalent(vArray, a, b)) return false; + if (!SignatureVectorEquivalent(vArray, b, a)) + return false; return true; } @@ -168,34 +171,45 @@ private ulong Next() public static void ProbablyEquivalentZ3(AstCtx ctx, AstIdx before, AstIdx after) { - var z3Ctx = new Context(); - var translator = new Z3Translator(ctx, z3Ctx); - var beforeZ3 = translator.Translate(before); - var afterZ3 = translator.Translate(after); - var solver = z3Ctx.MkSolver("QF_BV"); - - // Set the maximum timeout to 10 seconds. - var p = z3Ctx.MkParams(); - uint solverLimit = 10000; - p.Add("timeout", solverLimit); - solver.Parameters = p; - - Console.WriteLine("Proving equivalence...\n"); - solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); - var check = solver.Check(); - - var printModel = (Model model) => + using (var z3Ctx = new Context()) { - var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); - return $"[{String.Join(", ", values)}]"; - }; - - if (check == Status.UNSATISFIABLE) - Console.WriteLine("Expressions are equivalent."); - else if (check == Status.SATISFIABLE) - Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); - else - Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + var translator = new Z3Translator(ctx, z3Ctx); + var beforeZ3 = translator.Translate(before); + var afterZ3 = translator.Translate(after); + var solver = z3Ctx.MkSolver("QF_BV"); + + // Set the maximum timeout to 10 seconds. + var p = z3Ctx.MkParams(); + uint solverLimit = 5000; + p.Add("timeout", solverLimit); + solver.Parameters = p; + + Console.WriteLine("Proving equivalence...\n"); + solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); + var check = solver.Check(); + + var printModel = (Model model) => + { + var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); + return $"[{String.Join(", ", values)}]"; + }; + + if (check == Status.UNSATISFIABLE) + { + //Console.WriteLine("Expressions are equivalent."); + } + else if (check == Status.SATISFIABLE) + { + Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); + Debugger.Break(); + throw new InvalidOperationException(); + + } + else + { + //Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + } + } } } diff --git a/Mba.Simplifier/Utility/DagFormatter.cs b/Mba.Simplifier/Utility/DagFormatter.cs new file mode 100644 index 0000000..936d147 --- /dev/null +++ b/Mba.Simplifier/Utility/DagFormatter.cs @@ -0,0 +1,102 @@ +using Mba.Simplifier.Bindings; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Mba.Simplifier.Utility +{ + public static class DagFormatter + { + public static string Format(AstCtx ctx, AstIdx idx) + { + var sb = new StringBuilder(); + Format(sb, ctx, idx, new()); + return sb.ToString(); + } + + private static void Format(StringBuilder sb, AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + // Allocate value numbers for the operands if necessary + var opc = ctx.GetOpcode(idx); + var opcount = GetOpCount(opc); + if (opcount >= 1 && !valueNumbers.ContainsKey(ctx.GetOp0(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp0(idx))) + Format(sb, ctx, ctx.GetOp0(idx), valueNumbers); + if (opcount >= 2 && !valueNumbers.ContainsKey(ctx.GetOp1(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp1(idx))) + Format(sb, ctx, ctx.GetOp1(idx), valueNumbers); + + var op0 = () => $"{Lookup(ctx, ctx.GetOp0(idx), valueNumbers)}"; + var op1 = () => $"{Lookup(ctx, ctx.GetOp1(idx), valueNumbers)}"; + + var vNum = valueNumbers.Count; + valueNumbers.Add(idx, vNum); + var width = ctx.GetWidth(idx); + if (opc == AstOp.Symbol) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetSymbolName(idx)}"); + else if (opc == AstOp.Constant) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetConstantValue(idx)}"); + else if (opc == AstOp.Neg) + sb.AppendLine($"i{width} t{vNum} = ~{op0()}"); + else if (opc == AstOp.Zext || opc == AstOp.Trunc) + { + sb.AppendLine($"i{width} t{vNum} = {GetOperatorName(opc)} i{ctx.GetWidth(ctx.GetOp0(idx))} {op0()} to i{width}"); + } + else + { + sb.AppendLine($"i{width} t{vNum} = {op0()} {GetOperatorName(opc)} {op1()}"); + } + } + + private static bool IsConstOrSymbol(AstCtx ctx, AstIdx idx) + => ctx.GetOpcode(idx) == AstOp.Constant || ctx.GetOpcode(idx) == AstOp.Symbol; + + private static string Lookup(AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + var opc = ctx.GetOpcode(idx); + if (opc == AstOp.Constant) + return ctx.GetConstantValue(idx).ToString(); + if (opc == AstOp.Symbol) + return ctx.GetSymbolName(idx); + return $"t{valueNumbers[idx]}"; + } + + private static int GetOpCount(AstOp opc) + { + return opc switch + { + AstOp.None => 0, + AstOp.Add => 2, + AstOp.Mul => 2, + AstOp.Pow => 2, + AstOp.And => 2, + AstOp.Or => 2, + AstOp.Xor => 2, + AstOp.Neg => 1, + AstOp.Lshr => 2, + AstOp.Constant => 0, + AstOp.Symbol => 0, + AstOp.Zext => 1, + AstOp.Trunc => 1, + }; + } + + private static string GetOperatorName(AstOp opc) + { + return opc switch + { + AstOp.Add => "+", + AstOp.Mul => "*", + AstOp.Pow => "**", + AstOp.And => "&", + AstOp.Or => "|", + AstOp.Xor => "^", + AstOp.Neg => "~", + AstOp.Lshr => ">>", + AstOp.Zext => "zext", + AstOp.Trunc => "trunc", + _ => throw new InvalidOperationException(), + }; + } + } +} diff --git a/Simplifier/DatasetTester.cs b/Simplifier/DatasetTester.cs new file mode 100644 index 0000000..f010f27 --- /dev/null +++ b/Simplifier/DatasetTester.cs @@ -0,0 +1,125 @@ +using Mba.Simplifier.Bindings; +using Mba.Simplifier.Pipeline; +using Mba.Simplifier.Utility; +using Mba.Utility; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection.Metadata; +using System.Text; +using System.Threading.Tasks; + +namespace Simplifier +{ + public static class DatasetTester + { + public static void Run() + { + Console.WriteLine(" "); + var lines = File.ReadLines("C:\\Users\\colton\\source\\repos\\mba-database\\real-world-nonlinear-full.txt"); + var beforeAndAfter = lines.Select(x => (x.Split(",")[0], x.Split(",")[1])).ToList(); + + var ctx = new AstCtx(); + AstIdx.ctx = ctx; + + var asts = beforeAndAfter.Select(x => (RustAstParser.Parse(ctx, x.Item1, 64), RustAstParser.Parse(ctx, x.Item2, 64))).ToList(); + + + Parallel.ForEach(asts, x => + { + + ProbableEquivalenceChecker.ProbablyEquivalentZ3(ctx, x.Item1, x.Item2); + } + ); + + foreach(var (before, after) in asts) + { + ProbableEquivalenceChecker.ProbablyEquivalentZ3(ctx, before, after); + } + + Debugger.Break(); + + foreach (var (strBefore_, strAfter) in beforeAndAfter) + { + var strBefore = strBefore_; + + //if (strBefore != "((((((((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))^(213:i64|(-214:i64&RSI:i64)))*2199023256422:i64)+(((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&(~((213:i64&RSI:i64)|(-214:i64^(-214:i64&RSI:i64)))))*7378699388702425784:i64))+(((((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&RSI:i64)^-1:i64)|213:i64)*-7378698289190797573:i64))+(7378697189679169362:i64*(~(5:i64&RSI:i64))))+(((-1:i64*(~((-209:i64&RSI:i64)|(208:i64^(208:i64&RSI:i64)))))+((4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))|RSI:i64))*3689348594839584681:i64))+((((-9223372036854775808:i64^((-9223372036854775803:i64&RSI:i64)|(9223372036854775802:i64^(9223372036854775802:i64&RSI:i64))))+(0:i64+(((-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))&RSI:i64)*-1:i64)))+(0:i64+(((4040198467629586909:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))&(~RSI:i64))*-1:i64)))*3689349694351212892:i64))") + // continue; + + + //strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //strBefore = "(7610965373738707464:i64+((((956575116354345:i64*(5292288:i64&RBX:i64))+(3689349694351212892:i64*(4611686018427387690:i64&RSI:i64)))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + // strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //strBefore = "(228698418667888:i64+((((((((3689349694351212892:i64*(4611686018427387690:i64&RSI:i64))+(3689348594839584681:i64*(5:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(691752243057131259:i64*(208:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(-3689348594839584681:i64*(5:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-691754442080387681:i64*(208:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "(228698418667888:i64+((((3689349694351212892:i64*(4611686018427387690:i64&RSI:i64))+(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))"; + + //strBefore = "-214&((4040198467629586696+(1099511628211*(5292288&RBX))))"; + + //strBefore = "(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))"; + + //strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //if (strBefore != "(((-433557024052896108:i64+(46702856230664876:i64*(5292288:i64&RBX:i64)))+((((0:i64+((((7610965373738707464:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))|5370260760:i64)&-125:i64)*-1:i64))+((-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))^5370260736:i64))+(-7610965373738707565:i64+(71101018921573591:i64*(5292288:i64&RBX:i64))))*2199023256422:i64))+((((7610965373738707572:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))^5370260760:i64)+0:i64)*3298534884633:i64))") + // continue; + //strBefore = "(((-433557024052896108:i64+(46702856230664876:i64*(5292288:i64&RBX:i64)))+((((0:i64+((((7610965373738707464:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))|5370260760:i64)&-125:i64)*-1:i64))+((-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))^5370260736:i64))+(-7610965373738707565:i64+(71101018921573591:i64*(5292288:i64&RBX:i64))))*2199023256422:i64))+((((7610965373738707572:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))^5370260760:i64)+0:i64)*3298534884633:i64))"; + + //strBefore = "(2342386684228996530:i64+((((23351428115332438:i64*(5292288:i64&RBX:i64))+(1099511628211:i64*(-5370260861:i64&(7610965373738707456:i64+(956575116354345:i64*(5292288:i64&RBX:i64))))))+(-3298534884633:i64*(5370260744:i64&(7610965373738707456:i64+(956575116354345:i64*(5292288:i64&RBX:i64))))))+(2199023256422:i64*(5370260736:i64^(-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&-214:i64)"; + //strBefore = "(-214:i64&(1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64))))"; + + var before = RustAstParser.Parse(ctx, strBefore, 64); + + + + var after = RustAstParser.Parse(ctx, strAfter, 64); + + + + // if (ctx.GetAstString(before) != "((~(-72056494526299725:i64*(5631088361628047935:i64|(5292288:i64&RBX:i64))))&213:i64)") + // continue; + + + + var cls = ctx.GetClass(before); + var clsAfter = ctx.GetClass(after); + + var generalSimplifier = new GeneralSimplifier(ctx); + var simplified = generalSimplifier.SimplifyGeneral(before); + for (int i = 0; i < 10; i++) + { + generalSimplifier = new(ctx); + simplified = generalSimplifier.SimplifyGeneral(simplified); + } + + var kb = ctx.GetKnownBits(simplified); + var kb2 = ctx.GetKnownBits(before); + Console.WriteLine(kb.ToString() == kb2.ToString()); + + var r = LinearSimplifier.Run(ctx.GetWidth(before), ctx, before, false, true); + var r2 = LinearSimplifier.Run(ctx.GetWidth(simplified), ctx, simplified, false, true); + if (r != r2) + Debugger.Break(); + var rClass = ctx.GetClass(r); + var simplifiedClass = ctx.GetClass(simplified); + if (ctx.GetClass(after) != simplifiedClass) + Debugger.Break(); + //if (ctx.GetClass(before) == AstClassification.Nonlinear) + // Debugger.Break(); + //if (ctx.GetClass(simplified) != AstClassification.Nonlinear || !ctx.IsConstant(r))) + // continue; + + //Debugger.Break(); + } + + Debugger.Break(); + } + } +} diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 476bbb8..f110c50 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -9,6 +9,7 @@ using Mba.Simplifier.Utility; using Mba.Utility; using Microsoft.Z3; +using Simplifier; using System.ComponentModel; using System.Diagnostics; @@ -18,8 +19,29 @@ bool proveEquivalence = true; string inputText = null; +//DatasetTester.Run(); + + inputText = "((((1:i32&((uns17:i8 zx i32)&(~uns18:i32)))|(4294964010:i32&(~((uns17:i8 zx i32)|(~uns18:i32)))))|(4294964011:i32&((uns17:i8 zx i32)&uns18:i32)))|(4:i32*(1:i32&(uns19:i8 zx i32))))"; +inputText = "((2041933603239772578:i64+((((((((((((((((-27487790705275:i64*uns121:i64)+(-9223358842715237276:i64*(-860922984064492326:i64&uns121:i64)))+(9223354444668724432:i64*uns131:i64))+(-9223350046622211588:i64*(860922984064492325:i64&uns131:i64)))+(-8796093025688:i64*uns132:i64))+(4398046512844:i64*uns34:i64))+(17592186051376:i64*uns65:i64))+(-3298534884633:i64*uns91:i64))+(9223367638808262964:i64*(8362449052790283482:i64&uns91:i64)))+(13194139538532:i64*(860922984064492325:i64&(uns121:i64&uns130:i64))))+(14293651166743:i64*(-3750763034362895579:i64&(uns121:i64&uns67:i64))))+(4398046512844:i64*(uns130:i64&uns133:i64)))+(-8796093025688:i64*(1444920025149201626:i64&(uns130:i64&uns91:i64))))+(-4398046512844:i64*(uns131:i64&uns133:i64)))+(-9223350046622211588:i64*(3750763034362895578:i64&(uns131:i64&uns91:i64))))+(-9895604653899:i64*((uns121:i64&uns130:i64)&uns91:i64))))+(3062923494603851298:i64+(((((((-9895604653899:i64*uns130:i64)+(9895604653899:i64*(-3750763034362895579:i64&uns131:i64)))+(9895604653899:i64*(3750763034362895578:i64&uns17:i64)))+(-9895604653899:i64*(-3750763034362895579:i64&(uns121:i64&uns131:i64))))+(9895604653899:i64*(uns130:i64&uns134:i64)))+(-9895604653899:i64*(uns131:i64&uns134:i64)))+(9895604653899:i64*(-3750763034362895579:i64&(uns131:i64&uns91:i64))))))"; + +inputText = "((-1:i64*(~((((8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64))))&((~uns4:i64)^(((((~uns4:i64)&((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))|(((~uns4:i64)&((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&uns16:i64)))|(255:i64&(uns16:i64&((((~uns4:i64)&(~((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&(~((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&(~(8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))))))|(-256:i64&(((uns16:i64&(8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&(((~uns4:i64)&((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))|((~uns4:i64)&((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))))))))+(-1:i64*(((~uns16:i64)&(8614011786586714483:i64+(((((((4398046512844:i64*(5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64))))))+(4398046512844:i64*(-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))))+(6919028725695267695:i64*(183:i64&uns16:i64)))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))|(uns4:i64&(8614011786586714483:i64+(((((((4398046512844:i64*(5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64))))))+(4398046512844:i64*(-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))))+(6919028725695267695:i64*(183:i64&uns16:i64)))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64))))))))"; + +inputText = "(((-1099511628211:i64*((uns173:i64&(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))|(uns174:i64&(~(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))))))+(uns175:i64*(-2:i64+(-1:i64*uns174:i64))))+(2199023256422:i64*((((-1:i64*(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))+(-1:i64*uns173:i64))+(2:i64*((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns173:i64)))+((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns174:i64))))"; + +inputText = "(2374945116151681 + 1152921504606846706*(x&8796093022192)) + (-2383741209174193 + 271*(x&8796093022192))"; + +inputText = "(3*(x&0x7FFFFFFFFF0) - 0x180000000490) - 0x100000000350 + 0x178"; + +inputText = "(-43980465112680+(3*(-8796093022208|(8796093022192&(x&~15)))))"; + +inputText = "((0xFFFFFFFFFFFFFE71 * ~(~a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0)))) + (~a4 & (v37 + 0x570)) - ((~a4 & (v37 + 0x570)) | a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0)))) + (0xFFFFFFFFFFFFFE70 * (~a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0))))"; + +inputText = "((2 * (a1 & 0x7FFFFFFFFF0)) - 0x100000000350 + 0x178)"; + +inputText = " (2 * (x & 0x7FFFFFFFFF0)) - 0x100000000350 + 0x178"; + var printHelp = () => { Console.WriteLine("Usage: Simplifier.exe"); @@ -74,6 +96,21 @@ Console.WriteLine($"\nExpression: {ctx.GetAstString(id)}\n\n\n"); +Console.WriteLine(DagFormatter.Format(ctx, id)); + + +var bx = LinearSimplifier.Run(bitWidth, ctx, id, false, true); + +while(false) +{ + var simplifier = new GeneralSimplifier(ctx); + + var sw = Stopwatch.StartNew(); + var r = simplifier.SimplifyGeneral(id); + sw.Stop(); + Console.WriteLine($"Took {sw.ElapsedMilliseconds}ms"); +} + var input = id; id = ctx.RecursiveSimplify(id); for (int i = 0; i < 3; i++) From 8886e007fd0f9bb0c8f91f9adbe11bcb6c6af137 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Mon, 29 Sep 2025 14:57:11 -0400 Subject: [PATCH 08/21] changes --- Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs | 5 +++++ Simplifier/Program.cs | 2 ++ 2 files changed, 7 insertions(+) diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index ab4cc2b..6888f21 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -100,6 +100,7 @@ public unsafe bool ProbablyEquivalent(bool slowHeuristics = false) return true; } + public static bool Log = false; private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses) { var clone = new ulong[variables.Count]; @@ -114,6 +115,10 @@ private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses) var op1 = func1(vArray); var op2 = func2(vArray); + + if(Log) + Console.WriteLine($"{op1}, {op2}"); + if (op1 != op2) return false; } diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index f110c50..8d1e0df 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -42,6 +42,8 @@ inputText = " (2 * (x & 0x7FFFFFFFFF0)) - 0x100000000350 + 0x178"; +inputText = "-3 * ~e_cr3 + (mask ^ e_cr3) + -2 * (mask ^ (mask | e_cr3)) + 2 * (((byte_1400807D0 & 0x10 | 0x3F71D992FBB2CCEB) ^ 0xC08E266D044D3314) + (~e_cr3 | (mask ^ 0x3F71D992FBB2CCEB))) - (mask ^ ~e_cr3 ^ 0x3F71D992FBB2CCEB) - 0x3F71D992FBB2CCEB"; + var printHelp = () => { Console.WriteLine("Usage: Simplifier.exe"); From 4eb91e182ebd29eb0f82acbd334548a10eb4001b Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Tue, 18 Nov 2025 10:25:45 -0500 Subject: [PATCH 09/21] Disable treating the result vector as linear if any variables are truncated --- Mba.Simplifier/Bindings/AstCtx.cs | 1 + Mba.Simplifier/Pipeline/LinearSimplifier.cs | 8 ++++++++ Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Mba.Simplifier/Bindings/AstCtx.cs b/Mba.Simplifier/Bindings/AstCtx.cs index 3958d40..8d9cc92 100644 --- a/Mba.Simplifier/Bindings/AstCtx.cs +++ b/Mba.Simplifier/Bindings/AstCtx.cs @@ -49,6 +49,7 @@ public unsafe AstCtx() // Constructors public unsafe AstIdx Add(AstIdx a, AstIdx b) => Api.ContextAdd(this, a, b); + public unsafe AstIdx Sub(AstIdx a, AstIdx b) => Add(a, Mul(Constant(ulong.MaxValue, GetWidth(b)), b)); public unsafe AstIdx Mul(AstIdx a, AstIdx b) => Api.ContextMul(this, a, b); public unsafe AstIdx Pow(AstIdx a, AstIdx b) => Api.ContextPow(this, a, b); public unsafe AstIdx And(AstIdx a, AstIdx b) => Api.ContextAnd(this, a, b); diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index 639c144..8579d57 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -240,6 +240,14 @@ private AstIdx Simplify(bool useZ3 = false, bool alreadySplit = false) // If we have a multi-bit result vector, try to rewrite as a linear result vector. If possible, update state accordingly. private unsafe bool IsLinearResultVector() { + foreach(var v in variables) + { + // If the variable is zero extended or truncated, we treat this as a semi-linear signature vector. + // Truncation cannot be treated as linear, though in the future we may be able to get away with treating zero extension as linear? + if (!ctx.IsSymbol(v)) + return false; + } + fixed (ApInt* ptr = &resultVector[0]) { ushort bitIndex = 0; diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 6888f21..561a4e9 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -207,7 +207,7 @@ public static void ProbablyEquivalentZ3(AstCtx ctx, AstIdx before, AstIdx after) { Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); Debugger.Break(); - throw new InvalidOperationException(); + // throw new InvalidOperationException(); } else From f4d6a5a83381e90caaa79b3e52fa02030b1d36f0 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Wed, 19 Nov 2025 13:24:51 -0500 Subject: [PATCH 10/21] internal stuff --- Mba.Simplifier/Pipeline/LinearSimplifier.cs | 109 ++++++++++++++++---- Simplifier/Program.cs | 16 ++- 2 files changed, 105 insertions(+), 20 deletions(-) diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index 8579d57..0481f25 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -23,6 +23,7 @@ using static Antlr4.Runtime.Atn.SemanticContext; using Mba.Simplifier.Interpreter; using Mba.Simplifier.Utility; +using Microsoft.VisualBasic; namespace Mba.Simplifier.Pipeline { @@ -45,8 +46,16 @@ public class LinearSimplifier // If enabled, we try to find a simpler representation of grouping of basis expressions. private readonly bool tryDecomposeMultiBitBases; + // For internal use in private projects (do not use) private readonly Action? resultVectorHook; + + private readonly ulong[] inVec; + private readonly int depth; + + // + private readonly Dictionary anfDemandedBits; + private readonly ApInt moduloMask = 0; // Number of combinations of input variables(2^n), for a single bit index. @@ -69,14 +78,14 @@ public class LinearSimplifier private AstIdx? initialInput = null; - public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) + public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0, Dictionary anfDemandedBits = null) { if (variables == null) variables = ctx.CollectVariables(ast.Value); - return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth).Simplify(false, alreadySplit); + return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth, anfDemandedBits).Simplify(false, alreadySplit); } - public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) + public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0, Dictionary anfDemandedBits = null) { // If we are given an AST, verify that the correct width was passed. if (ast != null && bitSize != ctx.GetWidth(ast.Value)) @@ -90,7 +99,9 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables this.multiBit = multiBit; this.tryDecomposeMultiBitBases = tryDecomposeMultiBitBases; this.resultVectorHook = resultVectorHook; + this.inVec = inVec; this.depth = depth; + this.anfDemandedBits = anfDemandedBits; moduloMask = (ApInt)ModuloReducer.GetMask(bitSize); groupSizes = GetGroupSizes(variables.Count); numCombinations = (ApInt)Math.Pow(2, variables.Count); @@ -223,7 +234,7 @@ private AstIdx Simplify(bool useZ3 = false, bool alreadySplit = false) // If we were given a semi-linear expression, and the ground truth of that expression is linear, // truncate the size of the result vector down to 2^t, then treat it as a linear MBA. - if (multiBit && IsLinearResultVector()) + if (multiBit && IsLinearResultVector()) { multiBit = false; Array.Resize(ref resultVector, (int)numCombinations); @@ -939,7 +950,9 @@ public static ApInt SubtractConstantOffset(ApInt moduloMask, ApInt[] resultVecto if (multiBit) { - var r = SimplifyOneValueMultibit(constant, resultVector.ToArray(), variableCombinations); + //SimplifyManyValuesMultibit(constant, resultVector.ToArray()); + + var r = SimplifyOneValueMultibit(constant, resultVector.ToArray()); if (r != null) { CheckSolutionComplexity(r.Value, 1, null); @@ -1002,6 +1015,18 @@ public static ApInt SubtractConstantOffset(ApInt moduloMask, ApInt[] resultVecto } } + if(anfDemandedBits != null) + { + for(int i = 0; i < linearCombinations.Count; i++) + { + anfDemandedBits.TryAdd((ApInt)i, 0); + foreach(var (coeff, mask) in linearCombinations[i]) + { + anfDemandedBits[(ApInt)i] |= mask; + } + } + } + // Identify variables that are not present in any conjunction. // E.g. if we have a + (b&c), then a is not present in a conjunction, while b is. var withNoConjunctions = GetVariablesWithNoConjunctions(variableCombinations, linearCombinations); @@ -1288,6 +1313,11 @@ private ulong GetVariablesWithNoConjunctions(ulong[] variableCombinations, List< private ApInt? TryGetSingleCoeff((ApInt coeff, ApInt mask)[] uniqueCoeffs) { + int succCount = 0; + int fCount = 0; + + Dictionary<(ApInt coeff, ApInt mask), int> successes = new(); + foreach (var (coeff, mask) in uniqueCoeffs) { bool success = true; @@ -1300,9 +1330,23 @@ private ulong GetVariablesWithNoConjunctions(ulong[] variableCombinations, List< if (TryRewrite(otherCoeff, coeff, otherMask) == null) { + fCount += 1; success = false; break; } + + else if (TryRewrite(otherCoeff, coeff, mask | otherMask) != null) + { + + succCount += 1; + } + + else + { + succCount += 1; + } + + } if (success) @@ -1398,29 +1442,56 @@ private ulong GetVariablesWithNoConjunctions(ulong[] variableCombinations, List< return null; } - private AstIdx? SimplifyOneValueMultibit(ulong constant, ApInt[] withoutConstant, ApInt[] variableCombinations) + private void SimplifyManyValuesMultibit(ulong constant, ApInt[] withoutConstant) { - // Algorithm: Start at some point, check if you can change every coefficient to the target coefficient - bool truthTableIdx = true; - if (!truthTableIdx) - variableCombinations = new List() { 0 }.Concat(variableCombinations).ToArray(); + if (width != 64) + return; - var getConj = (ApInt i, ApInt? mask) => + List<(ApInt coeff, ApInt bitMask)> linearCombinations = new(); + for (ushort bitIndex = 0; bitIndex < GetNumBitIterations(multiBit, width); bitIndex++) { - if (truthTableIdx) + //Console.WriteLine(""); + var offset = bitIndex * numCombinations; + + for (int i = 0; i < (int)numCombinations; i++) { - var boolean = GetBooleanForIndex((int)i); - if (mask == null) - return boolean; + if (i != 1) + withoutConstant[(int)offset + i] = 0; - return ctx.And(ctx.Constant(mask.Value, width), boolean); + var coeff = withoutConstant[(int)offset + i]; + + coeff = refiner.MinimizeCoeff(coeff, bitIndex);; + + /* + //var str = coeff.ToString(); + var str = $"{coeff} * (x{i}&{1ul << bitIndex})"; + var spaces = String.Join("", Enumerable.Repeat(" ", 32 - str.Length)); + Console.Write($"{str + spaces}"); + */ + + if(coeff != 0) + { + linearCombinations.Add((coeff, 1ul << bitIndex)); + } } + } - return ConjunctionFromVarMask(1, i, mask); - }; + var w = refiner.SimplifyMultibitEntry(linearCombinations); + if (w.Count < 5) + { + Debugger.Break(); + Console.WriteLine(w); + } - AstIdx.ctx = ctx; + var single = SimplifyOneValueMultibit(constant, withoutConstant); + if (single != null) + Debugger.Break(); + } + + // Algorithm: Start at some point, check if you can change every coefficient to the target coefficient + private AstIdx? SimplifyOneValueMultibit(ulong constant, ApInt[] withoutConstant) + { // Reduce each row to a canonical form. If a row cannot be canonicalized, there is no solution. var uniqueCoeffs = TryReduceRows(constant, withoutConstant); if (uniqueCoeffs == null) diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 8d1e0df..44bcfa5 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -44,6 +44,18 @@ inputText = "-3 * ~e_cr3 + (mask ^ e_cr3) + -2 * (mask ^ (mask | e_cr3)) + 2 * (((byte_1400807D0 & 0x10 | 0x3F71D992FBB2CCEB) ^ 0xC08E266D044D3314) + (~e_cr3 | (mask ^ 0x3F71D992FBB2CCEB))) - (mask ^ ~e_cr3 ^ 0x3F71D992FBB2CCEB) - 0x3F71D992FBB2CCEB"; +inputText = "((((((((((((((((((((44:i64*(256:i64&v0:i64))+(22:i64*(512:i64&v0:i64)))+(11:i64*(1024:i64&v0:i64)))+(6:i64*(2048:i64&v0:i64)))+(2:i64*(12288:i64&v0:i64)))+(-12288:i64&v0:i64))+(44:i64*(256:i64&v1:i64)))+(22:i64*(512:i64&v1:i64)))+(11:i64*(1024:i64&v1:i64)))+(6:i64*(2048:i64&v1:i64)))+(2:i64*(12288:i64&v1:i64)))+(-12288:i64&v1:i64))+(88:i64*(128:i64&(v0:i64&v1:i64))))+(221567343985614:i64*(256:i64&(v0:i64&v1:i64))))+(-21:i64*(512:i64&(v0:i64&v1:i64))))+(-10:i64*(1024:i64&(v0:i64&v1:i64))))+(221567343985652:i64*(2048:i64&(v0:i64&v1:i64))))+(-2:i64*(4096:i64&(v0:i64&v1:i64))))+(-1:i64*(-9223372036854767616:i64&(v0:i64&v1:i64))))+(221567343985657:i64*(5218521314454437888:i64&(v0:i64&v1:i64))))"; +inputText = "((((((((((((((((((((44:i16*(256:i16&v0:i16))+(22:i16*(512:i16&v0:i16)))+(11:i16*(1024:i16&v0:i16)))+(6:i16*(2048:i16&v0:i16)))+(2:i16*(12288:i16&v0:i16)))+(-12288:i16&v0:i16))+(44:i16*(256:i16&v1:i16)))+(22:i16*(512:i16&v1:i16)))+(11:i16*(1024:i16&v1:i16)))+(6:i16*(2048:i16&v1:i16)))+(2:i16*(12288:i16&v1:i16)))+(-12288:i16&v1:i16))+(88:i16*(128:i16&(v0:i16&v1:i16))))+(221567343985614:i16*(256:i16&(v0:i16&v1:i16))))+(-21:i16*(512:i16&(v0:i16&v1:i16))))+(-10:i16*(1024:i16&(v0:i16&v1:i16))))+(221567343985652:i16*(2048:i16&(v0:i16&v1:i16))))+(-2:i16*(4096:i16&(v0:i16&v1:i16))))+(-1:i16*(-9223372036854767616:i16&(v0:i16&v1:i16))))+(221567343985657:i16*(5218521314454437888:i16&(v0:i16&v1:i16))))"; +//inputText = "((((((((((((((((((((((11111:i64+(255:i64&v0:i64))+(72057594037927893:i64*(256:i64&v0:i64)))+(36028797018963947:i64*(512:i64&v0:i64)))+(18014398509481974:i64*(1024:i64&v0:i64)))+(9007199254740987:i64*(2048:i64&v0:i64)))+(4503599627370494:i64*(4096:i64&v0:i64)))+(2251799813685247:i64*(8192:i64&v0:i64)))+(255:i64&v1:i64))+(72057594037927893:i64*(256:i64&v1:i64)))+(36028797018963947:i64*(512:i64&v1:i64)))+(18014398509481974:i64*(1024:i64&v1:i64)))+(9007199254740987:i64*(2048:i64&v1:i64)))+(4503599627370494:i64*(4096:i64&v1:i64)))+(2251799813685247:i64*(8192:i64&v1:i64)))+(-6773192272221240327:i64*(31:i64&(v0:i64&v1:i64))))+(144115188075855784:i64*(128:i64&(v0:i64&v1:i64))))+(-72057594037927893:i64*(256:i64&(v0:i64&v1:i64))))+(-36028797018963947:i64*(512:i64&(v0:i64&v1:i64))))+(-18014398509481974:i64*(1024:i64&(v0:i64&v1:i64))))+(-9007199254740987:i64*(2048:i64&(v0:i64&v1:i64))))+(-4503599627370494:i64*(4096:i64&(v0:i64&v1:i64))))+(-2251799813685247:i64*(8192:i64&(v0:i64&v1:i64))))"; + +inputText = "(((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((-2445733066508366983:i64+(4776822395524155:i64*(512:i64&v0:i64)))+(2388411197762078:i64*(1024:i64&v0:i64)))+(1194205598881039:i64*(2048:i64&v0:i64)))+(597102799440520:i64*(4096:i64&v0:i64)))+(68:i64*(36028797018963968:i64&v0:i64)))+(136:i64*(18014398509481984:i64&v0:i64)))+(74637849930065:i64*(32768:i64&v0:i64)))+(37318924965033:i64*(65536:i64&v0:i64)))+(18659462482517:i64*(131072:i64&v0:i64)))+(9329731241259:i64*(262144:i64&v0:i64)))+(-4611686018427387649:i64&v0:i64))+(2332432810315:i64*(1048576:i64&v0:i64)))+(284720803:i64*(1152921513196781568:i64&v0:i64)))+(583108202579:i64*(4194304:i64&v0:i64)))+(291554101290:i64*(2305843009222082560:i64&v0:i64)))+(145777050645:i64*(16777216:i64&v0:i64)))+(72888525323:i64*(33554432:i64&v0:i64)))+(36444262662:i64*(67108864:i64&v0:i64)))+(1112191:i64*(2199023255552:i64&v0:i64)))+(2224382:i64*(1099511627776:i64&v0:i64)))+(4555532833:i64*(536870912:i64&v0:i64)))+(2277766417:i64*(144115189149597696:i64&v0:i64)))+(1138883209:i64*(288230378299195392:i64&v0:i64)))+(569441605:i64*(576460756598390784:i64&v0:i64)))+(1166216405158:i64*(2097152:i64&v0:i64)))+(9553644791048309:i64*(256:i64&v0:i64)))+(17378:i64*(140737488355328:i64&v0:i64)))+(35590101:i64*(68719476736:i64&v0:i64)))+(17795051:i64*(137438953472:i64&v0:i64)))+(8897526:i64*(274877906944:i64&v0:i64)))+(4448763:i64*(549755813888:i64&v0:i64)))+(9111065666:i64*(268435456:i64&v0:i64)))+(18222131331:i64*(134217728:i64&v0:i64)))+(556096:i64*(4398046511104:i64&v0:i64)))+(278048:i64*(8796093022208:i64&v0:i64)))+(139024:i64*(17592186044416:i64&v0:i64)))+(69512:i64*(35184372088832:i64&v0:i64)))+(34756:i64*(70368744177664:i64&v0:i64)))+(71180201:i64*(34359738368:i64&v0:i64)))+(8689:i64*(281474976710656:i64&v0:i64)))+(2173:i64*(1125899906842624:i64&v0:i64)))+(1087:i64*(2251799813685248:i64&v0:i64)))+(544:i64*(4503599627370496:i64&v0:i64)))+(272:i64*(9007199254740992:i64&v0:i64)))+(149275699860130:i64*(16384:i64&v0:i64)))+(298551399720260:i64*(8192:i64&v0:i64)))+(34:i64*(72057594037927936:i64&v0:i64)))+(4345:i64*(562949953421312:i64&v0:i64)))+(142360402:i64*(17179869184:i64&v0:i64)))+(4664865620630:i64*(524288:i64&v0:i64)))+(4776822395524155:i64*(512:i64&v1:i64)))+(2388411197762078:i64*(1024:i64&v1:i64)))+(1194205598881039:i64*(2048:i64&v1:i64)))+(597102799440520:i64*(4096:i64&v1:i64)))+(68:i64*(36028797018963968:i64&v1:i64)))+(136:i64*(18014398509481984:i64&v1:i64)))+(74637849930065:i64*(32768:i64&v1:i64)))+(37318924965033:i64*(65536:i64&v1:i64)))+(18659462482517:i64*(131072:i64&v1:i64)))+(9329731241259:i64*(262144:i64&v1:i64)))+(-4611686018427387649:i64&v1:i64))+(2332432810315:i64*(1048576:i64&v1:i64)))+(284720803:i64*(1152921513196781568:i64&v1:i64)))+(583108202579:i64*(4194304:i64&v1:i64)))+(291554101290:i64*(2305843009222082560:i64&v1:i64)))+(145777050645:i64*(16777216:i64&v1:i64)))+(72888525323:i64*(33554432:i64&v1:i64)))+(36444262662:i64*(67108864:i64&v1:i64)))+(1112191:i64*(2199023255552:i64&v1:i64)))+(2224382:i64*(1099511627776:i64&v1:i64)))+(4555532833:i64*(536870912:i64&v1:i64)))+(2277766417:i64*(144115189149597696:i64&v1:i64)))+(1138883209:i64*(288230378299195392:i64&v1:i64)))+(569441605:i64*(576460756598390784:i64&v1:i64)))+(1166216405158:i64*(2097152:i64&v1:i64)))+(9553644791048309:i64*(256:i64&v1:i64)))+(17378:i64*(140737488355328:i64&v1:i64)))+(35590101:i64*(68719476736:i64&v1:i64)))+(17795051:i64*(137438953472:i64&v1:i64)))+(8897526:i64*(274877906944:i64&v1:i64)))+(4448763:i64*(549755813888:i64&v1:i64)))+(9111065666:i64*(268435456:i64&v1:i64)))+(18222131331:i64*(134217728:i64&v1:i64)))+(556096:i64*(4398046511104:i64&v1:i64)))+(278048:i64*(8796093022208:i64&v1:i64)))+(139024:i64*(17592186044416:i64&v1:i64)))+(69512:i64*(35184372088832:i64&v1:i64)))+(34756:i64*(70368744177664:i64&v1:i64)))+(71180201:i64*(34359738368:i64&v1:i64)))+(8689:i64*(281474976710656:i64&v1:i64)))+(2173:i64*(1125899906842624:i64&v1:i64)))+(1087:i64*(2251799813685248:i64&v1:i64)))+(544:i64*(4503599627370496:i64&v1:i64)))+(272:i64*(9007199254740992:i64&v1:i64)))+(149275699860130:i64*(16384:i64&v1:i64)))+(298551399720260:i64*(8192:i64&v1:i64)))+(34:i64*(72057594037927936:i64&v1:i64)))+(4345:i64*(562949953421312:i64&v1:i64)))+(142360402:i64*(17179869184:i64&v1:i64)))+(4664865620630:i64*(524288:i64&v1:i64)))+(-9553644791048309:i64*(256:i64&(v0:i64&v1:i64))))+(-4776822395524155:i64*(512:i64&(v0:i64&v1:i64))))+(-2388411197762078:i64*(1024:i64&(v0:i64&v1:i64))))+(-68:i64*(36028797018963968:i64&(v0:i64&v1:i64))))+(-597102799440520:i64*(4096:i64&(v0:i64&v1:i64))))+(-136:i64*(18014398509481984:i64&(v0:i64&v1:i64))))+(-149275699860130:i64*(16384:i64&(v0:i64&v1:i64))))+(-74637849930065:i64*(32768:i64&(v0:i64&v1:i64))))+(-37318924965033:i64*(65536:i64&(v0:i64&v1:i64))))+(-18659462482517:i64*(131072:i64&(v0:i64&v1:i64))))+(-6773192272221240327:i64*(31:i64&(v0:i64&v1:i64))))+(-4664865620630:i64*(524288:i64&(v0:i64&v1:i64))))+(-569441605:i64*(-4035225261828997120:i64&(v0:i64&v1:i64))))+(-1166216405158:i64*(2097152:i64&(v0:i64&v1:i64))))+(-583108202579:i64*(4194304:i64&(v0:i64&v1:i64))))+(-291554101290:i64*(8388608:i64&(v0:i64&v1:i64))))+(-145777050645:i64*(16777216:i64&(v0:i64&v1:i64))))+(-72888525323:i64*(33554432:i64&(v0:i64&v1:i64))))+(-36444262662:i64*(67108864:i64&(v0:i64&v1:i64))))+(-18222131331:i64*(1152921504741064704:i64&(v0:i64&v1:i64))))+(-9111065666:i64*(2305843009482129408:i64&(v0:i64&v1:i64))))+(-4555532833:i64*(536870912:i64&(v0:i64&v1:i64))))+(-2277766417:i64*(144115189149597696:i64&(v0:i64&v1:i64))))+(-1138883209:i64*(288230378299195392:i64&(v0:i64&v1:i64))))+(-2332432810315:i64*(1048576:i64&(v0:i64&v1:i64))))+(19107289582096616:i64*(128:i64&(v0:i64&v1:i64))))+(-34756:i64*(70368744177664:i64&(v0:i64&v1:i64))))+(-71180201:i64*(34359738368:i64&(v0:i64&v1:i64))))+(-35590101:i64*(68719476736:i64&(v0:i64&v1:i64))))+(-17795051:i64*(137438953472:i64&(v0:i64&v1:i64))))+(-8897526:i64*(274877906944:i64&(v0:i64&v1:i64))))+(-4448763:i64*(549755813888:i64&(v0:i64&v1:i64))))+(-2224382:i64*(1099511627776:i64&(v0:i64&v1:i64))))+(-1112191:i64*(2199023255552:i64&(v0:i64&v1:i64))))+(-556096:i64*(4398046511104:i64&(v0:i64&v1:i64))))+(-278048:i64*(8796093022208:i64&(v0:i64&v1:i64))))+(-139024:i64*(17592186044416:i64&(v0:i64&v1:i64))))+(-69512:i64*(35184372088832:i64&(v0:i64&v1:i64))))+(-142360402:i64*(17179869184:i64&(v0:i64&v1:i64))))+(-17378:i64*(140737488355328:i64&(v0:i64&v1:i64))))+(-4345:i64*(562949953421312:i64&(v0:i64&v1:i64))))+(-2173:i64*(1125899906842624:i64&(v0:i64&v1:i64))))+(-1087:i64*(2251799813685248:i64&(v0:i64&v1:i64))))+(-544:i64*(4503599627370496:i64&(v0:i64&v1:i64))))+(-272:i64*(9007199254740992:i64&(v0:i64&v1:i64))))+(-298551399720260:i64*(8192:i64&(v0:i64&v1:i64))))+(-1194205598881039:i64*(2048:i64&(v0:i64&v1:i64))))+(-34:i64*(72057594037927936:i64&(v0:i64&v1:i64))))+(-8689:i64*(281474976710656:i64&(v0:i64&v1:i64))))+(-284720803:i64*(8589934592:i64&(v0:i64&v1:i64))))+(-9329731241259:i64*(262144:i64&(v0:i64&v1:i64))))"; + +inputText = "(((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((7812993655859954:i64*(2048:i64&v0:i64))+(3906496827929977:i64*(4096:i64&v0:i64)))+(1953248413964989:i64*(8192:i64&v0:i64)))+(445:i64*(36028797018963968:i64&v0:i64)))+(488312103491248:i64*(32768:i64&v0:i64)))+(976624206982495:i64*(16384:i64&v0:i64)))+(122078025872812:i64*(131072:i64&v0:i64)))+(61039012936406:i64*(262144:i64&v0:i64)))+(30519506468203:i64*(524288:i64&v0:i64)))+(62503949246879628:i64*(256:i64&v0:i64)))+(7629876617051:i64*(2097152:i64&v0:i64)))+(931381423:i64*(2305843026393563136:i64&v0:i64)))+(1907469154263:i64*(8388608:i64&v0:i64)))+(953734577132:i64*(16777216:i64&v0:i64)))+(476867288566:i64*(33554432:i64&v0:i64)))+(238433644283:i64*(67108864:i64&v0:i64)))+(119216822142:i64*(134217728:i64&v0:i64)))+(3638209:i64*(4398046511104:i64&v0:i64)))+(7276418:i64*(2199023255552:i64&v0:i64)))+(14902102768:i64*(144115189149597696:i64&v0:i64)))+(7451051384:i64*(288230378299195392:i64&v0:i64)))+(3725525692:i64*(576460756598390784:i64&v0:i64)))+(1862762846:i64*(1152921513196781568:i64&v0:i64)))+(3814938308526:i64*(4194304:i64&v0:i64)))+(15625987311719907:i64*(1024:i64&v0:i64)))+(56848:i64*(281474976710656:i64&v0:i64)))+(116422678:i64*(137438953472:i64&v0:i64)))+(58211339:i64*(274877906944:i64&v0:i64)))+(29105670:i64*(549755813888:i64&v0:i64)))+(14552835:i64*(1099511627776:i64&v0:i64)))+(29804205536:i64*(536870912:i64&v0:i64)))+(59608411071:i64*(268435456:i64&v0:i64)))+(1819105:i64*(8796093022208:i64&v0:i64)))+(909553:i64*(17592186044416:i64&v0:i64)))+(454777:i64*(35184372088832:i64&v0:i64)))+(227389:i64*(70368744177664:i64&v0:i64)))+(113695:i64*(140737488355328:i64&v0:i64)))+(232845356:i64*(68719476736:i64&v0:i64)))+(14212:i64*(1125899906842624:i64&v0:i64)))+(7106:i64*(2251799813685248:i64&v0:i64)))+(3553:i64*(4503599627370496:i64&v0:i64)))+(1777:i64*(9007199254740992:i64&v0:i64)))+(889:i64*(18014398509481984:i64&v0:i64)))+(244156051745624:i64*(65536:i64&v0:i64)))+(223:i64*(72057594037927936:i64&v0:i64)))+(28424:i64*(562949953421312:i64&v0:i64)))+(465690712:i64*(34359738368:i64&v0:i64)))+(15259753234102:i64*(1048576:i64&v0:i64)))+(31251974623439814:i64*(512:i64&v0:i64)))+(7812993655859954:i64*(2048:i64&v1:i64)))+(3906496827929977:i64*(4096:i64&v1:i64)))+(1953248413964989:i64*(8192:i64&v1:i64)))+(445:i64*(36028797018963968:i64&v1:i64)))+(488312103491248:i64*(32768:i64&v1:i64)))+(976624206982495:i64*(16384:i64&v1:i64)))+(122078025872812:i64*(131072:i64&v1:i64)))+(61039012936406:i64*(262144:i64&v1:i64)))+(30519506468203:i64*(524288:i64&v1:i64)))+(62503949246879628:i64*(256:i64&v1:i64)))+(7629876617051:i64*(2097152:i64&v1:i64)))+(931381423:i64*(2305843026393563136:i64&v1:i64)))+(1907469154263:i64*(8388608:i64&v1:i64)))+(953734577132:i64*(16777216:i64&v1:i64)))+(476867288566:i64*(33554432:i64&v1:i64)))+(238433644283:i64*(67108864:i64&v1:i64)))+(119216822142:i64*(134217728:i64&v1:i64)))+(3638209:i64*(4398046511104:i64&v1:i64)))+(7276418:i64*(2199023255552:i64&v1:i64)))+(14902102768:i64*(144115189149597696:i64&v1:i64)))+(7451051384:i64*(288230378299195392:i64&v1:i64)))+(3725525692:i64*(576460756598390784:i64&v1:i64)))+(1862762846:i64*(1152921513196781568:i64&v1:i64)))+(3814938308526:i64*(4194304:i64&v1:i64)))+(15625987311719907:i64*(1024:i64&v1:i64)))+(56848:i64*(281474976710656:i64&v1:i64)))+(116422678:i64*(137438953472:i64&v1:i64)))+(58211339:i64*(274877906944:i64&v1:i64)))+(29105670:i64*(549755813888:i64&v1:i64)))+(14552835:i64*(1099511627776:i64&v1:i64)))+(29804205536:i64*(536870912:i64&v1:i64)))+(59608411071:i64*(268435456:i64&v1:i64)))+(1819105:i64*(8796093022208:i64&v1:i64)))+(909553:i64*(17592186044416:i64&v1:i64)))+(454777:i64*(35184372088832:i64&v1:i64)))+(227389:i64*(70368744177664:i64&v1:i64)))+(113695:i64*(140737488355328:i64&v1:i64)))+(232845356:i64*(68719476736:i64&v1:i64)))+(14212:i64*(1125899906842624:i64&v1:i64)))+(7106:i64*(2251799813685248:i64&v1:i64)))+(3553:i64*(4503599627370496:i64&v1:i64)))+(1777:i64*(9007199254740992:i64&v1:i64)))+(889:i64*(18014398509481984:i64&v1:i64)))+(244156051745624:i64*(65536:i64&v1:i64)))+(223:i64*(72057594037927936:i64&v1:i64)))+(28424:i64*(562949953421312:i64&v1:i64)))+(465690712:i64*(34359738368:i64&v1:i64)))+(15259753234102:i64*(1048576:i64&v1:i64)))+(31251974623439814:i64*(512:i64&v1:i64)))+(-15625987311719906:i64*(576460752303424512:i64&(v0:i64&v1:i64))))+(-7591426311874296:i64*(2048:i64&(v0:i64&v1:i64))))+(-3906496827929976:i64*(4096:i64&(v0:i64&v1:i64))))+(-1953248413964988:i64*(8192:i64&(v0:i64&v1:i64))))+(-444:i64*(36028797018963968:i64&(v0:i64&v1:i64))))+(-976624206982494:i64*(16384:i64&(v0:i64&v1:i64))))+(-244156051745623:i64*(65536:i64&(v0:i64&v1:i64))))+(-122078025872811:i64*(131072:i64&(v0:i64&v1:i64))))+(-120946645661404:i64*(262144:i64&(v0:i64&v1:i64))))+(-55242767104369:i64*(524288:i64&(v0:i64&v1:i64))))+(-22390827825852:i64*(1048576:i64&(v0:i64&v1:i64))))+(125007898493759256:i64*(128:i64&(v0:i64&v1:i64))))+(-1286519989:i64*(17179869184:i64&(v0:i64&v1:i64))))+(-2441473979357:i64*(8388608:i64&(v0:i64&v1:i64))))+(-953734577131:i64*(16777216:i64&(v0:i64&v1:i64))))+(-476867288565:i64*(33554432:i64&(v0:i64&v1:i64))))+(-238433644282:i64*(67108864:i64&(v0:i64&v1:i64))))+(-119216822141:i64*(134217728:i64&(v0:i64&v1:i64))))+(-112576899013:i64*(268435456:i64&(v0:i64&v1:i64))))+(-29804205535:i64*(536870912:i64&(v0:i64&v1:i64))))+(-14902102767:i64*(144115189149597696:i64&(v0:i64&v1:i64))))+(-8879931774:i64*(6917529029788565504:i64&(v0:i64&v1:i64))))+(-3725525691:i64*(4294967296:i64&(v0:i64&v1:i64))))+(-1862762845:i64*(1152921513196781568:i64&(v0:i64&v1:i64))))+(-6547966389172:i64*(4194304:i64&(v0:i64&v1:i64))))+(-31251974623439813:i64*(512:i64&(v0:i64&v1:i64))))+(-27406:i64*(562949953421312:i64&(v0:i64&v1:i64))))+(-203125788:i64*(137438953472:i64&(v0:i64&v1:i64))))+(-77805585:i64*(274877906944:i64&(v0:i64&v1:i64))))+(-29105669:i64*(549755813888:i64&(v0:i64&v1:i64))))+(-17369865:i64*(1099511627776:i64&(v0:i64&v1:i64))))+(-7276417:i64*(2199023255552:i64&(v0:i64&v1:i64))))+(-3638208:i64*(4398046511104:i64&(v0:i64&v1:i64))))+(-2538983:i64*(8796093022208:i64&(v0:i64&v1:i64))))+(-909552:i64*(17592186044416:i64&(v0:i64&v1:i64))))+(-650367:i64*(35184372088832:i64&(v0:i64&v1:i64))))+(-422979:i64*(70368744177664:i64&(v0:i64&v1:i64))))+(-178213:i64*(140737488355328:i64&(v0:i64&v1:i64))))+(-319548466:i64*(68719476736:i64&(v0:i64&v1:i64))))+(-14211:i64*(1125899906842624:i64&(v0:i64&v1:i64))))+(-6088:i64*(2251799813685248:i64&(v0:i64&v1:i64))))+(-3552:i64*(4503599627370496:i64&(v0:i64&v1:i64))))+(-2807:i64*(297237575406452736:i64&(v0:i64&v1:i64))))+(-895:i64*(18014398509481984:i64&(v0:i64&v1:i64))))+(-829694712926902:i64*(32768:i64&(v0:i64&v1:i64))))+(-222:i64*(72057594037927936:i64&(v0:i64&v1:i64))))+(-55830:i64*(281474976710656:i64&(v0:i64&v1:i64))))+(-820829278:i64*(34359738368:i64&(v0:i64&v1:i64))))+(-7629876617050:i64*(2097152:i64&(v0:i64&v1:i64))))+(-62282381902893970:i64*(256:i64&(v0:i64&v1:i64)))"; + +inputText = "(((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((13716790784261793:i64*(1024:i64&v0:i64))+(2569773508279186:i64*(2048:i64&v0:i64)))+(854931267102023:i64*(4096:i64&v0:i64)))+(642443377069797:i64*(1152921504606855168:i64&v0:i64)))+(147:i64*(36028797018963968:i64&v0:i64)))+(321221688534899:i64*(16384:i64&v0:i64)))+(80305422133725:i64*(65536:i64&v0:i64)))+(32409689095277:i64*(131072:i64&v0:i64)))+(20076355533432:i64*(262144:i64&v0:i64)))+(10038177766716:i64*(524288:i64&v0:i64)))+(14868252956188:i64*(1048576:i64&v0:i64)))+(41116376132466962:i64*(128:i64&v0:i64)))+(36662153:i64*(-9223372019674906624:i64&v0:i64)))+(1680457161042:i64*(8388608:i64&v0:i64)))+(313693055210:i64*(16777216:i64&v0:i64)))+(110405950451:i64*(33554432:i64&v0:i64)))+(78423263803:i64*(67108864:i64&v0:i64)))+(39211631902:i64*(134217728:i64&v0:i64)))+(41884715533:i64*(268435456:i64&v0:i64)))+(9802907976:i64*(576460752840294400:i64&v0:i64)))+(4901453988:i64*(1073741824:i64&v0:i64)))+(7549757392:i64*(2147483648:i64&v0:i64)))+(1225363497:i64*(4294967296:i64&v0:i64)))+(1416744851:i64*(2305843017803628544:i64&v0:i64)))+(1254772220840:i64*(4194304:i64&v0:i64)))+(841939291338179:i64*(512:i64&v0:i64)))+(11267:i64*(562949953421312:i64&v0:i64)))+(37049344:i64*(137438953472:i64&v0:i64)))+(19146305:i64*(274877906944:i64&v0:i64)))+(9573153:i64*(549755813888:i64&v0:i64)))+(4786577:i64*(288231475663339520:i64&v0:i64)))+(1150023:i64*(2199023255552:i64&v0:i64)))+(4147683:i64*(4398046511104:i64&v0:i64)))+(598323:i64*(8796093022208:i64&v0:i64)))+(104472:i64*(17592186044416:i64&v0:i64)))+(149581:i64*(35184372088832:i64&v0:i64)))+(142245:i64*(144185556820033536:i64&v0:i64)))+(37396:i64*(140737488355328:i64&v0:i64)))+(75341953:i64*(68719476736:i64&v0:i64)))+(4675:i64*(1125899906842624:i64&v0:i64)))+(4256:i64*(2251799813685248:i64&v0:i64)))+(3087:i64*(4503599627370496:i64&v0:i64)))+(455:i64*(9007199254740992:i64&v0:i64)))+(163:i64*(18014398509481984:i64&v0:i64)))+(293605310651192:i64*(32768:i64&v0:i64)))+(74:i64*(72057594037927936:i64&v0:i64)))+(20616:i64*(281474976710656:i64&v0:i64)))+(153170438:i64*(34359738368:i64&v0:i64)))+(2509544441679:i64*(2097152:i64&v0:i64)))+(11121033324454919:i64*(256:i64&v0:i64)))+(20558188066233481:i64*(256:i64&v1:i64)))+(10279094033116741:i64*(512:i64&v1:i64)))+(5139547016558371:i64*(1024:i64&v1:i64)))+(2569773508279186:i64*(2048:i64&v1:i64)))+(74:i64*(72057594037927936:i64&v1:i64)))+(147:i64*(324259173170675712:i64&v1:i64)))+(321221688534899:i64*(16384:i64&v1:i64)))+(160610844267450:i64*(32768:i64&v1:i64)))+(80305422133725:i64*(65536:i64&v1:i64)))+(40152711066863:i64*(131072:i64&v1:i64)))+(20076355533432:i64*(262144:i64&v1:i64)))+(164465504529867848:i64*(32:i64&v1:i64)))+(5019088883358:i64*(1048576:i64&v1:i64)))+(612681749:i64*(-8070450523657994240:i64&v1:i64)))+(1254772220840:i64*(4194304:i64&v1:i64)))+(627386110420:i64*(8388608:i64&v1:i64)))+(313693055210:i64*(576460752320200704:i64&v1:i64)))+(156846527605:i64*(33554432:i64&v1:i64)))+(78423263803:i64*(2305843009280802816:i64&v1:i64)))+(39211631902:i64*(134217728:i64&v1:i64)))+(19605815951:i64*(268435456:i64&v1:i64)))+(9802907976:i64*(536870912:i64&v1:i64)))+(4901453988:i64*(1073741824:i64&v1:i64)))+(2450726994:i64*(4611686020574871552:i64&v1:i64)))+(1225363497:i64*(4294967296:i64&v1:i64)))+(2509544441679:i64*(2097152:i64&v1:i64)))+(41116376132466962:i64*(128:i64&v1:i64)))+(37396:i64*(140737488355328:i64&v1:i64)))+(76585219:i64*(68719476736:i64&v1:i64)))+(38292610:i64*(137438953472:i64&v1:i64)))+(19146305:i64*(274877906944:i64&v1:i64)))+(9573153:i64*(549755813888:i64&v1:i64)))+(4786577:i64*(1099511627776:i64&v1:i64)))+(2393289:i64*(2199023255552:i64&v1:i64)))+(1196645:i64*(4398046511104:i64&v1:i64)))+(598323:i64*(8796093022208:i64&v1:i64)))+(299162:i64*(17592186044416:i64&v1:i64)))+(149581:i64*(35184372088832:i64&v1:i64)))+(74791:i64*(70368744177664:i64&v1:i64)))+(153170438:i64*(34359738368:i64&v1:i64)))+(18698:i64*(281474976710656:i64&v1:i64)))+(4675:i64*(1125899906842624:i64&v1:i64)))+(2338:i64*(2251799813685248:i64&v1:i64)))+(1169:i64*(4503599627370496:i64&v1:i64)))+(585:i64*(9007199254740992:i64&v1:i64)))+(293:i64*(162129586585337856:i64&v1:i64)))+(642443377069797:i64*(8192:i64&v1:i64)))+(1284886754139593:i64*(4096:i64&v1:i64)))+(9349:i64*(562949953421312:i64&v1:i64)))+(306340875:i64*(17179869184:i64&v1:i64)))+(10038177766716:i64*(524288:i64&v1:i64)))+(82232752264933924:i64*(64:i64&v1:i64)))+(-10279094033116740:i64*(512:i64&(v0:i64&v1:i64))))+(-5139547016558370:i64*(1024:i64&(v0:i64&v1:i64))))+(-2569773508279185:i64*(2048:i64&(v0:i64&v1:i64))))+(-1284886754139592:i64*(4096:i64&(v0:i64&v1:i64))))+(-146:i64*(324259173170675712:i64&(v0:i64&v1:i64))))+(-642443377069796:i64*(8192:i64&(v0:i64&v1:i64))))+(-160610844267449:i64*(32768:i64&(v0:i64&v1:i64))))+(-80305422133724:i64*(65536:i64&(v0:i64&v1:i64))))+(-40152711066862:i64*(131072:i64&(v0:i64&v1:i64))))+(-20076355533431:i64*(262144:i64&(v0:i64&v1:i64))))+(-10038177766715:i64*(524288:i64&(v0:i64&v1:i64))))+(-164465504529867848:i64*(32:i64&(v0:i64&v1:i64))))+(-612681748:i64*(1152921513196781568:i64&(v0:i64&v1:i64))))+(-1254772220839:i64*(4194304:i64&(v0:i64&v1:i64))))+(-627386110419:i64*(8388608:i64&(v0:i64&v1:i64))))+(-313693055209:i64*(576460752320200704:i64&(v0:i64&v1:i64))))+(-156846527604:i64*(33554432:i64&(v0:i64&v1:i64))))+(-78423263802:i64*(-6917529027573972992:i64&(v0:i64&v1:i64))))+(-39211631901:i64*(134217728:i64&(v0:i64&v1:i64))))+(-19605815950:i64*(268435456:i64&(v0:i64&v1:i64))))+(-9802907975:i64*(536870912:i64&(v0:i64&v1:i64))))+(-4901453987:i64*(1073741824:i64&(v0:i64&v1:i64))))+(-2450726993:i64*(4611686020574871552:i64&(v0:i64&v1:i64))))+(-1225363496:i64*(4294967296:i64&(v0:i64&v1:i64))))+(-2509544441678:i64*(2097152:i64&(v0:i64&v1:i64))))+(-20558188066233480:i64*(256:i64&(v0:i64&v1:i64))))+(-37395:i64*(140737488355328:i64&(v0:i64&v1:i64))))+(-76585218:i64*(68719476736:i64&(v0:i64&v1:i64))))+(-38292609:i64*(137438953472:i64&(v0:i64&v1:i64))))+(-19146304:i64*(274877906944:i64&(v0:i64&v1:i64))))+(-9573152:i64*(549755813888:i64&(v0:i64&v1:i64))))+(-4786576:i64*(1099511627776:i64&(v0:i64&v1:i64))))+(-2393288:i64*(2199023255552:i64&(v0:i64&v1:i64))))+(-1196644:i64*(4398046511104:i64&(v0:i64&v1:i64))))+(-598322:i64*(8796093022208:i64&(v0:i64&v1:i64))))+(-299161:i64*(17592186044416:i64&(v0:i64&v1:i64))))+(-149580:i64*(35184372088832:i64&(v0:i64&v1:i64))))+(-74790:i64*(70368744177664:i64&(v0:i64&v1:i64))))+(-153170437:i64*(34359738368:i64&(v0:i64&v1:i64))))+(-9348:i64*(562949953421312:i64&(v0:i64&v1:i64))))+(-4674:i64*(1125899906842624:i64&(v0:i64&v1:i64))))+(-2337:i64*(2251799813685248:i64&(v0:i64&v1:i64))))+(-1168:i64*(4503599627370496:i64&(v0:i64&v1:i64))))+(-584:i64*(9007199254740992:i64&(v0:i64&v1:i64))))+(-292:i64*(18014398509481984:i64&(v0:i64&v1:i64))))+(-321221688534898:i64*(16384:i64&(v0:i64&v1:i64))))+(-73:i64*(72057594037927936:i64&(v0:i64&v1:i64))))+(-18697:i64*(281474976710656:i64&(v0:i64&v1:i64))))+(-306340874:i64*(17179869184:i64&(v0:i64&v1:i64))))+(-5019088883357:i64*(1048576:i64&(v0:i64&v1:i64))))+(-82232752264933924:i64*(144115188075856064:i64&(v0:i64&v1:i64))))"; + +inputText = "(((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((551254427606215992:i64*(32:i64&v0:i64))+(275627213803107996:i64*(64:i64&v0:i64)))+(137813606901553998:i64*(128:i64&v0:i64)))+(68906803450776999:i64*(256:i64&v0:i64)))+(34453401725388500:i64*(512:i64&v0:i64)))+(980:i64*(18014398509481984:i64&v0:i64)))+(3917:i64*(4503599627370496:i64&v0:i64)))+(4306675215673563:i64*(4096:i64&v0:i64)))+(2153337607836782:i64*(8192:i64&v0:i64)))+(1076668803918391:i64*(16384:i64&v0:i64)))+(538334401959196:i64*(32768:i64&v0:i64)))+(269167200979598:i64*(65536:i64&v0:i64)))+(134583600489799:i64*(131072:i64&v0:i64)))+(1102508855212431984:i64*(16:i64&v0:i64)))+(8214331085:i64*(2147483648:i64&v0:i64)))+(16822950061225:i64*(1048576:i64&v0:i64)))+(8411475030613:i64*(2097152:i64&v0:i64)))+(4205737515307:i64*(4194304:i64&v0:i64)))+(2102868757654:i64*(8388608:i64&v0:i64)))+(1051434378827:i64*(16777216:i64&v0:i64)))+(525717189414:i64*(33554432:i64&v0:i64)))+(262858594707:i64*(67108864:i64&v0:i64)))+(131429297354:i64*(134217728:i64&v0:i64)))+(65714648677:i64*(268435456:i64&v0:i64)))+(32857324339:i64*(536870912:i64&v0:i64)))+(16428662170:i64*(1073741824:i64&v0:i64)))+(33645900122450:i64*(524288:i64&v0:i64)))+(67291800244900:i64*(262144:i64&v0:i64)))+(250682:i64*(70368744177664:i64&v0:i64)))+(1026791386:i64*(17179869184:i64&v0:i64)))+(513395693:i64*(34359738368:i64&v0:i64)))+(256697847:i64*(68719476736:i64&v0:i64)))+(128348924:i64*(137438953472:i64&v0:i64)))+(64174462:i64*(288230651029618688:i64&v0:i64)))+(32087231:i64*(576461302059237376:i64&v0:i64)))+(16043616:i64*(1099511627776:i64&v0:i64)))+(8021808:i64*(2199023255552:i64&v0:i64)))+(4010904:i64*(4398046511104:i64&v0:i64)))+(2005452:i64*(8796093022208:i64&v0:i64)))+(1002726:i64*(17592186044416:i64&v0:i64)))+(2053582772:i64*(8589934592:i64&v0:i64)))+(125341:i64*(140737488355328:i64&v0:i64)))+(62671:i64*(281474976710656:i64&v0:i64)))+(31336:i64*(562949953421312:i64&v0:i64)))+(15668:i64*(1125899906842624:i64&v0:i64)))+(7834:i64*(2251799813685248:i64&v0:i64)))+(8613350431347125:i64*(2048:i64&v0:i64)))+(1959:i64*(9007199254740992:i64&v0:i64)))+(17226700862694250:i64*(1024:i64&v0:i64)))+(490:i64*(36028797018963968:i64&v0:i64)))+(245:i64*(72057594037927936:i64&v0:i64)))+(123:i64*(144115188075855872:i64&v0:i64)))+(501363:i64*(35184372088832:i64&v0:i64)))+(4107165543:i64*(4294967296:i64&v0:i64)))+(8137759425928576:i64*(2048:i64&v1:i64)))+(4306675215673563:i64*(4096:i64&v1:i64)))+(1677746602418233:i64*(8192:i64&v1:i64)))+(437:i64*(36028797018963968:i64&v1:i64)))+(538334401959196:i64*(32768:i64&v1:i64)))+(601077798499842:i64*(16384:i64&v1:i64)))+(81205060137234:i64*(131072:i64&v1:i64)))+(13913259892335:i64*(262144:i64&v1:i64)))+(15451731858717:i64*(524288:i64&v1:i64)))+(14388016916912498:i64*(256:i64&v1:i64)))+(8411475030613:i64*(2097152:i64&v1:i64)))+(1026791386:i64*(17179869184:i64&v1:i64)))+(2102868757654:i64*(8388608:i64&v1:i64)))+(1051434378827:i64*(16777216:i64&v1:i64)))+(525717189414:i64*(33554432:i64&v1:i64)))+(210632189278:i64*(67108864:i64&v1:i64)))+(131429297354:i64*(134217728:i64&v1:i64)))+(4010904:i64*(4398046511104:i64&v1:i64)))+(701179:i64*(1297038891705958400:i64&v1:i64)))+(15741864293:i64*(1073741824:i64&v1:i64)))+(7527533208:i64*(2147483648:i64&v1:i64)))+(4107165543:i64*(4611686022722355200:i64&v1:i64)))+(2053582772:i64*(8589934592:i64&v1:i64)))+(4205737515307:i64*(4194304:i64&v1:i64)))+(16751109857275701:i64*(1024:i64&v1:i64)))+(16538:i64*(281474976710656:i64&v1:i64)))+(128348924:i64*(137438953472:i64&v1:i64)))+(64174462:i64*(274877906944:i64&v1:i64)))+(16377994:i64*(549755813888:i64&v1:i64)))+(16043616:i64*(1099511627776:i64&v1:i64)))+(14990657278:i64*(288230376688582656:i64&v1:i64)))+(65714648677:i64*(268435456:i64&v1:i64)))+(2005452:i64*(8796093022208:i64&v1:i64)))+(1022129:i64*(17592186044416:i64&v1:i64)))+(501363:i64*(35184372088832:i64&v1:i64)))+(7941:i64*(70368744177664:i64&v1:i64)))+(13672:i64*(140737488355328:i64&v1:i64)))+(106770882:i64*(68719476736:i64&v1:i64)))+(15668:i64*(1125899906842624:i64&v1:i64)))+(7834:i64*(2251799813685248:i64&v1:i64)))+(2840:i64*(4503599627370496:i64&v1:i64)))+(882:i64*(9007199254740992:i64&v1:i64)))+(927:i64*(594475150812905472:i64&v1:i64)))+(269167200979598:i64*(65536:i64&v1:i64)))+(192:i64*(72057594037927936:i64&v1:i64)))+(17971:i64*(562949953421312:i64&v1:i64)))+(363468728:i64*(34359738368:i64&v1:i64)))+(16822950061225:i64*(1048576:i64&v1:i64)))+(34453401725388500:i64*(512:i64&v1:i64)))+(-1102508855212431984:i64*(16:i64&(v0:i64&v1:i64))))+(-551254427606215992:i64*(32:i64&(v0:i64&v1:i64))))+(-198633974609720372:i64*(192:i64&(v0:i64&v1:i64))))+(-68906803450776998:i64*(256:i64&(v0:i64&v1:i64))))+(-34453401725388499:i64*(512:i64&(v0:i64&v1:i64))))+(-17226700862694249:i64*(1024:i64&(v0:i64&v1:i64))))+(-8613350431347124:i64*(2048:i64&(v0:i64&v1:i64))))+(-4306675215673562:i64*(4096:i64&(v0:i64&v1:i64))))+(-2153337607836781:i64*(8192:i64&(v0:i64&v1:i64))))+(-1076668803918390:i64*(16384:i64&(v0:i64&v1:i64))))+(-538334401959195:i64*(32768:i64&(v0:i64&v1:i64))))+(-269167200979597:i64*(65536:i64&(v0:i64&v1:i64))))+(2205017710424863968:i64*(8:i64&(v0:i64&v1:i64))))+(-67291800244899:i64*(4611686018427650048:i64&(v0:i64&v1:i64))))+(-8214331084:i64*(2147483648:i64&(v0:i64&v1:i64))))+(-16822950061224:i64*(1048576:i64&(v0:i64&v1:i64))))+(-8411475030612:i64*(2097152:i64&(v0:i64&v1:i64))))+(-4205737515306:i64*(4194304:i64&(v0:i64&v1:i64))))+(-2102868757653:i64*(8388608:i64&(v0:i64&v1:i64))))+(-1051434378826:i64*(16777216:i64&(v0:i64&v1:i64))))+(-525717189413:i64*(33554432:i64&(v0:i64&v1:i64))))+(-262858594706:i64*(67108864:i64&(v0:i64&v1:i64))))+(-131429297353:i64*(134217728:i64&(v0:i64&v1:i64))))+(-65714648676:i64*(268435456:i64&(v0:i64&v1:i64))))+(-32857324338:i64*(536870912:i64&(v0:i64&v1:i64))))+(-16428662169:i64*(1073741824:i64&(v0:i64&v1:i64))))+(-33645900122449:i64*(524288:i64&(v0:i64&v1:i64))))+(-134583600489798:i64*(131072:i64&(v0:i64&v1:i64))))+(-501362:i64*(35184372088832:i64&(v0:i64&v1:i64))))+(-1026791385:i64*(17179869184:i64&(v0:i64&v1:i64))))+(-513395692:i64*(34359738368:i64&(v0:i64&v1:i64))))+(-256697846:i64*(68719476736:i64&(v0:i64&v1:i64))))+(-128348923:i64*(137438953472:i64&(v0:i64&v1:i64))))+(-64174461:i64*(288230651029618688:i64&(v0:i64&v1:i64))))+(-32087230:i64*(576461302059237376:i64&(v0:i64&v1:i64))))+(-16043615:i64*(1099511627776:i64&(v0:i64&v1:i64))))+(-8021807:i64*(-5764605324010979328:i64&(v0:i64&v1:i64))))+(-4010903:i64*(4398046511104:i64&(v0:i64&v1:i64))))+(-2005451:i64*(8796093022208:i64&(v0:i64&v1:i64))))+(-1002725:i64*(17592186044416:i64&(v0:i64&v1:i64))))+(-2053582771:i64*(8589934592:i64&(v0:i64&v1:i64))))+(-250681:i64*(70368744177664:i64&(v0:i64&v1:i64))))+(-62670:i64*(281474976710656:i64&(v0:i64&v1:i64))))+(-31335:i64*(562949953421312:i64&(v0:i64&v1:i64))))+(-15667:i64*(1125899906842624:i64&(v0:i64&v1:i64))))+(-7833:i64*(2251799813685248:i64&(v0:i64&v1:i64))))+(-3916:i64*(4503599627370496:i64&(v0:i64&v1:i64))))+(-1958:i64*(9007199254740992:i64&(v0:i64&v1:i64))))+(-979:i64*(18014398509481984:i64&(v0:i64&v1:i64))))+(-489:i64*(36028797018963968:i64&(v0:i64&v1:i64))))+(-244:i64*(72057594037927936:i64&(v0:i64&v1:i64))))+(-122:i64*(144115188075855872:i64&(v0:i64&v1:i64))))+(-125340:i64*(140737488355328:i64&(v0:i64&v1:i64))))+(-4107165542:i64*(4294967296:i64&(v0:i64&v1:i64))))\r\n"; + var printHelp = () => { Console.WriteLine("Usage: Simplifier.exe"); @@ -98,11 +110,13 @@ Console.WriteLine($"\nExpression: {ctx.GetAstString(id)}\n\n\n"); -Console.WriteLine(DagFormatter.Format(ctx, id)); +//Console.WriteLine(DagFormatter.Format(ctx, id)); var bx = LinearSimplifier.Run(bitWidth, ctx, id, false, true); +Console.WriteLine($"Linear simplifier returned: {bx}"); + while(false) { var simplifier = new GeneralSimplifier(ctx); From f4635486d13ddf61df463fd46fb17c77216b3793 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Thu, 20 Nov 2025 08:21:15 -0500 Subject: [PATCH 11/21] Start optimizing jit --- Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs | 9 +++++++++ Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs | 4 ++-- Simplifier/Program.cs | 5 +++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs index 46f90df..b1a5d8a 100644 --- a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs +++ b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs @@ -33,6 +33,15 @@ public unsafe Amd64AssemblerDifferentialTester(byte* buffer) icedAssembler = new IcedAmd64Assembler(new Assembler(64)); } + public static void Test() + { + var buffer = new byte[64 * 4096]; + fixed(byte* p = buffer) + { + new Amd64AssemblerDifferentialTester(p).Run(); + } + } + public void Run() { for (int i = 0; i < registers.Length; i++) diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 561a4e9..2ce9dee 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -34,13 +34,13 @@ public class ProbableEquivalenceChecker private unsafe delegate* unmanaged[SuppressGCTransition] func2; - public static bool ProbablyEquivalent(AstCtx ctx, AstIdx before, AstIdx after) + public static bool ProbablyEquivalent(AstCtx ctx, AstIdx before, AstIdx after, bool slowHeuristics = true) { var pagePtr1 = JitUtils.AllocateExecutablePage(4096); var pagePtr2 = JitUtils.AllocateExecutablePage(4096); var allVars = ctx.CollectVariables(before).Concat(ctx.CollectVariables(after)).Distinct().OrderBy(x => ctx.GetSymbolName(x)).ToList(); - bool probablyEquivalent = new ProbableEquivalenceChecker(ctx, allVars, before, after, pagePtr1, pagePtr2).ProbablyEquivalent(true); + bool probablyEquivalent = new ProbableEquivalenceChecker(ctx, allVars, before, after, pagePtr1, pagePtr2).ProbablyEquivalent(false); JitUtils.FreeExecutablePage(pagePtr1); JitUtils.FreeExecutablePage(pagePtr2); diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 44bcfa5..d0f1634 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -4,6 +4,7 @@ using Mba.Simplifier.DSL; using Mba.Simplifier.Fuzzing; using Mba.Simplifier.Interpreter; +using Mba.Simplifier.Jit; using Mba.Simplifier.Minimization; using Mba.Simplifier.Pipeline; using Mba.Simplifier.Utility; @@ -22,6 +23,10 @@ //DatasetTester.Run(); +Amd64AssemblerDifferentialTester.Test(); +Debugger.Break(); + + inputText = "((((1:i32&((uns17:i8 zx i32)&(~uns18:i32)))|(4294964010:i32&(~((uns17:i8 zx i32)|(~uns18:i32)))))|(4294964011:i32&((uns17:i8 zx i32)&uns18:i32)))|(4:i32*(1:i32&(uns19:i8 zx i32))))"; inputText = "((2041933603239772578:i64+((((((((((((((((-27487790705275:i64*uns121:i64)+(-9223358842715237276:i64*(-860922984064492326:i64&uns121:i64)))+(9223354444668724432:i64*uns131:i64))+(-9223350046622211588:i64*(860922984064492325:i64&uns131:i64)))+(-8796093025688:i64*uns132:i64))+(4398046512844:i64*uns34:i64))+(17592186051376:i64*uns65:i64))+(-3298534884633:i64*uns91:i64))+(9223367638808262964:i64*(8362449052790283482:i64&uns91:i64)))+(13194139538532:i64*(860922984064492325:i64&(uns121:i64&uns130:i64))))+(14293651166743:i64*(-3750763034362895579:i64&(uns121:i64&uns67:i64))))+(4398046512844:i64*(uns130:i64&uns133:i64)))+(-8796093025688:i64*(1444920025149201626:i64&(uns130:i64&uns91:i64))))+(-4398046512844:i64*(uns131:i64&uns133:i64)))+(-9223350046622211588:i64*(3750763034362895578:i64&(uns131:i64&uns91:i64))))+(-9895604653899:i64*((uns121:i64&uns130:i64)&uns91:i64))))+(3062923494603851298:i64+(((((((-9895604653899:i64*uns130:i64)+(9895604653899:i64*(-3750763034362895579:i64&uns131:i64)))+(9895604653899:i64*(3750763034362895578:i64&uns17:i64)))+(-9895604653899:i64*(-3750763034362895579:i64&(uns121:i64&uns131:i64))))+(9895604653899:i64*(uns130:i64&uns134:i64)))+(-9895604653899:i64*(uns131:i64&uns134:i64)))+(9895604653899:i64*(-3750763034362895579:i64&(uns131:i64&uns91:i64))))))"; From dcbfc970b19f4797885938226559c494be286f28 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Thu, 20 Nov 2025 08:36:11 -0500 Subject: [PATCH 12/21] archive changes before port --- .../Jit/Amd64AssemblerDifferentialTester.cs | 12 +++++++----- Simplifier/Program.cs | 1 + 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs index b1a5d8a..f1d3ea2 100644 --- a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs +++ b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs @@ -65,13 +65,15 @@ private void DiffRegInsts(Register reg1) for (int _ = 0; _ < 100; _++) { var c = (ulong)rand.NextInt64(); + c |= rand.Next(0, 2) == 0 ? 0 : (1ul << 63); + Diff(nameof(IAmd64Assembler.MovabsRegImm64), reg1, c); - Diff(nameof(IAmd64Assembler.AddRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.SubRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.AndRegImm32), reg1, c); - Diff(nameof(IAmd64Assembler.ShrRegImm8), reg1, c); + Diff(nameof(IAmd64Assembler.AddRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.SubRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.AndRegImm32), reg1, (uint)c); + Diff(nameof(IAmd64Assembler.ShrRegImm8), reg1, (byte)c); if (reg1 != rsp) - Diff(nameof(IAmd64Assembler.PushMem64), reg1, c); + Diff(nameof(IAmd64Assembler.PushMem64), reg1, (int)c); } } diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index d0f1634..7360518 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -24,6 +24,7 @@ Amd64AssemblerDifferentialTester.Test(); +Console.WriteLine("Passed!"); Debugger.Break(); From 080021ff94df447bc60d3f47c7f7ed8d54d3dc75 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 08:23:22 -0500 Subject: [PATCH 13/21] Benchmark jit --- .../Jit/Amd64AssemblerDifferentialTester.cs | 2 + Mba.Simplifier/Jit/FastAmd64Assembler.cs | 283 ++++++++++-------- Simplifier/JitBenchmark.cs | 58 ++++ Simplifier/Program.cs | 4 +- 4 files changed, 220 insertions(+), 127 deletions(-) create mode 100644 Simplifier/JitBenchmark.cs diff --git a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs index f1d3ea2..cafc414 100644 --- a/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs +++ b/Mba.Simplifier/Jit/Amd64AssemblerDifferentialTester.cs @@ -101,6 +101,8 @@ private void DiffRegRegInsts(Register reg1, Register reg2) Diff(nameof(IAmd64Assembler.MovMem64Reg), reg2, (int)c, reg1); Diff(nameof(IAmd64Assembler.MovRegMem64), reg1, reg2, (int)c); Diff(nameof(IAmd64Assembler.MovRegMem64), reg2, reg1, (int)c); + Diff(nameof(IAmd64Assembler.AndMem64Reg), reg1, (int)c, reg2); + Diff(nameof(IAmd64Assembler.AndMem64Reg), reg2, (int)c, reg1); } } diff --git a/Mba.Simplifier/Jit/FastAmd64Assembler.cs b/Mba.Simplifier/Jit/FastAmd64Assembler.cs index bb6cd0e..f99adcb 100644 --- a/Mba.Simplifier/Jit/FastAmd64Assembler.cs +++ b/Mba.Simplifier/Jit/FastAmd64Assembler.cs @@ -11,12 +11,46 @@ namespace Mba.Simplifier.Jit { + unsafe ref struct StackBuffer + { + public byte* Ptr; + + public uint Offset; + + public StackBuffer(byte* ptr) + { + this.Ptr = ptr; + } + + public void PushU8(byte value) + { + Ptr[Offset++] = value; + } + + public void PushI32(int value) + => PushU32((uint)value); + + public void PushU32(uint value) + { + *(uint*)&Ptr[Offset] = value; + Offset += 4; + } + + public void PushU64(ulong value) + { + *(ulong*)&Ptr[Offset] = value; + Offset += 8; + } + } + public unsafe class FastAmd64Assembler : IAmd64Assembler { private byte* start; private byte* ptr; + private int offset = 0; + public List Instructions => GetInstructions(); public FastAmd64Assembler(byte* ptr) @@ -25,62 +59,89 @@ public FastAmd64Assembler(byte* ptr) this.ptr = ptr; } + private unsafe void EmitBytes(Span src) + { + for (int i = 0; i < src.Length; i++) + { + //start[offset + i] = src[i]; + *ptr++ = src[i]; + } + + //offset += length; + } + + private unsafe void EmitBytes(Span src, int len) + { + for (int i = 0; i < len; i++) + { + //start[offset + i] = src[i]; + *ptr++ = src[i]; + } + + //offset += length; + } + + private unsafe void EmitBytes(params byte[] bytes) + { + EmitBytes(bytes.AsSpan()); + //offset += length; + } + + private void EmitBuffer(StackBuffer buffer) + { + EmitBytes(new Span(buffer.Ptr, (int)buffer.Offset)); + } + public void PushReg(Register reg) { if (reg >= Register.RAX && reg <= Register.RDI) { byte opcode = (byte)(0x50 + GetRegisterCode(reg)); - *ptr++ = opcode; + EmitBytes(opcode); + return; } - else if (reg >= Register.R8 && reg <= Register.R15) + if (reg >= Register.R8 && reg <= Register.R15) { - byte rex = 0x41; - *ptr++ = rex; - + byte rex = 0x41; byte opcode = (byte)(0x50 + (int)reg - (int)Register.R8); - *ptr++ = opcode; + EmitBytes(rex, opcode); + return; } - else - { - throw new ArgumentException("Invalid register for PUSH instruction."); - } + throw new ArgumentException("Invalid register for PUSH instruction."); + } - // push qword ptr [baseReg+offset] public void PushMem64(Register baseReg, int offset) { - byte rex = 0x48; - if (IsExtended(baseReg)) rex |= 0x01; + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + + // normal, normal + // normal, sib + // extended, sib + // extended, normal + if (IsExtended(baseReg)) + { + byte rex = 0x49; + arr.PushU8(rex); + } byte opcode = 0xFF; byte modrm = (byte)(0x80 | (0x06 << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(opcode); + arr.PushU8(modrm); if (baseReg == Register.RSP || baseReg == Register.R12) { - if (IsExtended(baseReg)) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - if (IsExtended(baseReg)) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } + arr.PushI32(offset); - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + EmitBuffer(arr); } public void PopReg(Register reg) @@ -88,22 +149,20 @@ public void PopReg(Register reg) if (reg >= Register.RAX && reg <= Register.RDI) { byte opcode = (byte)(0x58 + GetRegisterCode(reg)); - *ptr++ = opcode; + EmitBytes(opcode); + return; } - else if (reg >= Register.R8 && reg <= Register.R15) + if (reg >= Register.R8 && reg <= Register.R15) { byte rex = 0x41; - *ptr++ = rex; - byte opcode = (byte)(0x58 + GetRegisterCode(reg) - 8); - *ptr++ = opcode; + EmitBytes(rex, opcode); + return; } - else - { - throw new ArgumentException($"Cannot pop {reg}"); - } + throw new ArgumentException($"Cannot pop {reg}"); + } public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) @@ -113,9 +172,7 @@ public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) if (IsExtended(reg2)) rex |= 0x04; byte modrm = (byte)(0xC0 | ((GetRegisterCode(reg2) & 0x07) << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void MovRegReg(Register reg1, Register reg2) @@ -123,6 +180,9 @@ public void MovRegReg(Register reg1, Register reg2) public void MovRegMem64(Register dstReg, Register baseReg, int offset) { + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + byte rex = 0x48; if (IsExtended(dstReg)) rex |= 0x04; if (IsExtended(baseReg)) rex |= 0x01; @@ -130,32 +190,27 @@ public void MovRegMem64(Register dstReg, Register baseReg, int offset) byte opcode = 0x8B; byte modrm = (byte)(0x80 | ((GetRegisterCode(dstReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } + arr.PushI32(offset); - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + EmitBuffer(arr); } // mov qword ptr [baseReg + offset], srcReg public void MovMem64Reg(Register baseReg, int offset, Register srcReg) { + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + byte rex = 0x48; if (IsExtended(srcReg)) rex |= 0x04; if (IsExtended(baseReg)) rex |= 0x01; @@ -163,27 +218,18 @@ public void MovMem64Reg(Register baseReg, int offset, Register srcReg) byte opcode = 0x89; byte modrm = (byte)(0x80 | ((GetRegisterCode(srcReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; + arr.PushU8(sib); } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - } - - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + arr.PushI32(offset); + EmitBuffer(arr); } public void MovabsRegImm64(Register reg1, ulong imm64) @@ -193,14 +239,14 @@ public void MovabsRegImm64(Register reg1, ulong imm64) var cond = (reg1 >= Register.RAX && reg1 <= Register.RDI); byte opcode = (byte)(0xB8 + (cond ? GetRegisterCode(reg1) : GetRegisterCode(reg1) - 8)); - *ptr++ = rex; - *ptr++ = opcode; - for (int i = 0; i < 8; i++) - { - *ptr++ = (byte)(imm64 & 0xFF); - imm64 >>= 8; - } + byte* p = stackalloc byte[10]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU64(imm64); + + EmitBuffer(arr); } public void AddRegReg(Register reg1, Register reg2) @@ -221,15 +267,15 @@ public void OpcRegImm(byte mask, Register reg1, uint imm32) byte opcode = 0x81; byte modrm = (byte)(0xC0 | (mask << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(imm32 & 0xFF); - imm32 >>= 8; - } + byte* p = stackalloc byte[7]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + arr.PushU32(imm32); + + EmitBuffer(arr); } public void ImulRegReg(Register reg1, Register reg2) @@ -241,10 +287,7 @@ public void ImulRegReg(Register reg1, Register reg2) byte opcode1 = 0x0F; byte opcode2 = 0xAF; byte modrm = (byte)(0xC0 | ((GetRegisterCode(reg1) & 0x07) << 3) | (GetRegisterCode(reg2) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode1; - *ptr++ = opcode2; - *ptr++ = modrm; + EmitBytes(rex, opcode1, opcode2, modrm); } public void AndRegReg(Register reg1, Register reg2) @@ -262,27 +305,22 @@ public void AndMem64Reg(Register baseReg, int offset, Register srcReg) byte opcode = 0x21; byte modrm = (byte)(0x80 | ((GetRegisterCode(srcReg) & 0x07) << 3) | (GetRegisterCode(baseReg) & 0x07)); + byte* p = stackalloc byte[8]; + var arr = new StackBuffer(ptr); + arr.PushU8(rex); + arr.PushU8(opcode); + arr.PushU8(modrm); + + if (baseReg == Register.RSP || baseReg == Register.R12) { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - byte sib = (byte)(0x00 | (0x04 << 3) | (GetRegisterCode(baseReg) & 0x07)); - *ptr++ = sib; - } - else - { - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + arr.PushU8(sib); } - for (int i = 0; i < 4; i++) - { - *ptr++ = (byte)(offset & 0xFF); - offset >>= 8; - } + arr.PushI32(offset); + + EmitBuffer(arr); } public void OrRegReg(Register reg1, Register reg2) @@ -298,9 +336,7 @@ public void NotReg(Register reg1) byte opcode = 0xF7; byte modrm = (byte)(0xC0 | (0x02 << 3) | (GetRegisterCode(reg1) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void ShlRegCl(Register reg) @@ -317,9 +353,7 @@ public void ShiftRegCl(bool shl, Register reg) byte opcode = 0xD3; var m1 = shl ? 0x04 : 0x05; byte modrm = (byte)(0xC0 | (m1 << 3) | (GetRegisterCode(reg) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + EmitBytes(rex, opcode, modrm); } public void ShrRegImm8(Register reg, byte imm8) @@ -329,10 +363,7 @@ public void ShrRegImm8(Register reg, byte imm8) byte opcode = 0xC1; byte modrm = (byte)(0xC0 | (0x05 << 3) | (GetRegisterCode(reg) & 0x07)); - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; - *ptr++ = imm8; + EmitBytes(rex, opcode, modrm, imm8); } public void CallReg(Register reg1) @@ -343,20 +374,20 @@ public void CallReg(Register reg1) byte opcode = 0xFF; byte modrm = (byte)(0xC0 | (0x02 << 3) | (GetRegisterCode(reg1) & 0x07)); - if (rex != 0x00) - *ptr++ = rex; - *ptr++ = opcode; - *ptr++ = modrm; + if (rex != 0) + EmitBytes(rex, opcode, modrm); + else + EmitBytes(opcode, modrm); } public void Ret() - => *ptr++ = 0xC3; + => EmitBytes(0xC3); private bool IsExtended(Register reg) => reg >= Register.R8 && reg <= Register.R15; - private int GetRegisterCode(Register reg) - => (int)reg - (int)Register.RAX; + private uint GetRegisterCode(Register reg) + => (uint)reg - (uint)Register.RAX; public List GetInstructions() { diff --git a/Simplifier/JitBenchmark.cs b/Simplifier/JitBenchmark.cs new file mode 100644 index 0000000..5f97c41 --- /dev/null +++ b/Simplifier/JitBenchmark.cs @@ -0,0 +1,58 @@ +using Mba.Common.MSiMBA; +using Mba.Simplifier.Bindings; +using Mba.Simplifier.Interpreter; +using Mba.Simplifier.Utility; +using Microsoft.Z3; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Simplifier +{ + public class JitBenchmark + { + public static void Run() + { + var bc = new JitBenchmark(); + while(true) + { + bc.Benchmark(); + } + } + + private readonly AstCtx ctx = new(); + + private readonly AstIdx idx; + + private JitBenchmark() + { + var inputText = "(((-1099511628211:i64*((uns173:i64&(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))|(uns174:i64&(~(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))))))+(uns175:i64*(-2:i64+(-1:i64*uns174:i64))))+(2199023256422:i64*((((-1:i64*(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))+(-1:i64*uns173:i64))+(2:i64*((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns173:i64)))+((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns174:i64))))"; + idx = RustAstParser.Parse(ctx, inputText, 64); + } + + private void Benchmark() + { + var variables = ctx.CollectVariables(idx); + for (int i = 0; i < 1000; i++) + { + var jit = new Amd64OptimizingJit(ctx); + jit.Compile(idx, variables, MultibitSiMBA.JitPage.Value, false); + } + + var sw = Stopwatch.StartNew(); + int limit = 10000; + + for (int i = 0; i < limit; i++) + { + var jit = new Amd64OptimizingJit(ctx); + jit.Compile(idx, variables, MultibitSiMBA.JitPage.Value, false); + } + + sw.Stop(); + Console.WriteLine($"Took {sw.ElapsedMilliseconds}ms to jit {limit} asts"); + } + } +} diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 7360518..a482a77 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -21,7 +21,9 @@ string inputText = null; //DatasetTester.Run(); - +JitBenchmark.Run(); +Console.WriteLine("Finished benchmarking"); +Debugger.Break(); Amd64AssemblerDifferentialTester.Test(); Console.WriteLine("Passed!"); From 5a9b84a65182676a3e83405671859aa3bc7c002b Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 08:44:26 -0500 Subject: [PATCH 14/21] Compact NodeInfo struct to fit in a qword --- Mba.Simplifier/Jit/Amd64OptimizingJit.cs | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index fe79ed6..7d54bde 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -43,12 +43,12 @@ public static Location Stack() public struct NodeInfo { - public uint numUses; + public byte numUses; // Allocate stack slot for the node if numUses > 1 - public uint slotIdx = uint.MaxValue; + public ushort slotIdx = ushort.MaxValue; - public NodeInfo(uint numInstances) + public NodeInfo(byte numInstances) { this.numUses = numInstances; } @@ -73,7 +73,7 @@ public class Amd64OptimizingJit private readonly Dictionary seen = new(); - private uint slotCount = 0; + private ushort slotCount = 0; Stack stack = new(16); @@ -123,6 +123,13 @@ public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, boo WriteInstructions(pagePtr, instructions); } + byte Inc(byte cl) + { + cl = (byte)((uint)cl + 1); + return cl == 0 ? (byte)255 : cl; + } + + private void CollectInfo(AstIdx idx) { if (seen.TryGetValue(idx, out var existing)) @@ -147,15 +154,15 @@ private void CollectInfo(AstIdx idx) var op1 = ctx.GetOp1(idx); CollectInfo(op1); - seen[op0] = new NodeInfo(seen[op0].numUses + 1); - seen[op1] = new NodeInfo(seen[op1].numUses + 1); + seen[op0] = new NodeInfo(Inc(seen[op0].numUses)); + seen[op1] = new NodeInfo(Inc(seen[op1].numUses)); break; case AstOp.Neg: case AstOp.Zext: case AstOp.Trunc: var single = ctx.GetOp0(idx); CollectInfo(single); - seen[single] = new NodeInfo(seen[single].numUses + 1); + seen[single] = new NodeInfo(Inc(seen[single].numUses)); break; default: break; @@ -520,8 +527,11 @@ private void AssignValueSlot(AstIdx idx, NodeInfo nodeInfo) { nodeInfo.slotIdx = slotCount; seen[idx] = nodeInfo; - // Bump slot count up - slotCount += 1; + // Bump slot count up. Throw if we hit the max slot limit + checked + { + slotCount += 1; + } } private static void EmitPrologue(IAmd64Assembler assembler, Register localsRegister, uint numStackSlots) From 9ea9f95a17fb9a68f44ad04fc6c63680a163726e Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 08:55:28 -0500 Subject: [PATCH 15/21] Bug fix comparison --- EqSat/src/simple_ast.rs | 17 +++++++++++++++++ Mba.Simplifier/Jit/Amd64OptimizingJit.cs | 2 +- Simplifier/Program.cs | 2 +- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 87f5bcb..9d9ff57 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -76,6 +76,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Add { a, b }, data); @@ -131,6 +132,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Mul { a, b }, data); @@ -157,6 +159,7 @@ impl Arena { has_poly: true, class: AstClass::Nonlinear, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Pow { a, b }, data); @@ -204,6 +207,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Neg { a }, data); } @@ -222,6 +226,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Lshr { a, b }, data); } @@ -245,6 +250,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Zext { a, to: width }, data); @@ -269,6 +275,7 @@ impl Arena { has_poly: has_poly, class: class, known_bits: known_bits, + imut_data: 0, }; return self.insert_ast_node(SimpleAst::Trunc { a, to: width }, data); @@ -281,6 +288,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::constant(c, width), + imut_data: 0, }; // Reduce the constant modulo 2**width @@ -296,6 +304,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::empty(width), + imut_data: 0, }; return self.insert_ast_node( @@ -322,6 +331,7 @@ impl Arena { has_poly: false, class: AstClass::Bitwise, known_bits: KnownBits::empty(width), + imut_data: 0, }; let symbol_ast_idx = self.insert_ast_node( @@ -417,6 +427,7 @@ impl Arena { has_poly: has_poly, class: max, known_bits: known_bits, + imut_data: 0, }; return data; @@ -490,7 +501,13 @@ pub struct AstData { // Classification of the ast class: AstClass, + // Known zero or one bits known_bits: KnownBits, + + // Internal mutable data for use in different algorithms. + // Specifically we use this field to avoid unnecessarily storing data in hashmaps. + // e.g "how many users does this node have?" can be stored here temporarily. + imut_data: u64, } #[derive(Clone, Hash, PartialEq, Eq)] diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index 7d54bde..b8931ac 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -196,7 +196,7 @@ private unsafe void LowerToX86(List vars) var idx = dfs[i]; var nodeInfo = seen[idx]; // If we've seen this value, load it's value from a local variable slot - if (nodeInfo.numUses > 1 && nodeInfo.slotIdx != uint.MaxValue) + if (nodeInfo.numUses > 1 && nodeInfo.slotIdx != ushort.MaxValue) { LoadSlotValue(nodeInfo.slotIdx); continue; diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index a482a77..2b5f627 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -22,7 +22,7 @@ //DatasetTester.Run(); JitBenchmark.Run(); -Console.WriteLine("Finished benchmarking"); +Console.WriteLine("Finished benchmarking "); Debugger.Break(); Amd64AssemblerDifferentialTester.Test(); From f43860f7c605f9495a55e0dbecee10be7f5242a7 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 10:14:46 -0500 Subject: [PATCH 16/21] Make JIT 3x faster --- EqSat/src/simple_ast.rs | 22 ++++ Mba.Simplifier/Bindings/AstCtx.cs | 12 +- Mba.Simplifier/Jit/Amd64OptimizingJit.cs | 154 ++++++++++++++++++++--- Simplifier/JitBenchmark.cs | 1 + 4 files changed, 171 insertions(+), 18 deletions(-) diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 9d9ff57..72b949b 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -393,6 +393,10 @@ impl Arena { unsafe { self.elements.get_unchecked(idx.0 as usize).1 } } + pub fn set_data(&mut self, idx: AstIdx, data: AstData) { + unsafe { self.elements.get_unchecked_mut(idx.0 as usize).1 = data } + } + pub fn get_bin_width(&self, a: AstIdx, b: AstIdx) -> u8 { let a_width = self.get_width(a); let b_width = self.get_width(b); @@ -1235,6 +1239,24 @@ pub extern "C" fn ContextGetKnownBits(ctx: *mut Context, id: AstIdx) -> KnownBit } } +#[no_mangle] +pub extern "C" fn ContextGetImutData(ctx: *mut Context, id: AstIdx) -> u64 { + unsafe { + let kb = (*ctx).arena.get_data(id).imut_data; + + return kb; + } +} + +#[no_mangle] +pub extern "C" fn ContextSetImutData(ctx: *mut Context, id: AstIdx, imut: u64) { + unsafe { + let mut data = (*ctx).arena.get_data(id).clone(); + data.imut_data = imut; + (*ctx).arena.set_data(id, data); + } +} + #[no_mangle] pub extern "C" fn ContextGetOp0(ctx: *const Context, id: AstIdx) -> AstIdx { unsafe { diff --git a/Mba.Simplifier/Bindings/AstCtx.cs b/Mba.Simplifier/Bindings/AstCtx.cs index 8d9cc92..861c6ab 100644 --- a/Mba.Simplifier/Bindings/AstCtx.cs +++ b/Mba.Simplifier/Bindings/AstCtx.cs @@ -162,6 +162,8 @@ public AstIdx Xor(IEnumerable nodes) public unsafe bool GetHasPoly(AstIdx id) => Api.ContextGetHasPoly(this, id); public unsafe AstClassification GetClass(AstIdx id) => Api.ContextGetClass(this, id); public unsafe KnownBits GetKnownBits(AstIdx id) => Api.ContextGetKnownBits(this, id); + public unsafe ulong GetImutData(AstIdx id) => Api.ContextGetImutData(this, id); + public unsafe void SetImutData(AstIdx id, ulong imut) => Api.ContextSetImutData(this, id, imut); public unsafe AstIdx GetOp0(AstIdx id) => Api.ContextGetOp0(this, id); public unsafe AstIdx GetOp1(AstIdx id) { @@ -353,7 +355,7 @@ public static class Api public unsafe static extern uint ContextGetCost(OpaqueAstCtx* ctx, AstIdx id); [DllImport("eq_sat")] - [SuppressGCTransition] + [SuppressGCTransition] [return: MarshalAs(UnmanagedType.U1)] public unsafe static extern bool ContextGetHasPoly(OpaqueAstCtx* ctx, AstIdx id); @@ -365,6 +367,14 @@ public static class Api [SuppressGCTransition] public unsafe static extern KnownBits ContextGetKnownBits(OpaqueAstCtx* ctx, AstIdx id); + [DllImport("eq_sat")] + [SuppressGCTransition] + public unsafe static extern ulong ContextGetImutData(OpaqueAstCtx* ctx, AstIdx id); + + [DllImport("eq_sat")] + [SuppressGCTransition] + public unsafe static extern void ContextSetImutData(OpaqueAstCtx* ctx, AstIdx id, ulong data); + [DllImport("eq_sat")] [SuppressGCTransition] public unsafe static extern AstIdx ContextGetOp0(OpaqueAstCtx* ctx, AstIdx id); diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index b8931ac..879d110 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -41,22 +41,126 @@ public static Location Stack() public bool IsRegister => Register != Register.None; } - public struct NodeInfo + /* + [StructLayout(LayoutKind.Explicit)] + struct NodeInfo { - public byte numUses; + [FieldOffset(0)] + public ushort numUses; + [FieldOffset(2)] + public ushort varIdx; + [FieldOffset(4)] + public ushort slotIdx = ushort.MaxValue; + [FieldOffset(6)] + public ushort exists = 1; + } + */ - // Allocate stack slot for the node if numUses > 1 + [StructLayout(LayoutKind.Explicit)] + public struct NodeInfo + { + [FieldOffset(0)] + public ushort numUses; + [FieldOffset(2)] + public ushort varIdx; + [FieldOffset(4)] public ushort slotIdx = ushort.MaxValue; + [FieldOffset(6)] + public ushort exists = 1; - public NodeInfo(byte numInstances) + public NodeInfo(ushort numInstances) { this.numUses = numInstances; } + public unsafe ulong ToUlong() + { + fixed(NodeInfo* ptr = &this) + { + return *((ulong*)ptr); + } + } + + public unsafe static NodeInfo FromUlong(ulong x) + { + NodeInfo r; + var ptr = (ulong*)&r; + *ptr = x; + return r; + } + public override string ToString() { return $"numInstances:{numUses}, slotIdx: {slotIdx}"; } + + } + + public interface IInfoStorage + { + public bool Contains(AstIdx idx); + + public NodeInfo Get(AstIdx idx); + + public void Set(AstIdx idx, NodeInfo info); + + public bool TryGet(AstIdx idx, out NodeInfo info); + } + + public class MapInfoStorage : IInfoStorage + { + private readonly Dictionary map = new(); + + public bool Contains(AstIdx idx) + { + return map.ContainsKey(idx); + } + + public NodeInfo Get(AstIdx idx) + { + return map[idx]; + } + + public void Set(AstIdx idx, NodeInfo info) + { + map[idx] = info; + } + + public bool TryGet(AstIdx idx, out NodeInfo info) + { + return map.TryGetValue(idx, out info); + } + } + + public class AuxInfoStorage : IInfoStorage + { + private readonly AstCtx ctx; + + public AuxInfoStorage(AstCtx ctx) + { + this.ctx = ctx; + } + + public bool Contains(AstIdx idx) + { + return Get(idx).exists != 0; + } + + public NodeInfo Get(AstIdx idx) + { + return NodeInfo.FromUlong(ctx.GetImutData(idx)); + } + + public void Set(AstIdx idx, NodeInfo info) + { + ctx.SetImutData(idx, info.ToUlong()); + } + + public bool TryGet(AstIdx idx, out NodeInfo info) + { + info = Get(idx); + return info.exists != 0; + } } // This class implements a JIT compiler to x86 with register allocation and node reuse. @@ -71,7 +175,9 @@ public class Amd64OptimizingJit private readonly List dfs = new(16); - private readonly Dictionary seen = new(); + //private readonly Dictionary seen = new(); + //MapInfoStorage seen = new(); + IInfoStorage seen; private ushort slotCount = 0; @@ -92,6 +198,8 @@ public class Amd64OptimizingJit public Amd64OptimizingJit(AstCtx ctx) { this.ctx = ctx; + seen = new AuxInfoStorage(ctx); + //seen = new MapInfoStorage(); } public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, bool useIcedBackend = false) @@ -102,9 +210,19 @@ public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, boo // Collect information about the nodes necessary for JITing (dfs order, how many users a value has) CollectInfo(idx); + for(int i = 0; i < variables.Count; i++) + { + var vIdx = variables[i]; + var data = seen.Get(vIdx); + data.varIdx = (byte)i; + } + // Compile the instructions to x86. LowerToX86(variables); + foreach (var id in dfs) + ctx.SetImutData(id, 0); + // If using the fast assembler backend, we've already emitted x86. // However the stack pointer adjustment needs to fixed up, because it wasn't known during prologue emission. if (!useIcedBackend) @@ -123,18 +241,18 @@ public unsafe void Compile(AstIdx idx, List variables, nint pagePtr, boo WriteInstructions(pagePtr, instructions); } - byte Inc(byte cl) + ushort Inc(ushort cl) { - cl = (byte)((uint)cl + 1); - return cl == 0 ? (byte)255 : cl; + cl += 1; + return cl == 0 ? ushort.MaxValue : cl; } private void CollectInfo(AstIdx idx) { - if (seen.TryGetValue(idx, out var existing)) + if (seen.TryGet(idx, out var existing)) { - seen[idx] = new NodeInfo(existing.numUses); + seen.Set(idx, new NodeInfo(existing.numUses)); dfs.Add(idx); return; } @@ -154,22 +272,23 @@ private void CollectInfo(AstIdx idx) var op1 = ctx.GetOp1(idx); CollectInfo(op1); - seen[op0] = new NodeInfo(Inc(seen[op0].numUses)); - seen[op1] = new NodeInfo(Inc(seen[op1].numUses)); + seen.Set(op0, new NodeInfo(Inc(seen.Get(op0).numUses))); + seen.Set(op1, new NodeInfo(Inc(seen.Get(op1).numUses))); break; case AstOp.Neg: case AstOp.Zext: case AstOp.Trunc: var single = ctx.GetOp0(idx); CollectInfo(single); - seen[single] = new NodeInfo(Inc(seen[single].numUses)); + seen.Set(single, new NodeInfo(Inc(seen.Get(single).numUses))); break; + case AstOp.Constant: default: break; } dfs.Add(idx); - seen[idx] = new NodeInfo(0); + seen.Set(idx, new NodeInfo(0)); } // Compile the provided DAG to x86 @@ -194,7 +313,7 @@ private unsafe void LowerToX86(List vars) for(int i = 0; i < dfs.Count; i++) { var idx = dfs[i]; - var nodeInfo = seen[idx]; + var nodeInfo = seen.Get(idx); // If we've seen this value, load it's value from a local variable slot if (nodeInfo.numUses > 1 && nodeInfo.slotIdx != ushort.MaxValue) { @@ -407,7 +526,8 @@ private void LowerConstant(AstIdx idx) private void LowerVariable(AstIdx idx, uint width, IReadOnlyList vars) { - uint offset = (uint)vars.IndexOf(idx); + //uint offset = (uint)vars.IndexOf(idx); + uint offset = seen.Get(idx).varIdx; if (freeRegisters.Count != 0) { var dest = freeRegisters.Pop(); @@ -526,7 +646,7 @@ private void ReduceLocationModulo(Location loc, uint width) private void AssignValueSlot(AstIdx idx, NodeInfo nodeInfo) { nodeInfo.slotIdx = slotCount; - seen[idx] = nodeInfo; + seen.Set(idx, nodeInfo); // Bump slot count up. Throw if we hit the max slot limit checked { diff --git a/Simplifier/JitBenchmark.cs b/Simplifier/JitBenchmark.cs index 5f97c41..3d9c75a 100644 --- a/Simplifier/JitBenchmark.cs +++ b/Simplifier/JitBenchmark.cs @@ -35,6 +35,7 @@ private JitBenchmark() private void Benchmark() { + AstIdx.ctx = ctx; var variables = ctx.CollectVariables(idx); for (int i = 0; i < 1000; i++) { From 4cf240246804ae085be05bf157c0f4062fb2fcce Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 10:26:50 -0500 Subject: [PATCH 17/21] Use optimized memcpy impl --- Mba.Simplifier/Jit/FastAmd64Assembler.cs | 31 +++++++++--------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/Mba.Simplifier/Jit/FastAmd64Assembler.cs b/Mba.Simplifier/Jit/FastAmd64Assembler.cs index f99adcb..7e44cd5 100644 --- a/Mba.Simplifier/Jit/FastAmd64Assembler.cs +++ b/Mba.Simplifier/Jit/FastAmd64Assembler.cs @@ -59,37 +59,28 @@ public FastAmd64Assembler(byte* ptr) this.ptr = ptr; } - private unsafe void EmitBytes(Span src) + private unsafe void EmitBytes(params byte[] bytes) { - for (int i = 0; i < src.Length; i++) + fixed(byte* p = &bytes[0]) { - //start[offset + i] = src[i]; - *ptr++ = src[i]; + Memcpy(ptr, p, (uint)bytes.Length); } - //offset += length; + ptr += bytes.Length; + //EmitBytes(bytes.AsSpan()); } - private unsafe void EmitBytes(Span src, int len) + private void EmitBuffer(StackBuffer buffer) { - for (int i = 0; i < len; i++) - { - //start[offset + i] = src[i]; - *ptr++ = src[i]; - } - - //offset += length; + Memcpy(ptr, buffer.Ptr, buffer.Offset); + ptr += buffer.Offset; + //EmitBytes(new Span(buffer.Ptr, (int)buffer.Offset)); } - private unsafe void EmitBytes(params byte[] bytes) + private void Memcpy(void* destination, void* source, uint byteCount) { - EmitBytes(bytes.AsSpan()); - //offset += length; - } + Unsafe.CopyBlockUnaligned(destination, source, byteCount); - private void EmitBuffer(StackBuffer buffer) - { - EmitBytes(new Span(buffer.Ptr, (int)buffer.Offset)); } public void PushReg(Register reg) From 992239c77281597f79541f555ed7da181673204f Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 11:01:47 -0500 Subject: [PATCH 18/21] add new x86 assembler --- EqSat/Cargo.lock | 18 +- EqSat/Cargo.toml | 4 + EqSat/src/assembler.rs | 859 ++++++++++++++++++++++++++++++++++++++++ EqSat/src/main.rs | 1 + EqSat/src/simple_ast.rs | 1 + 5 files changed, 882 insertions(+), 1 deletion(-) create mode 100644 EqSat/src/assembler.rs diff --git a/EqSat/Cargo.lock b/EqSat/Cargo.lock index 93fb5af..73b2441 100644 --- a/EqSat/Cargo.lock +++ b/EqSat/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "ahash" @@ -40,6 +40,7 @@ dependencies = [ "ahash", "cranelift-isle", "foldhash", + "iced-x86", "libc", "mimalloc", "rand", @@ -62,6 +63,21 @@ dependencies = [ "wasi", ] +[[package]] +name = "iced-x86" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c447cff8c7f384a7d4f741cfcff32f75f3ad02b406432e8d6c878d56b1edf6b" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.153" diff --git a/EqSat/Cargo.toml b/EqSat/Cargo.toml index e1c50d3..aac7f92 100644 --- a/EqSat/Cargo.toml +++ b/EqSat/Cargo.toml @@ -20,6 +20,10 @@ mimalloc = { version = "*", default-features = false } # egraph = { path = "./egraph" } foldhash = "=0.1.0" +[dependencies.iced-x86] +version = "1.21.0" +features = ["code_asm"] + [profile.release] debug = true debuginfo-level = 2 diff --git a/EqSat/src/assembler.rs b/EqSat/src/assembler.rs new file mode 100644 index 0000000..4ce0896 --- /dev/null +++ b/EqSat/src/assembler.rs @@ -0,0 +1,859 @@ +use iced_x86::code_asm::*; +use iced_x86::{Instruction, Register}; +use rand::Rng; +use std::fmt::Write; +use std::time::Instant; + +// Wrapper around a stack-allocated byte buffer for building instruction byte sequences +pub struct StackBuffer<'a> { + pub arr: &'a mut [u8], + pub offset: usize, +} + +impl StackBuffer<'_> { + fn push_u8(&mut self, byte: u8) { + unsafe { + *self.arr.get_unchecked_mut(self.offset) = byte; + } + + self.offset += 1; + } + + fn push_i32(&mut self, byte: i32) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut i32; + *ptr = byte; + } + + self.offset += 4; + } + + fn push_u32(&mut self, byte: u32) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut u32; + *ptr = byte; + } + + self.offset += 4; + } + + fn push_u64(&mut self, byte: u64) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut u64; + *ptr = byte; + } + + self.offset += 8; + } +} + +pub struct FastAmd64Assembler { + pub p: *mut u8, + pub offset: usize, +} + +pub trait IAmd64Assembler { + fn push_reg(&mut self, reg: Register); + + fn push_mem64(&mut self, base_reg: Register, offset: i32); + + fn pop_reg(&mut self, reg: Register); + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32); + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register); + + fn movabs_reg_imm64(&mut self, reg: Register, imm: u64); + + fn add_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32); + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32); + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn and_reg_imm32(&mut self, reg: Register, imm: u32); + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register); + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register); + + fn not_reg(&mut self, reg: Register); + + fn shl_reg_cl(&mut self, reg: Register); + + fn shr_reg_cl(&mut self, reg: Register); + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8); + + fn call_reg(&mut self, reg: Register); + + fn ret(&mut self); + + fn get_instructions(&mut self) -> Vec; + + fn get_bytes(&mut self) -> Vec; + + fn reset(&mut self); +} + +/// x86-64 assembler implementation using the iced-x86 library +pub struct IcedAmd64Assembler { + assembler: CodeAssembler, +} + +impl IcedAmd64Assembler { + /// Creates a new IcedAmd64Assembler with 64-bit mode + pub fn new() -> Result { + Ok(Self { + assembler: CodeAssembler::new(64)?, + }) + } + + /// Converts our Register enum to iced-x86's AsmRegister64 + fn conv(reg: Register) -> AsmRegister64 { + match reg { + Register::RAX => rax, + Register::RCX => rcx, + Register::RDX => rdx, + Register::RBX => rbx, + Register::RSP => rsp, + Register::RBP => rbp, + Register::RSI => rsi, + Register::RDI => rdi, + Register::R8 => r8, + Register::R9 => r9, + Register::R10 => r10, + Register::R11 => r11, + Register::R12 => r12, + Register::R13 => r13, + Register::R14 => r14, + Register::R15 => r15, + _ => panic!("Unsupported register"), + } + } +} + +impl IAmd64Assembler for IcedAmd64Assembler { + fn push_reg(&mut self, reg: Register) { + self.assembler.push(Self::conv(reg)).unwrap(); + } + + fn push_mem64(&mut self, base_reg: Register, offset: i32) { + self.assembler + .push(qword_ptr(Self::conv(base_reg) + offset)) + .unwrap(); + } + + fn pop_reg(&mut self, reg: Register) { + self.assembler.pop(Self::conv(reg)).unwrap(); + } + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .mov(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32) { + self.assembler + .mov( + Self::conv(dst_reg), + qword_ptr(Self::conv(base_reg) + offset), + ) + .unwrap(); + } + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + self.assembler + .mov( + qword_ptr(Self::conv(base_reg) + offset), + Self::conv(src_reg), + ) + .unwrap(); + } + + fn movabs_reg_imm64(&mut self, reg: Register, imm: u64) { + self.assembler.mov(Self::conv(reg), imm).unwrap(); + } + + fn add_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .add(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.assembler.add(Self::conv(reg), imm32 as i32).unwrap(); + } + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.assembler.sub(Self::conv(reg), imm32 as i32).unwrap(); + } + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .imul_2(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .and(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn and_reg_imm32(&mut self, reg: Register, imm: u32) { + self.assembler.and(Self::conv(reg), imm as i32).unwrap(); + } + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + self.assembler + .and( + qword_ptr(Self::conv(base_reg) + offset), + Self::conv(src_reg), + ) + .unwrap(); + } + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .or(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.assembler + .xor(Self::conv(reg1), Self::conv(reg2)) + .unwrap(); + } + + fn not_reg(&mut self, reg: Register) { + self.assembler.not(Self::conv(reg)).unwrap(); + } + + fn shl_reg_cl(&mut self, reg: Register) { + self.assembler.shl(Self::conv(reg), cl).unwrap(); + } + + fn shr_reg_cl(&mut self, reg: Register) { + self.assembler.shr(Self::conv(reg), cl).unwrap(); + } + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8) { + self.assembler.shr(Self::conv(reg), imm8 as u32).unwrap(); + } + + fn call_reg(&mut self, reg: Register) { + self.assembler.call(Self::conv(reg)).unwrap(); + } + + fn ret(&mut self) { + self.assembler.ret().unwrap(); + } + + fn get_instructions(&mut self) -> Vec { + self.assembler.instructions().to_vec() + } + + fn get_bytes(&mut self) -> Vec { + self.assembler.assemble(0).unwrap() + } + + fn reset(&mut self) { + self.assembler.reset(); + } +} + +impl FastAmd64Assembler { + fn emit_bytes(&mut self, data: &[u8]) { + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), self.p.add(self.offset), data.len()); + } + + self.offset += data.len(); + } + + fn emit_buffer(&mut self, buffer: &StackBuffer) { + unsafe { + std::ptr::copy_nonoverlapping( + buffer.arr.as_ptr(), + self.p.add(self.offset), + buffer.offset, + ); + } + + self.offset += buffer.offset; + } + + pub fn opcode_reg_reg(&mut self, opcode: u8, reg1: Register, reg2: Register) { + let mut rex = 0x48; + if is_extended(reg1) { + rex |= 0x01; + } + if is_extended(reg2) { + rex |= 0x04; + } + + let modrm = 0xC0 + | ((get_register_code(reg2) as u8 & 0x07) << 3) + | (get_register_code(reg1) as u8 & 0x07); + self.emit_bytes(&[rex, opcode, modrm]); + } + + pub fn opc_reg_imm(&mut self, mask: u8, reg: Register, imm32: u32) { + let p = &mut [0u8; 7]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0x81; + let modrm = 0xC0 | (mask << 3) | (self.get_register_code(reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + arr.push_u32(imm32); + + self.emit_buffer(&arr); + } + + pub fn shift_reg_cl(&mut self, shl: bool, reg: Register) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xD3; + let m1 = if shl { 0x04 } else { 0x05 }; + let modrm = 0xC0 | (m1 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm]); + } + + pub fn is_extended(&mut self, reg: Register) -> bool { + return reg >= Register::R8 && reg <= Register::R15; + } + + pub fn get_register_code(&mut self, reg: Register) -> u8 { + return (reg as u8) - (Register::RAX as u8); + } +} + +impl IAmd64Assembler for FastAmd64Assembler { + fn push_reg(&mut self, reg: Register) { + if reg >= Register::RAX && reg <= Register::RDI { + let opcode = (0x50 + get_register_code(reg)) as u8; + self.emit_bytes(&[opcode]); + return; + } + + let rex = 0x41; + let opcode = (0x50 + reg as u8 - Register::R8 as u8); + self.emit_bytes(&[rex, opcode]); + } + + fn push_mem64(&mut self, baseReg: Register, disp: i32) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + if is_extended(baseReg) { + let rex = 0x49; + arr.push_u8(rex); + } + + let opcode: u8 = 0xFF; + let modrm = (0x80 | (0x06 << 3) | (get_register_code(baseReg) & 0x07)) as u8; + arr.push_u8(opcode); + arr.push_u8(modrm); + + if baseReg == Register::RSP || baseReg == Register::R12 { + let sib: u8 = (0x00 | (0x04 << 3) | (get_register_code(baseReg) & 0x07)) as u8; + arr.push_u8(sib); + } + + arr.push_i32(disp); + self.emit_buffer(&arr); + } + + fn pop_reg(&mut self, reg: Register) { + if reg >= Register::RAX && reg <= Register::RDI { + let opcode = 0x58 + get_register_code(reg) as u8; + self.emit_bytes(&[opcode]); + return; + } + + let rex = 0x41; + let opcode = 0x58 + get_register_code(reg) as u8 - 8; + self.emit_bytes(&[rex, opcode]); + return; + } + + fn mov_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x89, reg1, reg2); + } + + fn mov_reg_mem64(&mut self, dst_reg: Register, base_reg: Register, offset: i32) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if is_extended(dst_reg) { + rex |= 0x04; + } + if is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x8B; + let modrm = 0x80 + | ((get_register_code(dst_reg) as u8 & 0x07) << 3) + | (get_register_code(base_reg) as u8 & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (get_register_code(base_reg) as u8 & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + fn mov_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(src_reg) { + rex |= 0x04; + } + if self.is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x89; + let modrm = 0x80 + | ((self.get_register_code(src_reg) & 0x07) << 3) + | (self.get_register_code(base_reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (self.get_register_code(base_reg) & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + fn movabs_reg_imm64(&mut self, reg: Register, imm64: u64) { + let p = &mut [0u8; 10]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let cond = reg >= Register::RAX && reg <= Register::RDI; + let opcode = 0xB8 + + if cond { + self.get_register_code(reg) + } else { + self.get_register_code(reg) - 8 + }; + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u64(imm64); + + self.emit_buffer(&arr); + } + + fn add_reg_reg(&mut self, dest: Register, src: Register) { + self.opcode_reg_reg(0x01, dest, src); + } + + fn add_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.opc_reg_imm(0x00, reg, imm32); + } + + fn sub_reg_imm32(&mut self, reg: Register, imm32: u32) { + self.opc_reg_imm(0x05, reg, imm32); + } + + fn imul_reg_reg(&mut self, reg1: Register, reg2: Register) { + let mut rex = 0x48; + if self.is_extended(reg1) { + rex |= 0x04; + } + if self.is_extended(reg2) { + rex |= 0x01; + } + + let opcode1 = 0x0F; + let opcode2 = 0xAF; + let modrm = 0xC0 + | ((self.get_register_code(reg1) & 0x07) << 3) + | (self.get_register_code(reg2) & 0x07); + + self.emit_bytes(&[rex, opcode1, opcode2, modrm]); + } + + fn and_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x21, reg1, reg2); + } + + fn and_reg_imm32(&mut self, reg1: Register, imm32: u32) { + self.opc_reg_imm(0x04, reg1, imm32); + } + + fn and_mem64_reg(&mut self, base_reg: Register, offset: i32, src_reg: Register) { + let p = &mut [0u8; 8]; + let mut arr = StackBuffer { arr: p, offset: 0 }; + + let mut rex = 0x48; + if self.is_extended(src_reg) { + rex |= 0x04; + } + if self.is_extended(base_reg) { + rex |= 0x01; + } + + let opcode = 0x21; + let modrm = 0x80 + | ((self.get_register_code(src_reg) & 0x07) << 3) + | (self.get_register_code(base_reg) & 0x07); + + arr.push_u8(rex); + arr.push_u8(opcode); + arr.push_u8(modrm); + + if base_reg == Register::RSP || base_reg == Register::R12 { + let sib = 0x00 | (0x04 << 3) | (self.get_register_code(base_reg) & 0x07); + arr.push_u8(sib); + } + + arr.push_i32(offset); + + self.emit_buffer(&arr); + } + + fn or_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x09, reg1, reg2); + } + + fn xor_reg_reg(&mut self, reg1: Register, reg2: Register) { + self.opcode_reg_reg(0x31, reg1, reg2); + } + + fn not_reg(&mut self, reg: Register) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xF7; + let modrm = 0xC0 | (0x02 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm]); + } + + fn shl_reg_cl(&mut self, reg: Register) { + self.shift_reg_cl(true, reg); + } + + fn shr_reg_cl(&mut self, reg: Register) { + self.shift_reg_cl(false, reg); + } + + fn shr_reg_imm8(&mut self, reg: Register, imm8: u8) { + let mut rex = 0x48; + if self.is_extended(reg) { + rex |= 0x01; + } + + let opcode = 0xC1; + let modrm = 0xC0 | (0x05 << 3) | (self.get_register_code(reg) & 0x07); + + self.emit_bytes(&[rex, opcode, modrm, imm8]); + } + + fn call_reg(&mut self, reg: Register) { + let mut rex = 0x00; + if self.is_extended(reg) { + rex = 0x41; + } + + let opcode = 0xFF; + let modrm = 0xC0 | (0x02 << 3) | (self.get_register_code(reg) & 0x07); + + if rex != 0 { + self.emit_bytes(&[rex, opcode, modrm]); + } else { + self.emit_bytes(&[opcode, modrm]); + } + } + + fn ret(&mut self) { + self.emit_bytes(&[0xC3]); + } + + fn get_instructions(&mut self) -> Vec { + let bytes = self.get_bytes(); + let mut decoder = iced_x86::Decoder::new(64, &bytes, iced_x86::DecoderOptions::NONE); + decoder.set_ip(0); + + let mut instructions = Vec::new(); + while decoder.position() < bytes.len() { + let instruction = decoder.decode(); + instructions.push(instruction); + } + + instructions + } + + fn get_bytes(&mut self) -> Vec { + let mut bytes = Vec::new(); + for i in 0..self.offset { + unsafe { + bytes.push(*self.p.add(i)); + } + } + return bytes; + } + + fn reset(&mut self) { + self.offset = 0; + } +} + +fn is_extended(reg: Register) -> bool { + return reg >= Register::R8 && reg <= Register::R15; +} + +fn get_register_code(reg: Register) -> u8 { + return (reg as u8) - (Register::RAX as u8); +} + +/// Differential tester that compares FastAmd64Assembler against IcedAmd64Assembler +pub struct Amd64AssemblerDifferentialTester { + rand: rand::rngs::ThreadRng, + registers: Vec, + iced_assembler: IcedAmd64Assembler, + fast_assembler: FastAmd64Assembler, +} + +impl Amd64AssemblerDifferentialTester { + /// Creates a new differential tester with the given buffer + pub unsafe fn new(buffer: *mut u8) -> Result { + let registers = vec![ + Register::RAX, + Register::RCX, + Register::RDX, + Register::RBX, + Register::RSI, + Register::RDI, + Register::RSP, + Register::RBP, + Register::R8, + Register::R9, + Register::R10, + Register::R11, + Register::R12, + Register::R13, + Register::R14, + Register::R15, + ]; + + Ok(Self { + rand: rand::thread_rng(), + registers, + iced_assembler: IcedAmd64Assembler::new()?, + fast_assembler: FastAmd64Assembler { + p: buffer, + offset: 0, + }, + }) + } + + /// Test entry point - allocates buffer and runs tests + pub fn test() -> Result<(), Box> { + let mut buffer = vec![0u8; 64 * 4096]; + let ptr = buffer.as_mut_ptr(); + + unsafe { + let mut tester = Self::new(ptr)?; + tester.run()?; + } + + Ok(()) + } + + /// Runs all differential tests + pub fn run(&mut self) -> Result<(), Box> { + for i in 0..self.registers.len() { + let reg1 = self.registers[i]; + self.diff_reg_insts(reg1)?; + + for j in (i + 1)..self.registers.len() { + let reg2 = self.registers[j]; + self.diff_reg_reg_insts(reg1, reg2)?; + } + } + + println!("All differential tests passed!"); + Ok(()) + } + + /// Tests single-register instructions + fn diff_reg_insts(&mut self, reg: Register) -> Result<(), Box> { + self.diff("PushReg", |asm| asm.push_reg(reg))?; + self.diff("PopReg", |asm| asm.pop_reg(reg))?; + self.diff("NotReg", |asm| asm.not_reg(reg))?; + self.diff("ShlRegCl", |asm| asm.shl_reg_cl(reg))?; + self.diff("ShrRegCl", |asm| asm.shr_reg_cl(reg))?; + self.diff("CallReg", |asm| asm.call_reg(reg))?; + + // Test reg, constant instructions + for _ in 0..100 { + let c = self.rand.gen::() as u64; + + self.diff("MovabsRegImm64", |asm| asm.movabs_reg_imm64(reg, c))?; + self.diff("AddRegImm32", |asm| asm.add_reg_imm32(reg, c as u32))?; + self.diff("SubRegImm32", |asm| asm.sub_reg_imm32(reg, c as u32))?; + self.diff("AndRegImm32", |asm| asm.and_reg_imm32(reg, c as u32))?; + self.diff("ShrRegImm8", |asm| asm.shr_reg_imm8(reg, c as u8))?; + + if reg != Register::RSP { + self.diff("PushMem64", |asm| asm.push_mem64(reg, c as i32))?; + } + } + + Ok(()) + } + + /// Tests two-register instructions + fn diff_reg_reg_insts( + &mut self, + reg1: Register, + reg2: Register, + ) -> Result<(), Box> { + // Test reg, reg instructions + self.diff("MovRegReg", |asm| asm.mov_reg_reg(reg1, reg2))?; + self.diff("MovRegReg", |asm| asm.mov_reg_reg(reg2, reg1))?; + self.diff("AddRegReg", |asm| asm.add_reg_reg(reg1, reg2))?; + self.diff("AddRegReg", |asm| asm.add_reg_reg(reg2, reg1))?; + self.diff("AndRegReg", |asm| asm.and_reg_reg(reg1, reg2))?; + self.diff("AndRegReg", |asm| asm.and_reg_reg(reg2, reg1))?; + self.diff("OrRegReg", |asm| asm.or_reg_reg(reg1, reg2))?; + self.diff("OrRegReg", |asm| asm.or_reg_reg(reg2, reg1))?; + self.diff("XorRegReg", |asm| asm.xor_reg_reg(reg1, reg2))?; + self.diff("XorRegReg", |asm| asm.xor_reg_reg(reg2, reg1))?; + self.diff("ImulRegReg", |asm| asm.imul_reg_reg(reg1, reg2))?; + self.diff("ImulRegReg", |asm| asm.imul_reg_reg(reg2, reg1))?; + + // Test reg, reg, constant instructions + for _ in 0..100 { + let c = self.rand.gen::(); + + self.diff("MovMem64Reg", |asm| asm.mov_mem64_reg(reg1, c, reg2))?; + self.diff("MovMem64Reg", |asm| asm.mov_mem64_reg(reg2, c, reg1))?; + self.diff("MovRegMem64", |asm| asm.mov_reg_mem64(reg1, reg2, c))?; + self.diff("MovRegMem64", |asm| asm.mov_reg_mem64(reg2, reg1, c))?; + self.diff("AndMem64Reg", |asm| asm.and_mem64_reg(reg1, c, reg2))?; + self.diff("AndMem64Reg", |asm| asm.and_mem64_reg(reg2, c, reg1))?; + } + + Ok(()) + } + + /// Executes a test function on both assemblers and compares results + fn diff(&mut self, method_name: &str, func: F) -> Result<(), Box> + where + F: Fn(&mut dyn IAmd64Assembler), + { + // Assemble the instruction using both assemblers + func(&mut self.iced_assembler); + func(&mut self.fast_assembler); + + // Compare the results + self.compare(method_name)?; + + // Reset both assemblers + self.iced_assembler.reset(); + self.fast_assembler.reset(); + + Ok(()) + } + + /// Compares the output of both assemblers + fn compare(&mut self, method_name: &str) -> Result<(), Box> { + let iced_insts = self.iced_assembler.get_instructions(); + let iced_bytes = self.iced_assembler.get_bytes(); + let our_insts = self.fast_assembler.get_instructions(); + let our_bytes = self.fast_assembler.get_bytes(); + + if iced_insts.is_empty() || iced_bytes.is_empty() || iced_insts.len() != our_insts.len() { + return Err(format!( + "Method {} failed: instruction count mismatch (iced: {}, ours: {})", + method_name, + iced_insts.len(), + our_insts.len() + ) + .into()); + } + + // Check if instructions are equivalent + if iced_insts.len() == 1 && our_insts.len() == 1 { + let iced_str = format!("{}", iced_insts[0]); + let our_str = format!("{}", our_insts[0]); + + if iced_str != our_str { + return Err(format!( + "Method {} failed: Instruction '{}' and '{}' not equivalent!\nIced bytes: {:?}\nOur bytes: {:?}", + method_name, + iced_str, + our_str, + iced_bytes, + our_bytes + ).into()); + } + } else { + // Compare all instructions + for (i, (iced_inst, our_inst)) in iced_insts.iter().zip(our_insts.iter()).enumerate() { + let iced_str = format!("{}", iced_inst); + let our_str = format!("{}", our_inst); + + if iced_str != our_str { + return Err(format!( + "Method {} failed at instruction {}: '{}' != '{}'", + method_name, i, iced_str, our_str + ) + .into()); + } + } + } + + Ok(()) + } +} diff --git a/EqSat/src/main.rs b/EqSat/src/main.rs index 357ebe0..f3b11e0 100644 --- a/EqSat/src/main.rs +++ b/EqSat/src/main.rs @@ -27,6 +27,7 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; +mod assembler; mod known_bits; mod mba; mod simple_ast; diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 72b949b..72ba99d 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -11,6 +11,7 @@ use ahash::AHashMap; use libc::{c_char, c_void}; use crate::{ + assembler::{self, *}, known_bits::{self, *}, mba::{self, Context as MbaContext}, truth_table_database::{TruthTable, TruthTableDatabase}, From 66b1a72436e177b16b9fbc16f5812c5514c8aef0 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 11:37:53 -0500 Subject: [PATCH 19/21] Implement wrapper for auxiliary data storage --- EqSat/src/simple_ast.rs | 104 +++++++++++++++++++++++ Mba.Simplifier/Jit/FastAmd64Assembler.cs | 44 ++++++++++ 2 files changed, 148 insertions(+) diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 72ba99d..02bbfcd 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -8,7 +8,9 @@ use std::{ }; use ahash::AHashMap; +use iced_x86::Register; use libc::{c_char, c_void}; +use std::marker::PhantomData; use crate::{ assembler::{self, *}, @@ -394,6 +396,10 @@ impl Arena { unsafe { self.elements.get_unchecked(idx.0 as usize).1 } } + pub fn get_data_mut(&mut self, idx: AstIdx) -> &mut AstData { + unsafe { &mut self.elements.get_unchecked_mut(idx.0 as usize).1 } + } + pub fn set_data(&mut self, idx: AstIdx, data: AstData) { unsafe { self.elements.get_unchecked_mut(idx.0 as usize).1 = data } } @@ -2513,3 +2519,101 @@ pub fn get_group_size_index(mask: u64) -> u32 { pub fn get_group_size(idx: u32) -> u32 { return 1 << idx; } + +enum LocKind { + Register, + Stack, +} + +struct Location { + pub register: Register, + pub is_stack: bool, +} + +impl Location { + pub fn is_register(&self) -> bool { + return self.register != Register::None; + } + + pub fn reg(r: Register) -> Location { + return Location { + register: r, + is_stack: false, + }; + } + + pub fn stack() -> Location { + return Location { + register: Register::None, + is_stack: true, + }; + } +} + +trait Exists { + fn is_null(&self) -> bool; +} + +// Assert that `NodeInfo` is 8 bytes in size +const _: () = [(); 1][(core::mem::size_of::() == 8) as usize ^ 1]; + +#[derive(Copy, Clone)] +struct NodeInfo { + pub num_uses: u16, + pub var_idx: u16, + pub slot_idx: u16, + pub exists: u16, +} + +impl From for NodeInfo { + fn from(value: u64) -> Self { + unsafe { + let ptr = (&value) as *const u64 as *const NodeInfo; + *ptr + } + } +} + +impl Into for NodeInfo { + fn into(self) -> u64 { + unsafe { + let ptr = (&self) as *const NodeInfo as *const u64; + *ptr + } + } +} + +impl Exists for NodeInfo { + fn is_null(&self) -> bool { + return self.exists != 0; + } +} + +struct AuxInfoStorage + Into + Exists> { + _marker: PhantomData, +} + +impl + Into + Exists> AuxInfoStorage { + pub fn contains(ctx: &mut Context, idx: AstIdx) -> bool { + let value = Self::get(ctx, idx); + return !value.is_null(); + } + + pub fn get(ctx: &mut Context, idx: AstIdx) -> T { + let value = ctx.arena.get_data(idx).imut_data; + return T::from(value); + } + + pub fn set(ctx: &mut Context, idx: AstIdx, value: T) { + ctx.arena.get_data_mut(idx).imut_data = value.into(); + } + + pub fn try_get(ctx: &mut Context, idx: AstIdx) -> Option { + let value = Self::get(ctx, idx); + if value.is_null() { + return None; + } + + return Some(value); + } +} diff --git a/Mba.Simplifier/Jit/FastAmd64Assembler.cs b/Mba.Simplifier/Jit/FastAmd64Assembler.cs index 7e44cd5..e0fdb79 100644 --- a/Mba.Simplifier/Jit/FastAmd64Assembler.cs +++ b/Mba.Simplifier/Jit/FastAmd64Assembler.cs @@ -6,6 +6,8 @@ using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics.Arm; +using System.Security.Cryptography; using System.Text; using System.Threading.Tasks; @@ -169,7 +171,49 @@ public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) public void MovRegReg(Register reg1, Register reg2) => OpcodeRegReg(0x89, reg1, reg2); + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void MovRegMem64(Register dstReg, Register baseReg, int offset) + { + switch(baseReg) + { + case Register.RAX: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RCX: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RDX: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RBX: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RSP: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RBP: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RSI: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.RDI: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R8: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R9: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R10: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R11: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R12: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R13: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R14: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + case Register.R15: + MovRegMem64Template(dstReg, Register.RAX, offset); break; + default: + break; + } + } + + public void MovRegMem64Template(Register dstReg, Register baseReg, int offset) { byte* p = stackalloc byte[8]; var arr = new StackBuffer(ptr); From 463f37c5863e5bbbef8857d2ec02c030ac100f7d Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 12:21:29 -0500 Subject: [PATCH 20/21] Set default field values for NodeInfo --- EqSat/src/assembler.rs | 18 ++++++++++ EqSat/src/simple_ast.rs | 22 ++++++++++--- Mba.Simplifier/Jit/Amd64OptimizingJit.cs | 15 --------- Mba.Simplifier/Jit/FastAmd64Assembler.cs | 42 ------------------------ 4 files changed, 35 insertions(+), 62 deletions(-) diff --git a/EqSat/src/assembler.rs b/EqSat/src/assembler.rs index 4ce0896..6696721 100644 --- a/EqSat/src/assembler.rs +++ b/EqSat/src/assembler.rs @@ -11,6 +11,24 @@ pub struct StackBuffer<'a> { } impl StackBuffer<'_> { + fn push(&mut self, v: T) { + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut T; + *ptr = v; + } + + self.offset += std::mem::size_of::(); + } + + fn pop(&mut self) -> T { + self.offset -= std::mem::size_of::(); + + unsafe { + let ptr = self.arr.as_mut_ptr().add(self.offset) as *mut T; + *ptr + } + } + fn push_u8(&mut self, byte: u8) { unsafe { *self.arr.get_unchecked_mut(self.offset) = byte; diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 02bbfcd..6552902 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -1,10 +1,11 @@ type Unit = (); +use core::num; use std::{ collections::{hash_map::Entry, HashMap, HashSet}, f32::consts::PI, ffi::{CStr, CString}, - u64, vec, + u16, u64, vec, }; use ahash::AHashMap; @@ -2551,7 +2552,7 @@ impl Location { } trait Exists { - fn is_null(&self) -> bool; + fn exists(&self) -> bool; } // Assert that `NodeInfo` is 8 bytes in size @@ -2565,6 +2566,17 @@ struct NodeInfo { pub exists: u16, } +impl NodeInfo { + pub fn new(num_instances: u16) -> Self { + return NodeInfo { + num_uses: num_instances, + var_idx: 0, + slot_idx: u16::MAX, + exists: 1, + }; + } +} + impl From for NodeInfo { fn from(value: u64) -> Self { unsafe { @@ -2584,7 +2596,7 @@ impl Into for NodeInfo { } impl Exists for NodeInfo { - fn is_null(&self) -> bool { + fn exists(&self) -> bool { return self.exists != 0; } } @@ -2596,7 +2608,7 @@ struct AuxInfoStorage + Into + Exists> { impl + Into + Exists> AuxInfoStorage { pub fn contains(ctx: &mut Context, idx: AstIdx) -> bool { let value = Self::get(ctx, idx); - return !value.is_null(); + return !value.exists(); } pub fn get(ctx: &mut Context, idx: AstIdx) -> T { @@ -2610,7 +2622,7 @@ impl + Into + Exists> AuxInfoStorage { pub fn try_get(ctx: &mut Context, idx: AstIdx) -> Option { let value = Self::get(ctx, idx); - if value.is_null() { + if value.exists() { return None; } diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index 879d110..c4724de 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -41,21 +41,6 @@ public static Location Stack() public bool IsRegister => Register != Register.None; } - /* - [StructLayout(LayoutKind.Explicit)] - struct NodeInfo - { - [FieldOffset(0)] - public ushort numUses; - [FieldOffset(2)] - public ushort varIdx; - [FieldOffset(4)] - public ushort slotIdx = ushort.MaxValue; - [FieldOffset(6)] - public ushort exists = 1; - } - */ - [StructLayout(LayoutKind.Explicit)] public struct NodeInfo { diff --git a/Mba.Simplifier/Jit/FastAmd64Assembler.cs b/Mba.Simplifier/Jit/FastAmd64Assembler.cs index e0fdb79..cb00966 100644 --- a/Mba.Simplifier/Jit/FastAmd64Assembler.cs +++ b/Mba.Simplifier/Jit/FastAmd64Assembler.cs @@ -171,49 +171,7 @@ public void OpcodeRegReg(byte opcode, Register reg1, Register reg2) public void MovRegReg(Register reg1, Register reg2) => OpcodeRegReg(0x89, reg1, reg2); - [MethodImpl(MethodImplOptions.AggressiveInlining)] public void MovRegMem64(Register dstReg, Register baseReg, int offset) - { - switch(baseReg) - { - case Register.RAX: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RCX: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RDX: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RBX: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RSP: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RBP: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RSI: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.RDI: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R8: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R9: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R10: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R11: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R12: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R13: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R14: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - case Register.R15: - MovRegMem64Template(dstReg, Register.RAX, offset); break; - default: - break; - } - } - - public void MovRegMem64Template(Register dstReg, Register baseReg, int offset) { byte* p = stackalloc byte[8]; var arr = new StackBuffer(ptr); From 0e5691965edb92fd992b635bad5207b537ce69a8 Mon Sep 17 00:00:00 2001 From: Colton1skees Date: Fri, 21 Nov 2025 12:56:18 -0500 Subject: [PATCH 21/21] Micro optimizations --- Mba.Simplifier/Jit/Amd64OptimizingJit.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs index c4724de..05274c8 100644 --- a/Mba.Simplifier/Jit/Amd64OptimizingJit.cs +++ b/Mba.Simplifier/Jit/Amd64OptimizingJit.cs @@ -60,7 +60,9 @@ public NodeInfo(ushort numInstances) public unsafe ulong ToUlong() { - fixed(NodeInfo* ptr = &this) + return Unsafe.As(ref this); + //return *((ulong*)&this); + fixed (NodeInfo* ptr = &this) { return *((ulong*)ptr); } @@ -68,6 +70,8 @@ public unsafe ulong ToUlong() public unsafe static NodeInfo FromUlong(ulong x) { + return *((NodeInfo*)&x); + NodeInfo r; var ptr = (ulong*)&r; *ptr = x;