From 183db977f310cdce34de0508fe849ac0737c2581 Mon Sep 17 00:00:00 2001 From: Rerumu Date: Sun, 12 Jun 2022 02:21:20 -0400 Subject: [PATCH] Refactor `Return` behavior --- codegen-luajit/src/backend/statement.rs | 19 ++-- codegen-luau/src/backend/statement.rs | 19 ++-- wasm-ast/src/builder.rs | 140 +++++++++++++++++------- wasm-ast/src/node.rs | 6 +- wasm-ast/src/visit.rs | 17 +-- 5 files changed, 120 insertions(+), 81 deletions(-) diff --git a/codegen-luajit/src/backend/statement.rs b/codegen-luajit/src/backend/statement.rs index 5b2bafa..aa46d7d 100644 --- a/codegen-luajit/src/backend/statement.rs +++ b/codegen-luajit/src/backend/statement.rs @@ -5,8 +5,8 @@ use std::{ use parity_wasm::elements::ValueType; use wasm_ast::node::{ - Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, Return, SetGlobal, - SetLocal, SetTemporary, Statement, StoreAt, Terminator, + Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, + SetTemporary, Statement, StoreAt, Terminator, }; use super::manager::{ @@ -77,21 +77,12 @@ impl Driver for BrTable { } } -impl Driver for Return { - fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - write!(w, "do return ")?; - self.list.as_slice().write(mng, w)?; - write!(w, "end ") - } -} - impl Driver for Terminator { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { match self { Self::Unreachable => write!(w, "error(\"out of code bounds\")"), Self::Br(s) => s.write(mng, w), Self::BrTable(s) => s.write(mng, w), - Self::Return(s) => s.write(mng, w), } } } @@ -301,6 +292,12 @@ impl Driver for FuncData { mng.num_param = self.num_param; self.code.write(mng, w)?; + if self.num_result != 0 { + write!(w, "return ")?; + write_ascending("reg", 0..self.num_result, w)?; + write!(w, " ")?; + } + write!(w, "end ") } } diff --git a/codegen-luau/src/backend/statement.rs b/codegen-luau/src/backend/statement.rs index 7eb285e..6f8e3c4 100644 --- a/codegen-luau/src/backend/statement.rs +++ b/codegen-luau/src/backend/statement.rs @@ -5,8 +5,8 @@ use std::{ use parity_wasm::elements::ValueType; use wasm_ast::node::{ - Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, Return, SetGlobal, - SetLocal, SetTemporary, Statement, StoreAt, Terminator, + Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, + SetTemporary, Statement, StoreAt, Terminator, }; use super::manager::{ @@ -61,21 +61,12 @@ impl Driver for BrTable { } } -impl Driver for Return { - fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - write!(w, "do return ")?; - self.list.as_slice().write(mng, w)?; - write!(w, "end ") - } -} - impl Driver for Terminator { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { match self { Self::Unreachable => write!(w, "error(\"out of code bounds\")"), Self::Br(s) => s.write(mng, w), Self::BrTable(s) => s.write(mng, w), - Self::Return(s) => s.write(mng, w), } } } @@ -306,6 +297,12 @@ impl Driver for FuncData { mng.num_param = self.num_param; self.code.write(mng, w)?; + if self.num_result != 0 { + write!(w, "return ")?; + write_ascending("reg", 0..self.num_result, w)?; + write!(w, " ")?; + } + write!(w, "end ") } } diff --git a/wasm-ast/src/builder.rs b/wasm-ast/src/builder.rs index 9533e44..7e9a118 100644 --- a/wasm-ast/src/builder.rs +++ b/wasm-ast/src/builder.rs @@ -6,7 +6,7 @@ use parity_wasm::elements::{ use crate::node::{ Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType, Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, - MemoryGrow, MemorySize, Return, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, + MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, StoreType, Terminator, UnOp, UnOpType, Value, }; @@ -117,7 +117,9 @@ struct StatList { last: Option, num_result: usize, + num_param: usize, num_stack: usize, + num_previous: usize, is_else: bool, } @@ -139,8 +141,9 @@ impl StatList { self.num_stack = self.num_stack.max(self.stack.len()); } - fn leak_at(&mut self, var: usize) { - let old = self.stack.get_mut(var).unwrap(); + fn leak_at(&mut self, index: usize) { + let old = self.stack.get_mut(index).unwrap(); + let var = self.num_previous + index; if old.is_temporary(var) { return; @@ -285,6 +288,29 @@ impl StatList { self.try_add_equal_zero(inst) } } + + // Return values from a block by leaking the stack and then + // adjusting the start if necessary. + fn set_return_data(&mut self, par_previous: usize, par_result: usize) { + self.leak_all(); + + // If the start of our copy and parent frame are the same our results + // are already in the right registers. + let start = self.num_previous + self.stack.len() - par_result; + + if start == par_previous { + return; + } + + for i in 0..par_result { + let data = Statement::SetTemporary(SetTemporary { + var: par_previous + i, + value: Expression::GetTemporary(GetTemporary { var: start + i }), + }); + + self.code.push(data); + } + } } impl From for Forward { @@ -333,6 +359,7 @@ impl<'a> Builder<'a> { FuncData { local_data: Vec::new(), + num_result: 1, num_param: 0, num_stack: data.num_stack, code: data.into(), @@ -346,33 +373,58 @@ impl<'a> Builder<'a> { FuncData { local_data: func.locals().to_vec(), + num_result: arity.num_result, num_param: arity.num_param, num_stack: data.num_stack, code: data.into(), } } - // FIXME: Sets up temporaries except when used in a weird way fn start_block(&mut self, typ: BlockType) { - let mut old = std::mem::take(&mut self.target); - - self.target.push_temporary(old.stack.len()); - self.target.num_result = match typ { - BlockType::NoResult => 0, - BlockType::Value(_) => 1, + let (num_param, num_result) = match typ { + BlockType::NoResult => (0, 0), + BlockType::Value(_) => (0, 1), BlockType::TypeIndex(i) => { let id = i.try_into().unwrap(); + let arity = self.type_info.arity_of(id); - self.type_info.arity_of(id).num_result + (arity.num_param, arity.num_result) } }; + let mut old = std::mem::take(&mut self.target); + old.leak_all(); - old.push_temporary(self.target.num_result); + + self.target.stack = old.pop_len(num_param); + self.target.num_result = num_result; + self.target.num_param = num_param; + self.target.num_previous = old.num_previous + old.stack.len(); + + old.push_temporary(num_result); self.pending.push(old); } + fn start_else(&mut self) { + let num_result = self.target.num_result; + let num_param = self.target.num_param; + let num_previous = self.target.num_previous; + + self.end_block(); + + let old = std::mem::take(&mut self.target); + + self.pending.push(old); + + self.target.num_result = num_result; + self.target.num_param = num_param; + self.target.num_previous = num_previous; + self.target.is_else = true; + + self.target.push_temporary(num_result); + } + fn end_block(&mut self) { let old = self.pending.pop().unwrap(); let now = std::mem::replace(&mut self.target, old); @@ -428,12 +480,29 @@ impl<'a> Builder<'a> { self.target.code.push(data); } - fn set_return(&mut self) { - let data = Terminator::Return(Return { - list: self.target.pop_len(self.num_result), - }); + fn get_relative_block(&self, index: usize) -> Option<&StatList> { + if index == 0 { + Some(&self.target) + } else { + self.pending.get(self.pending.len() - index) + } + } - self.target.leak_all(); + fn set_return_data(&mut self, target: usize) { + let (par_previous, par_result) = match self.get_relative_block(target) { + Some(v) => (v.num_previous, v.num_result), + None => (0, self.num_result), + }; + + self.target.set_return_data(par_previous, par_result); + } + + fn set_br_to_block(&mut self, target: usize) { + self.nested_unreachable += 1; + + let data = Terminator::Br(Br { target }); + + self.set_return_data(target); self.target.last = Some(data); } @@ -464,6 +533,8 @@ impl<'a> Builder<'a> { match *inst { Inst::Unreachable => { self.nested_unreachable += 1; + + self.target.leak_all(); self.target.last = Some(Terminator::Unreachable); } Inst::Nop => {} @@ -490,28 +561,17 @@ impl<'a> Builder<'a> { self.pending.last_mut().unwrap().code.push(data); } Inst::Else => { - let num_result = self.target.num_result; - - self.target.leak_all(); - self.end_block(); - self.start_block(BlockType::NoResult); - - self.target.num_result = num_result; - self.target.is_else = true; + self.set_return_data(0); + self.start_else(); } Inst::End => { - self.target.leak_all(); + self.set_return_data(0); self.end_block(); } Inst::Br(v) => { - self.nested_unreachable += 1; + let target = v.try_into().unwrap(); - let data = Terminator::Br(Br { - target: v.try_into().unwrap(), - }); - - self.target.leak_all(); - self.target.last = Some(data); + self.set_br_to_block(target); } Inst::BrIf(v) => { let data = Statement::BrIf(BrIf { @@ -520,23 +580,24 @@ impl<'a> Builder<'a> { }); // FIXME: Does not push results unless true - // self.target.add_result_data(); self.target.code.push(data); } Inst::BrTable(ref v) => { self.nested_unreachable += 1; + let default = v.default.try_into().unwrap(); let data = Terminator::BrTable(BrTable { cond: self.target.pop_required(), data: *v.clone(), }); - self.target.leak_all(); + self.set_return_data(default); self.target.last = Some(data); } Inst::Return => { - self.nested_unreachable += 1; - self.set_return(); + let target = self.pending.len(); + + self.set_br_to_block(target); } Inst::Call(i) => { self.add_call(i.try_into().unwrap()); @@ -661,6 +722,7 @@ impl<'a> Builder<'a> { fn build_stat_list(&mut self, list: &[Instruction], num_result: usize) -> StatList { self.nested_unreachable = 0; self.num_result = num_result; + self.target.num_result = num_result; for inst in list.iter().take(list.len() - 1) { if self.nested_unreachable == 0 { @@ -670,8 +732,8 @@ impl<'a> Builder<'a> { } } - if self.nested_unreachable == 0 && num_result != 0 { - self.set_return(); + if self.nested_unreachable == 0 { + self.set_return_data(0); } std::mem::take(&mut self.target) diff --git a/wasm-ast/src/node.rs b/wasm-ast/src/node.rs index 6ac9a88..9837481 100644 --- a/wasm-ast/src/node.rs +++ b/wasm-ast/src/node.rs @@ -669,15 +669,10 @@ pub struct BrTable { pub data: BrTableData, } -pub struct Return { - pub list: Vec, -} - pub enum Terminator { Unreachable, Br(Br), BrTable(BrTable), - Return(Return), } #[derive(Default)] @@ -753,6 +748,7 @@ pub enum Statement { pub struct FuncData { pub local_data: Vec, + pub num_result: usize, pub num_param: usize, pub num_stack: usize, pub code: Forward, diff --git a/wasm-ast/src/visit.rs b/wasm-ast/src/visit.rs index a3826cf..1edda7f 100644 --- a/wasm-ast/src/visit.rs +++ b/wasm-ast/src/visit.rs @@ -1,7 +1,7 @@ use crate::node::{ Backward, BinOp, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData, - GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Return, Select, - SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value, + GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Select, SetGlobal, + SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value, }; pub trait Visitor { @@ -35,8 +35,6 @@ pub trait Visitor { fn visit_br_table(&mut self, _: &BrTable) {} - fn visit_return(&mut self, _: &Return) {} - fn visit_terminator(&mut self, _: &Terminator) {} fn visit_forward(&mut self, _: &Forward) {} @@ -182,23 +180,12 @@ impl Driver for BrTable { } } -impl Driver for Return { - fn accept(&self, visitor: &mut T) { - for v in &self.list { - v.accept(visitor); - } - - visitor.visit_return(self); - } -} - impl Driver for Terminator { fn accept(&self, visitor: &mut T) { match self { Self::Unreachable => visitor.visit_unreachable(), Self::Br(v) => v.accept(visitor), Self::BrTable(v) => v.accept(visitor), - Self::Return(v) => v.accept(visitor), } visitor.visit_terminator(self);