diff --git a/wasm-ast/src/builder.rs b/wasm-ast/src/builder.rs index d323e65..3a9bf75 100644 --- a/wasm-ast/src/builder.rs +++ b/wasm-ast/src/builder.rs @@ -82,6 +82,24 @@ impl<'a> TypeInfo<'a> { self.arity_of(*adjusted) } + fn block_arity_of(&self, typ: BlockType) -> Arity { + match typ { + BlockType::NoResult => Arity { + num_param: 0, + num_result: 0, + }, + BlockType::Value(_) => Arity { + num_param: 0, + num_result: 1, + }, + BlockType::TypeIndex(i) => { + let id = i.try_into().unwrap(); + + self.arity_of(id) + } + } + } + fn func_of_import(import: &ImportEntry) -> Option { if let &External::Function(i) = import.external() { Some(i.try_into().unwrap()) @@ -110,17 +128,35 @@ impl<'a> TypeInfo<'a> { } } +enum BlockVariant { + Forward, + Backward, + If, + Else, +} + +enum BlockData { + Forward { num_result: usize }, + Backward { num_param: usize }, + If { num_result: usize, typ: BlockType }, + Else { num_result: usize }, +} + +impl Default for BlockData { + fn default() -> Self { + BlockData::Forward { num_result: 0 } + } +} + #[derive(Default)] struct StatList { stack: Vec, code: Vec, last: Option, - num_result: usize, - num_param: usize, + block_data: BlockData, num_stack: usize, num_previous: usize, - is_else: bool, } impl StatList { @@ -376,26 +412,29 @@ impl<'a> Builder<'a> { } } - fn start_block(&mut self, typ: BlockType, stat: Statement) { - let (num_param, num_result) = match typ { - BlockType::NoResult => (0, 0), - BlockType::Value(_) => (0, 1), - BlockType::TypeIndex(i) => { - let id = i.try_into().unwrap(); - let arity = self.type_info.arity_of(id); - - (arity.num_param, arity.num_result) - } - }; + fn start_block(&mut self, typ: BlockType, variant: BlockVariant) { + let Arity { + num_param, + num_result, + } = self.type_info.block_arity_of(typ); let mut old = std::mem::take(&mut self.target); old.leak_all(); - old.code.push(stat); + + self.target.block_data = match variant { + BlockVariant::Forward => BlockData::Forward { num_result }, + BlockVariant::Backward => BlockData::Backward { num_param }, + BlockVariant::If => BlockData::If { num_result, typ }, + BlockVariant::Else => { + old.pop_len(num_result); + old.push_temporary(num_param); + + BlockData::Else { num_result } + } + }; self.target.stack = old.pop_len(num_param); - self.target.num_result = num_result; - self.target.num_param = num_param; self.target.num_previous = old.num_previous + old.stack.len(); old.push_temporary(num_result); @@ -404,22 +443,15 @@ impl<'a> Builder<'a> { } fn start_else(&mut self) { - let mut temp = StatList { - num_result: self.target.num_result, - num_param: self.target.num_param, - num_stack: self.target.num_stack, - num_previous: self.target.num_previous, - is_else: true, - ..Default::default() + let typ = if let BlockData::If { typ, .. } = self.target.block_data { + typ + } else { + unreachable!() }; - temp.push_temporary(temp.num_param); - + self.target.leak_all(); self.end_block(); - - let old = std::mem::replace(&mut self.target, temp); - - self.pending.push(old); + self.start_block(typ, BlockVariant::Else); } fn end_block(&mut self) { @@ -428,13 +460,26 @@ impl<'a> Builder<'a> { self.target.num_stack = now.num_stack; - match self.target.code.last_mut().unwrap() { - Statement::Forward(data) => *data = now.into(), - Statement::Backward(data) => *data = now.into(), - Statement::If(data) if !now.is_else => data.truthy = now.into(), - Statement::If(data) if now.is_else => data.falsey = Some(now.into()), - _ => unreachable!(), - } + let stat = match now.block_data { + BlockData::Forward { .. } => Statement::Forward(now.into()), + BlockData::Backward { .. } => Statement::Backward(now.into()), + BlockData::If { .. } => Statement::If(If { + cond: self.target.pop_required(), + truthy: now.into(), + falsey: None, + }), + BlockData::Else { .. } => { + if let Statement::If(v) = self.target.code.last_mut().unwrap() { + v.falsey = Some(now.into()); + } else { + unreachable!() + } + + return; + } + }; + + self.target.code.push(stat); } fn get_relative_block(&self, index: usize) -> Option<&StatList> { @@ -447,7 +492,15 @@ impl<'a> Builder<'a> { fn get_br_terminator(&self, target: usize) -> Br { let (par_start, par_result) = match self.get_relative_block(target) { - Some(v) => (v.num_previous, v.num_result), + Some(v) => ( + v.num_previous, + match v.block_data { + BlockData::Forward { num_result } => num_result, + BlockData::Backward { num_param } => num_param, + BlockData::If { num_result, .. } => num_result, + BlockData::Else { num_result } => num_result, + }, + ), None => (0, self.num_result), }; @@ -534,26 +587,18 @@ impl<'a> Builder<'a> { } Inst::Nop => {} Inst::Block(typ) => { - let stat = Statement::Forward(Forward::default()); - - self.start_block(typ, stat); + self.start_block(typ, BlockVariant::Forward); } Inst::Loop(typ) => { - let stat = Statement::Backward(Backward::default()); - - self.start_block(typ, stat); + self.start_block(typ, BlockVariant::Backward); } Inst::If(typ) => { - let stat = Statement::If(If { - cond: self.target.pop_required(), - truthy: Forward::default(), - falsey: None, - }); + let cond = self.target.pop_required(); - self.start_block(typ, stat); + self.start_block(typ, BlockVariant::If); + self.pending.last_mut().unwrap().stack.push(cond); } Inst::Else => { - self.target.leak_all(); self.start_else(); } Inst::End => { @@ -726,7 +771,7 @@ impl<'a> Builder<'a> { fn build_stat_list(&mut self, list: &[Instruction], num_result: usize) -> StatList { self.nested_unreachable = 0; self.num_result = num_result; - self.target.num_result = num_result; + self.target.block_data = BlockData::Forward { num_result }; for inst in list.iter().take(list.len() - 1) { if self.nested_unreachable == 0 {