diff --git a/codegen/luajit/src/bin/wasm2luajit.rs b/codegen/luajit/src/bin/wasm2luajit.rs index e08684f..34ccd23 100644 --- a/codegen/luajit/src/bin/wasm2luajit.rs +++ b/codegen/luajit/src/bin/wasm2luajit.rs @@ -18,7 +18,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> { fn main() -> Result<()> { let data = load_arg_source()?; - let wasm = Module::from_data(&data); + let wasm = Module::try_from_data(&data).unwrap(); let lock = &mut std::io::stdout().lock(); diff --git a/codegen/luau/src/bin/wasm2luau.rs b/codegen/luau/src/bin/wasm2luau.rs index 76419aa..4bfa8fb 100644 --- a/codegen/luau/src/bin/wasm2luau.rs +++ b/codegen/luau/src/bin/wasm2luau.rs @@ -22,7 +22,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> { fn main() -> Result<()> { let data = load_arg_source()?; - let wasm = Module::from_data(&data); + let wasm = Module::try_from_data(&data).unwrap(); let lock = &mut std::io::stdout().lock(); diff --git a/dev-test/fuzz_targets/luajit_translate.rs b/dev-test/fuzz_targets/luajit_translate.rs index 8bc201f..7994e77 100644 --- a/dev-test/fuzz_targets/luajit_translate.rs +++ b/dev-test/fuzz_targets/luajit_translate.rs @@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule; libfuzzer_sys::fuzz_target!(|module: RngModule| { let data = module.to_bytes(); - let wasm = Module::from_data(&data); + let wasm = Module::try_from_data(&data).unwrap(); let sink = &mut std::io::sink(); diff --git a/dev-test/fuzz_targets/luau_translate.rs b/dev-test/fuzz_targets/luau_translate.rs index a136725..92fc8f6 100644 --- a/dev-test/fuzz_targets/luau_translate.rs +++ b/dev-test/fuzz_targets/luau_translate.rs @@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule; libfuzzer_sys::fuzz_target!(|module: RngModule| { let data = module.to_bytes(); - let wasm = Module::from_data(&data); + let wasm = Module::try_from_data(&data).unwrap(); let sink = &mut std::io::sink(); diff --git a/dev-test/tests/luajit_translate.rs b/dev-test/tests/luajit_translate.rs index 4283924..c2d0b83 100644 --- a/dev-test/tests/luajit_translate.rs +++ b/dev-test/tests/luajit_translate.rs @@ -90,7 +90,7 @@ impl Target for LuaJIT { Wat::Module(ast) => ast.encode().unwrap(), Wat::Component(_) => unimplemented!(), }; - let data = Module::from_data(&bytes); + let data = Module::try_from_data(&bytes).unwrap(); writeln!(w, "assert_trap((function()")?; codegen_luajit::from_module_untyped(&data, w)?; diff --git a/dev-test/tests/luau_translate.rs b/dev-test/tests/luau_translate.rs index 91f0b26..e090e88 100644 --- a/dev-test/tests/luau_translate.rs +++ b/dev-test/tests/luau_translate.rs @@ -103,7 +103,7 @@ impl Target for Luau { Wat::Module(ast) => ast.encode().unwrap(), Wat::Component(_) => unimplemented!(), }; - let data = Module::from_data(&bytes); + let data = Module::try_from_data(&bytes).unwrap(); writeln!(w, "assert_trap((function()")?; codegen_luau::from_module_untyped(&data, w)?; diff --git a/dev-test/tests/target.rs b/dev-test/tests/target.rs index 9843c5f..3faa63a 100644 --- a/dev-test/tests/target.rs +++ b/dev-test/tests/target.rs @@ -81,7 +81,7 @@ pub trait Target: Sized { let mut ast = try_into_ast_module(data).expect("Must be a module"); let bytes = ast.encode().unwrap(); - let data = AstModule::from_data(&bytes); + let data = AstModule::try_from_data(&bytes).unwrap(); let name = ast.id.as_ref().map(Id::name); Self::write_module(&data, name, w)?; diff --git a/wasm-ast/src/module.rs b/wasm-ast/src/module.rs index 0a7a6be..0aa1c8b 100644 --- a/wasm-ast/src/module.rs +++ b/wasm-ast/src/module.rs @@ -2,17 +2,9 @@ use std::collections::HashMap; use wasmparser::{ BlockType, Data, Element, Export, ExternalKind, FunctionBody, Global, Import, MemoryType, Name, - NameSectionReader, Parser, Payload, TableType, Type, TypeRef, + NameSectionReader, Parser, Payload, Result, TableType, Type, TypeRef, }; -macro_rules! to_section { - ($data:ident) => {{ - let read: Result<_, _> = $data.into_iter().collect(); - - read.unwrap() - }}; -} - #[derive(PartialEq, Eq, Clone, Copy)] pub enum External { Func, @@ -46,6 +38,13 @@ impl From for External { } } +pub(crate) fn read_checked(reader: I) -> Result> +where + I: IntoIterator>, +{ + reader.into_iter().collect() +} + pub struct Module<'a> { type_section: Vec, import_section: Vec>, @@ -64,8 +63,10 @@ pub struct Module<'a> { } impl<'a> Module<'a> { - #[must_use] - pub fn from_data(data: &'a [u8]) -> Self { + /// # Errors + /// + /// Returns a `BinaryReaderError` if any module section is malformed. + pub fn try_from_data(data: &'a [u8]) -> Result { let mut temp = Module { type_section: Vec::new(), import_section: Vec::new(), @@ -81,22 +82,22 @@ impl<'a> Module<'a> { start_section: None, }; - temp.load_data(data); - temp + temp.load_data(data)?; + Ok(temp) } - fn load_data(&mut self, data: &'a [u8]) { - for payload in Parser::new(0).parse_all(data).flatten() { - match payload { - Payload::TypeSection(v) => self.type_section = to_section!(v), - Payload::ImportSection(v) => self.import_section = to_section!(v), - Payload::FunctionSection(v) => self.func_section = to_section!(v), - Payload::TableSection(v) => self.table_section = to_section!(v), - Payload::MemorySection(v) => self.memory_section = to_section!(v), - Payload::GlobalSection(v) => self.global_section = to_section!(v), - Payload::ExportSection(v) => self.export_section = to_section!(v), - Payload::ElementSection(v) => self.element_section = to_section!(v), - Payload::DataSection(v) => self.data_section = to_section!(v), + fn load_data(&mut self, data: &'a [u8]) -> Result<()> { + for payload in Parser::new(0).parse_all(data) { + match payload? { + Payload::TypeSection(v) => self.type_section = read_checked(v)?, + Payload::ImportSection(v) => self.import_section = read_checked(v)?, + Payload::FunctionSection(v) => self.func_section = read_checked(v)?, + Payload::TableSection(v) => self.table_section = read_checked(v)?, + Payload::MemorySection(v) => self.memory_section = read_checked(v)?, + Payload::GlobalSection(v) => self.global_section = read_checked(v)?, + Payload::ExportSection(v) => self.export_section = read_checked(v)?, + Payload::ElementSection(v) => self.element_section = read_checked(v)?, + Payload::DataSection(v) => self.data_section = read_checked(v)?, Payload::CodeSectionEntry(v) => { self.code_section.push(v); } @@ -104,9 +105,9 @@ impl<'a> Module<'a> { self.start_section = Some(func); } Payload::CustomSection(v) if v.name() == "name" => { - for name in NameSectionReader::new(v.data(), v.data_offset()).unwrap() { - if let Name::Function(map) = name.unwrap() { - let mut iter = map.get_map().unwrap(); + for name in NameSectionReader::new(v.data(), v.data_offset())? { + if let Name::Function(map) = name? { + let mut iter = map.get_map()?; while let Ok(elem) = iter.read() { self.name_section.insert(elem.index, elem.name); @@ -117,6 +118,8 @@ impl<'a> Module<'a> { _ => {} } } + + Ok(()) } #[must_use]