This commit is contained in:
Jae Kwon
2014-07-01 14:50:24 -07:00
parent fa07748d23
commit c40fc65e6b
41 changed files with 3176 additions and 2938 deletions

View File

@ -3,12 +3,14 @@ package binary
import "io" import "io"
type Binary interface { type Binary interface {
WriteTo(w io.Writer) (int64, error) WriteTo(w io.Writer) (int64, error)
} }
func WriteOnto(b Binary, w io.Writer, n int64, err error) (int64, error) { func WriteOnto(b Binary, w io.Writer, n int64, err error) (int64, error) {
if err != nil { return n, err } if err != nil {
var n_ int64 return n, err
n_, err = b.WriteTo(w) }
return n+n_, err var n_ int64
n_, err = b.WriteTo(w)
return n + n_, err
} }

View File

@ -6,44 +6,52 @@ import "bytes"
type ByteSlice []byte type ByteSlice []byte
func (self ByteSlice) Equals(other Binary) bool { func (self ByteSlice) Equals(other Binary) bool {
if o, ok := other.(ByteSlice); ok { if o, ok := other.(ByteSlice); ok {
return bytes.Equal(self, o) return bytes.Equal(self, o)
} else { } else {
return false return false
} }
} }
func (self ByteSlice) Less(other Binary) bool { func (self ByteSlice) Less(other Binary) bool {
if o, ok := other.(ByteSlice); ok { if o, ok := other.(ByteSlice); ok {
return bytes.Compare(self, o) < 0 // -1 if a < b return bytes.Compare(self, o) < 0 // -1 if a < b
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self ByteSlice) ByteSize() int { func (self ByteSlice) ByteSize() int {
return len(self)+4 return len(self) + 4
} }
func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) { func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) {
var n_ int var n_ int
_, err = UInt32(len(self)).WriteTo(w) _, err = UInt32(len(self)).WriteTo(w)
if err != nil { return n, err } if err != nil {
n_, err = w.Write([]byte(self)) return n, err
return int64(n_+4), err }
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
} }
func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) { func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) {
length, err := ReadUInt32Safe(r) length, err := ReadUInt32Safe(r)
if err != nil { return nil, err } if err != nil {
bytes := make([]byte, int(length)) return nil, err
_, err = io.ReadFull(r, bytes) }
if err != nil { return nil, err } bytes := make([]byte, int(length))
return bytes, nil _, err = io.ReadFull(r, bytes)
if err != nil {
return nil, err
}
return bytes, nil
} }
func ReadByteSlice(r io.Reader) ByteSlice { func ReadByteSlice(r io.Reader) ByteSlice {
bytes, err := ReadByteSliceSafe(r) bytes, err := ReadByteSliceSafe(r)
if r != nil { panic(err) } if r != nil {
return bytes panic(err)
}
return bytes
} }

View File

@ -1,70 +1,100 @@
package binary package binary
import ( import (
"io" "io"
) )
const ( const (
TYPE_NIL = Byte(0x00) TYPE_NIL = Byte(0x00)
TYPE_BYTE = Byte(0x01) TYPE_BYTE = Byte(0x01)
TYPE_INT8 = Byte(0x02) TYPE_INT8 = Byte(0x02)
TYPE_UINT8 = Byte(0x03) TYPE_UINT8 = Byte(0x03)
TYPE_INT16 = Byte(0x04) TYPE_INT16 = Byte(0x04)
TYPE_UINT16 = Byte(0x05) TYPE_UINT16 = Byte(0x05)
TYPE_INT32 = Byte(0x06) TYPE_INT32 = Byte(0x06)
TYPE_UINT32 = Byte(0x07) TYPE_UINT32 = Byte(0x07)
TYPE_INT64 = Byte(0x08) TYPE_INT64 = Byte(0x08)
TYPE_UINT64 = Byte(0x09) TYPE_UINT64 = Byte(0x09)
TYPE_STRING = Byte(0x10) TYPE_STRING = Byte(0x10)
TYPE_BYTESLICE = Byte(0x11) TYPE_BYTESLICE = Byte(0x11)
TYPE_TIME = Byte(0x20) TYPE_TIME = Byte(0x20)
) )
func GetBinaryType(o Binary) Byte { func GetBinaryType(o Binary) Byte {
switch o.(type) { switch o.(type) {
case nil: return TYPE_NIL case nil:
case Byte: return TYPE_BYTE return TYPE_NIL
case Int8: return TYPE_INT8 case Byte:
case UInt8: return TYPE_UINT8 return TYPE_BYTE
case Int16: return TYPE_INT16 case Int8:
case UInt16: return TYPE_UINT16 return TYPE_INT8
case Int32: return TYPE_INT32 case UInt8:
case UInt32: return TYPE_UINT32 return TYPE_UINT8
case Int64: return TYPE_INT64 case Int16:
case UInt64: return TYPE_UINT64 return TYPE_INT16
case Int: panic("Int not supported") case UInt16:
case UInt: panic("UInt not supported") return TYPE_UINT16
case Int32:
return TYPE_INT32
case UInt32:
return TYPE_UINT32
case Int64:
return TYPE_INT64
case UInt64:
return TYPE_UINT64
case Int:
panic("Int not supported")
case UInt:
panic("UInt not supported")
case String: return TYPE_STRING case String:
case ByteSlice: return TYPE_BYTESLICE return TYPE_STRING
case ByteSlice:
return TYPE_BYTESLICE
case Time: return TYPE_TIME case Time:
return TYPE_TIME
default: panic("Unsupported type") default:
} panic("Unsupported type")
}
} }
func ReadBinary(r io.Reader) Binary { func ReadBinary(r io.Reader) Binary {
type_ := ReadByte(r) type_ := ReadByte(r)
switch type_ { switch type_ {
case TYPE_NIL: return nil case TYPE_NIL:
case TYPE_BYTE: return ReadByte(r) return nil
case TYPE_INT8: return ReadInt8(r) case TYPE_BYTE:
case TYPE_UINT8: return ReadUInt8(r) return ReadByte(r)
case TYPE_INT16: return ReadInt16(r) case TYPE_INT8:
case TYPE_UINT16: return ReadUInt16(r) return ReadInt8(r)
case TYPE_INT32: return ReadInt32(r) case TYPE_UINT8:
case TYPE_UINT32: return ReadUInt32(r) return ReadUInt8(r)
case TYPE_INT64: return ReadInt64(r) case TYPE_INT16:
case TYPE_UINT64: return ReadUInt64(r) return ReadInt16(r)
case TYPE_UINT16:
return ReadUInt16(r)
case TYPE_INT32:
return ReadInt32(r)
case TYPE_UINT32:
return ReadUInt32(r)
case TYPE_INT64:
return ReadInt64(r)
case TYPE_UINT64:
return ReadUInt64(r)
case TYPE_STRING: return ReadString(r) case TYPE_STRING:
case TYPE_BYTESLICE:return ReadByteSlice(r) return ReadString(r)
case TYPE_BYTESLICE:
return ReadByteSlice(r)
case TYPE_TIME: return ReadTime(r) case TYPE_TIME:
return ReadTime(r)
default: panic("Unsupported type") default:
} panic("Unsupported type")
}
} }

View File

@ -1,8 +1,8 @@
package binary package binary
import ( import (
"io" "encoding/binary"
"encoding/binary" "io"
) )
type Byte byte type Byte byte
@ -17,397 +17,426 @@ type UInt64 uint64
type Int int type Int int
type UInt uint type UInt uint
// Byte // Byte
func (self Byte) Equals(other Binary) bool { func (self Byte) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Byte) Less(other Binary) bool { func (self Byte) Less(other Binary) bool {
if o, ok := other.(Byte); ok { if o, ok := other.(Byte); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Byte) ByteSize() int { func (self Byte) ByteSize() int {
return 1 return 1
} }
func (self Byte) WriteTo(w io.Writer) (int64, error) { func (self Byte) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)}) n, err := w.Write([]byte{byte(self)})
return int64(n), err return int64(n), err
} }
func ReadByteSafe(r io.Reader) (Byte, error) { func ReadByteSafe(r io.Reader) (Byte, error) {
buf := [1]byte{0} buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return 0, err } if err != nil {
return Byte(buf[0]), nil return 0, err
}
return Byte(buf[0]), nil
} }
func ReadByte(r io.Reader) (Byte) { func ReadByte(r io.Reader) Byte {
b, err := ReadByteSafe(r) b, err := ReadByteSafe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// Int8 // Int8
func (self Int8) Equals(other Binary) bool { func (self Int8) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Int8) Less(other Binary) bool { func (self Int8) Less(other Binary) bool {
if o, ok := other.(Int8); ok { if o, ok := other.(Int8); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Int8) ByteSize() int { func (self Int8) ByteSize() int {
return 1 return 1
} }
func (self Int8) WriteTo(w io.Writer) (int64, error) { func (self Int8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)}) n, err := w.Write([]byte{byte(self)})
return int64(n), err return int64(n), err
} }
func ReadInt8Safe(r io.Reader) (Int8, error) { func ReadInt8Safe(r io.Reader) (Int8, error) {
buf := [1]byte{0} buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return Int8(0), err } if err != nil {
return Int8(buf[0]), nil return Int8(0), err
}
return Int8(buf[0]), nil
} }
func ReadInt8(r io.Reader) (Int8) { func ReadInt8(r io.Reader) Int8 {
b, err := ReadInt8Safe(r) b, err := ReadInt8Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// UInt8 // UInt8
func (self UInt8) Equals(other Binary) bool { func (self UInt8) Equals(other Binary) bool {
return self == other return self == other
} }
func (self UInt8) Less(other Binary) bool { func (self UInt8) Less(other Binary) bool {
if o, ok := other.(UInt8); ok { if o, ok := other.(UInt8); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self UInt8) ByteSize() int { func (self UInt8) ByteSize() int {
return 1 return 1
} }
func (self UInt8) WriteTo(w io.Writer) (int64, error) { func (self UInt8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)}) n, err := w.Write([]byte{byte(self)})
return int64(n), err return int64(n), err
} }
func ReadUInt8Safe(r io.Reader) (UInt8, error) { func ReadUInt8Safe(r io.Reader) (UInt8, error) {
buf := [1]byte{0} buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return UInt8(0), err } if err != nil {
return UInt8(buf[0]), nil return UInt8(0), err
}
return UInt8(buf[0]), nil
} }
func ReadUInt8(r io.Reader) (UInt8) { func ReadUInt8(r io.Reader) UInt8 {
b, err := ReadUInt8Safe(r) b, err := ReadUInt8Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// Int16 // Int16
func (self Int16) Equals(other Binary) bool { func (self Int16) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Int16) Less(other Binary) bool { func (self Int16) Less(other Binary) bool {
if o, ok := other.(Int16); ok { if o, ok := other.(Int16); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Int16) ByteSize() int { func (self Int16) ByteSize() int {
return 2 return 2
} }
func (self Int16) WriteTo(w io.Writer) (int64, error) { func (self Int16) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int16(self)) err := binary.Write(w, binary.LittleEndian, int16(self))
return 2, err return 2, err
} }
func ReadInt16Safe(r io.Reader) (Int16, error) { func ReadInt16Safe(r io.Reader) (Int16, error) {
buf := [2]byte{0} buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return Int16(0), err } if err != nil {
return Int16(binary.LittleEndian.Uint16(buf[:])), nil return Int16(0), err
}
return Int16(binary.LittleEndian.Uint16(buf[:])), nil
} }
func ReadInt16(r io.Reader) (Int16) { func ReadInt16(r io.Reader) Int16 {
b, err := ReadInt16Safe(r) b, err := ReadInt16Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// UInt16 // UInt16
func (self UInt16) Equals(other Binary) bool { func (self UInt16) Equals(other Binary) bool {
return self == other return self == other
} }
func (self UInt16) Less(other Binary) bool { func (self UInt16) Less(other Binary) bool {
if o, ok := other.(UInt16); ok { if o, ok := other.(UInt16); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self UInt16) ByteSize() int { func (self UInt16) ByteSize() int {
return 2 return 2
} }
func (self UInt16) WriteTo(w io.Writer) (int64, error) { func (self UInt16) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint16(self)) err := binary.Write(w, binary.LittleEndian, uint16(self))
return 2, err return 2, err
} }
func ReadUInt16Safe(r io.Reader) (UInt16, error) { func ReadUInt16Safe(r io.Reader) (UInt16, error) {
buf := [2]byte{0} buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return UInt16(0), err } if err != nil {
return UInt16(binary.LittleEndian.Uint16(buf[:])), nil return UInt16(0), err
}
return UInt16(binary.LittleEndian.Uint16(buf[:])), nil
} }
func ReadUInt16(r io.Reader) (UInt16) { func ReadUInt16(r io.Reader) UInt16 {
b, err := ReadUInt16Safe(r) b, err := ReadUInt16Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// Int32 // Int32
func (self Int32) Equals(other Binary) bool { func (self Int32) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Int32) Less(other Binary) bool { func (self Int32) Less(other Binary) bool {
if o, ok := other.(Int32); ok { if o, ok := other.(Int32); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Int32) ByteSize() int { func (self Int32) ByteSize() int {
return 4 return 4
} }
func (self Int32) WriteTo(w io.Writer) (int64, error) { func (self Int32) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int32(self)) err := binary.Write(w, binary.LittleEndian, int32(self))
return 4, err return 4, err
} }
func ReadInt32Safe(r io.Reader) (Int32, error) { func ReadInt32Safe(r io.Reader) (Int32, error) {
buf := [4]byte{0} buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return Int32(0), err } if err != nil {
return Int32(binary.LittleEndian.Uint32(buf[:])), nil return Int32(0), err
}
return Int32(binary.LittleEndian.Uint32(buf[:])), nil
} }
func ReadInt32(r io.Reader) (Int32) { func ReadInt32(r io.Reader) Int32 {
b, err := ReadInt32Safe(r) b, err := ReadInt32Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// UInt32 // UInt32
func (self UInt32) Equals(other Binary) bool { func (self UInt32) Equals(other Binary) bool {
return self == other return self == other
} }
func (self UInt32) Less(other Binary) bool { func (self UInt32) Less(other Binary) bool {
if o, ok := other.(UInt32); ok { if o, ok := other.(UInt32); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self UInt32) ByteSize() int { func (self UInt32) ByteSize() int {
return 4 return 4
} }
func (self UInt32) WriteTo(w io.Writer) (int64, error) { func (self UInt32) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint32(self)) err := binary.Write(w, binary.LittleEndian, uint32(self))
return 4, err return 4, err
} }
func ReadUInt32Safe(r io.Reader) (UInt32, error) { func ReadUInt32Safe(r io.Reader) (UInt32, error) {
buf := [4]byte{0} buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return UInt32(0), err } if err != nil {
return UInt32(binary.LittleEndian.Uint32(buf[:])), nil return UInt32(0), err
}
return UInt32(binary.LittleEndian.Uint32(buf[:])), nil
} }
func ReadUInt32(r io.Reader) (UInt32) { func ReadUInt32(r io.Reader) UInt32 {
b, err := ReadUInt32Safe(r) b, err := ReadUInt32Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// Int64 // Int64
func (self Int64) Equals(other Binary) bool { func (self Int64) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Int64) Less(other Binary) bool { func (self Int64) Less(other Binary) bool {
if o, ok := other.(Int64); ok { if o, ok := other.(Int64); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Int64) ByteSize() int { func (self Int64) ByteSize() int {
return 8 return 8
} }
func (self Int64) WriteTo(w io.Writer) (int64, error) { func (self Int64) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int64(self)) err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err return 8, err
} }
func ReadInt64Safe(r io.Reader) (Int64, error) { func ReadInt64Safe(r io.Reader) (Int64, error) {
buf := [8]byte{0} buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return Int64(0), err } if err != nil {
return Int64(binary.LittleEndian.Uint64(buf[:])), nil return Int64(0), err
}
return Int64(binary.LittleEndian.Uint64(buf[:])), nil
} }
func ReadInt64(r io.Reader) (Int64) { func ReadInt64(r io.Reader) Int64 {
b, err := ReadInt64Safe(r) b, err := ReadInt64Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// UInt64 // UInt64
func (self UInt64) Equals(other Binary) bool { func (self UInt64) Equals(other Binary) bool {
return self == other return self == other
} }
func (self UInt64) Less(other Binary) bool { func (self UInt64) Less(other Binary) bool {
if o, ok := other.(UInt64); ok { if o, ok := other.(UInt64); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self UInt64) ByteSize() int { func (self UInt64) ByteSize() int {
return 8 return 8
} }
func (self UInt64) WriteTo(w io.Writer) (int64, error) { func (self UInt64) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint64(self)) err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err return 8, err
} }
func ReadUInt64Safe(r io.Reader) (UInt64, error) { func ReadUInt64Safe(r io.Reader) (UInt64, error) {
buf := [8]byte{0} buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { return UInt64(0), err } if err != nil {
return UInt64(binary.LittleEndian.Uint64(buf[:])), nil return UInt64(0), err
}
return UInt64(binary.LittleEndian.Uint64(buf[:])), nil
} }
func ReadUInt64(r io.Reader) (UInt64) { func ReadUInt64(r io.Reader) UInt64 {
b, err := ReadUInt64Safe(r) b, err := ReadUInt64Safe(r)
if err != nil { panic(err) } if err != nil {
return b panic(err)
}
return b
} }
// Int // Int
func (self Int) Equals(other Binary) bool { func (self Int) Equals(other Binary) bool {
return self == other return self == other
} }
func (self Int) Less(other Binary) bool { func (self Int) Less(other Binary) bool {
if o, ok := other.(Int); ok { if o, ok := other.(Int); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Int) ByteSize() int { func (self Int) ByteSize() int {
return 8 return 8
} }
func (self Int) WriteTo(w io.Writer) (int64, error) { func (self Int) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int64(self)) err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err return 8, err
} }
func ReadInt(r io.Reader) Int { func ReadInt(r io.Reader) Int {
buf := [8]byte{0} buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { panic(err) } if err != nil {
return Int(binary.LittleEndian.Uint64(buf[:])) panic(err)
}
return Int(binary.LittleEndian.Uint64(buf[:]))
} }
// UInt // UInt
func (self UInt) Equals(other Binary) bool { func (self UInt) Equals(other Binary) bool {
return self == other return self == other
} }
func (self UInt) Less(other Binary) bool { func (self UInt) Less(other Binary) bool {
if o, ok := other.(UInt); ok { if o, ok := other.(UInt); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self UInt) ByteSize() int { func (self UInt) ByteSize() int {
return 8 return 8
} }
func (self UInt) WriteTo(w io.Writer) (int64, error) { func (self UInt) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint64(self)) err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err return 8, err
} }
func ReadUInt(r io.Reader) UInt { func ReadUInt(r io.Reader) UInt {
buf := [8]byte{0} buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { panic(err) } if err != nil {
return UInt(binary.LittleEndian.Uint64(buf[:])) panic(err)
}
return UInt(binary.LittleEndian.Uint64(buf[:]))
} }

View File

@ -7,40 +7,48 @@ type String string
// String // String
func (self String) Equals(other Binary) bool { func (self String) Equals(other Binary) bool {
return self == other return self == other
} }
func (self String) Less(other Binary) bool { func (self String) Less(other Binary) bool {
if o, ok := other.(String); ok { if o, ok := other.(String); ok {
return self < o return self < o
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self String) ByteSize() int { func (self String) ByteSize() int {
return len(self)+4 return len(self) + 4
} }
func (self String) WriteTo(w io.Writer) (n int64, err error) { func (self String) WriteTo(w io.Writer) (n int64, err error) {
var n_ int var n_ int
_, err = UInt32(len(self)).WriteTo(w) _, err = UInt32(len(self)).WriteTo(w)
if err != nil { return n, err } if err != nil {
n_, err = w.Write([]byte(self)) return n, err
return int64(n_+4), err }
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
} }
func ReadStringSafe(r io.Reader) (String, error) { func ReadStringSafe(r io.Reader) (String, error) {
length, err := ReadUInt32Safe(r) length, err := ReadUInt32Safe(r)
if err != nil { return "", err } if err != nil {
bytes := make([]byte, int(length)) return "", err
_, err = io.ReadFull(r, bytes) }
if err != nil { return "", err } bytes := make([]byte, int(length))
return String(bytes), nil _, err = io.ReadFull(r, bytes)
if err != nil {
return "", err
}
return String(bytes), nil
} }
func ReadString(r io.Reader) String { func ReadString(r io.Reader) String {
str, err := ReadStringSafe(r) str, err := ReadStringSafe(r)
if r != nil { panic(err) } if r != nil {
return str panic(err)
}
return str
} }

View File

@ -1,38 +1,38 @@
package binary package binary
import ( import (
"io" "io"
"time" "time"
) )
type Time struct { type Time struct {
time.Time time.Time
} }
func (self Time) Equals(other Binary) bool { func (self Time) Equals(other Binary) bool {
if o, ok := other.(Time); ok { if o, ok := other.(Time); ok {
return self.Equal(o.Time) return self.Equal(o.Time)
} else { } else {
return false return false
} }
} }
func (self Time) Less(other Binary) bool { func (self Time) Less(other Binary) bool {
if o, ok := other.(Time); ok { if o, ok := other.(Time); ok {
return self.Before(o.Time) return self.Before(o.Time)
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (self Time) ByteSize() int { func (self Time) ByteSize() int {
return 8 return 8
} }
func (self Time) WriteTo(w io.Writer) (int64, error) { func (self Time) WriteTo(w io.Writer) (int64, error) {
return Int64(self.Unix()).WriteTo(w) return Int64(self.Unix()).WriteTo(w)
} }
func ReadTime(r io.Reader) Time { func ReadTime(r io.Reader) Time {
return Time{time.Unix(int64(ReadInt64(r)), 0)} return Time{time.Unix(int64(ReadInt64(r)), 0)}
} }

View File

@ -1,33 +1,35 @@
package binary package binary
import ( import (
"crypto/sha256" "bytes"
"bytes" "crypto/sha256"
) )
func BinaryBytes(b Binary) ByteSlice { func BinaryBytes(b Binary) ByteSlice {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
b.WriteTo(buf) b.WriteTo(buf)
return ByteSlice(buf.Bytes()) return ByteSlice(buf.Bytes())
} }
// NOTE: does not care about the type, only the binary representation. // NOTE: does not care about the type, only the binary representation.
func BinaryEqual(a, b Binary) bool { func BinaryEqual(a, b Binary) bool {
aBytes := BinaryBytes(a) aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b) bBytes := BinaryBytes(b)
return bytes.Equal(aBytes, bBytes) return bytes.Equal(aBytes, bBytes)
} }
// NOTE: does not care about the type, only the binary representation. // NOTE: does not care about the type, only the binary representation.
func BinaryCompare(a, b Binary) int { func BinaryCompare(a, b Binary) int {
aBytes := BinaryBytes(a) aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b) bBytes := BinaryBytes(b)
return bytes.Compare(aBytes, bBytes) return bytes.Compare(aBytes, bBytes)
} }
func BinaryHash(b Binary) ByteSlice { func BinaryHash(b Binary) ByteSlice {
hasher := sha256.New() hasher := sha256.New()
_, err := b.WriteTo(hasher) _, err := b.WriteTo(hasher)
if err != nil { panic(err) } if err != nil {
return ByteSlice(hasher.Sum(nil)) panic(err)
}
return ByteSlice(hasher.Sum(nil))
} }

View File

@ -1,48 +1,48 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common"
"io" "io"
) )
type AccountId struct { type AccountId struct {
Type Byte Type Byte
Number UInt64 Number UInt64
PubKey ByteSlice PubKey ByteSlice
} }
const ( const (
ACCOUNT_TYPE_NUMBER = Byte(0x01) ACCOUNT_TYPE_NUMBER = Byte(0x01)
ACCOUNT_TYPE_PUBKEY = Byte(0x02) ACCOUNT_TYPE_PUBKEY = Byte(0x02)
ACCOUNT_TYPE_BOTH = Byte(0x03) ACCOUNT_TYPE_BOTH = Byte(0x03)
) )
func ReadAccountId(r io.Reader) AccountId { func ReadAccountId(r io.Reader) AccountId {
switch t := ReadByte(r); t { switch t := ReadByte(r); t {
case ACCOUNT_TYPE_NUMBER: case ACCOUNT_TYPE_NUMBER:
return AccountId{t, ReadUInt64(r), nil} return AccountId{t, ReadUInt64(r), nil}
case ACCOUNT_TYPE_PUBKEY: case ACCOUNT_TYPE_PUBKEY:
return AccountId{t, 0, ReadByteSlice(r)} return AccountId{t, 0, ReadByteSlice(r)}
case ACCOUNT_TYPE_BOTH: case ACCOUNT_TYPE_BOTH:
return AccountId{t, ReadUInt64(r), ReadByteSlice(r)} return AccountId{t, ReadUInt64(r), ReadByteSlice(r)}
default: default:
Panicf("Unknown AccountId type %x", t) Panicf("Unknown AccountId type %x", t)
return AccountId{} return AccountId{}
} }
} }
func (self AccountId) WriteTo(w io.Writer) (n int64, err error) { func (self AccountId) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type, w, n, err) n, err = WriteOnto(self.Type, w, n, err)
if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH { if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.Number, w, n, err) n, err = WriteOnto(self.Number, w, n, err)
} }
if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH { if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.PubKey, w, n, err) n, err = WriteOnto(self.PubKey, w, n, err)
} }
return return
} }
func AccountNumber(n UInt64) AccountId { func AccountNumber(n UInt64) AccountId {
return AccountId{ACCOUNT_TYPE_NUMBER, n, nil} return AccountId{ACCOUNT_TYPE_NUMBER, n, nil}
} }

View File

@ -1,9 +1,9 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common"
"io" "io"
) )
/* Adjustment /* Adjustment
@ -17,126 +17,122 @@ TODO: signing a bad checkpoint (block)
*/ */
type Adjustment interface { type Adjustment interface {
Type() Byte Type() Byte
Binary Binary
} }
const ( const (
ADJ_TYPE_BOND = Byte(0x01) ADJ_TYPE_BOND = Byte(0x01)
ADJ_TYPE_UNBOND = Byte(0x02) ADJ_TYPE_UNBOND = Byte(0x02)
ADJ_TYPE_TIMEOUT = Byte(0x03) ADJ_TYPE_TIMEOUT = Byte(0x03)
ADJ_TYPE_DUPEOUT = Byte(0x04) ADJ_TYPE_DUPEOUT = Byte(0x04)
) )
func ReadAdjustment(r io.Reader) Adjustment { func ReadAdjustment(r io.Reader) Adjustment {
switch t := ReadByte(r); t { switch t := ReadByte(r); t {
case ADJ_TYPE_BOND: case ADJ_TYPE_BOND:
return &Bond{ return &Bond{
Fee: ReadUInt64(r), Fee: ReadUInt64(r),
UnbondTo: ReadAccountId(r), UnbondTo: ReadAccountId(r),
Amount: ReadUInt64(r), Amount: ReadUInt64(r),
Signature: ReadSignature(r), Signature: ReadSignature(r),
} }
case ADJ_TYPE_UNBOND: case ADJ_TYPE_UNBOND:
return &Unbond{ return &Unbond{
Fee: ReadUInt64(r), Fee: ReadUInt64(r),
Amount: ReadUInt64(r), Amount: ReadUInt64(r),
Signature: ReadSignature(r), Signature: ReadSignature(r),
} }
case ADJ_TYPE_TIMEOUT: case ADJ_TYPE_TIMEOUT:
return &Timeout{ return &Timeout{
Account: ReadAccountId(r), Account: ReadAccountId(r),
Penalty: ReadUInt64(r), Penalty: ReadUInt64(r),
} }
case ADJ_TYPE_DUPEOUT: case ADJ_TYPE_DUPEOUT:
return &Dupeout{ return &Dupeout{
VoteA: ReadVote(r), VoteA: ReadVote(r),
VoteB: ReadVote(r), VoteB: ReadVote(r),
} }
default: default:
Panicf("Unknown Adjustment type %x", t) Panicf("Unknown Adjustment type %x", t)
return nil return nil
} }
} }
/* Bond < Adjustment */ /* Bond < Adjustment */
type Bond struct { type Bond struct {
Fee UInt64 Fee UInt64
UnbondTo AccountId UnbondTo AccountId
Amount UInt64 Amount UInt64
Signature Signature
} }
func (self *Bond) Type() Byte { func (self *Bond) Type() Byte {
return ADJ_TYPE_BOND return ADJ_TYPE_BOND
} }
func (self *Bond) WriteTo(w io.Writer) (n int64, err error) { func (self *Bond) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err) n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.UnbondTo, w, n, err) n, err = WriteOnto(self.UnbondTo, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err) n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err) n, err = WriteOnto(self.Signature, w, n, err)
return return
} }
/* Unbond < Adjustment */ /* Unbond < Adjustment */
type Unbond struct { type Unbond struct {
Fee UInt64 Fee UInt64
Amount UInt64 Amount UInt64
Signature Signature
} }
func (self *Unbond) Type() Byte { func (self *Unbond) Type() Byte {
return ADJ_TYPE_UNBOND return ADJ_TYPE_UNBOND
} }
func (self *Unbond) WriteTo(w io.Writer) (n int64, err error) { func (self *Unbond) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err) n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err) n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err) n, err = WriteOnto(self.Signature, w, n, err)
return return
} }
/* Timeout < Adjustment */ /* Timeout < Adjustment */
type Timeout struct { type Timeout struct {
Account AccountId Account AccountId
Penalty UInt64 Penalty UInt64
} }
func (self *Timeout) Type() Byte { func (self *Timeout) Type() Byte {
return ADJ_TYPE_TIMEOUT return ADJ_TYPE_TIMEOUT
} }
func (self *Timeout) WriteTo(w io.Writer) (n int64, err error) { func (self *Timeout) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Account, w, n, err) n, err = WriteOnto(self.Account, w, n, err)
n, err = WriteOnto(self.Penalty, w, n, err) n, err = WriteOnto(self.Penalty, w, n, err)
return return
} }
/* Dupeout < Adjustment */ /* Dupeout < Adjustment */
type Dupeout struct { type Dupeout struct {
VoteA Vote VoteA Vote
VoteB Vote VoteB Vote
} }
func (self *Dupeout) Type() Byte { func (self *Dupeout) Type() Byte {
return ADJ_TYPE_DUPEOUT return ADJ_TYPE_DUPEOUT
} }
func (self *Dupeout) WriteTo(w io.Writer) (n int64, err error) { func (self *Dupeout) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.VoteA, w, n, err) n, err = WriteOnto(self.VoteA, w, n, err)
n, err = WriteOnto(self.VoteB, w, n, err) n, err = WriteOnto(self.VoteB, w, n, err)
return return
} }

View File

@ -1,140 +1,137 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle" "github.com/tendermint/tendermint/merkle"
"io" "io"
) )
/* Block */ /* Block */
type Block struct { type Block struct {
Header Header
Validation Validation
Data Data
// Checkpoint // Checkpoint
} }
func ReadBlock(r io.Reader) *Block { func ReadBlock(r io.Reader) *Block {
return &Block{ return &Block{
Header: ReadHeader(r), Header: ReadHeader(r),
Validation: ReadValidation(r), Validation: ReadValidation(r),
Data: ReadData(r), Data: ReadData(r),
} }
} }
func (self *Block) Validate() bool { func (self *Block) Validate() bool {
return false return false
} }
func (self *Block) WriteTo(w io.Writer) (n int64, err error) { func (self *Block) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(&self.Header, w, n, err) n, err = WriteOnto(&self.Header, w, n, err)
n, err = WriteOnto(&self.Validation, w, n, err) n, err = WriteOnto(&self.Validation, w, n, err)
n, err = WriteOnto(&self.Data, w, n, err) n, err = WriteOnto(&self.Data, w, n, err)
return return
} }
/* Block > Header */ /* Block > Header */
type Header struct { type Header struct {
Name String Name String
Height UInt64 Height UInt64
Fees UInt64 Fees UInt64
Time UInt64 Time UInt64
PrevHash ByteSlice PrevHash ByteSlice
ValidationHash ByteSlice ValidationHash ByteSlice
DataHash ByteSlice DataHash ByteSlice
} }
func ReadHeader(r io.Reader) Header { func ReadHeader(r io.Reader) Header {
return Header{ return Header{
Name: ReadString(r), Name: ReadString(r),
Height: ReadUInt64(r), Height: ReadUInt64(r),
Fees: ReadUInt64(r), Fees: ReadUInt64(r),
Time: ReadUInt64(r), Time: ReadUInt64(r),
PrevHash: ReadByteSlice(r), PrevHash: ReadByteSlice(r),
ValidationHash: ReadByteSlice(r), ValidationHash: ReadByteSlice(r),
DataHash: ReadByteSlice(r), DataHash: ReadByteSlice(r),
} }
} }
func (self *Header) WriteTo(w io.Writer) (n int64, err error) { func (self *Header) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Name, w, n, err) n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.Height, w, n, err) n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.Fees, w, n, err) n, err = WriteOnto(self.Fees, w, n, err)
n, err = WriteOnto(self.Time, w, n, err) n, err = WriteOnto(self.Time, w, n, err)
n, err = WriteOnto(self.PrevHash, w, n, err) n, err = WriteOnto(self.PrevHash, w, n, err)
n, err = WriteOnto(self.ValidationHash, w, n, err) n, err = WriteOnto(self.ValidationHash, w, n, err)
n, err = WriteOnto(self.DataHash, w, n, err) n, err = WriteOnto(self.DataHash, w, n, err)
return return
} }
/* Block > Validation */ /* Block > Validation */
type Validation struct { type Validation struct {
Signatures []Signature Signatures []Signature
Adjustments []Adjustment Adjustments []Adjustment
} }
func ReadValidation(r io.Reader) Validation { func ReadValidation(r io.Reader) Validation {
numSigs := int(ReadUInt64(r)) numSigs := int(ReadUInt64(r))
numAdjs := int(ReadUInt64(r)) numAdjs := int(ReadUInt64(r))
sigs := make([]Signature, 0, numSigs) sigs := make([]Signature, 0, numSigs)
for i:=0; i<numSigs; i++ { for i := 0; i < numSigs; i++ {
sigs = append(sigs, ReadSignature(r)) sigs = append(sigs, ReadSignature(r))
} }
adjs := make([]Adjustment, 0, numAdjs) adjs := make([]Adjustment, 0, numAdjs)
for i:=0; i<numAdjs; i++ { for i := 0; i < numAdjs; i++ {
adjs = append(adjs, ReadAdjustment(r)) adjs = append(adjs, ReadAdjustment(r))
} }
return Validation{ return Validation{
Signatures: sigs, Signatures: sigs,
Adjustments: adjs, Adjustments: adjs,
} }
} }
func (self *Validation) WriteTo(w io.Writer) (n int64, err error) { func (self *Validation) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(UInt64(len(self.Signatures)), w, n, err) n, err = WriteOnto(UInt64(len(self.Signatures)), w, n, err)
n, err = WriteOnto(UInt64(len(self.Adjustments)), w, n, err) n, err = WriteOnto(UInt64(len(self.Adjustments)), w, n, err)
for _, sig := range self.Signatures { for _, sig := range self.Signatures {
n, err = WriteOnto(sig, w, n, err) n, err = WriteOnto(sig, w, n, err)
} }
for _, adj := range self.Adjustments { for _, adj := range self.Adjustments {
n, err = WriteOnto(adj, w, n, err) n, err = WriteOnto(adj, w, n, err)
} }
return return
} }
/* Block > Data */ /* Block > Data */
type Data struct { type Data struct {
Txs []Tx Txs []Tx
} }
func ReadData(r io.Reader) Data { func ReadData(r io.Reader) Data {
numTxs := int(ReadUInt64(r)) numTxs := int(ReadUInt64(r))
txs := make([]Tx, 0, numTxs) txs := make([]Tx, 0, numTxs)
for i:=0; i<numTxs; i++ { for i := 0; i < numTxs; i++ {
txs = append(txs, ReadTx(r)) txs = append(txs, ReadTx(r))
} }
return Data{txs} return Data{txs}
} }
func (self *Data) WriteTo(w io.Writer) (n int64, err error) { func (self *Data) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(UInt64(len(self.Txs)), w, n, err) n, err = WriteOnto(UInt64(len(self.Txs)), w, n, err)
for _, tx := range self.Txs { for _, tx := range self.Txs {
n, err = WriteOnto(tx, w, n, err) n, err = WriteOnto(tx, w, n, err)
} }
return return
} }
func (self *Data) MerkleHash() ByteSlice { func (self *Data) MerkleHash() ByteSlice {
bs := make([]Binary, 0, len(self.Txs)) bs := make([]Binary, 0, len(self.Txs))
for i, tx := range self.Txs { for i, tx := range self.Txs {
bs[i] = Binary(tx) bs[i] = Binary(tx)
} }
return merkle.HashFromBinarySlice(bs) return merkle.HashFromBinarySlice(bs)
} }

View File

@ -1,113 +1,115 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/binary" "bytes"
"testing" . "github.com/tendermint/tendermint/binary"
"math/rand" "math/rand"
"bytes" "testing"
) )
// Distributed pseudo-exponentially to test for various cases // Distributed pseudo-exponentially to test for various cases
func randVar() UInt64 { func randVar() UInt64 {
bits := rand.Uint32() % 64 bits := rand.Uint32() % 64
if bits == 0 { return 0 } if bits == 0 {
n := uint64(1 << (bits-1)) return 0
n += uint64(rand.Int63()) & ((1 << (bits-1)) - 1) }
return UInt64(n) n := uint64(1 << (bits - 1))
n += uint64(rand.Int63()) & ((1 << (bits - 1)) - 1)
return UInt64(n)
} }
func randBytes(n int) ByteSlice { func randBytes(n int) ByteSlice {
bs := make([]byte, n) bs := make([]byte, n)
for i:=0; i<n; i++ { for i := 0; i < n; i++ {
bs[i] = byte(rand.Intn(256)) bs[i] = byte(rand.Intn(256))
} }
return bs return bs
} }
func randSig() Signature { func randSig() Signature {
return Signature{AccountNumber(randVar()), randBytes(32)} return Signature{AccountNumber(randVar()), randBytes(32)}
} }
func TestBlock(t *testing.T) { func TestBlock(t *testing.T) {
// Txs // Txs
sendTx := &SendTx{ sendTx := &SendTx{
Signature: randSig(), Signature: randSig(),
Fee: randVar(), Fee: randVar(),
To: AccountNumber(randVar()), To: AccountNumber(randVar()),
Amount: randVar(), Amount: randVar(),
} }
nameTx := &NameTx{ nameTx := &NameTx{
Signature: randSig(), Signature: randSig(),
Fee: randVar(), Fee: randVar(),
Name: String(randBytes(12)), Name: String(randBytes(12)),
PubKey: randBytes(32), PubKey: randBytes(32),
} }
// Adjs // Adjs
bond := &Bond{ bond := &Bond{
Signature: randSig(), Signature: randSig(),
Fee: randVar(), Fee: randVar(),
UnbondTo: AccountNumber(randVar()), UnbondTo: AccountNumber(randVar()),
Amount: randVar(), Amount: randVar(),
} }
unbond := &Unbond{ unbond := &Unbond{
Signature: randSig(), Signature: randSig(),
Fee: randVar(), Fee: randVar(),
Amount: randVar(), Amount: randVar(),
} }
timeout := &Timeout{ timeout := &Timeout{
Account: AccountNumber(randVar()), Account: AccountNumber(randVar()),
Penalty: randVar(), Penalty: randVar(),
} }
dupeout := &Dupeout{ dupeout := &Dupeout{
VoteA: Vote{ VoteA: Vote{
Height: randVar(), Height: randVar(),
BlockHash: randBytes(32), BlockHash: randBytes(32),
Signature: randSig(), Signature: randSig(),
}, },
VoteB: Vote{ VoteB: Vote{
Height: randVar(), Height: randVar(),
BlockHash: randBytes(32), BlockHash: randBytes(32),
Signature: randSig(), Signature: randSig(),
}, },
} }
// Block // Block
block := &Block{ block := &Block{
Header{ Header{
Name: "Tendermint", Name: "Tendermint",
Height: randVar(), Height: randVar(),
Fees: randVar(), Fees: randVar(),
Time: randVar(), Time: randVar(),
PrevHash: randBytes(32), PrevHash: randBytes(32),
ValidationHash: randBytes(32), ValidationHash: randBytes(32),
DataHash: randBytes(32), DataHash: randBytes(32),
}, },
Validation{ Validation{
Signatures: []Signature{randSig(),randSig()}, Signatures: []Signature{randSig(), randSig()},
Adjustments:[]Adjustment{bond,unbond,timeout,dupeout}, Adjustments: []Adjustment{bond, unbond, timeout, dupeout},
}, },
Data{ Data{
Txs: []Tx{sendTx, nameTx}, Txs: []Tx{sendTx, nameTx},
}, },
} }
// Write the block, read it in again, write it again. // Write the block, read it in again, write it again.
// Then, compare. // Then, compare.
blockBytes := BinaryBytes(block) blockBytes := BinaryBytes(block)
block2 := ReadBlock(bytes.NewReader(blockBytes)) block2 := ReadBlock(bytes.NewReader(blockBytes))
blockBytes2 := BinaryBytes(block2) blockBytes2 := BinaryBytes(block2)
if !BinaryEqual(blockBytes, blockBytes2) { if !BinaryEqual(blockBytes, blockBytes2) {
t.Fatal("Write->Read of block failed.") t.Fatal("Write->Read of block failed.")
} }
} }

View File

@ -1,8 +1,8 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"io" "io"
) )
/* /*
@ -19,23 +19,23 @@ It usually follows the message to be signed.
*/ */
type Signature struct { type Signature struct {
Signer AccountId Signer AccountId
SigBytes ByteSlice SigBytes ByteSlice
} }
func ReadSignature(r io.Reader) Signature { func ReadSignature(r io.Reader) Signature {
return Signature{ return Signature{
Signer: ReadAccountId(r), Signer: ReadAccountId(r),
SigBytes: ReadByteSlice(r), SigBytes: ReadByteSlice(r),
} }
} }
func (self Signature) WriteTo(w io.Writer) (n int64, err error) { func (self Signature) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Signer, w, n, err) n, err = WriteOnto(self.Signer, w, n, err)
n, err = WriteOnto(self.SigBytes, w, n, err) n, err = WriteOnto(self.SigBytes, w, n, err)
return return
} }
func (self *Signature) Verify(msg ByteSlice) bool { func (self *Signature) Verify(msg ByteSlice) bool {
return false return false
} }

View File

@ -1,9 +1,9 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common"
"io" "io"
) )
/* /*
@ -21,79 +21,77 @@ Tx wire format:
*/ */
type Tx interface { type Tx interface {
Type() Byte Type() Byte
Binary Binary
} }
const ( const (
TX_TYPE_SEND = Byte(0x01) TX_TYPE_SEND = Byte(0x01)
TX_TYPE_NAME = Byte(0x02) TX_TYPE_NAME = Byte(0x02)
) )
func ReadTx(r io.Reader) Tx { func ReadTx(r io.Reader) Tx {
switch t := ReadByte(r); t { switch t := ReadByte(r); t {
case TX_TYPE_SEND: case TX_TYPE_SEND:
return &SendTx{ return &SendTx{
Fee: ReadUInt64(r), Fee: ReadUInt64(r),
To: ReadAccountId(r), To: ReadAccountId(r),
Amount: ReadUInt64(r), Amount: ReadUInt64(r),
Signature: ReadSignature(r), Signature: ReadSignature(r),
} }
case TX_TYPE_NAME: case TX_TYPE_NAME:
return &NameTx{ return &NameTx{
Fee: ReadUInt64(r), Fee: ReadUInt64(r),
Name: ReadString(r), Name: ReadString(r),
PubKey: ReadByteSlice(r), PubKey: ReadByteSlice(r),
Signature: ReadSignature(r), Signature: ReadSignature(r),
} }
default: default:
Panicf("Unknown Tx type %x", t) Panicf("Unknown Tx type %x", t)
return nil return nil
} }
} }
/* SendTx < Tx */ /* SendTx < Tx */
type SendTx struct { type SendTx struct {
Fee UInt64 Fee UInt64
To AccountId To AccountId
Amount UInt64 Amount UInt64
Signature Signature
} }
func (self *SendTx) Type() Byte { func (self *SendTx) Type() Byte {
return TX_TYPE_SEND return TX_TYPE_SEND
} }
func (self *SendTx) WriteTo(w io.Writer) (n int64, err error) { func (self *SendTx) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err) n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.To, w, n, err) n, err = WriteOnto(self.To, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err) n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err) n, err = WriteOnto(self.Signature, w, n, err)
return return
} }
/* NameTx < Tx */ /* NameTx < Tx */
type NameTx struct { type NameTx struct {
Fee UInt64 Fee UInt64
Name String Name String
PubKey ByteSlice PubKey ByteSlice
Signature Signature
} }
func (self *NameTx) Type() Byte { func (self *NameTx) Type() Byte {
return TX_TYPE_NAME return TX_TYPE_NAME
} }
func (self *NameTx) WriteTo(w io.Writer) (n int64, err error) { func (self *NameTx) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err) n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err) n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Name, w, n, err) n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.PubKey, w, n, err) n, err = WriteOnto(self.PubKey, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err) n, err = WriteOnto(self.Signature, w, n, err)
return return
} }

View File

@ -1,8 +1,8 @@
package blocks package blocks
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"io" "io"
) )
/* /*
@ -11,22 +11,22 @@ Typically only the signature is passed around, as the hash & height are implied.
*/ */
type Vote struct { type Vote struct {
Height UInt64 Height UInt64
BlockHash ByteSlice BlockHash ByteSlice
Signature Signature
} }
func ReadVote(r io.Reader) Vote { func ReadVote(r io.Reader) Vote {
return Vote{ return Vote{
Height: ReadUInt64(r), Height: ReadUInt64(r),
BlockHash: ReadByteSlice(r), BlockHash: ReadByteSlice(r),
Signature: ReadSignature(r), Signature: ReadSignature(r),
} }
} }
func (self Vote) WriteTo(w io.Writer) (n int64, err error) { func (self Vote) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Height, w, n, err) n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.BlockHash, w, n, err) n, err = WriteOnto(self.BlockHash, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err) n, err = WriteOnto(self.Signature, w, n, err)
return return
} }

View File

@ -1,45 +1,48 @@
package common package common
import ( import (
"time" "sync"
"sync" "time"
) )
/* Debouncer */ /* Debouncer */
type Debouncer struct { type Debouncer struct {
Ch chan struct{} Ch chan struct{}
quit chan struct{} quit chan struct{}
dur time.Duration dur time.Duration
mtx sync.Mutex mtx sync.Mutex
timer *time.Timer timer *time.Timer
} }
func NewDebouncer(dur time.Duration) *Debouncer { func NewDebouncer(dur time.Duration) *Debouncer {
var timer *time.Timer var timer *time.Timer
var ch = make(chan struct{}) var ch = make(chan struct{})
var quit = make(chan struct{}) var quit = make(chan struct{})
var mtx sync.Mutex var mtx sync.Mutex
fire := func() { fire := func() {
go func() { go func() {
select { select {
case ch <- struct{}{}: case ch <- struct{}{}:
case <-quit: case <-quit:
} }
}() }()
mtx.Lock(); defer mtx.Unlock() mtx.Lock()
timer.Reset(dur) defer mtx.Unlock()
} timer.Reset(dur)
timer = time.AfterFunc(dur, fire) }
return &Debouncer{Ch:ch, dur:dur, quit:quit, mtx:mtx, timer:timer} timer = time.AfterFunc(dur, fire)
return &Debouncer{Ch: ch, dur: dur, quit: quit, mtx: mtx, timer: timer}
} }
func (d *Debouncer) Reset() { func (d *Debouncer) Reset() {
d.mtx.Lock(); defer d.mtx.Unlock() d.mtx.Lock()
d.timer.Reset(d.dur) defer d.mtx.Unlock()
d.timer.Reset(d.dur)
} }
func (d *Debouncer) Stop() bool { func (d *Debouncer) Stop() bool {
d.mtx.Lock(); defer d.mtx.Unlock() d.mtx.Lock()
close(d.quit) defer d.mtx.Unlock()
return d.timer.Stop() close(d.quit)
return d.timer.Stop()
} }

View File

@ -1,28 +1,28 @@
package common package common
import ( import (
"container/heap" "container/heap"
) )
type Heap struct { type Heap struct {
pq priorityQueue pq priorityQueue
} }
func NewHeap() *Heap { func NewHeap() *Heap {
return &Heap{pq:make([]*pqItem, 0)} return &Heap{pq: make([]*pqItem, 0)}
} }
func (h *Heap) Len() int { func (h *Heap) Len() int {
return len(h.pq) return len(h.pq)
} }
func (h *Heap) Push(value interface{}, priority int) { func (h *Heap) Push(value interface{}, priority int) {
heap.Push(&h.pq, &pqItem{value:value, priority:priority}) heap.Push(&h.pq, &pqItem{value: value, priority: priority})
} }
func (h *Heap) Pop() interface{} { func (h *Heap) Pop() interface{} {
item := heap.Pop(&h.pq).(*pqItem) item := heap.Pop(&h.pq).(*pqItem)
return item.value return item.value
} }
/* /*
@ -43,9 +43,9 @@ func main() {
// From: http://golang.org/pkg/container/heap/#example__priorityQueue // From: http://golang.org/pkg/container/heap/#example__priorityQueue
type pqItem struct { type pqItem struct {
value interface{} value interface{}
priority int priority int
index int index int
} }
type priorityQueue []*pqItem type priorityQueue []*pqItem
@ -53,35 +53,34 @@ type priorityQueue []*pqItem
func (pq priorityQueue) Len() int { return len(pq) } func (pq priorityQueue) Len() int { return len(pq) }
func (pq priorityQueue) Less(i, j int) bool { func (pq priorityQueue) Less(i, j int) bool {
return pq[i].priority < pq[j].priority return pq[i].priority < pq[j].priority
} }
func (pq priorityQueue) Swap(i, j int) { func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i] pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i pq[i].index = i
pq[j].index = j pq[j].index = j
} }
func (pq *priorityQueue) Push(x interface{}) { func (pq *priorityQueue) Push(x interface{}) {
n := len(*pq) n := len(*pq)
item := x.(*pqItem) item := x.(*pqItem)
item.index = n item.index = n
*pq = append(*pq, item) *pq = append(*pq, item)
} }
func (pq *priorityQueue) Pop() interface{} { func (pq *priorityQueue) Pop() interface{} {
old := *pq old := *pq
n := len(old) n := len(old)
item := old[n-1] item := old[n-1]
item.index = -1 // for safety item.index = -1 // for safety
*pq = old[0 : n-1] *pq = old[0 : n-1]
return item return item
} }
func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority int) { func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority int) {
heap.Remove(pq, item.index) heap.Remove(pq, item.index)
item.value = value item.value = value
item.priority = priority item.priority = priority
heap.Push(pq, item) heap.Push(pq, item)
} }

View File

@ -1,9 +1,9 @@
package common package common
import ( import (
"fmt" "fmt"
) )
func Panicf(s string, args ...interface{}) { func Panicf(s string, args ...interface{}) {
panic(fmt.Sprintf(s, args...)) panic(fmt.Sprintf(s, args...))
} }

View File

@ -1,109 +1,115 @@
package config package config
import ( import (
"encoding/json" "encoding/json"
"fmt" "errors"
"io/ioutil" "fmt"
"log" "io/ioutil"
"os" "log"
"path/filepath" "os"
"strings" "path/filepath"
"errors" "strings"
//"crypto/rand" //"crypto/rand"
//"encoding/hex" //"encoding/hex"
) )
var APP_DIR = os.Getenv("HOME") + "/.tendermint" var APP_DIR = os.Getenv("HOME") + "/.tendermint"
/* Global & initialization */ /* Global & initialization */
var Config Config_ var Config Config_
func init() { func init() {
configFile := APP_DIR+"/config.json" configFile := APP_DIR + "/config.json"
// try to read configuration. if missing, write default // try to read configuration. if missing, write default
configBytes, err := ioutil.ReadFile(configFile) configBytes, err := ioutil.ReadFile(configFile)
if err != nil { if err != nil {
defaultConfig.write(configFile) defaultConfig.write(configFile)
fmt.Println("Config file written to config.json. Please edit & run again") fmt.Println("Config file written to config.json. Please edit & run again")
os.Exit(1) os.Exit(1)
return return
} }
// try to parse configuration. on error, die // try to parse configuration. on error, die
Config = Config_{} Config = Config_{}
err = json.Unmarshal(configBytes, &Config) err = json.Unmarshal(configBytes, &Config)
if err != nil { if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err) log.Panicf("Invalid configuration file %s: %v", configFile, err)
} }
err = Config.validate() err = Config.validate()
if err != nil { if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err) log.Panicf("Invalid configuration file %s: %v", configFile, err)
} }
} }
/* Default configuration */ /* Default configuration */
var defaultConfig = Config_{ var defaultConfig = Config_{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 8770, Port: 8770,
Db: DbConfig{ Db: DbConfig{
Type: "level", Type: "level",
Dir: APP_DIR+"/data", Dir: APP_DIR + "/data",
}, },
Twilio: TwilioConfig{ Twilio: TwilioConfig{},
},
} }
/* Configuration types */ /* Configuration types */
type Config_ struct { type Config_ struct {
Host string Host string
Port int Port int
Db DbConfig Db DbConfig
Twilio TwilioConfig Twilio TwilioConfig
} }
type TwilioConfig struct { type TwilioConfig struct {
Sid string Sid string
Token string Token string
From string From string
To string To string
MinInterval int MinInterval int
} }
type DbConfig struct { type DbConfig struct {
Type string Type string
Dir string Dir string
} }
func (cfg *Config_) validate() error { func (cfg *Config_) validate() error {
if cfg.Host == "" { return errors.New("Host must be set") } if cfg.Host == "" {
if cfg.Port == 0 { return errors.New("Port must be set") } return errors.New("Host must be set")
if cfg.Db.Type == "" { return errors.New("Db.Type must be set") } }
return nil if cfg.Port == 0 {
return errors.New("Port must be set")
}
if cfg.Db.Type == "" {
return errors.New("Db.Type must be set")
}
return nil
} }
func (cfg *Config_) bytes() []byte { func (cfg *Config_) bytes() []byte {
configBytes, err := json.Marshal(cfg) configBytes, err := json.Marshal(cfg)
if err != nil { panic(err) } if err != nil {
return configBytes panic(err)
}
return configBytes
} }
func (cfg *Config_) write(configFile string) { func (cfg *Config_) write(configFile string) {
if strings.Index(configFile, "/") != -1 { if strings.Index(configFile, "/") != -1 {
err := os.MkdirAll(filepath.Dir(configFile), 0700) err := os.MkdirAll(filepath.Dir(configFile), 0700)
if err != nil { panic(err) } if err != nil {
} panic(err)
err := ioutil.WriteFile(configFile, cfg.bytes(), 0600) }
if err != nil { }
panic(err) err := ioutil.WriteFile(configFile, cfg.bytes(), 0600)
} if err != nil {
panic(err)
}
} }
/* TODO: generate priv/pub keys /* TODO: generate priv/pub keys
@ -113,4 +119,3 @@ func generateKeys() string {
return hex.EncodeToString(bytes[:]) return hex.EncodeToString(bytes[:])
} }
*/ */

View File

@ -11,60 +11,60 @@ import "C"
import "unsafe" import "unsafe"
type Verify struct { type Verify struct {
Message []byte Message []byte
PubKey []byte PubKey []byte
Signature []byte Signature []byte
Valid bool Valid bool
} }
func MakePubKey(privKey []byte) []byte { func MakePubKey(privKey []byte) []byte {
pubKey := [32]byte{} pubKey := [32]byte{}
C.ed25519_publickey( C.ed25519_publickey(
(*C.uchar)(unsafe.Pointer(&privKey[0])), (*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])), (*C.uchar)(unsafe.Pointer(&pubKey[0])),
) )
return pubKey[:] return pubKey[:]
} }
func SignMessage(message []byte, privKey []byte, pubKey []byte) []byte { func SignMessage(message []byte, privKey []byte, pubKey []byte) []byte {
sig := [64]byte{} sig := [64]byte{}
C.ed25519_sign( C.ed25519_sign(
(*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)), (*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)),
(*C.uchar)(unsafe.Pointer(&privKey[0])), (*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])), (*C.uchar)(unsafe.Pointer(&pubKey[0])),
(*C.uchar)(unsafe.Pointer(&sig[0])), (*C.uchar)(unsafe.Pointer(&sig[0])),
) )
return sig[:] return sig[:]
} }
func VerifyBatch(verifys []*Verify) bool { func VerifyBatch(verifys []*Verify) bool {
count := len(verifys) count := len(verifys)
msgs := make([]*byte, count) msgs := make([]*byte, count)
lens := make([]C.size_t, count) lens := make([]C.size_t, count)
pubs := make([]*byte, count) pubs := make([]*byte, count)
sigs := make([]*byte, count) sigs := make([]*byte, count)
valids := make([]C.int, count) valids := make([]C.int, count)
for i, v := range verifys { for i, v := range verifys {
msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0])) msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0]))
lens[i] = (C.size_t)(len(v.Message)) lens[i] = (C.size_t)(len(v.Message))
pubs[i] = (*byte)(&v.PubKey[0]) pubs[i] = (*byte)(&v.PubKey[0])
sigs[i] = (*byte)(&v.Signature[0]) sigs[i] = (*byte)(&v.Signature[0])
} }
count_ := (C.size_t)(count) count_ := (C.size_t)(count)
msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0])) msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0]))
lens_ := (*C.size_t)(unsafe.Pointer(&lens[0])) lens_ := (*C.size_t)(unsafe.Pointer(&lens[0]))
pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0])) pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0]))
sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0])) sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0]))
res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0]) res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0])
for i, valid := range valids { for i, valid := range valids {
verifys[i].Valid = valid > 0 verifys[i].Valid = valid > 0
} }
return res == 0 return res == 0
} }

View File

@ -1,35 +1,47 @@
package crypto package crypto
import ( import (
"testing" "crypto/rand"
"crypto/rand" "testing"
) )
func TestSign(t *testing.T) { func TestSign(t *testing.T) {
privKey := make([]byte, 32) privKey := make([]byte, 32)
_, err := rand.Read(privKey) _, err := rand.Read(privKey)
if err != nil { t.Fatal(err) } if err != nil {
pubKey := MakePubKey(privKey) t.Fatal(err)
signature := SignMessage([]byte("hello"), privKey, pubKey) }
pubKey := MakePubKey(privKey)
signature := SignMessage([]byte("hello"), privKey, pubKey)
v1 := &Verify{ v1 := &Verify{
Message: []byte("hello"), Message: []byte("hello"),
PubKey: pubKey, PubKey: pubKey,
Signature: signature, Signature: signature,
} }
ok := VerifyBatch([]*Verify{v1, v1, v1, v1}) ok := VerifyBatch([]*Verify{v1, v1, v1, v1})
if ok != true { t.Fatal("Expected ok == true") } if ok != true {
if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") } t.Fatal("Expected ok == true")
}
if v1.Valid != true {
t.Fatal("Expected v1.Valid to be true")
}
v2 := &Verify{ v2 := &Verify{
Message: []byte{0x73}, Message: []byte{0x73},
PubKey: pubKey, PubKey: pubKey,
Signature: signature, Signature: signature,
} }
ok = VerifyBatch([]*Verify{v1, v1, v1, v2}) ok = VerifyBatch([]*Verify{v1, v1, v1, v2})
if ok != false { t.Fatal("Expected ok == false") } if ok != false {
if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") } t.Fatal("Expected ok == false")
if v2.Valid != false { t.Fatal("Expected v2.Valid to be true") } }
if v1.Valid != true {
t.Fatal("Expected v1.Valid to be true")
}
if v2.Valid != false {
t.Fatal("Expected v2.Valid to be true")
}
} }

View File

@ -1,54 +1,60 @@
package db package db
import ( import (
"fmt" "fmt"
"github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb"
"path" "path"
) )
type LevelDB struct { type LevelDB struct {
db *leveldb.DB db *leveldb.DB
} }
func NewLevelDB(name string) (*LevelDB, error) { func NewLevelDB(name string) (*LevelDB, error) {
dbPath := path.Join(name) dbPath := path.Join(name)
db, err := leveldb.OpenFile(dbPath, nil) db, err := leveldb.OpenFile(dbPath, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
database := &LevelDB{db: db} database := &LevelDB{db: db}
return database, nil return database, nil
} }
func (db *LevelDB) Put(key []byte, value []byte) { func (db *LevelDB) Put(key []byte, value []byte) {
err := db.db.Put(key, value, nil) err := db.db.Put(key, value, nil)
if err != nil { panic(err) } if err != nil {
panic(err)
}
} }
func (db *LevelDB) Get(key []byte) ([]byte) { func (db *LevelDB) Get(key []byte) []byte {
res, err := db.db.Get(key, nil) res, err := db.db.Get(key, nil)
if err != nil { panic(err) } if err != nil {
return res panic(err)
}
return res
} }
func (db *LevelDB) Delete(key []byte) { func (db *LevelDB) Delete(key []byte) {
err := db.db.Delete(key, nil) err := db.db.Delete(key, nil)
if err != nil { panic(err) } if err != nil {
panic(err)
}
} }
func (db *LevelDB) Db() *leveldb.DB { func (db *LevelDB) Db() *leveldb.DB {
return db.db return db.db
} }
func (db *LevelDB) Close() { func (db *LevelDB) Close() {
db.db.Close() db.db.Close()
} }
func (db *LevelDB) Print() { func (db *LevelDB) Print() {
iter := db.db.NewIterator(nil, nil) iter := db.db.NewIterator(nil, nil)
for iter.Next() { for iter.Next() {
key := iter.Key() key := iter.Key()
value := iter.Value() value := iter.Value()
fmt.Printf("[%x]:\t[%x]", key, value) fmt.Printf("[%x]:\t[%x]", key, value)
} }
} }

View File

@ -1,32 +1,32 @@
package db package db
import ( import (
"fmt" "fmt"
) )
type MemDB struct { type MemDB struct {
db map[string][]byte db map[string][]byte
} }
func NewMemDB() (*MemDB) { func NewMemDB() *MemDB {
database := &MemDB{db:make(map[string][]byte)} database := &MemDB{db: make(map[string][]byte)}
return database return database
} }
func (db *MemDB) Put(key []byte, value []byte) { func (db *MemDB) Put(key []byte, value []byte) {
db.db[string(key)] = value db.db[string(key)] = value
} }
func (db *MemDB) Get(key []byte) ([]byte) { func (db *MemDB) Get(key []byte) []byte {
return db.db[string(key)] return db.db[string(key)]
} }
func (db *MemDB) Delete(key []byte) { func (db *MemDB) Delete(key []byte) {
delete(db.db, string(key)) delete(db.db, string(key))
} }
func (db *MemDB) Print() { func (db *MemDB) Print() {
for key, value := range db.db { for key, value := range db.db {
fmt.Printf("[%x]:\t[%x]", []byte(key), value) fmt.Printf("[%x]:\t[%x]", []byte(key), value)
} }
} }

View File

@ -1,405 +1,447 @@
package merkle package merkle
import ( import (
. "github.com/tendermint/tendermint/binary" "bytes"
"bytes" "crypto/sha256"
"io" . "github.com/tendermint/tendermint/binary"
"crypto/sha256" "io"
) )
// Node // Node
type IAVLNode struct { type IAVLNode struct {
key Key key Key
value Value value Value
size uint64 size uint64
height uint8 height uint8
hash ByteSlice hash ByteSlice
left *IAVLNode left *IAVLNode
right *IAVLNode right *IAVLNode
// volatile // volatile
flags byte flags byte
} }
const ( const (
IAVLNODE_FLAG_PERSISTED = byte(0x01) IAVLNODE_FLAG_PERSISTED = byte(0x01)
IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) IAVLNODE_FLAG_PLACEHOLDER = byte(0x02)
) )
func NewIAVLNode(key Key, value Value) *IAVLNode { func NewIAVLNode(key Key, value Value) *IAVLNode {
return &IAVLNode{ return &IAVLNode{
key: key, key: key,
value: value, value: value,
size: 1, size: 1,
} }
} }
func (self *IAVLNode) Copy() *IAVLNode { func (self *IAVLNode) Copy() *IAVLNode {
if self.height == 0 { if self.height == 0 {
panic("Why are you copying a value node?") panic("Why are you copying a value node?")
} }
return &IAVLNode{ return &IAVLNode{
key: self.key, key: self.key,
size: self.size, size: self.size,
height: self.height, height: self.height,
left: self.left, left: self.left,
right: self.right, right: self.right,
hash: nil, hash: nil,
flags: byte(0), flags: byte(0),
} }
} }
func (self *IAVLNode) Key() Key { func (self *IAVLNode) Key() Key {
return self.key return self.key
} }
func (self *IAVLNode) Value() Value { func (self *IAVLNode) Value() Value {
return self.value return self.value
} }
func (self *IAVLNode) Size() uint64 { func (self *IAVLNode) Size() uint64 {
return self.size return self.size
} }
func (self *IAVLNode) Height() uint8 { func (self *IAVLNode) Height() uint8 {
return self.height return self.height
} }
func (self *IAVLNode) has(db Db, key Key) (has bool) { func (self *IAVLNode) has(db Db, key Key) (has bool) {
if self.key.Equals(key) { if self.key.Equals(key) {
return true return true
} }
if self.height == 0 { if self.height == 0 {
return false return false
} else { } else {
if key.Less(self.key) { if key.Less(self.key) {
return self.leftFilled(db).has(db, key) return self.leftFilled(db).has(db, key)
} else { } else {
return self.rightFilled(db).has(db, key) return self.rightFilled(db).has(db, key)
} }
} }
} }
func (self *IAVLNode) get(db Db, key Key) (value Value) { func (self *IAVLNode) get(db Db, key Key) (value Value) {
if self.height == 0 { if self.height == 0 {
if self.key.Equals(key) { if self.key.Equals(key) {
return self.value return self.value
} else { } else {
return nil return nil
} }
} else { } else {
if key.Less(self.key) { if key.Less(self.key) {
return self.leftFilled(db).get(db, key) return self.leftFilled(db).get(db, key)
} else { } else {
return self.rightFilled(db).get(db, key) return self.rightFilled(db).get(db, key)
} }
} }
} }
func (self *IAVLNode) Hash() (ByteSlice, uint64) { func (self *IAVLNode) Hash() (ByteSlice, uint64) {
if self.hash != nil { if self.hash != nil {
return self.hash, 0 return self.hash, 0
} }
hasher := sha256.New() hasher := sha256.New()
_, hashCount, err := self.saveToCountHashes(hasher, false) _, hashCount, err := self.saveToCountHashes(hasher, false)
if err != nil { panic(err) } if err != nil {
self.hash = hasher.Sum(nil) panic(err)
}
self.hash = hasher.Sum(nil)
return self.hash, hashCount+1 return self.hash, hashCount + 1
} }
func (self *IAVLNode) Save(db Db) { func (self *IAVLNode) Save(db Db) {
if self.hash == nil { if self.hash == nil {
panic("savee.hash can't be nil") panic("savee.hash can't be nil")
} }
if self.flags & IAVLNODE_FLAG_PERSISTED > 0 || if self.flags&IAVLNODE_FLAG_PERSISTED > 0 ||
self.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { self.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
return return
} }
// children // children
if self.height > 0 { if self.height > 0 {
self.left.Save(db) self.left.Save(db)
self.right.Save(db) self.right.Save(db)
} }
// save self // save self
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err := self.WriteTo(buf) _, err := self.WriteTo(buf)
if err != nil { panic(err) } if err != nil {
db.Put([]byte(self.hash), buf.Bytes()) panic(err)
}
db.Put([]byte(self.hash), buf.Bytes())
self.flags |= IAVLNODE_FLAG_PERSISTED self.flags |= IAVLNODE_FLAG_PERSISTED
} }
func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) {
if self.height == 0 { if self.height == 0 {
if key.Less(self.key) { if key.Less(self.key) {
return &IAVLNode{ return &IAVLNode{
key: self.key, key: self.key,
height: 1, height: 1,
size: 2, size: 2,
left: NewIAVLNode(key, value), left: NewIAVLNode(key, value),
right: self, right: self,
}, false }, false
} else if self.key.Equals(key) { } else if self.key.Equals(key) {
return NewIAVLNode(key, value), true return NewIAVLNode(key, value), true
} else { } else {
return &IAVLNode{ return &IAVLNode{
key: key, key: key,
height: 1, height: 1,
size: 2, size: 2,
left: self, left: self,
right: NewIAVLNode(key, value), right: NewIAVLNode(key, value),
}, false }, false
} }
} else { } else {
self = self.Copy() self = self.Copy()
if key.Less(self.key) { if key.Less(self.key) {
self.left, updated = self.leftFilled(db).put(db, key, value) self.left, updated = self.leftFilled(db).put(db, key, value)
} else { } else {
self.right, updated = self.rightFilled(db).put(db, key, value) self.right, updated = self.rightFilled(db).put(db, key, value)
} }
if updated { if updated {
return self, updated return self, updated
} else { } else {
self.calcHeightAndSize(db) self.calcHeightAndSize(db)
return self.balance(db), updated return self.balance(db), updated
} }
} }
} }
// newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed.
func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) { func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) {
if self.height == 0 { if self.height == 0 {
if self.key.Equals(key) { if self.key.Equals(key) {
return nil, nil, self.value, nil return nil, nil, self.value, nil
} else { } else {
return self, nil, nil, NotFound(key) return self, nil, nil, NotFound(key)
} }
} else { } else {
if key.Less(self.key) { if key.Less(self.key) {
var newLeft *IAVLNode var newLeft *IAVLNode
newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) newLeft, newKey, value, err = self.leftFilled(db).remove(db, key)
if err != nil { if err != nil {
return self, nil, value, err return self, nil, value, err
} else if newLeft == nil { // left node held value, was removed } else if newLeft == nil { // left node held value, was removed
return self.right, self.key, value, nil return self.right, self.key, value, nil
} }
self = self.Copy() self = self.Copy()
self.left = newLeft self.left = newLeft
} else { } else {
var newRight *IAVLNode var newRight *IAVLNode
newRight, newKey, value, err = self.rightFilled(db).remove(db, key) newRight, newKey, value, err = self.rightFilled(db).remove(db, key)
if err != nil { if err != nil {
return self, nil, value, err return self, nil, value, err
} else if newRight == nil { // right node held value, was removed } else if newRight == nil { // right node held value, was removed
return self.left, nil, value, nil return self.left, nil, value, nil
} }
self = self.Copy() self = self.Copy()
self.right = newRight self.right = newRight
if newKey != nil { if newKey != nil {
self.key = newKey self.key = newKey
newKey = nil newKey = nil
} }
} }
self.calcHeightAndSize(db) self.calcHeightAndSize(db)
return self.balance(db), newKey, value, err return self.balance(db), newKey, value, err
} }
} }
func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) {
n, _, err = self.saveToCountHashes(w, true) n, _, err = self.saveToCountHashes(w, true)
return return
} }
func (self *IAVLNode) saveToCountHashes(w io.Writer, meta bool) (n int64, hashCount uint64, err error) { func (self *IAVLNode) saveToCountHashes(w io.Writer, meta bool) (n int64, hashCount uint64, err error) {
var _n int64 var _n int64
if meta { if meta {
// height & size // height & size
_n, err = UInt8(self.height).WriteTo(w) _n, err = UInt8(self.height).WriteTo(w)
if err != nil { return } else { n += _n } if err != nil {
_n, err = UInt64(self.size).WriteTo(w) return
if err != nil { return } else { n += _n } } else {
n += _n
}
_n, err = UInt64(self.size).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
// key // key
_n, err = Byte(GetBinaryType(self.key)).WriteTo(w) _n, err = Byte(GetBinaryType(self.key)).WriteTo(w)
if err != nil { return } else { n += _n } if err != nil {
_n, err = self.key.WriteTo(w) return
if err != nil { return } else { n += _n } } else {
} n += _n
}
_n, err = self.key.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
if self.height == 0 { if self.height == 0 {
// value // value
_n, err = Byte(GetBinaryType(self.value)).WriteTo(w) _n, err = Byte(GetBinaryType(self.value)).WriteTo(w)
if err != nil { return } else { n += _n } if err != nil {
if self.value != nil { return
_n, err = self.value.WriteTo(w) } else {
if err != nil { return } else { n += _n } n += _n
} }
} else { if self.value != nil {
// left _n, err = self.value.WriteTo(w)
leftHash, leftCount := self.left.Hash() if err != nil {
hashCount += leftCount return
_n, err = leftHash.WriteTo(w) } else {
if err != nil { return } else { n += _n } n += _n
// right }
rightHash, rightCount := self.right.Hash() }
hashCount += rightCount } else {
_n, err = rightHash.WriteTo(w) // left
if err != nil { return } else { n += _n } leftHash, leftCount := self.left.Hash()
} hashCount += leftCount
_n, err = leftHash.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
// right
rightHash, rightCount := self.right.Hash()
hashCount += rightCount
_n, err = rightHash.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
return return
} }
// Given a placeholder node which has only the hash set, // Given a placeholder node which has only the hash set,
// load the rest of the data from db. // load the rest of the data from db.
// Not threadsafe. // Not threadsafe.
func (self *IAVLNode) fill(db Db) { func (self *IAVLNode) fill(db Db) {
if self.hash == nil { if self.hash == nil {
panic("placeholder.hash can't be nil") panic("placeholder.hash can't be nil")
} }
buf := db.Get(self.hash) buf := db.Get(self.hash)
r := bytes.NewReader(buf) r := bytes.NewReader(buf)
// node header // node header
self.height = uint8(ReadUInt8(r)) self.height = uint8(ReadUInt8(r))
self.size = uint64(ReadUInt64(r)) self.size = uint64(ReadUInt64(r))
// key // key
key := ReadBinary(r) key := ReadBinary(r)
self.key = key.(Key) self.key = key.(Key)
if self.height == 0 { if self.height == 0 {
// value // value
self.value = ReadBinary(r) self.value = ReadBinary(r)
} else { } else {
// left // left
var leftHash ByteSlice var leftHash ByteSlice
leftHash = ReadByteSlice(r) leftHash = ReadByteSlice(r)
self.left = &IAVLNode{ self.left = &IAVLNode{
hash: leftHash, hash: leftHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
} }
// right // right
var rightHash ByteSlice var rightHash ByteSlice
rightHash = ReadByteSlice(r) rightHash = ReadByteSlice(r)
self.right = &IAVLNode{ self.right = &IAVLNode{
hash: rightHash, hash: rightHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
} }
if r.Len() != 0 { if r.Len() != 0 {
panic("buf not all consumed") panic("buf not all consumed")
} }
} }
self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER
} }
func (self *IAVLNode) leftFilled(db Db) *IAVLNode { func (self *IAVLNode) leftFilled(db Db) *IAVLNode {
if self.left.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { if self.left.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.left.fill(db) self.left.fill(db)
} }
return self.left return self.left
} }
func (self *IAVLNode) rightFilled(db Db) *IAVLNode { func (self *IAVLNode) rightFilled(db Db) *IAVLNode {
if self.right.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { if self.right.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.right.fill(db) self.right.fill(db)
} }
return self.right return self.right
} }
func (self *IAVLNode) rotateRight(db Db) *IAVLNode { func (self *IAVLNode) rotateRight(db Db) *IAVLNode {
self = self.Copy() self = self.Copy()
sl := self.leftFilled(db).Copy() sl := self.leftFilled(db).Copy()
slr := sl.right slr := sl.right
sl.right = self sl.right = self
self.left = slr self.left = slr
self.calcHeightAndSize(db) self.calcHeightAndSize(db)
sl.calcHeightAndSize(db) sl.calcHeightAndSize(db)
return sl return sl
} }
func (self *IAVLNode) rotateLeft(db Db) *IAVLNode { func (self *IAVLNode) rotateLeft(db Db) *IAVLNode {
self = self.Copy() self = self.Copy()
sr := self.rightFilled(db).Copy() sr := self.rightFilled(db).Copy()
srl := sr.left srl := sr.left
sr.left = self sr.left = self
self.right = srl self.right = srl
self.calcHeightAndSize(db) self.calcHeightAndSize(db)
sr.calcHeightAndSize(db) sr.calcHeightAndSize(db)
return sr return sr
} }
func (self *IAVLNode) calcHeightAndSize(db Db) { func (self *IAVLNode) calcHeightAndSize(db Db) {
self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1 self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1
self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size()
} }
func (self *IAVLNode) calcBalance(db Db) int { func (self *IAVLNode) calcBalance(db Db) int {
return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height()) return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height())
} }
func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) { func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) {
balance := self.calcBalance(db) balance := self.calcBalance(db)
if (balance > 1) { if balance > 1 {
if (self.leftFilled(db).calcBalance(db) >= 0) { if self.leftFilled(db).calcBalance(db) >= 0 {
// Left Left Case // Left Left Case
return self.rotateRight(db) return self.rotateRight(db)
} else { } else {
// Left Right Case // Left Right Case
self = self.Copy() self = self.Copy()
self.left = self.leftFilled(db).rotateLeft(db) self.left = self.leftFilled(db).rotateLeft(db)
//self.calcHeightAndSize() //self.calcHeightAndSize()
return self.rotateRight(db) return self.rotateRight(db)
} }
} }
if (balance < -1) { if balance < -1 {
if (self.rightFilled(db).calcBalance(db) <= 0) { if self.rightFilled(db).calcBalance(db) <= 0 {
// Right Right Case // Right Right Case
return self.rotateLeft(db) return self.rotateLeft(db)
} else { } else {
// Right Left Case // Right Left Case
self = self.Copy() self = self.Copy()
self.right = self.rightFilled(db).rotateRight(db) self.right = self.rightFilled(db).rotateRight(db)
//self.calcHeightAndSize() //self.calcHeightAndSize()
return self.rotateLeft(db) return self.rotateLeft(db)
} }
} }
// Nothing changed // Nothing changed
return self return self
} }
func (self *IAVLNode) lmd(db Db) (*IAVLNode) { func (self *IAVLNode) lmd(db Db) *IAVLNode {
if self.height == 0 { if self.height == 0 {
return self return self
} }
return self.leftFilled(db).lmd(db) return self.leftFilled(db).lmd(db)
} }
func (self *IAVLNode) rmd(db Db) (*IAVLNode) { func (self *IAVLNode) rmd(db Db) *IAVLNode {
if self.height == 0 { if self.height == 0 {
return self return self
} }
return self.rightFilled(db).rmd(db) return self.rightFilled(db).rmd(db)
} }
func (self *IAVLNode) traverse(db Db, cb func(Node)bool) bool { func (self *IAVLNode) traverse(db Db, cb func(Node) bool) bool {
stop := cb(self) stop := cb(self)
if stop { return stop } if stop {
if self.height > 0 { return stop
stop = self.leftFilled(db).traverse(db, cb) }
if stop { return stop } if self.height > 0 {
stop = self.rightFilled(db).traverse(db, cb) stop = self.leftFilled(db).traverse(db, cb)
if stop { return stop } if stop {
} return stop
return false }
stop = self.rightFilled(db).traverse(db, cb)
if stop {
return stop
}
}
return false
} }

View File

@ -1,283 +1,283 @@
package merkle package merkle
import ( import (
. "github.com/tendermint/tendermint/binary" "bytes"
"testing" "crypto/sha256"
"fmt" "encoding/binary"
"os" "fmt"
"bytes" . "github.com/tendermint/tendermint/binary"
"math/rand" "github.com/tendermint/tendermint/db"
"encoding/binary" "math/rand"
"github.com/tendermint/tendermint/db" "os"
"crypto/sha256" "runtime"
"runtime" "testing"
) )
func init() { func init() {
if urandom, err := os.Open("/dev/urandom"); err != nil { if urandom, err := os.Open("/dev/urandom"); err != nil {
return return
} else { } else {
buf := make([]byte, 8) buf := make([]byte, 8)
if _, err := urandom.Read(buf); err == nil { if _, err := urandom.Read(buf); err == nil {
buf_reader := bytes.NewReader(buf) buf_reader := bytes.NewReader(buf)
if seed, err := binary.ReadVarint(buf_reader); err == nil { if seed, err := binary.ReadVarint(buf_reader); err == nil {
rand.Seed(seed) rand.Seed(seed)
} }
} }
urandom.Close() urandom.Close()
} }
} }
func TestUnit(t *testing.T) { func TestUnit(t *testing.T) {
// Convenience for a new node // Convenience for a new node
N := func(l, r interface{}) *IAVLNode { N := func(l, r interface{}) *IAVLNode {
var left, right *IAVLNode var left, right *IAVLNode
if _, ok := l.(*IAVLNode); ok { if _, ok := l.(*IAVLNode); ok {
left = l.(*IAVLNode) left = l.(*IAVLNode)
} else { } else {
left = NewIAVLNode(Int32(l.(int)), nil) left = NewIAVLNode(Int32(l.(int)), nil)
} }
if _, ok := r.(*IAVLNode); ok { if _, ok := r.(*IAVLNode); ok {
right = r.(*IAVLNode) right = r.(*IAVLNode)
} else { } else {
right = NewIAVLNode(Int32(r.(int)), nil) right = NewIAVLNode(Int32(r.(int)), nil)
} }
n := &IAVLNode{ n := &IAVLNode{
key: right.lmd(nil).key, key: right.lmd(nil).key,
left: left, left: left,
right: right, right: right,
} }
n.calcHeightAndSize(nil) n.calcHeightAndSize(nil)
n.Hash() n.Hash()
return n return n
} }
// Convenience for simple printing of keys & tree structure // Convenience for simple printing of keys & tree structure
var P func(*IAVLNode) string var P func(*IAVLNode) string
P = func(n *IAVLNode) string { P = func(n *IAVLNode) string {
if n.height == 0 { if n.height == 0 {
return fmt.Sprintf("%v", n.key) return fmt.Sprintf("%v", n.key)
} else { } else {
return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) return fmt.Sprintf("(%v %v)", P(n.left), P(n.right))
} }
} }
expectHash := func(n2 *IAVLNode, hashCount uint64) { expectHash := func(n2 *IAVLNode, hashCount uint64) {
// ensure number of new hash calculations is as expected. // ensure number of new hash calculations is as expected.
hash, count := n2.Hash() hash, count := n2.Hash()
if count != hashCount { if count != hashCount {
t.Fatalf("Expected %v new hashes, got %v", hashCount, count) t.Fatalf("Expected %v new hashes, got %v", hashCount, count)
} }
// nuke hashes and reconstruct hash, ensure it's the same. // nuke hashes and reconstruct hash, ensure it's the same.
(&IAVLTree{root:n2}).Traverse(func(node Node) bool { (&IAVLTree{root: n2}).Traverse(func(node Node) bool {
node.(*IAVLNode).hash = nil node.(*IAVLNode).hash = nil
return false return false
}) })
// ensure that the new hash after nuking is the same as the old. // ensure that the new hash after nuking is the same as the old.
newHash, _ := n2.Hash() newHash, _ := n2.Hash()
if bytes.Compare(hash, newHash) != 0 { if bytes.Compare(hash, newHash) != 0 {
t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash)
} }
} }
expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, updated := n.put(nil, Int32(i), nil) n2, updated := n.put(nil, Int32(i), nil)
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if updated == true || P(n2) != repr { if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
i, P(n), repr, P(n2), updated) i, P(n), repr, P(n2), updated)
} }
// ensure hash calculation requirements // ensure hash calculation requirements
expectHash(n2, hashCount) expectHash(n2, hashCount)
} }
expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, _, value, err := n.remove(nil, Int32(i)) n2, _, value, err := n.remove(nil, Int32(i))
// ensure node was added & structure is as expected. // ensure node was added & structure is as expected.
if value != nil || err != nil || P(n2) != repr { if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",
i, P(n), repr, P(n2), value, err) i, P(n), repr, P(n2), value, err)
} }
// ensure hash calculation requirements // ensure hash calculation requirements
expectHash(n2, hashCount) expectHash(n2, hashCount)
} }
//////// Test Put cases: //////// Test Put cases:
// Case 1: // Case 1:
n1 := N(4, 20) n1 := N(4, 20)
expectPut(n1, 8, "((4 8) 20)", 3) expectPut(n1, 8, "((4 8) 20)", 3)
expectPut(n1, 25, "(4 (20 25))", 3) expectPut(n1, 25, "(4 (20 25))", 3)
n2 := N(4, N(20, 25)) n2 := N(4, N(20, 25))
expectPut(n2, 8, "((4 8) (20 25))", 3) expectPut(n2, 8, "((4 8) (20 25))", 3)
expectPut(n2, 30, "((4 20) (25 30))", 4) expectPut(n2, 30, "((4 20) (25 30))", 4)
n3 := N(N(1, 2), 6) n3 := N(N(1, 2), 6)
expectPut(n3, 4, "((1 2) (4 6))", 4) expectPut(n3, 4, "((1 2) (4 6))", 4)
expectPut(n3, 8, "((1 2) (6 8))", 3) expectPut(n3, 8, "((1 2) (6 8))", 3)
n4 := N(N(1, 2), N(N(5, 6), N(7, 9))) n4 := N(N(1, 2), N(N(5, 6), N(7, 9)))
expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5)
expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5)
//////// Test Remove cases: //////// Test Remove cases:
n10 := N(N(1, 2), 3) n10 := N(N(1, 2), 3)
expectRemove(n10, 2, "(1 3)", 1) expectRemove(n10, 2, "(1 3)", 1)
expectRemove(n10, 3, "(1 2)", 0) expectRemove(n10, 3, "(1 2)", 0)
n11 := N(N(N(1, 2), 3), N(4, 5)) n11 := N(N(N(1, 2), 3), N(4, 5))
expectRemove(n11, 4, "((1 2) (3 5))", 2) expectRemove(n11, 4, "((1 2) (3 5))", 2)
expectRemove(n11, 3, "((1 2) (4 5))", 1) expectRemove(n11, 3, "((1 2) (4 5))", 1)
} }
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
type record struct { type record struct {
key String key String
value String value String
} }
records := make([]*record, 400) records := make([]*record, 400)
var tree *IAVLTree = NewIAVLTree(nil) var tree *IAVLTree = NewIAVLTree(nil)
var err error var err error
var val Value var val Value
var updated bool var updated bool
randomRecord := func() *record { randomRecord := func() *record {
return &record{ randstr(20), randstr(20) } return &record{randstr(20), randstr(20)}
} }
for i := range records { for i := range records {
r := randomRecord() r := randomRecord()
records[i] = r records[i] = r
//t.Log("New record", r) //t.Log("New record", r)
//PrintIAVLNode(tree.root) //PrintIAVLNode(tree.root)
updated = tree.Put(r.key, String("")) updated = tree.Put(r.key, String(""))
if updated { if updated {
t.Error("should have not been updated") t.Error("should have not been updated")
} }
updated = tree.Put(r.key, r.value) updated = tree.Put(r.key, r.value)
if !updated { if !updated {
t.Error("should have been updated") t.Error("should have been updated")
} }
if tree.Size() != uint64(i+1) { if tree.Size() != uint64(i+1) {
t.Error("size was wrong", tree.Size(), i+1) t.Error("size was wrong", tree.Size(), i+1)
} }
} }
for _, r := range records { for _, r := range records {
if has := tree.Has(r.key); !has { if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key) t.Error("Missing key", r.key)
} }
if has := tree.Has(randstr(12)); has { if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value") t.Error("wrong value")
} }
} }
for i, x := range records { for i, x := range records {
if val, err = tree.Remove(x.key); err != nil { if val, err = tree.Remove(x.key); err != nil {
t.Error(err) t.Error(err)
} else if !(val.(String)).Equals(x.value) { } else if !(val.(String)).Equals(x.value) {
t.Error("wrong value") t.Error("wrong value")
} }
for _, r := range records[i+1:] { for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has { if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key) t.Error("Missing key", r.key)
} }
if has := tree.Has(randstr(12)); has { if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key") t.Error("Table has extra key")
} }
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value") t.Error("wrong value")
} }
} }
if tree.Size() != uint64(len(records) - (i+1)) { if tree.Size() != uint64(len(records)-(i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i+1))) t.Error("size was wrong", tree.Size(), (len(records) - (i + 1)))
} }
} }
} }
func TestPersistence(t *testing.T) { func TestPersistence(t *testing.T) {
db := db.NewMemDB() db := db.NewMemDB()
// Create some random key value pairs // Create some random key value pairs
records := make(map[String]String) records := make(map[String]String)
for i:=0; i<10000; i++ { for i := 0; i < 10000; i++ {
records[String(randstr(20))] = String(randstr(20)) records[String(randstr(20))] = String(randstr(20))
} }
// Construct some tree and save it // Construct some tree and save it
t1 := NewIAVLTree(db) t1 := NewIAVLTree(db)
for key, value := range records { for key, value := range records {
t1.Put(key, value) t1.Put(key, value)
} }
t1.Save() t1.Save()
hash, _ := t1.Hash() hash, _ := t1.Hash()
// Load a tree // Load a tree
t2 := NewIAVLTreeFromHash(db, hash) t2 := NewIAVLTreeFromHash(db, hash)
for key, value := range records { for key, value := range records {
t2value := t2.Get(key) t2value := t2.Get(key)
if !BinaryEqual(t2value, value) { if !BinaryEqual(t2value, value) {
t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) t.Fatalf("Invalid value. Expected %v, got %v", value, t2value)
} }
} }
} }
func BenchmarkHash(b *testing.B) { func BenchmarkHash(b *testing.B) {
b.StopTimer() b.StopTimer()
s := randstr(128) s := randstr(128)
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(s)) hasher.Write([]byte(s))
hasher.Sum(nil) hasher.Sum(nil)
} }
} }
func BenchmarkImmutableAvlTree(b *testing.B) { func BenchmarkImmutableAvlTree(b *testing.B) {
b.StopTimer() b.StopTimer()
type record struct { type record struct {
key String key String
value String value String
} }
randomRecord := func() *record { randomRecord := func() *record {
return &record{ randstr(32), randstr(32) } return &record{randstr(32), randstr(32)}
} }
t := NewIAVLTree(nil) t := NewIAVLTree(nil)
for i:=0; i<1000000; i++ { for i := 0; i < 1000000; i++ {
r := randomRecord() r := randomRecord()
t.Put(r.key, r.value) t.Put(r.key, r.value)
} }
fmt.Println("ok, starting") fmt.Println("ok, starting")
runtime.GC() runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
r := randomRecord() r := randomRecord()
t.Put(r.key, r.value) t.Put(r.key, r.value)
t.Remove(r.key) t.Remove(r.key)
} }
} }

View File

@ -1,10 +1,10 @@
package merkle package merkle
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
) )
const HASH_BYTE_SIZE int = 4+32 const HASH_BYTE_SIZE int = 4 + 32
/* /*
Immutable AVL Tree (wraps the Node root) Immutable AVL Tree (wraps the Node root)
@ -13,100 +13,118 @@ This tree is not concurrency safe.
You must wrap your calls with your own mutex. You must wrap your calls with your own mutex.
*/ */
type IAVLTree struct { type IAVLTree struct {
db Db db Db
root *IAVLNode root *IAVLNode
} }
func NewIAVLTree(db Db) *IAVLTree { func NewIAVLTree(db Db) *IAVLTree {
return &IAVLTree{db:db, root:nil} return &IAVLTree{db: db, root: nil}
} }
func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree {
root := &IAVLNode{ root := &IAVLNode{
hash: hash, hash: hash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
} }
root.fill(db) root.fill(db)
return &IAVLTree{db:db, root:root} return &IAVLTree{db: db, root: root}
} }
func (t *IAVLTree) Root() Node { func (t *IAVLTree) Root() Node {
return t.root return t.root
} }
func (t *IAVLTree) Size() uint64 { func (t *IAVLTree) Size() uint64 {
if t.root == nil { return 0 } if t.root == nil {
return t.root.Size() return 0
}
return t.root.Size()
} }
func (t *IAVLTree) Height() uint8 { func (t *IAVLTree) Height() uint8 {
if t.root == nil { return 0 } if t.root == nil {
return t.root.Height() return 0
}
return t.root.Height()
} }
func (t *IAVLTree) Has(key Key) bool { func (t *IAVLTree) Has(key Key) bool {
if t.root == nil { return false } if t.root == nil {
return t.root.has(t.db, key) return false
}
return t.root.has(t.db, key)
} }
func (t *IAVLTree) Put(key Key, value Value) (updated bool) { func (t *IAVLTree) Put(key Key, value Value) (updated bool) {
if t.root == nil { if t.root == nil {
t.root = NewIAVLNode(key, value) t.root = NewIAVLNode(key, value)
return false return false
} }
t.root, updated = t.root.put(t.db, key, value) t.root, updated = t.root.put(t.db, key, value)
return updated return updated
} }
func (t *IAVLTree) Hash() (ByteSlice, uint64) { func (t *IAVLTree) Hash() (ByteSlice, uint64) {
if t.root == nil { return nil, 0 } if t.root == nil {
return t.root.Hash() return nil, 0
}
return t.root.Hash()
} }
func (t *IAVLTree) Save() { func (t *IAVLTree) Save() {
if t.root == nil { return } if t.root == nil {
if t.root.hash == nil { return
t.root.Hash() }
} if t.root.hash == nil {
t.root.Save(t.db) t.root.Hash()
}
t.root.Save(t.db)
} }
func (t *IAVLTree) Get(key Key) (value Value) { func (t *IAVLTree) Get(key Key) (value Value) {
if t.root == nil { return nil } if t.root == nil {
return t.root.get(t.db, key) return nil
}
return t.root.get(t.db, key)
} }
func (t *IAVLTree) Remove(key Key) (value Value, err error) { func (t *IAVLTree) Remove(key Key) (value Value, err error) {
if t.root == nil { return nil, NotFound(key) } if t.root == nil {
newRoot, _, value, err := t.root.remove(t.db, key) return nil, NotFound(key)
if err != nil { }
return nil, err newRoot, _, value, err := t.root.remove(t.db, key)
} if err != nil {
t.root = newRoot return nil, err
return value, nil }
t.root = newRoot
return value, nil
} }
func (t *IAVLTree) Copy() Tree { func (t *IAVLTree) Copy() Tree {
return &IAVLTree{db:t.db, root:t.root} return &IAVLTree{db: t.db, root: t.root}
} }
// Traverses all the nodes of the tree in prefix order. // Traverses all the nodes of the tree in prefix order.
// return true from cb to halt iteration. // return true from cb to halt iteration.
// node.Height() == 0 if you just want a value node. // node.Height() == 0 if you just want a value node.
func (t *IAVLTree) Traverse(cb func(Node) bool) { func (t *IAVLTree) Traverse(cb func(Node) bool) {
if t.root == nil { return } if t.root == nil {
t.root.traverse(t.db, cb) return
}
t.root.traverse(t.db, cb)
} }
func (t *IAVLTree) Values() <-chan Value { func (t *IAVLTree) Values() <-chan Value {
root := t.root root := t.root
ch := make(chan Value) ch := make(chan Value)
go func() { go func() {
root.traverse(t.db, func(n Node) bool { root.traverse(t.db, func(n Node) bool {
if n.Height() == 0 { ch <- n.Value() } if n.Height() == 0 {
return true ch <- n.Value()
}) }
close(ch) return true
}() })
return ch close(ch)
}()
return ch
} }

View File

@ -1,50 +1,50 @@
package merkle package merkle
import ( import (
. "github.com/tendermint/tendermint/binary" "fmt"
"fmt" . "github.com/tendermint/tendermint/binary"
) )
type Value interface { type Value interface {
Binary Binary
} }
type Key interface { type Key interface {
Binary Binary
Equals(Binary) bool Equals(Binary) bool
Less(b Binary) bool Less(b Binary) bool
} }
type Db interface { type Db interface {
Get([]byte) []byte Get([]byte) []byte
Put([]byte, []byte) Put([]byte, []byte)
} }
type Node interface { type Node interface {
Binary Binary
Key() Key Key() Key
Value() Value Value() Value
Size() uint64 Size() uint64
Height() uint8 Height() uint8
Hash() (ByteSlice, uint64) Hash() (ByteSlice, uint64)
Save(Db) Save(Db)
} }
type Tree interface { type Tree interface {
Root() Node Root() Node
Size() uint64 Size() uint64
Height() uint8 Height() uint8
Has(key Key) bool Has(key Key) bool
Get(key Key) Value Get(key Key) Value
Hash() (ByteSlice, uint64) Hash() (ByteSlice, uint64)
Save() Save()
Put(Key, Value) bool Put(Key, Value) bool
Remove(Key) (Value, error) Remove(Key) (Value, error)
Copy() Tree Copy() Tree
Traverse(func(Node)bool) Traverse(func(Node) bool)
Values() <-chan Value Values() <-chan Value
} }
func NotFound(key Key) error { func NotFound(key Key) error {
return fmt.Errorf("Key was not found.") return fmt.Errorf("Key was not found.")
} }

View File

@ -1,78 +1,83 @@
package merkle package merkle
import ( import (
. "github.com/tendermint/tendermint/binary" "crypto/sha256"
"os" "fmt"
"fmt" . "github.com/tendermint/tendermint/binary"
"crypto/sha256" "os"
) )
/* /*
Compute a deterministic merkle hash from a list of byteslices. Compute a deterministic merkle hash from a list of byteslices.
*/ */
func HashFromBinarySlice(items []Binary) ByteSlice { func HashFromBinarySlice(items []Binary) ByteSlice {
switch len(items) { switch len(items) {
case 0: case 0:
panic("Cannot compute hash of empty slice") panic("Cannot compute hash of empty slice")
case 1: case 1:
hasher := sha256.New() hasher := sha256.New()
_, err := items[0].WriteTo(hasher) _, err := items[0].WriteTo(hasher)
if err != nil { panic(err) } if err != nil {
return ByteSlice(hasher.Sum(nil)) panic(err)
default: }
hasher := sha256.New() return ByteSlice(hasher.Sum(nil))
_, err := HashFromBinarySlice(items[0:len(items)/2]).WriteTo(hasher) default:
if err != nil { panic(err) } hasher := sha256.New()
_, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher) _, err := HashFromBinarySlice(items[0 : len(items)/2]).WriteTo(hasher)
if err != nil { panic(err) } if err != nil {
return ByteSlice(hasher.Sum(nil)) panic(err)
} }
_, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher)
if err != nil {
panic(err)
}
return ByteSlice(hasher.Sum(nil))
}
} }
func PrintIAVLNode(node *IAVLNode) { func PrintIAVLNode(node *IAVLNode) {
fmt.Println("==== NODE") fmt.Println("==== NODE")
if node != nil { if node != nil {
printIAVLNode(node, 0) printIAVLNode(node, 0)
} }
fmt.Println("==== END") fmt.Println("==== END")
} }
func printIAVLNode(node *IAVLNode, indent int) { func printIAVLNode(node *IAVLNode, indent int) {
indentPrefix := "" indentPrefix := ""
for i:=0; i<indent; i++ { for i := 0; i < indent; i++ {
indentPrefix += " " indentPrefix += " "
} }
if node.right != nil { if node.right != nil {
printIAVLNode(node.rightFilled(nil), indent+1) printIAVLNode(node.rightFilled(nil), indent+1)
} }
fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height) fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height)
if node.left != nil { if node.left != nil {
printIAVLNode(node.leftFilled(nil), indent+1) printIAVLNode(node.leftFilled(nil), indent+1)
} }
} }
func randstr(length int) String { func randstr(length int) String {
if urandom, err := os.Open("/dev/urandom"); err != nil { if urandom, err := os.Open("/dev/urandom"); err != nil {
panic(err) panic(err)
} else { } else {
slice := make([]byte, length) slice := make([]byte, length)
if _, err := urandom.Read(slice); err != nil { if _, err := urandom.Read(slice); err != nil {
panic(err) panic(err)
} }
urandom.Close() urandom.Close()
return String(slice) return String(slice)
} }
panic("unreachable") panic("unreachable")
} }
func maxUint8(a, b uint8) uint8 { func maxUint8(a, b uint8) uint8 {
if a > b { if a > b {
return a return a
} }
return b return b
} }

View File

@ -5,217 +5,236 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/binary" crand "crypto/rand" // for seeding
crand "crypto/rand" // for seeding "encoding/binary"
"encoding/binary" "encoding/json"
"encoding/json" "fmt"
"io" . "github.com/tendermint/tendermint/binary"
"math" "io"
"math/rand" "math"
"net" "math/rand"
"sync" "net"
"sync/atomic" "os"
"time" "sync"
"os" "sync/atomic"
"fmt" "time"
) )
/* AddrBook - concurrency safe peer address manager */ /* AddrBook - concurrency safe peer address manager */
type AddrBook struct { type AddrBook struct {
filePath string filePath string
mtx sync.Mutex mtx sync.Mutex
rand *rand.Rand rand *rand.Rand
key [32]byte key [32]byte
addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress
addrNew [newBucketCount]map[string]*KnownAddress addrNew [newBucketCount]map[string]*KnownAddress
addrOld [oldBucketCount][]*KnownAddress addrOld [oldBucketCount][]*KnownAddress
started int32 started int32
shutdown int32 shutdown int32
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
nOld int nOld int
nNew int nNew int
} }
const ( const (
// addresses under which the address manager will claim to need more addresses. // addresses under which the address manager will claim to need more addresses.
needAddressThreshold = 1000 needAddressThreshold = 1000
// interval used to dump the address cache to disk for future use. // interval used to dump the address cache to disk for future use.
dumpAddressInterval = time.Minute * 2 dumpAddressInterval = time.Minute * 2
// max addresses in each old address bucket. // max addresses in each old address bucket.
oldBucketSize = 64 oldBucketSize = 64
// buckets we split old addresses over. // buckets we split old addresses over.
oldBucketCount = 64 oldBucketCount = 64
// max addresses in each new address bucket. // max addresses in each new address bucket.
newBucketSize = 64 newBucketSize = 64
// buckets that we spread new addresses over. // buckets that we spread new addresses over.
newBucketCount = 256 newBucketCount = 256
// old buckets over which an address group will be spread. // old buckets over which an address group will be spread.
oldBucketsPerGroup = 4 oldBucketsPerGroup = 4
// new buckets over which an source address group will be spread. // new buckets over which an source address group will be spread.
newBucketsPerGroup = 32 newBucketsPerGroup = 32
// buckets a frequently seen new address may end up in. // buckets a frequently seen new address may end up in.
newBucketsPerAddress = 4 newBucketsPerAddress = 4
// days before which we assume an address has vanished // days before which we assume an address has vanished
// if we have not seen it announced in that long. // if we have not seen it announced in that long.
numMissingDays = 30 numMissingDays = 30
// tries without a single success before we assume an address is bad. // tries without a single success before we assume an address is bad.
numRetries = 3 numRetries = 3
// max failures we will accept without a success before considering an address bad. // max failures we will accept without a success before considering an address bad.
maxFailures = 10 maxFailures = 10
// days since the last success before we will consider evicting an address. // days since the last success before we will consider evicting an address.
minBadDays = 7 minBadDays = 7
// max addresses that we will send in response to a getAddr // max addresses that we will send in response to a getAddr
// (in practise the most addresses we will return from a call to AddressCache()). // (in practise the most addresses we will return from a call to AddressCache()).
getAddrMax = 2500 getAddrMax = 2500
// % of total addresses known that we will share with a call to AddressCache. // % of total addresses known that we will share with a call to AddressCache.
getAddrPercent = 23 getAddrPercent = 23
// current version of the on-disk format. // current version of the on-disk format.
serialisationVersion = 1 serialisationVersion = 1
) )
// Use Start to begin processing asynchronous address updates. // Use Start to begin processing asynchronous address updates.
func NewAddrBook(filePath string) *AddrBook { func NewAddrBook(filePath string) *AddrBook {
am := AddrBook{ am := AddrBook{
rand: rand.New(rand.NewSource(time.Now().UnixNano())), rand: rand.New(rand.NewSource(time.Now().UnixNano())),
quit: make(chan struct{}), quit: make(chan struct{}),
filePath: filePath, filePath: filePath,
} }
am.init() am.init()
return &am return &am
} }
// When modifying this, don't forget to update loadFromFile() // When modifying this, don't forget to update loadFromFile()
func (a *AddrBook) init() { func (a *AddrBook) init() {
a.addrIndex = make(map[string]*KnownAddress) a.addrIndex = make(map[string]*KnownAddress)
io.ReadFull(crand.Reader, a.key[:]) io.ReadFull(crand.Reader, a.key[:])
for i := range a.addrNew { for i := range a.addrNew {
a.addrNew[i] = make(map[string]*KnownAddress) a.addrNew[i] = make(map[string]*KnownAddress)
} }
for i := range a.addrOld { for i := range a.addrOld {
a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize) a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize)
} }
} }
func (a *AddrBook) Start() { func (a *AddrBook) Start() {
if atomic.AddInt32(&a.started, 1) != 1 { return } if atomic.AddInt32(&a.started, 1) != 1 {
log.Trace("Starting address manager") return
a.loadFromFile(a.filePath) }
a.wg.Add(1) log.Trace("Starting address manager")
go a.addressHandler() a.loadFromFile(a.filePath)
a.wg.Add(1)
go a.addressHandler()
} }
func (a *AddrBook) Stop() { func (a *AddrBook) Stop() {
if atomic.AddInt32(&a.shutdown, 1) != 1 { return } if atomic.AddInt32(&a.shutdown, 1) != 1 {
log.Infof("Address manager shutting down") return
close(a.quit) }
a.wg.Wait() log.Infof("Address manager shutting down")
close(a.quit)
a.wg.Wait()
} }
func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) { func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock() a.mtx.Lock()
a.addAddress(addr, src) defer a.mtx.Unlock()
a.addAddress(addr, src)
} }
func (a *AddrBook) NeedMoreAddresses() bool { func (a *AddrBook) NeedMoreAddresses() bool {
return a.NumAddresses() < needAddressThreshold return a.NumAddresses() < needAddressThreshold
} }
func (a *AddrBook) NumAddresses() int { func (a *AddrBook) NumAddresses() int {
a.mtx.Lock(); defer a.mtx.Unlock() a.mtx.Lock()
return a.nOld + a.nNew defer a.mtx.Unlock()
return a.nOld + a.nNew
} }
// Pick a new address to connect to. // Pick a new address to connect to.
func (a *AddrBook) PickAddress(class string, newBias int) *KnownAddress { func (a *AddrBook) PickAddress(class string, newBias int) *KnownAddress {
a.mtx.Lock(); defer a.mtx.Unlock() a.mtx.Lock()
defer a.mtx.Unlock()
if a.nOld == 0 && a.nNew == 0 { return nil } if a.nOld == 0 && a.nNew == 0 {
if newBias > 100 { newBias = 100 } return nil
if newBias < 0 { newBias = 0 } }
if newBias > 100 {
newBias = 100
}
if newBias < 0 {
newBias = 0
}
// Bias between new and old addresses. // Bias between new and old addresses.
oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias))
newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias)
if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation {
// pick random Old bucket. // pick random Old bucket.
var bucket []*KnownAddress = nil var bucket []*KnownAddress = nil
for len(bucket) == 0 { for len(bucket) == 0 {
bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] bucket = a.addrOld[a.rand.Intn(len(a.addrOld))]
} }
// pick a random ka from bucket. // pick a random ka from bucket.
return bucket[a.rand.Intn(len(bucket))] return bucket[a.rand.Intn(len(bucket))]
} else { } else {
// pick random New bucket. // pick random New bucket.
var bucket map[string]*KnownAddress = nil var bucket map[string]*KnownAddress = nil
for len(bucket) == 0 { for len(bucket) == 0 {
bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] bucket = a.addrNew[a.rand.Intn(len(a.addrNew))]
} }
// pick a random ka from bucket. // pick a random ka from bucket.
randIndex := a.rand.Intn(len(bucket)) randIndex := a.rand.Intn(len(bucket))
for _, ka := range bucket { for _, ka := range bucket {
randIndex-- randIndex--
if randIndex == 0 { if randIndex == 0 {
return ka return ka
} }
} }
panic("Should not happen") panic("Should not happen")
} }
return nil return nil
} }
func (a *AddrBook) MarkGood(addr *NetAddress) { func (a *AddrBook) MarkGood(addr *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock() a.mtx.Lock()
ka := a.addrIndex[addr.String()] defer a.mtx.Unlock()
if ka == nil { return } ka := a.addrIndex[addr.String()]
ka.MarkAttempt(true) if ka == nil {
if ka.OldBucket == -1 { return
a.moveToOld(ka) }
} ka.MarkAttempt(true)
if ka.OldBucket == -1 {
a.moveToOld(ka)
}
} }
func (a *AddrBook) MarkAttempt(addr *NetAddress) { func (a *AddrBook) MarkAttempt(addr *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock() a.mtx.Lock()
ka := a.addrIndex[addr.String()] defer a.mtx.Unlock()
if ka == nil { return } ka := a.addrIndex[addr.String()]
ka.MarkAttempt(false) if ka == nil {
return
}
ka.MarkAttempt(false)
} }
/* Loading & Saving */ /* Loading & Saving */
type addrBookJSON struct { type addrBookJSON struct {
Key [32]byte Key [32]byte
AddrNew [newBucketCount]map[string]*KnownAddress AddrNew [newBucketCount]map[string]*KnownAddress
AddrOld [oldBucketCount][]*KnownAddress AddrOld [oldBucketCount][]*KnownAddress
NOld int NOld int
NNew int NNew int
} }
func (a *AddrBook) saveToFile(filePath string) { func (a *AddrBook) saveToFile(filePath string) {
aJSON := &addrBookJSON{ aJSON := &addrBookJSON{
Key: a.key, Key: a.key,
AddrNew: a.addrNew, AddrNew: a.addrNew,
AddrOld: a.addrOld, AddrOld: a.addrOld,
NOld: a.nOld, NOld: a.nOld,
NNew: a.nNew, NNew: a.nNew,
} }
w, err := os.Create(filePath) w, err := os.Create(filePath)
if err != nil { if err != nil {
@ -225,296 +244,306 @@ func (a *AddrBook) saveToFile(filePath string) {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
defer w.Close() defer w.Close()
err = enc.Encode(&aJSON) err = enc.Encode(&aJSON)
if err != nil { panic(err) } if err != nil {
panic(err)
}
} }
func (a *AddrBook) loadFromFile(filePath string) { func (a *AddrBook) loadFromFile(filePath string) {
// If doesn't exist, do nothing. // If doesn't exist, do nothing.
_, err := os.Stat(filePath) _, err := os.Stat(filePath)
if os.IsNotExist(err) { return } if os.IsNotExist(err) {
return
}
// Load addrBookJSON{} // Load addrBookJSON{}
r, err := os.Open(filePath) r, err := os.Open(filePath)
if err != nil { if err != nil {
panic(fmt.Errorf("%s error opening file: %v", filePath, err)) panic(fmt.Errorf("%s error opening file: %v", filePath, err))
} }
defer r.Close() defer r.Close()
aJSON := &addrBookJSON{} aJSON := &addrBookJSON{}
dec := json.NewDecoder(r) dec := json.NewDecoder(r)
err = dec.Decode(aJSON) err = dec.Decode(aJSON)
if err != nil { if err != nil {
panic(fmt.Errorf("error reading %s: %v", filePath, err)) panic(fmt.Errorf("error reading %s: %v", filePath, err))
} }
// Now we need to initialize self. // Now we need to initialize self.
copy(a.key[:], aJSON.Key[:]) copy(a.key[:], aJSON.Key[:])
a.addrNew = aJSON.AddrNew a.addrNew = aJSON.AddrNew
for i, oldBucket := range aJSON.AddrOld { for i, oldBucket := range aJSON.AddrOld {
copy(a.addrOld[i], oldBucket) copy(a.addrOld[i], oldBucket)
} }
a.nNew = aJSON.NNew a.nNew = aJSON.NNew
a.nOld = aJSON.NOld a.nOld = aJSON.NOld
a.addrIndex = make(map[string]*KnownAddress) a.addrIndex = make(map[string]*KnownAddress)
for _, newBucket := range a.addrNew { for _, newBucket := range a.addrNew {
for key, ka := range newBucket { for key, ka := range newBucket {
a.addrIndex[key] = ka a.addrIndex[key] = ka
} }
} }
} }
/* Private methods */ /* Private methods */
func (a *AddrBook) addressHandler() { func (a *AddrBook) addressHandler() {
dumpAddressTicker := time.NewTicker(dumpAddressInterval) dumpAddressTicker := time.NewTicker(dumpAddressInterval)
out: out:
for { for {
select { select {
case <-dumpAddressTicker.C: case <-dumpAddressTicker.C:
a.saveToFile(a.filePath) a.saveToFile(a.filePath)
case <-a.quit: case <-a.quit:
break out break out
} }
} }
dumpAddressTicker.Stop() dumpAddressTicker.Stop()
a.saveToFile(a.filePath) a.saveToFile(a.filePath)
a.wg.Done() a.wg.Done()
log.Trace("Address handler done") log.Trace("Address handler done")
} }
func (a *AddrBook) addAddress(addr, src *NetAddress) { func (a *AddrBook) addAddress(addr, src *NetAddress) {
if !addr.Routable() { return } if !addr.Routable() {
return
}
key := addr.String() key := addr.String()
ka := a.addrIndex[key] ka := a.addrIndex[key]
if ka != nil { if ka != nil {
// Already added // Already added
if ka.OldBucket != -1 { return } if ka.OldBucket != -1 {
if ka.NewRefs == newBucketsPerAddress { return } return
}
if ka.NewRefs == newBucketsPerAddress {
return
}
// The more entries we have, the less likely we are to add more. // The more entries we have, the less likely we are to add more.
factor := int32(2 * ka.NewRefs) factor := int32(2 * ka.NewRefs)
if a.rand.Int31n(factor) != 0 { if a.rand.Int31n(factor) != 0 {
return return
} }
} else { } else {
ka = NewKnownAddress(addr, src) ka = NewKnownAddress(addr, src)
a.addrIndex[key] = ka a.addrIndex[key] = ka
a.nNew++ a.nNew++
} }
bucket := a.getNewBucket(addr, src) bucket := a.getNewBucket(addr, src)
// Already exists? // Already exists?
if _, ok := a.addrNew[bucket][key]; ok { if _, ok := a.addrNew[bucket][key]; ok {
return return
} }
// Enforce max addresses. // Enforce max addresses.
if len(a.addrNew[bucket]) > newBucketSize { if len(a.addrNew[bucket]) > newBucketSize {
log.Tracef("new bucket is full, expiring old ") log.Tracef("new bucket is full, expiring old ")
a.expireNew(bucket) a.expireNew(bucket)
} }
// Add to new bucket. // Add to new bucket.
ka.NewRefs++ ka.NewRefs++
a.addrNew[bucket][key] = ka a.addrNew[bucket][key] = ka
log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew) log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew)
} }
// Make space in the new buckets by expiring the really bad entries. // Make space in the new buckets by expiring the really bad entries.
// If no bad entries are available we look at a few and remove the oldest. // If no bad entries are available we look at a few and remove the oldest.
func (a *AddrBook) expireNew(bucket int) { func (a *AddrBook) expireNew(bucket int) {
var oldest *KnownAddress var oldest *KnownAddress
for k, v := range a.addrNew[bucket] { for k, v := range a.addrNew[bucket] {
// If an entry is bad, throw it away // If an entry is bad, throw it away
if v.Bad() { if v.Bad() {
log.Tracef("expiring bad address %v", k) log.Tracef("expiring bad address %v", k)
delete(a.addrNew[bucket], k) delete(a.addrNew[bucket], k)
v.NewRefs-- v.NewRefs--
if v.NewRefs == 0 { if v.NewRefs == 0 {
a.nNew-- a.nNew--
delete(a.addrIndex, k) delete(a.addrIndex, k)
} }
return return
} }
// or, keep track of the oldest entry // or, keep track of the oldest entry
if oldest == nil { if oldest == nil {
oldest = v oldest = v
} else if v.LastAttempt.Before(oldest.LastAttempt.Time) { } else if v.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = v oldest = v
} }
} }
// If we haven't thrown out a bad entry, throw out the oldest entry // If we haven't thrown out a bad entry, throw out the oldest entry
if oldest != nil { if oldest != nil {
key := oldest.Addr.String() key := oldest.Addr.String()
log.Tracef("expiring oldest address %v", key) log.Tracef("expiring oldest address %v", key)
delete(a.addrNew[bucket], key) delete(a.addrNew[bucket], key)
oldest.NewRefs-- oldest.NewRefs--
if oldest.NewRefs == 0 { if oldest.NewRefs == 0 {
a.nNew-- a.nNew--
delete(a.addrIndex, key) delete(a.addrIndex, key)
} }
} }
} }
func (a *AddrBook) moveToOld(ka *KnownAddress) { func (a *AddrBook) moveToOld(ka *KnownAddress) {
// Remove from all new buckets. // Remove from all new buckets.
// Remember one of those new buckets. // Remember one of those new buckets.
addrKey := ka.Addr.String() addrKey := ka.Addr.String()
freedBucket := -1 freedBucket := -1
for i := range a.addrNew { for i := range a.addrNew {
// we check for existance so we can record the first one // we check for existance so we can record the first one
if _, ok := a.addrNew[i][addrKey]; ok { if _, ok := a.addrNew[i][addrKey]; ok {
delete(a.addrNew[i], addrKey) delete(a.addrNew[i], addrKey)
ka.NewRefs-- ka.NewRefs--
if freedBucket == -1 { if freedBucket == -1 {
freedBucket = i freedBucket = i
} }
} }
} }
a.nNew-- a.nNew--
if freedBucket == -1 { panic("Expected to find addr in at least one new bucket") } if freedBucket == -1 {
panic("Expected to find addr in at least one new bucket")
}
oldBucket := a.getOldBucket(ka.Addr) oldBucket := a.getOldBucket(ka.Addr)
// If room in oldBucket, put it in. // If room in oldBucket, put it in.
if len(a.addrOld[oldBucket]) < oldBucketSize { if len(a.addrOld[oldBucket]) < oldBucketSize {
ka.OldBucket = Int16(oldBucket) ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka) a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka)
a.nOld++ a.nOld++
return return
} }
// No room, we have to evict something else. // No room, we have to evict something else.
rmkaIndex := a.pickOld(oldBucket) rmkaIndex := a.pickOld(oldBucket)
rmka := a.addrOld[oldBucket][rmkaIndex] rmka := a.addrOld[oldBucket][rmkaIndex]
// Find a new bucket to put rmka in. // Find a new bucket to put rmka in.
newBucket := a.getNewBucket(rmka.Addr, rmka.Src) newBucket := a.getNewBucket(rmka.Addr, rmka.Src)
if len(a.addrNew[newBucket]) >= newBucketSize { if len(a.addrNew[newBucket]) >= newBucketSize {
newBucket = freedBucket newBucket = freedBucket
} }
// replace with ka in list. // replace with ka in list.
ka.OldBucket = Int16(oldBucket) ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket][rmkaIndex] = ka a.addrOld[oldBucket][rmkaIndex] = ka
rmka.OldBucket = -1 rmka.OldBucket = -1
// put rmka into new bucket // put rmka into new bucket
rmkey := rmka.Addr.String() rmkey := rmka.Addr.String()
log.Tracef("Replacing %s with %s in old", rmkey, addrKey) log.Tracef("Replacing %s with %s in old", rmkey, addrKey)
a.addrNew[newBucket][rmkey] = rmka a.addrNew[newBucket][rmkey] = rmka
rmka.NewRefs++ rmka.NewRefs++
a.nNew++ a.nNew++
} }
// Returns the index in old bucket of oldest entry. // Returns the index in old bucket of oldest entry.
func (a *AddrBook) pickOld(bucket int) int { func (a *AddrBook) pickOld(bucket int) int {
var oldest *KnownAddress var oldest *KnownAddress
var oldestIndex int var oldestIndex int
for i, ka := range a.addrOld[bucket] { for i, ka := range a.addrOld[bucket] {
if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) { if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = ka oldest = ka
oldestIndex = i oldestIndex = i
} }
} }
return oldestIndex return oldestIndex
} }
// doublesha256(key + sourcegroup + // doublesha256(key + sourcegroup +
// int64(doublesha256(key + group + sourcegroup))%bucket_per_source_group) % num_new_buckes // int64(doublesha256(key + group + sourcegroup))%bucket_per_source_group) % num_new_buckes
func (a *AddrBook) getNewBucket(addr, src *NetAddress) int { func (a *AddrBook) getNewBucket(addr, src *NetAddress) int {
data1 := []byte{} data1 := []byte{}
data1 = append(data1, a.key[:]...) data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(GroupKey(addr))...) data1 = append(data1, []byte(GroupKey(addr))...)
data1 = append(data1, []byte(GroupKey(src))...) data1 = append(data1, []byte(GroupKey(src))...)
hash1 := DoubleSha256(data1) hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1) hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= newBucketsPerGroup hash64 %= newBucketsPerGroup
var hashbuf [8]byte var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64) binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{} data2 := []byte{}
data2 = append(data2, a.key[:]...) data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(src)...) data2 = append(data2, GroupKey(src)...)
data2 = append(data2, hashbuf[:]...) data2 = append(data2, hashbuf[:]...)
hash2 := DoubleSha256(data2) hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % newBucketCount) return int(binary.LittleEndian.Uint64(hash2) % newBucketCount)
} }
// doublesha256(key + group + truncate_to_64bits(doublesha256(key + addr))%buckets_per_group) % num_buckets // doublesha256(key + group + truncate_to_64bits(doublesha256(key + addr))%buckets_per_group) % num_buckets
func (a *AddrBook) getOldBucket(addr *NetAddress) int { func (a *AddrBook) getOldBucket(addr *NetAddress) int {
data1 := []byte{} data1 := []byte{}
data1 = append(data1, a.key[:]...) data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(addr.String())...) data1 = append(data1, []byte(addr.String())...)
hash1 := DoubleSha256(data1) hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1) hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= oldBucketsPerGroup hash64 %= oldBucketsPerGroup
var hashbuf [8]byte var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64) binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{} data2 := []byte{}
data2 = append(data2, a.key[:]...) data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(addr)...) data2 = append(data2, GroupKey(addr)...)
data2 = append(data2, hashbuf[:]...) data2 = append(data2, hashbuf[:]...)
hash2 := DoubleSha256(data2) hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount) return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount)
} }
// Return a string representing the network group of this address. // Return a string representing the network group of this address.
// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string // This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string
// "local" for a local address and the string "unroutable for an unroutable // "local" for a local address and the string "unroutable for an unroutable
// address. // address.
func GroupKey (na *NetAddress) string { func GroupKey(na *NetAddress) string {
if na.Local() { if na.Local() {
return "local" return "local"
} }
if !na.Routable() { if !na.Routable() {
return "unroutable" return "unroutable"
} }
if ipv4 := na.IP.To4(); ipv4 != nil { if ipv4 := na.IP.To4(); ipv4 != nil {
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String()
} }
if na.RFC6145() || na.RFC6052() { if na.RFC6145() || na.RFC6052() {
// last four bytes are the ip address // last four bytes are the ip address
ip := net.IP(na.IP[12:16]) ip := net.IP(na.IP[12:16])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
} }
if na.RFC3964() { if na.RFC3964() {
ip := net.IP(na.IP[2:7]) ip := net.IP(na.IP[2:7])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
} }
if na.RFC4380() { if na.RFC4380() {
// teredo tunnels have the last 4 bytes as the v4 address XOR // teredo tunnels have the last 4 bytes as the v4 address XOR
// 0xff. // 0xff.
ip := net.IP(make([]byte, 4)) ip := net.IP(make([]byte, 4))
for i, byte := range na.IP[12:16] { for i, byte := range na.IP[12:16] {
ip[i] = byte ^ 0xff ip[i] = byte ^ 0xff
} }
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
} }
// OK, so now we know ourselves to be a IPv6 address. // OK, so now we know ourselves to be a IPv6 address.
// bitcoind uses /32 for everything, except for Hurricane Electric's // bitcoind uses /32 for everything, except for Hurricane Electric's
// (he.net) IP range, which it uses /36 for. // (he.net) IP range, which it uses /36 for.
bits := 32 bits := 32
heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), heNet := &net.IPNet{IP: net.ParseIP("2001:470::"),
Mask: net.CIDRMask(32, 128)} Mask: net.CIDRMask(32, 128)}
if heNet.Contains(na.IP) { if heNet.Contains(na.IP) {
bits = 36 bits = 36
} }
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String()
} }

View File

@ -1,12 +1,12 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/common" "errors"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle" . "github.com/tendermint/tendermint/common"
"sync/atomic" "github.com/tendermint/tendermint/merkle"
"sync" "sync"
"errors" "sync/atomic"
) )
/* Client /* Client
@ -21,147 +21,161 @@ import (
XXX what about peer disconnects? XXX what about peer disconnects?
*/ */
type Client struct { type Client struct {
addrBook *AddrBook addrBook *AddrBook
targetNumPeers int targetNumPeers int
makePeerFn func(*Connection) *Peer makePeerFn func(*Connection) *Peer
self *Peer self *Peer
recvQueues map[String]chan *InboundPacket recvQueues map[String]chan *InboundPacket
mtx sync.Mutex mtx sync.Mutex
peers merkle.Tree // addr -> *Peer peers merkle.Tree // addr -> *Peer
quit chan struct{} quit chan struct{}
stopped uint32 stopped uint32
} }
var ( var (
CLIENT_STOPPED_ERROR = errors.New("Client already stopped") CLIENT_STOPPED_ERROR = errors.New("Client already stopped")
CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer") CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer")
) )
func NewClient(makePeerFn func(*Connection) *Peer) *Client { func NewClient(makePeerFn func(*Connection) *Peer) *Client {
self := makePeerFn(nil) self := makePeerFn(nil)
if self == nil { if self == nil {
Panicf("makePeerFn(nil) must return a prototypical peer for self") Panicf("makePeerFn(nil) must return a prototypical peer for self")
} }
recvQueues := make(map[String]chan *InboundPacket) recvQueues := make(map[String]chan *InboundPacket)
for chName, _ := range self.channels { for chName, _ := range self.channels {
recvQueues[chName] = make(chan *InboundPacket) recvQueues[chName] = make(chan *InboundPacket)
} }
c := &Client{ c := &Client{
addrBook: nil, // TODO addrBook: nil, // TODO
targetNumPeers: 0, // TODO targetNumPeers: 0, // TODO
makePeerFn: makePeerFn, makePeerFn: makePeerFn,
self: self, self: self,
recvQueues: recvQueues, recvQueues: recvQueues,
peers: merkle.NewIAVLTree(nil), peers: merkle.NewIAVLTree(nil),
quit: make(chan struct{}), quit: make(chan struct{}),
stopped: 0, stopped: 0,
} }
return c return c
} }
func (c *Client) Stop() { func (c *Client) Stop() {
log.Infof("Stopping client") log.Infof("Stopping client")
// lock // lock
c.mtx.Lock() c.mtx.Lock()
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
close(c.quit) close(c.quit)
// stop each peer. // stop each peer.
for peerValue := range c.peers.Values() { for peerValue := range c.peers.Values() {
peer := peerValue.(*Peer) peer := peerValue.(*Peer)
peer.Stop() peer.Stop()
} }
// empty tree. // empty tree.
c.peers = merkle.NewIAVLTree(nil) c.peers = merkle.NewIAVLTree(nil)
} }
c.mtx.Unlock() c.mtx.Unlock()
// unlock // unlock
} }
func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) { func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) {
if atomic.LoadUint32(&c.stopped) == 1 { return nil, CLIENT_STOPPED_ERROR } if atomic.LoadUint32(&c.stopped) == 1 {
return nil, CLIENT_STOPPED_ERROR
}
log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing) log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing)
peer := c.makePeerFn(conn) peer := c.makePeerFn(conn)
peer.outgoing = outgoing peer.outgoing = outgoing
err := c.addPeer(peer) err := c.addPeer(peer)
if err != nil { return nil, err } if err != nil {
return nil, err
}
go peer.Start(c.recvQueues) go peer.Start(c.recvQueues)
return peer, nil return peer, nil
} }
func (c *Client) Broadcast(pkt Packet) { func (c *Client) Broadcast(pkt Packet) {
if atomic.LoadUint32(&c.stopped) == 1 { return } if atomic.LoadUint32(&c.stopped) == 1 {
return
}
log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes)) log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes))
for v := range c.Peers().Values() { for v := range c.Peers().Values() {
peer := v.(*Peer) peer := v.(*Peer)
success := peer.TrySend(pkt) success := peer.TrySend(pkt)
log.Tracef("Broadcast for peer %v success: %v", peer, success) log.Tracef("Broadcast for peer %v success: %v", peer, success)
if !success { if !success {
// TODO: notify the peer // TODO: notify the peer
} }
} }
} }
// blocks until a message is popped. // blocks until a message is popped.
func (c *Client) Receive(chName String) *InboundPacket { func (c *Client) Receive(chName String) *InboundPacket {
if atomic.LoadUint32(&c.stopped) == 1 { return nil } if atomic.LoadUint32(&c.stopped) == 1 {
return nil
}
log.Tracef("Receive on [%v]", chName) log.Tracef("Receive on [%v]", chName)
q := c.recvQueues[chName] q := c.recvQueues[chName]
if q == nil { Panicf("Expected recvQueues[%f], found none", chName) } if q == nil {
Panicf("Expected recvQueues[%f], found none", chName)
}
for { for {
select { select {
case <-c.quit: case <-c.quit:
return nil return nil
case inPacket := <-q: case inPacket := <-q:
return inPacket return inPacket
} }
} }
} }
func (c *Client) Peers() merkle.Tree { func (c *Client) Peers() merkle.Tree {
// lock & defer // lock & defer
c.mtx.Lock(); defer c.mtx.Unlock() c.mtx.Lock()
return c.peers.Copy() defer c.mtx.Unlock()
// unlock deferred return c.peers.Copy()
// unlock deferred
} }
func (c *Client) StopPeer(peer *Peer) { func (c *Client) StopPeer(peer *Peer) {
// lock // lock
c.mtx.Lock() c.mtx.Lock()
peerValue, _ := c.peers.Remove(peer.RemoteAddress()) peerValue, _ := c.peers.Remove(peer.RemoteAddress())
c.mtx.Unlock() c.mtx.Unlock()
// unlock // unlock
peer_ := peerValue.(*Peer) peer_ := peerValue.(*Peer)
if peer_ != nil { if peer_ != nil {
peer_.Stop() peer_.Stop()
} }
} }
func (c *Client) addPeer(peer *Peer) error { func (c *Client) addPeer(peer *Peer) error {
addr := peer.RemoteAddress() addr := peer.RemoteAddress()
// lock & defer // lock & defer
c.mtx.Lock(); defer c.mtx.Unlock() c.mtx.Lock()
if c.stopped == 1 { return CLIENT_STOPPED_ERROR } defer c.mtx.Unlock()
if !c.peers.Has(addr) { if c.stopped == 1 {
log.Tracef("Actually putting addr: %v, peer: %v", addr, peer) return CLIENT_STOPPED_ERROR
c.peers.Put(addr, peer) }
return nil if !c.peers.Has(addr) {
} else { log.Tracef("Actually putting addr: %v, peer: %v", addr, peer)
// ignore duplicate peer for addr. c.peers.Put(addr, peer)
log.Infof("Ignoring duplicate peer for addr %v", addr) return nil
return CLIENT_DUPLICATE_PEER_ERROR } else {
} // ignore duplicate peer for addr.
// unlock deferred log.Infof("Ignoring duplicate peer for addr %v", addr)
return CLIENT_DUPLICATE_PEER_ERROR
}
// unlock deferred
} }

View File

@ -1,106 +1,105 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"testing" "testing"
"time" "time"
) )
// convenience method for creating two clients connected to each other. // convenience method for creating two clients connected to each other.
func makeClientPair(t *testing.T, bufferSize int, channels []string) (*Client, *Client) { func makeClientPair(t *testing.T, bufferSize int, channels []string) (*Client, *Client) {
peerMaker := func(conn *Connection) *Peer { peerMaker := func(conn *Connection) *Peer {
p := NewPeer(conn) p := NewPeer(conn)
p.channels = map[String]*Channel{} p.channels = map[String]*Channel{}
for chName := range channels { for chName := range channels {
p.channels[String(chName)] = NewChannel(String(chName), bufferSize) p.channels[String(chName)] = NewChannel(String(chName), bufferSize)
} }
return p return p
} }
// Create two clients that will be interconnected. // Create two clients that will be interconnected.
c1 := NewClient(peerMaker) c1 := NewClient(peerMaker)
c2 := NewClient(peerMaker) c2 := NewClient(peerMaker)
// Create a server for the listening client. // Create a server for the listening client.
s1 := NewServer("tcp", ":8001", c1) s1 := NewServer("tcp", ":8001", c1)
// Dial the server & add the connection to c2. // Dial the server & add the connection to c2.
s1laddr := s1.LocalAddress() s1laddr := s1.LocalAddress()
conn, err := s1laddr.Dial() conn, err := s1laddr.Dial()
if err != nil { if err != nil {
t.Fatalf("Could not connect to server address %v", s1laddr) t.Fatalf("Could not connect to server address %v", s1laddr)
} else { } else {
t.Logf("Created a connection to local server address %v", s1laddr) t.Logf("Created a connection to local server address %v", s1laddr)
} }
c2.AddPeerWithConnection(conn, true) c2.AddPeerWithConnection(conn, true)
// Wait for things to happen, peers to get added... // Wait for things to happen, peers to get added...
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
return c1, c2 return c1, c2
} }
func TestClients(t *testing.T) { func TestClients(t *testing.T) {
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
// Lets send a message from c1 to c2. // Lets send a message from c1 to c2.
if c1.Peers().Size() != 1 { if c1.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size()) t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size())
} }
if c2.Peers().Size() != 1 { if c2.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size()) t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size())
} }
// Broadcast a message on ch1 // Broadcast a message on ch1
c1.Broadcast(NewPacket("ch1", ByteSlice("channel one"))) c1.Broadcast(NewPacket("ch1", ByteSlice("channel one")))
// Broadcast a message on ch2 // Broadcast a message on ch2
c1.Broadcast(NewPacket("ch2", ByteSlice("channel two"))) c1.Broadcast(NewPacket("ch2", ByteSlice("channel two")))
// Broadcast a message on ch3 // Broadcast a message on ch3
c1.Broadcast(NewPacket("ch3", ByteSlice("channel three"))) c1.Broadcast(NewPacket("ch3", ByteSlice("channel three")))
// Wait for things to settle... // Wait for things to settle...
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Receive message from channel 2 and check // Receive message from channel 2 and check
inMsg := c2.Receive("ch2") inMsg := c2.Receive("ch2")
if string(inMsg.Bytes) != "channel two" { if string(inMsg.Bytes) != "channel two" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
} }
// Receive message from channel 1 and check // Receive message from channel 1 and check
inMsg = c2.Receive("ch1") inMsg = c2.Receive("ch1")
if string(inMsg.Bytes) != "channel one" { if string(inMsg.Bytes) != "channel one" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
} }
s1.Stop() s1.Stop()
c2.Stop() c2.Stop()
} }
func BenchmarkClients(b *testing.B) { func BenchmarkClients(b *testing.B) {
b.StopTimer() b.StopTimer()
// TODO: benchmark the random functions, which is faster? // TODO: benchmark the random functions, which is faster?
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
// Create a sink on either channel to just pop off messages. // Create a sink on either channel to just pop off messages.
// TODO: ensure that when clients stop, this goroutine stops. // TODO: ensure that when clients stop, this goroutine stops.
func recvHandler(c *Client) { recvHandler := func(c *Client) {
} }
go recvHandler(c1) go recvHandler(c1)
go recvHandler(c2) go recvHandler(c2)
b.StartTimer() b.StartTimer()
// Send random message from one channel to another // Send random message from one channel to another
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
} }
} }

View File

@ -1,191 +1,192 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/common" "fmt"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"sync/atomic" . "github.com/tendermint/tendermint/common"
"net" "net"
"time" "sync/atomic"
"fmt" "time"
) )
const ( const (
OUT_QUEUE_SIZE = 50 OUT_QUEUE_SIZE = 50
IDLE_TIMEOUT_MINUTES = 5 IDLE_TIMEOUT_MINUTES = 5
PING_TIMEOUT_MINUTES = 2 PING_TIMEOUT_MINUTES = 2
) )
/* Connnection */ /* Connnection */
type Connection struct { type Connection struct {
ioStats IOStats ioStats IOStats
sendQueue chan Packet // never closes sendQueue chan Packet // never closes
conn net.Conn conn net.Conn
quit chan struct{} quit chan struct{}
stopped uint32 stopped uint32
pingDebouncer *Debouncer pingDebouncer *Debouncer
pong chan struct{} pong chan struct{}
} }
var ( var (
PACKET_TYPE_PING = UInt8(0x00) PACKET_TYPE_PING = UInt8(0x00)
PACKET_TYPE_PONG = UInt8(0x01) PACKET_TYPE_PONG = UInt8(0x01)
PACKET_TYPE_MSG = UInt8(0x10) PACKET_TYPE_MSG = UInt8(0x10)
) )
func NewConnection(conn net.Conn) *Connection { func NewConnection(conn net.Conn) *Connection {
return &Connection{ return &Connection{
sendQueue: make(chan Packet, OUT_QUEUE_SIZE), sendQueue: make(chan Packet, OUT_QUEUE_SIZE),
conn: conn, conn: conn,
quit: make(chan struct{}), quit: make(chan struct{}),
pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute), pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute),
pong: make(chan struct{}), pong: make(chan struct{}),
} }
} }
// returns true if successfully queued, // returns true if successfully queued,
// returns false if connection was closed. // returns false if connection was closed.
// blocks. // blocks.
func (c *Connection) Send(pkt Packet) bool { func (c *Connection) Send(pkt Packet) bool {
select { select {
case c.sendQueue <- pkt: case c.sendQueue <- pkt:
return true return true
case <-c.quit: case <-c.quit:
return false return false
} }
} }
func (c *Connection) Start(channels map[String]*Channel) { func (c *Connection) Start(channels map[String]*Channel) {
log.Debugf("Starting %v", c) log.Debugf("Starting %v", c)
go c.sendHandler() go c.sendHandler()
go c.recvHandler(channels) go c.recvHandler(channels)
} }
func (c *Connection) Stop() { func (c *Connection) Stop() {
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
log.Debugf("Stopping %v", c) log.Debugf("Stopping %v", c)
close(c.quit) close(c.quit)
c.conn.Close() c.conn.Close()
c.pingDebouncer.Stop() c.pingDebouncer.Stop()
// We can't close pong safely here because // We can't close pong safely here because
// recvHandler may write to it after we've stopped. // recvHandler may write to it after we've stopped.
// Though it doesn't need to get closed at all, // Though it doesn't need to get closed at all,
// we close it @ recvHandler. // we close it @ recvHandler.
// close(c.pong) // close(c.pong)
} }
} }
func (c *Connection) LocalAddress() *NetAddress { func (c *Connection) LocalAddress() *NetAddress {
return NewNetAddress(c.conn.LocalAddr()) return NewNetAddress(c.conn.LocalAddr())
} }
func (c *Connection) RemoteAddress() *NetAddress { func (c *Connection) RemoteAddress() *NetAddress {
return NewNetAddress(c.conn.RemoteAddr()) return NewNetAddress(c.conn.RemoteAddr())
} }
func (c *Connection) String() string { func (c *Connection) String() string {
return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr()) return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr())
} }
func (c *Connection) flush() { func (c *Connection) flush() {
// TODO flush? (turn off nagel, turn back on, etc) // TODO flush? (turn off nagel, turn back on, etc)
} }
func (c *Connection) sendHandler() { func (c *Connection) sendHandler() {
log.Tracef("%v sendHandler", c) log.Tracef("%v sendHandler", c)
// TODO: catch panics & stop connection. // TODO: catch panics & stop connection.
FOR_LOOP: FOR_LOOP:
for { for {
var err error var err error
select { select {
case <-c.pingDebouncer.Ch: case <-c.pingDebouncer.Ch:
_, err = PACKET_TYPE_PING.WriteTo(c.conn) _, err = PACKET_TYPE_PING.WriteTo(c.conn)
case sendPkt := <-c.sendQueue: case sendPkt := <-c.sendQueue:
log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection") log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection")
_, err = PACKET_TYPE_MSG.WriteTo(c.conn) _, err = PACKET_TYPE_MSG.WriteTo(c.conn)
if err != nil { break } if err != nil {
_, err = sendPkt.WriteTo(c.conn) break
case <-c.pong: }
_, err = PACKET_TYPE_PONG.WriteTo(c.conn) _, err = sendPkt.WriteTo(c.conn)
case <-c.quit: case <-c.pong:
break FOR_LOOP _, err = PACKET_TYPE_PONG.WriteTo(c.conn)
} case <-c.quit:
break FOR_LOOP
}
if err != nil { if err != nil {
log.Infof("%v failed @ sendHandler:\n%v", c, err) log.Infof("%v failed @ sendHandler:\n%v", c, err)
c.Stop() c.Stop()
break FOR_LOOP break FOR_LOOP
} }
c.flush() c.flush()
} }
log.Tracef("%v sendHandler done", c) log.Tracef("%v sendHandler done", c)
// cleanup // cleanup
} }
func (c *Connection) recvHandler(channels map[String]*Channel) { func (c *Connection) recvHandler(channels map[String]*Channel) {
log.Tracef("%v recvHandler with %v channels", c, len(channels)) log.Tracef("%v recvHandler with %v channels", c, len(channels))
// TODO: catch panics & stop connection. // TODO: catch panics & stop connection.
FOR_LOOP: FOR_LOOP:
for { for {
pktType, err := ReadUInt8Safe(c.conn) pktType, err := ReadUInt8Safe(c.conn)
if err != nil { if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 { if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c) log.Infof("%v failed @ recvHandler", c)
c.Stop() c.Stop()
} }
break FOR_LOOP break FOR_LOOP
} else { } else {
log.Tracef("Found pktType %v", pktType) log.Tracef("Found pktType %v", pktType)
} }
switch pktType { switch pktType {
case PACKET_TYPE_PING: case PACKET_TYPE_PING:
c.pong <- struct{}{} c.pong <- struct{}{}
case PACKET_TYPE_PONG: case PACKET_TYPE_PONG:
// do nothing // do nothing
case PACKET_TYPE_MSG: case PACKET_TYPE_MSG:
pkt, err := ReadPacketSafe(c.conn) pkt, err := ReadPacketSafe(c.conn)
if err != nil { if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 { if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c) log.Infof("%v failed @ recvHandler", c)
c.Stop() c.Stop()
} }
break FOR_LOOP break FOR_LOOP
} }
channel := channels[pkt.Channel] channel := channels[pkt.Channel]
if channel == nil { if channel == nil {
Panicf("Unknown channel %v", pkt.Channel) Panicf("Unknown channel %v", pkt.Channel)
} }
channel.recvQueue <- pkt channel.recvQueue <- pkt
default: default:
Panicf("Unknown message type %v", pktType) Panicf("Unknown message type %v", pktType)
} }
c.pingDebouncer.Reset() c.pingDebouncer.Reset()
} }
log.Tracef("%v recvHandler done", c) log.Tracef("%v recvHandler done", c)
// cleanup // cleanup
close(c.pong) close(c.pong)
for _ = range c.pong { for _ = range c.pong {
// drain // drain
} }
} }
/* IOStats */ /* IOStats */
type IOStats struct { type IOStats struct {
TimeConnected Time TimeConnected Time
LastSent Time LastSent Time
LastRecv Time LastRecv Time
BytesRecv UInt64 BytesRecv UInt64
BytesSent UInt64 BytesSent UInt64
PktsRecv UInt64 PktsRecv UInt64
PktsSent UInt64 PktsSent UInt64
} }

View File

@ -1,104 +1,104 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"time" "io"
"io" "time"
) )
/* /*
KnownAddress KnownAddress
tracks information about a known network address that is used tracks information about a known network address that is used
to determine how viable an address is. to determine how viable an address is.
*/ */
type KnownAddress struct { type KnownAddress struct {
Addr *NetAddress Addr *NetAddress
Src *NetAddress Src *NetAddress
Attempts UInt32 Attempts UInt32
LastAttempt Time LastAttempt Time
LastSuccess Time LastSuccess Time
NewRefs UInt16 NewRefs UInt16
OldBucket Int16 // TODO init to -1 OldBucket Int16 // TODO init to -1
} }
func NewKnownAddress(addr *NetAddress, src *NetAddress) *KnownAddress { func NewKnownAddress(addr *NetAddress, src *NetAddress) *KnownAddress {
return &KnownAddress{ return &KnownAddress{
Addr: addr, Addr: addr,
Src: src, Src: src,
OldBucket: -1, OldBucket: -1,
LastAttempt: Time{time.Now()}, LastAttempt: Time{time.Now()},
Attempts: 0, Attempts: 0,
} }
} }
func ReadKnownAddress(r io.Reader) *KnownAddress { func ReadKnownAddress(r io.Reader) *KnownAddress {
return &KnownAddress{ return &KnownAddress{
Addr: ReadNetAddress(r), Addr: ReadNetAddress(r),
Src: ReadNetAddress(r), Src: ReadNetAddress(r),
Attempts: ReadUInt32(r), Attempts: ReadUInt32(r),
LastAttempt: ReadTime(r), LastAttempt: ReadTime(r),
LastSuccess: ReadTime(r), LastSuccess: ReadTime(r),
NewRefs: ReadUInt16(r), NewRefs: ReadUInt16(r),
OldBucket: ReadInt16(r), OldBucket: ReadInt16(r),
} }
} }
func (ka *KnownAddress) WriteTo(w io.Writer) (n int64, err error) { func (ka *KnownAddress) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(ka.Addr, w, n, err) n, err = WriteOnto(ka.Addr, w, n, err)
n, err = WriteOnto(ka.Src, w, n, err) n, err = WriteOnto(ka.Src, w, n, err)
n, err = WriteOnto(ka.Attempts, w, n, err) n, err = WriteOnto(ka.Attempts, w, n, err)
n, err = WriteOnto(ka.LastAttempt, w, n, err) n, err = WriteOnto(ka.LastAttempt, w, n, err)
n, err = WriteOnto(ka.LastSuccess, w, n, err) n, err = WriteOnto(ka.LastSuccess, w, n, err)
n, err = WriteOnto(ka.NewRefs, w, n, err) n, err = WriteOnto(ka.NewRefs, w, n, err)
n, err = WriteOnto(ka.OldBucket, w, n, err) n, err = WriteOnto(ka.OldBucket, w, n, err)
return return
} }
func (ka *KnownAddress) MarkAttempt(success bool) { func (ka *KnownAddress) MarkAttempt(success bool) {
now := Time{time.Now()} now := Time{time.Now()}
ka.LastAttempt = now ka.LastAttempt = now
if success { if success {
ka.LastSuccess = now ka.LastSuccess = now
ka.Attempts = 0 ka.Attempts = 0
} else { } else {
ka.Attempts += 1 ka.Attempts += 1
} }
} }
/* /*
An address is bad if the address in question has not been tried in the last An address is bad if the address in question has not been tried in the last
minute and meets one of the following criteria: minute and meets one of the following criteria:
1) It claims to be from the future 1) It claims to be from the future
2) It hasn't been seen in over a month 2) It hasn't been seen in over a month
3) It has failed at least three times and never succeeded 3) It has failed at least three times and never succeeded
4) It has failed ten times in the last week 4) It has failed ten times in the last week
All addresses that meet these criteria are assumed to be worthless and not All addresses that meet these criteria are assumed to be worthless and not
worth keeping hold of. worth keeping hold of.
*/ */
func (ka *KnownAddress) Bad() bool { func (ka *KnownAddress) Bad() bool {
// Has been attempted in the last minute --> good // Has been attempted in the last minute --> good
if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) {
return false return false
} }
// Over a month old? // Over a month old?
if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) {
return true return true
} }
// Never succeeded? // Never succeeded?
if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries {
return true return true
} }
// Hasn't succeeded in too long? // Hasn't succeeded in too long?
if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) &&
ka.Attempts >= maxFailures { ka.Attempts >= maxFailures {
return true return true
} }
return false return false
} }

View File

@ -1,127 +1,145 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/common"
"sync/atomic" "net"
"net" "sync/atomic"
) )
const ( const (
DEFAULT_PORT = 8001 DEFAULT_PORT = 8001
) )
/* Listener */ /* Listener */
type Listener interface { type Listener interface {
Connections() <-chan *Connection Connections() <-chan *Connection
LocalAddress() *NetAddress LocalAddress() *NetAddress
Stop() Stop()
} }
/* DefaultListener */ /* DefaultListener */
type DefaultListener struct { type DefaultListener struct {
listener net.Listener listener net.Listener
connections chan *Connection connections chan *Connection
stopped uint32 stopped uint32
} }
const ( const (
DEFAULT_BUFFERED_CONNECTIONS = 10 DEFAULT_BUFFERED_CONNECTIONS = 10
) )
func NewDefaultListener(protocol string, listenAddr string) Listener { func NewDefaultListener(protocol string, listenAddr string) Listener {
listener, err := net.Listen(protocol, listenAddr) listener, err := net.Listen(protocol, listenAddr)
if err != nil { panic(err) } if err != nil {
panic(err)
}
dl := &DefaultListener{ dl := &DefaultListener{
listener: listener, listener: listener,
connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS), connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS),
} }
go dl.listenHandler() go dl.listenHandler()
return dl return dl
} }
func (l *DefaultListener) listenHandler() { func (l *DefaultListener) listenHandler() {
for { for {
conn, err := l.listener.Accept() conn, err := l.listener.Accept()
if atomic.LoadUint32(&l.stopped) == 1 { return } if atomic.LoadUint32(&l.stopped) == 1 {
return
}
// listener wasn't stopped, // listener wasn't stopped,
// yet we encountered an error. // yet we encountered an error.
if err != nil { panic(err) } if err != nil {
panic(err)
}
c := NewConnection(conn) c := NewConnection(conn)
l.connections <- c l.connections <- c
} }
// cleanup // cleanup
close(l.connections) close(l.connections)
for _ = range l.connections { for _ = range l.connections {
// drain // drain
} }
} }
func (l *DefaultListener) Connections() <-chan *Connection { func (l *DefaultListener) Connections() <-chan *Connection {
return l.connections return l.connections
} }
func (l *DefaultListener) LocalAddress() *NetAddress { func (l *DefaultListener) LocalAddress() *NetAddress {
return GetLocalAddress() return GetLocalAddress()
} }
func (l *DefaultListener) Stop() { func (l *DefaultListener) Stop() {
if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) { if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) {
l.listener.Close() l.listener.Close()
} }
} }
/* local address helpers */ /* local address helpers */
func GetLocalAddress() *NetAddress { func GetLocalAddress() *NetAddress {
laddr := GetUPNPLocalAddress() laddr := GetUPNPLocalAddress()
if laddr != nil { return laddr } if laddr != nil {
return laddr
}
laddr = GetDefaultLocalAddress() laddr = GetDefaultLocalAddress()
if laddr != nil { return laddr } if laddr != nil {
return laddr
}
panic("Could not determine local address") panic("Could not determine local address")
} }
// UPNP external address discovery & port mapping // UPNP external address discovery & port mapping
// TODO: more flexible internal & external ports // TODO: more flexible internal & external ports
func GetUPNPLocalAddress() *NetAddress { func GetUPNPLocalAddress() *NetAddress {
nat, err := Discover() nat, err := Discover()
if err != nil { return nil } if err != nil {
return nil
}
ext, err := nat.GetExternalAddress() ext, err := nat.GetExternalAddress()
if err != nil { return nil } if err != nil {
return nil
}
_, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0) _, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0)
if err != nil { return nil } if err != nil {
return nil
}
return NewNetAddressIPPort(ext, DEFAULT_PORT) return NewNetAddressIPPort(ext, DEFAULT_PORT)
} }
// Naive local IPv4 interface address detection // Naive local IPv4 interface address detection
// TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh // TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh
func GetDefaultLocalAddress() *NetAddress { func GetDefaultLocalAddress() *NetAddress {
addrs, err := net.InterfaceAddrs() addrs, err := net.InterfaceAddrs()
if err != nil { Panicf("Unexpected error fetching interface addresses: %v", err) } if err != nil {
Panicf("Unexpected error fetching interface addresses: %v", err)
}
for _, a := range addrs { for _, a := range addrs {
ipnet, ok := a.(*net.IPNet) ipnet, ok := a.(*net.IPNet)
if !ok { continue } if !ok {
v4 := ipnet.IP.To4() continue
if v4 == nil || v4[0] == 127 { continue } // loopback }
return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT) v4 := ipnet.IP.To4()
} if v4 == nil || v4[0] == 127 {
return nil continue
} // loopback
return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT)
}
return nil
} }

View File

@ -1,14 +1,14 @@
package peer package peer
import ( import (
"github.com/cihub/seelog" "github.com/cihub/seelog"
) )
var log seelog.LoggerInterface var log seelog.LoggerInterface
func init() { func init() {
// TODO: replace with configuration file in the ~/.tendermint directory. // TODO: replace with configuration file in the ~/.tendermint directory.
config := ` config := `
<seelog type="sync"> <seelog type="sync">
<outputs formatid="colored"> <outputs formatid="colored">
<console/> <console/>
@ -19,7 +19,9 @@ func init() {
</formats> </formats>
</seelog>` </seelog>`
var err error var err error
log, err = seelog.LoggerFromConfigAsBytes([]byte(config)) log, err = seelog.LoggerFromConfigAsBytes([]byte(config))
if err != nil { panic(err) } if err != nil {
panic(err)
}
} }

View File

@ -1,59 +1,61 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/binary"
"io" "io"
) )
/* Packet */ /* Packet */
type Packet struct { type Packet struct {
Channel String Channel String
Bytes ByteSlice Bytes ByteSlice
// Hash // Hash
} }
func NewPacket(chName String, bytes ByteSlice) Packet { func NewPacket(chName String, bytes ByteSlice) Packet {
return Packet{ return Packet{
Channel: chName, Channel: chName,
Bytes: bytes, Bytes: bytes,
} }
} }
func (p Packet) WriteTo(w io.Writer) (n int64, err error) { func (p Packet) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(&p.Channel, w, n, err) n, err = WriteOnto(&p.Channel, w, n, err)
n, err = WriteOnto(&p.Bytes, w, n, err) n, err = WriteOnto(&p.Bytes, w, n, err)
return return
} }
func ReadPacketSafe(r io.Reader) (pkt Packet, err error) { func ReadPacketSafe(r io.Reader) (pkt Packet, err error) {
chName, err := ReadStringSafe(r) chName, err := ReadStringSafe(r)
if err != nil { return } if err != nil {
// TODO: packet length sanity check. return
bytes, err := ReadByteSliceSafe(r) }
if err != nil { return } // TODO: packet length sanity check.
return NewPacket(chName, bytes), nil bytes, err := ReadByteSliceSafe(r)
if err != nil {
return
}
return NewPacket(chName, bytes), nil
} }
/* InboundPacket */ /* InboundPacket */
type InboundPacket struct { type InboundPacket struct {
Peer *Peer Peer *Peer
Channel *Channel Channel *Channel
Time Time Time Time
Packet Packet
} }
/* NewFilterMsg */ /* NewFilterMsg */
type NewFilterMsg struct { type NewFilterMsg struct {
ChName String ChName String
Filter interface{} // todo Filter interface{} // todo
} }
func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) { func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) {
panic("TODO: implement") panic("TODO: implement")
return 0, nil // TODO return 0, nil // TODO
} }

View File

@ -5,150 +5,158 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common"
"io" "io"
"net" "net"
"strconv" "strconv"
) )
/* NetAddress */ /* NetAddress */
type NetAddress struct { type NetAddress struct {
IP net.IP IP net.IP
Port UInt16 Port UInt16
} }
// TODO: socks proxies? // TODO: socks proxies?
func NewNetAddress(addr net.Addr) *NetAddress { func NewNetAddress(addr net.Addr) *NetAddress {
tcpAddr, ok := addr.(*net.TCPAddr) tcpAddr, ok := addr.(*net.TCPAddr)
if !ok { Panicf("Only TCPAddrs are supported. Got: %v", addr) } if !ok {
ip := tcpAddr.IP Panicf("Only TCPAddrs are supported. Got: %v", addr)
port := UInt16(tcpAddr.Port) }
return NewNetAddressIPPort(ip, port) ip := tcpAddr.IP
port := UInt16(tcpAddr.Port)
return NewNetAddressIPPort(ip, port)
} }
func NewNetAddressString(addr string) *NetAddress { func NewNetAddressString(addr string) *NetAddress {
host, portStr, err := net.SplitHostPort(addr) host, portStr, err := net.SplitHostPort(addr)
if err != nil { panic(err) } if err != nil {
ip := net.ParseIP(host) panic(err)
port, err := strconv.ParseUint(portStr, 10, 16) }
if err != nil { panic(err) } ip := net.ParseIP(host)
na := NewNetAddressIPPort(ip, UInt16(port)) port, err := strconv.ParseUint(portStr, 10, 16)
return na if err != nil {
panic(err)
}
na := NewNetAddressIPPort(ip, UInt16(port))
return na
} }
func NewNetAddressIPPort(ip net.IP, port UInt16) *NetAddress { func NewNetAddressIPPort(ip net.IP, port UInt16) *NetAddress {
na := NetAddress{ na := NetAddress{
IP: ip, IP: ip,
Port: port, Port: port,
} }
return &na return &na
} }
func ReadNetAddress(r io.Reader) *NetAddress { func ReadNetAddress(r io.Reader) *NetAddress {
return &NetAddress{ return &NetAddress{
IP: net.IP(ReadByteSlice(r)), IP: net.IP(ReadByteSlice(r)),
Port: ReadUInt16(r), Port: ReadUInt16(r),
} }
} }
func (na *NetAddress) WriteTo(w io.Writer) (n int64, err error) { func (na *NetAddress) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err) n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err)
n, err = WriteOnto(na.Port, w, n, err) n, err = WriteOnto(na.Port, w, n, err)
return return
} }
func (na *NetAddress) Equals(other Binary) bool { func (na *NetAddress) Equals(other Binary) bool {
if o, ok := other.(*NetAddress); ok { if o, ok := other.(*NetAddress); ok {
return na.String() == o.String() return na.String() == o.String()
} else { } else {
return false return false
} }
} }
func (na *NetAddress) Less(other Binary) bool { func (na *NetAddress) Less(other Binary) bool {
if o, ok := other.(*NetAddress); ok { if o, ok := other.(*NetAddress); ok {
return na.String() < o.String() return na.String() < o.String()
} else { } else {
panic("Cannot compare unequal types") panic("Cannot compare unequal types")
} }
} }
func (na *NetAddress) String() string { func (na *NetAddress) String() string {
port := strconv.FormatUint(uint64(na.Port), 10) port := strconv.FormatUint(uint64(na.Port), 10)
addr := net.JoinHostPort(na.IP.String(), port) addr := net.JoinHostPort(na.IP.String(), port)
return addr return addr
} }
func (na *NetAddress) Dial() (*Connection, error) { func (na *NetAddress) Dial() (*Connection, error) {
conn, err := net.Dial("tcp", na.String()) conn, err := net.Dial("tcp", na.String())
if err != nil { return nil, err } if err != nil {
return NewConnection(conn), nil return nil, err
}
return NewConnection(conn), nil
} }
func (na *NetAddress) Routable() bool { func (na *NetAddress) Routable() bool {
// TODO(oga) bitcoind doesn't include RFC3849 here, but should we? // TODO(oga) bitcoind doesn't include RFC3849 here, but should we?
return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() ||
na.RFC4193() || na.RFC4843() || na.Local()) na.RFC4193() || na.RFC4843() || na.Local())
} }
// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero // For IPv4 these are either a 0 or all bits set address. For IPv6 a zero
// address or one that matches the RFC3849 documentation address format. // address or one that matches the RFC3849 documentation address format.
func (na *NetAddress) Valid() bool { func (na *NetAddress) Valid() bool {
return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() ||
na.IP.Equal(net.IPv4bcast)) na.IP.Equal(net.IPv4bcast))
} }
func (na *NetAddress) Local() bool { func (na *NetAddress) Local() bool {
return na.IP.IsLoopback() || zero4.Contains(na.IP) return na.IP.IsLoopback() || zero4.Contains(na.IP)
} }
func (na *NetAddress) ReachabilityTo(o *NetAddress) int { func (na *NetAddress) ReachabilityTo(o *NetAddress) int {
const ( const (
Unreachable = 0 Unreachable = 0
Default = iota Default = iota
Teredo Teredo
Ipv6_weak Ipv6_weak
Ipv4 Ipv4
Ipv6_strong Ipv6_strong
Private Private
) )
if !na.Routable() { if !na.Routable() {
return Unreachable return Unreachable
} else if na.RFC4380() { } else if na.RFC4380() {
if !o.Routable() { if !o.Routable() {
return Default return Default
} else if o.RFC4380() { } else if o.RFC4380() {
return Teredo return Teredo
} else if o.IP.To4() != nil { } else if o.IP.To4() != nil {
return Ipv4 return Ipv4
} else { // ipv6 } else { // ipv6
return Ipv6_weak return Ipv6_weak
} }
} else if na.IP.To4() != nil { } else if na.IP.To4() != nil {
if o.Routable() && o.IP.To4() != nil { if o.Routable() && o.IP.To4() != nil {
return Ipv4 return Ipv4
} }
return Default return Default
} else /* ipv6 */ { } else /* ipv6 */ {
var tunnelled bool var tunnelled bool
// Is our v6 is tunnelled? // Is our v6 is tunnelled?
if o.RFC3964() || o.RFC6052() || o.RFC6145() { if o.RFC3964() || o.RFC6052() || o.RFC6145() {
tunnelled = true tunnelled = true
} }
if !o.Routable() { if !o.Routable() {
return Default return Default
} else if o.RFC4380() { } else if o.RFC4380() {
return Teredo return Teredo
} else if o.IP.To4() != nil { } else if o.IP.To4() != nil {
return Ipv4 return Ipv4
} else if tunnelled { } else if tunnelled {
// only prioritise ipv6 if we aren't tunnelling it. // only prioritise ipv6 if we aren't tunnelling it.
return Ipv6_weak return Ipv6_weak
} }
return Ipv6_strong return Ipv6_strong
} }
} }
// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) // RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12)
@ -161,23 +169,25 @@ func (na *NetAddress) ReachabilityTo(o *NetAddress) int {
// RFC4862: IPv6 Autoconfig (FE80::/64) // RFC4862: IPv6 Autoconfig (FE80::/64)
// RFC6052: IPv6 well known prefix (64:FF9B::/96) // RFC6052: IPv6 well known prefix (64:FF9B::/96)
// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96 // RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96
var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)}
var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)}
var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)}
var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)}
var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)}
var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)}
var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)}
var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)}
var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)}
func (na *NetAddress) RFC1918() bool { return rfc1918_10.Contains(na.IP) || func (na *NetAddress) RFC1918() bool {
rfc1918_192.Contains(na.IP) || return rfc1918_10.Contains(na.IP) ||
rfc1918_172.Contains(na.IP) } rfc1918_192.Contains(na.IP) ||
rfc1918_172.Contains(na.IP)
}
func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) } func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) }
func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) } func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) }
func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) } func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) }

View File

@ -1,172 +1,174 @@
package peer package peer
import ( import (
. "github.com/tendermint/tendermint/binary" "fmt"
"sync/atomic" . "github.com/tendermint/tendermint/binary"
"sync" "io"
"io" "sync"
"time" "sync/atomic"
"fmt" "time"
) )
/* Peer */ /* Peer */
type Peer struct { type Peer struct {
outgoing bool outgoing bool
conn *Connection conn *Connection
channels map[String]*Channel channels map[String]*Channel
mtx sync.Mutex mtx sync.Mutex
quit chan struct{} quit chan struct{}
stopped uint32 stopped uint32
} }
func NewPeer(conn *Connection) *Peer { func NewPeer(conn *Connection) *Peer {
return &Peer{ return &Peer{
conn: conn, conn: conn,
quit: make(chan struct{}), quit: make(chan struct{}),
stopped: 0, stopped: 0,
} }
} }
func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket ) { func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket) {
log.Debugf("Starting %v", p) log.Debugf("Starting %v", p)
p.conn.Start(p.channels) p.conn.Start(p.channels)
for chName, _ := range p.channels { for chName, _ := range p.channels {
go p.recvHandler(chName, peerRecvQueues[chName]) go p.recvHandler(chName, peerRecvQueues[chName])
go p.sendHandler(chName) go p.sendHandler(chName)
} }
} }
func (p *Peer) Stop() { func (p *Peer) Stop() {
// lock // lock
p.mtx.Lock() p.mtx.Lock()
if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
log.Debugf("Stopping %v", p) log.Debugf("Stopping %v", p)
close(p.quit) close(p.quit)
p.conn.Stop() p.conn.Stop()
} }
p.mtx.Unlock() p.mtx.Unlock()
// unlock // unlock
} }
func (p *Peer) LocalAddress() *NetAddress { func (p *Peer) LocalAddress() *NetAddress {
return p.conn.LocalAddress() return p.conn.LocalAddress()
} }
func (p *Peer) RemoteAddress() *NetAddress { func (p *Peer) RemoteAddress() *NetAddress {
return p.conn.RemoteAddress() return p.conn.RemoteAddress()
} }
func (p *Peer) Channel(chName String) *Channel { func (p *Peer) Channel(chName String) *Channel {
return p.channels[chName] return p.channels[chName]
} }
// If the channel's queue is full, just return false. // If the channel's queue is full, just return false.
// Later the sendHandler will send the pkt to the underlying connection. // Later the sendHandler will send the pkt to the underlying connection.
func (p *Peer) TrySend(pkt Packet) bool { func (p *Peer) TrySend(pkt Packet) bool {
channel := p.Channel(pkt.Channel) channel := p.Channel(pkt.Channel)
sendQueue := channel.SendQueue() sendQueue := channel.SendQueue()
// lock & defer // lock & defer
p.mtx.Lock(); defer p.mtx.Unlock() p.mtx.Lock()
if p.stopped == 1 { return false } defer p.mtx.Unlock()
select { if p.stopped == 1 {
case sendQueue <- pkt: return false
return true }
default: // buffer full select {
return false case sendQueue <- pkt:
} return true
// unlock deferred default: // buffer full
return false
}
// unlock deferred
} }
func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { func (p *Peer) WriteTo(w io.Writer) (n int64, err error) {
return p.RemoteAddress().WriteTo(w) return p.RemoteAddress().WriteTo(w)
} }
func (p *Peer) String() string { func (p *Peer) String() string {
return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing) return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing)
} }
func (p *Peer) recvHandler(chName String, inboundPacketQueue chan<- *InboundPacket) { func (p *Peer) recvHandler(chName String, inboundPacketQueue chan<- *InboundPacket) {
log.Tracef("%v recvHandler [%v]", p, chName) log.Tracef("%v recvHandler [%v]", p, chName)
channel := p.channels[chName] channel := p.channels[chName]
recvQueue := channel.RecvQueue() recvQueue := channel.RecvQueue()
FOR_LOOP: FOR_LOOP:
for { for {
select { select {
case <-p.quit: case <-p.quit:
break FOR_LOOP break FOR_LOOP
case pkt := <-recvQueue: case pkt := <-recvQueue:
// send to inboundPacketQueue // send to inboundPacketQueue
inboundPacket := &InboundPacket{ inboundPacket := &InboundPacket{
Peer: p, Peer: p,
Channel: channel, Channel: channel,
Time: Time{time.Now()}, Time: Time{time.Now()},
Packet: pkt, Packet: pkt,
} }
select { select {
case <-p.quit: case <-p.quit:
break FOR_LOOP break FOR_LOOP
case inboundPacketQueue <- inboundPacket: case inboundPacketQueue <- inboundPacket:
continue continue
} }
} }
} }
log.Tracef("%v recvHandler [%v] closed", p, chName) log.Tracef("%v recvHandler [%v] closed", p, chName)
// cleanup // cleanup
// (none) // (none)
} }
func (p *Peer) sendHandler(chName String) { func (p *Peer) sendHandler(chName String) {
log.Tracef("%v sendHandler [%v]", p, chName) log.Tracef("%v sendHandler [%v]", p, chName)
chSendQueue := p.channels[chName].sendQueue chSendQueue := p.channels[chName].sendQueue
FOR_LOOP: FOR_LOOP:
for { for {
select { select {
case <-p.quit: case <-p.quit:
break FOR_LOOP break FOR_LOOP
case pkt := <-chSendQueue: case pkt := <-chSendQueue:
log.Tracef("Sending packet to peer chSendQueue") log.Tracef("Sending packet to peer chSendQueue")
// blocks until the connection is Stop'd, // blocks until the connection is Stop'd,
// which happens when this peer is Stop'd. // which happens when this peer is Stop'd.
p.conn.Send(pkt) p.conn.Send(pkt)
} }
} }
log.Tracef("%v sendHandler [%v] closed", p, chName) log.Tracef("%v sendHandler [%v] closed", p, chName)
// cleanup // cleanup
// (none) // (none)
} }
/* Channel */ /* Channel */
type Channel struct { type Channel struct {
name String name String
recvQueue chan Packet recvQueue chan Packet
sendQueue chan Packet sendQueue chan Packet
//stats Stats //stats Stats
} }
func NewChannel(name String, bufferSize int) *Channel { func NewChannel(name String, bufferSize int) *Channel {
return &Channel{ return &Channel{
name: name, name: name,
recvQueue: make(chan Packet, bufferSize), recvQueue: make(chan Packet, bufferSize),
sendQueue: make(chan Packet, bufferSize), sendQueue: make(chan Packet, bufferSize),
} }
} }
func (c *Channel) Name() String { func (c *Channel) Name() String {
return c.name return c.name
} }
func (c *Channel) RecvQueue() <-chan Packet { func (c *Channel) RecvQueue() <-chan Packet {
return c.recvQueue return c.recvQueue
} }
func (c *Channel) SendQueue() chan<- Packet { func (c *Channel) SendQueue() chan<- Packet {
return c.sendQueue return c.sendQueue
} }

View File

@ -1,39 +1,38 @@
package peer package peer
import ( import ()
)
/* Server */ /* Server */
type Server struct { type Server struct {
listener Listener listener Listener
client *Client client *Client
} }
func NewServer(protocol string, laddr string, c *Client) *Server { func NewServer(protocol string, laddr string, c *Client) *Server {
l := NewDefaultListener(protocol, laddr) l := NewDefaultListener(protocol, laddr)
s := &Server{ s := &Server{
listener: l, listener: l,
client: c, client: c,
} }
go s.IncomingConnectionHandler() go s.IncomingConnectionHandler()
return s return s
} }
func (s *Server) LocalAddress() *NetAddress { func (s *Server) LocalAddress() *NetAddress {
return s.listener.LocalAddress() return s.listener.LocalAddress()
} }
// meant to run in a goroutine // meant to run in a goroutine
func (s *Server) IncomingConnectionHandler() { func (s *Server) IncomingConnectionHandler() {
for conn := range s.listener.Connections() { for conn := range s.listener.Connections() {
log.Infof("New connection found: %v", conn) log.Infof("New connection found: %v", conn)
s.client.AddPeerWithConnection(conn, false) s.client.AddPeerWithConnection(conn, false)
} }
} }
func (s *Server) Stop() { func (s *Server) Stop() {
log.Infof("Stopping server") log.Infof("Stopping server")
s.listener.Stop() s.listener.Stop()
s.client.Stop() s.client.Stop()
} }

View File

@ -7,370 +7,370 @@ package peer
// //
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
urnDomain string urnDomain string
} }
// protocol is either "udp" or "tcp" // protocol is either "udp" or "tcp"
type NAT interface { type NAT interface {
GetExternalAddress() (addr net.IP, err error) GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
} }
func Discover() (nat NAT, err error) { func Discover() (nat NAT, err error) {
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return
} }
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return
} }
socket := conn.(*net.UDPConn) socket := conn.(*net.UDPConn)
defer socket.Close() defer socket.Close()
err = socket.SetDeadline(time.Now().Add(3 * time.Second)) err = socket.SetDeadline(time.Now().Add(3 * time.Second))
if err != nil { if err != nil {
return return
} }
st := "InternetGatewayDevice:1" st := "InternetGatewayDevice:1"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
"HOST: 239.255.255.250:1900\r\n" + "HOST: 239.255.255.250:1900\r\n" +
"ST: ssdp:all\r\n" + "ST: ssdp:all\r\n" +
"MAN: \"ssdp:discover\"\r\n" + "MAN: \"ssdp:discover\"\r\n" +
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = socket.WriteToUDP(message, ssdp)
if err != nil { if err != nil {
return return
} }
var n int var n int
n, _, err = socket.ReadFromUDP(answerBytes) n, _, err = socket.ReadFromUDP(answerBytes)
for { for {
n, _, err = socket.ReadFromUDP(answerBytes) n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
break break
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:n])
if strings.Index(answer, st) < 0 { if strings.Index(answer, st) < 0 {
continue continue
} }
// HTTP header field names are case-insensitive. // HTTP header field names are case-insensitive.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
locString := "\r\nlocation:" locString := "\r\nlocation:"
answer = strings.ToLower(answer) answer = strings.ToLower(answer)
locIndex := strings.Index(answer, locString) locIndex := strings.Index(answer, locString)
if locIndex < 0 { if locIndex < 0 {
continue continue
} }
loc := answer[locIndex+len(locString):] loc := answer[locIndex+len(locString):]
endIndex := strings.Index(loc, "\r\n") endIndex := strings.Index(loc, "\r\n")
if endIndex < 0 { if endIndex < 0 {
continue continue
} }
locURL := strings.TrimSpace(loc[0:endIndex]) locURL := strings.TrimSpace(loc[0:endIndex])
var serviceURL, urnDomain string var serviceURL, urnDomain string
serviceURL, urnDomain, err = getServiceURL(locURL) serviceURL, urnDomain, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return
} }
var ourIP net.IP var ourIP net.IP
ourIP, err = localIPv4() ourIP, err = localIPv4()
if err != nil { if err != nil {
return return
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
return return
} }
} }
err = errors.New("UPnP port discovery failed.") err = errors.New("UPnP port discovery failed.")
return return
} }
type Envelope struct { type Envelope struct {
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
Soap *SoapBody Soap *SoapBody
} }
type SoapBody struct { type SoapBody struct {
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
ExternalIP *ExternalIPAddressResponse ExternalIP *ExternalIPAddressResponse
} }
type ExternalIPAddressResponse struct { type ExternalIPAddressResponse struct {
XMLName xml.Name `xml:"GetExternalIPAddressResponse"` XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
IPAddress string `xml:"NewExternalIPAddress"` IPAddress string `xml:"NewExternalIPAddress"`
} }
type ExternalIPAddress struct { type ExternalIPAddress struct {
XMLName xml.Name `xml:"NewExternalIPAddress"` XMLName xml.Name `xml:"NewExternalIPAddress"`
IP string IP string
} }
type Service struct { type Service struct {
ServiceType string `xml:"serviceType"` ServiceType string `xml:"serviceType"`
ControlURL string `xml:"controlURL"` ControlURL string `xml:"controlURL"`
} }
type DeviceList struct { type DeviceList struct {
Device []Device `xml:"device"` Device []Device `xml:"device"`
} }
type ServiceList struct { type ServiceList struct {
Service []Service `xml:"service"` Service []Service `xml:"service"`
} }
type Device struct { type Device struct {
XMLName xml.Name `xml:"device"` XMLName xml.Name `xml:"device"`
DeviceType string `xml:"deviceType"` DeviceType string `xml:"deviceType"`
DeviceList DeviceList `xml:"deviceList"` DeviceList DeviceList `xml:"deviceList"`
ServiceList ServiceList `xml:"serviceList"` ServiceList ServiceList `xml:"serviceList"`
} }
type Root struct { type Root struct {
Device Device Device Device
} }
func getChildDevice(d *Device, deviceType string) *Device { func getChildDevice(d *Device, deviceType string) *Device {
dl := d.DeviceList.Device dl := d.DeviceList.Device
for i := 0; i < len(dl); i++ { for i := 0; i < len(dl); i++ {
if strings.Index(dl[i].DeviceType, deviceType) >= 0 { if strings.Index(dl[i].DeviceType, deviceType) >= 0 {
return &dl[i] return &dl[i]
} }
} }
return nil return nil
} }
func getChildService(d *Device, serviceType string) *Service { func getChildService(d *Device, serviceType string) *Service {
sl := d.ServiceList.Service sl := d.ServiceList.Service
for i := 0; i < len(sl); i++ { for i := 0; i < len(sl); i++ {
if strings.Index(sl[i].ServiceType, serviceType) >= 0 { if strings.Index(sl[i].ServiceType, serviceType) >= 0 {
return &sl[i] return &sl[i]
} }
} }
return nil return nil
} }
func localIPv4() (net.IP, error) { func localIPv4() (net.IP, error) {
tt, err := net.Interfaces() tt, err := net.Interfaces()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, t := range tt { for _, t := range tt {
aa, err := t.Addrs() aa, err := t.Addrs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, a := range aa { for _, a := range aa {
ipnet, ok := a.(*net.IPNet) ipnet, ok := a.(*net.IPNet)
if !ok { if !ok {
continue continue
} }
v4 := ipnet.IP.To4() v4 := ipnet.IP.To4()
if v4 == nil || v4[0] == 127 { // loopback address if v4 == nil || v4[0] == 127 { // loopback address
continue continue
} }
return v4, nil return v4, nil
} }
} }
return nil, errors.New("cannot find local IP address") return nil, errors.New("cannot find local IP address")
} }
func getServiceURL(rootURL string) (url, urnDomain string, err error) { func getServiceURL(rootURL string) (url, urnDomain string, err error) {
r, err := http.Get(rootURL) r, err := http.Get(rootURL)
if err != nil { if err != nil {
return return
} }
defer r.Body.Close() defer r.Body.Close()
if r.StatusCode >= 400 { if r.StatusCode >= 400 {
err = errors.New(string(r.StatusCode)) err = errors.New(string(r.StatusCode))
return return
} }
var root Root var root Root
err = xml.NewDecoder(r.Body).Decode(&root) err = xml.NewDecoder(r.Body).Decode(&root)
if err != nil { if err != nil {
return return
} }
a := &root.Device a := &root.Device
if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 {
err = errors.New("No InternetGatewayDevice") err = errors.New("No InternetGatewayDevice")
return return
} }
b := getChildDevice(a, "WANDevice:1") b := getChildDevice(a, "WANDevice:1")
if b == nil { if b == nil {
err = errors.New("No WANDevice") err = errors.New("No WANDevice")
return return
} }
c := getChildDevice(b, "WANConnectionDevice:1") c := getChildDevice(b, "WANConnectionDevice:1")
if c == nil { if c == nil {
err = errors.New("No WANConnectionDevice") err = errors.New("No WANConnectionDevice")
return return
} }
d := getChildService(c, "WANIPConnection:1") d := getChildService(c, "WANIPConnection:1")
if d == nil { if d == nil {
// Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
// instead of under WanConnectionDevice // instead of under WanConnectionDevice
d = getChildService(b, "WANIPConnection:1") d = getChildService(b, "WANIPConnection:1")
if d == nil { if d == nil {
err = errors.New("No WANIPConnection") err = errors.New("No WANIPConnection")
return return
} }
} }
// Extract the domain name, which isn't always 'schemas-upnp-org' // Extract the domain name, which isn't always 'schemas-upnp-org'
urnDomain = strings.Split(d.ServiceType, ":")[1] urnDomain = strings.Split(d.ServiceType, ":")[1]
url = combineURL(rootURL, d.ControlURL) url = combineURL(rootURL, d.ControlURL)
return return
} }
func combineURL(rootURL, subURL string) string { func combineURL(rootURL, subURL string) string {
protocolEnd := "://" protocolEnd := "://"
protoEndIndex := strings.Index(rootURL, protocolEnd) protoEndIndex := strings.Index(rootURL, protocolEnd)
a := rootURL[protoEndIndex+len(protocolEnd):] a := rootURL[protoEndIndex+len(protocolEnd):]
rootIndex := strings.Index(a, "/") rootIndex := strings.Index(a, "/")
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
} }
func soapRequest(url, function, message, domain string) (r *http.Response, err error) { func soapRequest(url, function, message, domain string) (r *http.Response, err error) {
fullMessage := "<?xml version=\"1.0\" ?>" + fullMessage := "<?xml version=\"1.0\" ?>" +
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + "<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
"<s:Body>" + message + "</s:Body></s:Envelope>" "<s:Body>" + message + "</s:Body></s:Envelope>"
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
//req.Header.Set("Transfer-Encoding", "chunked") //req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
req.Header.Set("Connection", "Close") req.Header.Set("Connection", "Close")
req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache") req.Header.Set("Pragma", "no-cache")
// log.Stderr("soapRequest ", req) // log.Stderr("soapRequest ", req)
r, err = http.DefaultClient.Do(req) r, err = http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
/*if r.Body != nil { /*if r.Body != nil {
defer r.Body.Close() defer r.Body.Close()
}*/ }*/
if r.StatusCode >= 400 { if r.StatusCode >= 400 {
// log.Stderr(function, r.StatusCode) // log.Stderr(function, r.StatusCode)
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
r = nil r = nil
return return
} }
return return
} }
type statusInfo struct { type statusInfo struct {
externalIpAddress string externalIpAddress string
} }
func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) {
message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"</u:GetExternalIPAddress>" "</u:GetExternalIPAddress>"
var response *http.Response var response *http.Response
response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
if response != nil { if response != nil {
defer response.Body.Close() defer response.Body.Close()
} }
if err != nil { if err != nil {
return return
} }
var envelope Envelope var envelope Envelope
data, err := ioutil.ReadAll(response.Body) data, err := ioutil.ReadAll(response.Body)
reader := bytes.NewReader(data) reader := bytes.NewReader(data)
xml.NewDecoder(reader).Decode(&envelope) xml.NewDecoder(reader).Decode(&envelope)
info = statusInfo{envelope.Soap.ExternalIP.IPAddress} info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
if err != nil { if err != nil {
return return
} }
return return
} }
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getExternalIPAddress() info, err := n.getExternalIPAddress()
if err != nil { if err != nil {
return return
} }
addr = net.ParseIP(info.externalIpAddress) addr = net.ParseIP(info.externalIpAddress)
return return
} }
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation. // A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + "<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" "<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description + message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + "</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>" "</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
if response != nil { if response != nil {
defer response.Body.Close() defer response.Body.Close()
} }
if err != nil { if err != nil {
return return
} }
// TODO: check response to see if the port was forwarded // TODO: check response to see if the port was forwarded
// log.Println(message, response) // log.Println(message, response)
mappedExternalPort = externalPort mappedExternalPort = externalPort
_ = response _ = response
return return
} }
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + "<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>" "</u:DeletePortMapping>"
var response *http.Response var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
if response != nil { if response != nil {
defer response.Body.Close() defer response.Body.Close()
} }
if err != nil { if err != nil {
return return
} }
// TODO: check response to see if the port was deleted // TODO: check response to see if the port was deleted
// log.Println(message, response) // log.Println(message, response)
_ = response _ = response
return return
} }

View File

@ -1,8 +1,8 @@
package peer package peer
import ( import (
"testing" "testing"
"time" "time"
) )
/* /*
@ -11,38 +11,38 @@ TODO: set up or find a service to probe open ports.
*/ */
func TestUPNP(t *testing.T) { func TestUPNP(t *testing.T) {
t.Log("hello!") t.Log("hello!")
nat, err := Discover() nat, err := Discover()
if err != nil { if err != nil {
t.Fatalf("NAT upnp could not be discovered: %v", err) t.Fatalf("NAT upnp could not be discovered: %v", err)
} }
t.Log("ourIP: ", nat.(*upnpNAT).ourIP) t.Log("ourIP: ", nat.(*upnpNAT).ourIP)
ext, err := nat.GetExternalAddress() ext, err := nat.GetExternalAddress()
if err != nil { if err != nil {
t.Fatalf("External address error: %v", err) t.Fatalf("External address error: %v", err)
} }
t.Logf("External address: %v", ext) t.Logf("External address: %v", ext)
port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0) port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0)
if err != nil { if err != nil {
t.Fatalf("Port mapping error: %v", err) t.Fatalf("Port mapping error: %v", err)
} }
t.Logf("Port mapping mapped: %v", port) t.Logf("Port mapping mapped: %v", port)
// also run the listener, open for all remote addresses. // also run the listener, open for all remote addresses.
listener := NewDefaultListener("tcp", "0.0.0.0:8001") listener := NewDefaultListener("tcp", "0.0.0.0:8001")
// now sleep for 10 seconds // now sleep for 10 seconds
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
err = nat.DeletePortMapping("tcp", 8001, 8001) err = nat.DeletePortMapping("tcp", 8001, 8001)
if err != nil { if err != nil {
t.Fatalf("Port mapping delete error: %v", err) t.Fatalf("Port mapping delete error: %v", err)
} }
t.Logf("Port mapping deleted") t.Logf("Port mapping deleted")
listener.Stop() listener.Stop()
} }

View File

@ -1,7 +1,7 @@
package peer package peer
import ( import (
"crypto/sha256" "crypto/sha256"
) )
// DoubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes. // DoubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes.