Add comprehensive stack use analysis

This commit is contained in:
Rerumu 2022-06-17 21:14:40 -04:00
parent 684f2d9ad7
commit b3c931a38e
4 changed files with 222 additions and 153 deletions

View File

@ -3,17 +3,26 @@ use parity_wasm::elements::{
Instruction, Module, Type, TypeSection, Instruction, Module, Type, TypeSection,
}; };
use crate::node::{ use crate::{
Align, Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType, node::{
Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType,
MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Expression, Forward, FuncData, GetGlobal, GetLocal, If, LoadAt, LoadType, MemoryGrow,
StoreType, Terminator, UnOp, UnOpType, Value, MemorySize, Select, SetGlobal, SetLocal, Statement, StoreAt, StoreType, Terminator, UnOp,
UnOpType, Value,
},
stack::{ReadType, Stack},
}; };
macro_rules! leak_with_predicate { macro_rules! leak_on {
($name:tt, $predicate:tt) => { ($name:tt, $variant:tt) => {
fn $name(&mut self, id: usize) { fn $name(&mut self, id: usize) {
self.leak_with(|v| v.$predicate(id)); let read = ReadType::$variant(id);
for i in 0..self.stack.var_list.len() {
if self.stack.var_list[i].read.contains(&read) {
self.leak_at(i);
}
}
} }
}; };
} }
@ -151,13 +160,11 @@ impl Default for BlockData {
#[derive(Default)] #[derive(Default)]
struct StatList { struct StatList {
stack: Vec<Expression>, stack: Stack,
code: Vec<Statement>, code: Vec<Statement>,
last: Option<Terminator>, last: Option<Terminator>,
block_data: BlockData, block_data: BlockData,
num_stack: usize,
num_previous: usize,
} }
impl StatList { impl StatList {
@ -165,88 +172,38 @@ impl StatList {
Self::default() Self::default()
} }
fn push_data(&mut self, data: Expression) {
self.stack.push(data);
}
fn pop_required(&mut self) -> Expression {
self.stack.pop().unwrap()
}
fn pop_len(&mut self, len: usize) -> Vec<Expression> {
self.stack.split_off(self.stack.len() - len)
}
fn push_temporary(&mut self, num: usize) {
let len = self.stack.len() + self.num_previous;
for var in len..len + num {
let data = Expression::GetTemporary(GetTemporary { var });
self.push_data(data);
}
self.num_stack = self.num_stack.max(len + num);
}
fn leak_at(&mut self, index: usize) { fn leak_at(&mut self, index: usize) {
let old = self.stack.get_mut(index).unwrap(); if let Some(set) = self.stack.leak_at(index) {
let var = self.num_previous + index; self.code.push(set);
if old.is_temporary(var) {
return;
}
let get = Expression::GetTemporary(GetTemporary { var });
let set = Statement::SetTemporary(SetTemporary {
var,
value: std::mem::replace(old, get),
});
self.num_stack = self.num_stack.max(var + 1);
self.code.push(set);
}
fn leak_with<P>(&mut self, predicate: P)
where
P: Fn(&Expression) -> bool,
{
let pend: Vec<_> = self
.stack
.iter()
.enumerate()
.filter_map(|v| predicate(v.1).then(|| v.0))
.collect();
for var in pend {
self.leak_at(var);
} }
} }
leak_with_predicate!(leak_local_write, is_local_read);
leak_with_predicate!(leak_global_write, is_global_read);
leak_with_predicate!(leak_memory_write, is_memory_read);
fn leak_all(&mut self) { fn leak_all(&mut self) {
self.leak_with(|_| true); for i in 0..self.stack.var_list.len() {
self.leak_at(i);
}
} }
leak_on!(leak_local_write, Local);
leak_on!(leak_global_write, Global);
leak_on!(leak_memory_write, Memory);
fn push_load(&mut self, what: LoadType, offset: u32) { fn push_load(&mut self, what: LoadType, offset: u32) {
let data = Expression::LoadAt(LoadAt { let data = Expression::LoadAt(LoadAt {
what, what,
offset, offset,
pointer: self.pop_required().into(), pointer: self.stack.pop().into(),
}); });
self.push_data(data); self.stack.push_with_single(data);
} }
fn add_store(&mut self, what: StoreType, offset: u32) { fn add_store(&mut self, what: StoreType, offset: u32) {
let data = Statement::StoreAt(StoreAt { let data = Statement::StoreAt(StoreAt {
what, what,
offset, offset,
value: self.pop_required(), value: self.stack.pop(),
pointer: self.pop_required(), pointer: self.stack.pop(),
}); });
self.leak_memory_write(0); self.leak_memory_write(0);
@ -256,36 +213,47 @@ impl StatList {
fn push_constant<T: Into<Value>>(&mut self, value: T) { fn push_constant<T: Into<Value>>(&mut self, value: T) {
let value = Expression::Value(value.into()); let value = Expression::Value(value.into());
self.push_data(value); self.stack.push(value);
} }
fn push_un_op(&mut self, op: UnOpType) { fn push_un_op(&mut self, op: UnOpType) {
let rhs = self.stack.pop_with_read();
let data = Expression::UnOp(UnOp { let data = Expression::UnOp(UnOp {
op, op,
rhs: self.pop_required().into(), rhs: rhs.0.into(),
}); });
self.push_data(data); self.stack.push_with_read(data, rhs.1);
} }
fn push_bin_op(&mut self, op: BinOpType) { fn push_bin_op(&mut self, op: BinOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::BinOp(BinOp { let data = Expression::BinOp(BinOp {
op, op,
rhs: self.pop_required().into(), rhs: rhs.0.into(),
lhs: self.pop_required().into(), lhs: lhs.0.into(),
}); });
self.push_data(data); rhs.1.extend(lhs.1);
self.stack.push_with_read(data, rhs.1);
} }
fn push_cmp_op(&mut self, op: CmpOpType) { fn push_cmp_op(&mut self, op: CmpOpType) {
let mut rhs = self.stack.pop_with_read();
let lhs = self.stack.pop_with_read();
let data = Expression::CmpOp(CmpOp { let data = Expression::CmpOp(CmpOp {
op, op,
rhs: self.pop_required().into(), rhs: rhs.0.into(),
lhs: self.pop_required().into(), lhs: lhs.0.into(),
}); });
self.push_data(data); rhs.1.extend(lhs.1);
self.stack.push_with_read(data, rhs.1);
} }
// Eqz is the only unary comparison so it's "emulated" // Eqz is the only unary comparison so it's "emulated"
@ -327,18 +295,6 @@ impl StatList {
} }
} }
// Return the alignment necessary for this block to branch out to a
// another given block
fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align {
let start = self.stack.len() + self.num_previous - par_result;
Align {
new: par_start,
old: start,
length: par_result,
}
}
fn set_terminator(&mut self, term: Terminator) { fn set_terminator(&mut self, term: Terminator) {
self.leak_all(); self.leak_all();
self.last = Some(term); self.last = Some(term);
@ -391,7 +347,7 @@ impl<'a> Builder<'a> {
local_data: Vec::new(), local_data: Vec::new(),
num_result: 1, num_result: 1,
num_param: 0, num_param: 0,
num_stack: data.num_stack, num_stack: data.stack.capacity,
code: data.into(), code: data.into(),
} }
} }
@ -405,7 +361,7 @@ impl<'a> Builder<'a> {
local_data: func.locals().to_vec(), local_data: func.locals().to_vec(),
num_result: arity.num_result, num_result: arity.num_result,
num_param: arity.num_param, num_param: arity.num_param,
num_stack: data.num_stack, num_stack: data.stack.capacity,
code: data.into(), code: data.into(),
} }
} }
@ -425,18 +381,16 @@ impl<'a> Builder<'a> {
BlockVariant::Backward => BlockData::Backward { num_param }, BlockVariant::Backward => BlockData::Backward { num_param },
BlockVariant::If => BlockData::If { num_result, typ }, BlockVariant::If => BlockData::If { num_result, typ },
BlockVariant::Else => { BlockVariant::Else => {
old.pop_len(num_result); old.stack.pop_len(num_result).for_each(drop);
old.push_temporary(num_param); old.stack.push_temporary(num_param);
BlockData::Else { num_result } BlockData::Else { num_result }
} }
}; };
self.target.stack = old.pop_len(num_param); self.target.stack = old.stack.split_last(num_param);
self.target.num_stack = old.num_stack;
self.target.num_previous = old.num_previous + old.stack.len();
old.push_temporary(num_result); old.stack.push_temporary(num_result);
self.pending.push(old); self.pending.push(old);
} }
@ -457,13 +411,13 @@ impl<'a> Builder<'a> {
let old = self.pending.pop().unwrap(); let old = self.pending.pop().unwrap();
let now = std::mem::replace(&mut self.target, old); let now = std::mem::replace(&mut self.target, old);
self.target.num_stack = now.num_stack; self.target.stack.capacity = now.stack.capacity;
let stat = match now.block_data { let stat = match now.block_data {
BlockData::Forward { .. } => Statement::Forward(now.into()), BlockData::Forward { .. } => Statement::Forward(now.into()),
BlockData::Backward { .. } => Statement::Backward(now.into()), BlockData::Backward { .. } => Statement::Backward(now.into()),
BlockData::If { .. } => Statement::If(If { BlockData::If { .. } => Statement::If(If {
cond: self.target.pop_required(), cond: self.target.stack.pop(),
truthy: now.into(), truthy: now.into(),
falsey: None, falsey: None,
}), }),
@ -498,20 +452,23 @@ impl<'a> Builder<'a> {
BlockData::Backward { num_param } => num_param, BlockData::Backward { num_param } => num_param,
}; };
let align = self.target.get_br_alignment(block.num_previous, par_result); let align = self
.target
.stack
.get_br_alignment(block.stack.previous, par_result);
Br { target, align } Br { target, align }
} }
fn add_call(&mut self, func: usize) { fn add_call(&mut self, func: usize) {
let arity = self.type_info.rel_arity_of(func); let arity = self.type_info.rel_arity_of(func);
let param_list = self.target.pop_len(arity.num_param); let param_list = self.target.stack.pop_len(arity.num_param).collect();
let first = self.target.stack.len(); let first = self.target.stack.var_list.len();
let result = first..first + arity.num_result; let result = first..first + arity.num_result;
self.target.leak_all(); self.target.leak_all();
self.target.push_temporary(arity.num_result); self.target.stack.push_temporary(arity.num_result);
let data = Statement::Call(Call { let data = Statement::Call(Call {
func, func,
@ -524,14 +481,14 @@ impl<'a> Builder<'a> {
fn add_call_indirect(&mut self, typ: usize, table: usize) { fn add_call_indirect(&mut self, typ: usize, table: usize) {
let arity = self.type_info.arity_of(typ); let arity = self.type_info.arity_of(typ);
let index = self.target.pop_required(); let index = self.target.stack.pop();
let param_list = self.target.pop_len(arity.num_param); let param_list = self.target.stack.pop_len(arity.num_param).collect();
let first = self.target.stack.len(); let first = self.target.stack.var_list.len();
let result = first..first + arity.num_result; let result = first..first + arity.num_result;
self.target.leak_all(); self.target.leak_all();
self.target.push_temporary(arity.num_result); self.target.stack.push_temporary(arity.num_result);
let data = Statement::CallIndirect(CallIndirect { let data = Statement::CallIndirect(CallIndirect {
table, table,
@ -588,7 +545,7 @@ impl<'a> Builder<'a> {
self.start_block(typ, BlockVariant::Backward); self.start_block(typ, BlockVariant::Backward);
} }
Inst::If(typ) => { Inst::If(typ) => {
let cond = self.target.pop_required(); let cond = self.target.stack.pop();
self.start_block(typ, BlockVariant::If); self.start_block(typ, BlockVariant::If);
self.pending.last_mut().unwrap().stack.push(cond); self.pending.last_mut().unwrap().stack.push(cond);
@ -609,7 +566,7 @@ impl<'a> Builder<'a> {
} }
Inst::BrIf(v) => { Inst::BrIf(v) => {
let data = Statement::BrIf(BrIf { let data = Statement::BrIf(BrIf {
cond: self.target.pop_required(), cond: self.target.stack.pop(),
target: self.get_br_terminator(v.try_into().unwrap()), target: self.get_br_terminator(v.try_into().unwrap()),
}); });
@ -617,7 +574,7 @@ impl<'a> Builder<'a> {
self.target.code.push(data); self.target.code.push(data);
} }
Inst::BrTable(ref v) => { Inst::BrTable(ref v) => {
let cond = self.target.pop_required(); let cond = self.target.stack.pop();
let data = v let data = v
.table .table
.iter() .iter()
@ -650,35 +607,41 @@ impl<'a> Builder<'a> {
self.add_call_indirect(i.try_into().unwrap(), t.into()); self.add_call_indirect(i.try_into().unwrap(), t.into());
} }
Inst::Drop => { Inst::Drop => {
let last = self.target.stack.len() - 1; let last = self.target.stack.var_list.len() - 1;
if self.target.stack[last].has_side_effect() { if self.target.stack.var_list[last].data.has_side_effect() {
self.target.leak_at(last); self.target.leak_at(last);
} }
self.target.pop_required(); self.target.stack.pop();
} }
Inst::Select => { Inst::Select => {
let mut cond = self.target.stack.pop_with_read();
let b = self.target.stack.pop_with_read();
let a = self.target.stack.pop_with_read();
let data = Expression::Select(Select { let data = Expression::Select(Select {
cond: self.target.pop_required().into(), cond: cond.0.into(),
b: self.target.pop_required().into(), b: b.0.into(),
a: self.target.pop_required().into(), a: a.0.into(),
}); });
self.target.push_data(data); cond.1.extend(b.1);
cond.1.extend(a.1);
self.target.stack.push_with_read(data, cond.1);
} }
Inst::GetLocal(i) => { Inst::GetLocal(i) => {
let data = Expression::GetLocal(GetLocal { let var = i.try_into().unwrap();
var: i.try_into().unwrap(), let data = Expression::GetLocal(GetLocal { var });
});
self.target.push_data(data); self.target.stack.push_with_single(data);
} }
Inst::SetLocal(i) => { Inst::SetLocal(i) => {
let var = i.try_into().unwrap(); let var = i.try_into().unwrap();
let data = Statement::SetLocal(SetLocal { let data = Statement::SetLocal(SetLocal {
var, var,
value: self.target.pop_required(), value: self.target.stack.pop(),
}); });
self.target.leak_local_write(var); self.target.leak_local_write(var);
@ -689,25 +652,24 @@ impl<'a> Builder<'a> {
let get = Expression::GetLocal(GetLocal { var }); let get = Expression::GetLocal(GetLocal { var });
let set = Statement::SetLocal(SetLocal { let set = Statement::SetLocal(SetLocal {
var, var,
value: self.target.pop_required(), value: self.target.stack.pop(),
}); });
self.target.leak_local_write(var); self.target.leak_local_write(var);
self.target.push_data(get); self.target.stack.push_with_single(get);
self.target.code.push(set); self.target.code.push(set);
} }
Inst::GetGlobal(i) => { Inst::GetGlobal(i) => {
let data = Expression::GetGlobal(GetGlobal { let var = i.try_into().unwrap();
var: i.try_into().unwrap(), let data = Expression::GetGlobal(GetGlobal { var });
});
self.target.push_data(data); self.target.stack.push_with_single(data);
} }
Inst::SetGlobal(i) => { Inst::SetGlobal(i) => {
let var = i.try_into().unwrap(); let var = i.try_into().unwrap();
let data = Statement::SetGlobal(SetGlobal { let data = Statement::SetGlobal(SetGlobal {
var, var,
value: self.target.pop_required(), value: self.target.stack.pop(),
}); });
self.target.leak_global_write(var); self.target.leak_global_write(var);
@ -740,16 +702,16 @@ impl<'a> Builder<'a> {
let memory = i.try_into().unwrap(); let memory = i.try_into().unwrap();
let data = Expression::MemorySize(MemorySize { memory }); let data = Expression::MemorySize(MemorySize { memory });
self.target.push_data(data); self.target.stack.push(data);
} }
Inst::GrowMemory(i) => { Inst::GrowMemory(i) => {
let memory = i.try_into().unwrap(); let memory = i.try_into().unwrap();
let data = Expression::MemoryGrow(MemoryGrow { let data = Expression::MemoryGrow(MemoryGrow {
memory, memory,
value: self.target.pop_required().into(), value: self.target.stack.pop().into(),
}); });
self.target.push_data(data); self.target.stack.push(data);
self.target.leak_all(); self.target.leak_all();
} }
Inst::I32Const(v) => self.target.push_constant(v), Inst::I32Const(v) => self.target.push_constant(v),

View File

@ -1,3 +1,4 @@
pub mod builder; pub mod builder;
pub mod node; pub mod node;
mod stack;
pub mod visit; pub mod visit;

View File

@ -637,21 +637,6 @@ impl Expression {
pub fn is_temporary(&self, id: usize) -> bool { pub fn is_temporary(&self, id: usize) -> bool {
matches!(self, Expression::GetTemporary(v) if v.var == id) matches!(self, Expression::GetTemporary(v) if v.var == id)
} }
#[must_use]
pub fn is_local_read(&self, id: usize) -> bool {
matches!(self, Expression::GetLocal(v) if v.var == id)
}
#[must_use]
pub fn is_global_read(&self, id: usize) -> bool {
matches!(self, Expression::GetGlobal(v) if v.var == id)
}
#[must_use]
pub fn is_memory_read(&self, id: usize) -> bool {
id == 0 && matches!(self, Expression::LoadAt(_))
}
} }
pub struct Align { pub struct Align {

121
wasm-ast/src/stack.rs Normal file
View File

@ -0,0 +1,121 @@
use std::collections::HashSet;
use crate::node::{
Align, Expression, GetGlobal, GetLocal, GetTemporary, LoadAt, SetTemporary, Statement,
};
#[derive(PartialEq, Eq, Hash)]
pub enum ReadType {
Local(usize),
Global(usize),
Memory(usize),
}
pub struct Slot {
pub read: HashSet<ReadType>,
pub data: Expression,
}
#[derive(Default)]
pub struct Stack {
pub var_list: Vec<Slot>,
pub capacity: usize,
pub previous: usize,
}
impl Stack {
pub fn split_last(&mut self, len: usize) -> Self {
let desired = self.var_list.len() - len;
let content = self.var_list.split_off(desired);
Self {
var_list: content,
capacity: self.capacity,
previous: self.previous + desired,
}
}
pub fn push_with_read(&mut self, data: Expression, read: HashSet<ReadType>) {
self.var_list.push(Slot { read, data });
}
pub fn push(&mut self, data: Expression) {
self.push_with_read(data, HashSet::new());
}
pub fn push_with_single(&mut self, data: Expression) {
let mut read = HashSet::new();
let elem = match data {
Expression::GetLocal(GetLocal { var }) => ReadType::Local(var),
Expression::GetGlobal(GetGlobal { var }) => ReadType::Global(var),
Expression::LoadAt(LoadAt { .. }) => ReadType::Memory(0),
_ => unreachable!(),
};
read.insert(elem);
self.var_list.push(Slot { read, data });
}
pub fn pop_with_read(&mut self) -> (Expression, HashSet<ReadType>) {
let var = self.var_list.pop().unwrap();
(var.data, var.read)
}
pub fn pop(&mut self) -> Expression {
self.pop_with_read().0
}
pub fn pop_len(&'_ mut self, len: usize) -> impl Iterator<Item = Expression> + '_ {
let desired = self.var_list.len() - len;
self.var_list.drain(desired..).map(|v| v.data)
}
pub fn push_temporary(&mut self, num: usize) {
let len = self.var_list.len() + self.previous;
for var in len..len + num {
let data = Expression::GetTemporary(GetTemporary { var });
self.push(data);
}
self.capacity = self.capacity.max(len + num);
}
// Try to leak a slot's value to a `SetTemporary` instruction,
// adjusting the capacity and old index accordingly
pub fn leak_at(&mut self, index: usize) -> Option<Statement> {
let old = &mut self.var_list[index];
let var = self.previous + index;
if old.data.is_temporary(var) {
return None;
}
old.read.clear();
let get = Expression::GetTemporary(GetTemporary { var });
let set = Statement::SetTemporary(SetTemporary {
var,
value: std::mem::replace(&mut old.data, get),
});
self.capacity = self.capacity.max(var + 1);
Some(set)
}
// Return the alignment necessary for this block to branch out to a
// another given stack frame
pub fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align {
let start = self.var_list.len() + self.previous - par_result;
Align {
new: par_start,
old: start,
length: par_result,
}
}
}