From 233aee2c5e0b9ba1afcdf79e5ceb5171fb07f8da Mon Sep 17 00:00:00 2001 From: Rerumu Date: Fri, 22 Apr 2022 03:54:45 -0400 Subject: [PATCH] Simplify type and local handling --- codegen-luajit/src/gen.rs | 28 +++---- codegen-luau/src/gen.rs | 28 +++---- wasm-ast/src/builder.rs | 165 ++++++++++++++++---------------------- wasm-ast/src/node.rs | 6 +- wasm-ast/src/visit.rs | 2 +- 5 files changed, 100 insertions(+), 129 deletions(-) diff --git a/codegen-luajit/src/gen.rs b/codegen-luajit/src/gen.rs index f987675..a1f8016 100644 --- a/codegen-luajit/src/gen.rs +++ b/codegen-luajit/src/gen.rs @@ -5,7 +5,7 @@ use parity_wasm::elements::{ }; use wasm_ast::{ - builder::{Arities, Builder}, + builder::{Builder, TypeInfo}, node::{ AnyBinOp, AnyCmpOp, AnyLoad, AnyStore, AnyUnOp, Backward, Br, BrIf, BrTable, Call, CallIndirect, Else, Expression, Forward, Function, GetGlobal, GetLocal, If, Memorize, @@ -121,19 +121,17 @@ fn write_result_list(range: Range, w: Writer) -> Result<()> { } fn write_variable_list(func: &Function, w: Writer) -> Result<()> { - if !func.local_list.is_empty() { - let num_local = func.local_list.len().try_into().unwrap(); - + for data in &func.local_data { write!(w, "local ")?; - write_in_order("loc", num_local, w)?; + write_in_order("loc", data.count(), w)?; write!(w, " = ")?; - for (i, t) in func.local_list.iter().enumerate() { + for i in 0..data.count() { if i != 0 { write!(w, ", ")?; } - write!(w, "ZERO_{} ", t)?; + write!(w, "ZERO_{} ", data.value_type())?; } } @@ -597,7 +595,7 @@ impl Driver for Function { write!(w, "local temp ")?; v.num_param = self.num_param; - self.body.visit(v, w)?; + self.code.visit(v, w)?; write!(w, "end ") } @@ -605,16 +603,16 @@ impl Driver for Function { pub struct Generator<'a> { wasm: &'a Module, - arity: Arities, + type_info: TypeInfo<'a>, } static RUNTIME: &str = include_str!("../runtime/runtime.lua"); impl<'a> Transpiler<'a> for Generator<'a> { fn new(wasm: &'a Module) -> Self { - let arity = Arities::new(wasm); + let type_info = TypeInfo::from_module(wasm); - Self { wasm, arity } + Self { wasm, type_info } } fn runtime(w: Writer) -> Result<()> { @@ -850,10 +848,10 @@ impl<'a> Generator<'a> { // FIXME: Make `pub` only for fuzzing. #[must_use] pub fn build_func_list(&self) -> Vec { - let range = 0..self.arity.len_in(); + let list = self.wasm.code_section().unwrap().bodies(); + let iter = list.iter().enumerate(); - range - .map(|i| Builder::new(self.wasm, &self.arity).consume(i)) + iter.map(|f| Builder::new(&self.type_info).consume(f.0, f.1)) .collect() } @@ -863,7 +861,7 @@ impl<'a> Generator<'a> { /// # Panics /// If the number of functions overflows 32 bits. pub fn gen_func_list(&self, func_list: &[Function], w: Writer) -> Result<()> { - let o = self.arity.len_ex(); + let o = self.type_info.len_ex(); func_list.iter().enumerate().try_for_each(|(i, v)| { write_func_name(self.wasm, i.try_into().unwrap(), o.try_into().unwrap(), w)?; diff --git a/codegen-luau/src/gen.rs b/codegen-luau/src/gen.rs index 5a9b4f8..20b5ecb 100644 --- a/codegen-luau/src/gen.rs +++ b/codegen-luau/src/gen.rs @@ -5,7 +5,7 @@ use parity_wasm::elements::{ }; use wasm_ast::{ - builder::{Arities, Builder}, + builder::{Builder, TypeInfo}, node::{ AnyBinOp, AnyCmpOp, AnyLoad, AnyStore, AnyUnOp, Backward, Br, BrIf, BrTable, Call, CallIndirect, Else, Expression, Forward, Function, GetGlobal, GetLocal, If, Memorize, @@ -120,19 +120,17 @@ fn write_result_list(range: Range, w: Writer) -> Result<()> { } fn write_variable_list(func: &Function, w: Writer) -> Result<()> { - if !func.local_list.is_empty() { - let num_local = func.local_list.len().try_into().unwrap(); - + for data in &func.local_data { write!(w, "local ")?; - write_in_order("loc", num_local, w)?; + write_in_order("loc", data.count(), w)?; write!(w, " = ")?; - for (i, t) in func.local_list.iter().enumerate() { + for i in 0..data.count() { if i != 0 { write!(w, ", ")?; } - write!(w, "ZERO_{} ", t)?; + write!(w, "ZERO_{} ", data.value_type())?; } } @@ -593,7 +591,7 @@ impl Driver for Function { write_variable_list(self, w)?; v.num_param = self.num_param; - self.body.visit(v, w)?; + self.code.visit(v, w)?; write!(w, "end ") } @@ -601,16 +599,16 @@ impl Driver for Function { pub struct Generator<'a> { wasm: &'a Module, - arity: Arities, + type_info: TypeInfo<'a>, } static RUNTIME: &str = include_str!("../runtime/runtime.lua"); impl<'a> Transpiler<'a> for Generator<'a> { fn new(wasm: &'a Module) -> Self { - let arity = Arities::new(wasm); + let type_info = TypeInfo::from_module(wasm); - Self { wasm, arity } + Self { wasm, type_info } } fn runtime(w: Writer) -> Result<()> { @@ -843,15 +841,15 @@ impl<'a> Generator<'a> { } fn build_func_list(&self) -> Vec { - let range = 0..self.arity.len_in(); + let list = self.wasm.code_section().unwrap().bodies(); + let iter = list.iter().enumerate(); - range - .map(|i| Builder::new(self.wasm, &self.arity).consume(i)) + iter.map(|f| Builder::new(&self.type_info).consume(f.0, f.1)) .collect() } fn gen_func_list(&self, func_list: &[Function], w: Writer) -> Result<()> { - let o = self.arity.len_ex(); + let o = self.type_info.len_ex(); func_list.iter().enumerate().try_for_each(|(i, v)| { write_func_name(self.wasm, i.try_into().unwrap(), o.try_into().unwrap(), w)?; diff --git a/wasm-ast/src/builder.rs b/wasm-ast/src/builder.rs index 524a8c2..95600fd 100644 --- a/wasm-ast/src/builder.rs +++ b/wasm-ast/src/builder.rs @@ -1,6 +1,6 @@ use parity_wasm::elements::{ - BlockType, External, FuncBody, FunctionType, ImportEntry, Instruction, Local, Module, Type, - ValueType, + BlockType, External, Func, FuncBody, FunctionSection, FunctionType, ImportEntry, ImportSection, + Instruction, Module, Type, TypeSection, }; use crate::node::{ @@ -25,77 +25,80 @@ impl Arity { num_result, } } - - fn from_index(types: &[Type], index: u32) -> Self { - let Type::Function(typ) = &types[index as usize]; - - Self::from_type(typ) - } - - fn new_arity_ext(types: &[Type], import: &ImportEntry) -> Option { - if let External::Function(i) = import.external() { - Some(Arity::from_index(types, *i)) - } else { - None - } - } - - fn new_in_list(wasm: &Module) -> Vec { - let (types, funcs) = match (wasm.type_section(), wasm.function_section()) { - (Some(t), Some(f)) => (t.types(), f.entries()), - _ => return Vec::new(), - }; - - funcs - .iter() - .map(|i| Self::from_index(types, i.type_ref())) - .collect() - } - - fn new_ex_list(wasm: &Module) -> Vec { - let (types, imports) = match (wasm.type_section(), wasm.import_section()) { - (Some(t), Some(i)) => (t.types(), i.entries()), - _ => return Vec::new(), - }; - - imports - .iter() - .filter_map(|i| Self::new_arity_ext(types, i)) - .collect() - } } -pub struct Arities { - ex_arity: Vec, - in_arity: Vec, +pub struct TypeInfo<'a> { + data: &'a [Type], + func_ex: Vec, + func_in: Vec, } -impl Arities { +impl<'a> TypeInfo<'a> { #[must_use] - pub fn new(parent: &Module) -> Self { + pub fn from_module(parent: &'a Module) -> Self { + let data = parent + .type_section() + .map_or([].as_ref(), TypeSection::types); + + let func_ex = Self::new_ex_list(parent); + let func_in = Self::new_in_list(parent); + Self { - ex_arity: Arity::new_ex_list(parent), - in_arity: Arity::new_in_list(parent), + data, + func_ex, + func_in, } } #[must_use] pub fn len_in(&self) -> usize { - self.in_arity.len() + self.func_in.len() } #[must_use] pub fn len_ex(&self) -> usize { - self.ex_arity.len() + self.func_ex.len() } - fn arity_of(&self, index: usize) -> &Arity { - let offset = self.ex_arity.len(); + fn raw_arity_of(&self, index: u32) -> Arity { + let Type::Function(typ) = &self.data[index as usize]; - self.ex_arity - .get(index) - .or_else(|| self.in_arity.get(index - offset)) - .unwrap() + Arity::from_type(typ) + } + + fn arity_of(&self, index: usize) -> Arity { + let adjusted = self + .func_ex + .iter() + .chain(self.func_in.iter()) + .nth(index) + .unwrap(); + + self.raw_arity_of(*adjusted) + } + + fn func_of_import(import: &ImportEntry) -> Option { + if let &External::Function(i) = import.external() { + Some(i) + } else { + None + } + } + + fn new_ex_list(wasm: &Module) -> Vec { + let list = wasm + .import_section() + .map_or([].as_ref(), ImportSection::entries); + + list.iter().filter_map(Self::func_of_import).collect() + } + + fn new_in_list(wasm: &Module) -> Vec { + let list = wasm + .function_section() + .map_or([].as_ref(), FunctionSection::entries); + + list.iter().map(Func::type_ref).collect() } } @@ -259,8 +262,7 @@ impl Stacked { pub struct Builder<'a> { // target state - wasm: &'a Module, - other: &'a Arities, + type_info: &'a TypeInfo<'a>, num_result: u32, // translation state @@ -278,67 +280,40 @@ fn is_dead_precursor(inst: &Instruction) -> bool { ) } -fn flat_local_list(local: Local) -> impl Iterator { - std::iter::repeat(local.value_type()).take(local.count().try_into().unwrap()) -} - -fn load_local_list(func: &FuncBody) -> Vec { - func.locals() - .iter() - .copied() - .flat_map(flat_local_list) - .collect() -} - -fn load_func_at(wasm: &Module, index: usize) -> &FuncBody { - &wasm.code_section().unwrap().bodies()[index] -} - impl<'a> Builder<'a> { #[must_use] - pub fn new(wasm: &'a Module, other: &'a Arities) -> Builder<'a> { + pub fn new(info: &'a TypeInfo) -> Builder<'a> { Builder { - wasm, - other, + type_info: info, num_result: 0, data: Stacked::new(), } } #[must_use] - pub fn consume(mut self, index: usize) -> Function { - let func = load_func_at(self.wasm, index); - let arity = &self.other.in_arity[index]; - - let local_list = load_local_list(func); - let num_param = arity.num_param; + pub fn consume(mut self, index: usize, func: &'a FuncBody) -> Function { + let arity = &self.type_info.arity_of(self.type_info.len_ex() + index); self.num_result = arity.num_result; - let body = self.new_forward(&mut func.code().elements()); + let code = self.new_forward(&mut func.code().elements()); let num_stack = self.data.last_stack.try_into().unwrap(); Function { - local_list, - num_param, + local_data: func.locals().to_vec(), + num_param: arity.num_param, num_stack, - body, + code, } } - fn get_type_of(&self, index: u32) -> Arity { - let types = self.wasm.type_section().unwrap().types(); - - Arity::from_index(types, index) - } - fn push_block_result(&mut self, typ: BlockType) { let num = match typ { BlockType::NoResult => { return; } BlockType::Value(_) => 1, - BlockType::TypeIndex(i) => self.get_type_of(i).num_result, + BlockType::TypeIndex(i) => self.type_info.raw_arity_of(i).num_result, }; self.data.push_recall(num); @@ -354,7 +329,7 @@ impl<'a> Builder<'a> { } fn gen_call(&mut self, func: u32, stat: &mut Vec) { - let arity = self.other.arity_of(func as usize); + let arity = self.type_info.arity_of(func as usize); let param_list = self.data.pop_many(arity.num_param as usize); let first = u32::try_from(self.data.stack.len()).unwrap(); @@ -371,7 +346,7 @@ impl<'a> Builder<'a> { } fn gen_call_indirect(&mut self, typ: u32, table: u8, stat: &mut Vec) { - let arity = self.get_type_of(typ); + let arity = self.type_info.raw_arity_of(typ); let index = self.data.pop(); let param_list = self.data.pop_many(arity.num_param as usize); diff --git a/wasm-ast/src/node.rs b/wasm-ast/src/node.rs index a703d3b..ecf0eb5 100644 --- a/wasm-ast/src/node.rs +++ b/wasm-ast/src/node.rs @@ -1,6 +1,6 @@ use std::ops::Range; -use parity_wasm::elements::{BrTableData, ValueType}; +use parity_wasm::elements::{BrTableData, Local}; use std::convert::TryFrom; @@ -767,8 +767,8 @@ pub enum Statement { } pub struct Function { - pub local_list: Vec, + pub local_data: Vec, pub num_param: u32, pub num_stack: u32, - pub body: Forward, + pub code: Forward, } diff --git a/wasm-ast/src/visit.rs b/wasm-ast/src/visit.rs index 463f48b..d1ddca2 100644 --- a/wasm-ast/src/visit.rs +++ b/wasm-ast/src/visit.rs @@ -324,6 +324,6 @@ impl Driver for Statement { impl Driver for Function { fn accept(&self, visitor: &mut T) { - self.body.accept(visitor); + self.code.accept(visitor); } }