diff --git a/varint-rs/src/lib.rs b/varint-rs/src/lib.rs index 4a0a4855..7c660bf6 100644 --- a/varint-rs/src/lib.rs +++ b/varint-rs/src/lib.rs @@ -123,34 +123,48 @@ pub trait EncoderHelper: Sized { /// Helper trait to allow multiple integer types to be encoded pub trait DecoderHelper: Sized { /// Decode a single byte - fn decode_one(decoder: &mut DecoderState, byte: u8) -> Option; + fn decode_one(decoder: &mut DecoderState, byte: u8) -> errors::Result>; /// Read as much of the varint as possible fn read(decoder: &mut DecoderState, input: R) -> Poll, Error>; } macro_rules! impl_decoderstate { - ($t:ty) => { impl_decoderstate!($t, <$t>::from, |v| v); }; + ($t:ty) => { + impl_decoderstate!( + $t, + |a| a as $t, + |a: $t, b| -> Option<$t> { a.checked_shl(b as u32) } + ); + }; ($t:ty, $make_fn:expr) => { impl_decoderstate!($t, $make_fn, $make_fn); }; - ($t:ty, $make_fn:expr, $make_shift_fn:expr) => { + ($t:ty, $make_fn:expr, $shift_fn:expr) => { impl DecoderHelper for $t { #[inline] - fn decode_one(decoder: &mut DecoderState, byte: u8) -> Option<$t> { - decoder.accumulator.take().and_then(|accumulator| { - let out = accumulator | - ( - $make_fn(byte & 0x7F) << - $make_shift_fn(decoder.shift * USABLE_BITS_PER_BYTE) - ); + fn decode_one(decoder: &mut DecoderState, byte: u8) -> ::errors::Result> { + let res = decoder.accumulator.take().and_then(|accumulator| { + let out = accumulator | match $shift_fn( + $make_fn(byte & 0x7F), + decoder.shift * USABLE_BITS_PER_BYTE, + ) { + Some(a) => a, + None => return Some(Err(ErrorKind::ParseError.into())), + }; decoder.shift += 1; if byte & 0x80 == 0 { - Some(out) + Some(Ok(out)) } else { decoder.accumulator = AccumulatorState::InProgress(out); None } - }) + }); + + match res { + Some(Ok(number)) => Ok(Some(number)), + Some(Err(err)) => Err(err), + None => Ok(None), + } } fn read( @@ -173,7 +187,7 @@ macro_rules! impl_decoderstate { match input.read_exact(&mut buffer) { Ok(()) => { - if let Some(out) = Self::decode_one(decoder, buffer[0]) { + if let Some(out) = Self::decode_one(decoder, buffer[0])? { break Ok(Async::Ready(Some(out))); } } @@ -258,9 +272,9 @@ impl_encoderstate!(u64, (|val| val as u64)); impl_encoderstate!(u32, (|val| val as u32)); impl_decoderstate!(usize); -impl_decoderstate!(BigUint); -impl_decoderstate!(u64, (|val| val as u64)); -impl_decoderstate!(u32, (|val| val as u32)); +impl_decoderstate!(BigUint, BigUint::from, |a, b| Some(a << b)); +impl_decoderstate!(u64); +impl_decoderstate!(u32); impl EncoderState { pub fn source(&self) -> &T { @@ -368,7 +382,8 @@ impl Decoder for VarintDecoder { // We know that the length is not 0, so this cannot fail. let first_byte = src.split_to(1)[0]; let mut state = self.state.take().unwrap_or_default(); - let out = T::decode_one(&mut state, first_byte); + let out = T::decode_one(&mut state, first_byte) + .map_err(|_| io::Error::from(io::ErrorKind::Other))?; if let Some(out) = out { break Ok(Some(out)); @@ -390,10 +405,12 @@ pub fn decode(mut input: R) -> errors::Resu match input.read_exact(&mut buffer) { Ok(()) => { - if let Some(out) = T::decode_one(&mut decoder, buffer[0]) { - break Ok(out); - } - } + if let Some(out) = T::decode_one(&mut decoder, buffer[0]) + .map_err(|_| io::Error::from(io::ErrorKind::Other))? + { + break Ok(out); + } + } Err(inner) => break Err(Error::with_chain(inner, ErrorKind::ParseError)), } } @@ -417,6 +434,34 @@ mod tests { use num_bigint::BigUint; use futures::{Future, Stream}; + #[test] + fn large_number_fails() { + use std::io::Cursor; + use futures::Async; + use super::WriteState; + + let mut out = vec![0u8; 10]; + + { + let writable: Cursor<&mut [_]> = Cursor::new(&mut out); + + let mut state = EncoderState::new(::std::u64::MAX); + + assert_eq!( + state.write(writable).unwrap(), + Async::Ready(WriteState::Done(10)) + ); + } + + let result: Result, _> = FramedRead::new(&out[..], VarintDecoder::new()) + .into_future() + .map(|(out, _)| out) + .map_err(|(out, _)| out) + .wait(); + + assert!(result.is_err()); + } + #[test] fn can_decode_basic_biguint() { assert_eq!(