Refactor Return behavior

This commit is contained in:
Rerumu 2022-06-12 02:21:20 -04:00
parent b8e40fe740
commit 183db977f3
5 changed files with 120 additions and 81 deletions

View File

@ -5,8 +5,8 @@ use std::{
use parity_wasm::elements::ValueType;
use wasm_ast::node::{
Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, Return, SetGlobal,
SetLocal, SetTemporary, Statement, StoreAt, Terminator,
Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal,
SetTemporary, Statement, StoreAt, Terminator,
};
use super::manager::{
@ -77,21 +77,12 @@ impl Driver for BrTable {
}
}
impl Driver for Return {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "do return ")?;
self.list.as_slice().write(mng, w)?;
write!(w, "end ")
}
}
impl Driver for Terminator {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
match self {
Self::Unreachable => write!(w, "error(\"out of code bounds\")"),
Self::Br(s) => s.write(mng, w),
Self::BrTable(s) => s.write(mng, w),
Self::Return(s) => s.write(mng, w),
}
}
}
@ -301,6 +292,12 @@ impl Driver for FuncData {
mng.num_param = self.num_param;
self.code.write(mng, w)?;
if self.num_result != 0 {
write!(w, "return ")?;
write_ascending("reg", 0..self.num_result, w)?;
write!(w, " ")?;
}
write!(w, "end ")
}
}

View File

@ -5,8 +5,8 @@ use std::{
use parity_wasm::elements::ValueType;
use wasm_ast::node::{
Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, Return, SetGlobal,
SetLocal, SetTemporary, Statement, StoreAt, Terminator,
Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal,
SetTemporary, Statement, StoreAt, Terminator,
};
use super::manager::{
@ -61,21 +61,12 @@ impl Driver for BrTable {
}
}
impl Driver for Return {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "do return ")?;
self.list.as_slice().write(mng, w)?;
write!(w, "end ")
}
}
impl Driver for Terminator {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
match self {
Self::Unreachable => write!(w, "error(\"out of code bounds\")"),
Self::Br(s) => s.write(mng, w),
Self::BrTable(s) => s.write(mng, w),
Self::Return(s) => s.write(mng, w),
}
}
}
@ -306,6 +297,12 @@ impl Driver for FuncData {
mng.num_param = self.num_param;
self.code.write(mng, w)?;
if self.num_result != 0 {
write!(w, "return ")?;
write_ascending("reg", 0..self.num_result, w)?;
write!(w, " ")?;
}
write!(w, "end ")
}
}

View File

@ -6,7 +6,7 @@ use parity_wasm::elements::{
use crate::node::{
Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType,
Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType,
MemoryGrow, MemorySize, Return, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt,
MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt,
StoreType, Terminator, UnOp, UnOpType, Value,
};
@ -117,7 +117,9 @@ struct StatList {
last: Option<Terminator>,
num_result: usize,
num_param: usize,
num_stack: usize,
num_previous: usize,
is_else: bool,
}
@ -139,8 +141,9 @@ impl StatList {
self.num_stack = self.num_stack.max(self.stack.len());
}
fn leak_at(&mut self, var: usize) {
let old = self.stack.get_mut(var).unwrap();
fn leak_at(&mut self, index: usize) {
let old = self.stack.get_mut(index).unwrap();
let var = self.num_previous + index;
if old.is_temporary(var) {
return;
@ -285,6 +288,29 @@ impl StatList {
self.try_add_equal_zero(inst)
}
}
// Return values from a block by leaking the stack and then
// adjusting the start if necessary.
fn set_return_data(&mut self, par_previous: usize, par_result: usize) {
self.leak_all();
// If the start of our copy and parent frame are the same our results
// are already in the right registers.
let start = self.num_previous + self.stack.len() - par_result;
if start == par_previous {
return;
}
for i in 0..par_result {
let data = Statement::SetTemporary(SetTemporary {
var: par_previous + i,
value: Expression::GetTemporary(GetTemporary { var: start + i }),
});
self.code.push(data);
}
}
}
impl From<StatList> for Forward {
@ -333,6 +359,7 @@ impl<'a> Builder<'a> {
FuncData {
local_data: Vec::new(),
num_result: 1,
num_param: 0,
num_stack: data.num_stack,
code: data.into(),
@ -346,33 +373,58 @@ impl<'a> Builder<'a> {
FuncData {
local_data: func.locals().to_vec(),
num_result: arity.num_result,
num_param: arity.num_param,
num_stack: data.num_stack,
code: data.into(),
}
}
// FIXME: Sets up temporaries except when used in a weird way
fn start_block(&mut self, typ: BlockType) {
let mut old = std::mem::take(&mut self.target);
self.target.push_temporary(old.stack.len());
self.target.num_result = match typ {
BlockType::NoResult => 0,
BlockType::Value(_) => 1,
let (num_param, num_result) = match typ {
BlockType::NoResult => (0, 0),
BlockType::Value(_) => (0, 1),
BlockType::TypeIndex(i) => {
let id = i.try_into().unwrap();
let arity = self.type_info.arity_of(id);
self.type_info.arity_of(id).num_result
(arity.num_param, arity.num_result)
}
};
let mut old = std::mem::take(&mut self.target);
old.leak_all();
old.push_temporary(self.target.num_result);
self.target.stack = old.pop_len(num_param);
self.target.num_result = num_result;
self.target.num_param = num_param;
self.target.num_previous = old.num_previous + old.stack.len();
old.push_temporary(num_result);
self.pending.push(old);
}
fn start_else(&mut self) {
let num_result = self.target.num_result;
let num_param = self.target.num_param;
let num_previous = self.target.num_previous;
self.end_block();
let old = std::mem::take(&mut self.target);
self.pending.push(old);
self.target.num_result = num_result;
self.target.num_param = num_param;
self.target.num_previous = num_previous;
self.target.is_else = true;
self.target.push_temporary(num_result);
}
fn end_block(&mut self) {
let old = self.pending.pop().unwrap();
let now = std::mem::replace(&mut self.target, old);
@ -428,12 +480,29 @@ impl<'a> Builder<'a> {
self.target.code.push(data);
}
fn set_return(&mut self) {
let data = Terminator::Return(Return {
list: self.target.pop_len(self.num_result),
});
fn get_relative_block(&self, index: usize) -> Option<&StatList> {
if index == 0 {
Some(&self.target)
} else {
self.pending.get(self.pending.len() - index)
}
}
self.target.leak_all();
fn set_return_data(&mut self, target: usize) {
let (par_previous, par_result) = match self.get_relative_block(target) {
Some(v) => (v.num_previous, v.num_result),
None => (0, self.num_result),
};
self.target.set_return_data(par_previous, par_result);
}
fn set_br_to_block(&mut self, target: usize) {
self.nested_unreachable += 1;
let data = Terminator::Br(Br { target });
self.set_return_data(target);
self.target.last = Some(data);
}
@ -464,6 +533,8 @@ impl<'a> Builder<'a> {
match *inst {
Inst::Unreachable => {
self.nested_unreachable += 1;
self.target.leak_all();
self.target.last = Some(Terminator::Unreachable);
}
Inst::Nop => {}
@ -490,28 +561,17 @@ impl<'a> Builder<'a> {
self.pending.last_mut().unwrap().code.push(data);
}
Inst::Else => {
let num_result = self.target.num_result;
self.target.leak_all();
self.end_block();
self.start_block(BlockType::NoResult);
self.target.num_result = num_result;
self.target.is_else = true;
self.set_return_data(0);
self.start_else();
}
Inst::End => {
self.target.leak_all();
self.set_return_data(0);
self.end_block();
}
Inst::Br(v) => {
self.nested_unreachable += 1;
let target = v.try_into().unwrap();
let data = Terminator::Br(Br {
target: v.try_into().unwrap(),
});
self.target.leak_all();
self.target.last = Some(data);
self.set_br_to_block(target);
}
Inst::BrIf(v) => {
let data = Statement::BrIf(BrIf {
@ -520,23 +580,24 @@ impl<'a> Builder<'a> {
});
// FIXME: Does not push results unless true
// self.target.add_result_data();
self.target.code.push(data);
}
Inst::BrTable(ref v) => {
self.nested_unreachable += 1;
let default = v.default.try_into().unwrap();
let data = Terminator::BrTable(BrTable {
cond: self.target.pop_required(),
data: *v.clone(),
});
self.target.leak_all();
self.set_return_data(default);
self.target.last = Some(data);
}
Inst::Return => {
self.nested_unreachable += 1;
self.set_return();
let target = self.pending.len();
self.set_br_to_block(target);
}
Inst::Call(i) => {
self.add_call(i.try_into().unwrap());
@ -661,6 +722,7 @@ impl<'a> Builder<'a> {
fn build_stat_list(&mut self, list: &[Instruction], num_result: usize) -> StatList {
self.nested_unreachable = 0;
self.num_result = num_result;
self.target.num_result = num_result;
for inst in list.iter().take(list.len() - 1) {
if self.nested_unreachable == 0 {
@ -670,8 +732,8 @@ impl<'a> Builder<'a> {
}
}
if self.nested_unreachable == 0 && num_result != 0 {
self.set_return();
if self.nested_unreachable == 0 {
self.set_return_data(0);
}
std::mem::take(&mut self.target)

View File

@ -669,15 +669,10 @@ pub struct BrTable {
pub data: BrTableData,
}
pub struct Return {
pub list: Vec<Expression>,
}
pub enum Terminator {
Unreachable,
Br(Br),
BrTable(BrTable),
Return(Return),
}
#[derive(Default)]
@ -753,6 +748,7 @@ pub enum Statement {
pub struct FuncData {
pub local_data: Vec<Local>,
pub num_result: usize,
pub num_param: usize,
pub num_stack: usize,
pub code: Forward,

View File

@ -1,7 +1,7 @@
use crate::node::{
Backward, BinOp, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData,
GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Return, Select,
SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value,
GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Select, SetGlobal,
SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value,
};
pub trait Visitor {
@ -35,8 +35,6 @@ pub trait Visitor {
fn visit_br_table(&mut self, _: &BrTable) {}
fn visit_return(&mut self, _: &Return) {}
fn visit_terminator(&mut self, _: &Terminator) {}
fn visit_forward(&mut self, _: &Forward) {}
@ -182,23 +180,12 @@ impl<T: Visitor> Driver<T> for BrTable {
}
}
impl<T: Visitor> Driver<T> for Return {
fn accept(&self, visitor: &mut T) {
for v in &self.list {
v.accept(visitor);
}
visitor.visit_return(self);
}
}
impl<T: Visitor> Driver<T> for Terminator {
fn accept(&self, visitor: &mut T) {
match self {
Self::Unreachable => visitor.visit_unreachable(),
Self::Br(v) => v.accept(visitor),
Self::BrTable(v) => v.accept(visitor),
Self::Return(v) => v.accept(visitor),
}
visitor.visit_terminator(self);