This commit is contained in:
Jae Kwon 2014-07-01 14:50:24 -07:00
parent fa07748d23
commit c40fc65e6b
41 changed files with 3176 additions and 2938 deletions

View File

@ -3,12 +3,14 @@ package binary
import "io"
type Binary interface {
WriteTo(w io.Writer) (int64, error)
WriteTo(w io.Writer) (int64, error)
}
func WriteOnto(b Binary, w io.Writer, n int64, err error) (int64, error) {
if err != nil { return n, err }
var n_ int64
n_, err = b.WriteTo(w)
return n+n_, err
if err != nil {
return n, err
}
var n_ int64
n_, err = b.WriteTo(w)
return n + n_, err
}

View File

@ -6,44 +6,52 @@ import "bytes"
type ByteSlice []byte
func (self ByteSlice) Equals(other Binary) bool {
if o, ok := other.(ByteSlice); ok {
return bytes.Equal(self, o)
} else {
return false
}
if o, ok := other.(ByteSlice); ok {
return bytes.Equal(self, o)
} else {
return false
}
}
func (self ByteSlice) Less(other Binary) bool {
if o, ok := other.(ByteSlice); ok {
return bytes.Compare(self, o) < 0 // -1 if a < b
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(ByteSlice); ok {
return bytes.Compare(self, o) < 0 // -1 if a < b
} else {
panic("Cannot compare unequal types")
}
}
func (self ByteSlice) ByteSize() int {
return len(self)+4
return len(self) + 4
}
func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) {
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil { return n, err }
n_, err = w.Write([]byte(self))
return int64(n_+4), err
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil {
return n, err
}
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
}
func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) {
length, err := ReadUInt32Safe(r)
if err != nil { return nil, err }
bytes := make([]byte, int(length))
_, err = io.ReadFull(r, bytes)
if err != nil { return nil, err }
return bytes, nil
length, err := ReadUInt32Safe(r)
if err != nil {
return nil, err
}
bytes := make([]byte, int(length))
_, err = io.ReadFull(r, bytes)
if err != nil {
return nil, err
}
return bytes, nil
}
func ReadByteSlice(r io.Reader) ByteSlice {
bytes, err := ReadByteSliceSafe(r)
if r != nil { panic(err) }
return bytes
bytes, err := ReadByteSliceSafe(r)
if r != nil {
panic(err)
}
return bytes
}

View File

@ -1,70 +1,100 @@
package binary
import (
"io"
"io"
)
const (
TYPE_NIL = Byte(0x00)
TYPE_BYTE = Byte(0x01)
TYPE_INT8 = Byte(0x02)
TYPE_UINT8 = Byte(0x03)
TYPE_INT16 = Byte(0x04)
TYPE_UINT16 = Byte(0x05)
TYPE_INT32 = Byte(0x06)
TYPE_UINT32 = Byte(0x07)
TYPE_INT64 = Byte(0x08)
TYPE_UINT64 = Byte(0x09)
TYPE_NIL = Byte(0x00)
TYPE_BYTE = Byte(0x01)
TYPE_INT8 = Byte(0x02)
TYPE_UINT8 = Byte(0x03)
TYPE_INT16 = Byte(0x04)
TYPE_UINT16 = Byte(0x05)
TYPE_INT32 = Byte(0x06)
TYPE_UINT32 = Byte(0x07)
TYPE_INT64 = Byte(0x08)
TYPE_UINT64 = Byte(0x09)
TYPE_STRING = Byte(0x10)
TYPE_BYTESLICE = Byte(0x11)
TYPE_STRING = Byte(0x10)
TYPE_BYTESLICE = Byte(0x11)
TYPE_TIME = Byte(0x20)
TYPE_TIME = Byte(0x20)
)
func GetBinaryType(o Binary) Byte {
switch o.(type) {
case nil: return TYPE_NIL
case Byte: return TYPE_BYTE
case Int8: return TYPE_INT8
case UInt8: return TYPE_UINT8
case Int16: return TYPE_INT16
case UInt16: return TYPE_UINT16
case Int32: return TYPE_INT32
case UInt32: return TYPE_UINT32
case Int64: return TYPE_INT64
case UInt64: return TYPE_UINT64
case Int: panic("Int not supported")
case UInt: panic("UInt not supported")
switch o.(type) {
case nil:
return TYPE_NIL
case Byte:
return TYPE_BYTE
case Int8:
return TYPE_INT8
case UInt8:
return TYPE_UINT8
case Int16:
return TYPE_INT16
case UInt16:
return TYPE_UINT16
case Int32:
return TYPE_INT32
case UInt32:
return TYPE_UINT32
case Int64:
return TYPE_INT64
case UInt64:
return TYPE_UINT64
case Int:
panic("Int not supported")
case UInt:
panic("UInt not supported")
case String: return TYPE_STRING
case ByteSlice: return TYPE_BYTESLICE
case String:
return TYPE_STRING
case ByteSlice:
return TYPE_BYTESLICE
case Time: return TYPE_TIME
case Time:
return TYPE_TIME
default: panic("Unsupported type")
}
default:
panic("Unsupported type")
}
}
func ReadBinary(r io.Reader) Binary {
type_ := ReadByte(r)
switch type_ {
case TYPE_NIL: return nil
case TYPE_BYTE: return ReadByte(r)
case TYPE_INT8: return ReadInt8(r)
case TYPE_UINT8: return ReadUInt8(r)
case TYPE_INT16: return ReadInt16(r)
case TYPE_UINT16: return ReadUInt16(r)
case TYPE_INT32: return ReadInt32(r)
case TYPE_UINT32: return ReadUInt32(r)
case TYPE_INT64: return ReadInt64(r)
case TYPE_UINT64: return ReadUInt64(r)
type_ := ReadByte(r)
switch type_ {
case TYPE_NIL:
return nil
case TYPE_BYTE:
return ReadByte(r)
case TYPE_INT8:
return ReadInt8(r)
case TYPE_UINT8:
return ReadUInt8(r)
case TYPE_INT16:
return ReadInt16(r)
case TYPE_UINT16:
return ReadUInt16(r)
case TYPE_INT32:
return ReadInt32(r)
case TYPE_UINT32:
return ReadUInt32(r)
case TYPE_INT64:
return ReadInt64(r)
case TYPE_UINT64:
return ReadUInt64(r)
case TYPE_STRING: return ReadString(r)
case TYPE_BYTESLICE:return ReadByteSlice(r)
case TYPE_STRING:
return ReadString(r)
case TYPE_BYTESLICE:
return ReadByteSlice(r)
case TYPE_TIME: return ReadTime(r)
case TYPE_TIME:
return ReadTime(r)
default: panic("Unsupported type")
}
default:
panic("Unsupported type")
}
}

View File

@ -1,8 +1,8 @@
package binary
import (
"io"
"encoding/binary"
"encoding/binary"
"io"
)
type Byte byte
@ -17,397 +17,426 @@ type UInt64 uint64
type Int int
type UInt uint
// Byte
func (self Byte) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Byte) Less(other Binary) bool {
if o, ok := other.(Byte); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Byte); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Byte) ByteSize() int {
return 1
return 1
}
func (self Byte) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadByteSafe(r io.Reader) (Byte, error) {
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return 0, err }
return Byte(buf[0]), nil
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return 0, err
}
return Byte(buf[0]), nil
}
func ReadByte(r io.Reader) (Byte) {
b, err := ReadByteSafe(r)
if err != nil { panic(err) }
return b
func ReadByte(r io.Reader) Byte {
b, err := ReadByteSafe(r)
if err != nil {
panic(err)
}
return b
}
// Int8
func (self Int8) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Int8) Less(other Binary) bool {
if o, ok := other.(Int8); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Int8); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int8) ByteSize() int {
return 1
return 1
}
func (self Int8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadInt8Safe(r io.Reader) (Int8, error) {
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return Int8(0), err }
return Int8(buf[0]), nil
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return Int8(0), err
}
return Int8(buf[0]), nil
}
func ReadInt8(r io.Reader) (Int8) {
b, err := ReadInt8Safe(r)
if err != nil { panic(err) }
return b
func ReadInt8(r io.Reader) Int8 {
b, err := ReadInt8Safe(r)
if err != nil {
panic(err)
}
return b
}
// UInt8
func (self UInt8) Equals(other Binary) bool {
return self == other
return self == other
}
func (self UInt8) Less(other Binary) bool {
if o, ok := other.(UInt8); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(UInt8); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt8) ByteSize() int {
return 1
return 1
}
func (self UInt8) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write([]byte{byte(self)})
return int64(n), err
n, err := w.Write([]byte{byte(self)})
return int64(n), err
}
func ReadUInt8Safe(r io.Reader) (UInt8, error) {
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return UInt8(0), err }
return UInt8(buf[0]), nil
buf := [1]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt8(0), err
}
return UInt8(buf[0]), nil
}
func ReadUInt8(r io.Reader) (UInt8) {
b, err := ReadUInt8Safe(r)
if err != nil { panic(err) }
return b
func ReadUInt8(r io.Reader) UInt8 {
b, err := ReadUInt8Safe(r)
if err != nil {
panic(err)
}
return b
}
// Int16
func (self Int16) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Int16) Less(other Binary) bool {
if o, ok := other.(Int16); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Int16); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int16) ByteSize() int {
return 2
return 2
}
func (self Int16) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int16(self))
return 2, err
err := binary.Write(w, binary.LittleEndian, int16(self))
return 2, err
}
func ReadInt16Safe(r io.Reader) (Int16, error) {
buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return Int16(0), err }
return Int16(binary.LittleEndian.Uint16(buf[:])), nil
buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return Int16(0), err
}
return Int16(binary.LittleEndian.Uint16(buf[:])), nil
}
func ReadInt16(r io.Reader) (Int16) {
b, err := ReadInt16Safe(r)
if err != nil { panic(err) }
return b
func ReadInt16(r io.Reader) Int16 {
b, err := ReadInt16Safe(r)
if err != nil {
panic(err)
}
return b
}
// UInt16
func (self UInt16) Equals(other Binary) bool {
return self == other
return self == other
}
func (self UInt16) Less(other Binary) bool {
if o, ok := other.(UInt16); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(UInt16); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt16) ByteSize() int {
return 2
return 2
}
func (self UInt16) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint16(self))
return 2, err
err := binary.Write(w, binary.LittleEndian, uint16(self))
return 2, err
}
func ReadUInt16Safe(r io.Reader) (UInt16, error) {
buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return UInt16(0), err }
return UInt16(binary.LittleEndian.Uint16(buf[:])), nil
buf := [2]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt16(0), err
}
return UInt16(binary.LittleEndian.Uint16(buf[:])), nil
}
func ReadUInt16(r io.Reader) (UInt16) {
b, err := ReadUInt16Safe(r)
if err != nil { panic(err) }
return b
func ReadUInt16(r io.Reader) UInt16 {
b, err := ReadUInt16Safe(r)
if err != nil {
panic(err)
}
return b
}
// Int32
func (self Int32) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Int32) Less(other Binary) bool {
if o, ok := other.(Int32); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Int32); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int32) ByteSize() int {
return 4
return 4
}
func (self Int32) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int32(self))
return 4, err
err := binary.Write(w, binary.LittleEndian, int32(self))
return 4, err
}
func ReadInt32Safe(r io.Reader) (Int32, error) {
buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return Int32(0), err }
return Int32(binary.LittleEndian.Uint32(buf[:])), nil
buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return Int32(0), err
}
return Int32(binary.LittleEndian.Uint32(buf[:])), nil
}
func ReadInt32(r io.Reader) (Int32) {
b, err := ReadInt32Safe(r)
if err != nil { panic(err) }
return b
func ReadInt32(r io.Reader) Int32 {
b, err := ReadInt32Safe(r)
if err != nil {
panic(err)
}
return b
}
// UInt32
func (self UInt32) Equals(other Binary) bool {
return self == other
return self == other
}
func (self UInt32) Less(other Binary) bool {
if o, ok := other.(UInt32); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(UInt32); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt32) ByteSize() int {
return 4
return 4
}
func (self UInt32) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint32(self))
return 4, err
err := binary.Write(w, binary.LittleEndian, uint32(self))
return 4, err
}
func ReadUInt32Safe(r io.Reader) (UInt32, error) {
buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return UInt32(0), err }
return UInt32(binary.LittleEndian.Uint32(buf[:])), nil
buf := [4]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt32(0), err
}
return UInt32(binary.LittleEndian.Uint32(buf[:])), nil
}
func ReadUInt32(r io.Reader) (UInt32) {
b, err := ReadUInt32Safe(r)
if err != nil { panic(err) }
return b
func ReadUInt32(r io.Reader) UInt32 {
b, err := ReadUInt32Safe(r)
if err != nil {
panic(err)
}
return b
}
// Int64
func (self Int64) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Int64) Less(other Binary) bool {
if o, ok := other.(Int64); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Int64); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int64) ByteSize() int {
return 8
return 8
}
func (self Int64) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err
err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err
}
func ReadInt64Safe(r io.Reader) (Int64, error) {
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return Int64(0), err }
return Int64(binary.LittleEndian.Uint64(buf[:])), nil
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return Int64(0), err
}
return Int64(binary.LittleEndian.Uint64(buf[:])), nil
}
func ReadInt64(r io.Reader) (Int64) {
b, err := ReadInt64Safe(r)
if err != nil { panic(err) }
return b
func ReadInt64(r io.Reader) Int64 {
b, err := ReadInt64Safe(r)
if err != nil {
panic(err)
}
return b
}
// UInt64
func (self UInt64) Equals(other Binary) bool {
return self == other
return self == other
}
func (self UInt64) Less(other Binary) bool {
if o, ok := other.(UInt64); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(UInt64); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt64) ByteSize() int {
return 8
return 8
}
func (self UInt64) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err
err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err
}
func ReadUInt64Safe(r io.Reader) (UInt64, error) {
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { return UInt64(0), err }
return UInt64(binary.LittleEndian.Uint64(buf[:])), nil
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
return UInt64(0), err
}
return UInt64(binary.LittleEndian.Uint64(buf[:])), nil
}
func ReadUInt64(r io.Reader) (UInt64) {
b, err := ReadUInt64Safe(r)
if err != nil { panic(err) }
return b
func ReadUInt64(r io.Reader) UInt64 {
b, err := ReadUInt64Safe(r)
if err != nil {
panic(err)
}
return b
}
// Int
func (self Int) Equals(other Binary) bool {
return self == other
return self == other
}
func (self Int) Less(other Binary) bool {
if o, ok := other.(Int); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Int); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self Int) ByteSize() int {
return 8
return 8
}
func (self Int) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err
err := binary.Write(w, binary.LittleEndian, int64(self))
return 8, err
}
func ReadInt(r io.Reader) Int {
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { panic(err) }
return Int(binary.LittleEndian.Uint64(buf[:]))
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
panic(err)
}
return Int(binary.LittleEndian.Uint64(buf[:]))
}
// UInt
func (self UInt) Equals(other Binary) bool {
return self == other
return self == other
}
func (self UInt) Less(other Binary) bool {
if o, ok := other.(UInt); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(UInt); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self UInt) ByteSize() int {
return 8
return 8
}
func (self UInt) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err
err := binary.Write(w, binary.LittleEndian, uint64(self))
return 8, err
}
func ReadUInt(r io.Reader) UInt {
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil { panic(err) }
return UInt(binary.LittleEndian.Uint64(buf[:]))
buf := [8]byte{0}
_, err := io.ReadFull(r, buf[:])
if err != nil {
panic(err)
}
return UInt(binary.LittleEndian.Uint64(buf[:]))
}

View File

@ -7,40 +7,48 @@ type String string
// String
func (self String) Equals(other Binary) bool {
return self == other
return self == other
}
func (self String) Less(other Binary) bool {
if o, ok := other.(String); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(String); ok {
return self < o
} else {
panic("Cannot compare unequal types")
}
}
func (self String) ByteSize() int {
return len(self)+4
return len(self) + 4
}
func (self String) WriteTo(w io.Writer) (n int64, err error) {
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil { return n, err }
n_, err = w.Write([]byte(self))
return int64(n_+4), err
var n_ int
_, err = UInt32(len(self)).WriteTo(w)
if err != nil {
return n, err
}
n_, err = w.Write([]byte(self))
return int64(n_ + 4), err
}
func ReadStringSafe(r io.Reader) (String, error) {
length, err := ReadUInt32Safe(r)
if err != nil { return "", err }
bytes := make([]byte, int(length))
_, err = io.ReadFull(r, bytes)
if err != nil { return "", err }
return String(bytes), nil
length, err := ReadUInt32Safe(r)
if err != nil {
return "", err
}
bytes := make([]byte, int(length))
_, err = io.ReadFull(r, bytes)
if err != nil {
return "", err
}
return String(bytes), nil
}
func ReadString(r io.Reader) String {
str, err := ReadStringSafe(r)
if r != nil { panic(err) }
return str
str, err := ReadStringSafe(r)
if r != nil {
panic(err)
}
return str
}

View File

@ -1,38 +1,38 @@
package binary
import (
"io"
"time"
"io"
"time"
)
type Time struct {
time.Time
time.Time
}
func (self Time) Equals(other Binary) bool {
if o, ok := other.(Time); ok {
return self.Equal(o.Time)
} else {
return false
}
if o, ok := other.(Time); ok {
return self.Equal(o.Time)
} else {
return false
}
}
func (self Time) Less(other Binary) bool {
if o, ok := other.(Time); ok {
return self.Before(o.Time)
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(Time); ok {
return self.Before(o.Time)
} else {
panic("Cannot compare unequal types")
}
}
func (self Time) ByteSize() int {
return 8
return 8
}
func (self Time) WriteTo(w io.Writer) (int64, error) {
return Int64(self.Unix()).WriteTo(w)
return Int64(self.Unix()).WriteTo(w)
}
func ReadTime(r io.Reader) Time {
return Time{time.Unix(int64(ReadInt64(r)), 0)}
return Time{time.Unix(int64(ReadInt64(r)), 0)}
}

View File

@ -1,33 +1,35 @@
package binary
import (
"crypto/sha256"
"bytes"
"bytes"
"crypto/sha256"
)
func BinaryBytes(b Binary) ByteSlice {
buf := bytes.NewBuffer(nil)
b.WriteTo(buf)
return ByteSlice(buf.Bytes())
buf := bytes.NewBuffer(nil)
b.WriteTo(buf)
return ByteSlice(buf.Bytes())
}
// NOTE: does not care about the type, only the binary representation.
func BinaryEqual(a, b Binary) bool {
aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b)
return bytes.Equal(aBytes, bBytes)
aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b)
return bytes.Equal(aBytes, bBytes)
}
// NOTE: does not care about the type, only the binary representation.
func BinaryCompare(a, b Binary) int {
aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b)
return bytes.Compare(aBytes, bBytes)
aBytes := BinaryBytes(a)
bBytes := BinaryBytes(b)
return bytes.Compare(aBytes, bBytes)
}
func BinaryHash(b Binary) ByteSlice {
hasher := sha256.New()
_, err := b.WriteTo(hasher)
if err != nil { panic(err) }
return ByteSlice(hasher.Sum(nil))
hasher := sha256.New()
_, err := b.WriteTo(hasher)
if err != nil {
panic(err)
}
return ByteSlice(hasher.Sum(nil))
}

View File

@ -1,48 +1,48 @@
package blocks
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"io"
)
type AccountId struct {
Type Byte
Number UInt64
PubKey ByteSlice
Type Byte
Number UInt64
PubKey ByteSlice
}
const (
ACCOUNT_TYPE_NUMBER = Byte(0x01)
ACCOUNT_TYPE_PUBKEY = Byte(0x02)
ACCOUNT_TYPE_BOTH = Byte(0x03)
ACCOUNT_TYPE_NUMBER = Byte(0x01)
ACCOUNT_TYPE_PUBKEY = Byte(0x02)
ACCOUNT_TYPE_BOTH = Byte(0x03)
)
func ReadAccountId(r io.Reader) AccountId {
switch t := ReadByte(r); t {
case ACCOUNT_TYPE_NUMBER:
return AccountId{t, ReadUInt64(r), nil}
case ACCOUNT_TYPE_PUBKEY:
return AccountId{t, 0, ReadByteSlice(r)}
case ACCOUNT_TYPE_BOTH:
return AccountId{t, ReadUInt64(r), ReadByteSlice(r)}
default:
Panicf("Unknown AccountId type %x", t)
return AccountId{}
}
switch t := ReadByte(r); t {
case ACCOUNT_TYPE_NUMBER:
return AccountId{t, ReadUInt64(r), nil}
case ACCOUNT_TYPE_PUBKEY:
return AccountId{t, 0, ReadByteSlice(r)}
case ACCOUNT_TYPE_BOTH:
return AccountId{t, ReadUInt64(r), ReadByteSlice(r)}
default:
Panicf("Unknown AccountId type %x", t)
return AccountId{}
}
}
func (self AccountId) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type, w, n, err)
if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.Number, w, n, err)
}
if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.PubKey, w, n, err)
}
return
n, err = WriteOnto(self.Type, w, n, err)
if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.Number, w, n, err)
}
if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH {
n, err = WriteOnto(self.PubKey, w, n, err)
}
return
}
func AccountNumber(n UInt64) AccountId {
return AccountId{ACCOUNT_TYPE_NUMBER, n, nil}
return AccountId{ACCOUNT_TYPE_NUMBER, n, nil}
}

View File

@ -1,9 +1,9 @@
package blocks
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"io"
)
/* Adjustment
@ -17,126 +17,122 @@ TODO: signing a bad checkpoint (block)
*/
type Adjustment interface {
Type() Byte
Binary
Type() Byte
Binary
}
const (
ADJ_TYPE_BOND = Byte(0x01)
ADJ_TYPE_UNBOND = Byte(0x02)
ADJ_TYPE_TIMEOUT = Byte(0x03)
ADJ_TYPE_DUPEOUT = Byte(0x04)
ADJ_TYPE_BOND = Byte(0x01)
ADJ_TYPE_UNBOND = Byte(0x02)
ADJ_TYPE_TIMEOUT = Byte(0x03)
ADJ_TYPE_DUPEOUT = Byte(0x04)
)
func ReadAdjustment(r io.Reader) Adjustment {
switch t := ReadByte(r); t {
case ADJ_TYPE_BOND:
return &Bond{
Fee: ReadUInt64(r),
UnbondTo: ReadAccountId(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case ADJ_TYPE_UNBOND:
return &Unbond{
Fee: ReadUInt64(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case ADJ_TYPE_TIMEOUT:
return &Timeout{
Account: ReadAccountId(r),
Penalty: ReadUInt64(r),
}
case ADJ_TYPE_DUPEOUT:
return &Dupeout{
VoteA: ReadVote(r),
VoteB: ReadVote(r),
}
default:
Panicf("Unknown Adjustment type %x", t)
return nil
}
switch t := ReadByte(r); t {
case ADJ_TYPE_BOND:
return &Bond{
Fee: ReadUInt64(r),
UnbondTo: ReadAccountId(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case ADJ_TYPE_UNBOND:
return &Unbond{
Fee: ReadUInt64(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case ADJ_TYPE_TIMEOUT:
return &Timeout{
Account: ReadAccountId(r),
Penalty: ReadUInt64(r),
}
case ADJ_TYPE_DUPEOUT:
return &Dupeout{
VoteA: ReadVote(r),
VoteB: ReadVote(r),
}
default:
Panicf("Unknown Adjustment type %x", t)
return nil
}
}
/* Bond < Adjustment */
type Bond struct {
Fee UInt64
UnbondTo AccountId
Amount UInt64
Signature
Fee UInt64
UnbondTo AccountId
Amount UInt64
Signature
}
func (self *Bond) Type() Byte {
return ADJ_TYPE_BOND
return ADJ_TYPE_BOND
}
func (self *Bond) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.UnbondTo, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.UnbondTo, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
}
/* Unbond < Adjustment */
type Unbond struct {
Fee UInt64
Amount UInt64
Signature
Fee UInt64
Amount UInt64
Signature
}
func (self *Unbond) Type() Byte {
return ADJ_TYPE_UNBOND
return ADJ_TYPE_UNBOND
}
func (self *Unbond) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
}
/* Timeout < Adjustment */
type Timeout struct {
Account AccountId
Penalty UInt64
Account AccountId
Penalty UInt64
}
func (self *Timeout) Type() Byte {
return ADJ_TYPE_TIMEOUT
return ADJ_TYPE_TIMEOUT
}
func (self *Timeout) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Account, w, n, err)
n, err = WriteOnto(self.Penalty, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Account, w, n, err)
n, err = WriteOnto(self.Penalty, w, n, err)
return
}
/* Dupeout < Adjustment */
type Dupeout struct {
VoteA Vote
VoteB Vote
VoteA Vote
VoteB Vote
}
func (self *Dupeout) Type() Byte {
return ADJ_TYPE_DUPEOUT
return ADJ_TYPE_DUPEOUT
}
func (self *Dupeout) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.VoteA, w, n, err)
n, err = WriteOnto(self.VoteB, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.VoteA, w, n, err)
n, err = WriteOnto(self.VoteB, w, n, err)
return
}

View File

@ -1,140 +1,137 @@
package blocks
import (
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle"
"io"
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle"
"io"
)
/* Block */
type Block struct {
Header
Validation
Data
// Checkpoint
Header
Validation
Data
// Checkpoint
}
func ReadBlock(r io.Reader) *Block {
return &Block{
Header: ReadHeader(r),
Validation: ReadValidation(r),
Data: ReadData(r),
}
return &Block{
Header: ReadHeader(r),
Validation: ReadValidation(r),
Data: ReadData(r),
}
}
func (self *Block) Validate() bool {
return false
return false
}
func (self *Block) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(&self.Header, w, n, err)
n, err = WriteOnto(&self.Validation, w, n, err)
n, err = WriteOnto(&self.Data, w, n, err)
return
n, err = WriteOnto(&self.Header, w, n, err)
n, err = WriteOnto(&self.Validation, w, n, err)
n, err = WriteOnto(&self.Data, w, n, err)
return
}
/* Block > Header */
type Header struct {
Name String
Height UInt64
Fees UInt64
Time UInt64
PrevHash ByteSlice
ValidationHash ByteSlice
DataHash ByteSlice
Name String
Height UInt64
Fees UInt64
Time UInt64
PrevHash ByteSlice
ValidationHash ByteSlice
DataHash ByteSlice
}
func ReadHeader(r io.Reader) Header {
return Header{
Name: ReadString(r),
Height: ReadUInt64(r),
Fees: ReadUInt64(r),
Time: ReadUInt64(r),
PrevHash: ReadByteSlice(r),
ValidationHash: ReadByteSlice(r),
DataHash: ReadByteSlice(r),
}
return Header{
Name: ReadString(r),
Height: ReadUInt64(r),
Fees: ReadUInt64(r),
Time: ReadUInt64(r),
PrevHash: ReadByteSlice(r),
ValidationHash: ReadByteSlice(r),
DataHash: ReadByteSlice(r),
}
}
func (self *Header) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.Fees, w, n, err)
n, err = WriteOnto(self.Time, w, n, err)
n, err = WriteOnto(self.PrevHash, w, n, err)
n, err = WriteOnto(self.ValidationHash, w, n, err)
n, err = WriteOnto(self.DataHash, w, n, err)
return
n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.Fees, w, n, err)
n, err = WriteOnto(self.Time, w, n, err)
n, err = WriteOnto(self.PrevHash, w, n, err)
n, err = WriteOnto(self.ValidationHash, w, n, err)
n, err = WriteOnto(self.DataHash, w, n, err)
return
}
/* Block > Validation */
type Validation struct {
Signatures []Signature
Adjustments []Adjustment
Signatures []Signature
Adjustments []Adjustment
}
func ReadValidation(r io.Reader) Validation {
numSigs := int(ReadUInt64(r))
numAdjs := int(ReadUInt64(r))
sigs := make([]Signature, 0, numSigs)
for i:=0; i<numSigs; i++ {
sigs = append(sigs, ReadSignature(r))
}
adjs := make([]Adjustment, 0, numAdjs)
for i:=0; i<numAdjs; i++ {
adjs = append(adjs, ReadAdjustment(r))
}
return Validation{
Signatures: sigs,
Adjustments: adjs,
}
numSigs := int(ReadUInt64(r))
numAdjs := int(ReadUInt64(r))
sigs := make([]Signature, 0, numSigs)
for i := 0; i < numSigs; i++ {
sigs = append(sigs, ReadSignature(r))
}
adjs := make([]Adjustment, 0, numAdjs)
for i := 0; i < numAdjs; i++ {
adjs = append(adjs, ReadAdjustment(r))
}
return Validation{
Signatures: sigs,
Adjustments: adjs,
}
}
func (self *Validation) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(UInt64(len(self.Signatures)), w, n, err)
n, err = WriteOnto(UInt64(len(self.Adjustments)), w, n, err)
for _, sig := range self.Signatures {
n, err = WriteOnto(sig, w, n, err)
}
for _, adj := range self.Adjustments {
n, err = WriteOnto(adj, w, n, err)
}
return
n, err = WriteOnto(UInt64(len(self.Signatures)), w, n, err)
n, err = WriteOnto(UInt64(len(self.Adjustments)), w, n, err)
for _, sig := range self.Signatures {
n, err = WriteOnto(sig, w, n, err)
}
for _, adj := range self.Adjustments {
n, err = WriteOnto(adj, w, n, err)
}
return
}
/* Block > Data */
type Data struct {
Txs []Tx
Txs []Tx
}
func ReadData(r io.Reader) Data {
numTxs := int(ReadUInt64(r))
txs := make([]Tx, 0, numTxs)
for i:=0; i<numTxs; i++ {
txs = append(txs, ReadTx(r))
}
return Data{txs}
numTxs := int(ReadUInt64(r))
txs := make([]Tx, 0, numTxs)
for i := 0; i < numTxs; i++ {
txs = append(txs, ReadTx(r))
}
return Data{txs}
}
func (self *Data) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(UInt64(len(self.Txs)), w, n, err)
for _, tx := range self.Txs {
n, err = WriteOnto(tx, w, n, err)
}
return
n, err = WriteOnto(UInt64(len(self.Txs)), w, n, err)
for _, tx := range self.Txs {
n, err = WriteOnto(tx, w, n, err)
}
return
}
func (self *Data) MerkleHash() ByteSlice {
bs := make([]Binary, 0, len(self.Txs))
for i, tx := range self.Txs {
bs[i] = Binary(tx)
}
return merkle.HashFromBinarySlice(bs)
bs := make([]Binary, 0, len(self.Txs))
for i, tx := range self.Txs {
bs[i] = Binary(tx)
}
return merkle.HashFromBinarySlice(bs)
}

View File

@ -1,113 +1,115 @@
package blocks
import (
. "github.com/tendermint/tendermint/binary"
"testing"
"math/rand"
"bytes"
"bytes"
. "github.com/tendermint/tendermint/binary"
"math/rand"
"testing"
)
// Distributed pseudo-exponentially to test for various cases
func randVar() UInt64 {
bits := rand.Uint32() % 64
if bits == 0 { return 0 }
n := uint64(1 << (bits-1))
n += uint64(rand.Int63()) & ((1 << (bits-1)) - 1)
return UInt64(n)
bits := rand.Uint32() % 64
if bits == 0 {
return 0
}
n := uint64(1 << (bits - 1))
n += uint64(rand.Int63()) & ((1 << (bits - 1)) - 1)
return UInt64(n)
}
func randBytes(n int) ByteSlice {
bs := make([]byte, n)
for i:=0; i<n; i++ {
bs[i] = byte(rand.Intn(256))
}
return bs
bs := make([]byte, n)
for i := 0; i < n; i++ {
bs[i] = byte(rand.Intn(256))
}
return bs
}
func randSig() Signature {
return Signature{AccountNumber(randVar()), randBytes(32)}
return Signature{AccountNumber(randVar()), randBytes(32)}
}
func TestBlock(t *testing.T) {
// Txs
// Txs
sendTx := &SendTx{
Signature: randSig(),
Fee: randVar(),
To: AccountNumber(randVar()),
Amount: randVar(),
}
sendTx := &SendTx{
Signature: randSig(),
Fee: randVar(),
To: AccountNumber(randVar()),
Amount: randVar(),
}
nameTx := &NameTx{
Signature: randSig(),
Fee: randVar(),
Name: String(randBytes(12)),
PubKey: randBytes(32),
}
nameTx := &NameTx{
Signature: randSig(),
Fee: randVar(),
Name: String(randBytes(12)),
PubKey: randBytes(32),
}
// Adjs
// Adjs
bond := &Bond{
Signature: randSig(),
Fee: randVar(),
UnbondTo: AccountNumber(randVar()),
Amount: randVar(),
}
bond := &Bond{
Signature: randSig(),
Fee: randVar(),
UnbondTo: AccountNumber(randVar()),
Amount: randVar(),
}
unbond := &Unbond{
Signature: randSig(),
Fee: randVar(),
Amount: randVar(),
}
unbond := &Unbond{
Signature: randSig(),
Fee: randVar(),
Amount: randVar(),
}
timeout := &Timeout{
Account: AccountNumber(randVar()),
Penalty: randVar(),
}
timeout := &Timeout{
Account: AccountNumber(randVar()),
Penalty: randVar(),
}
dupeout := &Dupeout{
VoteA: Vote{
Height: randVar(),
BlockHash: randBytes(32),
Signature: randSig(),
},
VoteB: Vote{
Height: randVar(),
BlockHash: randBytes(32),
Signature: randSig(),
},
}
dupeout := &Dupeout{
VoteA: Vote{
Height: randVar(),
BlockHash: randBytes(32),
Signature: randSig(),
},
VoteB: Vote{
Height: randVar(),
BlockHash: randBytes(32),
Signature: randSig(),
},
}
// Block
// Block
block := &Block{
Header{
Name: "Tendermint",
Height: randVar(),
Fees: randVar(),
Time: randVar(),
PrevHash: randBytes(32),
ValidationHash: randBytes(32),
DataHash: randBytes(32),
},
Validation{
Signatures: []Signature{randSig(),randSig()},
Adjustments:[]Adjustment{bond,unbond,timeout,dupeout},
},
Data{
Txs: []Tx{sendTx, nameTx},
},
}
block := &Block{
Header{
Name: "Tendermint",
Height: randVar(),
Fees: randVar(),
Time: randVar(),
PrevHash: randBytes(32),
ValidationHash: randBytes(32),
DataHash: randBytes(32),
},
Validation{
Signatures: []Signature{randSig(), randSig()},
Adjustments: []Adjustment{bond, unbond, timeout, dupeout},
},
Data{
Txs: []Tx{sendTx, nameTx},
},
}
// Write the block, read it in again, write it again.
// Then, compare.
// Write the block, read it in again, write it again.
// Then, compare.
blockBytes := BinaryBytes(block)
block2 := ReadBlock(bytes.NewReader(blockBytes))
blockBytes2 := BinaryBytes(block2)
blockBytes := BinaryBytes(block)
block2 := ReadBlock(bytes.NewReader(blockBytes))
blockBytes2 := BinaryBytes(block2)
if !BinaryEqual(blockBytes, blockBytes2) {
t.Fatal("Write->Read of block failed.")
}
if !BinaryEqual(blockBytes, blockBytes2) {
t.Fatal("Write->Read of block failed.")
}
}

View File

@ -1,8 +1,8 @@
package blocks
import (
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
"io"
)
/*
@ -19,23 +19,23 @@ It usually follows the message to be signed.
*/
type Signature struct {
Signer AccountId
SigBytes ByteSlice
Signer AccountId
SigBytes ByteSlice
}
func ReadSignature(r io.Reader) Signature {
return Signature{
Signer: ReadAccountId(r),
SigBytes: ReadByteSlice(r),
}
return Signature{
Signer: ReadAccountId(r),
SigBytes: ReadByteSlice(r),
}
}
func (self Signature) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Signer, w, n, err)
n, err = WriteOnto(self.SigBytes, w, n, err)
return
n, err = WriteOnto(self.Signer, w, n, err)
n, err = WriteOnto(self.SigBytes, w, n, err)
return
}
func (self *Signature) Verify(msg ByteSlice) bool {
return false
return false
}

View File

@ -1,9 +1,9 @@
package blocks
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"io"
)
/*
@ -21,79 +21,77 @@ Tx wire format:
*/
type Tx interface {
Type() Byte
Binary
Type() Byte
Binary
}
const (
TX_TYPE_SEND = Byte(0x01)
TX_TYPE_NAME = Byte(0x02)
TX_TYPE_SEND = Byte(0x01)
TX_TYPE_NAME = Byte(0x02)
)
func ReadTx(r io.Reader) Tx {
switch t := ReadByte(r); t {
case TX_TYPE_SEND:
return &SendTx{
Fee: ReadUInt64(r),
To: ReadAccountId(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case TX_TYPE_NAME:
return &NameTx{
Fee: ReadUInt64(r),
Name: ReadString(r),
PubKey: ReadByteSlice(r),
Signature: ReadSignature(r),
}
default:
Panicf("Unknown Tx type %x", t)
return nil
}
switch t := ReadByte(r); t {
case TX_TYPE_SEND:
return &SendTx{
Fee: ReadUInt64(r),
To: ReadAccountId(r),
Amount: ReadUInt64(r),
Signature: ReadSignature(r),
}
case TX_TYPE_NAME:
return &NameTx{
Fee: ReadUInt64(r),
Name: ReadString(r),
PubKey: ReadByteSlice(r),
Signature: ReadSignature(r),
}
default:
Panicf("Unknown Tx type %x", t)
return nil
}
}
/* SendTx < Tx */
type SendTx struct {
Fee UInt64
To AccountId
Amount UInt64
Signature
Fee UInt64
To AccountId
Amount UInt64
Signature
}
func (self *SendTx) Type() Byte {
return TX_TYPE_SEND
return TX_TYPE_SEND
}
func (self *SendTx) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.To, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.To, w, n, err)
n, err = WriteOnto(self.Amount, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
}
/* NameTx < Tx */
type NameTx struct {
Fee UInt64
Name String
PubKey ByteSlice
Signature
Fee UInt64
Name String
PubKey ByteSlice
Signature
}
func (self *NameTx) Type() Byte {
return TX_TYPE_NAME
return TX_TYPE_NAME
}
func (self *NameTx) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.PubKey, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
n, err = WriteOnto(self.Type(), w, n, err)
n, err = WriteOnto(self.Fee, w, n, err)
n, err = WriteOnto(self.Name, w, n, err)
n, err = WriteOnto(self.PubKey, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
}

View File

@ -1,8 +1,8 @@
package blocks
import (
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
"io"
)
/*
@ -11,22 +11,22 @@ Typically only the signature is passed around, as the hash & height are implied.
*/
type Vote struct {
Height UInt64
BlockHash ByteSlice
Signature
Height UInt64
BlockHash ByteSlice
Signature
}
func ReadVote(r io.Reader) Vote {
return Vote{
Height: ReadUInt64(r),
BlockHash: ReadByteSlice(r),
Signature: ReadSignature(r),
}
return Vote{
Height: ReadUInt64(r),
BlockHash: ReadByteSlice(r),
Signature: ReadSignature(r),
}
}
func (self Vote) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.BlockHash, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
n, err = WriteOnto(self.Height, w, n, err)
n, err = WriteOnto(self.BlockHash, w, n, err)
n, err = WriteOnto(self.Signature, w, n, err)
return
}

View File

@ -1,45 +1,48 @@
package common
import (
"time"
"sync"
"sync"
"time"
)
/* Debouncer */
type Debouncer struct {
Ch chan struct{}
quit chan struct{}
dur time.Duration
mtx sync.Mutex
timer *time.Timer
Ch chan struct{}
quit chan struct{}
dur time.Duration
mtx sync.Mutex
timer *time.Timer
}
func NewDebouncer(dur time.Duration) *Debouncer {
var timer *time.Timer
var ch = make(chan struct{})
var quit = make(chan struct{})
var mtx sync.Mutex
fire := func() {
go func() {
select {
case ch <- struct{}{}:
case <-quit:
}
}()
mtx.Lock(); defer mtx.Unlock()
timer.Reset(dur)
}
timer = time.AfterFunc(dur, fire)
return &Debouncer{Ch:ch, dur:dur, quit:quit, mtx:mtx, timer:timer}
var timer *time.Timer
var ch = make(chan struct{})
var quit = make(chan struct{})
var mtx sync.Mutex
fire := func() {
go func() {
select {
case ch <- struct{}{}:
case <-quit:
}
}()
mtx.Lock()
defer mtx.Unlock()
timer.Reset(dur)
}
timer = time.AfterFunc(dur, fire)
return &Debouncer{Ch: ch, dur: dur, quit: quit, mtx: mtx, timer: timer}
}
func (d *Debouncer) Reset() {
d.mtx.Lock(); defer d.mtx.Unlock()
d.timer.Reset(d.dur)
d.mtx.Lock()
defer d.mtx.Unlock()
d.timer.Reset(d.dur)
}
func (d *Debouncer) Stop() bool {
d.mtx.Lock(); defer d.mtx.Unlock()
close(d.quit)
return d.timer.Stop()
d.mtx.Lock()
defer d.mtx.Unlock()
close(d.quit)
return d.timer.Stop()
}

View File

@ -1,28 +1,28 @@
package common
import (
"container/heap"
"container/heap"
)
type Heap struct {
pq priorityQueue
pq priorityQueue
}
func NewHeap() *Heap {
return &Heap{pq:make([]*pqItem, 0)}
return &Heap{pq: make([]*pqItem, 0)}
}
func (h *Heap) Len() int {
return len(h.pq)
return len(h.pq)
}
func (h *Heap) Push(value interface{}, priority int) {
heap.Push(&h.pq, &pqItem{value:value, priority:priority})
heap.Push(&h.pq, &pqItem{value: value, priority: priority})
}
func (h *Heap) Pop() interface{} {
item := heap.Pop(&h.pq).(*pqItem)
return item.value
item := heap.Pop(&h.pq).(*pqItem)
return item.value
}
/*
@ -43,9 +43,9 @@ func main() {
// From: http://golang.org/pkg/container/heap/#example__priorityQueue
type pqItem struct {
value interface{}
priority int
index int
value interface{}
priority int
index int
}
type priorityQueue []*pqItem
@ -53,35 +53,34 @@ type priorityQueue []*pqItem
func (pq priorityQueue) Len() int { return len(pq) }
func (pq priorityQueue) Less(i, j int) bool {
return pq[i].priority < pq[j].priority
return pq[i].priority < pq[j].priority
}
func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *priorityQueue) Push(x interface{}) {
n := len(*pq)
item := x.(*pqItem)
item.index = n
*pq = append(*pq, item)
n := len(*pq)
item := x.(*pqItem)
item.index = n
*pq = append(*pq, item)
}
func (pq *priorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
item.index = -1 // for safety
*pq = old[0 : n-1]
return item
old := *pq
n := len(old)
item := old[n-1]
item.index = -1 // for safety
*pq = old[0 : n-1]
return item
}
func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority int) {
heap.Remove(pq, item.index)
item.value = value
item.priority = priority
heap.Push(pq, item)
heap.Remove(pq, item.index)
item.value = value
item.priority = priority
heap.Push(pq, item)
}

View File

@ -1,9 +1,9 @@
package common
import (
"fmt"
"fmt"
)
func Panicf(s string, args ...interface{}) {
panic(fmt.Sprintf(s, args...))
panic(fmt.Sprintf(s, args...))
}

View File

@ -1,109 +1,115 @@
package config
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"errors"
//"crypto/rand"
//"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
//"crypto/rand"
//"encoding/hex"
)
var APP_DIR = os.Getenv("HOME") + "/.tendermint"
/* Global & initialization */
var Config Config_
func init() {
configFile := APP_DIR+"/config.json"
configFile := APP_DIR + "/config.json"
// try to read configuration. if missing, write default
configBytes, err := ioutil.ReadFile(configFile)
if err != nil {
defaultConfig.write(configFile)
fmt.Println("Config file written to config.json. Please edit & run again")
os.Exit(1)
return
}
// try to read configuration. if missing, write default
configBytes, err := ioutil.ReadFile(configFile)
if err != nil {
defaultConfig.write(configFile)
fmt.Println("Config file written to config.json. Please edit & run again")
os.Exit(1)
return
}
// try to parse configuration. on error, die
Config = Config_{}
err = json.Unmarshal(configBytes, &Config)
if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err)
}
err = Config.validate()
if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err)
}
// try to parse configuration. on error, die
Config = Config_{}
err = json.Unmarshal(configBytes, &Config)
if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err)
}
err = Config.validate()
if err != nil {
log.Panicf("Invalid configuration file %s: %v", configFile, err)
}
}
/* Default configuration */
var defaultConfig = Config_{
Host: "127.0.0.1",
Port: 8770,
Db: DbConfig{
Type: "level",
Dir: APP_DIR+"/data",
},
Twilio: TwilioConfig{
},
Host: "127.0.0.1",
Port: 8770,
Db: DbConfig{
Type: "level",
Dir: APP_DIR + "/data",
},
Twilio: TwilioConfig{},
}
/* Configuration types */
type Config_ struct {
Host string
Port int
Db DbConfig
Twilio TwilioConfig
Host string
Port int
Db DbConfig
Twilio TwilioConfig
}
type TwilioConfig struct {
Sid string
Token string
From string
To string
MinInterval int
Sid string
Token string
From string
To string
MinInterval int
}
type DbConfig struct {
Type string
Dir string
Type string
Dir string
}
func (cfg *Config_) validate() error {
if cfg.Host == "" { return errors.New("Host must be set") }
if cfg.Port == 0 { return errors.New("Port must be set") }
if cfg.Db.Type == "" { return errors.New("Db.Type must be set") }
return nil
if cfg.Host == "" {
return errors.New("Host must be set")
}
if cfg.Port == 0 {
return errors.New("Port must be set")
}
if cfg.Db.Type == "" {
return errors.New("Db.Type must be set")
}
return nil
}
func (cfg *Config_) bytes() []byte {
configBytes, err := json.Marshal(cfg)
if err != nil { panic(err) }
return configBytes
configBytes, err := json.Marshal(cfg)
if err != nil {
panic(err)
}
return configBytes
}
func (cfg *Config_) write(configFile string) {
if strings.Index(configFile, "/") != -1 {
err := os.MkdirAll(filepath.Dir(configFile), 0700)
if err != nil { panic(err) }
}
err := ioutil.WriteFile(configFile, cfg.bytes(), 0600)
if err != nil {
panic(err)
}
if strings.Index(configFile, "/") != -1 {
err := os.MkdirAll(filepath.Dir(configFile), 0700)
if err != nil {
panic(err)
}
}
err := ioutil.WriteFile(configFile, cfg.bytes(), 0600)
if err != nil {
panic(err)
}
}
/* TODO: generate priv/pub keys
@ -113,4 +119,3 @@ func generateKeys() string {
return hex.EncodeToString(bytes[:])
}
*/

View File

@ -11,60 +11,60 @@ import "C"
import "unsafe"
type Verify struct {
Message []byte
PubKey []byte
Signature []byte
Valid bool
Message []byte
PubKey []byte
Signature []byte
Valid bool
}
func MakePubKey(privKey []byte) []byte {
pubKey := [32]byte{}
C.ed25519_publickey(
(*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])),
)
return pubKey[:]
pubKey := [32]byte{}
C.ed25519_publickey(
(*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])),
)
return pubKey[:]
}
func SignMessage(message []byte, privKey []byte, pubKey []byte) []byte {
sig := [64]byte{}
C.ed25519_sign(
(*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)),
(*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])),
(*C.uchar)(unsafe.Pointer(&sig[0])),
)
return sig[:]
sig := [64]byte{}
C.ed25519_sign(
(*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)),
(*C.uchar)(unsafe.Pointer(&privKey[0])),
(*C.uchar)(unsafe.Pointer(&pubKey[0])),
(*C.uchar)(unsafe.Pointer(&sig[0])),
)
return sig[:]
}
func VerifyBatch(verifys []*Verify) bool {
count := len(verifys)
count := len(verifys)
msgs := make([]*byte, count)
lens := make([]C.size_t, count)
pubs := make([]*byte, count)
sigs := make([]*byte, count)
valids := make([]C.int, count)
msgs := make([]*byte, count)
lens := make([]C.size_t, count)
pubs := make([]*byte, count)
sigs := make([]*byte, count)
valids := make([]C.int, count)
for i, v := range verifys {
msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0]))
lens[i] = (C.size_t)(len(v.Message))
pubs[i] = (*byte)(&v.PubKey[0])
sigs[i] = (*byte)(&v.Signature[0])
}
for i, v := range verifys {
msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0]))
lens[i] = (C.size_t)(len(v.Message))
pubs[i] = (*byte)(&v.PubKey[0])
sigs[i] = (*byte)(&v.Signature[0])
}
count_ := (C.size_t)(count)
msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0]))
lens_ := (*C.size_t)(unsafe.Pointer(&lens[0]))
pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0]))
sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0]))
count_ := (C.size_t)(count)
msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0]))
lens_ := (*C.size_t)(unsafe.Pointer(&lens[0]))
pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0]))
sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0]))
res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0])
res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0])
for i, valid := range valids {
verifys[i].Valid = valid > 0
}
for i, valid := range valids {
verifys[i].Valid = valid > 0
}
return res == 0
return res == 0
}

View File

@ -1,35 +1,47 @@
package crypto
import (
"testing"
"crypto/rand"
"crypto/rand"
"testing"
)
func TestSign(t *testing.T) {
privKey := make([]byte, 32)
_, err := rand.Read(privKey)
if err != nil { t.Fatal(err) }
pubKey := MakePubKey(privKey)
signature := SignMessage([]byte("hello"), privKey, pubKey)
privKey := make([]byte, 32)
_, err := rand.Read(privKey)
if err != nil {
t.Fatal(err)
}
pubKey := MakePubKey(privKey)
signature := SignMessage([]byte("hello"), privKey, pubKey)
v1 := &Verify{
Message: []byte("hello"),
PubKey: pubKey,
Signature: signature,
}
v1 := &Verify{
Message: []byte("hello"),
PubKey: pubKey,
Signature: signature,
}
ok := VerifyBatch([]*Verify{v1, v1, v1, v1})
if ok != true { t.Fatal("Expected ok == true") }
if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") }
ok := VerifyBatch([]*Verify{v1, v1, v1, v1})
if ok != true {
t.Fatal("Expected ok == true")
}
if v1.Valid != true {
t.Fatal("Expected v1.Valid to be true")
}
v2 := &Verify{
Message: []byte{0x73},
PubKey: pubKey,
Signature: signature,
}
v2 := &Verify{
Message: []byte{0x73},
PubKey: pubKey,
Signature: signature,
}
ok = VerifyBatch([]*Verify{v1, v1, v1, v2})
if ok != false { t.Fatal("Expected ok == false") }
if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") }
if v2.Valid != false { t.Fatal("Expected v2.Valid to be true") }
ok = VerifyBatch([]*Verify{v1, v1, v1, v2})
if ok != false {
t.Fatal("Expected ok == false")
}
if v1.Valid != true {
t.Fatal("Expected v1.Valid to be true")
}
if v2.Valid != false {
t.Fatal("Expected v2.Valid to be true")
}
}

View File

@ -1,54 +1,60 @@
package db
import (
"fmt"
"github.com/syndtr/goleveldb/leveldb"
"path"
"fmt"
"github.com/syndtr/goleveldb/leveldb"
"path"
)
type LevelDB struct {
db *leveldb.DB
db *leveldb.DB
}
func NewLevelDB(name string) (*LevelDB, error) {
dbPath := path.Join(name)
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
return nil, err
}
database := &LevelDB{db: db}
return database, nil
dbPath := path.Join(name)
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
return nil, err
}
database := &LevelDB{db: db}
return database, nil
}
func (db *LevelDB) Put(key []byte, value []byte) {
err := db.db.Put(key, value, nil)
if err != nil { panic(err) }
err := db.db.Put(key, value, nil)
if err != nil {
panic(err)
}
}
func (db *LevelDB) Get(key []byte) ([]byte) {
res, err := db.db.Get(key, nil)
if err != nil { panic(err) }
return res
func (db *LevelDB) Get(key []byte) []byte {
res, err := db.db.Get(key, nil)
if err != nil {
panic(err)
}
return res
}
func (db *LevelDB) Delete(key []byte) {
err := db.db.Delete(key, nil)
if err != nil { panic(err) }
err := db.db.Delete(key, nil)
if err != nil {
panic(err)
}
}
func (db *LevelDB) Db() *leveldb.DB {
return db.db
return db.db
}
func (db *LevelDB) Close() {
db.db.Close()
db.db.Close()
}
func (db *LevelDB) Print() {
iter := db.db.NewIterator(nil, nil)
for iter.Next() {
key := iter.Key()
value := iter.Value()
fmt.Printf("[%x]:\t[%x]", key, value)
}
iter := db.db.NewIterator(nil, nil)
for iter.Next() {
key := iter.Key()
value := iter.Value()
fmt.Printf("[%x]:\t[%x]", key, value)
}
}

View File

@ -1,32 +1,32 @@
package db
import (
"fmt"
"fmt"
)
type MemDB struct {
db map[string][]byte
db map[string][]byte
}
func NewMemDB() (*MemDB) {
database := &MemDB{db:make(map[string][]byte)}
return database
func NewMemDB() *MemDB {
database := &MemDB{db: make(map[string][]byte)}
return database
}
func (db *MemDB) Put(key []byte, value []byte) {
db.db[string(key)] = value
db.db[string(key)] = value
}
func (db *MemDB) Get(key []byte) ([]byte) {
return db.db[string(key)]
func (db *MemDB) Get(key []byte) []byte {
return db.db[string(key)]
}
func (db *MemDB) Delete(key []byte) {
delete(db.db, string(key))
delete(db.db, string(key))
}
func (db *MemDB) Print() {
for key, value := range db.db {
fmt.Printf("[%x]:\t[%x]", []byte(key), value)
}
for key, value := range db.db {
fmt.Printf("[%x]:\t[%x]", []byte(key), value)
}
}

View File

@ -1,405 +1,447 @@
package merkle
import (
. "github.com/tendermint/tendermint/binary"
"bytes"
"io"
"crypto/sha256"
"bytes"
"crypto/sha256"
. "github.com/tendermint/tendermint/binary"
"io"
)
// Node
type IAVLNode struct {
key Key
value Value
size uint64
height uint8
hash ByteSlice
left *IAVLNode
right *IAVLNode
key Key
value Value
size uint64
height uint8
hash ByteSlice
left *IAVLNode
right *IAVLNode
// volatile
flags byte
// volatile
flags byte
}
const (
IAVLNODE_FLAG_PERSISTED = byte(0x01)
IAVLNODE_FLAG_PLACEHOLDER = byte(0x02)
IAVLNODE_FLAG_PERSISTED = byte(0x01)
IAVLNODE_FLAG_PLACEHOLDER = byte(0x02)
)
func NewIAVLNode(key Key, value Value) *IAVLNode {
return &IAVLNode{
key: key,
value: value,
size: 1,
}
return &IAVLNode{
key: key,
value: value,
size: 1,
}
}
func (self *IAVLNode) Copy() *IAVLNode {
if self.height == 0 {
panic("Why are you copying a value node?")
}
return &IAVLNode{
key: self.key,
size: self.size,
height: self.height,
left: self.left,
right: self.right,
hash: nil,
flags: byte(0),
}
if self.height == 0 {
panic("Why are you copying a value node?")
}
return &IAVLNode{
key: self.key,
size: self.size,
height: self.height,
left: self.left,
right: self.right,
hash: nil,
flags: byte(0),
}
}
func (self *IAVLNode) Key() Key {
return self.key
return self.key
}
func (self *IAVLNode) Value() Value {
return self.value
return self.value
}
func (self *IAVLNode) Size() uint64 {
return self.size
return self.size
}
func (self *IAVLNode) Height() uint8 {
return self.height
return self.height
}
func (self *IAVLNode) has(db Db, key Key) (has bool) {
if self.key.Equals(key) {
return true
}
if self.height == 0 {
return false
} else {
if key.Less(self.key) {
return self.leftFilled(db).has(db, key)
} else {
return self.rightFilled(db).has(db, key)
}
}
if self.key.Equals(key) {
return true
}
if self.height == 0 {
return false
} else {
if key.Less(self.key) {
return self.leftFilled(db).has(db, key)
} else {
return self.rightFilled(db).has(db, key)
}
}
}
func (self *IAVLNode) get(db Db, key Key) (value Value) {
if self.height == 0 {
if self.key.Equals(key) {
return self.value
} else {
return nil
}
} else {
if key.Less(self.key) {
return self.leftFilled(db).get(db, key)
} else {
return self.rightFilled(db).get(db, key)
}
}
if self.height == 0 {
if self.key.Equals(key) {
return self.value
} else {
return nil
}
} else {
if key.Less(self.key) {
return self.leftFilled(db).get(db, key)
} else {
return self.rightFilled(db).get(db, key)
}
}
}
func (self *IAVLNode) Hash() (ByteSlice, uint64) {
if self.hash != nil {
return self.hash, 0
}
if self.hash != nil {
return self.hash, 0
}
hasher := sha256.New()
_, hashCount, err := self.saveToCountHashes(hasher, false)
if err != nil { panic(err) }
self.hash = hasher.Sum(nil)
hasher := sha256.New()
_, hashCount, err := self.saveToCountHashes(hasher, false)
if err != nil {
panic(err)
}
self.hash = hasher.Sum(nil)
return self.hash, hashCount+1
return self.hash, hashCount + 1
}
func (self *IAVLNode) Save(db Db) {
if self.hash == nil {
panic("savee.hash can't be nil")
}
if self.flags & IAVLNODE_FLAG_PERSISTED > 0 ||
self.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 {
return
}
if self.hash == nil {
panic("savee.hash can't be nil")
}
if self.flags&IAVLNODE_FLAG_PERSISTED > 0 ||
self.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
return
}
// children
if self.height > 0 {
self.left.Save(db)
self.right.Save(db)
}
// children
if self.height > 0 {
self.left.Save(db)
self.right.Save(db)
}
// save self
buf := bytes.NewBuffer(nil)
_, err := self.WriteTo(buf)
if err != nil { panic(err) }
db.Put([]byte(self.hash), buf.Bytes())
// save self
buf := bytes.NewBuffer(nil)
_, err := self.WriteTo(buf)
if err != nil {
panic(err)
}
db.Put([]byte(self.hash), buf.Bytes())
self.flags |= IAVLNODE_FLAG_PERSISTED
self.flags |= IAVLNODE_FLAG_PERSISTED
}
func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) {
if self.height == 0 {
if key.Less(self.key) {
return &IAVLNode{
key: self.key,
height: 1,
size: 2,
left: NewIAVLNode(key, value),
right: self,
}, false
} else if self.key.Equals(key) {
return NewIAVLNode(key, value), true
} else {
return &IAVLNode{
key: key,
height: 1,
size: 2,
left: self,
right: NewIAVLNode(key, value),
}, false
}
} else {
self = self.Copy()
if key.Less(self.key) {
self.left, updated = self.leftFilled(db).put(db, key, value)
} else {
self.right, updated = self.rightFilled(db).put(db, key, value)
}
if updated {
return self, updated
} else {
self.calcHeightAndSize(db)
return self.balance(db), updated
}
}
if self.height == 0 {
if key.Less(self.key) {
return &IAVLNode{
key: self.key,
height: 1,
size: 2,
left: NewIAVLNode(key, value),
right: self,
}, false
} else if self.key.Equals(key) {
return NewIAVLNode(key, value), true
} else {
return &IAVLNode{
key: key,
height: 1,
size: 2,
left: self,
right: NewIAVLNode(key, value),
}, false
}
} else {
self = self.Copy()
if key.Less(self.key) {
self.left, updated = self.leftFilled(db).put(db, key, value)
} else {
self.right, updated = self.rightFilled(db).put(db, key, value)
}
if updated {
return self, updated
} else {
self.calcHeightAndSize(db)
return self.balance(db), updated
}
}
}
// newKey: new leftmost leaf key for tree after successfully removing 'key' if changed.
func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) {
if self.height == 0 {
if self.key.Equals(key) {
return nil, nil, self.value, nil
} else {
return self, nil, nil, NotFound(key)
}
} else {
if key.Less(self.key) {
var newLeft *IAVLNode
newLeft, newKey, value, err = self.leftFilled(db).remove(db, key)
if err != nil {
return self, nil, value, err
} else if newLeft == nil { // left node held value, was removed
return self.right, self.key, value, nil
}
self = self.Copy()
self.left = newLeft
} else {
var newRight *IAVLNode
newRight, newKey, value, err = self.rightFilled(db).remove(db, key)
if err != nil {
return self, nil, value, err
} else if newRight == nil { // right node held value, was removed
return self.left, nil, value, nil
}
self = self.Copy()
self.right = newRight
if newKey != nil {
self.key = newKey
newKey = nil
}
}
self.calcHeightAndSize(db)
return self.balance(db), newKey, value, err
}
if self.height == 0 {
if self.key.Equals(key) {
return nil, nil, self.value, nil
} else {
return self, nil, nil, NotFound(key)
}
} else {
if key.Less(self.key) {
var newLeft *IAVLNode
newLeft, newKey, value, err = self.leftFilled(db).remove(db, key)
if err != nil {
return self, nil, value, err
} else if newLeft == nil { // left node held value, was removed
return self.right, self.key, value, nil
}
self = self.Copy()
self.left = newLeft
} else {
var newRight *IAVLNode
newRight, newKey, value, err = self.rightFilled(db).remove(db, key)
if err != nil {
return self, nil, value, err
} else if newRight == nil { // right node held value, was removed
return self.left, nil, value, nil
}
self = self.Copy()
self.right = newRight
if newKey != nil {
self.key = newKey
newKey = nil
}
}
self.calcHeightAndSize(db)
return self.balance(db), newKey, value, err
}
}
func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) {
n, _, err = self.saveToCountHashes(w, true)
return
n, _, err = self.saveToCountHashes(w, true)
return
}
func (self *IAVLNode) saveToCountHashes(w io.Writer, meta bool) (n int64, hashCount uint64, err error) {
var _n int64
var _n int64
if meta {
// height & size
_n, err = UInt8(self.height).WriteTo(w)
if err != nil { return } else { n += _n }
_n, err = UInt64(self.size).WriteTo(w)
if err != nil { return } else { n += _n }
if meta {
// height & size
_n, err = UInt8(self.height).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
_n, err = UInt64(self.size).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
// key
_n, err = Byte(GetBinaryType(self.key)).WriteTo(w)
if err != nil { return } else { n += _n }
_n, err = self.key.WriteTo(w)
if err != nil { return } else { n += _n }
}
// key
_n, err = Byte(GetBinaryType(self.key)).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
_n, err = self.key.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
if self.height == 0 {
// value
_n, err = Byte(GetBinaryType(self.value)).WriteTo(w)
if err != nil { return } else { n += _n }
if self.value != nil {
_n, err = self.value.WriteTo(w)
if err != nil { return } else { n += _n }
}
} else {
// left
leftHash, leftCount := self.left.Hash()
hashCount += leftCount
_n, err = leftHash.WriteTo(w)
if err != nil { return } else { n += _n }
// right
rightHash, rightCount := self.right.Hash()
hashCount += rightCount
_n, err = rightHash.WriteTo(w)
if err != nil { return } else { n += _n }
}
if self.height == 0 {
// value
_n, err = Byte(GetBinaryType(self.value)).WriteTo(w)
if err != nil {
return
} else {
n += _n
}
if self.value != nil {
_n, err = self.value.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
} else {
// left
leftHash, leftCount := self.left.Hash()
hashCount += leftCount
_n, err = leftHash.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
// right
rightHash, rightCount := self.right.Hash()
hashCount += rightCount
_n, err = rightHash.WriteTo(w)
if err != nil {
return
} else {
n += _n
}
}
return
return
}
// Given a placeholder node which has only the hash set,
// load the rest of the data from db.
// Not threadsafe.
func (self *IAVLNode) fill(db Db) {
if self.hash == nil {
panic("placeholder.hash can't be nil")
}
buf := db.Get(self.hash)
r := bytes.NewReader(buf)
// node header
self.height = uint8(ReadUInt8(r))
self.size = uint64(ReadUInt64(r))
// key
key := ReadBinary(r)
self.key = key.(Key)
if self.hash == nil {
panic("placeholder.hash can't be nil")
}
buf := db.Get(self.hash)
r := bytes.NewReader(buf)
// node header
self.height = uint8(ReadUInt8(r))
self.size = uint64(ReadUInt64(r))
// key
key := ReadBinary(r)
self.key = key.(Key)
if self.height == 0 {
// value
self.value = ReadBinary(r)
} else {
// left
var leftHash ByteSlice
leftHash = ReadByteSlice(r)
self.left = &IAVLNode{
hash: leftHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
// right
var rightHash ByteSlice
rightHash = ReadByteSlice(r)
self.right = &IAVLNode{
hash: rightHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
if r.Len() != 0 {
panic("buf not all consumed")
}
}
self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER
if self.height == 0 {
// value
self.value = ReadBinary(r)
} else {
// left
var leftHash ByteSlice
leftHash = ReadByteSlice(r)
self.left = &IAVLNode{
hash: leftHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
// right
var rightHash ByteSlice
rightHash = ReadByteSlice(r)
self.right = &IAVLNode{
hash: rightHash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
if r.Len() != 0 {
panic("buf not all consumed")
}
}
self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER
}
func (self *IAVLNode) leftFilled(db Db) *IAVLNode {
if self.left.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.left.fill(db)
}
return self.left
if self.left.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.left.fill(db)
}
return self.left
}
func (self *IAVLNode) rightFilled(db Db) *IAVLNode {
if self.right.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.right.fill(db)
}
return self.right
if self.right.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
self.right.fill(db)
}
return self.right
}
func (self *IAVLNode) rotateRight(db Db) *IAVLNode {
self = self.Copy()
sl := self.leftFilled(db).Copy()
slr := sl.right
self = self.Copy()
sl := self.leftFilled(db).Copy()
slr := sl.right
sl.right = self
self.left = slr
sl.right = self
self.left = slr
self.calcHeightAndSize(db)
sl.calcHeightAndSize(db)
self.calcHeightAndSize(db)
sl.calcHeightAndSize(db)
return sl
return sl
}
func (self *IAVLNode) rotateLeft(db Db) *IAVLNode {
self = self.Copy()
sr := self.rightFilled(db).Copy()
srl := sr.left
self = self.Copy()
sr := self.rightFilled(db).Copy()
srl := sr.left
sr.left = self
self.right = srl
sr.left = self
self.right = srl
self.calcHeightAndSize(db)
sr.calcHeightAndSize(db)
self.calcHeightAndSize(db)
sr.calcHeightAndSize(db)
return sr
return sr
}
func (self *IAVLNode) calcHeightAndSize(db Db) {
self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1
self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size()
self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1
self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size()
}
func (self *IAVLNode) calcBalance(db Db) int {
return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height())
return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height())
}
func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) {
balance := self.calcBalance(db)
if (balance > 1) {
if (self.leftFilled(db).calcBalance(db) >= 0) {
// Left Left Case
return self.rotateRight(db)
} else {
// Left Right Case
self = self.Copy()
self.left = self.leftFilled(db).rotateLeft(db)
//self.calcHeightAndSize()
return self.rotateRight(db)
}
}
if (balance < -1) {
if (self.rightFilled(db).calcBalance(db) <= 0) {
// Right Right Case
return self.rotateLeft(db)
} else {
// Right Left Case
self = self.Copy()
self.right = self.rightFilled(db).rotateRight(db)
//self.calcHeightAndSize()
return self.rotateLeft(db)
}
}
// Nothing changed
return self
balance := self.calcBalance(db)
if balance > 1 {
if self.leftFilled(db).calcBalance(db) >= 0 {
// Left Left Case
return self.rotateRight(db)
} else {
// Left Right Case
self = self.Copy()
self.left = self.leftFilled(db).rotateLeft(db)
//self.calcHeightAndSize()
return self.rotateRight(db)
}
}
if balance < -1 {
if self.rightFilled(db).calcBalance(db) <= 0 {
// Right Right Case
return self.rotateLeft(db)
} else {
// Right Left Case
self = self.Copy()
self.right = self.rightFilled(db).rotateRight(db)
//self.calcHeightAndSize()
return self.rotateLeft(db)
}
}
// Nothing changed
return self
}
func (self *IAVLNode) lmd(db Db) (*IAVLNode) {
if self.height == 0 {
return self
}
return self.leftFilled(db).lmd(db)
func (self *IAVLNode) lmd(db Db) *IAVLNode {
if self.height == 0 {
return self
}
return self.leftFilled(db).lmd(db)
}
func (self *IAVLNode) rmd(db Db) (*IAVLNode) {
if self.height == 0 {
return self
}
return self.rightFilled(db).rmd(db)
func (self *IAVLNode) rmd(db Db) *IAVLNode {
if self.height == 0 {
return self
}
return self.rightFilled(db).rmd(db)
}
func (self *IAVLNode) traverse(db Db, cb func(Node)bool) bool {
stop := cb(self)
if stop { return stop }
if self.height > 0 {
stop = self.leftFilled(db).traverse(db, cb)
if stop { return stop }
stop = self.rightFilled(db).traverse(db, cb)
if stop { return stop }
}
return false
func (self *IAVLNode) traverse(db Db, cb func(Node) bool) bool {
stop := cb(self)
if stop {
return stop
}
if self.height > 0 {
stop = self.leftFilled(db).traverse(db, cb)
if stop {
return stop
}
stop = self.rightFilled(db).traverse(db, cb)
if stop {
return stop
}
}
return false
}

View File

@ -1,283 +1,283 @@
package merkle
import (
. "github.com/tendermint/tendermint/binary"
"testing"
"fmt"
"os"
"bytes"
"math/rand"
"encoding/binary"
"github.com/tendermint/tendermint/db"
"crypto/sha256"
"runtime"
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/db"
"math/rand"
"os"
"runtime"
"testing"
)
func init() {
if urandom, err := os.Open("/dev/urandom"); err != nil {
return
} else {
buf := make([]byte, 8)
if _, err := urandom.Read(buf); err == nil {
buf_reader := bytes.NewReader(buf)
if seed, err := binary.ReadVarint(buf_reader); err == nil {
rand.Seed(seed)
}
}
urandom.Close()
}
if urandom, err := os.Open("/dev/urandom"); err != nil {
return
} else {
buf := make([]byte, 8)
if _, err := urandom.Read(buf); err == nil {
buf_reader := bytes.NewReader(buf)
if seed, err := binary.ReadVarint(buf_reader); err == nil {
rand.Seed(seed)
}
}
urandom.Close()
}
}
func TestUnit(t *testing.T) {
// Convenience for a new node
N := func(l, r interface{}) *IAVLNode {
var left, right *IAVLNode
if _, ok := l.(*IAVLNode); ok {
left = l.(*IAVLNode)
} else {
left = NewIAVLNode(Int32(l.(int)), nil)
}
if _, ok := r.(*IAVLNode); ok {
right = r.(*IAVLNode)
} else {
right = NewIAVLNode(Int32(r.(int)), nil)
}
// Convenience for a new node
N := func(l, r interface{}) *IAVLNode {
var left, right *IAVLNode
if _, ok := l.(*IAVLNode); ok {
left = l.(*IAVLNode)
} else {
left = NewIAVLNode(Int32(l.(int)), nil)
}
if _, ok := r.(*IAVLNode); ok {
right = r.(*IAVLNode)
} else {
right = NewIAVLNode(Int32(r.(int)), nil)
}
n := &IAVLNode{
key: right.lmd(nil).key,
left: left,
right: right,
}
n.calcHeightAndSize(nil)
n.Hash()
return n
}
n := &IAVLNode{
key: right.lmd(nil).key,
left: left,
right: right,
}
n.calcHeightAndSize(nil)
n.Hash()
return n
}
// Convenience for simple printing of keys & tree structure
var P func(*IAVLNode) string
P = func(n *IAVLNode) string {
if n.height == 0 {
return fmt.Sprintf("%v", n.key)
} else {
return fmt.Sprintf("(%v %v)", P(n.left), P(n.right))
}
}
// Convenience for simple printing of keys & tree structure
var P func(*IAVLNode) string
P = func(n *IAVLNode) string {
if n.height == 0 {
return fmt.Sprintf("%v", n.key)
} else {
return fmt.Sprintf("(%v %v)", P(n.left), P(n.right))
}
}
expectHash := func(n2 *IAVLNode, hashCount uint64) {
// ensure number of new hash calculations is as expected.
hash, count := n2.Hash()
if count != hashCount {
t.Fatalf("Expected %v new hashes, got %v", hashCount, count)
}
// nuke hashes and reconstruct hash, ensure it's the same.
(&IAVLTree{root:n2}).Traverse(func(node Node) bool {
node.(*IAVLNode).hash = nil
return false
})
// ensure that the new hash after nuking is the same as the old.
newHash, _ := n2.Hash()
if bytes.Compare(hash, newHash) != 0 {
t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash)
}
}
expectHash := func(n2 *IAVLNode, hashCount uint64) {
// ensure number of new hash calculations is as expected.
hash, count := n2.Hash()
if count != hashCount {
t.Fatalf("Expected %v new hashes, got %v", hashCount, count)
}
// nuke hashes and reconstruct hash, ensure it's the same.
(&IAVLTree{root: n2}).Traverse(func(node Node) bool {
node.(*IAVLNode).hash = nil
return false
})
// ensure that the new hash after nuking is the same as the old.
newHash, _ := n2.Hash()
if bytes.Compare(hash, newHash) != 0 {
t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash)
}
}
expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, updated := n.put(nil, Int32(i), nil)
// ensure node was added & structure is as expected.
if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
i, P(n), repr, P(n2), updated)
}
// ensure hash calculation requirements
expectHash(n2, hashCount)
}
expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, updated := n.put(nil, Int32(i), nil)
// ensure node was added & structure is as expected.
if updated == true || P(n2) != repr {
t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v",
i, P(n), repr, P(n2), updated)
}
// ensure hash calculation requirements
expectHash(n2, hashCount)
}
expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, _, value, err := n.remove(nil, Int32(i))
// ensure node was added & structure is as expected.
if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",
i, P(n), repr, P(n2), value, err)
}
// ensure hash calculation requirements
expectHash(n2, hashCount)
}
expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) {
n2, _, value, err := n.remove(nil, Int32(i))
// ensure node was added & structure is as expected.
if value != nil || err != nil || P(n2) != repr {
t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v",
i, P(n), repr, P(n2), value, err)
}
// ensure hash calculation requirements
expectHash(n2, hashCount)
}
//////// Test Put cases:
//////// Test Put cases:
// Case 1:
n1 := N(4, 20)
// Case 1:
n1 := N(4, 20)
expectPut(n1, 8, "((4 8) 20)", 3)
expectPut(n1, 25, "(4 (20 25))", 3)
expectPut(n1, 8, "((4 8) 20)", 3)
expectPut(n1, 25, "(4 (20 25))", 3)
n2 := N(4, N(20, 25))
n2 := N(4, N(20, 25))
expectPut(n2, 8, "((4 8) (20 25))", 3)
expectPut(n2, 30, "((4 20) (25 30))", 4)
expectPut(n2, 8, "((4 8) (20 25))", 3)
expectPut(n2, 30, "((4 20) (25 30))", 4)
n3 := N(N(1, 2), 6)
n3 := N(N(1, 2), 6)
expectPut(n3, 4, "((1 2) (4 6))", 4)
expectPut(n3, 8, "((1 2) (6 8))", 3)
expectPut(n3, 4, "((1 2) (4 6))", 4)
expectPut(n3, 8, "((1 2) (6 8))", 3)
n4 := N(N(1, 2), N(N(5, 6), N(7, 9)))
n4 := N(N(1, 2), N(N(5, 6), N(7, 9)))
expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5)
expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5)
expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5)
expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5)
//////// Test Remove cases:
//////// Test Remove cases:
n10 := N(N(1, 2), 3)
n10 := N(N(1, 2), 3)
expectRemove(n10, 2, "(1 3)", 1)
expectRemove(n10, 3, "(1 2)", 0)
expectRemove(n10, 2, "(1 3)", 1)
expectRemove(n10, 3, "(1 2)", 0)
n11 := N(N(N(1, 2), 3), N(4, 5))
n11 := N(N(N(1, 2), 3), N(4, 5))
expectRemove(n11, 4, "((1 2) (3 5))", 2)
expectRemove(n11, 3, "((1 2) (4 5))", 1)
expectRemove(n11, 4, "((1 2) (3 5))", 2)
expectRemove(n11, 3, "((1 2) (4 5))", 1)
}
func TestIntegration(t *testing.T) {
type record struct {
key String
value String
}
type record struct {
key String
value String
}
records := make([]*record, 400)
var tree *IAVLTree = NewIAVLTree(nil)
var err error
var val Value
var updated bool
records := make([]*record, 400)
var tree *IAVLTree = NewIAVLTree(nil)
var err error
var val Value
var updated bool
randomRecord := func() *record {
return &record{ randstr(20), randstr(20) }
}
randomRecord := func() *record {
return &record{randstr(20), randstr(20)}
}
for i := range records {
r := randomRecord()
records[i] = r
//t.Log("New record", r)
//PrintIAVLNode(tree.root)
updated = tree.Put(r.key, String(""))
if updated {
t.Error("should have not been updated")
}
updated = tree.Put(r.key, r.value)
if !updated {
t.Error("should have been updated")
}
if tree.Size() != uint64(i+1) {
t.Error("size was wrong", tree.Size(), i+1)
}
}
for i := range records {
r := randomRecord()
records[i] = r
//t.Log("New record", r)
//PrintIAVLNode(tree.root)
updated = tree.Put(r.key, String(""))
if updated {
t.Error("should have not been updated")
}
updated = tree.Put(r.key, r.value)
if !updated {
t.Error("should have been updated")
}
if tree.Size() != uint64(i+1) {
t.Error("size was wrong", tree.Size(), i+1)
}
}
for _, r := range records {
if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key)
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
for _, r := range records {
if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key)
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
for i, x := range records {
if val, err = tree.Remove(x.key); err != nil {
t.Error(err)
} else if !(val.(String)).Equals(x.value) {
t.Error("wrong value")
}
for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key)
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
if tree.Size() != uint64(len(records) - (i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i+1)))
}
}
for i, x := range records {
if val, err = tree.Remove(x.key); err != nil {
t.Error(err)
} else if !(val.(String)).Equals(x.value) {
t.Error("wrong value")
}
for _, r := range records[i+1:] {
if has := tree.Has(r.key); !has {
t.Error("Missing key", r.key)
}
if has := tree.Has(randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
if tree.Size() != uint64(len(records)-(i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i + 1)))
}
}
}
func TestPersistence(t *testing.T) {
db := db.NewMemDB()
db := db.NewMemDB()
// Create some random key value pairs
records := make(map[String]String)
for i:=0; i<10000; i++ {
records[String(randstr(20))] = String(randstr(20))
}
// Create some random key value pairs
records := make(map[String]String)
for i := 0; i < 10000; i++ {
records[String(randstr(20))] = String(randstr(20))
}
// Construct some tree and save it
t1 := NewIAVLTree(db)
for key, value := range records {
t1.Put(key, value)
}
t1.Save()
// Construct some tree and save it
t1 := NewIAVLTree(db)
for key, value := range records {
t1.Put(key, value)
}
t1.Save()
hash, _ := t1.Hash()
hash, _ := t1.Hash()
// Load a tree
t2 := NewIAVLTreeFromHash(db, hash)
for key, value := range records {
t2value := t2.Get(key)
if !BinaryEqual(t2value, value) {
t.Fatalf("Invalid value. Expected %v, got %v", value, t2value)
}
}
// Load a tree
t2 := NewIAVLTreeFromHash(db, hash)
for key, value := range records {
t2value := t2.Get(key)
if !BinaryEqual(t2value, value) {
t.Fatalf("Invalid value. Expected %v, got %v", value, t2value)
}
}
}
func BenchmarkHash(b *testing.B) {
b.StopTimer()
b.StopTimer()
s := randstr(128)
s := randstr(128)
b.StartTimer()
for i := 0; i < b.N; i++ {
hasher := sha256.New()
hasher.Write([]byte(s))
hasher.Sum(nil)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
hasher := sha256.New()
hasher.Write([]byte(s))
hasher.Sum(nil)
}
}
func BenchmarkImmutableAvlTree(b *testing.B) {
b.StopTimer()
b.StopTimer()
type record struct {
key String
value String
}
type record struct {
key String
value String
}
randomRecord := func() *record {
return &record{ randstr(32), randstr(32) }
}
randomRecord := func() *record {
return &record{randstr(32), randstr(32)}
}
t := NewIAVLTree(nil)
for i:=0; i<1000000; i++ {
r := randomRecord()
t.Put(r.key, r.value)
}
t := NewIAVLTree(nil)
for i := 0; i < 1000000; i++ {
r := randomRecord()
t.Put(r.key, r.value)
}
fmt.Println("ok, starting")
fmt.Println("ok, starting")
runtime.GC()
runtime.GC()
b.StartTimer()
for i := 0; i < b.N; i++ {
r := randomRecord()
t.Put(r.key, r.value)
t.Remove(r.key)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
r := randomRecord()
t.Put(r.key, r.value)
t.Remove(r.key)
}
}

View File

@ -1,10 +1,10 @@
package merkle
import (
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/binary"
)
const HASH_BYTE_SIZE int = 4+32
const HASH_BYTE_SIZE int = 4 + 32
/*
Immutable AVL Tree (wraps the Node root)
@ -13,100 +13,118 @@ This tree is not concurrency safe.
You must wrap your calls with your own mutex.
*/
type IAVLTree struct {
db Db
root *IAVLNode
db Db
root *IAVLNode
}
func NewIAVLTree(db Db) *IAVLTree {
return &IAVLTree{db:db, root:nil}
return &IAVLTree{db: db, root: nil}
}
func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree {
root := &IAVLNode{
hash: hash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
root.fill(db)
return &IAVLTree{db:db, root:root}
root := &IAVLNode{
hash: hash,
flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
}
root.fill(db)
return &IAVLTree{db: db, root: root}
}
func (t *IAVLTree) Root() Node {
return t.root
return t.root
}
func (t *IAVLTree) Size() uint64 {
if t.root == nil { return 0 }
return t.root.Size()
if t.root == nil {
return 0
}
return t.root.Size()
}
func (t *IAVLTree) Height() uint8 {
if t.root == nil { return 0 }
return t.root.Height()
if t.root == nil {
return 0
}
return t.root.Height()
}
func (t *IAVLTree) Has(key Key) bool {
if t.root == nil { return false }
return t.root.has(t.db, key)
if t.root == nil {
return false
}
return t.root.has(t.db, key)
}
func (t *IAVLTree) Put(key Key, value Value) (updated bool) {
if t.root == nil {
t.root = NewIAVLNode(key, value)
return false
}
t.root, updated = t.root.put(t.db, key, value)
return updated
if t.root == nil {
t.root = NewIAVLNode(key, value)
return false
}
t.root, updated = t.root.put(t.db, key, value)
return updated
}
func (t *IAVLTree) Hash() (ByteSlice, uint64) {
if t.root == nil { return nil, 0 }
return t.root.Hash()
if t.root == nil {
return nil, 0
}
return t.root.Hash()
}
func (t *IAVLTree) Save() {
if t.root == nil { return }
if t.root.hash == nil {
t.root.Hash()
}
t.root.Save(t.db)
if t.root == nil {
return
}
if t.root.hash == nil {
t.root.Hash()
}
t.root.Save(t.db)
}
func (t *IAVLTree) Get(key Key) (value Value) {
if t.root == nil { return nil }
return t.root.get(t.db, key)
if t.root == nil {
return nil
}
return t.root.get(t.db, key)
}
func (t *IAVLTree) Remove(key Key) (value Value, err error) {
if t.root == nil { return nil, NotFound(key) }
newRoot, _, value, err := t.root.remove(t.db, key)
if err != nil {
return nil, err
}
t.root = newRoot
return value, nil
if t.root == nil {
return nil, NotFound(key)
}
newRoot, _, value, err := t.root.remove(t.db, key)
if err != nil {
return nil, err
}
t.root = newRoot
return value, nil
}
func (t *IAVLTree) Copy() Tree {
return &IAVLTree{db:t.db, root:t.root}
return &IAVLTree{db: t.db, root: t.root}
}
// Traverses all the nodes of the tree in prefix order.
// return true from cb to halt iteration.
// node.Height() == 0 if you just want a value node.
func (t *IAVLTree) Traverse(cb func(Node) bool) {
if t.root == nil { return }
t.root.traverse(t.db, cb)
if t.root == nil {
return
}
t.root.traverse(t.db, cb)
}
func (t *IAVLTree) Values() <-chan Value {
root := t.root
ch := make(chan Value)
go func() {
root.traverse(t.db, func(n Node) bool {
if n.Height() == 0 { ch <- n.Value() }
return true
})
close(ch)
}()
return ch
root := t.root
ch := make(chan Value)
go func() {
root.traverse(t.db, func(n Node) bool {
if n.Height() == 0 {
ch <- n.Value()
}
return true
})
close(ch)
}()
return ch
}

View File

@ -1,50 +1,50 @@
package merkle
import (
. "github.com/tendermint/tendermint/binary"
"fmt"
"fmt"
. "github.com/tendermint/tendermint/binary"
)
type Value interface {
Binary
Binary
}
type Key interface {
Binary
Equals(Binary) bool
Less(b Binary) bool
Binary
Equals(Binary) bool
Less(b Binary) bool
}
type Db interface {
Get([]byte) []byte
Put([]byte, []byte)
Get([]byte) []byte
Put([]byte, []byte)
}
type Node interface {
Binary
Key() Key
Value() Value
Size() uint64
Height() uint8
Hash() (ByteSlice, uint64)
Save(Db)
Binary
Key() Key
Value() Value
Size() uint64
Height() uint8
Hash() (ByteSlice, uint64)
Save(Db)
}
type Tree interface {
Root() Node
Size() uint64
Height() uint8
Has(key Key) bool
Get(key Key) Value
Hash() (ByteSlice, uint64)
Save()
Put(Key, Value) bool
Remove(Key) (Value, error)
Copy() Tree
Traverse(func(Node)bool)
Values() <-chan Value
Root() Node
Size() uint64
Height() uint8
Has(key Key) bool
Get(key Key) Value
Hash() (ByteSlice, uint64)
Save()
Put(Key, Value) bool
Remove(Key) (Value, error)
Copy() Tree
Traverse(func(Node) bool)
Values() <-chan Value
}
func NotFound(key Key) error {
return fmt.Errorf("Key was not found.")
return fmt.Errorf("Key was not found.")
}

View File

@ -1,78 +1,83 @@
package merkle
import (
. "github.com/tendermint/tendermint/binary"
"os"
"fmt"
"crypto/sha256"
"crypto/sha256"
"fmt"
. "github.com/tendermint/tendermint/binary"
"os"
)
/*
Compute a deterministic merkle hash from a list of byteslices.
*/
func HashFromBinarySlice(items []Binary) ByteSlice {
switch len(items) {
case 0:
panic("Cannot compute hash of empty slice")
case 1:
hasher := sha256.New()
_, err := items[0].WriteTo(hasher)
if err != nil { panic(err) }
return ByteSlice(hasher.Sum(nil))
default:
hasher := sha256.New()
_, err := HashFromBinarySlice(items[0:len(items)/2]).WriteTo(hasher)
if err != nil { panic(err) }
_, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher)
if err != nil { panic(err) }
return ByteSlice(hasher.Sum(nil))
}
switch len(items) {
case 0:
panic("Cannot compute hash of empty slice")
case 1:
hasher := sha256.New()
_, err := items[0].WriteTo(hasher)
if err != nil {
panic(err)
}
return ByteSlice(hasher.Sum(nil))
default:
hasher := sha256.New()
_, err := HashFromBinarySlice(items[0 : len(items)/2]).WriteTo(hasher)
if err != nil {
panic(err)
}
_, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher)
if err != nil {
panic(err)
}
return ByteSlice(hasher.Sum(nil))
}
}
func PrintIAVLNode(node *IAVLNode) {
fmt.Println("==== NODE")
if node != nil {
printIAVLNode(node, 0)
}
fmt.Println("==== END")
fmt.Println("==== NODE")
if node != nil {
printIAVLNode(node, 0)
}
fmt.Println("==== END")
}
func printIAVLNode(node *IAVLNode, indent int) {
indentPrefix := ""
for i:=0; i<indent; i++ {
indentPrefix += " "
}
indentPrefix := ""
for i := 0; i < indent; i++ {
indentPrefix += " "
}
if node.right != nil {
printIAVLNode(node.rightFilled(nil), indent+1)
}
if node.right != nil {
printIAVLNode(node.rightFilled(nil), indent+1)
}
fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height)
fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height)
if node.left != nil {
printIAVLNode(node.leftFilled(nil), indent+1)
}
if node.left != nil {
printIAVLNode(node.leftFilled(nil), indent+1)
}
}
func randstr(length int) String {
if urandom, err := os.Open("/dev/urandom"); err != nil {
panic(err)
} else {
slice := make([]byte, length)
if _, err := urandom.Read(slice); err != nil {
panic(err)
}
urandom.Close()
return String(slice)
}
panic("unreachable")
if urandom, err := os.Open("/dev/urandom"); err != nil {
panic(err)
} else {
slice := make([]byte, length)
if _, err := urandom.Read(slice); err != nil {
panic(err)
}
urandom.Close()
return String(slice)
}
panic("unreachable")
}
func maxUint8(a, b uint8) uint8 {
if a > b {
return a
}
return b
if a > b {
return a
}
return b
}

View File

@ -5,217 +5,236 @@
package peer
import (
. "github.com/tendermint/tendermint/binary"
crand "crypto/rand" // for seeding
"encoding/binary"
"encoding/json"
"io"
"math"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"os"
"fmt"
crand "crypto/rand" // for seeding
"encoding/binary"
"encoding/json"
"fmt"
. "github.com/tendermint/tendermint/binary"
"io"
"math"
"math/rand"
"net"
"os"
"sync"
"sync/atomic"
"time"
)
/* AddrBook - concurrency safe peer address manager */
type AddrBook struct {
filePath string
filePath string
mtx sync.Mutex
rand *rand.Rand
key [32]byte
addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress
addrNew [newBucketCount]map[string]*KnownAddress
addrOld [oldBucketCount][]*KnownAddress
started int32
shutdown int32
wg sync.WaitGroup
quit chan struct{}
nOld int
nNew int
mtx sync.Mutex
rand *rand.Rand
key [32]byte
addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress
addrNew [newBucketCount]map[string]*KnownAddress
addrOld [oldBucketCount][]*KnownAddress
started int32
shutdown int32
wg sync.WaitGroup
quit chan struct{}
nOld int
nNew int
}
const (
// addresses under which the address manager will claim to need more addresses.
needAddressThreshold = 1000
// addresses under which the address manager will claim to need more addresses.
needAddressThreshold = 1000
// interval used to dump the address cache to disk for future use.
dumpAddressInterval = time.Minute * 2
// interval used to dump the address cache to disk for future use.
dumpAddressInterval = time.Minute * 2
// max addresses in each old address bucket.
oldBucketSize = 64
// max addresses in each old address bucket.
oldBucketSize = 64
// buckets we split old addresses over.
oldBucketCount = 64
// buckets we split old addresses over.
oldBucketCount = 64
// max addresses in each new address bucket.
newBucketSize = 64
// max addresses in each new address bucket.
newBucketSize = 64
// buckets that we spread new addresses over.
newBucketCount = 256
// buckets that we spread new addresses over.
newBucketCount = 256
// old buckets over which an address group will be spread.
oldBucketsPerGroup = 4
// old buckets over which an address group will be spread.
oldBucketsPerGroup = 4
// new buckets over which an source address group will be spread.
newBucketsPerGroup = 32
// new buckets over which an source address group will be spread.
newBucketsPerGroup = 32
// buckets a frequently seen new address may end up in.
newBucketsPerAddress = 4
// buckets a frequently seen new address may end up in.
newBucketsPerAddress = 4
// days before which we assume an address has vanished
// if we have not seen it announced in that long.
numMissingDays = 30
// days before which we assume an address has vanished
// if we have not seen it announced in that long.
numMissingDays = 30
// tries without a single success before we assume an address is bad.
numRetries = 3
// tries without a single success before we assume an address is bad.
numRetries = 3
// max failures we will accept without a success before considering an address bad.
maxFailures = 10
// max failures we will accept without a success before considering an address bad.
maxFailures = 10
// days since the last success before we will consider evicting an address.
minBadDays = 7
// days since the last success before we will consider evicting an address.
minBadDays = 7
// max addresses that we will send in response to a getAddr
// (in practise the most addresses we will return from a call to AddressCache()).
getAddrMax = 2500
// max addresses that we will send in response to a getAddr
// (in practise the most addresses we will return from a call to AddressCache()).
getAddrMax = 2500
// % of total addresses known that we will share with a call to AddressCache.
getAddrPercent = 23
// % of total addresses known that we will share with a call to AddressCache.
getAddrPercent = 23
// current version of the on-disk format.
serialisationVersion = 1
// current version of the on-disk format.
serialisationVersion = 1
)
// Use Start to begin processing asynchronous address updates.
func NewAddrBook(filePath string) *AddrBook {
am := AddrBook{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
quit: make(chan struct{}),
filePath: filePath,
}
am.init()
return &am
am := AddrBook{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
quit: make(chan struct{}),
filePath: filePath,
}
am.init()
return &am
}
// When modifying this, don't forget to update loadFromFile()
func (a *AddrBook) init() {
a.addrIndex = make(map[string]*KnownAddress)
io.ReadFull(crand.Reader, a.key[:])
for i := range a.addrNew {
a.addrNew[i] = make(map[string]*KnownAddress)
}
for i := range a.addrOld {
a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize)
}
a.addrIndex = make(map[string]*KnownAddress)
io.ReadFull(crand.Reader, a.key[:])
for i := range a.addrNew {
a.addrNew[i] = make(map[string]*KnownAddress)
}
for i := range a.addrOld {
a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize)
}
}
func (a *AddrBook) Start() {
if atomic.AddInt32(&a.started, 1) != 1 { return }
log.Trace("Starting address manager")
a.loadFromFile(a.filePath)
a.wg.Add(1)
go a.addressHandler()
if atomic.AddInt32(&a.started, 1) != 1 {
return
}
log.Trace("Starting address manager")
a.loadFromFile(a.filePath)
a.wg.Add(1)
go a.addressHandler()
}
func (a *AddrBook) Stop() {
if atomic.AddInt32(&a.shutdown, 1) != 1 { return }
log.Infof("Address manager shutting down")
close(a.quit)
a.wg.Wait()
if atomic.AddInt32(&a.shutdown, 1) != 1 {
return
}
log.Infof("Address manager shutting down")
close(a.quit)
a.wg.Wait()
}
func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock()
a.addAddress(addr, src)
a.mtx.Lock()
defer a.mtx.Unlock()
a.addAddress(addr, src)
}
func (a *AddrBook) NeedMoreAddresses() bool {
return a.NumAddresses() < needAddressThreshold
return a.NumAddresses() < needAddressThreshold
}
func (a *AddrBook) NumAddresses() int {
a.mtx.Lock(); defer a.mtx.Unlock()
return a.nOld + a.nNew
a.mtx.Lock()
defer a.mtx.Unlock()
return a.nOld + a.nNew
}
// Pick a new address to connect to.
func (a *AddrBook) PickAddress(class string, newBias int) *KnownAddress {
a.mtx.Lock(); defer a.mtx.Unlock()
a.mtx.Lock()
defer a.mtx.Unlock()
if a.nOld == 0 && a.nNew == 0 { return nil }
if newBias > 100 { newBias = 100 }
if newBias < 0 { newBias = 0 }
if a.nOld == 0 && a.nNew == 0 {
return nil
}
if newBias > 100 {
newBias = 100
}
if newBias < 0 {
newBias = 0
}
// Bias between new and old addresses.
oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias))
newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias)
// Bias between new and old addresses.
oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias))
newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias)
if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation {
// pick random Old bucket.
var bucket []*KnownAddress = nil
for len(bucket) == 0 {
bucket = a.addrOld[a.rand.Intn(len(a.addrOld))]
}
// pick a random ka from bucket.
return bucket[a.rand.Intn(len(bucket))]
} else {
// pick random New bucket.
var bucket map[string]*KnownAddress = nil
for len(bucket) == 0 {
bucket = a.addrNew[a.rand.Intn(len(a.addrNew))]
}
// pick a random ka from bucket.
randIndex := a.rand.Intn(len(bucket))
for _, ka := range bucket {
randIndex--
if randIndex == 0 {
return ka
}
}
panic("Should not happen")
}
return nil
if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation {
// pick random Old bucket.
var bucket []*KnownAddress = nil
for len(bucket) == 0 {
bucket = a.addrOld[a.rand.Intn(len(a.addrOld))]
}
// pick a random ka from bucket.
return bucket[a.rand.Intn(len(bucket))]
} else {
// pick random New bucket.
var bucket map[string]*KnownAddress = nil
for len(bucket) == 0 {
bucket = a.addrNew[a.rand.Intn(len(a.addrNew))]
}
// pick a random ka from bucket.
randIndex := a.rand.Intn(len(bucket))
for _, ka := range bucket {
randIndex--
if randIndex == 0 {
return ka
}
}
panic("Should not happen")
}
return nil
}
func (a *AddrBook) MarkGood(addr *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock()
ka := a.addrIndex[addr.String()]
if ka == nil { return }
ka.MarkAttempt(true)
if ka.OldBucket == -1 {
a.moveToOld(ka)
}
a.mtx.Lock()
defer a.mtx.Unlock()
ka := a.addrIndex[addr.String()]
if ka == nil {
return
}
ka.MarkAttempt(true)
if ka.OldBucket == -1 {
a.moveToOld(ka)
}
}
func (a *AddrBook) MarkAttempt(addr *NetAddress) {
a.mtx.Lock(); defer a.mtx.Unlock()
ka := a.addrIndex[addr.String()]
if ka == nil { return }
ka.MarkAttempt(false)
a.mtx.Lock()
defer a.mtx.Unlock()
ka := a.addrIndex[addr.String()]
if ka == nil {
return
}
ka.MarkAttempt(false)
}
/* Loading & Saving */
type addrBookJSON struct {
Key [32]byte
AddrNew [newBucketCount]map[string]*KnownAddress
AddrOld [oldBucketCount][]*KnownAddress
NOld int
NNew int
Key [32]byte
AddrNew [newBucketCount]map[string]*KnownAddress
AddrOld [oldBucketCount][]*KnownAddress
NOld int
NNew int
}
func (a *AddrBook) saveToFile(filePath string) {
aJSON := &addrBookJSON{
Key: a.key,
AddrNew: a.addrNew,
AddrOld: a.addrOld,
NOld: a.nOld,
NNew: a.nNew,
}
aJSON := &addrBookJSON{
Key: a.key,
AddrNew: a.addrNew,
AddrOld: a.addrOld,
NOld: a.nOld,
NNew: a.nNew,
}
w, err := os.Create(filePath)
if err != nil {
@ -225,296 +244,306 @@ func (a *AddrBook) saveToFile(filePath string) {
enc := json.NewEncoder(w)
defer w.Close()
err = enc.Encode(&aJSON)
if err != nil { panic(err) }
if err != nil {
panic(err)
}
}
func (a *AddrBook) loadFromFile(filePath string) {
// If doesn't exist, do nothing.
// If doesn't exist, do nothing.
_, err := os.Stat(filePath)
if os.IsNotExist(err) { return }
if os.IsNotExist(err) {
return
}
// Load addrBookJSON{}
// Load addrBookJSON{}
r, err := os.Open(filePath)
if err != nil {
panic(fmt.Errorf("%s error opening file: %v", filePath, err))
panic(fmt.Errorf("%s error opening file: %v", filePath, err))
}
defer r.Close()
aJSON := &addrBookJSON{}
aJSON := &addrBookJSON{}
dec := json.NewDecoder(r)
err = dec.Decode(aJSON)
if err != nil {
panic(fmt.Errorf("error reading %s: %v", filePath, err))
}
// Now we need to initialize self.
// Now we need to initialize self.
copy(a.key[:], aJSON.Key[:])
a.addrNew = aJSON.AddrNew
for i, oldBucket := range aJSON.AddrOld {
copy(a.addrOld[i], oldBucket)
}
a.nNew = aJSON.NNew
a.nOld = aJSON.NOld
copy(a.key[:], aJSON.Key[:])
a.addrNew = aJSON.AddrNew
for i, oldBucket := range aJSON.AddrOld {
copy(a.addrOld[i], oldBucket)
}
a.nNew = aJSON.NNew
a.nOld = aJSON.NOld
a.addrIndex = make(map[string]*KnownAddress)
for _, newBucket := range a.addrNew {
for key, ka := range newBucket {
a.addrIndex[key] = ka
}
}
a.addrIndex = make(map[string]*KnownAddress)
for _, newBucket := range a.addrNew {
for key, ka := range newBucket {
a.addrIndex[key] = ka
}
}
}
/* Private methods */
func (a *AddrBook) addressHandler() {
dumpAddressTicker := time.NewTicker(dumpAddressInterval)
dumpAddressTicker := time.NewTicker(dumpAddressInterval)
out:
for {
select {
case <-dumpAddressTicker.C:
a.saveToFile(a.filePath)
case <-a.quit:
break out
}
}
dumpAddressTicker.Stop()
a.saveToFile(a.filePath)
a.wg.Done()
log.Trace("Address handler done")
for {
select {
case <-dumpAddressTicker.C:
a.saveToFile(a.filePath)
case <-a.quit:
break out
}
}
dumpAddressTicker.Stop()
a.saveToFile(a.filePath)
a.wg.Done()
log.Trace("Address handler done")
}
func (a *AddrBook) addAddress(addr, src *NetAddress) {
if !addr.Routable() { return }
if !addr.Routable() {
return
}
key := addr.String()
ka := a.addrIndex[key]
key := addr.String()
ka := a.addrIndex[key]
if ka != nil {
// Already added
if ka.OldBucket != -1 { return }
if ka.NewRefs == newBucketsPerAddress { return }
if ka != nil {
// Already added
if ka.OldBucket != -1 {
return
}
if ka.NewRefs == newBucketsPerAddress {
return
}
// The more entries we have, the less likely we are to add more.
factor := int32(2 * ka.NewRefs)
if a.rand.Int31n(factor) != 0 {
return
}
} else {
ka = NewKnownAddress(addr, src)
a.addrIndex[key] = ka
a.nNew++
}
// The more entries we have, the less likely we are to add more.
factor := int32(2 * ka.NewRefs)
if a.rand.Int31n(factor) != 0 {
return
}
} else {
ka = NewKnownAddress(addr, src)
a.addrIndex[key] = ka
a.nNew++
}
bucket := a.getNewBucket(addr, src)
bucket := a.getNewBucket(addr, src)
// Already exists?
if _, ok := a.addrNew[bucket][key]; ok {
return
}
// Already exists?
if _, ok := a.addrNew[bucket][key]; ok {
return
}
// Enforce max addresses.
if len(a.addrNew[bucket]) > newBucketSize {
log.Tracef("new bucket is full, expiring old ")
a.expireNew(bucket)
}
// Enforce max addresses.
if len(a.addrNew[bucket]) > newBucketSize {
log.Tracef("new bucket is full, expiring old ")
a.expireNew(bucket)
}
// Add to new bucket.
ka.NewRefs++
a.addrNew[bucket][key] = ka
// Add to new bucket.
ka.NewRefs++
a.addrNew[bucket][key] = ka
log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew)
log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew)
}
// Make space in the new buckets by expiring the really bad entries.
// If no bad entries are available we look at a few and remove the oldest.
func (a *AddrBook) expireNew(bucket int) {
var oldest *KnownAddress
for k, v := range a.addrNew[bucket] {
// If an entry is bad, throw it away
if v.Bad() {
log.Tracef("expiring bad address %v", k)
delete(a.addrNew[bucket], k)
v.NewRefs--
if v.NewRefs == 0 {
a.nNew--
delete(a.addrIndex, k)
}
return
}
// or, keep track of the oldest entry
if oldest == nil {
oldest = v
} else if v.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = v
}
}
var oldest *KnownAddress
for k, v := range a.addrNew[bucket] {
// If an entry is bad, throw it away
if v.Bad() {
log.Tracef("expiring bad address %v", k)
delete(a.addrNew[bucket], k)
v.NewRefs--
if v.NewRefs == 0 {
a.nNew--
delete(a.addrIndex, k)
}
return
}
// or, keep track of the oldest entry
if oldest == nil {
oldest = v
} else if v.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = v
}
}
// If we haven't thrown out a bad entry, throw out the oldest entry
if oldest != nil {
key := oldest.Addr.String()
log.Tracef("expiring oldest address %v", key)
delete(a.addrNew[bucket], key)
oldest.NewRefs--
if oldest.NewRefs == 0 {
a.nNew--
delete(a.addrIndex, key)
}
}
// If we haven't thrown out a bad entry, throw out the oldest entry
if oldest != nil {
key := oldest.Addr.String()
log.Tracef("expiring oldest address %v", key)
delete(a.addrNew[bucket], key)
oldest.NewRefs--
if oldest.NewRefs == 0 {
a.nNew--
delete(a.addrIndex, key)
}
}
}
func (a *AddrBook) moveToOld(ka *KnownAddress) {
// Remove from all new buckets.
// Remember one of those new buckets.
addrKey := ka.Addr.String()
freedBucket := -1
for i := range a.addrNew {
// we check for existance so we can record the first one
if _, ok := a.addrNew[i][addrKey]; ok {
delete(a.addrNew[i], addrKey)
ka.NewRefs--
if freedBucket == -1 {
freedBucket = i
}
}
}
a.nNew--
if freedBucket == -1 { panic("Expected to find addr in at least one new bucket") }
// Remove from all new buckets.
// Remember one of those new buckets.
addrKey := ka.Addr.String()
freedBucket := -1
for i := range a.addrNew {
// we check for existance so we can record the first one
if _, ok := a.addrNew[i][addrKey]; ok {
delete(a.addrNew[i], addrKey)
ka.NewRefs--
if freedBucket == -1 {
freedBucket = i
}
}
}
a.nNew--
if freedBucket == -1 {
panic("Expected to find addr in at least one new bucket")
}
oldBucket := a.getOldBucket(ka.Addr)
oldBucket := a.getOldBucket(ka.Addr)
// If room in oldBucket, put it in.
if len(a.addrOld[oldBucket]) < oldBucketSize {
ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka)
a.nOld++
return
}
// If room in oldBucket, put it in.
if len(a.addrOld[oldBucket]) < oldBucketSize {
ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka)
a.nOld++
return
}
// No room, we have to evict something else.
rmkaIndex := a.pickOld(oldBucket)
rmka := a.addrOld[oldBucket][rmkaIndex]
// No room, we have to evict something else.
rmkaIndex := a.pickOld(oldBucket)
rmka := a.addrOld[oldBucket][rmkaIndex]
// Find a new bucket to put rmka in.
newBucket := a.getNewBucket(rmka.Addr, rmka.Src)
if len(a.addrNew[newBucket]) >= newBucketSize {
newBucket = freedBucket
}
// Find a new bucket to put rmka in.
newBucket := a.getNewBucket(rmka.Addr, rmka.Src)
if len(a.addrNew[newBucket]) >= newBucketSize {
newBucket = freedBucket
}
// replace with ka in list.
ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket][rmkaIndex] = ka
rmka.OldBucket = -1
// replace with ka in list.
ka.OldBucket = Int16(oldBucket)
a.addrOld[oldBucket][rmkaIndex] = ka
rmka.OldBucket = -1
// put rmka into new bucket
rmkey := rmka.Addr.String()
log.Tracef("Replacing %s with %s in old", rmkey, addrKey)
a.addrNew[newBucket][rmkey] = rmka
rmka.NewRefs++
a.nNew++
// put rmka into new bucket
rmkey := rmka.Addr.String()
log.Tracef("Replacing %s with %s in old", rmkey, addrKey)
a.addrNew[newBucket][rmkey] = rmka
rmka.NewRefs++
a.nNew++
}
// Returns the index in old bucket of oldest entry.
func (a *AddrBook) pickOld(bucket int) int {
var oldest *KnownAddress
var oldestIndex int
for i, ka := range a.addrOld[bucket] {
if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = ka
oldestIndex = i
}
}
return oldestIndex
var oldest *KnownAddress
var oldestIndex int
for i, ka := range a.addrOld[bucket] {
if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) {
oldest = ka
oldestIndex = i
}
}
return oldestIndex
}
// doublesha256(key + sourcegroup +
// int64(doublesha256(key + group + sourcegroup))%bucket_per_source_group) % num_new_buckes
func (a *AddrBook) getNewBucket(addr, src *NetAddress) int {
data1 := []byte{}
data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(GroupKey(addr))...)
data1 = append(data1, []byte(GroupKey(src))...)
hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= newBucketsPerGroup
var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{}
data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(src)...)
data2 = append(data2, hashbuf[:]...)
data1 := []byte{}
data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(GroupKey(addr))...)
data1 = append(data1, []byte(GroupKey(src))...)
hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= newBucketsPerGroup
var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{}
data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(src)...)
data2 = append(data2, hashbuf[:]...)
hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % newBucketCount)
hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % newBucketCount)
}
// doublesha256(key + group + truncate_to_64bits(doublesha256(key + addr))%buckets_per_group) % num_buckets
func (a *AddrBook) getOldBucket(addr *NetAddress) int {
data1 := []byte{}
data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(addr.String())...)
hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= oldBucketsPerGroup
var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{}
data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(addr)...)
data2 = append(data2, hashbuf[:]...)
data1 := []byte{}
data1 = append(data1, a.key[:]...)
data1 = append(data1, []byte(addr.String())...)
hash1 := DoubleSha256(data1)
hash64 := binary.LittleEndian.Uint64(hash1)
hash64 %= oldBucketsPerGroup
var hashbuf [8]byte
binary.LittleEndian.PutUint64(hashbuf[:], hash64)
data2 := []byte{}
data2 = append(data2, a.key[:]...)
data2 = append(data2, GroupKey(addr)...)
data2 = append(data2, hashbuf[:]...)
hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount)
hash2 := DoubleSha256(data2)
return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount)
}
// Return a string representing the network group of this address.
// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string
// "local" for a local address and the string "unroutable for an unroutable
// address.
func GroupKey (na *NetAddress) string {
if na.Local() {
return "local"
}
if !na.Routable() {
return "unroutable"
}
func GroupKey(na *NetAddress) string {
if na.Local() {
return "local"
}
if !na.Routable() {
return "unroutable"
}
if ipv4 := na.IP.To4(); ipv4 != nil {
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String()
}
if na.RFC6145() || na.RFC6052() {
// last four bytes are the ip address
ip := net.IP(na.IP[12:16])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
}
if ipv4 := na.IP.To4(); ipv4 != nil {
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String()
}
if na.RFC6145() || na.RFC6052() {
// last four bytes are the ip address
ip := net.IP(na.IP[12:16])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
}
if na.RFC3964() {
ip := net.IP(na.IP[2:7])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
if na.RFC3964() {
ip := net.IP(na.IP[2:7])
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
}
if na.RFC4380() {
// teredo tunnels have the last 4 bytes as the v4 address XOR
// 0xff.
ip := net.IP(make([]byte, 4))
for i, byte := range na.IP[12:16] {
ip[i] = byte ^ 0xff
}
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
}
}
if na.RFC4380() {
// teredo tunnels have the last 4 bytes as the v4 address XOR
// 0xff.
ip := net.IP(make([]byte, 4))
for i, byte := range na.IP[12:16] {
ip[i] = byte ^ 0xff
}
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
}
// OK, so now we know ourselves to be a IPv6 address.
// bitcoind uses /32 for everything, except for Hurricane Electric's
// (he.net) IP range, which it uses /36 for.
bits := 32
heNet := &net.IPNet{IP: net.ParseIP("2001:470::"),
Mask: net.CIDRMask(32, 128)}
if heNet.Contains(na.IP) {
bits = 36
}
// OK, so now we know ourselves to be a IPv6 address.
// bitcoind uses /32 for everything, except for Hurricane Electric's
// (he.net) IP range, which it uses /36 for.
bits := 32
heNet := &net.IPNet{IP: net.ParseIP("2001:470::"),
Mask: net.CIDRMask(32, 128)}
if heNet.Contains(na.IP) {
bits = 36
}
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String()
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String()
}

View File

@ -1,12 +1,12 @@
package peer
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle"
"sync/atomic"
"sync"
"errors"
"errors"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"github.com/tendermint/tendermint/merkle"
"sync"
"sync/atomic"
)
/* Client
@ -21,147 +21,161 @@ import (
XXX what about peer disconnects?
*/
type Client struct {
addrBook *AddrBook
targetNumPeers int
makePeerFn func(*Connection) *Peer
self *Peer
recvQueues map[String]chan *InboundPacket
addrBook *AddrBook
targetNumPeers int
makePeerFn func(*Connection) *Peer
self *Peer
recvQueues map[String]chan *InboundPacket
mtx sync.Mutex
peers merkle.Tree // addr -> *Peer
quit chan struct{}
stopped uint32
mtx sync.Mutex
peers merkle.Tree // addr -> *Peer
quit chan struct{}
stopped uint32
}
var (
CLIENT_STOPPED_ERROR = errors.New("Client already stopped")
CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer")
CLIENT_STOPPED_ERROR = errors.New("Client already stopped")
CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer")
)
func NewClient(makePeerFn func(*Connection) *Peer) *Client {
self := makePeerFn(nil)
if self == nil {
Panicf("makePeerFn(nil) must return a prototypical peer for self")
}
self := makePeerFn(nil)
if self == nil {
Panicf("makePeerFn(nil) must return a prototypical peer for self")
}
recvQueues := make(map[String]chan *InboundPacket)
for chName, _ := range self.channels {
recvQueues[chName] = make(chan *InboundPacket)
}
recvQueues := make(map[String]chan *InboundPacket)
for chName, _ := range self.channels {
recvQueues[chName] = make(chan *InboundPacket)
}
c := &Client{
addrBook: nil, // TODO
targetNumPeers: 0, // TODO
makePeerFn: makePeerFn,
self: self,
recvQueues: recvQueues,
c := &Client{
addrBook: nil, // TODO
targetNumPeers: 0, // TODO
makePeerFn: makePeerFn,
self: self,
recvQueues: recvQueues,
peers: merkle.NewIAVLTree(nil),
quit: make(chan struct{}),
stopped: 0,
}
return c
peers: merkle.NewIAVLTree(nil),
quit: make(chan struct{}),
stopped: 0,
}
return c
}
func (c *Client) Stop() {
log.Infof("Stopping client")
// lock
c.mtx.Lock()
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
close(c.quit)
// stop each peer.
for peerValue := range c.peers.Values() {
peer := peerValue.(*Peer)
peer.Stop()
}
// empty tree.
c.peers = merkle.NewIAVLTree(nil)
}
c.mtx.Unlock()
// unlock
log.Infof("Stopping client")
// lock
c.mtx.Lock()
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
close(c.quit)
// stop each peer.
for peerValue := range c.peers.Values() {
peer := peerValue.(*Peer)
peer.Stop()
}
// empty tree.
c.peers = merkle.NewIAVLTree(nil)
}
c.mtx.Unlock()
// unlock
}
func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) {
if atomic.LoadUint32(&c.stopped) == 1 { return nil, CLIENT_STOPPED_ERROR }
if atomic.LoadUint32(&c.stopped) == 1 {
return nil, CLIENT_STOPPED_ERROR
}
log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing)
peer := c.makePeerFn(conn)
peer.outgoing = outgoing
err := c.addPeer(peer)
if err != nil { return nil, err }
log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing)
peer := c.makePeerFn(conn)
peer.outgoing = outgoing
err := c.addPeer(peer)
if err != nil {
return nil, err
}
go peer.Start(c.recvQueues)
go peer.Start(c.recvQueues)
return peer, nil
return peer, nil
}
func (c *Client) Broadcast(pkt Packet) {
if atomic.LoadUint32(&c.stopped) == 1 { return }
if atomic.LoadUint32(&c.stopped) == 1 {
return
}
log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes))
for v := range c.Peers().Values() {
peer := v.(*Peer)
success := peer.TrySend(pkt)
log.Tracef("Broadcast for peer %v success: %v", peer, success)
if !success {
// TODO: notify the peer
}
}
log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes))
for v := range c.Peers().Values() {
peer := v.(*Peer)
success := peer.TrySend(pkt)
log.Tracef("Broadcast for peer %v success: %v", peer, success)
if !success {
// TODO: notify the peer
}
}
}
// blocks until a message is popped.
func (c *Client) Receive(chName String) *InboundPacket {
if atomic.LoadUint32(&c.stopped) == 1 { return nil }
if atomic.LoadUint32(&c.stopped) == 1 {
return nil
}
log.Tracef("Receive on [%v]", chName)
q := c.recvQueues[chName]
if q == nil { Panicf("Expected recvQueues[%f], found none", chName) }
log.Tracef("Receive on [%v]", chName)
q := c.recvQueues[chName]
if q == nil {
Panicf("Expected recvQueues[%f], found none", chName)
}
for {
select {
case <-c.quit:
return nil
case inPacket := <-q:
return inPacket
}
}
for {
select {
case <-c.quit:
return nil
case inPacket := <-q:
return inPacket
}
}
}
func (c *Client) Peers() merkle.Tree {
// lock & defer
c.mtx.Lock(); defer c.mtx.Unlock()
return c.peers.Copy()
// unlock deferred
// lock & defer
c.mtx.Lock()
defer c.mtx.Unlock()
return c.peers.Copy()
// unlock deferred
}
func (c *Client) StopPeer(peer *Peer) {
// lock
c.mtx.Lock()
peerValue, _ := c.peers.Remove(peer.RemoteAddress())
c.mtx.Unlock()
// unlock
// lock
c.mtx.Lock()
peerValue, _ := c.peers.Remove(peer.RemoteAddress())
c.mtx.Unlock()
// unlock
peer_ := peerValue.(*Peer)
if peer_ != nil {
peer_.Stop()
}
peer_ := peerValue.(*Peer)
if peer_ != nil {
peer_.Stop()
}
}
func (c *Client) addPeer(peer *Peer) error {
addr := peer.RemoteAddress()
addr := peer.RemoteAddress()
// lock & defer
c.mtx.Lock(); defer c.mtx.Unlock()
if c.stopped == 1 { return CLIENT_STOPPED_ERROR }
if !c.peers.Has(addr) {
log.Tracef("Actually putting addr: %v, peer: %v", addr, peer)
c.peers.Put(addr, peer)
return nil
} else {
// ignore duplicate peer for addr.
log.Infof("Ignoring duplicate peer for addr %v", addr)
return CLIENT_DUPLICATE_PEER_ERROR
}
// unlock deferred
// lock & defer
c.mtx.Lock()
defer c.mtx.Unlock()
if c.stopped == 1 {
return CLIENT_STOPPED_ERROR
}
if !c.peers.Has(addr) {
log.Tracef("Actually putting addr: %v, peer: %v", addr, peer)
c.peers.Put(addr, peer)
return nil
} else {
// ignore duplicate peer for addr.
log.Infof("Ignoring duplicate peer for addr %v", addr)
return CLIENT_DUPLICATE_PEER_ERROR
}
// unlock deferred
}

View File

@ -1,106 +1,105 @@
package peer
import (
. "github.com/tendermint/tendermint/binary"
"testing"
"time"
. "github.com/tendermint/tendermint/binary"
"testing"
"time"
)
// convenience method for creating two clients connected to each other.
func makeClientPair(t *testing.T, bufferSize int, channels []string) (*Client, *Client) {
peerMaker := func(conn *Connection) *Peer {
p := NewPeer(conn)
p.channels = map[String]*Channel{}
for chName := range channels {
p.channels[String(chName)] = NewChannel(String(chName), bufferSize)
}
return p
}
peerMaker := func(conn *Connection) *Peer {
p := NewPeer(conn)
p.channels = map[String]*Channel{}
for chName := range channels {
p.channels[String(chName)] = NewChannel(String(chName), bufferSize)
}
return p
}
// Create two clients that will be interconnected.
c1 := NewClient(peerMaker)
c2 := NewClient(peerMaker)
// Create two clients that will be interconnected.
c1 := NewClient(peerMaker)
c2 := NewClient(peerMaker)
// Create a server for the listening client.
s1 := NewServer("tcp", ":8001", c1)
// Create a server for the listening client.
s1 := NewServer("tcp", ":8001", c1)
// Dial the server & add the connection to c2.
s1laddr := s1.LocalAddress()
conn, err := s1laddr.Dial()
if err != nil {
t.Fatalf("Could not connect to server address %v", s1laddr)
} else {
t.Logf("Created a connection to local server address %v", s1laddr)
}
// Dial the server & add the connection to c2.
s1laddr := s1.LocalAddress()
conn, err := s1laddr.Dial()
if err != nil {
t.Fatalf("Could not connect to server address %v", s1laddr)
} else {
t.Logf("Created a connection to local server address %v", s1laddr)
}
c2.AddPeerWithConnection(conn, true)
c2.AddPeerWithConnection(conn, true)
// Wait for things to happen, peers to get added...
time.Sleep(100 * time.Millisecond)
// Wait for things to happen, peers to get added...
time.Sleep(100 * time.Millisecond)
return c1, c2
return c1, c2
}
func TestClients(t *testing.T) {
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
// Lets send a message from c1 to c2.
if c1.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size())
}
if c2.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size())
}
// Lets send a message from c1 to c2.
if c1.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size())
}
if c2.Peers().Size() != 1 {
t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size())
}
// Broadcast a message on ch1
c1.Broadcast(NewPacket("ch1", ByteSlice("channel one")))
// Broadcast a message on ch2
c1.Broadcast(NewPacket("ch2", ByteSlice("channel two")))
// Broadcast a message on ch3
c1.Broadcast(NewPacket("ch3", ByteSlice("channel three")))
// Broadcast a message on ch1
c1.Broadcast(NewPacket("ch1", ByteSlice("channel one")))
// Broadcast a message on ch2
c1.Broadcast(NewPacket("ch2", ByteSlice("channel two")))
// Broadcast a message on ch3
c1.Broadcast(NewPacket("ch3", ByteSlice("channel three")))
// Wait for things to settle...
time.Sleep(100 * time.Millisecond)
// Wait for things to settle...
time.Sleep(100 * time.Millisecond)
// Receive message from channel 2 and check
inMsg := c2.Receive("ch2")
if string(inMsg.Bytes) != "channel two" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
}
// Receive message from channel 2 and check
inMsg := c2.Receive("ch2")
if string(inMsg.Bytes) != "channel two" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
}
// Receive message from channel 1 and check
inMsg = c2.Receive("ch1")
if string(inMsg.Bytes) != "channel one" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
}
// Receive message from channel 1 and check
inMsg = c2.Receive("ch1")
if string(inMsg.Bytes) != "channel one" {
t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes))
}
s1.Stop()
c2.Stop()
s1.Stop()
c2.Stop()
}
func BenchmarkClients(b *testing.B) {
b.StopTimer()
b.StopTimer()
// TODO: benchmark the random functions, which is faster?
// TODO: benchmark the random functions, which is faster?
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"})
// Create a sink on either channel to just pop off messages.
// TODO: ensure that when clients stop, this goroutine stops.
func recvHandler(c *Client) {
}
// Create a sink on either channel to just pop off messages.
// TODO: ensure that when clients stop, this goroutine stops.
recvHandler := func(c *Client) {
}
go recvHandler(c1)
go recvHandler(c2)
go recvHandler(c1)
go recvHandler(c2)
b.StartTimer()
b.StartTimer()
// Send random message from one channel to another
for i := 0; i < b.N; i++ {
}
// Send random message from one channel to another
for i := 0; i < b.N; i++ {
}
}

View File

@ -1,191 +1,192 @@
package peer
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"sync/atomic"
"net"
"time"
"fmt"
"fmt"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"net"
"sync/atomic"
"time"
)
const (
OUT_QUEUE_SIZE = 50
IDLE_TIMEOUT_MINUTES = 5
PING_TIMEOUT_MINUTES = 2
OUT_QUEUE_SIZE = 50
IDLE_TIMEOUT_MINUTES = 5
PING_TIMEOUT_MINUTES = 2
)
/* Connnection */
type Connection struct {
ioStats IOStats
ioStats IOStats
sendQueue chan Packet // never closes
conn net.Conn
quit chan struct{}
stopped uint32
pingDebouncer *Debouncer
pong chan struct{}
sendQueue chan Packet // never closes
conn net.Conn
quit chan struct{}
stopped uint32
pingDebouncer *Debouncer
pong chan struct{}
}
var (
PACKET_TYPE_PING = UInt8(0x00)
PACKET_TYPE_PONG = UInt8(0x01)
PACKET_TYPE_MSG = UInt8(0x10)
PACKET_TYPE_PING = UInt8(0x00)
PACKET_TYPE_PONG = UInt8(0x01)
PACKET_TYPE_MSG = UInt8(0x10)
)
func NewConnection(conn net.Conn) *Connection {
return &Connection{
sendQueue: make(chan Packet, OUT_QUEUE_SIZE),
conn: conn,
quit: make(chan struct{}),
pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute),
pong: make(chan struct{}),
}
return &Connection{
sendQueue: make(chan Packet, OUT_QUEUE_SIZE),
conn: conn,
quit: make(chan struct{}),
pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute),
pong: make(chan struct{}),
}
}
// returns true if successfully queued,
// returns false if connection was closed.
// blocks.
func (c *Connection) Send(pkt Packet) bool {
select {
case c.sendQueue <- pkt:
return true
case <-c.quit:
return false
}
select {
case c.sendQueue <- pkt:
return true
case <-c.quit:
return false
}
}
func (c *Connection) Start(channels map[String]*Channel) {
log.Debugf("Starting %v", c)
go c.sendHandler()
go c.recvHandler(channels)
log.Debugf("Starting %v", c)
go c.sendHandler()
go c.recvHandler(channels)
}
func (c *Connection) Stop() {
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
log.Debugf("Stopping %v", c)
close(c.quit)
c.conn.Close()
c.pingDebouncer.Stop()
// We can't close pong safely here because
// recvHandler may write to it after we've stopped.
// Though it doesn't need to get closed at all,
// we close it @ recvHandler.
// close(c.pong)
}
if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
log.Debugf("Stopping %v", c)
close(c.quit)
c.conn.Close()
c.pingDebouncer.Stop()
// We can't close pong safely here because
// recvHandler may write to it after we've stopped.
// Though it doesn't need to get closed at all,
// we close it @ recvHandler.
// close(c.pong)
}
}
func (c *Connection) LocalAddress() *NetAddress {
return NewNetAddress(c.conn.LocalAddr())
return NewNetAddress(c.conn.LocalAddr())
}
func (c *Connection) RemoteAddress() *NetAddress {
return NewNetAddress(c.conn.RemoteAddr())
return NewNetAddress(c.conn.RemoteAddr())
}
func (c *Connection) String() string {
return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr())
return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr())
}
func (c *Connection) flush() {
// TODO flush? (turn off nagel, turn back on, etc)
// TODO flush? (turn off nagel, turn back on, etc)
}
func (c *Connection) sendHandler() {
log.Tracef("%v sendHandler", c)
log.Tracef("%v sendHandler", c)
// TODO: catch panics & stop connection.
// TODO: catch panics & stop connection.
FOR_LOOP:
for {
var err error
select {
case <-c.pingDebouncer.Ch:
_, err = PACKET_TYPE_PING.WriteTo(c.conn)
case sendPkt := <-c.sendQueue:
log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection")
_, err = PACKET_TYPE_MSG.WriteTo(c.conn)
if err != nil { break }
_, err = sendPkt.WriteTo(c.conn)
case <-c.pong:
_, err = PACKET_TYPE_PONG.WriteTo(c.conn)
case <-c.quit:
break FOR_LOOP
}
FOR_LOOP:
for {
var err error
select {
case <-c.pingDebouncer.Ch:
_, err = PACKET_TYPE_PING.WriteTo(c.conn)
case sendPkt := <-c.sendQueue:
log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection")
_, err = PACKET_TYPE_MSG.WriteTo(c.conn)
if err != nil {
break
}
_, err = sendPkt.WriteTo(c.conn)
case <-c.pong:
_, err = PACKET_TYPE_PONG.WriteTo(c.conn)
case <-c.quit:
break FOR_LOOP
}
if err != nil {
log.Infof("%v failed @ sendHandler:\n%v", c, err)
c.Stop()
break FOR_LOOP
}
if err != nil {
log.Infof("%v failed @ sendHandler:\n%v", c, err)
c.Stop()
break FOR_LOOP
}
c.flush()
}
c.flush()
}
log.Tracef("%v sendHandler done", c)
// cleanup
log.Tracef("%v sendHandler done", c)
// cleanup
}
func (c *Connection) recvHandler(channels map[String]*Channel) {
log.Tracef("%v recvHandler with %v channels", c, len(channels))
log.Tracef("%v recvHandler with %v channels", c, len(channels))
// TODO: catch panics & stop connection.
// TODO: catch panics & stop connection.
FOR_LOOP:
for {
pktType, err := ReadUInt8Safe(c.conn)
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c)
c.Stop()
}
break FOR_LOOP
} else {
log.Tracef("Found pktType %v", pktType)
}
FOR_LOOP:
for {
pktType, err := ReadUInt8Safe(c.conn)
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c)
c.Stop()
}
break FOR_LOOP
} else {
log.Tracef("Found pktType %v", pktType)
}
switch pktType {
case PACKET_TYPE_PING:
c.pong <- struct{}{}
case PACKET_TYPE_PONG:
// do nothing
case PACKET_TYPE_MSG:
pkt, err := ReadPacketSafe(c.conn)
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c)
c.Stop()
}
break FOR_LOOP
}
channel := channels[pkt.Channel]
if channel == nil {
Panicf("Unknown channel %v", pkt.Channel)
}
channel.recvQueue <- pkt
default:
Panicf("Unknown message type %v", pktType)
}
switch pktType {
case PACKET_TYPE_PING:
c.pong <- struct{}{}
case PACKET_TYPE_PONG:
// do nothing
case PACKET_TYPE_MSG:
pkt, err := ReadPacketSafe(c.conn)
if err != nil {
if atomic.LoadUint32(&c.stopped) != 1 {
log.Infof("%v failed @ recvHandler", c)
c.Stop()
}
break FOR_LOOP
}
channel := channels[pkt.Channel]
if channel == nil {
Panicf("Unknown channel %v", pkt.Channel)
}
channel.recvQueue <- pkt
default:
Panicf("Unknown message type %v", pktType)
}
c.pingDebouncer.Reset()
}
c.pingDebouncer.Reset()
}
log.Tracef("%v recvHandler done", c)
// cleanup
close(c.pong)
for _ = range c.pong {
// drain
}
log.Tracef("%v recvHandler done", c)
// cleanup
close(c.pong)
for _ = range c.pong {
// drain
}
}
/* IOStats */
type IOStats struct {
TimeConnected Time
LastSent Time
LastRecv Time
BytesRecv UInt64
BytesSent UInt64
PktsRecv UInt64
PktsSent UInt64
TimeConnected Time
LastSent Time
LastRecv Time
BytesRecv UInt64
BytesSent UInt64
PktsRecv UInt64
PktsSent UInt64
}

View File

@ -1,104 +1,104 @@
package peer
import (
. "github.com/tendermint/tendermint/binary"
"time"
"io"
. "github.com/tendermint/tendermint/binary"
"io"
"time"
)
/*
KnownAddress
KnownAddress
tracks information about a known network address that is used
to determine how viable an address is.
tracks information about a known network address that is used
to determine how viable an address is.
*/
type KnownAddress struct {
Addr *NetAddress
Src *NetAddress
Attempts UInt32
LastAttempt Time
LastSuccess Time
NewRefs UInt16
OldBucket Int16 // TODO init to -1
Addr *NetAddress
Src *NetAddress
Attempts UInt32
LastAttempt Time
LastSuccess Time
NewRefs UInt16
OldBucket Int16 // TODO init to -1
}
func NewKnownAddress(addr *NetAddress, src *NetAddress) *KnownAddress {
return &KnownAddress{
Addr: addr,
Src: src,
OldBucket: -1,
LastAttempt: Time{time.Now()},
Attempts: 0,
}
return &KnownAddress{
Addr: addr,
Src: src,
OldBucket: -1,
LastAttempt: Time{time.Now()},
Attempts: 0,
}
}
func ReadKnownAddress(r io.Reader) *KnownAddress {
return &KnownAddress{
Addr: ReadNetAddress(r),
Src: ReadNetAddress(r),
Attempts: ReadUInt32(r),
LastAttempt: ReadTime(r),
LastSuccess: ReadTime(r),
NewRefs: ReadUInt16(r),
OldBucket: ReadInt16(r),
}
return &KnownAddress{
Addr: ReadNetAddress(r),
Src: ReadNetAddress(r),
Attempts: ReadUInt32(r),
LastAttempt: ReadTime(r),
LastSuccess: ReadTime(r),
NewRefs: ReadUInt16(r),
OldBucket: ReadInt16(r),
}
}
func (ka *KnownAddress) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(ka.Addr, w, n, err)
n, err = WriteOnto(ka.Src, w, n, err)
n, err = WriteOnto(ka.Attempts, w, n, err)
n, err = WriteOnto(ka.LastAttempt, w, n, err)
n, err = WriteOnto(ka.LastSuccess, w, n, err)
n, err = WriteOnto(ka.NewRefs, w, n, err)
n, err = WriteOnto(ka.OldBucket, w, n, err)
return
n, err = WriteOnto(ka.Addr, w, n, err)
n, err = WriteOnto(ka.Src, w, n, err)
n, err = WriteOnto(ka.Attempts, w, n, err)
n, err = WriteOnto(ka.LastAttempt, w, n, err)
n, err = WriteOnto(ka.LastSuccess, w, n, err)
n, err = WriteOnto(ka.NewRefs, w, n, err)
n, err = WriteOnto(ka.OldBucket, w, n, err)
return
}
func (ka *KnownAddress) MarkAttempt(success bool) {
now := Time{time.Now()}
ka.LastAttempt = now
if success {
ka.LastSuccess = now
ka.Attempts = 0
} else {
ka.Attempts += 1
}
now := Time{time.Now()}
ka.LastAttempt = now
if success {
ka.LastSuccess = now
ka.Attempts = 0
} else {
ka.Attempts += 1
}
}
/*
An address is bad if the address in question has not been tried in the last
minute and meets one of the following criteria:
An address is bad if the address in question has not been tried in the last
minute and meets one of the following criteria:
1) It claims to be from the future
2) It hasn't been seen in over a month
3) It has failed at least three times and never succeeded
4) It has failed ten times in the last week
1) It claims to be from the future
2) It hasn't been seen in over a month
3) It has failed at least three times and never succeeded
4) It has failed ten times in the last week
All addresses that meet these criteria are assumed to be worthless and not
worth keeping hold of.
All addresses that meet these criteria are assumed to be worthless and not
worth keeping hold of.
*/
func (ka *KnownAddress) Bad() bool {
// Has been attempted in the last minute --> good
if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) {
return false
}
// Has been attempted in the last minute --> good
if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) {
return false
}
// Over a month old?
if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) {
return true
}
// Over a month old?
if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) {
return true
}
// Never succeeded?
if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries {
return true
}
// Never succeeded?
if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries {
return true
}
// Hasn't succeeded in too long?
if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) &&
ka.Attempts >= maxFailures {
return true
}
// Hasn't succeeded in too long?
if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) &&
ka.Attempts >= maxFailures {
return true
}
return false
return false
}

View File

@ -1,127 +1,145 @@
package peer
import (
. "github.com/tendermint/tendermint/common"
"sync/atomic"
"net"
. "github.com/tendermint/tendermint/common"
"net"
"sync/atomic"
)
const (
DEFAULT_PORT = 8001
DEFAULT_PORT = 8001
)
/* Listener */
type Listener interface {
Connections() <-chan *Connection
LocalAddress() *NetAddress
Stop()
Connections() <-chan *Connection
LocalAddress() *NetAddress
Stop()
}
/* DefaultListener */
type DefaultListener struct {
listener net.Listener
connections chan *Connection
stopped uint32
listener net.Listener
connections chan *Connection
stopped uint32
}
const (
DEFAULT_BUFFERED_CONNECTIONS = 10
DEFAULT_BUFFERED_CONNECTIONS = 10
)
func NewDefaultListener(protocol string, listenAddr string) Listener {
listener, err := net.Listen(protocol, listenAddr)
if err != nil { panic(err) }
listener, err := net.Listen(protocol, listenAddr)
if err != nil {
panic(err)
}
dl := &DefaultListener{
listener: listener,
connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS),
}
dl := &DefaultListener{
listener: listener,
connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS),
}
go dl.listenHandler()
go dl.listenHandler()
return dl
return dl
}
func (l *DefaultListener) listenHandler() {
for {
conn, err := l.listener.Accept()
for {
conn, err := l.listener.Accept()
if atomic.LoadUint32(&l.stopped) == 1 { return }
if atomic.LoadUint32(&l.stopped) == 1 {
return
}
// listener wasn't stopped,
// yet we encountered an error.
if err != nil { panic(err) }
// listener wasn't stopped,
// yet we encountered an error.
if err != nil {
panic(err)
}
c := NewConnection(conn)
l.connections <- c
}
c := NewConnection(conn)
l.connections <- c
}
// cleanup
close(l.connections)
for _ = range l.connections {
// drain
}
// cleanup
close(l.connections)
for _ = range l.connections {
// drain
}
}
func (l *DefaultListener) Connections() <-chan *Connection {
return l.connections
return l.connections
}
func (l *DefaultListener) LocalAddress() *NetAddress {
return GetLocalAddress()
return GetLocalAddress()
}
func (l *DefaultListener) Stop() {
if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) {
l.listener.Close()
}
if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) {
l.listener.Close()
}
}
/* local address helpers */
func GetLocalAddress() *NetAddress {
laddr := GetUPNPLocalAddress()
if laddr != nil { return laddr }
laddr := GetUPNPLocalAddress()
if laddr != nil {
return laddr
}
laddr = GetDefaultLocalAddress()
if laddr != nil { return laddr }
laddr = GetDefaultLocalAddress()
if laddr != nil {
return laddr
}
panic("Could not determine local address")
panic("Could not determine local address")
}
// UPNP external address discovery & port mapping
// TODO: more flexible internal & external ports
func GetUPNPLocalAddress() *NetAddress {
nat, err := Discover()
if err != nil { return nil }
nat, err := Discover()
if err != nil {
return nil
}
ext, err := nat.GetExternalAddress()
if err != nil { return nil }
ext, err := nat.GetExternalAddress()
if err != nil {
return nil
}
_, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0)
if err != nil { return nil }
_, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0)
if err != nil {
return nil
}
return NewNetAddressIPPort(ext, DEFAULT_PORT)
return NewNetAddressIPPort(ext, DEFAULT_PORT)
}
// Naive local IPv4 interface address detection
// TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh
func GetDefaultLocalAddress() *NetAddress {
addrs, err := net.InterfaceAddrs()
if err != nil { Panicf("Unexpected error fetching interface addresses: %v", err) }
addrs, err := net.InterfaceAddrs()
if err != nil {
Panicf("Unexpected error fetching interface addresses: %v", err)
}
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if !ok { continue }
v4 := ipnet.IP.To4()
if v4 == nil || v4[0] == 127 { continue } // loopback
return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT)
}
return nil
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if !ok {
continue
}
v4 := ipnet.IP.To4()
if v4 == nil || v4[0] == 127 {
continue
} // loopback
return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT)
}
return nil
}

View File

@ -1,14 +1,14 @@
package peer
import (
"github.com/cihub/seelog"
"github.com/cihub/seelog"
)
var log seelog.LoggerInterface
func init() {
// TODO: replace with configuration file in the ~/.tendermint directory.
config := `
// TODO: replace with configuration file in the ~/.tendermint directory.
config := `
<seelog type="sync">
<outputs formatid="colored">
<console/>
@ -19,7 +19,9 @@ func init() {
</formats>
</seelog>`
var err error
log, err = seelog.LoggerFromConfigAsBytes([]byte(config))
if err != nil { panic(err) }
var err error
log, err = seelog.LoggerFromConfigAsBytes([]byte(config))
if err != nil {
panic(err)
}
}

View File

@ -1,59 +1,61 @@
package peer
import (
. "github.com/tendermint/tendermint/binary"
"io"
. "github.com/tendermint/tendermint/binary"
"io"
)
/* Packet */
type Packet struct {
Channel String
Bytes ByteSlice
// Hash
Channel String
Bytes ByteSlice
// Hash
}
func NewPacket(chName String, bytes ByteSlice) Packet {
return Packet{
Channel: chName,
Bytes: bytes,
}
return Packet{
Channel: chName,
Bytes: bytes,
}
}
func (p Packet) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(&p.Channel, w, n, err)
n, err = WriteOnto(&p.Bytes, w, n, err)
return
n, err = WriteOnto(&p.Channel, w, n, err)
n, err = WriteOnto(&p.Bytes, w, n, err)
return
}
func ReadPacketSafe(r io.Reader) (pkt Packet, err error) {
chName, err := ReadStringSafe(r)
if err != nil { return }
// TODO: packet length sanity check.
bytes, err := ReadByteSliceSafe(r)
if err != nil { return }
return NewPacket(chName, bytes), nil
chName, err := ReadStringSafe(r)
if err != nil {
return
}
// TODO: packet length sanity check.
bytes, err := ReadByteSliceSafe(r)
if err != nil {
return
}
return NewPacket(chName, bytes), nil
}
/* InboundPacket */
type InboundPacket struct {
Peer *Peer
Channel *Channel
Time Time
Packet
Peer *Peer
Channel *Channel
Time Time
Packet
}
/* NewFilterMsg */
type NewFilterMsg struct {
ChName String
Filter interface{} // todo
ChName String
Filter interface{} // todo
}
func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) {
panic("TODO: implement")
return 0, nil // TODO
panic("TODO: implement")
return 0, nil // TODO
}

View File

@ -5,150 +5,158 @@
package peer
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
"io"
"net"
"strconv"
. "github.com/tendermint/tendermint/binary"
. "github.com/tendermint/tendermint/common"
"io"
"net"
"strconv"
)
/* NetAddress */
type NetAddress struct {
IP net.IP
Port UInt16
IP net.IP
Port UInt16
}
// TODO: socks proxies?
func NewNetAddress(addr net.Addr) *NetAddress {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok { Panicf("Only TCPAddrs are supported. Got: %v", addr) }
ip := tcpAddr.IP
port := UInt16(tcpAddr.Port)
return NewNetAddressIPPort(ip, port)
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
Panicf("Only TCPAddrs are supported. Got: %v", addr)
}
ip := tcpAddr.IP
port := UInt16(tcpAddr.Port)
return NewNetAddressIPPort(ip, port)
}
func NewNetAddressString(addr string) *NetAddress {
host, portStr, err := net.SplitHostPort(addr)
if err != nil { panic(err) }
ip := net.ParseIP(host)
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil { panic(err) }
na := NewNetAddressIPPort(ip, UInt16(port))
return na
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
panic(err)
}
ip := net.ParseIP(host)
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
panic(err)
}
na := NewNetAddressIPPort(ip, UInt16(port))
return na
}
func NewNetAddressIPPort(ip net.IP, port UInt16) *NetAddress {
na := NetAddress{
IP: ip,
Port: port,
}
return &na
na := NetAddress{
IP: ip,
Port: port,
}
return &na
}
func ReadNetAddress(r io.Reader) *NetAddress {
return &NetAddress{
IP: net.IP(ReadByteSlice(r)),
Port: ReadUInt16(r),
}
return &NetAddress{
IP: net.IP(ReadByteSlice(r)),
Port: ReadUInt16(r),
}
}
func (na *NetAddress) WriteTo(w io.Writer) (n int64, err error) {
n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err)
n, err = WriteOnto(na.Port, w, n, err)
return
n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err)
n, err = WriteOnto(na.Port, w, n, err)
return
}
func (na *NetAddress) Equals(other Binary) bool {
if o, ok := other.(*NetAddress); ok {
return na.String() == o.String()
} else {
return false
}
if o, ok := other.(*NetAddress); ok {
return na.String() == o.String()
} else {
return false
}
}
func (na *NetAddress) Less(other Binary) bool {
if o, ok := other.(*NetAddress); ok {
return na.String() < o.String()
} else {
panic("Cannot compare unequal types")
}
if o, ok := other.(*NetAddress); ok {
return na.String() < o.String()
} else {
panic("Cannot compare unequal types")
}
}
func (na *NetAddress) String() string {
port := strconv.FormatUint(uint64(na.Port), 10)
addr := net.JoinHostPort(na.IP.String(), port)
return addr
port := strconv.FormatUint(uint64(na.Port), 10)
addr := net.JoinHostPort(na.IP.String(), port)
return addr
}
func (na *NetAddress) Dial() (*Connection, error) {
conn, err := net.Dial("tcp", na.String())
if err != nil { return nil, err }
return NewConnection(conn), nil
conn, err := net.Dial("tcp", na.String())
if err != nil {
return nil, err
}
return NewConnection(conn), nil
}
func (na *NetAddress) Routable() bool {
// TODO(oga) bitcoind doesn't include RFC3849 here, but should we?
return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() ||
na.RFC4193() || na.RFC4843() || na.Local())
// TODO(oga) bitcoind doesn't include RFC3849 here, but should we?
return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() ||
na.RFC4193() || na.RFC4843() || na.Local())
}
// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero
// address or one that matches the RFC3849 documentation address format.
func (na *NetAddress) Valid() bool {
return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() ||
na.IP.Equal(net.IPv4bcast))
return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() ||
na.IP.Equal(net.IPv4bcast))
}
func (na *NetAddress) Local() bool {
return na.IP.IsLoopback() || zero4.Contains(na.IP)
return na.IP.IsLoopback() || zero4.Contains(na.IP)
}
func (na *NetAddress) ReachabilityTo(o *NetAddress) int {
const (
Unreachable = 0
Default = iota
Teredo
Ipv6_weak
Ipv4
Ipv6_strong
Private
)
if !na.Routable() {
return Unreachable
} else if na.RFC4380() {
if !o.Routable() {
return Default
} else if o.RFC4380() {
return Teredo
} else if o.IP.To4() != nil {
return Ipv4
} else { // ipv6
return Ipv6_weak
}
} else if na.IP.To4() != nil {
if o.Routable() && o.IP.To4() != nil {
return Ipv4
}
return Default
} else /* ipv6 */ {
var tunnelled bool
// Is our v6 is tunnelled?
if o.RFC3964() || o.RFC6052() || o.RFC6145() {
tunnelled = true
}
if !o.Routable() {
return Default
} else if o.RFC4380() {
return Teredo
} else if o.IP.To4() != nil {
return Ipv4
} else if tunnelled {
// only prioritise ipv6 if we aren't tunnelling it.
return Ipv6_weak
}
return Ipv6_strong
}
const (
Unreachable = 0
Default = iota
Teredo
Ipv6_weak
Ipv4
Ipv6_strong
Private
)
if !na.Routable() {
return Unreachable
} else if na.RFC4380() {
if !o.Routable() {
return Default
} else if o.RFC4380() {
return Teredo
} else if o.IP.To4() != nil {
return Ipv4
} else { // ipv6
return Ipv6_weak
}
} else if na.IP.To4() != nil {
if o.Routable() && o.IP.To4() != nil {
return Ipv4
}
return Default
} else /* ipv6 */ {
var tunnelled bool
// Is our v6 is tunnelled?
if o.RFC3964() || o.RFC6052() || o.RFC6145() {
tunnelled = true
}
if !o.Routable() {
return Default
} else if o.RFC4380() {
return Teredo
} else if o.IP.To4() != nil {
return Ipv4
} else if tunnelled {
// only prioritise ipv6 if we aren't tunnelling it.
return Ipv6_weak
}
return Ipv6_strong
}
}
// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12)
@ -161,23 +169,25 @@ func (na *NetAddress) ReachabilityTo(o *NetAddress) int {
// RFC4862: IPv6 Autoconfig (FE80::/64)
// RFC6052: IPv6 well known prefix (64:FF9B::/96)
// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96
var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)}
var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)}
var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)}
var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)}
var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)}
var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)}
var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)}
var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)}
var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)}
var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)}
var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}
var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)}
var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)}
var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)}
var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)}
var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)}
var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)}
var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)}
var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)}
func (na *NetAddress) RFC1918() bool { return rfc1918_10.Contains(na.IP) ||
rfc1918_192.Contains(na.IP) ||
rfc1918_172.Contains(na.IP) }
func (na *NetAddress) RFC1918() bool {
return rfc1918_10.Contains(na.IP) ||
rfc1918_192.Contains(na.IP) ||
rfc1918_172.Contains(na.IP)
}
func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) }
func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) }
func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) }

View File

@ -1,172 +1,174 @@
package peer
import (
. "github.com/tendermint/tendermint/binary"
"sync/atomic"
"sync"
"io"
"time"
"fmt"
"fmt"
. "github.com/tendermint/tendermint/binary"
"io"
"sync"
"sync/atomic"
"time"
)
/* Peer */
type Peer struct {
outgoing bool
conn *Connection
channels map[String]*Channel
outgoing bool
conn *Connection
channels map[String]*Channel
mtx sync.Mutex
quit chan struct{}
stopped uint32
mtx sync.Mutex
quit chan struct{}
stopped uint32
}
func NewPeer(conn *Connection) *Peer {
return &Peer{
conn: conn,
quit: make(chan struct{}),
stopped: 0,
}
return &Peer{
conn: conn,
quit: make(chan struct{}),
stopped: 0,
}
}
func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket ) {
log.Debugf("Starting %v", p)
p.conn.Start(p.channels)
for chName, _ := range p.channels {
go p.recvHandler(chName, peerRecvQueues[chName])
go p.sendHandler(chName)
}
func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket) {
log.Debugf("Starting %v", p)
p.conn.Start(p.channels)
for chName, _ := range p.channels {
go p.recvHandler(chName, peerRecvQueues[chName])
go p.sendHandler(chName)
}
}
func (p *Peer) Stop() {
// lock
p.mtx.Lock()
if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
log.Debugf("Stopping %v", p)
close(p.quit)
p.conn.Stop()
}
p.mtx.Unlock()
// unlock
// lock
p.mtx.Lock()
if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
log.Debugf("Stopping %v", p)
close(p.quit)
p.conn.Stop()
}
p.mtx.Unlock()
// unlock
}
func (p *Peer) LocalAddress() *NetAddress {
return p.conn.LocalAddress()
return p.conn.LocalAddress()
}
func (p *Peer) RemoteAddress() *NetAddress {
return p.conn.RemoteAddress()
return p.conn.RemoteAddress()
}
func (p *Peer) Channel(chName String) *Channel {
return p.channels[chName]
return p.channels[chName]
}
// If the channel's queue is full, just return false.
// Later the sendHandler will send the pkt to the underlying connection.
func (p *Peer) TrySend(pkt Packet) bool {
channel := p.Channel(pkt.Channel)
sendQueue := channel.SendQueue()
channel := p.Channel(pkt.Channel)
sendQueue := channel.SendQueue()
// lock & defer
p.mtx.Lock(); defer p.mtx.Unlock()
if p.stopped == 1 { return false }
select {
case sendQueue <- pkt:
return true
default: // buffer full
return false
}
// unlock deferred
// lock & defer
p.mtx.Lock()
defer p.mtx.Unlock()
if p.stopped == 1 {
return false
}
select {
case sendQueue <- pkt:
return true
default: // buffer full
return false
}
// unlock deferred
}
func (p *Peer) WriteTo(w io.Writer) (n int64, err error) {
return p.RemoteAddress().WriteTo(w)
return p.RemoteAddress().WriteTo(w)
}
func (p *Peer) String() string {
return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing)
return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing)
}
func (p *Peer) recvHandler(chName String, inboundPacketQueue chan<- *InboundPacket) {
log.Tracef("%v recvHandler [%v]", p, chName)
channel := p.channels[chName]
recvQueue := channel.RecvQueue()
log.Tracef("%v recvHandler [%v]", p, chName)
channel := p.channels[chName]
recvQueue := channel.RecvQueue()
FOR_LOOP:
for {
select {
case <-p.quit:
break FOR_LOOP
case pkt := <-recvQueue:
// send to inboundPacketQueue
inboundPacket := &InboundPacket{
Peer: p,
Channel: channel,
Time: Time{time.Now()},
Packet: pkt,
}
select {
case <-p.quit:
break FOR_LOOP
case inboundPacketQueue <- inboundPacket:
continue
}
}
}
FOR_LOOP:
for {
select {
case <-p.quit:
break FOR_LOOP
case pkt := <-recvQueue:
// send to inboundPacketQueue
inboundPacket := &InboundPacket{
Peer: p,
Channel: channel,
Time: Time{time.Now()},
Packet: pkt,
}
select {
case <-p.quit:
break FOR_LOOP
case inboundPacketQueue <- inboundPacket:
continue
}
}
}
log.Tracef("%v recvHandler [%v] closed", p, chName)
// cleanup
// (none)
log.Tracef("%v recvHandler [%v] closed", p, chName)
// cleanup
// (none)
}
func (p *Peer) sendHandler(chName String) {
log.Tracef("%v sendHandler [%v]", p, chName)
chSendQueue := p.channels[chName].sendQueue
FOR_LOOP:
for {
select {
case <-p.quit:
break FOR_LOOP
case pkt := <-chSendQueue:
log.Tracef("Sending packet to peer chSendQueue")
// blocks until the connection is Stop'd,
// which happens when this peer is Stop'd.
p.conn.Send(pkt)
}
}
log.Tracef("%v sendHandler [%v]", p, chName)
chSendQueue := p.channels[chName].sendQueue
FOR_LOOP:
for {
select {
case <-p.quit:
break FOR_LOOP
case pkt := <-chSendQueue:
log.Tracef("Sending packet to peer chSendQueue")
// blocks until the connection is Stop'd,
// which happens when this peer is Stop'd.
p.conn.Send(pkt)
}
}
log.Tracef("%v sendHandler [%v] closed", p, chName)
// cleanup
// (none)
log.Tracef("%v sendHandler [%v] closed", p, chName)
// cleanup
// (none)
}
/* Channel */
type Channel struct {
name String
recvQueue chan Packet
sendQueue chan Packet
//stats Stats
name String
recvQueue chan Packet
sendQueue chan Packet
//stats Stats
}
func NewChannel(name String, bufferSize int) *Channel {
return &Channel{
name: name,
recvQueue: make(chan Packet, bufferSize),
sendQueue: make(chan Packet, bufferSize),
}
return &Channel{
name: name,
recvQueue: make(chan Packet, bufferSize),
sendQueue: make(chan Packet, bufferSize),
}
}
func (c *Channel) Name() String {
return c.name
return c.name
}
func (c *Channel) RecvQueue() <-chan Packet {
return c.recvQueue
return c.recvQueue
}
func (c *Channel) SendQueue() chan<- Packet {
return c.sendQueue
return c.sendQueue
}

View File

@ -1,39 +1,38 @@
package peer
import (
)
import ()
/* Server */
type Server struct {
listener Listener
client *Client
listener Listener
client *Client
}
func NewServer(protocol string, laddr string, c *Client) *Server {
l := NewDefaultListener(protocol, laddr)
s := &Server{
listener: l,
client: c,
}
go s.IncomingConnectionHandler()
return s
l := NewDefaultListener(protocol, laddr)
s := &Server{
listener: l,
client: c,
}
go s.IncomingConnectionHandler()
return s
}
func (s *Server) LocalAddress() *NetAddress {
return s.listener.LocalAddress()
return s.listener.LocalAddress()
}
// meant to run in a goroutine
func (s *Server) IncomingConnectionHandler() {
for conn := range s.listener.Connections() {
log.Infof("New connection found: %v", conn)
s.client.AddPeerWithConnection(conn, false)
}
for conn := range s.listener.Connections() {
log.Infof("New connection found: %v", conn)
s.client.AddPeerWithConnection(conn, false)
}
}
func (s *Server) Stop() {
log.Infof("Stopping server")
s.listener.Stop()
s.client.Stop()
log.Infof("Stopping server")
s.listener.Stop()
s.client.Stop()
}

View File

@ -7,370 +7,370 @@ package peer
//
import (
"bytes"
"encoding/xml"
"errors"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"time"
"bytes"
"encoding/xml"
"errors"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"time"
)
type upnpNAT struct {
serviceURL string
ourIP string
urnDomain string
serviceURL string
ourIP string
urnDomain string
}
// protocol is either "udp" or "tcp"
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
func Discover() (nat NAT, err error) {
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil {
return
}
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
return
}
socket := conn.(*net.UDPConn)
defer socket.Close()
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil {
return
}
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
return
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(3 * time.Second))
if err != nil {
return
}
err = socket.SetDeadline(time.Now().Add(3 * time.Second))
if err != nil {
return
}
st := "InternetGatewayDevice:1"
st := "InternetGatewayDevice:1"
buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" +
"HOST: 239.255.255.250:1900\r\n" +
"ST: ssdp:all\r\n" +
"MAN: \"ssdp:discover\"\r\n" +
"MX: 2\r\n\r\n")
message := buf.Bytes()
answerBytes := make([]byte, 1024)
for i := 0; i < 3; i++ {
_, err = socket.WriteToUDP(message, ssdp)
if err != nil {
return
}
var n int
n, _, err = socket.ReadFromUDP(answerBytes)
for {
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil {
break
}
answer := string(answerBytes[0:n])
if strings.Index(answer, st) < 0 {
continue
}
// HTTP header field names are case-insensitive.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
locString := "\r\nlocation:"
answer = strings.ToLower(answer)
locIndex := strings.Index(answer, locString)
if locIndex < 0 {
continue
}
loc := answer[locIndex+len(locString):]
endIndex := strings.Index(loc, "\r\n")
if endIndex < 0 {
continue
}
locURL := strings.TrimSpace(loc[0:endIndex])
var serviceURL, urnDomain string
serviceURL, urnDomain, err = getServiceURL(locURL)
if err != nil {
return
}
var ourIP net.IP
ourIP, err = localIPv4()
if err != nil {
return
}
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
return
}
}
err = errors.New("UPnP port discovery failed.")
return
buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" +
"HOST: 239.255.255.250:1900\r\n" +
"ST: ssdp:all\r\n" +
"MAN: \"ssdp:discover\"\r\n" +
"MX: 2\r\n\r\n")
message := buf.Bytes()
answerBytes := make([]byte, 1024)
for i := 0; i < 3; i++ {
_, err = socket.WriteToUDP(message, ssdp)
if err != nil {
return
}
var n int
n, _, err = socket.ReadFromUDP(answerBytes)
for {
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil {
break
}
answer := string(answerBytes[0:n])
if strings.Index(answer, st) < 0 {
continue
}
// HTTP header field names are case-insensitive.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
locString := "\r\nlocation:"
answer = strings.ToLower(answer)
locIndex := strings.Index(answer, locString)
if locIndex < 0 {
continue
}
loc := answer[locIndex+len(locString):]
endIndex := strings.Index(loc, "\r\n")
if endIndex < 0 {
continue
}
locURL := strings.TrimSpace(loc[0:endIndex])
var serviceURL, urnDomain string
serviceURL, urnDomain, err = getServiceURL(locURL)
if err != nil {
return
}
var ourIP net.IP
ourIP, err = localIPv4()
if err != nil {
return
}
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
return
}
}
err = errors.New("UPnP port discovery failed.")
return
}
type Envelope struct {
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
Soap *SoapBody
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
Soap *SoapBody
}
type SoapBody struct {
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
ExternalIP *ExternalIPAddressResponse
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
ExternalIP *ExternalIPAddressResponse
}
type ExternalIPAddressResponse struct {
XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
IPAddress string `xml:"NewExternalIPAddress"`
XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
IPAddress string `xml:"NewExternalIPAddress"`
}
type ExternalIPAddress struct {
XMLName xml.Name `xml:"NewExternalIPAddress"`
IP string
XMLName xml.Name `xml:"NewExternalIPAddress"`
IP string
}
type Service struct {
ServiceType string `xml:"serviceType"`
ControlURL string `xml:"controlURL"`
ServiceType string `xml:"serviceType"`
ControlURL string `xml:"controlURL"`
}
type DeviceList struct {
Device []Device `xml:"device"`
Device []Device `xml:"device"`
}
type ServiceList struct {
Service []Service `xml:"service"`
Service []Service `xml:"service"`
}
type Device struct {
XMLName xml.Name `xml:"device"`
DeviceType string `xml:"deviceType"`
DeviceList DeviceList `xml:"deviceList"`
ServiceList ServiceList `xml:"serviceList"`
XMLName xml.Name `xml:"device"`
DeviceType string `xml:"deviceType"`
DeviceList DeviceList `xml:"deviceList"`
ServiceList ServiceList `xml:"serviceList"`
}
type Root struct {
Device Device
Device Device
}
func getChildDevice(d *Device, deviceType string) *Device {
dl := d.DeviceList.Device
for i := 0; i < len(dl); i++ {
if strings.Index(dl[i].DeviceType, deviceType) >= 0 {
return &dl[i]
}
}
return nil
dl := d.DeviceList.Device
for i := 0; i < len(dl); i++ {
if strings.Index(dl[i].DeviceType, deviceType) >= 0 {
return &dl[i]
}
}
return nil
}
func getChildService(d *Device, serviceType string) *Service {
sl := d.ServiceList.Service
for i := 0; i < len(sl); i++ {
if strings.Index(sl[i].ServiceType, serviceType) >= 0 {
return &sl[i]
}
}
return nil
sl := d.ServiceList.Service
for i := 0; i < len(sl); i++ {
if strings.Index(sl[i].ServiceType, serviceType) >= 0 {
return &sl[i]
}
}
return nil
}
func localIPv4() (net.IP, error) {
tt, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, t := range tt {
aa, err := t.Addrs()
if err != nil {
return nil, err
}
for _, a := range aa {
ipnet, ok := a.(*net.IPNet)
if !ok {
continue
}
v4 := ipnet.IP.To4()
if v4 == nil || v4[0] == 127 { // loopback address
continue
}
return v4, nil
}
}
return nil, errors.New("cannot find local IP address")
tt, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, t := range tt {
aa, err := t.Addrs()
if err != nil {
return nil, err
}
for _, a := range aa {
ipnet, ok := a.(*net.IPNet)
if !ok {
continue
}
v4 := ipnet.IP.To4()
if v4 == nil || v4[0] == 127 { // loopback address
continue
}
return v4, nil
}
}
return nil, errors.New("cannot find local IP address")
}
func getServiceURL(rootURL string) (url, urnDomain string, err error) {
r, err := http.Get(rootURL)
if err != nil {
return
}
defer r.Body.Close()
if r.StatusCode >= 400 {
err = errors.New(string(r.StatusCode))
return
}
var root Root
err = xml.NewDecoder(r.Body).Decode(&root)
if err != nil {
return
}
a := &root.Device
if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 {
err = errors.New("No InternetGatewayDevice")
return
}
b := getChildDevice(a, "WANDevice:1")
if b == nil {
err = errors.New("No WANDevice")
return
}
c := getChildDevice(b, "WANConnectionDevice:1")
if c == nil {
err = errors.New("No WANConnectionDevice")
return
}
d := getChildService(c, "WANIPConnection:1")
if d == nil {
// Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
// instead of under WanConnectionDevice
d = getChildService(b, "WANIPConnection:1")
r, err := http.Get(rootURL)
if err != nil {
return
}
defer r.Body.Close()
if r.StatusCode >= 400 {
err = errors.New(string(r.StatusCode))
return
}
var root Root
err = xml.NewDecoder(r.Body).Decode(&root)
if err != nil {
return
}
a := &root.Device
if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 {
err = errors.New("No InternetGatewayDevice")
return
}
b := getChildDevice(a, "WANDevice:1")
if b == nil {
err = errors.New("No WANDevice")
return
}
c := getChildDevice(b, "WANConnectionDevice:1")
if c == nil {
err = errors.New("No WANConnectionDevice")
return
}
d := getChildService(c, "WANIPConnection:1")
if d == nil {
// Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
// instead of under WanConnectionDevice
d = getChildService(b, "WANIPConnection:1")
if d == nil {
err = errors.New("No WANIPConnection")
return
}
}
// Extract the domain name, which isn't always 'schemas-upnp-org'
urnDomain = strings.Split(d.ServiceType, ":")[1]
url = combineURL(rootURL, d.ControlURL)
return
if d == nil {
err = errors.New("No WANIPConnection")
return
}
}
// Extract the domain name, which isn't always 'schemas-upnp-org'
urnDomain = strings.Split(d.ServiceType, ":")[1]
url = combineURL(rootURL, d.ControlURL)
return
}
func combineURL(rootURL, subURL string) string {
protocolEnd := "://"
protoEndIndex := strings.Index(rootURL, protocolEnd)
a := rootURL[protoEndIndex+len(protocolEnd):]
rootIndex := strings.Index(a, "/")
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
protocolEnd := "://"
protoEndIndex := strings.Index(rootURL, protocolEnd)
a := rootURL[protoEndIndex+len(protocolEnd):]
rootIndex := strings.Index(a, "/")
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
}
func soapRequest(url, function, message, domain string) (r *http.Response, err error) {
fullMessage := "<?xml version=\"1.0\" ?>" +
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
"<s:Body>" + message + "</s:Body></s:Envelope>"
fullMessage := "<?xml version=\"1.0\" ?>" +
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
"<s:Body>" + message + "</s:Body></s:Envelope>"
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
//req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
req.Header.Set("Connection", "Close")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
//req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
req.Header.Set("Connection", "Close")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
// log.Stderr("soapRequest ", req)
// log.Stderr("soapRequest ", req)
r, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
/*if r.Body != nil {
defer r.Body.Close()
}*/
r, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
/*if r.Body != nil {
defer r.Body.Close()
}*/
if r.StatusCode >= 400 {
// log.Stderr(function, r.StatusCode)
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
r = nil
return
}
return
if r.StatusCode >= 400 {
// log.Stderr(function, r.StatusCode)
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
r = nil
return
}
return
}
type statusInfo struct {
externalIpAddress string
externalIpAddress string
}
func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) {
message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"</u:GetExternalIPAddress>"
message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"</u:GetExternalIPAddress>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
var envelope Envelope
data, err := ioutil.ReadAll(response.Body)
reader := bytes.NewReader(data)
xml.NewDecoder(reader).Decode(&envelope)
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
var envelope Envelope
data, err := ioutil.ReadAll(response.Body)
reader := bytes.NewReader(data)
xml.NewDecoder(reader).Decode(&envelope)
info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
if err != nil {
return
}
if err != nil {
return
}
return
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getExternalIPAddress()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
info, err := n.getExternalIPAddress()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}

View File

@ -1,8 +1,8 @@
package peer
import (
"testing"
"time"
"testing"
"time"
)
/*
@ -11,38 +11,38 @@ TODO: set up or find a service to probe open ports.
*/
func TestUPNP(t *testing.T) {
t.Log("hello!")
t.Log("hello!")
nat, err := Discover()
if err != nil {
t.Fatalf("NAT upnp could not be discovered: %v", err)
}
nat, err := Discover()
if err != nil {
t.Fatalf("NAT upnp could not be discovered: %v", err)
}
t.Log("ourIP: ", nat.(*upnpNAT).ourIP)
t.Log("ourIP: ", nat.(*upnpNAT).ourIP)
ext, err := nat.GetExternalAddress()
if err != nil {
t.Fatalf("External address error: %v", err)
}
t.Logf("External address: %v", ext)
ext, err := nat.GetExternalAddress()
if err != nil {
t.Fatalf("External address error: %v", err)
}
t.Logf("External address: %v", ext)
port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0)
if err != nil {
t.Fatalf("Port mapping error: %v", err)
}
t.Logf("Port mapping mapped: %v", port)
port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0)
if err != nil {
t.Fatalf("Port mapping error: %v", err)
}
t.Logf("Port mapping mapped: %v", port)
// also run the listener, open for all remote addresses.
listener := NewDefaultListener("tcp", "0.0.0.0:8001")
// also run the listener, open for all remote addresses.
listener := NewDefaultListener("tcp", "0.0.0.0:8001")
// now sleep for 10 seconds
time.Sleep(10 * time.Second)
// now sleep for 10 seconds
time.Sleep(10 * time.Second)
err = nat.DeletePortMapping("tcp", 8001, 8001)
if err != nil {
t.Fatalf("Port mapping delete error: %v", err)
}
t.Logf("Port mapping deleted")
err = nat.DeletePortMapping("tcp", 8001, 8001)
if err != nil {
t.Fatalf("Port mapping delete error: %v", err)
}
t.Logf("Port mapping deleted")
listener.Stop()
listener.Stop()
}

View File

@ -1,7 +1,7 @@
package peer
import (
"crypto/sha256"
"crypto/sha256"
)
// DoubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes.