diff --git a/wasm-ast/src/factory.rs b/wasm-ast/src/factory.rs index 8469ec2..3a59eab 100644 --- a/wasm-ast/src/factory.rs +++ b/wasm-ast/src/factory.rs @@ -8,19 +8,9 @@ use crate::{ MemoryCopy, MemoryFill, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, Statement, StoreAt, StoreType, Terminator, UnOp, UnOpType, Value, }, - stack::{ReadType, Stack}, + stack::{ReadGet, Stack}, }; -macro_rules! leak_on { - ($name:tt, $variant:tt) => { - fn $name(&mut self, id: usize) { - let read = ReadType::$variant(id); - - self.stack.leak_into(&mut self.code, |v| v.has_read(read)) - } - }; -} - #[derive(Clone, Copy)] enum BlockVariant { Forward, @@ -73,14 +63,28 @@ impl StatList { } fn leak_pre_call(&mut self) { - self.stack.leak_into(&mut self.code, |v| { - v.has_global_read() || v.has_memory_read() + self.stack.leak_into(&mut self.code, |node| { + ReadGet::run(node, |_| false, |_| true, |_| true) }); } - leak_on!(leak_local_write, Local); - leak_on!(leak_global_write, Global); - leak_on!(leak_memory_write, Memory); + fn leak_local_write(&mut self, id: usize) { + self.stack.leak_into(&mut self.code, |node| { + ReadGet::run(node, |var| var.var() == id, |_| false, |_| false) + }); + } + + fn leak_global_write(&mut self, id: usize) { + self.stack.leak_into(&mut self.code, |node| { + ReadGet::run(node, |_| false, |var| var.var() == id, |_| false) + }); + } + + fn leak_memory_write(&mut self, id: usize) { + self.stack.leak_into(&mut self.code, |node| { + ReadGet::run(node, |_| false, |_| false, |var| var.memory() == id) + }); + } fn push_load(&mut self, load_type: LoadType, memarg: MemArg) { let memory = memarg.memory.try_into().unwrap(); @@ -93,7 +97,7 @@ impl StatList { pointer: self.stack.pop().into(), }); - self.stack.push_with_single(data); + self.stack.push(data); } fn add_store(&mut self, store_type: StoreType, memarg: MemArg) { @@ -119,43 +123,32 @@ impl StatList { } fn push_un_op(&mut self, op_type: UnOpType) { - let rhs = self.stack.pop_with_read(); let data = Expression::UnOp(UnOp { op_type, - rhs: rhs.0.into(), + rhs: self.stack.pop().into(), }); - self.stack.push_with_read(data, rhs.1); + self.stack.push(data); } fn push_bin_op(&mut self, op_type: BinOpType) { - let mut rhs = self.stack.pop_with_read(); - let lhs = self.stack.pop_with_read(); - let data = Expression::BinOp(BinOp { op_type, - rhs: rhs.0.into(), - lhs: lhs.0.into(), + rhs: self.stack.pop().into(), + lhs: self.stack.pop().into(), }); - rhs.1.extend(lhs.1); - - self.stack.push_with_read(data, rhs.1); + self.stack.push(data); } fn push_cmp_op(&mut self, op_type: CmpOpType) { - let mut rhs = self.stack.pop_with_read(); - let lhs = self.stack.pop_with_read(); - let data = Expression::CmpOp(CmpOp { op_type, - rhs: rhs.0.into(), - lhs: lhs.0.into(), + rhs: self.stack.pop().into(), + lhs: self.stack.pop().into(), }); - rhs.1.extend(lhs.1); - - self.stack.push_with_read(data, rhs.1); + self.stack.push(data); } // Eqz is the only unary comparison so it's "emulated" @@ -293,7 +286,9 @@ impl<'a> Factory<'a> { } fn start_else(&mut self) { - let BlockData::If { ty, .. } = self.target.block_data else { unreachable!() }; + let BlockData::If { ty, .. } = self.target.block_data else { + unreachable!() + }; self.target.leak_all(); self.end_block(); @@ -314,7 +309,9 @@ impl<'a> Factory<'a> { on_false: None, }), BlockData::Else { .. } => { - let Statement::If(last) = self.target.code.last_mut().unwrap() else { unreachable!() }; + let Statement::If(last) = self.target.code.last_mut().unwrap() else { + unreachable!() + }; last.on_false = Some(Box::new(now.into())); @@ -504,26 +501,19 @@ impl<'a> Factory<'a> { self.target.stack.pop(); } Operator::Select => { - let mut condition = self.target.stack.pop_with_read(); - let on_false = self.target.stack.pop_with_read(); - let on_true = self.target.stack.pop_with_read(); - let data = Expression::Select(Select { - condition: condition.0.into(), - on_true: on_true.0.into(), - on_false: on_false.0.into(), + condition: self.target.stack.pop().into(), + on_false: self.target.stack.pop().into(), + on_true: self.target.stack.pop().into(), }); - condition.1.extend(on_true.1); - condition.1.extend(on_false.1); - - self.target.stack.push_with_read(data, condition.1); + self.target.stack.push(data); } Operator::LocalGet { local_index } => { let var = local_index.try_into().unwrap(); let data = Expression::GetLocal(Local { var }); - self.target.stack.push_with_single(data); + self.target.stack.push(data); } Operator::LocalSet { local_index } => { let var = local_index.try_into().unwrap(); @@ -544,14 +534,14 @@ impl<'a> Factory<'a> { }); self.target.leak_local_write(var); - self.target.stack.push_with_single(get); + self.target.stack.push(get); self.target.code.push(set); } Operator::GlobalGet { global_index } => { let var = global_index.try_into().unwrap(); let data = Expression::GetGlobal(GetGlobal { var }); - self.target.stack.push_with_single(data); + self.target.stack.push(data); } Operator::GlobalSet { global_index } => { let var = global_index.try_into().unwrap(); diff --git a/wasm-ast/src/stack.rs b/wasm-ast/src/stack.rs index 9a73430..42ee22d 100644 --- a/wasm-ast/src/stack.rs +++ b/wasm-ast/src/stack.rs @@ -1,42 +1,59 @@ -use std::collections::HashSet; - -use crate::node::{ - Align, Expression, GetGlobal, LoadAt, Local, ResultList, SetTemporary, Statement, Temporary, +use crate::{ + node::{ + Align, Expression, GetGlobal, LoadAt, Local, ResultList, SetTemporary, Statement, Temporary, + }, + visit::{Driver, Visitor}, }; -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub enum ReadType { - Local(usize), - Global(usize), - Memory(usize), +pub struct ReadGet { + has_local: A, + has_global: B, + has_memory: C, + result: bool, } -pub struct Slot { - read: HashSet, - data: Expression, +impl ReadGet +where + A: Fn(Local) -> bool, + B: Fn(GetGlobal) -> bool, + C: Fn(&LoadAt) -> bool, +{ + pub fn run>(node: &D, has_local: A, has_global: B, has_memory: C) -> bool { + let mut visitor = Self { + has_local, + has_global, + has_memory, + result: false, + }; + + node.accept(&mut visitor); + + visitor.result + } } -impl Slot { - const fn is_temporary(&self, id: usize) -> bool { - matches!(self.data, Expression::GetTemporary(ref v) if v.var() == id) +impl Visitor for ReadGet +where + A: Fn(Local) -> bool, + B: Fn(GetGlobal) -> bool, + C: Fn(&LoadAt) -> bool, +{ + fn visit_get_global(&mut self, get_global: GetGlobal) { + self.result |= (self.has_global)(get_global); } - pub fn has_read(&self, id: ReadType) -> bool { - self.read.contains(&id) + fn visit_load_at(&mut self, load_at: &LoadAt) { + self.result |= (self.has_memory)(load_at); } - pub fn has_global_read(&self) -> bool { - self.read.iter().any(|r| matches!(r, ReadType::Global(_))) - } - - pub fn has_memory_read(&self) -> bool { - self.read.iter().any(|r| matches!(r, ReadType::Memory(_))) + fn visit_get_local(&mut self, local: Local) { + self.result |= (self.has_local)(local); } } #[derive(Default)] pub struct Stack { - var_list: Vec, + var_list: Vec, pub capacity: usize, pub previous: usize, } @@ -57,41 +74,18 @@ impl Stack { } } - pub fn push_with_read(&mut self, data: Expression, read: HashSet) { - self.var_list.push(Slot { read, data }); - } - pub fn push(&mut self, data: Expression) { - self.push_with_read(data, HashSet::new()); - } - - pub fn push_with_single(&mut self, data: Expression) { - let mut read = HashSet::new(); - let elem = match data { - Expression::GetLocal(Local { var }) => ReadType::Local(var), - Expression::GetGlobal(GetGlobal { var }) => ReadType::Global(var), - Expression::LoadAt(LoadAt { memory, .. }) => ReadType::Memory(memory), - _ => unreachable!(), - }; - - read.insert(elem); - self.var_list.push(Slot { read, data }); - } - - pub fn pop_with_read(&mut self) -> (Expression, HashSet) { - let var = self.var_list.pop().unwrap(); - - (var.data, var.read) + self.var_list.push(data); } pub fn pop(&mut self) -> Expression { - self.pop_with_read().0 + self.var_list.pop().unwrap() } pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator + '_ { let desired = self.len() - len; - self.var_list.drain(desired..).map(|v| v.data) + self.var_list.drain(desired..) } pub fn push_temporaries(&mut self, num: usize) -> ResultList { @@ -129,21 +123,21 @@ impl Stack { // adjusting the capacity and old index accordingly pub fn leak_into

(&mut self, code: &mut Vec, predicate: P) where - P: Fn(&Slot) -> bool, + P: Fn(&Expression) -> bool, { for (i, old) in self.var_list.iter_mut().enumerate() { let var = self.previous + i; + let is_temporary = + matches!(old, Expression::GetTemporary(temporary) if temporary.var() == var); - if old.is_temporary(var) || !predicate(old) { + if is_temporary || !predicate(old) { continue; } - old.read.clear(); - let get = Expression::GetTemporary(Temporary { var }); let set = Statement::SetTemporary(SetTemporary { var: Temporary { var }, - value: std::mem::replace(&mut old.data, get).into(), + value: std::mem::replace(old, get).into(), }); self.capacity = self.capacity.max(var + 1);