diff --git a/src/elements/func.rs b/src/elements/func.rs index 4aa3bbc..91f9feb 100644 --- a/src/elements/func.rs +++ b/src/elements/func.rs @@ -3,6 +3,7 @@ use super::{ Deserialize, Error, ValueType, VarUint32, CountedList, Opcodes, Serialize, CountedWriter, CountedListWriter, }; +use elements::section::SectionReader; /// Function signature (type reference) #[derive(Debug, Copy, Clone)] @@ -116,9 +117,10 @@ impl Deserialize for FuncBody { fn deserialize(reader: &mut R) -> Result { // todo: maybe use reader.take(section_length) - let _body_size = VarUint32::deserialize(reader)?; - let locals: Vec = CountedList::deserialize(reader)?.into_inner(); - let opcodes = Opcodes::deserialize(reader)?; + let mut body_reader = SectionReader::new(reader)?; + let locals: Vec = CountedList::::deserialize(&mut body_reader)?.into_inner(); + let opcodes = Opcodes::deserialize(&mut body_reader)?; + body_reader.close()?; Ok(FuncBody { locals: locals, opcodes: opcodes }) } } diff --git a/src/elements/mod.rs b/src/elements/mod.rs index 6c1a2cd..0b7d1e4 100644 --- a/src/elements/mod.rs +++ b/src/elements/mod.rs @@ -4,9 +4,26 @@ use std::error; use std::fmt; use std::io; +macro_rules! buffered_read { + ($buffer_size: expr, $length: expr, $reader: expr) => { + { + let mut vec_buf = Vec::new(); + let mut total_read = 0; + let mut buf = [0u8; $buffer_size]; + while total_read < $length { + let next_to_read = if $length - total_read > $buffer_size { $buffer_size } else { $length - total_read }; + $reader.read_exact(&mut buf[0..next_to_read])?; + vec_buf.extend_from_slice(&buf[0..next_to_read]); + total_read += next_to_read; + } + vec_buf + } + } +} + +mod primitives; mod module; mod section; -mod primitives; mod types; mod import_entry; mod export_entry; @@ -43,15 +60,16 @@ pub use self::name_section::{ /// Deserialization from serial i/o pub trait Deserialize : Sized { /// Serialization error produced by deserialization routine. - type Error; + type Error: From; /// Deserialize type from serial i/o fn deserialize(reader: &mut R) -> Result; } -/// Serialization to serial i/o +/// Serialization to serial i/o. Takes self by value to consume less memory +/// (parity-wasm IR is being partially freed by filling the result buffer). pub trait Serialize { /// Serialization error produced by serialization routine. - type Error; + type Error: From; /// Serialize type to serial i/o fn serialize(self, writer: &mut W) -> Result<(), Self::Error>; } @@ -100,6 +118,18 @@ pub enum Error { InvalidVarUint64, /// Inconsistent metadata InconsistentMetadata, + /// Invalid section id + InvalidSectionId(u8), + /// Sections are out of order + SectionsOutOfOrder, + /// Duplicated sections + DuplicatedSections(u8), + /// Invalid memory reference (should be 0) + InvalidMemoryReference(u8), + /// Invalid table reference (should be 0) + InvalidTableReference(u8), + /// Unknown function form (should be 0x60) + UnknownFunctionForm(u8), } impl fmt::Display for Error { @@ -125,6 +155,12 @@ impl fmt::Display for Error { Error::InvalidVarInt64 => write!(f, "Not a signed 64-bit integer"), Error::InvalidVarUint64 => write!(f, "Not an unsigned 64-bit integer"), Error::InconsistentMetadata => write!(f, "Inconsistent metadata"), + Error::InvalidSectionId(ref id) => write!(f, "Invalid section id: {}", id), + Error::SectionsOutOfOrder => write!(f, "Sections out of order"), + Error::DuplicatedSections(ref id) => write!(f, "Dupliated sections ({})", id), + Error::InvalidMemoryReference(ref mem_ref) => write!(f, "Invalid memory reference ({})", mem_ref), + Error::InvalidTableReference(ref table_ref) => write!(f, "Invalid table reference ({})", table_ref), + Error::UnknownFunctionForm(ref form) => write!(f, "Unknown function form ({})", form), } } } @@ -150,6 +186,12 @@ impl error::Error for Error { Error::InvalidVarInt64 => "Not a signed 64-bit integer", Error::InvalidVarUint64 => "Not an unsigned 64-bit integer", Error::InconsistentMetadata => "Inconsistent metadata", + Error::InvalidSectionId(_) => "Invalid section id", + Error::SectionsOutOfOrder => "Sections out of order", + Error::DuplicatedSections(_) => "Duplicated section", + Error::InvalidMemoryReference(_) => "Invalid memory reference", + Error::InvalidTableReference(_) => "Invalid table reference", + Error::UnknownFunctionForm(_) => "Unknown function form", } } } @@ -193,7 +235,11 @@ pub fn deserialize_file>(p: P) -> Result(contents: &[u8]) -> Result { let mut reader = io::Cursor::new(contents); - T::deserialize(&mut reader) + let result = T::deserialize(&mut reader)?; + if reader.position() != contents.len() as u64 { + return Err(io::Error::from(io::ErrorKind::InvalidData).into()) + } + Ok(result) } /// Create buffer with serialized value. diff --git a/src/elements/module.rs b/src/elements/module.rs index 1538d50..394b48c 100644 --- a/src/elements/module.rs +++ b/src/elements/module.rs @@ -230,11 +230,23 @@ impl Deserialize for Module { return Err(Error::UnsupportedVersion(version)); } + let mut last_section_id = 0; + loop { match Section::deserialize(reader) { Err(Error::UnexpectedEof) => { break; }, Err(e) => { return Err(e) }, - Ok(section) => { sections.push(section); } + Ok(section) => { + if section.id() != 0 { + if last_section_id > section.id() { + return Err(Error::SectionsOutOfOrder); + } else if last_section_id == section.id() { + return Err(Error::DuplicatedSections(last_section_id)); + } + } + last_section_id = section.id(); + sections.push(section); + } } } @@ -319,7 +331,7 @@ pub fn peek_size(source: &[u8]) -> usize { #[cfg(test)] mod integration_tests { - use super::super::{deserialize_file, serialize, deserialize_buffer, Section, Error}; + use super::super::{deserialize_file, serialize, deserialize_buffer, Section}; use super::Module; #[test] @@ -468,15 +480,6 @@ mod integration_tests { assert_eq!(Module::default().magic, module2.magic); } - #[test] - fn inconsistent_meta() { - let result = deserialize_file("./res/cases/v1/payload_len.wasm"); - - // should be error, not panic - if let Err(Error::InconsistentMetadata) = result {} - else { panic!("Should return inconsistent metadata error"); } - } - #[test] fn names() { use super::super::name_section::NameSection; diff --git a/src/elements/ops.rs b/src/elements/ops.rs index dbfbc67..28c6c40 100644 --- a/src/elements/ops.rs +++ b/src/elements/ops.rs @@ -344,9 +344,16 @@ impl Deserialize for Opcode { }, 0x0f => Return, 0x10 => Call(VarUint32::deserialize(reader)?.into()), - 0x11 => CallIndirect( - VarUint32::deserialize(reader)?.into(), - Uint8::deserialize(reader)?.into()), + 0x11 => { + let signature: u32 = VarUint32::deserialize(reader)?.into(); + let table_ref: u8 = Uint8::deserialize(reader)?.into(); + if table_ref != 0 { return Err(Error::InvalidTableReference(table_ref)); } + + CallIndirect( + signature, + table_ref, + ) + }, 0x1a => Drop, 0x1b => Select, @@ -449,8 +456,16 @@ impl Deserialize for Opcode { VarUint32::deserialize(reader)?.into()), - 0x3f => CurrentMemory(Uint8::deserialize(reader)?.into()), - 0x40 => GrowMemory(Uint8::deserialize(reader)?.into()), + 0x3f => { + let mem_ref: u8 = Uint8::deserialize(reader)?.into(); + if mem_ref != 0 { return Err(Error::InvalidMemoryReference(mem_ref)); } + CurrentMemory(mem_ref) + }, + 0x40 => { + let mem_ref: u8 = Uint8::deserialize(reader)?.into(); + if mem_ref != 0 { return Err(Error::InvalidMemoryReference(mem_ref)); } + GrowMemory(mem_ref) + } 0x41 => I32Const(VarInt32::deserialize(reader)?.into()), 0x42 => I64Const(VarInt64::deserialize(reader)?.into()), diff --git a/src/elements/primitives.rs b/src/elements/primitives.rs index bf65978..3b78299 100644 --- a/src/elements/primitives.rs +++ b/src/elements/primitives.rs @@ -492,11 +492,9 @@ impl Deserialize for String { type Error = Error; fn deserialize(reader: &mut R) -> Result { - let length = VarUint32::deserialize(reader)?.into(); + let length = u32::from(VarUint32::deserialize(reader)?) as usize; if length > 0 { - let mut buf = vec![0u8; length]; - reader.read_exact(&mut buf)?; - String::from_utf8(buf).map_err(|_| Error::NonUtf8String) + String::from_utf8(buffered_read!(1024, length, reader)).map_err(|_| Error::NonUtf8String) } else { Ok(String::new()) @@ -600,6 +598,7 @@ impl, T: IntoIterator> Serialize f } } + #[cfg(test)] mod tests { diff --git a/src/elements/section.rs b/src/elements/section.rs index 640cdf7..98fa451 100644 --- a/src/elements/section.rs +++ b/src/elements/section.rs @@ -2,7 +2,6 @@ use std::io; use super::{ Serialize, Deserialize, - Unparsed, Error, VarUint7, VarUint32, @@ -25,6 +24,8 @@ use super::{ use super::types::Type; use super::name_section::NameSection; +const ENTRIES_BUFFER_LENGTH: usize = 16384; + /// Section in the WebAssembly module. #[derive(Debug, Clone)] pub enum Section { @@ -102,8 +103,10 @@ impl Deserialize for Section { Section::Export(ExportSection::deserialize(reader)?) }, 8 => { - let _section_length = VarUint32::deserialize(reader)?; - Section::Start(VarUint32::deserialize(reader)?.into()) + let mut section_reader = SectionReader::new(reader)?; + let start_idx = VarUint32::deserialize(&mut section_reader)?; + section_reader.close()?; + Section::Start(start_idx.into()) }, 9 => { Section::Element(ElementSection::deserialize(reader)?) @@ -114,9 +117,9 @@ impl Deserialize for Section { 11 => { Section::Data(DataSection::deserialize(reader)?) }, - _ => { - Section::Unparsed { id: id.into(), payload: Unparsed::deserialize(reader)?.into() } - } + invalid_id => { + return Err(Error::InvalidSectionId(invalid_id)) + }, } ) } @@ -194,6 +197,72 @@ impl Serialize for Section { } } +impl Section { + pub(crate) fn id(&self) -> u8 { + match *self { + Section::Custom(_) => 0x00, + Section::Unparsed { .. } => 0x00, + Section::Type(_) => 0x1, + Section::Import(_) => 0x2, + Section::Function(_) => 0x3, + Section::Table(_) => 0x4, + Section::Memory(_) => 0x5, + Section::Global(_) => 0x6, + Section::Export(_) => 0x7, + Section::Start(_) => 0x8, + Section::Element(_) => 0x9, + Section::Code(_) => 0x0a, + Section::Data(_) => 0x0b, + Section::Name(_) => 0x00, + } + } +} + +pub(crate) struct SectionReader { + cursor: io::Cursor>, + declared_length: usize, +} + +impl SectionReader { + pub fn new(reader: &mut R) -> Result { + let length = u32::from(VarUint32::deserialize(reader)?) as usize; + let inner_buffer = buffered_read!(ENTRIES_BUFFER_LENGTH, length, reader); + let buf_length = inner_buffer.len(); + let cursor = io::Cursor::new(inner_buffer); + + Ok(SectionReader { + cursor: cursor, + declared_length: buf_length, + }) + } + + pub fn close(self) -> Result<(), ::elements::Error> { + let cursor = self.cursor; + let buf_length = self.declared_length; + + if cursor.position() != buf_length as u64 { + Err(io::Error::from(io::ErrorKind::InvalidData).into()) + } else { + Ok(()) + } + } +} + +impl io::Read for SectionReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.cursor.read(buf) + } +} + +fn read_entries>(reader: &mut R) + -> Result, ::elements::Error> +{ + let mut section_reader = SectionReader::new(reader)?; + let result = CountedList::::deserialize(&mut section_reader)?.into_inner(); + section_reader.close()?; + Ok(result) +} + /// Custom section #[derive(Debug, Default, Clone)] pub struct CustomSection { @@ -228,21 +297,11 @@ impl Deserialize for CustomSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let section_length: u32 = VarUint32::deserialize(reader)?.into(); - - let name = String::deserialize(reader)?; - let total_naming = name.len() as u32 + name.len() as u32 / 128 + 1; - if total_naming > section_length { - return Err(Error::InconsistentMetadata) - } else if total_naming == section_length { - return Ok(CustomSection { name: name, payload: Vec::new() }); - } - - let payload_left = section_length - total_naming; - let mut payload = vec![0u8; payload_left as usize]; - reader.read_exact(&mut payload[..])?; - + let section_length: usize = u32::from(VarUint32::deserialize(reader)?) as usize; + let buf = buffered_read!(16384, section_length, reader); + let mut cursor = ::std::io::Cursor::new(&buf[..]); + let name = String::deserialize(&mut cursor)?; + let payload = buf[cursor.position() as usize..].to_vec(); Ok(CustomSection { name: name, payload: payload }) } } @@ -286,10 +345,7 @@ impl Deserialize for TypeSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let types: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(TypeSection(types)) + Ok(TypeSection(read_entries(reader)?)) } } @@ -348,10 +404,7 @@ impl Deserialize for ImportSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(ImportSection(entries)) + Ok(ImportSection(read_entries(reader)?)) } } @@ -396,11 +449,7 @@ impl Deserialize for FunctionSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let funcs: Vec = CountedList::::deserialize(reader)? - .into_inner(); - Ok(FunctionSection(funcs)) + Ok(FunctionSection(read_entries(reader)?)) } } @@ -445,10 +494,7 @@ impl Deserialize for TableSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(TableSection(entries)) + Ok(TableSection(read_entries(reader)?)) } } @@ -493,10 +539,7 @@ impl Deserialize for MemorySection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(MemorySection(entries)) + Ok(MemorySection(read_entries(reader)?)) } } @@ -541,10 +584,7 @@ impl Deserialize for GlobalSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(GlobalSection(entries)) + Ok(GlobalSection(read_entries(reader)?)) } } @@ -589,10 +629,7 @@ impl Deserialize for ExportSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(ExportSection(entries)) + Ok(ExportSection(read_entries(reader)?)) } } @@ -637,10 +674,7 @@ impl Deserialize for CodeSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(CodeSection(entries)) + Ok(CodeSection(read_entries(reader)?)) } } @@ -685,10 +719,7 @@ impl Deserialize for ElementSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(ElementSection(entries)) + Ok(ElementSection(read_entries(reader)?)) } } @@ -733,10 +764,7 @@ impl Deserialize for DataSection { type Error = Error; fn deserialize(reader: &mut R) -> Result { - // todo: maybe use reader.take(section_length) - let _section_length = VarUint32::deserialize(reader)?; - let entries: Vec = CountedList::deserialize(reader)?.into_inner(); - Ok(DataSection(entries)) + Ok(DataSection(read_entries(reader)?)) } } @@ -846,23 +874,23 @@ mod tests { fn types_test_payload() -> &'static [u8] { &[ // section length - 148u8, 0x80, 0x80, 0x80, 0x0, + 11, // 2 functions - 130u8, 0x80, 0x80, 0x80, 0x0, + 2, // func 1, form =1 - 0x01, + 0x60, // param_count=1 - 129u8, 0x80, 0x80, 0x80, 0x0, + 1, // first param 0x7e, // i64 // no return params 0x00, // func 2, form=1 - 0x01, - // param_count=1 - 130u8, 0x80, 0x80, 0x80, 0x0, + 0x60, + // param_count=2 + 2, // first param 0x7e, // second param @@ -898,9 +926,9 @@ mod tests { // section id 0x07, // section length - 148u8, 0x80, 0x80, 0x80, 0x0, + 28, // 6 entries - 134u8, 0x80, 0x80, 0x80, 0x0, + 6, // func "A", index 6 // [name_len(1-5 bytes), name_bytes(name_len, internal_kind(1byte), internal_index(1-5 bytes)]) 0x01, 0x41, 0x01, 0x86, 0x80, 0x00, @@ -978,10 +1006,11 @@ mod tests { fn data_payload() -> &'static [u8] { &[ 0x0bu8, // section id - 19, // 19 bytes overall + 20, // 20 bytes overall 0x01, // number of segments 0x00, // index 0x0b, // just `end` op + 0x10, // 16x 0x00 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/src/elements/segment.rs b/src/elements/segment.rs index e77a6d0..1fe49c5 100644 --- a/src/elements/segment.rs +++ b/src/elements/segment.rs @@ -107,10 +107,8 @@ impl Deserialize for DataSegment { fn deserialize(reader: &mut R) -> Result { let index = VarUint32::deserialize(reader)?; let offset = InitExpr::deserialize(reader)?; - let value_len = VarUint32::deserialize(reader)?; - - let mut value_buf = vec![0u8; value_len.into()]; - reader.read_exact(&mut value_buf[..])?; + let value_len = u32::from(VarUint32::deserialize(reader)?) as usize; + let value_buf = buffered_read!(65536, value_len, reader); Ok(DataSegment { index: index.into(), diff --git a/src/elements/types.rs b/src/elements/types.rs index 9e82767..11fedcc 100644 --- a/src/elements/types.rs +++ b/src/elements/types.rs @@ -171,6 +171,10 @@ impl Deserialize for FunctionType { fn deserialize(reader: &mut R) -> Result { let form: u8 = VarUint7::deserialize(reader)?.into(); + if form != 0x60 { + return Err(Error::UnknownFunctionForm(form)); + } + let params: Vec = CountedList::deserialize(reader)?.into_inner(); let has_return_type = VarUint1::deserialize(reader)?;