Add module loading validity checks

This commit is contained in:
Rerumu 2022-08-22 13:23:25 -04:00
parent e40023e8e6
commit 2c913e86ed
8 changed files with 38 additions and 35 deletions

View File

@ -18,7 +18,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> {
fn main() -> Result<()> {
let data = load_arg_source()?;
let wasm = Module::from_data(&data);
let wasm = Module::try_from_data(&data).unwrap();
let lock = &mut std::io::stdout().lock();

View File

@ -22,7 +22,7 @@ fn do_runtime(lock: &mut dyn Write) -> Result<()> {
fn main() -> Result<()> {
let data = load_arg_source()?;
let wasm = Module::from_data(&data);
let wasm = Module::try_from_data(&data).unwrap();
let lock = &mut std::io::stdout().lock();

View File

@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule;
libfuzzer_sys::fuzz_target!(|module: RngModule| {
let data = module.to_bytes();
let wasm = Module::from_data(&data);
let wasm = Module::try_from_data(&data).unwrap();
let sink = &mut std::io::sink();

View File

@ -5,7 +5,7 @@ use wasm_smith::Module as RngModule;
libfuzzer_sys::fuzz_target!(|module: RngModule| {
let data = module.to_bytes();
let wasm = Module::from_data(&data);
let wasm = Module::try_from_data(&data).unwrap();
let sink = &mut std::io::sink();

View File

@ -90,7 +90,7 @@ impl Target for LuaJIT {
Wat::Module(ast) => ast.encode().unwrap(),
Wat::Component(_) => unimplemented!(),
};
let data = Module::from_data(&bytes);
let data = Module::try_from_data(&bytes).unwrap();
writeln!(w, "assert_trap((function()")?;
codegen_luajit::from_module_untyped(&data, w)?;

View File

@ -103,7 +103,7 @@ impl Target for Luau {
Wat::Module(ast) => ast.encode().unwrap(),
Wat::Component(_) => unimplemented!(),
};
let data = Module::from_data(&bytes);
let data = Module::try_from_data(&bytes).unwrap();
writeln!(w, "assert_trap((function()")?;
codegen_luau::from_module_untyped(&data, w)?;

View File

@ -81,7 +81,7 @@ pub trait Target: Sized {
let mut ast = try_into_ast_module(data).expect("Must be a module");
let bytes = ast.encode().unwrap();
let data = AstModule::from_data(&bytes);
let data = AstModule::try_from_data(&bytes).unwrap();
let name = ast.id.as_ref().map(Id::name);
Self::write_module(&data, name, w)?;

View File

@ -2,17 +2,9 @@ use std::collections::HashMap;
use wasmparser::{
BlockType, Data, Element, Export, ExternalKind, FunctionBody, Global, Import, MemoryType, Name,
NameSectionReader, Parser, Payload, TableType, Type, TypeRef,
NameSectionReader, Parser, Payload, Result, TableType, Type, TypeRef,
};
macro_rules! to_section {
($data:ident) => {{
let read: Result<_, _> = $data.into_iter().collect();
read.unwrap()
}};
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum External {
Func,
@ -46,6 +38,13 @@ impl From<ExternalKind> for External {
}
}
pub(crate) fn read_checked<T, I>(reader: I) -> Result<Vec<T>>
where
I: IntoIterator<Item = Result<T>>,
{
reader.into_iter().collect()
}
pub struct Module<'a> {
type_section: Vec<Type>,
import_section: Vec<Import<'a>>,
@ -64,8 +63,10 @@ pub struct Module<'a> {
}
impl<'a> Module<'a> {
#[must_use]
pub fn from_data(data: &'a [u8]) -> Self {
/// # Errors
///
/// Returns a `BinaryReaderError` if any module section is malformed.
pub fn try_from_data(data: &'a [u8]) -> Result<Self> {
let mut temp = Module {
type_section: Vec::new(),
import_section: Vec::new(),
@ -81,22 +82,22 @@ impl<'a> Module<'a> {
start_section: None,
};
temp.load_data(data);
temp
temp.load_data(data)?;
Ok(temp)
}
fn load_data(&mut self, data: &'a [u8]) {
for payload in Parser::new(0).parse_all(data).flatten() {
match payload {
Payload::TypeSection(v) => self.type_section = to_section!(v),
Payload::ImportSection(v) => self.import_section = to_section!(v),
Payload::FunctionSection(v) => self.func_section = to_section!(v),
Payload::TableSection(v) => self.table_section = to_section!(v),
Payload::MemorySection(v) => self.memory_section = to_section!(v),
Payload::GlobalSection(v) => self.global_section = to_section!(v),
Payload::ExportSection(v) => self.export_section = to_section!(v),
Payload::ElementSection(v) => self.element_section = to_section!(v),
Payload::DataSection(v) => self.data_section = to_section!(v),
fn load_data(&mut self, data: &'a [u8]) -> Result<()> {
for payload in Parser::new(0).parse_all(data) {
match payload? {
Payload::TypeSection(v) => self.type_section = read_checked(v)?,
Payload::ImportSection(v) => self.import_section = read_checked(v)?,
Payload::FunctionSection(v) => self.func_section = read_checked(v)?,
Payload::TableSection(v) => self.table_section = read_checked(v)?,
Payload::MemorySection(v) => self.memory_section = read_checked(v)?,
Payload::GlobalSection(v) => self.global_section = read_checked(v)?,
Payload::ExportSection(v) => self.export_section = read_checked(v)?,
Payload::ElementSection(v) => self.element_section = read_checked(v)?,
Payload::DataSection(v) => self.data_section = read_checked(v)?,
Payload::CodeSectionEntry(v) => {
self.code_section.push(v);
}
@ -104,9 +105,9 @@ impl<'a> Module<'a> {
self.start_section = Some(func);
}
Payload::CustomSection(v) if v.name() == "name" => {
for name in NameSectionReader::new(v.data(), v.data_offset()).unwrap() {
if let Name::Function(map) = name.unwrap() {
let mut iter = map.get_map().unwrap();
for name in NameSectionReader::new(v.data(), v.data_offset())? {
if let Name::Function(map) = name? {
let mut iter = map.get_map()?;
while let Ok(elem) = iter.read() {
self.name_section.insert(elem.index, elem.name);
@ -117,6 +118,8 @@ impl<'a> Module<'a> {
_ => {}
}
}
Ok(())
}
#[must_use]