diff --git a/wasm-ast/src/builder.rs b/wasm-ast/src/builder.rs index 8878acc..b37ba19 100644 --- a/wasm-ast/src/builder.rs +++ b/wasm-ast/src/builder.rs @@ -3,17 +3,26 @@ use parity_wasm::elements::{ Instruction, Module, Type, TypeSection, }; -use crate::node::{ - Align, Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType, - Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, - MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, - StoreType, Terminator, UnOp, UnOpType, Value, +use crate::{ + node::{ + Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType, + Expression, Forward, FuncData, GetGlobal, GetLocal, If, LoadAt, LoadType, MemoryGrow, + MemorySize, Select, SetGlobal, SetLocal, Statement, StoreAt, StoreType, Terminator, UnOp, + UnOpType, Value, + }, + stack::{ReadType, Stack}, }; -macro_rules! leak_with_predicate { - ($name:tt, $predicate:tt) => { +macro_rules! leak_on { + ($name:tt, $variant:tt) => { fn $name(&mut self, id: usize) { - self.leak_with(|v| v.$predicate(id)); + let read = ReadType::$variant(id); + + for i in 0..self.stack.var_list.len() { + if self.stack.var_list[i].read.contains(&read) { + self.leak_at(i); + } + } } }; } @@ -151,13 +160,11 @@ impl Default for BlockData { #[derive(Default)] struct StatList { - stack: Vec, + stack: Stack, code: Vec, last: Option, block_data: BlockData, - num_stack: usize, - num_previous: usize, } impl StatList { @@ -165,88 +172,38 @@ impl StatList { Self::default() } - fn push_data(&mut self, data: Expression) { - self.stack.push(data); - } - - fn pop_required(&mut self) -> Expression { - self.stack.pop().unwrap() - } - - fn pop_len(&mut self, len: usize) -> Vec { - self.stack.split_off(self.stack.len() - len) - } - - fn push_temporary(&mut self, num: usize) { - let len = self.stack.len() + self.num_previous; - - for var in len..len + num { - let data = Expression::GetTemporary(GetTemporary { var }); - - self.push_data(data); - } - - self.num_stack = self.num_stack.max(len + num); - } - fn leak_at(&mut self, index: usize) { - let old = self.stack.get_mut(index).unwrap(); - let var = self.num_previous + index; - - if old.is_temporary(var) { - return; - } - - let get = Expression::GetTemporary(GetTemporary { var }); - let set = Statement::SetTemporary(SetTemporary { - var, - value: std::mem::replace(old, get), - }); - - self.num_stack = self.num_stack.max(var + 1); - self.code.push(set); - } - - fn leak_with

(&mut self, predicate: P) - where - P: Fn(&Expression) -> bool, - { - let pend: Vec<_> = self - .stack - .iter() - .enumerate() - .filter_map(|v| predicate(v.1).then(|| v.0)) - .collect(); - - for var in pend { - self.leak_at(var); + if let Some(set) = self.stack.leak_at(index) { + self.code.push(set); } } - leak_with_predicate!(leak_local_write, is_local_read); - leak_with_predicate!(leak_global_write, is_global_read); - leak_with_predicate!(leak_memory_write, is_memory_read); - fn leak_all(&mut self) { - self.leak_with(|_| true); + for i in 0..self.stack.var_list.len() { + self.leak_at(i); + } } + leak_on!(leak_local_write, Local); + leak_on!(leak_global_write, Global); + leak_on!(leak_memory_write, Memory); + fn push_load(&mut self, what: LoadType, offset: u32) { let data = Expression::LoadAt(LoadAt { what, offset, - pointer: self.pop_required().into(), + pointer: self.stack.pop().into(), }); - self.push_data(data); + self.stack.push_with_single(data); } fn add_store(&mut self, what: StoreType, offset: u32) { let data = Statement::StoreAt(StoreAt { what, offset, - value: self.pop_required(), - pointer: self.pop_required(), + value: self.stack.pop(), + pointer: self.stack.pop(), }); self.leak_memory_write(0); @@ -256,36 +213,47 @@ impl StatList { fn push_constant>(&mut self, value: T) { let value = Expression::Value(value.into()); - self.push_data(value); + self.stack.push(value); } fn push_un_op(&mut self, op: UnOpType) { + let rhs = self.stack.pop_with_read(); let data = Expression::UnOp(UnOp { op, - rhs: self.pop_required().into(), + rhs: rhs.0.into(), }); - self.push_data(data); + self.stack.push_with_read(data, rhs.1); } fn push_bin_op(&mut self, op: BinOpType) { + let mut rhs = self.stack.pop_with_read(); + let lhs = self.stack.pop_with_read(); + let data = Expression::BinOp(BinOp { op, - rhs: self.pop_required().into(), - lhs: self.pop_required().into(), + rhs: rhs.0.into(), + lhs: lhs.0.into(), }); - self.push_data(data); + rhs.1.extend(lhs.1); + + self.stack.push_with_read(data, rhs.1); } fn push_cmp_op(&mut self, op: CmpOpType) { + let mut rhs = self.stack.pop_with_read(); + let lhs = self.stack.pop_with_read(); + let data = Expression::CmpOp(CmpOp { op, - rhs: self.pop_required().into(), - lhs: self.pop_required().into(), + rhs: rhs.0.into(), + lhs: lhs.0.into(), }); - self.push_data(data); + rhs.1.extend(lhs.1); + + self.stack.push_with_read(data, rhs.1); } // Eqz is the only unary comparison so it's "emulated" @@ -327,18 +295,6 @@ impl StatList { } } - // Return the alignment necessary for this block to branch out to a - // another given block - fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align { - let start = self.stack.len() + self.num_previous - par_result; - - Align { - new: par_start, - old: start, - length: par_result, - } - } - fn set_terminator(&mut self, term: Terminator) { self.leak_all(); self.last = Some(term); @@ -391,7 +347,7 @@ impl<'a> Builder<'a> { local_data: Vec::new(), num_result: 1, num_param: 0, - num_stack: data.num_stack, + num_stack: data.stack.capacity, code: data.into(), } } @@ -405,7 +361,7 @@ impl<'a> Builder<'a> { local_data: func.locals().to_vec(), num_result: arity.num_result, num_param: arity.num_param, - num_stack: data.num_stack, + num_stack: data.stack.capacity, code: data.into(), } } @@ -425,18 +381,16 @@ impl<'a> Builder<'a> { BlockVariant::Backward => BlockData::Backward { num_param }, BlockVariant::If => BlockData::If { num_result, typ }, BlockVariant::Else => { - old.pop_len(num_result); - old.push_temporary(num_param); + old.stack.pop_len(num_result).for_each(drop); + old.stack.push_temporary(num_param); BlockData::Else { num_result } } }; - self.target.stack = old.pop_len(num_param); - self.target.num_stack = old.num_stack; - self.target.num_previous = old.num_previous + old.stack.len(); + self.target.stack = old.stack.split_last(num_param); - old.push_temporary(num_result); + old.stack.push_temporary(num_result); self.pending.push(old); } @@ -457,13 +411,13 @@ impl<'a> Builder<'a> { let old = self.pending.pop().unwrap(); let now = std::mem::replace(&mut self.target, old); - self.target.num_stack = now.num_stack; + self.target.stack.capacity = now.stack.capacity; let stat = match now.block_data { BlockData::Forward { .. } => Statement::Forward(now.into()), BlockData::Backward { .. } => Statement::Backward(now.into()), BlockData::If { .. } => Statement::If(If { - cond: self.target.pop_required(), + cond: self.target.stack.pop(), truthy: now.into(), falsey: None, }), @@ -498,20 +452,23 @@ impl<'a> Builder<'a> { BlockData::Backward { num_param } => num_param, }; - let align = self.target.get_br_alignment(block.num_previous, par_result); + let align = self + .target + .stack + .get_br_alignment(block.stack.previous, par_result); Br { target, align } } fn add_call(&mut self, func: usize) { let arity = self.type_info.rel_arity_of(func); - let param_list = self.target.pop_len(arity.num_param); + let param_list = self.target.stack.pop_len(arity.num_param).collect(); - let first = self.target.stack.len(); + let first = self.target.stack.var_list.len(); let result = first..first + arity.num_result; self.target.leak_all(); - self.target.push_temporary(arity.num_result); + self.target.stack.push_temporary(arity.num_result); let data = Statement::Call(Call { func, @@ -524,14 +481,14 @@ impl<'a> Builder<'a> { fn add_call_indirect(&mut self, typ: usize, table: usize) { let arity = self.type_info.arity_of(typ); - let index = self.target.pop_required(); - let param_list = self.target.pop_len(arity.num_param); + let index = self.target.stack.pop(); + let param_list = self.target.stack.pop_len(arity.num_param).collect(); - let first = self.target.stack.len(); + let first = self.target.stack.var_list.len(); let result = first..first + arity.num_result; self.target.leak_all(); - self.target.push_temporary(arity.num_result); + self.target.stack.push_temporary(arity.num_result); let data = Statement::CallIndirect(CallIndirect { table, @@ -588,7 +545,7 @@ impl<'a> Builder<'a> { self.start_block(typ, BlockVariant::Backward); } Inst::If(typ) => { - let cond = self.target.pop_required(); + let cond = self.target.stack.pop(); self.start_block(typ, BlockVariant::If); self.pending.last_mut().unwrap().stack.push(cond); @@ -609,7 +566,7 @@ impl<'a> Builder<'a> { } Inst::BrIf(v) => { let data = Statement::BrIf(BrIf { - cond: self.target.pop_required(), + cond: self.target.stack.pop(), target: self.get_br_terminator(v.try_into().unwrap()), }); @@ -617,7 +574,7 @@ impl<'a> Builder<'a> { self.target.code.push(data); } Inst::BrTable(ref v) => { - let cond = self.target.pop_required(); + let cond = self.target.stack.pop(); let data = v .table .iter() @@ -650,35 +607,41 @@ impl<'a> Builder<'a> { self.add_call_indirect(i.try_into().unwrap(), t.into()); } Inst::Drop => { - let last = self.target.stack.len() - 1; + let last = self.target.stack.var_list.len() - 1; - if self.target.stack[last].has_side_effect() { + if self.target.stack.var_list[last].data.has_side_effect() { self.target.leak_at(last); } - self.target.pop_required(); + self.target.stack.pop(); } Inst::Select => { + let mut cond = self.target.stack.pop_with_read(); + let b = self.target.stack.pop_with_read(); + let a = self.target.stack.pop_with_read(); + let data = Expression::Select(Select { - cond: self.target.pop_required().into(), - b: self.target.pop_required().into(), - a: self.target.pop_required().into(), + cond: cond.0.into(), + b: b.0.into(), + a: a.0.into(), }); - self.target.push_data(data); + cond.1.extend(b.1); + cond.1.extend(a.1); + + self.target.stack.push_with_read(data, cond.1); } Inst::GetLocal(i) => { - let data = Expression::GetLocal(GetLocal { - var: i.try_into().unwrap(), - }); + let var = i.try_into().unwrap(); + let data = Expression::GetLocal(GetLocal { var }); - self.target.push_data(data); + self.target.stack.push_with_single(data); } Inst::SetLocal(i) => { let var = i.try_into().unwrap(); let data = Statement::SetLocal(SetLocal { var, - value: self.target.pop_required(), + value: self.target.stack.pop(), }); self.target.leak_local_write(var); @@ -689,25 +652,24 @@ impl<'a> Builder<'a> { let get = Expression::GetLocal(GetLocal { var }); let set = Statement::SetLocal(SetLocal { var, - value: self.target.pop_required(), + value: self.target.stack.pop(), }); self.target.leak_local_write(var); - self.target.push_data(get); + self.target.stack.push_with_single(get); self.target.code.push(set); } Inst::GetGlobal(i) => { - let data = Expression::GetGlobal(GetGlobal { - var: i.try_into().unwrap(), - }); + let var = i.try_into().unwrap(); + let data = Expression::GetGlobal(GetGlobal { var }); - self.target.push_data(data); + self.target.stack.push_with_single(data); } Inst::SetGlobal(i) => { let var = i.try_into().unwrap(); let data = Statement::SetGlobal(SetGlobal { var, - value: self.target.pop_required(), + value: self.target.stack.pop(), }); self.target.leak_global_write(var); @@ -740,16 +702,16 @@ impl<'a> Builder<'a> { let memory = i.try_into().unwrap(); let data = Expression::MemorySize(MemorySize { memory }); - self.target.push_data(data); + self.target.stack.push(data); } Inst::GrowMemory(i) => { let memory = i.try_into().unwrap(); let data = Expression::MemoryGrow(MemoryGrow { memory, - value: self.target.pop_required().into(), + value: self.target.stack.pop().into(), }); - self.target.push_data(data); + self.target.stack.push(data); self.target.leak_all(); } Inst::I32Const(v) => self.target.push_constant(v), diff --git a/wasm-ast/src/lib.rs b/wasm-ast/src/lib.rs index 0dd2259..3fe9dcc 100644 --- a/wasm-ast/src/lib.rs +++ b/wasm-ast/src/lib.rs @@ -1,3 +1,4 @@ pub mod builder; pub mod node; +mod stack; pub mod visit; diff --git a/wasm-ast/src/node.rs b/wasm-ast/src/node.rs index 5f889fa..cb808ee 100644 --- a/wasm-ast/src/node.rs +++ b/wasm-ast/src/node.rs @@ -637,21 +637,6 @@ impl Expression { pub fn is_temporary(&self, id: usize) -> bool { matches!(self, Expression::GetTemporary(v) if v.var == id) } - - #[must_use] - pub fn is_local_read(&self, id: usize) -> bool { - matches!(self, Expression::GetLocal(v) if v.var == id) - } - - #[must_use] - pub fn is_global_read(&self, id: usize) -> bool { - matches!(self, Expression::GetGlobal(v) if v.var == id) - } - - #[must_use] - pub fn is_memory_read(&self, id: usize) -> bool { - id == 0 && matches!(self, Expression::LoadAt(_)) - } } pub struct Align { diff --git a/wasm-ast/src/stack.rs b/wasm-ast/src/stack.rs new file mode 100644 index 0000000..680a4ce --- /dev/null +++ b/wasm-ast/src/stack.rs @@ -0,0 +1,121 @@ +use std::collections::HashSet; + +use crate::node::{ + Align, Expression, GetGlobal, GetLocal, GetTemporary, LoadAt, SetTemporary, Statement, +}; + +#[derive(PartialEq, Eq, Hash)] +pub enum ReadType { + Local(usize), + Global(usize), + Memory(usize), +} + +pub struct Slot { + pub read: HashSet, + pub data: Expression, +} + +#[derive(Default)] +pub struct Stack { + pub var_list: Vec, + pub capacity: usize, + pub previous: usize, +} + +impl Stack { + pub fn split_last(&mut self, len: usize) -> Self { + let desired = self.var_list.len() - len; + let content = self.var_list.split_off(desired); + + Self { + var_list: content, + capacity: self.capacity, + previous: self.previous + desired, + } + } + + pub fn push_with_read(&mut self, data: Expression, read: HashSet) { + self.var_list.push(Slot { read, data }); + } + + pub fn push(&mut self, data: Expression) { + self.push_with_read(data, HashSet::new()); + } + + pub fn push_with_single(&mut self, data: Expression) { + let mut read = HashSet::new(); + let elem = match data { + Expression::GetLocal(GetLocal { var }) => ReadType::Local(var), + Expression::GetGlobal(GetGlobal { var }) => ReadType::Global(var), + Expression::LoadAt(LoadAt { .. }) => ReadType::Memory(0), + _ => unreachable!(), + }; + + read.insert(elem); + self.var_list.push(Slot { read, data }); + } + + pub fn pop_with_read(&mut self) -> (Expression, HashSet) { + let var = self.var_list.pop().unwrap(); + + (var.data, var.read) + } + + pub fn pop(&mut self) -> Expression { + self.pop_with_read().0 + } + + pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator + '_ { + let desired = self.var_list.len() - len; + + self.var_list.drain(desired..).map(|v| v.data) + } + + pub fn push_temporary(&mut self, num: usize) { + let len = self.var_list.len() + self.previous; + + for var in len..len + num { + let data = Expression::GetTemporary(GetTemporary { var }); + + self.push(data); + } + + self.capacity = self.capacity.max(len + num); + } + + // Try to leak a slot's value to a `SetTemporary` instruction, + // adjusting the capacity and old index accordingly + pub fn leak_at(&mut self, index: usize) -> Option { + let old = &mut self.var_list[index]; + let var = self.previous + index; + + if old.data.is_temporary(var) { + return None; + } + + old.read.clear(); + + let get = Expression::GetTemporary(GetTemporary { var }); + let set = Statement::SetTemporary(SetTemporary { + var, + value: std::mem::replace(&mut old.data, get), + }); + + self.capacity = self.capacity.max(var + 1); + + Some(set) + } + + // Return the alignment necessary for this block to branch out to a + // another given stack frame + pub fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align { + let start = self.var_list.len() + self.previous - par_result; + + Align { + new: par_start, + old: start, + length: par_result, + } + } +}