Fix alignment when branching

This commit is contained in:
Rerumu 2022-06-13 22:22:08 -04:00
parent 20b888bfa0
commit a6cf4fdf07
5 changed files with 187 additions and 153 deletions

View File

@ -5,7 +5,7 @@ use std::{
use parity_wasm::elements::ValueType; use parity_wasm::elements::ValueType;
use wasm_ast::node::{ use wasm_ast::node::{
Backward, Br, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal,
SetTemporary, Statement, StoreAt, Terminator, SetTemporary, Statement, StoreAt, Terminator,
}; };
@ -13,66 +13,40 @@ use super::manager::{
write_ascending, write_condition, write_separated, write_variable, Driver, Manager, write_ascending, write_condition, write_separated, write_variable, Driver, Manager,
}; };
fn write_br_at(up: usize, mng: &Manager, w: &mut dyn Write) -> Result<()> {
let level = mng.label_list().iter().nth_back(up).unwrap();
write!(w, "goto continue_at_{level} ")
}
impl Driver for Br { impl Driver for Br {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write_br_at(self.target, mng, w) let level = *mng.label_list().iter().nth_back(self.target).unwrap();
}
}
fn condense_jump_table(list: &[u32]) -> Vec<(usize, usize, u32)> { if !self.align.is_aligned() {
let mut result = Vec::new(); write_ascending("reg", self.align.new_range(), w)?;
let mut index = 0; write!(w, " = ")?;
write_ascending("reg", self.align.old_range(), w)?;
while index < list.len() { write!(w, " ")?;
let start = index;
loop {
index += 1;
// if end of list or next value is not equal, break
if index == list.len() || list[index - 1] != list[index] {
break;
}
} }
result.push((start, index - 1, list[start])); write!(w, "goto continue_at_{level} ")
} }
result
} }
impl Driver for BrTable { impl Driver for BrTable {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
let default = self.data.default.try_into().unwrap();
// Our condition should be pure so we probably don't need
// to emit it in this case.
if self.data.table.is_empty() {
return write_br_at(default, mng, w);
}
write!(w, "temp = ")?; write!(w, "temp = ")?;
self.cond.write(mng, w)?; self.cond.write(mng, w)?;
for (start, end, dest) in condense_jump_table(&self.data.table) { // Our condition should be pure so we probably don't need
if start == end { // to emit it in this case.
write!(w, "if temp == {start} then ")?; if self.data.is_empty() {
} else { return self.default.write(mng, w);
write!(w, "if temp >= {start} and temp <= {end} then ")?;
} }
write_br_at(dest.try_into().unwrap(), mng, w)?; for (case, dest) in self.data.iter().enumerate() {
write!(w, "if temp == {case} then ")?;
dest.write(mng, w)?;
write!(w, "else")?; write!(w, "else")?;
} }
write!(w, " ")?; write!(w, " ")?;
write_br_at(default, mng, w)?; self.default.write(mng, w)?;
write!(w, "end ") write!(w, "end ")
} }
} }
@ -127,6 +101,16 @@ impl Driver for Backward {
} }
} }
impl Driver for BrIf {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "if ")?;
write_condition(&self.cond, mng, w)?;
write!(w, "then ")?;
self.target.write(mng, w)?;
write!(w, "end ")
}
}
impl Driver for If { impl Driver for If {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "if ")?; write!(w, "if ")?;
@ -220,6 +204,7 @@ impl Driver for Statement {
match self { match self {
Self::Forward(s) => s.write(mng, w), Self::Forward(s) => s.write(mng, w),
Self::Backward(s) => s.write(mng, w), Self::Backward(s) => s.write(mng, w),
Self::BrIf(s) => s.write(mng, w),
Self::If(s) => s.write(mng, w), Self::If(s) => s.write(mng, w),
Self::Call(s) => s.write(mng, w), Self::Call(s) => s.write(mng, w),
Self::CallIndirect(s) => s.write(mng, w), Self::CallIndirect(s) => s.write(mng, w),

View File

@ -5,7 +5,7 @@ use std::{
use parity_wasm::elements::ValueType; use parity_wasm::elements::ValueType;
use wasm_ast::node::{ use wasm_ast::node::{
Backward, Br, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal, Backward, Br, BrIf, BrTable, Call, CallIndirect, Forward, FuncData, If, SetGlobal, SetLocal,
SetTemporary, Statement, StoreAt, Terminator, SetTemporary, Statement, StoreAt, Terminator,
}; };
@ -13,50 +13,53 @@ use super::manager::{
write_ascending, write_condition, write_separated, write_variable, Driver, Label, Manager, write_ascending, write_condition, write_separated, write_variable, Driver, Label, Manager,
}; };
fn write_br_at(up: usize, mng: &Manager, w: &mut dyn Write) -> Result<()> { impl Driver for Br {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "do ")?; write!(w, "do ")?;
if up == 0 { if !self.align.is_aligned() {
write_ascending("reg", self.align.new_range(), w)?;
write!(w, " = ")?;
write_ascending("reg", self.align.old_range(), w)?;
write!(w, " ")?;
}
if self.target == 0 {
if let Some(&Label::Backward) = mng.label_list().last() { if let Some(&Label::Backward) = mng.label_list().last() {
write!(w, "continue ")?; write!(w, "continue ")?;
} else { } else {
write!(w, "break ")?; write!(w, "break ")?;
} }
} else { } else {
let level = mng.label_list().len() - 1 - up; let level = mng.label_list().len() - 1 - self.target;
write!(w, "desired = {level} ")?; write!(w, "desired = {level} ")?;
write!(w, "break ")?; write!(w, "break ")?;
} }
write!(w, "end ") write!(w, "end ")
}
impl Driver for Br {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write_br_at(self.target, mng, w)
} }
} }
impl Driver for BrTable { impl Driver for BrTable {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "do ")?; write!(w, "temp = ")?;
write!(w, "local temp = {{")?;
if !self.data.table.is_empty() {
write!(w, "[0] =")?;
for d in self.data.table.iter() {
write!(w, "{d}, ")?;
}
}
write!(w, "}} ")?;
write!(w, "desired = temp[")?;
self.cond.write(mng, w)?; self.cond.write(mng, w)?;
write!(w, "] or {} ", self.data.default)?;
write!(w, "break ")?; // Our condition should be pure so we probably don't need
// to emit it in this case.
if self.data.is_empty() {
return self.default.write(mng, w);
}
for (case, dest) in self.data.iter().enumerate() {
write!(w, "if temp == {case} then ")?;
dest.write(mng, w)?;
write!(w, "else")?;
}
write!(w, " ")?;
self.default.write(mng, w)?;
write!(w, "end ") write!(w, "end ")
} }
} }
@ -133,6 +136,16 @@ impl Driver for Backward {
} }
} }
impl Driver for BrIf {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "if ")?;
write_condition(&self.cond, mng, w)?;
write!(w, "then ")?;
self.target.write(mng, w)?;
write!(w, "end ")
}
}
impl Driver for If { impl Driver for If {
fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> { fn write(&self, mng: &mut Manager, w: &mut dyn Write) -> Result<()> {
write!(w, "while true do ")?; write!(w, "while true do ")?;
@ -227,6 +240,7 @@ impl Driver for Statement {
match self { match self {
Self::Forward(s) => s.write(mng, w), Self::Forward(s) => s.write(mng, w),
Self::Backward(s) => s.write(mng, w), Self::Backward(s) => s.write(mng, w),
Self::BrIf(s) => s.write(mng, w),
Self::If(s) => s.write(mng, w), Self::If(s) => s.write(mng, w),
Self::Call(s) => s.write(mng, w), Self::Call(s) => s.write(mng, w),
Self::CallIndirect(s) => s.write(mng, w), Self::CallIndirect(s) => s.write(mng, w),

View File

@ -4,10 +4,10 @@ use parity_wasm::elements::{
}; };
use crate::node::{ use crate::node::{
Backward, BinOp, BinOpType, Br, BrTable, Call, CallIndirect, CmpOp, CmpOpType, Expression, Align, Backward, BinOp, BinOpType, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, CmpOpType,
Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType, MemoryGrow, Expression, Forward, FuncData, GetGlobal, GetLocal, GetTemporary, If, LoadAt, LoadType,
MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt, StoreType, MemoryGrow, MemorySize, Select, SetGlobal, SetLocal, SetTemporary, Statement, StoreAt,
Terminator, UnOp, UnOpType, Value, StoreType, Terminator, UnOp, UnOpType, Value,
}; };
macro_rules! leak_with_predicate { macro_rules! leak_with_predicate {
@ -291,27 +291,21 @@ impl StatList {
} }
} }
// Return values from a block by leaking the stack and then // Return the alignment necessary for this block to branch out to a
// adjusting the start if necessary. // another given block
fn set_return_data(&mut self, par_previous: usize, par_result: usize) { fn get_br_alignment(&self, par_start: usize, par_result: usize) -> Align {
let start = self.stack.len() + self.num_previous - par_result;
Align {
new: par_start,
old: start,
length: par_result,
}
}
fn set_terminator(&mut self, term: Terminator) {
self.leak_all(); self.leak_all();
self.last = Some(term);
// If the start of our copy and parent frame are the same our results
// are already in the right registers.
let start = self.num_previous + self.stack.len() - par_result;
if start == par_previous {
return;
}
for i in 0..par_result {
let data = Statement::SetTemporary(SetTemporary {
var: par_previous + i,
value: Expression::GetTemporary(GetTemporary { var: start + i }),
});
self.code.push(data);
}
} }
} }
@ -443,6 +437,25 @@ impl<'a> Builder<'a> {
} }
} }
fn get_relative_block(&self, index: usize) -> Option<&StatList> {
if index == 0 {
Some(&self.target)
} else {
self.pending.get(self.pending.len() - index)
}
}
fn get_br_terminator(&self, target: usize) -> Br {
let (par_start, par_result) = match self.get_relative_block(target) {
Some(v) => (v.num_previous, v.num_result),
None => (0, self.num_result),
};
let align = self.target.get_br_alignment(par_start, par_result);
Br { target, align }
}
fn add_call(&mut self, func: usize) { fn add_call(&mut self, func: usize) {
let arity = self.type_info.rel_arity_of(func); let arity = self.type_info.rel_arity_of(func);
let param_list = self.target.pop_len(arity.num_param); let param_list = self.target.pop_len(arity.num_param);
@ -483,32 +496,6 @@ impl<'a> Builder<'a> {
self.target.code.push(data); self.target.code.push(data);
} }
fn get_relative_block(&self, index: usize) -> Option<&StatList> {
if index == 0 {
Some(&self.target)
} else {
self.pending.get(self.pending.len() - index)
}
}
fn set_return_data(&mut self, target: usize) {
let (par_previous, par_result) = match self.get_relative_block(target) {
Some(v) => (v.num_previous, v.num_result),
None => (0, self.num_result),
};
self.target.set_return_data(par_previous, par_result);
}
fn set_br_to_block(&mut self, target: usize) {
self.nested_unreachable += 1;
let data = Terminator::Br(Br { target });
self.set_return_data(target);
self.target.last = Some(data);
}
#[cold] #[cold]
fn drop_unreachable(&mut self, inst: &Instruction) { fn drop_unreachable(&mut self, inst: &Instruction) {
match inst { match inst {
@ -543,8 +530,7 @@ impl<'a> Builder<'a> {
Inst::Unreachable => { Inst::Unreachable => {
self.nested_unreachable += 1; self.nested_unreachable += 1;
self.target.leak_all(); self.target.set_terminator(Terminator::Unreachable);
self.target.last = Some(Terminator::Unreachable);
} }
Inst::Nop => {} Inst::Nop => {}
Inst::Block(typ) => { Inst::Block(typ) => {
@ -567,48 +553,55 @@ impl<'a> Builder<'a> {
self.start_block(typ, stat); self.start_block(typ, stat);
} }
Inst::Else => { Inst::Else => {
self.set_return_data(0); self.target.leak_all();
self.start_else(); self.start_else();
} }
Inst::End => { Inst::End => {
self.set_return_data(0); self.target.leak_all();
self.end_block(); self.end_block();
} }
Inst::Br(v) => { Inst::Br(v) => {
let target = v.try_into().unwrap(); let target = v.try_into().unwrap();
let term = Terminator::Br(self.get_br_terminator(target));
self.set_br_to_block(target); self.target.set_terminator(term);
self.nested_unreachable += 1;
} }
Inst::BrIf(v) => { Inst::BrIf(v) => {
let target: usize = v.try_into().unwrap(); let data = Statement::BrIf(BrIf {
let stat = Statement::If(If {
cond: self.target.pop_required(), cond: self.target.pop_required(),
truthy: Forward::default(), target: self.get_br_terminator(v.try_into().unwrap()),
falsey: None,
}); });
self.start_block(BlockType::NoResult, stat); self.target.leak_all();
self.set_br_to_block(target + 1); self.target.code.push(data);
self.end_block();
self.nested_unreachable -= 1;
} }
Inst::BrTable(ref v) => { Inst::BrTable(ref v) => {
self.nested_unreachable += 1; let cond = self.target.pop_required();
let data = v
.table
.iter()
.copied()
.map(|v| self.get_br_terminator(v.try_into().unwrap()))
.collect();
let default = v.default.try_into().unwrap(); let default = self.get_br_terminator(v.default.try_into().unwrap());
let data = Terminator::BrTable(BrTable {
cond: self.target.pop_required(), let term = Terminator::BrTable(BrTable {
data: *v.clone(), cond,
data,
default,
}); });
self.set_return_data(default); self.target.set_terminator(term);
self.target.last = Some(data); self.nested_unreachable += 1;
} }
Inst::Return => { Inst::Return => {
let target = self.pending.len(); let target = self.pending.len();
let term = Terminator::Br(self.get_br_terminator(target));
self.set_br_to_block(target); self.target.set_terminator(term);
self.nested_unreachable += 1;
} }
Inst::Call(i) => { Inst::Call(i) => {
self.add_call(i.try_into().unwrap()); self.add_call(i.try_into().unwrap());
@ -744,7 +737,7 @@ impl<'a> Builder<'a> {
} }
if self.nested_unreachable == 0 { if self.nested_unreachable == 0 {
self.set_return_data(0); self.target.leak_all();
} }
std::mem::take(&mut self.target) std::mem::take(&mut self.target)

View File

@ -1,6 +1,6 @@
use std::ops::Range; use std::ops::Range;
use parity_wasm::elements::{BrTableData, Instruction, Local, SignExtInstruction}; use parity_wasm::elements::{Instruction, Local, SignExtInstruction};
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@ -660,13 +660,38 @@ impl Expression {
} }
} }
pub struct Align {
pub new: usize,
pub old: usize,
pub length: usize,
}
impl Align {
#[must_use]
pub fn is_aligned(&self) -> bool {
self.length == 0 || self.new == self.old
}
#[must_use]
pub fn new_range(&self) -> Range<usize> {
self.new..self.new + self.length
}
#[must_use]
pub fn old_range(&self) -> Range<usize> {
self.old..self.old + self.length
}
}
pub struct Br { pub struct Br {
pub target: usize, pub target: usize,
pub align: Align,
} }
pub struct BrTable { pub struct BrTable {
pub cond: Expression, pub cond: Expression,
pub data: BrTableData, pub data: Vec<Br>,
pub default: Br,
} }
pub enum Terminator { pub enum Terminator {
@ -687,6 +712,11 @@ pub struct Backward {
pub last: Option<Terminator>, pub last: Option<Terminator>,
} }
pub struct BrIf {
pub cond: Expression,
pub target: Br,
}
pub struct If { pub struct If {
pub cond: Expression, pub cond: Expression,
pub truthy: Forward, pub truthy: Forward,
@ -731,6 +761,7 @@ pub struct StoreAt {
pub enum Statement { pub enum Statement {
Forward(Forward), Forward(Forward),
Backward(Backward), Backward(Backward),
BrIf(BrIf),
If(If), If(If),
Call(Call), Call(Call),
CallIndirect(CallIndirect), CallIndirect(CallIndirect),

View File

@ -1,5 +1,5 @@
use crate::node::{ use crate::node::{
Backward, BinOp, Br, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData, Backward, BinOp, Br, BrIf, BrTable, Call, CallIndirect, CmpOp, Expression, Forward, FuncData,
GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Select, SetGlobal, GetGlobal, GetLocal, GetTemporary, If, LoadAt, MemoryGrow, MemorySize, Select, SetGlobal,
SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value, SetLocal, SetTemporary, Statement, StoreAt, Terminator, UnOp, Value,
}; };
@ -41,6 +41,8 @@ pub trait Visitor {
fn visit_backward(&mut self, _: &Backward) {} fn visit_backward(&mut self, _: &Backward) {}
fn visit_br_if(&mut self, _: &BrIf) {}
fn visit_if(&mut self, _: &If) {} fn visit_if(&mut self, _: &If) {}
fn visit_call(&mut self, _: &Call) {} fn visit_call(&mut self, _: &Call) {}
@ -218,6 +220,14 @@ impl<T: Visitor> Driver<T> for Backward {
} }
} }
impl<T: Visitor> Driver<T> for BrIf {
fn accept(&self, visitor: &mut T) {
self.cond.accept(visitor);
visitor.visit_br_if(self);
}
}
impl<T: Visitor> Driver<T> for If { impl<T: Visitor> Driver<T> for If {
fn accept(&self, visitor: &mut T) { fn accept(&self, visitor: &mut T) {
self.cond.accept(visitor); self.cond.accept(visitor);
@ -291,6 +301,7 @@ impl<T: Visitor> Driver<T> for Statement {
match self { match self {
Self::Forward(v) => v.accept(visitor), Self::Forward(v) => v.accept(visitor),
Self::Backward(v) => v.accept(visitor), Self::Backward(v) => v.accept(visitor),
Self::BrIf(v) => v.accept(visitor),
Self::If(v) => v.accept(visitor), Self::If(v) => v.accept(visitor),
Self::Call(v) => v.accept(visitor), Self::Call(v) => v.accept(visitor),
Self::CallIndirect(v) => v.accept(visitor), Self::CallIndirect(v) => v.accept(visitor),