Fix lazy expression reordering

This commit is contained in:
Rerumu 2023-11-03 09:38:08 -04:00
parent bbaa60e8c2
commit 6ca09ca5cd
2 changed files with 91 additions and 107 deletions

View File

@ -8,19 +8,9 @@ use crate::{
MemoryCopy, MemoryFill, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, Statement, MemoryCopy, MemoryFill, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, Statement,
StoreAt, StoreType, Terminator, UnOp, UnOpType, Value, StoreAt, StoreType, Terminator, UnOp, UnOpType, Value,
}, },
stack::{ReadType, Stack}, stack::{ReadGet, Stack},
}; };
macro_rules! leak_on {
($name:tt, $variant:tt) => {
fn $name(&mut self, id: usize) {
let read = ReadType::$variant(id);
self.stack.leak_into(&mut self.code, |v| v.has_read(read))
}
};
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
enum BlockVariant { enum BlockVariant {
Forward, Forward,
@ -73,14 +63,28 @@ impl StatList {
} }
fn leak_pre_call(&mut self) { fn leak_pre_call(&mut self) {
self.stack.leak_into(&mut self.code, |v| { self.stack.leak_into(&mut self.code, |node| {
v.has_global_read() || v.has_memory_read() ReadGet::run(node, |_| false, |_| true, |_| true)
}); });
} }
leak_on!(leak_local_write, Local); fn leak_local_write(&mut self, id: usize) {
leak_on!(leak_global_write, Global); self.stack.leak_into(&mut self.code, |node| {
leak_on!(leak_memory_write, Memory); ReadGet::run(node, |var| var.var() == id, |_| false, |_| false)
});
}
fn leak_global_write(&mut self, id: usize) {
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |_| false, |var| var.var() == id, |_| false)
});
}
fn leak_memory_write(&mut self, id: usize) {
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |_| false, |_| false, |var| var.memory() == id)
});
}
fn push_load(&mut self, load_type: LoadType, memarg: MemArg) { fn push_load(&mut self, load_type: LoadType, memarg: MemArg) {
let memory = memarg.memory.try_into().unwrap(); let memory = memarg.memory.try_into().unwrap();
@ -93,7 +97,7 @@ impl StatList {
pointer: self.stack.pop().into(), pointer: self.stack.pop().into(),
}); });
self.stack.push_with_single(data); self.stack.push(data);
} }
fn add_store(&mut self, store_type: StoreType, memarg: MemArg) { fn add_store(&mut self, store_type: StoreType, memarg: MemArg) {
@ -119,43 +123,32 @@ impl StatList {
} }
fn push_un_op(&mut self, op_type: UnOpType) { fn push_un_op(&mut self, op_type: UnOpType) {
let rhs = self.stack.pop_with_read();
let data = Expression::UnOp(UnOp { let data = Expression::UnOp(UnOp {
op_type, op_type,
rhs: rhs.0.into(), rhs: self.stack.pop().into(),
}); });
self.stack.push_with_read(data, rhs.1); self.stack.push(data);
} }
fn push_bin_op(&mut self, op_type: BinOpType) { fn push_bin_op(&mut self, op_type: BinOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::BinOp(BinOp { let data = Expression::BinOp(BinOp {
op_type, op_type,
rhs: rhs.0.into(), rhs: self.stack.pop().into(),
lhs: lhs.0.into(), lhs: self.stack.pop().into(),
}); });
rhs.1.extend(lhs.1); self.stack.push(data);
self.stack.push_with_read(data, rhs.1);
} }
fn push_cmp_op(&mut self, op_type: CmpOpType) { fn push_cmp_op(&mut self, op_type: CmpOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::CmpOp(CmpOp { let data = Expression::CmpOp(CmpOp {
op_type, op_type,
rhs: rhs.0.into(), rhs: self.stack.pop().into(),
lhs: lhs.0.into(), lhs: self.stack.pop().into(),
}); });
rhs.1.extend(lhs.1); self.stack.push(data);
self.stack.push_with_read(data, rhs.1);
} }
// Eqz is the only unary comparison so it's "emulated" // Eqz is the only unary comparison so it's "emulated"
@ -293,7 +286,9 @@ impl<'a> Factory<'a> {
} }
fn start_else(&mut self) { fn start_else(&mut self) {
let BlockData::If { ty, .. } = self.target.block_data else { unreachable!() }; let BlockData::If { ty, .. } = self.target.block_data else {
unreachable!()
};
self.target.leak_all(); self.target.leak_all();
self.end_block(); self.end_block();
@ -314,7 +309,9 @@ impl<'a> Factory<'a> {
on_false: None, on_false: None,
}), }),
BlockData::Else { .. } => { BlockData::Else { .. } => {
let Statement::If(last) = self.target.code.last_mut().unwrap() else { unreachable!() }; let Statement::If(last) = self.target.code.last_mut().unwrap() else {
unreachable!()
};
last.on_false = Some(Box::new(now.into())); last.on_false = Some(Box::new(now.into()));
@ -504,26 +501,19 @@ impl<'a> Factory<'a> {
self.target.stack.pop(); self.target.stack.pop();
} }
Operator::Select => { Operator::Select => {
let mut condition = self.target.stack.pop_with_read();
let on_false = self.target.stack.pop_with_read();
let on_true = self.target.stack.pop_with_read();
let data = Expression::Select(Select { let data = Expression::Select(Select {
condition: condition.0.into(), condition: self.target.stack.pop().into(),
on_true: on_true.0.into(), on_false: self.target.stack.pop().into(),
on_false: on_false.0.into(), on_true: self.target.stack.pop().into(),
}); });
condition.1.extend(on_true.1); self.target.stack.push(data);
condition.1.extend(on_false.1);
self.target.stack.push_with_read(data, condition.1);
} }
Operator::LocalGet { local_index } => { Operator::LocalGet { local_index } => {
let var = local_index.try_into().unwrap(); let var = local_index.try_into().unwrap();
let data = Expression::GetLocal(Local { var }); let data = Expression::GetLocal(Local { var });
self.target.stack.push_with_single(data); self.target.stack.push(data);
} }
Operator::LocalSet { local_index } => { Operator::LocalSet { local_index } => {
let var = local_index.try_into().unwrap(); let var = local_index.try_into().unwrap();
@ -544,14 +534,14 @@ impl<'a> Factory<'a> {
}); });
self.target.leak_local_write(var); self.target.leak_local_write(var);
self.target.stack.push_with_single(get); self.target.stack.push(get);
self.target.code.push(set); self.target.code.push(set);
} }
Operator::GlobalGet { global_index } => { Operator::GlobalGet { global_index } => {
let var = global_index.try_into().unwrap(); let var = global_index.try_into().unwrap();
let data = Expression::GetGlobal(GetGlobal { var }); let data = Expression::GetGlobal(GetGlobal { var });
self.target.stack.push_with_single(data); self.target.stack.push(data);
} }
Operator::GlobalSet { global_index } => { Operator::GlobalSet { global_index } => {
let var = global_index.try_into().unwrap(); let var = global_index.try_into().unwrap();

View File

@ -1,42 +1,59 @@
use std::collections::HashSet; use crate::{
node::{
use crate::node::{ Align, Expression, GetGlobal, LoadAt, Local, ResultList, SetTemporary, Statement, Temporary,
Align, Expression, GetGlobal, LoadAt, Local, ResultList, SetTemporary, Statement, Temporary, },
visit::{Driver, Visitor},
}; };
#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct ReadGet<A, B, C> {
pub enum ReadType { has_local: A,
Local(usize), has_global: B,
Global(usize), has_memory: C,
Memory(usize), result: bool,
} }
pub struct Slot { impl<A, B, C> ReadGet<A, B, C>
read: HashSet<ReadType>, where
data: Expression, A: Fn(Local) -> bool,
B: Fn(GetGlobal) -> bool,
C: Fn(&LoadAt) -> bool,
{
pub fn run<D: Driver<Self>>(node: &D, has_local: A, has_global: B, has_memory: C) -> bool {
let mut visitor = Self {
has_local,
has_global,
has_memory,
result: false,
};
node.accept(&mut visitor);
visitor.result
}
} }
impl Slot { impl<A, B, C> Visitor for ReadGet<A, B, C>
const fn is_temporary(&self, id: usize) -> bool { where
matches!(self.data, Expression::GetTemporary(ref v) if v.var() == id) A: Fn(Local) -> bool,
B: Fn(GetGlobal) -> bool,
C: Fn(&LoadAt) -> bool,
{
fn visit_get_global(&mut self, get_global: GetGlobal) {
self.result |= (self.has_global)(get_global);
} }
pub fn has_read(&self, id: ReadType) -> bool { fn visit_load_at(&mut self, load_at: &LoadAt) {
self.read.contains(&id) self.result |= (self.has_memory)(load_at);
} }
pub fn has_global_read(&self) -> bool { fn visit_get_local(&mut self, local: Local) {
self.read.iter().any(|r| matches!(r, ReadType::Global(_))) self.result |= (self.has_local)(local);
}
pub fn has_memory_read(&self) -> bool {
self.read.iter().any(|r| matches!(r, ReadType::Memory(_)))
} }
} }
#[derive(Default)] #[derive(Default)]
pub struct Stack { pub struct Stack {
var_list: Vec<Slot>, var_list: Vec<Expression>,
pub capacity: usize, pub capacity: usize,
pub previous: usize, pub previous: usize,
} }
@ -57,41 +74,18 @@ impl Stack {
} }
} }
pub fn push_with_read(&mut self, data: Expression, read: HashSet<ReadType>) {
self.var_list.push(Slot { read, data });
}
pub fn push(&mut self, data: Expression) { pub fn push(&mut self, data: Expression) {
self.push_with_read(data, HashSet::new()); self.var_list.push(data);
}
pub fn push_with_single(&mut self, data: Expression) {
let mut read = HashSet::new();
let elem = match data {
Expression::GetLocal(Local { var }) => ReadType::Local(var),
Expression::GetGlobal(GetGlobal { var }) => ReadType::Global(var),
Expression::LoadAt(LoadAt { memory, .. }) => ReadType::Memory(memory),
_ => unreachable!(),
};
read.insert(elem);
self.var_list.push(Slot { read, data });
}
pub fn pop_with_read(&mut self) -> (Expression, HashSet<ReadType>) {
let var = self.var_list.pop().unwrap();
(var.data, var.read)
} }
pub fn pop(&mut self) -> Expression { pub fn pop(&mut self) -> Expression {
self.pop_with_read().0 self.var_list.pop().unwrap()
} }
pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator<Item = Expression> + '_ { pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator<Item = Expression> + '_ {
let desired = self.len() - len; let desired = self.len() - len;
self.var_list.drain(desired..).map(|v| v.data) self.var_list.drain(desired..)
} }
pub fn push_temporaries(&mut self, num: usize) -> ResultList { pub fn push_temporaries(&mut self, num: usize) -> ResultList {
@ -129,21 +123,21 @@ impl Stack {
// adjusting the capacity and old index accordingly // adjusting the capacity and old index accordingly
pub fn leak_into<P>(&mut self, code: &mut Vec<Statement>, predicate: P) pub fn leak_into<P>(&mut self, code: &mut Vec<Statement>, predicate: P)
where where
P: Fn(&Slot) -> bool, P: Fn(&Expression) -> bool,
{ {
for (i, old) in self.var_list.iter_mut().enumerate() { for (i, old) in self.var_list.iter_mut().enumerate() {
let var = self.previous + i; let var = self.previous + i;
let is_temporary =
matches!(old, Expression::GetTemporary(temporary) if temporary.var() == var);
if old.is_temporary(var) || !predicate(old) { if is_temporary || !predicate(old) {
continue; continue;
} }
old.read.clear();
let get = Expression::GetTemporary(Temporary { var }); let get = Expression::GetTemporary(Temporary { var });
let set = Statement::SetTemporary(SetTemporary { let set = Statement::SetTemporary(SetTemporary {
var: Temporary { var }, var: Temporary { var },
value: std::mem::replace(&mut old.data, get).into(), value: std::mem::replace(old, get).into(),
}); });
self.capacity = self.capacity.max(var + 1); self.capacity = self.capacity.max(var + 1);