diff --git a/src/elements/export_entry.rs b/src/elements/export_entry.rs index c885fa6..71dae9c 100644 --- a/src/elements/export_entry.rs +++ b/src/elements/export_entry.rs @@ -1,6 +1,6 @@ -use std::io; use std::string::String; use super::{Deserialize, Serialize, Error, VarUint7, VarUint32}; +use io; /// Internal reference of the exported entry. #[derive(Debug, Clone, Copy, PartialEq)] diff --git a/src/elements/func.rs b/src/elements/func.rs index 82782ea..d7077fd 100644 --- a/src/elements/func.rs +++ b/src/elements/func.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::vec::Vec; use super::{ Deserialize, Error, ValueType, VarUint32, CountedList, Opcodes, diff --git a/src/elements/global_entry.rs b/src/elements/global_entry.rs index 1851a6c..e4e3755 100644 --- a/src/elements/global_entry.rs +++ b/src/elements/global_entry.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use super::{Deserialize, Serialize, Error, GlobalType, InitExpr}; /// Global entry in the module. diff --git a/src/elements/import_entry.rs b/src/elements/import_entry.rs index cba39bd..41de029 100644 --- a/src/elements/import_entry.rs +++ b/src/elements/import_entry.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::string::String; use super::{ Deserialize, Serialize, Error, VarUint7, VarInt7, VarUint32, VarUint1, diff --git a/src/elements/index_map.rs b/src/elements/index_map.rs index e2be2a2..d0c200c 100644 --- a/src/elements/index_map.rs +++ b/src/elements/index_map.rs @@ -1,10 +1,10 @@ use std::cmp::min; -use std::io::{Read, Write}; use std::iter::{FromIterator, IntoIterator}; use std::mem; use std::slice; use std::vec; use std::vec::Vec; +use io; use super::{Deserialize, Error, Serialize, VarUint32}; @@ -151,7 +151,7 @@ impl IndexMap { rdr: &mut R, ) -> Result, Error> where - R: Read, + R: io::Read, F: Fn(u32, &mut R) -> Result, { let len: u32 = VarUint32::deserialize(rdr)?.into(); @@ -326,7 +326,7 @@ where { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Self::Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Self::Error> { VarUint32::from(self.len()).serialize(wtr)?; for (idx, value) in self { VarUint32::from(idx).serialize(wtr)?; @@ -344,7 +344,7 @@ where /// Deserialize a map containing simple values that support `Deserialize`. /// We will allocate an underlying array no larger than `max_entry_space` to /// hold the data, so the maximum index must be less than `max_entry_space`. - pub fn deserialize( + pub fn deserialize( max_entry_space: usize, rdr: &mut R, ) -> Result { @@ -357,7 +357,7 @@ where #[cfg(test)] mod tests { - use std::io; + use io; use super::*; #[test] diff --git a/src/elements/mod.rs b/src/elements/mod.rs index dae2e1c..4dd2f87 100644 --- a/src/elements/mod.rs +++ b/src/elements/mod.rs @@ -1,8 +1,7 @@ //! Elements of the WebAssembly binary format. -use std::error; use std::fmt; -use std::io; +use io; use std::vec::Vec; use std::string::String; @@ -14,7 +13,7 @@ macro_rules! buffered_read { let mut buf = [0u8; $buffer_size]; while total_read < $length { let next_to_read = if $length - total_read > $buffer_size { $buffer_size } else { $length - total_read }; - $reader.read_exact(&mut buf[0..next_to_read])?; + $reader.read(&mut buf[0..next_to_read])?; vec_buf.extend_from_slice(&buf[0..next_to_read]); total_read += next_to_read; } @@ -177,7 +176,8 @@ impl fmt::Display for Error { } } -impl error::Error for Error { +#[cfg(feature = "std")] +impl ::std::error::Error for Error { fn description(&self) -> &str { match *self { Error::UnexpectedEof => "Unexpected end of input", @@ -212,7 +212,7 @@ impl error::Error for Error { impl From for Error { fn from(err: io::Error) -> Self { - Error::HeapOther(format!("I/O Error: {}", err)) + Error::HeapOther(format!("I/O Error: {:?}", err)) } } @@ -225,7 +225,7 @@ impl Deserialize for Unparsed { fn deserialize(reader: &mut R) -> Result { let len = VarUint32::deserialize(reader)?.into(); let mut vec = vec![0u8; len]; - reader.read_exact(&mut vec[..])?; + reader.read(&mut vec[..])?; Ok(Unparsed(vec)) } } @@ -236,22 +236,14 @@ impl From for Vec { } } -/// Deserialize module from file. -pub fn deserialize_file>(p: P) -> Result { - use std::io::Read; - - let mut contents = Vec::new(); - ::std::fs::File::open(p)?.read_to_end(&mut contents)?; - - deserialize_buffer(&contents) -} - /// Deserialize deserializable type from buffer. pub fn deserialize_buffer(contents: &[u8]) -> Result { let mut reader = io::Cursor::new(contents); let result = T::deserialize(&mut reader)?; - if reader.position() != contents.len() as u64 { - return Err(io::Error::from(io::ErrorKind::InvalidData).into()) + if reader.position() != contents.len() { + // It's a TrailingData, since if there is not enough data then + // UnexpectedEof must have been returned earlier in T::deserialize. + return Err(io::Error::TrailingData.into()) } Ok(result) } @@ -263,9 +255,33 @@ pub fn serialize(val: T) -> Result, T::Error> { Ok(buf) } -/// Serialize module to the file -pub fn serialize_to_file>(p: P, module: Module) -> Result<(), Error> -{ - let mut io = ::std::fs::File::create(p)?; - module.serialize(&mut io) +/// Deserialize module from file. +#[cfg(feature = "std")] +pub fn deserialize_file>(p: P) -> Result { + use std::io::Read; + + let mut contents = Vec::new(); + + ::std::fs::File::open(p) + .and_then(|mut f| f.read_to_end(&mut contents)) + .map_err(|e| Error::HeapOther(format!("Can't read from the file: {:?}", e)))?; + + deserialize_buffer(&contents) +} + +/// Serialize module to the file +#[cfg(feature = "std")] +pub fn serialize_to_file>(p: P, module: Module) -> Result<(), Error> { + use std::io::Write; + + let mut io = ::std::fs::File::create(p) + .map_err(|e| Error::HeapOther(format!("Can't create the file: {:?}", e)))?; + let mut buf = Vec::new(); + + module.serialize(&mut buf)?; + + io.write_all(&buf) + .map_err(|e| Error::HeapOther(format!("Can't write to the file: {:?}", e)))?; + + Ok(()) } diff --git a/src/elements/module.rs b/src/elements/module.rs index 4ecec53..e6961ef 100644 --- a/src/elements/module.rs +++ b/src/elements/module.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::vec::Vec; use byteorder::{LittleEndian, ByteOrder}; @@ -310,8 +310,8 @@ impl Module { Ok(reloc_section) => reloc_section, Err(e) => { parse_errors.push((i, e)); continue; } }; - if rdr.position() != custom.payload().len() as u64 { - parse_errors.push((i, io::Error::from(io::ErrorKind::InvalidData).into())); + if rdr.position() != custom.payload().len() { + parse_errors.push((i, io::Error::InvalidData.into())); continue; } Some(Section::Reloc(reloc_section)) @@ -448,17 +448,17 @@ struct PeekSection<'a> { } impl<'a> io::Read for PeekSection<'a> { - fn read(&mut self, buf: &mut [u8]) -> ::std::io::Result { + fn read(&mut self, buf: &mut [u8]) -> io::Result<()> { let available = ::std::cmp::min(buf.len(), self.region.len() - self.cursor); if available < buf.len() { - return Err(::std::io::Error::from(::std::io::ErrorKind::UnexpectedEof)); + return Err(io::Error::UnexpectedEof); } let range = self.cursor..self.cursor + buf.len(); buf.copy_from_slice(&self.region[range]); self.cursor += available; - Ok(available) + Ok(()) } } diff --git a/src/elements/name_section.rs b/src/elements/name_section.rs index 0de7d81..6935aed 100644 --- a/src/elements/name_section.rs +++ b/src/elements/name_section.rs @@ -1,4 +1,4 @@ -use std::io::{Read, Write}; +use io; use std::vec::Vec; use std::string::String; @@ -32,7 +32,7 @@ pub enum NameSection { impl NameSection { /// Deserialize a name section. - pub fn deserialize( + pub fn deserialize( module: &Module, rdr: &mut R, ) -> Result { @@ -44,7 +44,7 @@ impl NameSection { NAME_TYPE_LOCAL => NameSection::Local(LocalNameSection::deserialize(module, rdr)?), _ => { let mut name_payload = vec![0u8; name_payload_len as usize]; - rdr.read_exact(&mut name_payload)?; + rdr.read(&mut name_payload)?; NameSection::Unparsed { name_type, name_payload, @@ -58,7 +58,7 @@ impl NameSection { impl Serialize for NameSection { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { let (name_type, name_payload) = match self { NameSection::Module(mod_name) => { let mut buffer = vec![]; @@ -82,7 +82,7 @@ impl Serialize for NameSection { }; VarUint7::from(name_type).serialize(wtr)?; VarUint32::from(name_payload.len()).serialize(wtr)?; - wtr.write_all(&name_payload)?; + wtr.write(&name_payload)?; Ok(()) } } @@ -113,7 +113,7 @@ impl ModuleNameSection { impl Serialize for ModuleNameSection { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { self.name.serialize(wtr) } } @@ -121,7 +121,7 @@ impl Serialize for ModuleNameSection { impl Deserialize for ModuleNameSection { type Error = Error; - fn deserialize(rdr: &mut R) -> Result { + fn deserialize(rdr: &mut R) -> Result { let name = String::deserialize(rdr)?; Ok(ModuleNameSection { name }) } @@ -145,7 +145,7 @@ impl FunctionNameSection { } /// Deserialize names, making sure that all names correspond to functions. - pub fn deserialize( + pub fn deserialize( module: &Module, rdr: &mut R, ) -> Result { @@ -157,7 +157,7 @@ impl FunctionNameSection { impl Serialize for FunctionNameSection { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { self.names.serialize(wtr) } } @@ -182,7 +182,7 @@ impl LocalNameSection { /// Deserialize names, making sure that all names correspond to local /// variables. - pub fn deserialize( + pub fn deserialize( module: &Module, rdr: &mut R, ) -> Result { @@ -221,7 +221,7 @@ impl LocalNameSection { impl Serialize for LocalNameSection { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { self.local_names.serialize(wtr) } } diff --git a/src/elements/ops.rs b/src/elements/ops.rs index 4352b04..b293060 100644 --- a/src/elements/ops.rs +++ b/src/elements/ops.rs @@ -1,6 +1,7 @@ -use std::{io, fmt}; +use std::fmt; use std::vec::Vec; use std::boxed::Box; +use io; use super::{ Serialize, Deserialize, Error, Uint8, VarUint32, CountedList, BlockType, @@ -613,7 +614,7 @@ impl Deserialize for Opcode { macro_rules! op { ($writer: expr, $byte: expr) => ({ let b: u8 = $byte; - $writer.write_all(&[b])?; + $writer.write(&[b])?; }); ($writer: expr, $byte: expr, $s: block) => ({ op!($writer, $byte); diff --git a/src/elements/primitives.rs b/src/elements/primitives.rs index a224279..c765ce7 100644 --- a/src/elements/primitives.rs +++ b/src/elements/primitives.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::vec::Vec; use std::string::String; use byteorder::{LittleEndian, ByteOrder}; @@ -44,7 +44,7 @@ impl Deserialize for VarUint32 { loop { if shift > 31 { return Err(Error::InvalidVarUint32); } - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; let b = u8buf[0] as u32; res |= (b & 0x7f).checked_shl(shift).ok_or(Error::InvalidVarUint32)?; shift += 7; @@ -71,7 +71,7 @@ impl Serialize for VarUint32 { if v > 0 { buf[0] |= 0b1000_0000; } - writer.write_all(&buf[..])?; + writer.write(&buf[..])?; if v == 0 { break; } } @@ -100,7 +100,7 @@ impl Deserialize for VarUint64 { loop { if shift > 63 { return Err(Error::InvalidVarUint64); } - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; let b = u8buf[0] as u64; res |= (b & 0x7f).checked_shl(shift).ok_or(Error::InvalidVarUint64)?; shift += 7; @@ -127,7 +127,7 @@ impl Serialize for VarUint64 { if v > 0 { buf[0] |= 0b1000_0000; } - writer.write_all(&buf[..])?; + writer.write(&buf[..])?; if v == 0 { break; } } @@ -162,7 +162,7 @@ impl Deserialize for VarUint7 { fn deserialize(reader: &mut R) -> Result { let mut u8buf = [0u8; 1]; - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; Ok(VarUint7(u8buf[0])) } } @@ -172,7 +172,7 @@ impl Serialize for VarUint7 { fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { // todo check range? - writer.write_all(&[self.0])?; + writer.write(&[self.0])?; Ok(()) } } @@ -198,7 +198,7 @@ impl Deserialize for VarInt7 { fn deserialize(reader: &mut R) -> Result { let mut u8buf = [0u8; 1]; - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; // check if number is not continued! if u8buf[0] & 0b1000_0000 != 0 { @@ -219,7 +219,7 @@ impl Serialize for VarInt7 { // todo check range? let mut b: u8 = self.0 as u8; if self.0 < 0 { b |= 0b0100_0000; b &= 0b0111_1111; } - writer.write_all(&[b])?; + writer.write(&[b])?; Ok(()) } } @@ -246,7 +246,7 @@ impl Deserialize for Uint8 { fn deserialize(reader: &mut R) -> Result { let mut u8buf = [0u8; 1]; - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; Ok(Uint8(u8buf[0])) } } @@ -255,7 +255,7 @@ impl Serialize for Uint8 { type Error = Error; fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { - writer.write_all(&[self.0])?; + writer.write(&[self.0])?; Ok(()) } } @@ -286,7 +286,7 @@ impl Deserialize for VarInt32 { let mut u8buf = [0u8; 1]; loop { if shift > 31 { return Err(Error::InvalidVarInt32); } - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; let b = u8buf[0]; res |= ((b & 0x7f) as i32).checked_shl(shift).ok_or(Error::InvalidVarInt32)?; @@ -327,7 +327,7 @@ impl Serialize for VarInt32 { buf[0] |= 0b1000_0000 } - writer.write_all(&buf[..])?; + writer.write(&buf[..])?; } Ok(()) @@ -360,7 +360,7 @@ impl Deserialize for VarInt64 { loop { if shift > 63 { return Err(Error::InvalidVarInt64); } - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; let b = u8buf[0]; res |= ((b & 0x7f) as i64).checked_shl(shift).ok_or(Error::InvalidVarInt64)?; @@ -399,7 +399,7 @@ impl Serialize for VarInt64 { buf[0] |= 0b1000_0000 } - writer.write_all(&buf[..])?; + writer.write(&buf[..])?; } Ok(()) @@ -415,7 +415,7 @@ impl Deserialize for Uint32 { fn deserialize(reader: &mut R) -> Result { let mut buf = [0u8; 4]; - reader.read_exact(&mut buf)?; + reader.read(&mut buf)?; // todo check range Ok(Uint32(LittleEndian::read_u32(&buf))) } @@ -433,7 +433,7 @@ impl Serialize for Uint32 { fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { let mut buf = [0u8; 4]; LittleEndian::write_u32(&mut buf, self.0); - writer.write_all(&buf)?; + writer.write(&buf)?; Ok(()) } } @@ -451,7 +451,7 @@ impl Deserialize for Uint64 { fn deserialize(reader: &mut R) -> Result { let mut buf = [0u8; 8]; - reader.read_exact(&mut buf)?; + reader.read(&mut buf)?; // todo check range Ok(Uint64(LittleEndian::read_u64(&buf))) } @@ -463,7 +463,7 @@ impl Serialize for Uint64 { fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { let mut buf = [0u8; 8]; LittleEndian::write_u64(&mut buf, self.0); - writer.write_all(&buf)?; + writer.write(&buf)?; Ok(()) } } @@ -500,7 +500,7 @@ impl Deserialize for VarUint1 { fn deserialize(reader: &mut R) -> Result { let mut u8buf = [0u8; 1]; - reader.read_exact(&mut u8buf)?; + reader.read(&mut u8buf)?; match u8buf[0] { 0 => Ok(VarUint1(false)), 1 => Ok(VarUint1(true)), @@ -513,7 +513,7 @@ impl Serialize for VarUint1 { type Error = Error; fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { - writer.write_all(&[ + writer.write(&[ if self.0 { 1u8 } else { 0u8 } ])?; Ok(()) @@ -539,7 +539,7 @@ impl Serialize for String { fn serialize(self, writer: &mut W) -> Result<(), Error> { VarUint32::from(self.len()).serialize(writer)?; - writer.write_all(&self.into_bytes()[..])?; + writer.write(&self.into_bytes()[..])?; Ok(()) } } @@ -589,24 +589,15 @@ impl<'a, W: 'a + io::Write> CountedWriter<'a, W> { let data = self.data; VarUint32::from(data.len()) .serialize(writer) - .map_err( - |_| io::Error::new( - io::ErrorKind::Other, - "Length serialization error", - ) - )?; - writer.write_all(&data[..])?; + .map_err(|_| io::Error::InvalidData)?; + writer.write(&data[..])?; Ok(()) } } impl<'a, W: 'a + io::Write> io::Write for CountedWriter<'a, W> { - fn write(&mut self, buf: &[u8]) -> io::Result { + fn write(&mut self, buf: &[u8]) -> io::Result<()> { self.data.extend(buf.to_vec()); - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { Ok(()) } } diff --git a/src/elements/reloc_section.rs b/src/elements/reloc_section.rs index 540bb13..bf29bdb 100644 --- a/src/elements/reloc_section.rs +++ b/src/elements/reloc_section.rs @@ -1,4 +1,4 @@ -use std::io::{Read, Write}; +use io; use std::vec::Vec; use std::string::String; @@ -73,7 +73,7 @@ impl RelocSection { impl RelocSection { /// Deserialize a reloc section. - pub fn deserialize( + pub fn deserialize( name: String, rdr: &mut R, ) -> Result { @@ -101,7 +101,7 @@ impl RelocSection { impl Serialize for RelocSection { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { let mut counted_writer = CountedWriter::new(wtr); self.name.serialize(&mut counted_writer)?; @@ -209,7 +209,7 @@ pub enum RelocationEntry { impl Deserialize for RelocationEntry { type Error = Error; - fn deserialize(rdr: &mut R) -> Result { + fn deserialize(rdr: &mut R) -> Result { match VarUint7::deserialize(rdr)?.into() { FUNCTION_INDEX_LEB => Ok(RelocationEntry::FunctionIndexLeb { offset: VarUint32::deserialize(rdr)?.into(), @@ -262,7 +262,7 @@ impl Deserialize for RelocationEntry { impl Serialize for RelocationEntry { type Error = Error; - fn serialize(self, wtr: &mut W) -> Result<(), Error> { + fn serialize(self, wtr: &mut W) -> Result<(), Error> { match self { RelocationEntry::FunctionIndexLeb { offset, index } => { VarUint7::from(FUNCTION_INDEX_LEB).serialize(wtr)?; diff --git a/src/elements/section.rs b/src/elements/section.rs index e27743b..ad2d0e1 100644 --- a/src/elements/section.rs +++ b/src/elements/section.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::vec::Vec; use std::string::String; use super::{ @@ -145,7 +145,7 @@ impl Serialize for Section { }, Section::Unparsed { id, payload } => { VarUint7::from(id).serialize(writer)?; - writer.write_all(&payload[..])?; + writer.write(&payload[..])?; }, Section::Type(type_section) => { VarUint7::from(0x01).serialize(writer)?; @@ -250,12 +250,12 @@ impl SectionReader { }) } - pub fn close(self) -> Result<(), ::elements::Error> { + pub fn close(self) -> Result<(), io::Error> { let cursor = self.cursor; let buf_length = self.declared_length; - if cursor.position() != buf_length as u64 { - Err(io::Error::from(io::ErrorKind::InvalidData).into()) + if cursor.position() != buf_length { + Err(io::Error::InvalidData) } else { Ok(()) } @@ -263,8 +263,9 @@ impl SectionReader { } impl io::Read for SectionReader { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.cursor.read(buf) + fn read(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.cursor.read(buf)?; + Ok(()) } } @@ -313,7 +314,7 @@ impl Deserialize for CustomSection { fn deserialize(reader: &mut R) -> Result { let section_length: usize = u32::from(VarUint32::deserialize(reader)?) as usize; let buf = buffered_read!(16384, section_length, reader); - let mut cursor = ::std::io::Cursor::new(&buf[..]); + let mut cursor = io::Cursor::new(&buf[..]); let name = String::deserialize(&mut cursor)?; let payload = buf[cursor.position() as usize..].to_vec(); Ok(CustomSection { name: name, payload: payload }) @@ -324,11 +325,11 @@ impl Serialize for CustomSection { type Error = Error; fn serialize(self, writer: &mut W) -> Result<(), Self::Error> { - use std::io::Write; + use io::Write; let mut counted_writer = CountedWriter::new(writer); self.name.serialize(&mut counted_writer)?; - counted_writer.write_all(&self.payload[..])?; + counted_writer.write(&self.payload[..])?; counted_writer.done()?; Ok(()) } diff --git a/src/elements/segment.rs b/src/elements/segment.rs index e2447fc..71365ae 100644 --- a/src/elements/segment.rs +++ b/src/elements/segment.rs @@ -1,4 +1,4 @@ -use std::io; +use io; use std::vec::Vec; use super::{Deserialize, Serialize, Error, VarUint32, CountedList, InitExpr, CountedListWriter}; @@ -128,7 +128,7 @@ impl Serialize for DataSegment { let value = self.value; VarUint32::from(value.len()).serialize(writer)?; - writer.write_all(&value[..])?; + writer.write(&value[..])?; Ok(()) } } diff --git a/src/elements/types.rs b/src/elements/types.rs index d034092..7d79f14 100644 --- a/src/elements/types.rs +++ b/src/elements/types.rs @@ -1,4 +1,5 @@ -use std::{io, fmt}; +use io; +use std::fmt; use std::vec::Vec; use super::{ Deserialize, Serialize, Error, VarUint7, VarInt7, VarUint1, CountedList, diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..c6730bb --- /dev/null +++ b/src/io.rs @@ -0,0 +1,102 @@ +//! Simple abstractions for the IO operations. +//! +//! Basically it just a replacement for the std::io that is usable from +//! the `no_std` environment. + +use std::vec::Vec; + +/// IO specific error. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Some unexpected data left in the buffer after reading all data. + TrailingData, + + /// Unexpected End-Of-File + UnexpectedEof, + + /// Invalid data is encountered. + InvalidData, +} + +/// IO specific Result. +pub type Result = ::std::result::Result; + +pub trait Write { + /// Write a buffer of data into this write. + /// + /// All data is written at once. + fn write(&mut self, buf: &[u8]) -> Result<()>; +} + +pub trait Read { + /// Read a data from this read to a buffer. + /// + /// If there is not enough data in this read then `UnexpectedEof` will be returned. + fn read(&mut self, buf: &mut [u8]) -> Result<()>; +} + +impl Write for Vec { + fn write(&mut self, buf: &[u8]) -> Result<()> { + self.extend(buf); + Ok(()) + } +} + +/// Reader that saves the last position. +pub struct Cursor { + inner: T, + pos: usize, +} + +impl Cursor { + pub fn new(inner: T) -> Cursor { + Cursor { + inner, + pos: 0, + } + } + + pub fn position(&self) -> usize { + self.pos + } +} + +impl> Read for Cursor { + fn read(&mut self, buf: &mut [u8]) -> Result<()> { + let slice = self.inner.as_ref(); + let remainder = slice.len() - self.pos; + let requested = buf.len(); + if requested > remainder { + return Err(Error::UnexpectedEof); + } + buf.copy_from_slice(&slice[self.pos..(self.pos + requested)]); + self.pos += requested; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cursor() { + let mut cursor = Cursor::new(vec![0xFFu8, 0x7Fu8]); + assert_eq!(cursor.position(), 0); + + let mut buf = [0u8]; + assert!(cursor.read(&mut buf[..]).is_ok()); + assert_eq!(cursor.position(), 1); + assert_eq!(buf[0], 0xFFu8); + assert!(cursor.read(&mut buf[..]).is_ok()); + assert_eq!(buf[0], 0x7Fu8); + assert_eq!(cursor.position(), 2); + } + + #[test] + fn overflow_in_cursor() { + let mut cursor = Cursor::new(vec![0u8]); + let mut buf = [0, 1, 2]; + assert!(cursor.read(&mut buf[..]).is_err()); + } +}