Fix lazy expression reordering

This commit is contained in:
Rerumu 2023-11-03 09:38:08 -04:00
parent bbaa60e8c2
commit 6ca09ca5cd
2 changed files with 91 additions and 107 deletions

View File

@ -8,19 +8,9 @@ use crate::{
MemoryCopy, MemoryFill, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, Statement,
StoreAt, StoreType, Terminator, UnOp, UnOpType, Value,
},
stack::{ReadType, Stack},
stack::{ReadGet, Stack},
};
macro_rules! leak_on {
($name:tt, $variant:tt) => {
fn $name(&mut self, id: usize) {
let read = ReadType::$variant(id);
self.stack.leak_into(&mut self.code, |v| v.has_read(read))
}
};
}
#[derive(Clone, Copy)]
enum BlockVariant {
Forward,
@ -73,14 +63,28 @@ impl StatList {
}
fn leak_pre_call(&mut self) {
self.stack.leak_into(&mut self.code, |v| {
v.has_global_read() || v.has_memory_read()
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |_| false, |_| true, |_| true)
});
}
leak_on!(leak_local_write, Local);
leak_on!(leak_global_write, Global);
leak_on!(leak_memory_write, Memory);
fn leak_local_write(&mut self, id: usize) {
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |var| var.var() == id, |_| false, |_| false)
});
}
fn leak_global_write(&mut self, id: usize) {
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |_| false, |var| var.var() == id, |_| false)
});
}
fn leak_memory_write(&mut self, id: usize) {
self.stack.leak_into(&mut self.code, |node| {
ReadGet::run(node, |_| false, |_| false, |var| var.memory() == id)
});
}
fn push_load(&mut self, load_type: LoadType, memarg: MemArg) {
let memory = memarg.memory.try_into().unwrap();
@ -93,7 +97,7 @@ impl StatList {
pointer: self.stack.pop().into(),
});
self.stack.push_with_single(data);
self.stack.push(data);
}
fn add_store(&mut self, store_type: StoreType, memarg: MemArg) {
@ -119,43 +123,32 @@ impl StatList {
}
fn push_un_op(&mut self, op_type: UnOpType) {
let rhs = self.stack.pop_with_read();
let data = Expression::UnOp(UnOp {
op_type,
rhs: rhs.0.into(),
rhs: self.stack.pop().into(),
});
self.stack.push_with_read(data, rhs.1);
self.stack.push(data);
}
fn push_bin_op(&mut self, op_type: BinOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::BinOp(BinOp {
op_type,
rhs: rhs.0.into(),
lhs: lhs.0.into(),
rhs: self.stack.pop().into(),
lhs: self.stack.pop().into(),
});
rhs.1.extend(lhs.1);
self.stack.push_with_read(data, rhs.1);
self.stack.push(data);
}
fn push_cmp_op(&mut self, op_type: CmpOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::CmpOp(CmpOp {
op_type,
rhs: rhs.0.into(),
lhs: lhs.0.into(),
rhs: self.stack.pop().into(),
lhs: self.stack.pop().into(),
});
rhs.1.extend(lhs.1);
self.stack.push_with_read(data, rhs.1);
self.stack.push(data);
}
// Eqz is the only unary comparison so it's "emulated"
@ -293,7 +286,9 @@ impl<'a> Factory<'a> {
}
fn start_else(&mut self) {
let BlockData::If { ty, .. } = self.target.block_data else { unreachable!() };
let BlockData::If { ty, .. } = self.target.block_data else {
unreachable!()
};
self.target.leak_all();
self.end_block();
@ -314,7 +309,9 @@ impl<'a> Factory<'a> {
on_false: None,
}),
BlockData::Else { .. } => {
let Statement::If(last) = self.target.code.last_mut().unwrap() else { unreachable!() };
let Statement::If(last) = self.target.code.last_mut().unwrap() else {
unreachable!()
};
last.on_false = Some(Box::new(now.into()));
@ -504,26 +501,19 @@ impl<'a> Factory<'a> {
self.target.stack.pop();
}
Operator::Select => {
let mut condition = self.target.stack.pop_with_read();
let on_false = self.target.stack.pop_with_read();
let on_true = self.target.stack.pop_with_read();
let data = Expression::Select(Select {
condition: condition.0.into(),
on_true: on_true.0.into(),
on_false: on_false.0.into(),
condition: self.target.stack.pop().into(),
on_false: self.target.stack.pop().into(),
on_true: self.target.stack.pop().into(),
});
condition.1.extend(on_true.1);
condition.1.extend(on_false.1);
self.target.stack.push_with_read(data, condition.1);
self.target.stack.push(data);
}
Operator::LocalGet { local_index } => {
let var = local_index.try_into().unwrap();
let data = Expression::GetLocal(Local { var });
self.target.stack.push_with_single(data);
self.target.stack.push(data);
}
Operator::LocalSet { local_index } => {
let var = local_index.try_into().unwrap();
@ -544,14 +534,14 @@ impl<'a> Factory<'a> {
});
self.target.leak_local_write(var);
self.target.stack.push_with_single(get);
self.target.stack.push(get);
self.target.code.push(set);
}
Operator::GlobalGet { global_index } => {
let var = global_index.try_into().unwrap();
let data = Expression::GetGlobal(GetGlobal { var });
self.target.stack.push_with_single(data);
self.target.stack.push(data);
}
Operator::GlobalSet { global_index } => {
let var = global_index.try_into().unwrap();

View File

@ -1,42 +1,59 @@
use std::collections::HashSet;
use crate::node::{
use crate::{
node::{
Align, Expression, GetGlobal, LoadAt, Local, ResultList, SetTemporary, Statement, Temporary,
},
visit::{Driver, Visitor},
};
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReadType {
Local(usize),
Global(usize),
Memory(usize),
pub struct ReadGet<A, B, C> {
has_local: A,
has_global: B,
has_memory: C,
result: bool,
}
pub struct Slot {
read: HashSet<ReadType>,
data: Expression,
impl<A, B, C> ReadGet<A, B, C>
where
A: Fn(Local) -> bool,
B: Fn(GetGlobal) -> bool,
C: Fn(&LoadAt) -> bool,
{
pub fn run<D: Driver<Self>>(node: &D, has_local: A, has_global: B, has_memory: C) -> bool {
let mut visitor = Self {
has_local,
has_global,
has_memory,
result: false,
};
node.accept(&mut visitor);
visitor.result
}
}
impl Slot {
const fn is_temporary(&self, id: usize) -> bool {
matches!(self.data, Expression::GetTemporary(ref v) if v.var() == id)
impl<A, B, C> Visitor for ReadGet<A, B, C>
where
A: Fn(Local) -> bool,
B: Fn(GetGlobal) -> bool,
C: Fn(&LoadAt) -> bool,
{
fn visit_get_global(&mut self, get_global: GetGlobal) {
self.result |= (self.has_global)(get_global);
}
pub fn has_read(&self, id: ReadType) -> bool {
self.read.contains(&id)
fn visit_load_at(&mut self, load_at: &LoadAt) {
self.result |= (self.has_memory)(load_at);
}
pub fn has_global_read(&self) -> bool {
self.read.iter().any(|r| matches!(r, ReadType::Global(_)))
}
pub fn has_memory_read(&self) -> bool {
self.read.iter().any(|r| matches!(r, ReadType::Memory(_)))
fn visit_get_local(&mut self, local: Local) {
self.result |= (self.has_local)(local);
}
}
#[derive(Default)]
pub struct Stack {
var_list: Vec<Slot>,
var_list: Vec<Expression>,
pub capacity: usize,
pub previous: usize,
}
@ -57,41 +74,18 @@ impl Stack {
}
}
pub fn push_with_read(&mut self, data: Expression, read: HashSet<ReadType>) {
self.var_list.push(Slot { read, data });
}
pub fn push(&mut self, data: Expression) {
self.push_with_read(data, HashSet::new());
}
pub fn push_with_single(&mut self, data: Expression) {
let mut read = HashSet::new();
let elem = match data {
Expression::GetLocal(Local { var }) => ReadType::Local(var),
Expression::GetGlobal(GetGlobal { var }) => ReadType::Global(var),
Expression::LoadAt(LoadAt { memory, .. }) => ReadType::Memory(memory),
_ => unreachable!(),
};
read.insert(elem);
self.var_list.push(Slot { read, data });
}
pub fn pop_with_read(&mut self) -> (Expression, HashSet<ReadType>) {
let var = self.var_list.pop().unwrap();
(var.data, var.read)
self.var_list.push(data);
}
pub fn pop(&mut self) -> Expression {
self.pop_with_read().0
self.var_list.pop().unwrap()
}
pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator<Item = Expression> + '_ {
let desired = self.len() - len;
self.var_list.drain(desired..).map(|v| v.data)
self.var_list.drain(desired..)
}
pub fn push_temporaries(&mut self, num: usize) -> ResultList {
@ -129,21 +123,21 @@ impl Stack {
// adjusting the capacity and old index accordingly
pub fn leak_into<P>(&mut self, code: &mut Vec<Statement>, predicate: P)
where
P: Fn(&Slot) -> bool,
P: Fn(&Expression) -> bool,
{
for (i, old) in self.var_list.iter_mut().enumerate() {
let var = self.previous + i;
let is_temporary =
matches!(old, Expression::GetTemporary(temporary) if temporary.var() == var);
if old.is_temporary(var) || !predicate(old) {
if is_temporary || !predicate(old) {
continue;
}
old.read.clear();
let get = Expression::GetTemporary(Temporary { var });
let set = Statement::SetTemporary(SetTemporary {
var: Temporary { var },
value: std::mem::replace(&mut old.data, get).into(),
value: std::mem::replace(old, get).into(),
});
self.capacity = self.capacity.max(var + 1);