diff --git a/codegen-luajit/src/backend/statement.rs b/codegen-luajit/src/backend/statement.rs index 00c4ca3..5ebe4d0 100644 --- a/codegen-luajit/src/backend/statement.rs +++ b/codegen-luajit/src/backend/statement.rs @@ -5,7 +5,7 @@ use std::{ use parity_wasm::elements::ValueType; use wasm_ast::node::{ - Backward, Br, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, + Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Terminator, }; @@ -13,66 +13,40 @@ use super::manager::{ write_ascending, write_condition, write_separated, write_variable, Driver, Manager, }; -fn write_br_at(up: usize, mng: &Manager, w: &mut dyn Write) -> Result<()> { - let level = mng.label_list().iter().nth_back(up).unwrap(); - - write!(w, "goto continue_at_{level} ") -} - impl Driver for Br { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - write_br_at(self.target, mng, w) - } -} + let level = *mng.label_list().iter().nth_back(self.target).unwrap(); -fn condense_jump_table(list: &[u32]) -> Vec<(usize, usize, u32)> { - let mut result = Vec::new(); - let mut index = 0; - - while index < list.len() { - let start = index; - - loop { - index += 1; - - // if end of list or next value is not equal, break - if index == list.len() || list[index - 1] != list[index] { - break; - } + if !self.align.is_aligned() { + write_ascending("reg", self.align.new_range(), w)?; + write!(w, " = ")?; + write_ascending("reg", self.align.old_range(), w)?; + write!(w, " ")?; } - result.push((start, index - 1, list[start])); + write!(w, "goto continue_at_{level} ") } - - result } impl Driver for BrTable { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - let default = self.data.default.try_into().unwrap(); - - // Our condition should be pure so we probably don't need - // to emit it in this case. - if self.data.table.is_empty() { - return write_br_at(default, mng, w); - } - write!(w, "temp = ")?; self.cond.write(mng, w)?; - for (start, end, dest) in condense_jump_table(&self.data.table) { - if start == end { - write!(w, "if temp == {start} then ")?; - } else { - write!(w, "if temp >= {start} and temp <= {end} then ")?; - } + // Our condition should be pure so we probably don't need + // to emit it in this case. + if self.data.is_empty() { + return self.default.write(mng, w); + } - write_br_at(dest.try_into().unwrap(), mng, w)?; + for (case, dest) in self.data.iter().enumerate() { + write!(w, "if temp == {case} then ")?; + dest.write(mng, w)?; write!(w, "else")?; } write!(w, " ")?; - write_br_at(default, mng, w)?; + self.default.write(mng, w)?; write!(w, "end ") } } @@ -127,6 +101,16 @@ impl Driver for Backward { } } +impl Driver for BrIf { + fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { + write!(w, "if ")?; + write_condition(&self.cond, mng, w)?; + write!(w, "then ")?; + self.target.write(mng, w)?; + write!(w, "end ") + } +} + impl Driver for If { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { write!(w, "if ")?; @@ -220,6 +204,7 @@ impl Driver for Statement { match self { Self::Forward(s) => s.write(mng, w), Self::Backward(s) => s.write(mng, w), + Self::BrIf(s) => s.write(mng, w), Self::If(s) => s.write(mng, w), Self::Call(s) => s.write(mng, w), Self::CallIndirect(s) => s.write(mng, w), diff --git a/codegen-luau/src/backend/statement.rs b/codegen-luau/src/backend/statement.rs index bc0bb08..8f82e47 100644 --- a/codegen-luau/src/backend/statement.rs +++ b/codegen-luau/src/backend/statement.rs @@ -5,7 +5,7 @@ use std::{ use parity_wasm::elements::ValueType; use wasm_ast::node::{ - Backward, Br, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, + Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Terminator, }; @@ -13,50 +13,53 @@ use super::manager::{ write_ascending, write_condition, write_separated, write_variable, Driver, Label, Manager, }; -fn write_br_at(up: usize, mng: &Manager, w: &mut dyn Write) -> Result<()> { - write!(w, "do ")?; - - if up == 0 { - if let Some(&Label::Backward) = mng.label_list().last() { - write!(w, "continue ")?; - } else { - write!(w, "break ")?; - } - } else { - let level = mng.label_list().len() - 1 - up; - - write!(w, "desired = {level} ")?; - write!(w, "break ")?; - } - - write!(w, "end ") -} - impl Driver for Br { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - write_br_at(self.target, mng, w) + write!(w, "do ")?; + + if !self.align.is_aligned() { + write_ascending("reg", self.align.new_range(), w)?; + write!(w, " = ")?; + write_ascending("reg", self.align.old_range(), w)?; + write!(w, " ")?; + } + + if self.target == 0 { + if let Some(&Label::Backward) = mng.label_list().last() { + write!(w, "continue ")?; + } else { + write!(w, "break ")?; + } + } else { + let level = mng.label_list().len() - 1 - self.target; + + write!(w, "desired = {level} ")?; + write!(w, "break ")?; + } + + write!(w, "end ") } } impl Driver for BrTable { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { - write!(w, "do ")?; - write!(w, "local temp = {{")?; + write!(w, "temp = ")?; + self.cond.write(mng, w)?; - if !self.data.table.is_empty() { - write!(w, "[0] =")?; - - for d in self.data.table.iter() { - write!(w, "{d}, ")?; - } + // Our condition should be pure so we probably don't need + // to emit it in this case. + if self.data.is_empty() { + return self.default.write(mng, w); } - write!(w, "}} ")?; + for (case, dest) in self.data.iter().enumerate() { + write!(w, "if temp == {case} then ")?; + dest.write(mng, w)?; + write!(w, "else")?; + } - write!(w, "desired = temp[")?; - self.cond.write(mng, w)?; - write!(w, "] or {} ", self.data.default)?; - write!(w, "break ")?; + write!(w, " ")?; + self.default.write(mng, w)?; write!(w, "end ") } } @@ -133,6 +136,16 @@ impl Driver for Backward { } } +impl Driver for BrIf { + fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { + write!(w, "if ")?; + write_condition(&self.cond, mng, w)?; + write!(w, "then ")?; + self.target.write(mng, w)?; + write!(w, "end ") + } +} + impl Driver for If { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { write!(w, "while true do ")?; @@ -227,6 +240,7 @@ impl Driver for Statement { match self { Self::Forward(s) => s.write(mng, w), Self::Backward(s) => s.write(mng, w), + Self::BrIf(s) => s.write(mng, w), Self::If(s) => s.write(mng, w), Self::Call(s) => s.write(mng, w), Self::CallIndirect(s) => s.write(mng, w), diff --git a/wasm-ast/src/builder.rs b/wasm-ast/src/builder.rs index adf38f8..d323e65 100644 --- a/wasm-ast/src/builder.rs +++ b/wasm-ast/src/builder.rs @@ -4,10 +4,10 @@ use parity_wasm::elements::{ }; use crate::node::{ - Backward, BinOp, BinOpType, Br, BrTable, Call, CallIndirect, CmpOp, CmpOpType, Expression, - Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, MemoryGrow, - MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, StoreType, - Terminator, UnOp, UnOpType, Value, + Align, Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType, + Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, + MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, + StoreType, Terminator, UnOp, UnOpType, Value, }; macro_rules! leak_with_predicate { @@ -291,27 +291,21 @@ impl StatList { } } - // Return values from a block by leaking the stack and then - // adjusting the start if necessary. - fn set_return_data(&mut self, par_previous: usize, par_result: usize) { + // Return the alignment necessary for this block to branch out to a + // another given block + fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align { + let start = self.stack.len() + self.num_previous - par_result; + + Align { + new: par_start, + old: start, + length: par_result, + } + } + + fn set_terminator(&mut self, term: Terminator) { self.leak_all(); - - // If the start of our copy and parent frame are the same our results - // are already in the right registers. - let start = self.num_previous + self.stack.len() - par_result; - - if start == par_previous { - return; - } - - for i in 0..par_result { - let data = Statement::SetTemporary(SetTemporary { - var: par_previous + i, - value: Expression::GetTemporary(GetTemporary { var: start + i }), - }); - - self.code.push(data); - } + self.last = Some(term); } } @@ -443,6 +437,25 @@ impl<'a> Builder<'a> { } } + fn get_relative_block(&self, index: usize) -> Option<&StatList> { + if index == 0 { + Some(&self.target) + } else { + self.pending.get(self.pending.len() - index) + } + } + + fn get_br_terminator(&self, target: usize) -> Br { + let (par_start, par_result) = match self.get_relative_block(target) { + Some(v) => (v.num_previous, v.num_result), + None => (0, self.num_result), + }; + + let align = self.target.get_br_alignment(par_start, par_result); + + Br { target, align } + } + fn add_call(&mut self, func: usize) { let arity = self.type_info.rel_arity_of(func); let param_list = self.target.pop_len(arity.num_param); @@ -483,32 +496,6 @@ impl<'a> Builder<'a> { self.target.code.push(data); } - fn get_relative_block(&self, index: usize) -> Option<&StatList> { - if index == 0 { - Some(&self.target) - } else { - self.pending.get(self.pending.len() - index) - } - } - - fn set_return_data(&mut self, target: usize) { - let (par_previous, par_result) = match self.get_relative_block(target) { - Some(v) => (v.num_previous, v.num_result), - None => (0, self.num_result), - }; - - self.target.set_return_data(par_previous, par_result); - } - - fn set_br_to_block(&mut self, target: usize) { - self.nested_unreachable += 1; - - let data = Terminator::Br(Br { target }); - - self.set_return_data(target); - self.target.last = Some(data); - } - #[cold] fn drop_unreachable(&mut self, inst: &Instruction) { match inst { @@ -543,8 +530,7 @@ impl<'a> Builder<'a> { Inst::Unreachable => { self.nested_unreachable += 1; - self.target.leak_all(); - self.target.last = Some(Terminator::Unreachable); + self.target.set_terminator(Terminator::Unreachable); } Inst::Nop => {} Inst::Block(typ) => { @@ -567,48 +553,55 @@ impl<'a> Builder<'a> { self.start_block(typ, stat); } Inst::Else => { - self.set_return_data(0); + self.target.leak_all(); self.start_else(); } Inst::End => { - self.set_return_data(0); + self.target.leak_all(); self.end_block(); } Inst::Br(v) => { let target = v.try_into().unwrap(); + let term = Terminator::Br(self.get_br_terminator(target)); - self.set_br_to_block(target); + self.target.set_terminator(term); + self.nested_unreachable += 1; } Inst::BrIf(v) => { - let target: usize = v.try_into().unwrap(); - let stat = Statement::If(If { + let data = Statement::BrIf(BrIf { cond: self.target.pop_required(), - truthy: Forward::default(), - falsey: None, + target: self.get_br_terminator(v.try_into().unwrap()), }); - self.start_block(BlockType::NoResult, stat); - self.set_br_to_block(target + 1); - self.end_block(); - - self.nested_unreachable -= 1; + self.target.leak_all(); + self.target.code.push(data); } Inst::BrTable(ref v) => { - self.nested_unreachable += 1; + let cond = self.target.pop_required(); + let data = v + .table + .iter() + .copied() + .map(|v| self.get_br_terminator(v.try_into().unwrap())) + .collect(); - let default = v.default.try_into().unwrap(); - let data = Terminator::BrTable(BrTable { - cond: self.target.pop_required(), - data: *v.clone(), + let default = self.get_br_terminator(v.default.try_into().unwrap()); + + let term = Terminator::BrTable(BrTable { + cond, + data, + default, }); - self.set_return_data(default); - self.target.last = Some(data); + self.target.set_terminator(term); + self.nested_unreachable += 1; } Inst::Return => { let target = self.pending.len(); + let term = Terminator::Br(self.get_br_terminator(target)); - self.set_br_to_block(target); + self.target.set_terminator(term); + self.nested_unreachable += 1; } Inst::Call(i) => { self.add_call(i.try_into().unwrap()); @@ -744,7 +737,7 @@ impl<'a> Builder<'a> { } if self.nested_unreachable == 0 { - self.set_return_data(0); + self.target.leak_all(); } std::mem::take(&mut self.target) diff --git a/wasm-ast/src/node.rs b/wasm-ast/src/node.rs index 1df0cbd..047746e 100644 --- a/wasm-ast/src/node.rs +++ b/wasm-ast/src/node.rs @@ -1,6 +1,6 @@ use std::ops::Range; -use parity_wasm::elements::{BrTableData, Instruction, Local, SignExtInstruction}; +use parity_wasm::elements::{Instruction, Local, SignExtInstruction}; #[allow(non_camel_case_types)] #[derive(Clone, Copy)] @@ -660,13 +660,38 @@ impl Expression { } } +pub struct Align { + pub new: usize, + pub old: usize, + pub length: usize, +} + +impl Align { + #[must_use] + pub fn is_aligned(&self) -> bool { + self.length == 0 || self.new == self.old + } + + #[must_use] + pub fn new_range(&self) -> Range { + self.new..self.new + self.length + } + + #[must_use] + pub fn old_range(&self) -> Range { + self.old..self.old + self.length + } +} + pub struct Br { pub target: usize, + pub align: Align, } pub struct BrTable { pub cond: Expression, - pub data: BrTableData, + pub data: Vec
, + pub default: Br, } pub enum Terminator { @@ -687,6 +712,11 @@ pub struct Backward { pub last: Option, } +pub struct BrIf { + pub cond: Expression, + pub target: Br, +} + pub struct If { pub cond: Expression, pub truthy: Forward, @@ -731,6 +761,7 @@ pub struct StoreAt { pub enum Statement { Forward(Forward), Backward(Backward), + BrIf(BrIf), If(If), Call(Call), CallIndirect(CallIndirect), diff --git a/wasm-ast/src/visit.rs b/wasm-ast/src/visit.rs index 4a4d758..97c8ad5 100644 --- a/wasm-ast/src/visit.rs +++ b/wasm-ast/src/visit.rs @@ -1,5 +1,5 @@ use crate::node::{ - Backward, BinOp, Br, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData, + Backward, BinOp, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value, }; @@ -41,6 +41,8 @@ pub trait Visitor { fn visit_backward(&mut self, _: &Backward) {} + fn visit_br_if(&mut self, _: &BrIf) {} + fn visit_if(&mut self, _: &If) {} fn visit_call(&mut self, _: &Call) {} @@ -218,6 +220,14 @@ impl Driver for Backward { } } +impl Driver for BrIf { + fn accept(&self, visitor: &mut T) { + self.cond.accept(visitor); + + visitor.visit_br_if(self); + } +} + impl Driver for If { fn accept(&self, visitor: &mut T) { self.cond.accept(visitor); @@ -291,6 +301,7 @@ impl Driver for Statement { match self { Self::Forward(v) => v.accept(visitor), Self::Backward(v) => v.accept(visitor), + Self::BrIf(v) => v.accept(visitor), Self::If(v) => v.accept(visitor), Self::Call(v) => v.accept(visitor), Self::CallIndirect(v) => v.accept(visitor),