Add module loading validity checks

This commit is contained in:
Rerumu 2022-08-22 13:23:25 -04:00
parent e40023e8e6
commit 2c913e86ed
8 changed files with 38 additions and 35 deletions

View File

@ -18,7 +18,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> {
fn main() -> Result<()> { fn main() -> Result<()> {
let data = load_arg_source()?; let data = load_arg_source()?;
let wasm = Module::from_data(&data); let wasm = Module::try_from_data(&data).unwrap();
let lock = &mut std::io::stdout().lock(); let lock = &mut std::io::stdout().lock();

View File

@ -22,7 +22,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> {
fn main() -> Result<()> { fn main() -> Result<()> {
let data = load_arg_source()?; let data = load_arg_source()?;
let wasm = Module::from_data(&data); let wasm = Module::try_from_data(&data).unwrap();
let lock = &mut std::io::stdout().lock(); let lock = &mut std::io::stdout().lock();

View File

@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule;
libfuzzer_sys::fuzz_target!(|module: RngModule| { libfuzzer_sys::fuzz_target!(|module: RngModule| {
let data = module.to_bytes(); let data = module.to_bytes();
let wasm = Module::from_data(&data); let wasm = Module::try_from_data(&data).unwrap();
let sink = &mut std::io::sink(); let sink = &mut std::io::sink();

View File

@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule;
libfuzzer_sys::fuzz_target!(|module: RngModule| { libfuzzer_sys::fuzz_target!(|module: RngModule| {
let data = module.to_bytes(); let data = module.to_bytes();
let wasm = Module::from_data(&data); let wasm = Module::try_from_data(&data).unwrap();
let sink = &mut std::io::sink(); let sink = &mut std::io::sink();

View File

@ -90,7 +90,7 @@ impl Target for LuaJIT {
Wat::Module(ast) => ast.encode().unwrap(), Wat::Module(ast) => ast.encode().unwrap(),
Wat::Component(_) => unimplemented!(), Wat::Component(_) => unimplemented!(),
}; };
let data = Module::from_data(&bytes); let data = Module::try_from_data(&bytes).unwrap();
writeln!(w, "assert_trap((function()")?; writeln!(w, "assert_trap((function()")?;
codegen_luajit::from_module_untyped(&data, w)?; codegen_luajit::from_module_untyped(&data, w)?;

View File

@ -103,7 +103,7 @@ impl Target for Luau {
Wat::Module(ast) => ast.encode().unwrap(), Wat::Module(ast) => ast.encode().unwrap(),
Wat::Component(_) => unimplemented!(), Wat::Component(_) => unimplemented!(),
}; };
let data = Module::from_data(&bytes); let data = Module::try_from_data(&bytes).unwrap();
writeln!(w, "assert_trap((function()")?; writeln!(w, "assert_trap((function()")?;
codegen_luau::from_module_untyped(&data, w)?; codegen_luau::from_module_untyped(&data, w)?;

View File

@ -81,7 +81,7 @@ pub trait Target: Sized {
let mut ast = try_into_ast_module(data).expect("Must be a module"); let mut ast = try_into_ast_module(data).expect("Must be a module");
let bytes = ast.encode().unwrap(); let bytes = ast.encode().unwrap();
let data = AstModule::from_data(&bytes); let data = AstModule::try_from_data(&bytes).unwrap();
let name = ast.id.as_ref().map(Id::name); let name = ast.id.as_ref().map(Id::name);
Self::write_module(&data, name, w)?; Self::write_module(&data, name, w)?;

View File

@ -2,17 +2,9 @@ use std::collections::HashMap;
use wasmparser::{ use wasmparser::{
BlockType, Data, Element, Export, ExternalKind, FunctionBody, Global, Import, MemoryType, Name, BlockType, Data, Element, Export, ExternalKind, FunctionBody, Global, Import, MemoryType, Name,
NameSectionReader, Parser, Payload, TableType, Type, TypeRef, NameSectionReader, Parser, Payload, Result, TableType, Type, TypeRef,
}; };
macro_rules! to_section {
($data:ident) => {{
let read: Result<_, _> = $data.into_iter().collect();
read.unwrap()
}};
}
#[derive(PartialEq, Eq, Clone, Copy)] #[derive(PartialEq, Eq, Clone, Copy)]
pub enum External { pub enum External {
Func, Func,
@ -46,6 +38,13 @@ impl From<ExternalKind> for External {
} }
} }
pub(crate) fn read_checked<T, I>(reader: I) -> Result<Vec<T>>
where
I: IntoIterator<Item = Result<T>>,
{
reader.into_iter().collect()
}
pub struct Module<'a> { pub struct Module<'a> {
type_section: Vec<Type>, type_section: Vec<Type>,
import_section: Vec<Import<'a>>, import_section: Vec<Import<'a>>,
@ -64,8 +63,10 @@ pub struct Module<'a> {
} }
impl<'a> Module<'a> { impl<'a> Module<'a> {
#[must_use] /// # Errors
pub fn from_data(data: &'a [u8]) -> Self { ///
/// Returns a `BinaryReaderError` if any module section is malformed.
pub fn try_from_data(data: &'a [u8]) -> Result<Self> {
let mut temp = Module { let mut temp = Module {
type_section: Vec::new(), type_section: Vec::new(),
import_section: Vec::new(), import_section: Vec::new(),
@ -81,22 +82,22 @@ impl<'a> Module<'a> {
start_section: None, start_section: None,
}; };
temp.load_data(data); temp.load_data(data)?;
temp Ok(temp)
} }
fn load_data(&mut self, data: &'a [u8]) { fn load_data(&mut self, data: &'a [u8]) -> Result<()> {
for payload in Parser::new(0).parse_all(data).flatten() { for payload in Parser::new(0).parse_all(data) {
match payload { match payload? {
Payload::TypeSection(v) => self.type_section = to_section!(v), Payload::TypeSection(v) => self.type_section = read_checked(v)?,
Payload::ImportSection(v) => self.import_section = to_section!(v), Payload::ImportSection(v) => self.import_section = read_checked(v)?,
Payload::FunctionSection(v) => self.func_section = to_section!(v), Payload::FunctionSection(v) => self.func_section = read_checked(v)?,
Payload::TableSection(v) => self.table_section = to_section!(v), Payload::TableSection(v) => self.table_section = read_checked(v)?,
Payload::MemorySection(v) => self.memory_section = to_section!(v), Payload::MemorySection(v) => self.memory_section = read_checked(v)?,
Payload::GlobalSection(v) => self.global_section = to_section!(v), Payload::GlobalSection(v) => self.global_section = read_checked(v)?,
Payload::ExportSection(v) => self.export_section = to_section!(v), Payload::ExportSection(v) => self.export_section = read_checked(v)?,
Payload::ElementSection(v) => self.element_section = to_section!(v), Payload::ElementSection(v) => self.element_section = read_checked(v)?,
Payload::DataSection(v) => self.data_section = to_section!(v), Payload::DataSection(v) => self.data_section = read_checked(v)?,
Payload::CodeSectionEntry(v) => { Payload::CodeSectionEntry(v) => {
self.code_section.push(v); self.code_section.push(v);
} }
@ -104,9 +105,9 @@ impl<'a> Module<'a> {
self.start_section = Some(func); self.start_section = Some(func);
} }
Payload::CustomSection(v) if v.name() == "name" => { Payload::CustomSection(v) if v.name() == "name" => {
for name in NameSectionReader::new(v.data(), v.data_offset()).unwrap() { for name in NameSectionReader::new(v.data(), v.data_offset())? {
if let Name::Function(map) = name.unwrap() { if let Name::Function(map) = name? {
let mut iter = map.get_map().unwrap(); let mut iter = map.get_map()?;
while let Ok(elem) = iter.read() { while let Ok(elem) = iter.read() {
self.name_section.insert(elem.index, elem.name); self.name_section.insert(elem.index, elem.name);
@ -117,6 +118,8 @@ impl<'a> Module<'a> {
_ => {} _ => {}
} }
} }
Ok(())
} }
#[must_use] #[must_use]