converting Binary struct model to native w/ global methods model

This commit is contained in:
Jae Kwon 2014-09-03 19:21:19 -07:00
parent a8ece216f0
commit d0ec18dc16
14 changed files with 521 additions and 908 deletions

View File

@ -6,11 +6,20 @@ type Binary interface {
WriteTo(w io.Writer) (int64, error) WriteTo(w io.Writer) (int64, error)
} }
func WriteTo(b Binary, w io.Writer, n int64, err error) (int64, error) { func WriteTo(w io.Writer, bz []byte, n *int64, err *error) {
if err != nil { if *err != nil {
return n, err return
} }
var n_ int64 n_, err_ := w.Write(bz)
n_, err = b.WriteTo(w) *n += int64(n_)
return n + n_, err *err = err_
}
func ReadFull(r io.Reader, buf []byte, n *int64, err *error) {
if *err != nil {
return
}
n_, err_ := io.ReadFull(r, buf)
*n += int64(n_)
*err = err_
} }

View File

@ -1,71 +1,22 @@
package binary package binary
import "io" import (
import "bytes" "io"
)
type ByteSlice []byte // ByteSlice
func (self ByteSlice) Equals(other interface{}) bool { func WriteByteSlice(w io.Writer, bz []byte, n *int64, err *error) {
if o, ok := other.(ByteSlice); ok { WriteUInt32(w, uint32(len(bz)), n, err)
return bytes.Equal(self, o) WriteTo(w, bz, n, err)
} else { }
return false
func ReadByteSlice(r io.Reader, n *int64, err *error) []byte {
length := ReadUInt32(r, n, err)
if *err != nil {
return nil
} }
} buf := make([]byte, int(length))
ReadFull(r, buf, n, err)
func (self ByteSlice) Less(other interface{}) bool { return buf
if o, ok := other.(ByteSlice); ok {
return bytes.Compare(self, o) < 0 // -1 if a < b
} else {
panic("Cannot compare unequal types")
}
}
func (self ByteSlice) ByteSize() int {
return len(self) + 4
}
func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) {
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil {
return n, err
}
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
}
func (self ByteSlice) Reader() io.Reader {
return bytes.NewReader([]byte(self))
}
func ReadByteSliceSafe(r io.Reader) (bytes ByteSlice, n int64, err error) {
length, n_, err := ReadUInt32Safe(r)
n += n_
if err != nil {
return nil, n, err
}
bytes = make([]byte, int(length))
n__, err := io.ReadFull(r, bytes)
n += int64(n__)
if err != nil {
return nil, n, err
}
return bytes, n, nil
}
func ReadByteSliceN(r io.Reader) (bytes ByteSlice, n int64) {
bytes, n, err := ReadByteSliceSafe(r)
if err != nil {
panic(err)
}
return bytes, n
}
func ReadByteSlice(r io.Reader) (bytes ByteSlice) {
bytes, _, err := ReadByteSliceSafe(r)
if err != nil {
panic(err)
}
return bytes
} }

View File

@ -1,91 +1,138 @@
package binary package binary
import ( import (
"errors"
"io" "io"
"time"
) )
type Codec interface {
WriteTo(io.Writer, interface{}, *int64, *error)
ReadFrom(io.Reader, *int64, *error) interface{}
}
//-----------------------------------------------------------------------------
const ( const (
TYPE_NIL = Byte(0x00) typeNil = byte(0x00)
TYPE_BYTE = Byte(0x01) typeByte = byte(0x01)
TYPE_INT8 = Byte(0x02) typeInt8 = byte(0x02)
TYPE_UINT8 = Byte(0x03) // typeUInt8 = byte(0x03)
TYPE_INT16 = Byte(0x04) typeInt16 = byte(0x04)
TYPE_UINT16 = Byte(0x05) typeUInt16 = byte(0x05)
TYPE_INT32 = Byte(0x06) typeInt32 = byte(0x06)
TYPE_UINT32 = Byte(0x07) typeUInt32 = byte(0x07)
TYPE_INT64 = Byte(0x08) typeInt64 = byte(0x08)
TYPE_UINT64 = Byte(0x09) typeUInt64 = byte(0x09)
TYPE_STRING = Byte(0x10) typeString = byte(0x10)
TYPE_BYTESLICE = Byte(0x11) typeByteSlice = byte(0x11)
TYPE_TIME = Byte(0x20) typeTime = byte(0x20)
) )
func GetBinaryType(o Binary) Byte { var BasicCodec = basicCodec{}
type basicCodec struct{}
func (bc basicCodec) WriteTo(w io.Writer, o interface{}, n *int64, err *error) {
switch o.(type) { switch o.(type) {
case nil: case nil:
return TYPE_NIL WriteByte(w, typeNil, n, err)
case Byte: case byte:
return TYPE_BYTE WriteByte(w, typeByte, n, err)
case Int8: WriteByte(w, o.(byte), n, err)
return TYPE_INT8 case int8:
case UInt8: WriteByte(w, typeInt8, n, err)
return TYPE_UINT8 WriteInt8(w, o.(int8), n, err)
case Int16: //case uint8:
return TYPE_INT16 // WriteByte(w, typeUInt8, n, err)
case UInt16: // WriteUInt8(w, o.(uint8), n, err)
return TYPE_UINT16 case int16:
case Int32: WriteByte(w, typeInt16, n, err)
return TYPE_INT32 WriteInt16(w, o.(int16), n, err)
case UInt32: case uint16:
return TYPE_UINT32 WriteByte(w, typeUInt16, n, err)
case Int64: WriteUInt16(w, o.(uint16), n, err)
return TYPE_INT64 case int32:
case UInt64: WriteByte(w, typeInt32, n, err)
return TYPE_UINT64 WriteInt32(w, o.(int32), n, err)
case String: case uint32:
return TYPE_STRING WriteByte(w, typeUInt32, n, err)
case ByteSlice: WriteUInt32(w, o.(uint32), n, err)
return TYPE_BYTESLICE case int64:
case Time: WriteByte(w, typeInt64, n, err)
return TYPE_TIME WriteInt64(w, o.(int64), n, err)
case uint64:
WriteByte(w, typeUInt64, n, err)
WriteUInt64(w, o.(uint64), n, err)
case string:
WriteByte(w, typeString, n, err)
WriteString(w, o.(string), n, err)
case []byte:
WriteByte(w, typeByteSlice, n, err)
WriteByteSlice(w, o.([]byte), n, err)
case time.Time:
WriteByte(w, typeTime, n, err)
WriteTime(w, o.(time.Time), n, err)
default:
panic("Unsupported type")
}
return
}
func (bc basicCodec) ReadFrom(r io.Reader, n *int64, err *error) interface{} {
type_ := ReadByte(r, n, err)
switch type_ {
case typeNil:
return nil
case typeByte:
return ReadByte(r, n, err)
case typeInt8:
return ReadInt8(r, n, err)
//case typeUInt8:
// return ReadUInt8(r, n, err)
case typeInt16:
return ReadInt16(r, n, err)
case typeUInt16:
return ReadUInt16(r, n, err)
case typeInt32:
return ReadInt32(r, n, err)
case typeUInt32:
return ReadUInt32(r, n, err)
case typeInt64:
return ReadInt64(r, n, err)
case typeUInt64:
return ReadUInt64(r, n, err)
case typeString:
return ReadString(r, n, err)
case typeByteSlice:
return ReadByteSlice(r, n, err)
case typeTime:
return ReadTime(r, n, err)
default: default:
panic("Unsupported type") panic("Unsupported type")
} }
} }
func ReadBinaryN(r io.Reader) (o Binary, n int64) { //-----------------------------------------------------------------------------
type_, n_ := ReadByteN(r)
n += n_ // Creates an adapter codec for Binary things.
switch type_ { // Resulting Codec can be used with merkle/*.
case TYPE_NIL: type BinaryCodec struct {
o, n_ = nil, 0 decoder func(io.Reader, *int64, *error) interface{}
case TYPE_BYTE: }
o, n_ = ReadByteN(r)
case TYPE_INT8: func NewBinaryCodec(decoder func(io.Reader, *int64, *error) interface{}) *BinaryCodec {
o, n_ = ReadInt8N(r) return &BinaryCodec{decoder}
case TYPE_UINT8: }
o, n_ = ReadUInt8N(r)
case TYPE_INT16: func (ca *BinaryCodec) WriteTo(w io.Writer, o interface{}, n *int64, err *error) {
o, n_ = ReadInt16N(r) if bo, ok := o.(Binary); ok {
case TYPE_UINT16: WriteTo(w, BinaryBytes(bo), n, err)
o, n_ = ReadUInt16N(r) } else {
case TYPE_INT32: *err = errors.New("BinaryCodec expected Binary object")
o, n_ = ReadInt32N(r) }
case TYPE_UINT32: }
o, n_ = ReadUInt32N(r)
case TYPE_INT64: func (ca *BinaryCodec) ReadFrom(r io.Reader, n *int64, err *error) interface{} {
o, n_ = ReadInt64N(r) return ca.decoder(r, n, err)
case TYPE_UINT64:
o, n_ = ReadUInt64N(r)
case TYPE_STRING:
o, n_ = ReadStringN(r)
case TYPE_BYTESLICE:
o, n_ = ReadByteSliceN(r)
case TYPE_TIME:
o, n_ = ReadTimeN(r)
default:
panic("Unsupported type")
}
n += n_
return o, n
} }

View File

@ -5,494 +5,118 @@ import (
"io" "io"
) )
type Byte byte
type Int8 int8
type UInt8 uint8
type Int16 int16
type UInt16 uint16
type Int32 int32
type UInt32 uint32
type Int64 int64
type UInt64 uint64
type Int int
type UInt uint
// Byte // Byte
func (self Byte) Equals(other interface{}) bool { func WriteByte(w io.Writer, b byte, n *int64, err *error) {
return self == other WriteTo(w, []byte{b}, n, err)
} }
func (self Byte) Less(other interface{}) bool { func ReadByte(r io.Reader, n *int64, err *error) byte {
if o, ok := other.(Byte); ok { buf := make([]byte, 1)
return self < o ReadFull(r, buf, n, err)
} else { return buf[0]
panic("Cannot compare unequal types")
}
}
func (self Byte) ByteSize() int {
return 1
}
func (self Byte) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadByteSafe(r io.Reader) (Byte, int64, error) {
buf := [1]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return 0, int64(n), err
}
return Byte(buf[0]), int64(n), nil
}
func ReadByteN(r io.Reader) (Byte, int64) {
b, n, err := ReadByteSafe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadByte(r io.Reader) Byte {
b, _, err := ReadByteSafe(r)
if err != nil {
panic(err)
}
return b
}
func Readbyte(r io.Reader) byte {
return byte(ReadByte(r))
} }
// Int8 // Int8
func (self Int8) Equals(other interface{}) bool { func WriteInt8(w io.Writer, i int8, n *int64, err *error) {
return self == other WriteByte(w, byte(i), n, err)
} }
func (self Int8) Less(other interface{}) bool { func ReadInt8(r io.Reader, n *int64, err *error) int8 {
if o, ok := other.(Int8); ok { return int8(ReadByte(r, n, err))
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int8) ByteSize() int {
return 1
}
func (self Int8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadInt8Safe(r io.Reader) (Int8, int64, error) {
buf := [1]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return Int8(0), int64(n), err
}
return Int8(buf[0]), int64(n), nil
}
func ReadInt8N(r io.Reader) (Int8, int64) {
b, n, err := ReadInt8Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadInt8(r io.Reader) Int8 {
b, _, err := ReadInt8Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readint8(r io.Reader) int8 {
return int8(ReadInt8(r))
} }
// UInt8 // UInt8
func (self UInt8) Equals(other interface{}) bool { func WriteUInt8(w io.Writer, i uint8, n *int64, err *error) {
return self == other WriteByte(w, byte(i), n, err)
} }
func (self UInt8) Less(other interface{}) bool { func ReadUInt8(r io.Reader, n *int64, err *error) uint8 {
if o, ok := other.(UInt8); ok { return uint8(ReadByte(r, n, err))
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt8) ByteSize() int {
return 1
}
func (self UInt8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadUInt8Safe(r io.Reader) (UInt8, int64, error) {
buf := [1]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt8(0), int64(n), err
}
return UInt8(buf[0]), int64(n), nil
}
func ReadUInt8N(r io.Reader) (UInt8, int64) {
b, n, err := ReadUInt8Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadUInt8(r io.Reader) UInt8 {
b, _, err := ReadUInt8Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readuint8(r io.Reader) uint8 {
return uint8(ReadUInt8(r))
} }
// Int16 // Int16
func (self Int16) Equals(other interface{}) bool { func WriteInt16(w io.Writer, i int16, n *int64, err *error) {
return self == other buf := make([]byte, 2)
binary.LittleEndian.PutUint16(buf, uint16(i))
WriteTo(w, buf, n, err)
} }
func (self Int16) Less(other interface{}) bool { func ReadInt16(r io.Reader, n *int64, err *error) int16 {
if o, ok := other.(Int16); ok { buf := make([]byte, 2)
return self < o ReadFull(r, buf, n, err)
} else { return int16(binary.LittleEndian.Uint16(buf))
panic("Cannot compare unequal types")
}
}
func (self Int16) ByteSize() int {
return 2
}
func (self Int16) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0}
binary.LittleEndian.PutUint16(buf, uint16(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadInt16Safe(r io.Reader) (Int16, int64, error) {
buf := [2]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return Int16(0), int64(n), err
}
return Int16(binary.LittleEndian.Uint16(buf[:])), int64(n), nil
}
func ReadInt16N(r io.Reader) (Int16, int64) {
b, n, err := ReadInt16Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadInt16(r io.Reader) Int16 {
b, _, err := ReadInt16Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readint16(r io.Reader) int16 {
return int16(ReadInt16(r))
} }
// UInt16 // UInt16
func (self UInt16) Equals(other interface{}) bool { func WriteUInt16(w io.Writer, i uint16, n *int64, err *error) {
return self == other buf := make([]byte, 2)
binary.LittleEndian.PutUint16(buf, uint16(i))
WriteTo(w, buf, n, err)
} }
func (self UInt16) Less(other interface{}) bool { func ReadUInt16(r io.Reader, n *int64, err *error) uint16 {
if o, ok := other.(UInt16); ok { buf := make([]byte, 2)
return self < o ReadFull(r, buf, n, err)
} else { return uint16(binary.LittleEndian.Uint16(buf))
panic("Cannot compare unequal types")
}
}
func (self UInt16) ByteSize() int {
return 2
}
func (self UInt16) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0}
binary.LittleEndian.PutUint16(buf, uint16(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadUInt16Safe(r io.Reader) (UInt16, int64, error) {
buf := [2]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt16(0), int64(n), err
}
return UInt16(binary.LittleEndian.Uint16(buf[:])), int64(n), nil
}
func ReadUInt16N(r io.Reader) (UInt16, int64) {
b, n, err := ReadUInt16Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadUInt16(r io.Reader) UInt16 {
b, _, err := ReadUInt16Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readuint16(r io.Reader) uint16 {
return uint16(ReadUInt16(r))
} }
// Int32 // Int32
func (self Int32) Equals(other interface{}) bool { func WriteInt32(w io.Writer, i int32, n *int64, err *error) {
return self == other buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, uint32(i))
WriteTo(w, buf, n, err)
} }
func (self Int32) Less(other interface{}) bool { func ReadInt32(r io.Reader, n *int64, err *error) int32 {
if o, ok := other.(Int32); ok { buf := make([]byte, 4)
return self < o ReadFull(r, buf, n, err)
} else { return int32(binary.LittleEndian.Uint32(buf))
panic("Cannot compare unequal types")
}
}
func (self Int32) ByteSize() int {
return 4
}
func (self Int32) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(buf, uint32(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadInt32Safe(r io.Reader) (Int32, int64, error) {
buf := [4]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return Int32(0), int64(n), err
}
return Int32(binary.LittleEndian.Uint32(buf[:])), int64(n), nil
}
func ReadInt32N(r io.Reader) (Int32, int64) {
b, n, err := ReadInt32Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadInt32(r io.Reader) Int32 {
b, _, err := ReadInt32Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readint32(r io.Reader) int32 {
return int32(ReadInt32(r))
} }
// UInt32 // UInt32
func (self UInt32) Equals(other interface{}) bool { func WriteUInt32(w io.Writer, i uint32, n *int64, err *error) {
return self == other buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, uint32(i))
WriteTo(w, buf, n, err)
} }
func (self UInt32) Less(other interface{}) bool { func ReadUInt32(r io.Reader, n *int64, err *error) uint32 {
if o, ok := other.(UInt32); ok { buf := make([]byte, 4)
return self < o ReadFull(r, buf, n, err)
} else { return uint32(binary.LittleEndian.Uint32(buf))
panic("Cannot compare unequal types")
}
}
func (self UInt32) ByteSize() int {
return 4
}
func (self UInt32) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(buf, uint32(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadUInt32Safe(r io.Reader) (UInt32, int64, error) {
buf := [4]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt32(0), int64(n), err
}
return UInt32(binary.LittleEndian.Uint32(buf[:])), int64(n), nil
}
func ReadUInt32N(r io.Reader) (UInt32, int64) {
b, n, err := ReadUInt32Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadUInt32(r io.Reader) UInt32 {
b, _, err := ReadUInt32Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readuint32(r io.Reader) uint32 {
return uint32(ReadUInt32(r))
} }
// Int64 // Int64
func (self Int64) Equals(other interface{}) bool { func WriteInt64(w io.Writer, i int64, n *int64, err *error) {
return self == other buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(i))
WriteTo(w, buf, n, err)
} }
func (self Int64) Less(other interface{}) bool { func ReadInt64(r io.Reader, n *int64, err *error) int64 {
if o, ok := other.(Int64); ok { buf := make([]byte, 8)
return self < o ReadFull(r, buf, n, err)
} else { return int64(binary.LittleEndian.Uint64(buf))
panic("Cannot compare unequal types")
}
}
func (self Int64) ByteSize() int {
return 8
}
func (self Int64) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(buf, uint64(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadInt64Safe(r io.Reader) (Int64, int64, error) {
buf := [8]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return Int64(0), int64(n), err
}
return Int64(binary.LittleEndian.Uint64(buf[:])), int64(n), nil
}
func ReadInt64N(r io.Reader) (Int64, int64) {
b, n, err := ReadInt64Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadInt64(r io.Reader) Int64 {
b, _, err := ReadInt64Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readint64(r io.Reader) int64 {
return int64(ReadInt64(r))
} }
// UInt64 // UInt64
func (self UInt64) Equals(other interface{}) bool { func WriteUInt64(w io.Writer, i uint64, n *int64, err *error) {
return self == other buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(i))
WriteTo(w, buf, n, err)
} }
func (self UInt64) Less(other interface{}) bool { func ReadUInt64(r io.Reader, n *int64, err *error) uint64 {
if o, ok := other.(UInt64); ok { buf := make([]byte, 8)
return self < o ReadFull(r, buf, n, err)
} else { return uint64(binary.LittleEndian.Uint64(buf))
panic("Cannot compare unequal types")
}
}
func (self UInt64) ByteSize() int {
return 8
}
func (self UInt64) WriteTo(w io.Writer) (int64, error) {
buf := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(buf, uint64(self))
n, err := w.Write(buf)
return int64(n), err
}
func ReadUInt64Safe(r io.Reader) (UInt64, int64, error) {
buf := [8]byte{0}
n, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt64(0), int64(n), err
}
return UInt64(binary.LittleEndian.Uint64(buf[:])), int64(n), nil
}
func ReadUInt64N(r io.Reader) (UInt64, int64) {
b, n, err := ReadUInt64Safe(r)
if err != nil {
panic(err)
}
return b, n
}
func ReadUInt64(r io.Reader) UInt64 {
b, _, err := ReadUInt64Safe(r)
if err != nil {
panic(err)
}
return b
}
func Readuint64(r io.Reader) uint64 {
return uint64(ReadUInt64(r))
} }

View File

@ -2,67 +2,19 @@ package binary
import "io" import "io"
type String string
// String // String
func (self String) Equals(other interface{}) bool { func WriteString(w io.Writer, s string, n *int64, err *error) {
return self == other WriteUInt32(w, uint32(len(s)), n, err)
WriteTo(w, []byte(s), n, err)
} }
func (self String) Less(other interface{}) bool { func ReadString(r io.Reader, n *int64, err *error) string {
if o, ok := other.(String); ok { length := ReadUInt32(r, n, err)
return self < o if *err != nil {
} else { return ""
panic("Cannot compare unequal types")
} }
} buf := make([]byte, int(length))
ReadFull(r, buf, n, err)
func (self String) ByteSize() int { return string(buf)
return len(self) + 4
}
func (self String) WriteTo(w io.Writer) (n int64, err error) {
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil {
return n, err
}
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
}
func ReadStringSafe(r io.Reader) (str String, n int64, err error) {
length, n_, err := ReadUInt32Safe(r)
n += n_
if err != nil {
return "", n, err
}
bytes := make([]byte, int(length))
n__, err := io.ReadFull(r, bytes)
n += int64(n__)
if err != nil {
return "", n, err
}
return String(bytes), n, nil
}
func ReadStringN(r io.Reader) (str String, n int64) {
str, n, err := ReadStringSafe(r)
if err != nil {
panic(err)
}
return str, n
}
func ReadString(r io.Reader) (str String) {
str, _, err := ReadStringSafe(r)
if err != nil {
panic(err)
}
return str
}
func Readstring(r io.Reader) (str string) {
return string(ReadString(r))
} }

View File

@ -5,58 +5,13 @@ import (
"time" "time"
) )
type Time struct { // Time
time.Time
func WriteTime(w io.Writer, t time.Time, n *int64, err *error) {
WriteInt64(w, t.Unix(), n, err)
} }
func TimeFromUnix(secSinceEpoch int64) Time { func ReadTime(r io.Reader, n *int64, err *error) time.Time {
return Time{time.Unix(secSinceEpoch, 0)} t := ReadInt64(r, n, err)
} return time.Unix(t, 0)
func (self Time) Equals(other interface{}) bool {
if o, ok := other.(Time); ok {
return self.Equal(o.Time)
} else {
return false
}
}
func (self Time) Less(other interface{}) bool {
if o, ok := other.(Time); ok {
return self.Before(o.Time)
} else {
panic("Cannot compare unequal types")
}
}
func (self Time) ByteSize() int {
return 8
}
func (self Time) WriteTo(w io.Writer) (int64, error) {
return Int64(self.Unix()).WriteTo(w)
}
func ReadTimeSafe(r io.Reader) (Time, int64, error) {
t, n, err := ReadInt64Safe(r)
if err != nil {
return Time{}, n, err
}
return Time{time.Unix(int64(t), 0)}, n, nil
}
func ReadTimeN(r io.Reader) (Time, int64) {
t, n, err := ReadTimeSafe(r)
if err != nil {
panic(err)
}
return t, n
}
func ReadTime(r io.Reader) Time {
t, _, err := ReadTimeSafe(r)
if err != nil {
panic(err)
}
return t
} }

View File

@ -5,10 +5,10 @@ import (
"crypto/sha256" "crypto/sha256"
) )
func BinaryBytes(b Binary) ByteSlice { func BinaryBytes(b Binary) []byte {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
b.WriteTo(buf) b.WriteTo(buf)
return ByteSlice(buf.Bytes()) return buf.Bytes()
} }
// NOTE: does not care about the type, only the binary representation. // NOTE: does not care about the type, only the binary representation.
@ -25,11 +25,11 @@ func BinaryCompare(a, b Binary) int {
return bytes.Compare(aBytes, bBytes) return bytes.Compare(aBytes, bBytes)
} }
func BinaryHash(b Binary) ByteSlice { func BinaryHash(b Binary) []byte {
hasher := sha256.New() hasher := sha256.New()
_, err := b.WriteTo(hasher) _, err := b.WriteTo(hasher)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return ByteSlice(hasher.Sum(nil)) return hasher.Sum(nil)
} }

View File

@ -10,11 +10,11 @@ import (
// Node // Node
type IAVLNode struct { type IAVLNode struct {
key Key key []byte
value Value value []byte
size uint64 size uint64
height uint8 height uint8
hash ByteSlice hash []byte
left *IAVLNode left *IAVLNode
right *IAVLNode right *IAVLNode
@ -27,7 +27,7 @@ const (
IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) IAVLNODE_FLAG_PLACEHOLDER = byte(0x02)
) )
func NewIAVLNode(key Key, value Value) *IAVLNode { func NewIAVLNode(key []byte, value []byte) *IAVLNode {
return &IAVLNode{ return &IAVLNode{
key: key, key: key,
value: value, value: value,
@ -50,14 +50,6 @@ func (self *IAVLNode) Copy() *IAVLNode {
} }
} }
func (self *IAVLNode) Key() Key {
return self.key
}
func (self *IAVLNode) Value() Value {
return self.value
}
func (self *IAVLNode) Size() uint64 { func (self *IAVLNode) Size() uint64 {
return self.size return self.size
} }
@ -66,14 +58,14 @@ func (self *IAVLNode) Height() uint8 {
return self.height return self.height
} }
func (self *IAVLNode) has(db Db, key Key) (has bool) { func (self *IAVLNode) has(db Db, key []byte) (has bool) {
if self.key.Equals(key) { if bytes.Equal(self.key, key) {
return true return true
} }
if self.height == 0 { if self.height == 0 {
return false return false
} else { } else {
if key.Less(self.key) { if bytes.Compare(key, self.key) == -1 {
return self.leftFilled(db).has(db, key) return self.leftFilled(db).has(db, key)
} else { } else {
return self.rightFilled(db).has(db, key) return self.rightFilled(db).has(db, key)
@ -81,15 +73,15 @@ func (self *IAVLNode) has(db Db, key Key) (has bool) {
} }
} }
func (self *IAVLNode) get(db Db, key Key) (value Value) { func (self *IAVLNode) get(db Db, key []byte) (value []byte) {
if self.height == 0 { if self.height == 0 {
if self.key.Equals(key) { if bytes.Equal(self.key, key) {
return self.value return self.value
} else { } else {
return nil return nil
} }
} else { } else {
if key.Less(self.key) { if bytes.Compare(key, self.key) == -1 {
return self.leftFilled(db).get(db, key) return self.leftFilled(db).get(db, key)
} else { } else {
return self.rightFilled(db).get(db, key) return self.rightFilled(db).get(db, key)
@ -97,7 +89,7 @@ func (self *IAVLNode) get(db Db, key Key) (value Value) {
} }
} }
func (self *IAVLNode) Hash() (ByteSlice, uint64) { func (self *IAVLNode) HashWithCount() ([]byte, uint64) {
if self.hash != nil { if self.hash != nil {
return self.hash, 0 return self.hash, 0
} }
@ -138,9 +130,9 @@ func (self *IAVLNode) Save(db Db) {
self.flags |= IAVLNODE_FLAG_PERSISTED self.flags |= IAVLNODE_FLAG_PERSISTED
} }
func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { func (self *IAVLNode) set(db Db, key []byte, value []byte) (_ *IAVLNode, updated bool) {
if self.height == 0 { if self.height == 0 {
if key.Less(self.key) { if bytes.Compare(key, self.key) == -1 {
return &IAVLNode{ return &IAVLNode{
key: self.key, key: self.key,
height: 1, height: 1,
@ -148,7 +140,7 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo
left: NewIAVLNode(key, value), left: NewIAVLNode(key, value),
right: self, right: self,
}, false }, false
} else if self.key.Equals(key) { } else if bytes.Equal(self.key, key) {
return NewIAVLNode(key, value), true return NewIAVLNode(key, value), true
} else { } else {
return &IAVLNode{ return &IAVLNode{
@ -161,7 +153,7 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo
} }
} else { } else {
self = self.Copy() self = self.Copy()
if key.Less(self.key) { if bytes.Compare(key, self.key) == -1 {
self.left, updated = self.leftFilled(db).set(db, key, value) self.left, updated = self.leftFilled(db).set(db, key, value)
} else { } else {
self.right, updated = self.rightFilled(db).set(db, key, value) self.right, updated = self.rightFilled(db).set(db, key, value)
@ -176,15 +168,15 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo
} }
// newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed.
func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) { func (self *IAVLNode) remove(db Db, key []byte) (newSelf *IAVLNode, newKey []byte, value []byte, err error) {
if self.height == 0 { if self.height == 0 {
if self.key.Equals(key) { if bytes.Equal(self.key, key) {
return nil, nil, self.value, nil return nil, nil, self.value, nil
} else { } else {
return self, nil, nil, NotFound(key) return self, nil, nil, NotFound(key)
} }
} else { } else {
if key.Less(self.key) { if bytes.Compare(key, self.key) == -1 {
var newLeft *IAVLNode var newLeft *IAVLNode
newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) newLeft, newKey, value, err = self.leftFilled(db).remove(db, key)
if err != nil { if err != nil {
@ -220,74 +212,28 @@ func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) {
} }
func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) { func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) {
var _n int64 // height & size & key
WriteUInt8(w, self.height, &n, &err)
// height & size WriteUInt64(w, self.size, &n, &err)
_n, err = UInt8(self.height).WriteTo(w) WriteByteSlice(w, self.key, &n, &err)
if err != nil { if err != nil {
return return
} else {
n += _n
}
_n, err = UInt64(self.size).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
// key
_n, err = Byte(GetBinaryType(self.key)).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
_n, err = self.key.WriteTo(w)
if err != nil {
return
} else {
n += _n
} }
// value or children // value or children
if self.height == 0 { if self.height == 0 {
// value // value
_n, err = Byte(GetBinaryType(self.value)).WriteTo(w) WriteByteSlice(w, self.value, &n, &err)
if err != nil {
return
} else {
n += _n
}
if self.value != nil {
_n, err = self.value.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
} else { } else {
// left // left
leftHash, leftCount := self.left.Hash() leftHash, leftCount := self.left.HashWithCount()
hashCount += leftCount hashCount += leftCount
_n, err = leftHash.WriteTo(w) WriteByteSlice(w, leftHash, &n, &err)
if err != nil {
return
} else {
n += _n
}
// right // right
rightHash, rightCount := self.right.Hash() rightHash, rightCount := self.right.HashWithCount()
hashCount += rightCount hashCount += rightCount
_n, err = rightHash.WriteTo(w) WriteByteSlice(w, rightHash, &n, &err)
if err != nil {
return
} else {
n += _n
}
} }
return return
} }
@ -300,25 +246,30 @@ func (self *IAVLNode) fill(db Db) {
} }
buf := db.Get(self.hash) buf := db.Get(self.hash)
r := bytes.NewReader(buf) r := bytes.NewReader(buf)
// node header var n int64
self.height = uint8(ReadUInt8(r)) var err error
self.size = uint64(ReadUInt64(r))
// key
key, _ := ReadBinaryN(r)
self.key = key.(Key)
// node header & key
self.height = ReadUInt8(r, &n, &err)
self.size = ReadUInt64(r, &n, &err)
self.key = ReadByteSlice(r, &n, &err)
if err != nil {
panic(err)
}
// node value or children.
if self.height == 0 { if self.height == 0 {
// value // value
self.value, _ = ReadBinaryN(r) self.value = ReadByteSlice(r, &n, &err)
} else { } else {
// left // left
leftHash := ReadByteSlice(r) leftHash := ReadByteSlice(r, &n, &err)
self.left = &IAVLNode{ self.left = &IAVLNode{
hash: leftHash, hash: leftHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
} }
// right // right
rightHash := ReadByteSlice(r) rightHash := ReadByteSlice(r, &n, &err)
self.right = &IAVLNode{ self.right = &IAVLNode{
hash: rightHash, hash: rightHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
@ -327,6 +278,9 @@ func (self *IAVLNode) fill(db Db) {
panic("buf not all consumed") panic("buf not all consumed")
} }
} }
if err != nil {
panic(err)
}
self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER
} }
@ -425,7 +379,7 @@ func (self *IAVLNode) rmd(db Db) *IAVLNode {
return self.rightFilled(db).rmd(db) return self.rightFilled(db).rmd(db)
} }
func (self *IAVLNode) traverse(db Db, cb func(Node) bool) bool { func (self *IAVLNode) traverse(db Db, cb func(*IAVLNode) bool) bool {
stop := cb(self) stop := cb(self)
if stop { if stop {
return stop return stop

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/common"
"github.com/tendermint/tendermint/db" "github.com/tendermint/tendermint/db"
@ -17,8 +16,8 @@ func init() {
// TODO: seed rand? // TODO: seed rand?
} }
func randstr(length int) String { func randstr(length int) string {
return String(RandStr(length)) return RandStr(length)
} }
func TestUnit(t *testing.T) { func TestUnit(t *testing.T) {
@ -29,12 +28,12 @@ func TestUnit(t *testing.T) {
if _, ok := l.(*IAVLNode); ok { if _, ok := l.(*IAVLNode); ok {
left = l.(*IAVLNode) left = l.(*IAVLNode)
} else { } else {
left = NewIAVLNode(Int32(l.(int)), nil) left = NewIAVLNode([]byte{byte(l.(int))}, nil)
} }
if _, ok := r.(*IAVLNode); ok { if _, ok := r.(*IAVLNode); ok {
right = r.(*IAVLNode) right = r.(*IAVLNode)
} else { } else {
right = NewIAVLNode(Int32(r.(int)), nil) right = NewIAVLNode([]byte{byte(r.(int))}, nil)
} }
n := &IAVLNode{ n := &IAVLNode{
@ -43,7 +42,7 @@ func TestUnit(t *testing.T) {
right: right, right: right,
} }
n.calcHeightAndSize(nil) n.calcHeightAndSize(nil)
n.Hash() n.HashWithCount()
return n return n
} }
@ -51,7 +50,7 @@ func TestUnit(t *testing.T) {
var P func(*IAVLNode) string var P func(*IAVLNode) string
P = func(n *IAVLNode) string { P = func(n *IAVLNode) string {
if n.height == 0 { if n.height == 0 {
return fmt.Sprintf("%v", n.key) return fmt.Sprintf("%v", n.key[0])
} else { } else {
return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) return fmt.Sprintf("(%v %v)", P(n.left), P(n.right))
} }
@ -59,24 +58,24 @@ func TestUnit(t *testing.T) {
expectHash := func(n2 *IAVLNode, hashCount uint64) { expectHash := func(n2 *IAVLNode, hashCount uint64) {
// ensure number of new hash calculations is as expected. // ensure number of new hash calculations is as expected.
hash, count := n2.Hash() hash, count := n2.HashWithCount()
if count != hashCount { if count != hashCount {
t.Fatalf("Expected %v new hashes, got %v", hashCount, count) t.Fatalf("Expected %v new hashes, got %v", hashCount, count)
} }
// nuke hashes and reconstruct hash, ensure it's the same. // nuke hashes and reconstruct hash, ensure it's the same.
(&IAVLTree{root: n2}).Traverse(func(node Node) bool { n2.traverse(nil, func(node *IAVLNode) bool {
node.(*IAVLNode).hash = nil node.hash = nil
return false return false
}) })
// ensure that the new hash after nuking is the same as the old. // ensure that the new hash after nuking is the same as the old.
newHash, _ := n2.Hash() newHash, _ := n2.HashWithCount()
if bytes.Compare(hash, newHash) != 0 { if bytes.Compare(hash, newHash) != 0 {
t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash)
} }
} }
expectSet := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectSet := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, updated := n.set(nil, Int32(i), nil) n2, updated := n.set(nil, []byte{byte(i)}, nil)
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if updated == true || P(n2) != repr { if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
@ -87,7 +86,7 @@ func TestUnit(t *testing.T) {
} }
expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, _, value, err := n.remove(nil, Int32(i)) n2, _, value, err := n.remove(nil, []byte{byte(i)})
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if value != nil || err != nil || P(n2) != repr { if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",
@ -137,14 +136,14 @@ func TestUnit(t *testing.T) {
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
type record struct { type record struct {
key String key string
value String value string
} }
records := make([]*record, 400) records := make([]*record, 400)
var tree *IAVLTree = NewIAVLTree(nil) var tree *IAVLTree = NewIAVLTree(nil)
var err error var err error
var val Value var val []byte
var updated bool var updated bool
randomRecord := func() *record { randomRecord := func() *record {
@ -156,11 +155,11 @@ func TestIntegration(t *testing.T) {
records[i] = r records[i] = r
//t.Log("New record", r) //t.Log("New record", r)
//PrintIAVLNode(tree.root) //PrintIAVLNode(tree.root)
updated = tree.Set(r.key, String("")) updated = tree.Set([]byte(r.key), []byte(""))
if updated { if updated {
t.Error("should have not been updated") t.Error("should have not been updated")
} }
updated = tree.Set(r.key, r.value) updated = tree.Set([]byte(r.key), []byte(r.value))
if !updated { if !updated {
t.Error("should have been updated") t.Error("should have been updated")
} }
@ -170,31 +169,32 @@ func TestIntegration(t *testing.T) {
} }
for _, r := range records { for _, r := range records {
if has := tree.Has(r.key); !has { if has := tree.Has([]byte(r.key)); !has {
t.Error("Missing key", r.key) t.Error("Missing key", r.key)
} }
if has := tree.Has(randstr(12)); has { if has := tree.Has([]byte(randstr(12))); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { if val := tree.Get([]byte(r.key)); string(val) != r.value {
t.Error("wrong value") t.Error("wrong value")
} }
} }
for i, x := range records { for i, x := range records {
if val, err = tree.Remove(x.key); err != nil { if val, err = tree.Remove([]byte(x.key)); err != nil {
t.Error(err) t.Error(err)
} else if !(val.(String)).Equals(x.value) { } else if string(val) != x.value {
t.Error("wrong value") t.Error("wrong value")
} }
for _, r := range records[i+1:] { for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has { if has := tree.Has([]byte(r.key)); !has {
t.Error("Missing key", r.key) t.Error("Missing key", r.key)
} }
if has := tree.Has(randstr(12)); has { if has := tree.Has([]byte(randstr(12))); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { val := tree.Get([]byte(r.key))
if string(val) != r.value {
t.Error("wrong value") t.Error("wrong value")
} }
} }
@ -208,25 +208,25 @@ func TestPersistence(t *testing.T) {
db := db.NewMemDB() db := db.NewMemDB()
// Create some random key value pairs // Create some random key value pairs
records := make(map[String]String) records := make(map[string]string)
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
records[String(randstr(20))] = String(randstr(20)) records[randstr(20)] = randstr(20)
} }
// Construct some tree and save it // Construct some tree and save it
t1 := NewIAVLTree(db) t1 := NewIAVLTree(db)
for key, value := range records { for key, value := range records {
t1.Set(key, value) t1.Set([]byte(key), []byte(value))
} }
t1.Save() t1.Save()
hash, _ := t1.Hash() hash, _ := t1.HashWithCount()
// Load a tree // Load a tree
t2 := NewIAVLTreeFromHash(db, hash) t2 := NewIAVLTreeFromHash(db, hash)
for key, value := range records { for key, value := range records {
t2value := t2.Get(key) t2value := t2.Get([]byte(key))
if !BinaryEqual(t2value, value) { if string(t2value) != value {
t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) t.Fatalf("Invalid value. Expected %v, got %v", value, t2value)
} }
} }
@ -249,8 +249,8 @@ func BenchmarkImmutableAvlTree(b *testing.B) {
b.StopTimer() b.StopTimer()
type record struct { type record struct {
key String key string
value String value string
} }
randomRecord := func() *record { randomRecord := func() *record {
@ -260,7 +260,7 @@ func BenchmarkImmutableAvlTree(b *testing.B) {
t := NewIAVLTree(nil) t := NewIAVLTree(nil)
for i := 0; i < 1000000; i++ { for i := 0; i < 1000000; i++ {
r := randomRecord() r := randomRecord()
t.Set(r.key, r.value) t.Set([]byte(r.key), []byte(r.value))
} }
fmt.Println("ok, starting") fmt.Println("ok, starting")
@ -270,7 +270,7 @@ func BenchmarkImmutableAvlTree(b *testing.B) {
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
r := randomRecord() r := randomRecord()
t.Set(r.key, r.value) t.Set([]byte(r.key), []byte(r.value))
t.Remove(r.key) t.Remove([]byte(r.key))
} }
} }

View File

@ -1,9 +1,5 @@
package merkle package merkle
import (
. "github.com/tendermint/tendermint/binary"
)
const HASH_BYTE_SIZE int = 4 + 32 const HASH_BYTE_SIZE int = 4 + 32
/* /*
@ -18,10 +14,14 @@ type IAVLTree struct {
} }
func NewIAVLTree(db Db) *IAVLTree { func NewIAVLTree(db Db) *IAVLTree {
return &IAVLTree{db: db, root: nil} return &IAVLTree{
db: db,
root: nil,
}
} }
func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { // TODO rename to Load.
func NewIAVLTreeFromHash(db Db, hash []byte) *IAVLTree {
root := &IAVLNode{ root := &IAVLNode{
hash: hash, hash: hash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
@ -43,10 +43,6 @@ func NewIAVLTreeFromKey(db Db, key string) *IAVLTree {
return &IAVLTree{db: db, root: root} return &IAVLTree{db: db, root: root}
} }
func (t *IAVLTree) Root() Node {
return t.root
}
func (t *IAVLTree) Size() uint64 { func (t *IAVLTree) Size() uint64 {
if t.root == nil { if t.root == nil {
return 0 return 0
@ -61,14 +57,14 @@ func (t *IAVLTree) Height() uint8 {
return t.root.Height() return t.root.Height()
} }
func (t *IAVLTree) Has(key Key) bool { func (t *IAVLTree) Has(key []byte) bool {
if t.root == nil { if t.root == nil {
return false return false
} }
return t.root.has(t.db, key) return t.root.has(t.db, key)
} }
func (t *IAVLTree) Set(key Key, value Value) (updated bool) { func (t *IAVLTree) Set(key []byte, value []byte) (updated bool) {
if t.root == nil { if t.root == nil {
t.root = NewIAVLNode(key, value) t.root = NewIAVLNode(key, value)
return false return false
@ -77,18 +73,26 @@ func (t *IAVLTree) Set(key Key, value Value) (updated bool) {
return updated return updated
} }
func (t *IAVLTree) Hash() (ByteSlice, uint64) { func (t *IAVLTree) Hash() []byte {
if t.root == nil {
return nil
}
hash, _ := t.root.HashWithCount()
return hash
}
func (t *IAVLTree) HashWithCount() ([]byte, uint64) {
if t.root == nil { if t.root == nil {
return nil, 0 return nil, 0
} }
return t.root.Hash() return t.root.HashWithCount()
} }
func (t *IAVLTree) Save() { func (t *IAVLTree) Save() {
if t.root == nil { if t.root == nil {
return return
} }
t.root.Hash() t.root.HashWithCount()
t.root.Save(t.db) t.root.Save(t.db)
} }
@ -96,19 +100,19 @@ func (t *IAVLTree) SaveKey(key string) {
if t.root == nil { if t.root == nil {
return return
} }
hash, _ := t.root.Hash() hash, _ := t.root.HashWithCount()
t.root.Save(t.db) t.root.Save(t.db)
t.db.Set([]byte(key), hash) t.db.Set([]byte(key), hash)
} }
func (t *IAVLTree) Get(key Key) (value Value) { func (t *IAVLTree) Get(key []byte) (value []byte) {
if t.root == nil { if t.root == nil {
return nil return nil
} }
return t.root.get(t.db, key) return t.root.get(t.db, key)
} }
func (t *IAVLTree) Remove(key Key) (value Value, err error) { func (t *IAVLTree) Remove(key []byte) (value []byte, err error) {
if t.root == nil { if t.root == nil {
return nil, NotFound(key) return nil, NotFound(key)
} }
@ -123,32 +127,3 @@ func (t *IAVLTree) Remove(key Key) (value Value, err error) {
func (t *IAVLTree) Copy() Tree { func (t *IAVLTree) Copy() Tree {
return &IAVLTree{db: t.db, root: t.root} return &IAVLTree{db: t.db, root: t.root}
} }
// Traverses all the nodes of the tree in prefix order.
// return true from cb to halt iteration.
// node.Height() == 0 if you just want a value node.
func (t *IAVLTree) Traverse(cb func(Node) bool) {
if t.root == nil {
return
}
t.root.traverse(t.db, cb)
}
func (t *IAVLTree) Values() <-chan Value {
root := t.root
ch := make(chan Value)
if root == nil {
close(ch)
return ch
}
go func() {
root.traverse(t.db, func(n Node) bool {
if n.Height() == 0 {
ch <- n.Value()
}
return true
})
close(ch)
}()
return ch
}

View File

@ -2,50 +2,27 @@ package merkle
import ( import (
"fmt" "fmt"
. "github.com/tendermint/tendermint/binary"
) )
type Value interface {
Binary
}
type Key interface {
Binary
Equals(interface{}) bool
Less(b interface{}) bool
}
type Db interface { type Db interface {
Get([]byte) []byte Get([]byte) []byte
Set([]byte, []byte) Set([]byte, []byte)
} }
type Node interface {
Binary
Key() Key
Value() Value
Size() uint64
Height() uint8
Hash() (ByteSlice, uint64)
Save(Db)
}
type Tree interface { type Tree interface {
Root() Node
Size() uint64 Size() uint64
Height() uint8 Height() uint8
Has(key Key) bool Has(key []byte) bool
Get(key Key) Value Get(key []byte) []byte
Hash() (ByteSlice, uint64) HashWithCount() ([]byte, uint64)
Hash() []byte
Save() Save()
SaveKey(string) SaveKey(string)
Set(Key, Value) bool Set(key []byte, vlaue []byte) bool
Remove(Key) (Value, error) Remove(key []byte) ([]byte, error)
Copy() Tree Copy() Tree
Traverse(func(Node) bool)
Values() <-chan Value
} }
func NotFound(key Key) error { func NotFound(key []byte) error {
return fmt.Errorf("Key was not found.") return fmt.Errorf("Key was not found.")
} }

View File

@ -7,10 +7,15 @@ import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
) )
func HashFromByteSlices(items [][]byte) []byte {
panic("Implement me")
return nil
}
/* /*
Compute a deterministic merkle hash from a list of byteslices. Compute a deterministic merkle hash from a list of byteslices.
*/ */
func HashFromBinarySlice(items []Binary) ByteSlice { func HashFromBinarySlice(items []Binary) []byte {
switch len(items) { switch len(items) {
case 0: case 0:
panic("Cannot compute hash of empty slice") panic("Cannot compute hash of empty slice")
@ -20,18 +25,22 @@ func HashFromBinarySlice(items []Binary) ByteSlice {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return ByteSlice(hasher.Sum(nil)) return hasher.Sum(nil)
default: default:
hasher := sha256.New() var n int64
_, err := HashFromBinarySlice(items[0 : len(items)/2]).WriteTo(hasher) var err error
var hasher = sha256.New()
hash := HashFromBinarySlice(items[0 : len(items)/2])
WriteByteSlice(hasher, hash, &n, &err)
if err != nil { if err != nil {
panic(err) panic(err)
} }
_, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher) hash = HashFromBinarySlice(items[len(items)/2:])
WriteByteSlice(hasher, hash, &n, &err)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return ByteSlice(hasher.Sum(nil)) return hasher.Sum(nil)
} }
} }

19
state/store.go Normal file
View File

@ -0,0 +1,19 @@
package state
import (
. "github.com/tendermint/tendermint/blocks"
)
// XXX ugh, bad name.
type StateStore struct {
}
func (ss *StateStore) StageBlock(block *Block) error {
// XXX implement staging.
return nil
}
func (ss *StateStore) CommitBlock(block *Block) error {
// XXX implement staging.
return nil
}

141
state/validator.go Normal file
View File

@ -0,0 +1,141 @@
package state
import (
"io"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/blocks"
//. "github.com/tendermint/tendermint/common"
db_ "github.com/tendermint/tendermint/db"
)
// Holds state for a Validator at a given height+round.
// Meant to be discarded every round of the consensus protocol.
// TODO consider moving this to another common types package.
type Validator struct {
Account
BondHeight uint32
VotingPower uint64
Accum int64
}
// Used to persist the state of ConsensusStateControl.
func ReadValidator(r io.Reader) *Validator {
return &Validator{
Account: Account{
Id: Readuint64(r),
PubKey: ReadByteSlice(r),
},
BondHeight: Readuint32(r),
VotingPower: Readuint64(r),
Accum: Readint64(r),
}
}
// Creates a new copy of the validator so we can mutate accum.
func (v *Validator) Copy() *Validator {
return &Validator{
Account: v.Account,
BondHeight: v.BondHeight,
VotingPower: v.VotingPower,
Accum: v.Accum,
}
}
// Used to persist the state of ConsensusStateControl.
func (v *Validator) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteTo(UInt64(v.Id), w, n, err)
n, err = WriteTo(v.PubKey, w, n, err)
n, err = WriteTo(UInt32(v.BondHeight), w, n, err)
n, err = WriteTo(UInt64(v.VotingPower), w, n, err)
n, err = WriteTo(Int64(v.Accum), w, n, err)
return
}
//-----------------------------------------------------------------------------
// TODO: Ensure that double signing never happens via an external persistent check.
type PrivValidator struct {
PrivAccount
db *db_.LevelDB
}
// Modifies the vote object in memory.
// Double signing results in an error.
func (pv *PrivValidator) SignVote(vote *Vote) error {
return nil
}
//-----------------------------------------------------------------------------
// Not goroutine-safe.
type ValidatorSet struct {
validators map[uint64]*Validator
}
func NewValidatorSet(validators map[uint64]*Validator) *ValidatorSet {
if validators == nil {
validators = make(map[uint64]*Validator)
}
return &ValidatorSet{
valdiators: validators,
}
}
func (v *ValidatorSet) IncrementAccum() {
totalDelta := int64(0)
for _, validator := range v.validators {
validator.Accum += int64(validator.VotingPower)
totalDelta += int64(validator.VotingPower)
}
proposer := v.GetProposer()
proposer.Accum -= totalDelta
// NOTE: sum(v) here should be zero.
if true {
totalAccum := int64(0)
for _, validator := range v.validators {
totalAccum += validator.Accum
}
if totalAccum != 0 {
Panicf("Total Accum of validators did not equal 0. Got: ", totalAccum)
}
}
}
func (v *ValidatorSet) Copy() *ValidatorSet {
mapCopy := map[uint64]*Validator{}
for _, val := range validators {
mapCopy[val.Id] = val.Copy()
}
return &ValidatorSet{
validators: mapCopy,
}
}
func (v *ValidatorSet) Add(validator *Valdaitor) {
v.validators[validator.Id] = validator
}
func (v *ValidatorSet) Get(id uint64) *Validator {
return v.validators[validator.Id]
}
func (v *ValidatorSet) Map() map[uint64]*Validator {
return v.validators
}
// TODO: cache proposer. invalidate upon increment.
func (v *ValidatorSet) GetProposer() (proposer *Validator) {
highestAccum := int64(0)
for _, validator := range v.validators {
if validator.Accum > highestAccum {
highestAccum = validator.Accum
proposer = validator
} else if validator.Accum == highestAccum {
if validator.Id < proposer.Id { // Seniority
proposer = validator
}
}
}
return
}