From c40fc65e6bbdd7821b18f1db48392ba89d4f34c7 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 1 Jul 2014 14:50:24 -0700 Subject: [PATCH] go fmt --- binary/binary.go | 12 +- binary/byteslice.go | 58 ++-- binary/codec.go | 128 ++++--- binary/int.go | 413 ++++++++++++----------- binary/string.go | 50 +-- binary/time.go | 32 +- binary/util.go | 32 +- blocks/account.go | 58 ++-- blocks/adjustment.go | 146 ++++---- blocks/block.go | 169 +++++----- blocks/block_test.go | 168 ++++----- blocks/signature.go | 24 +- blocks/tx.go | 98 +++--- blocks/vote.go | 28 +- common/debounce.go | 59 ++-- common/heap.go | 57 ++-- common/panic.go | 4 +- config/config.go | 139 ++++---- crypto/ed25519.go | 80 ++--- crypto/ed25519_test.go | 60 ++-- db/level_db.go | 60 ++-- db/mem_db.go | 24 +- merkle/iavl_node.go | 642 ++++++++++++++++++----------------- merkle/iavl_test.go | 444 ++++++++++++------------ merkle/iavl_tree.go | 124 ++++--- merkle/types.go | 56 +-- merkle/util.go | 109 +++--- peer/addrbook.go | 747 +++++++++++++++++++++-------------------- peer/client.go | 230 +++++++------ peer/client_test.go | 139 ++++---- peer/connection.go | 259 +++++++------- peer/knownaddress.go | 142 ++++---- peer/listener.go | 146 ++++---- peer/log.go | 14 +- peer/msg.go | 58 ++-- peer/netaddress.go | 236 ++++++------- peer/peer.go | 216 ++++++------ peer/server.go | 37 +- peer/upnp.go | 558 +++++++++++++++--------------- peer/upnp_test.go | 56 +-- peer/util.go | 2 +- 41 files changed, 3176 insertions(+), 2938 deletions(-) diff --git a/binary/binary.go b/binary/binary.go index 8d13cd35..453079a4 100644 --- a/binary/binary.go +++ b/binary/binary.go @@ -3,12 +3,14 @@ package binary import "io" type Binary interface { - WriteTo(w io.Writer) (int64, error) + WriteTo(w io.Writer) (int64, error) } func WriteOnto(b Binary, w io.Writer, n int64, err error) (int64, error) { - if err != nil { return n, err } - var n_ int64 - n_, err = b.WriteTo(w) - return n+n_, err + if err != nil { + return n, err + } + var n_ int64 + n_, err = b.WriteTo(w) + return n + n_, err } diff --git a/binary/byteslice.go b/binary/byteslice.go index 238207f6..c17625e6 100644 --- a/binary/byteslice.go +++ b/binary/byteslice.go @@ -6,44 +6,52 @@ import "bytes" type ByteSlice []byte func (self ByteSlice) Equals(other Binary) bool { - if o, ok := other.(ByteSlice); ok { - return bytes.Equal(self, o) - } else { - return false - } + if o, ok := other.(ByteSlice); ok { + return bytes.Equal(self, o) + } else { + return false + } } func (self ByteSlice) Less(other Binary) bool { - if o, ok := other.(ByteSlice); ok { - return bytes.Compare(self, o) < 0 // -1 if a < b - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(ByteSlice); ok { + return bytes.Compare(self, o) < 0 // -1 if a < b + } else { + panic("Cannot compare unequal types") + } } func (self ByteSlice) ByteSize() int { - return len(self)+4 + return len(self) + 4 } func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) { - var n_ int - _, err = UInt32(len(self)).WriteTo(w) - if err != nil { return n, err } - n_, err = w.Write([]byte(self)) - return int64(n_+4), err + var n_ int + _, err = UInt32(len(self)).WriteTo(w) + if err != nil { + return n, err + } + n_, err = w.Write([]byte(self)) + return int64(n_ + 4), err } func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) { - length, err := ReadUInt32Safe(r) - if err != nil { return nil, err } - bytes := make([]byte, int(length)) - _, err = io.ReadFull(r, bytes) - if err != nil { return nil, err } - return bytes, nil + length, err := ReadUInt32Safe(r) + if err != nil { + return nil, err + } + bytes := make([]byte, int(length)) + _, err = io.ReadFull(r, bytes) + if err != nil { + return nil, err + } + return bytes, nil } func ReadByteSlice(r io.Reader) ByteSlice { - bytes, err := ReadByteSliceSafe(r) - if r != nil { panic(err) } - return bytes + bytes, err := ReadByteSliceSafe(r) + if r != nil { + panic(err) + } + return bytes } diff --git a/binary/codec.go b/binary/codec.go index f6825e9b..4544c53f 100644 --- a/binary/codec.go +++ b/binary/codec.go @@ -1,70 +1,100 @@ package binary import ( - "io" + "io" ) const ( - TYPE_NIL = Byte(0x00) - TYPE_BYTE = Byte(0x01) - TYPE_INT8 = Byte(0x02) - TYPE_UINT8 = Byte(0x03) - TYPE_INT16 = Byte(0x04) - TYPE_UINT16 = Byte(0x05) - TYPE_INT32 = Byte(0x06) - TYPE_UINT32 = Byte(0x07) - TYPE_INT64 = Byte(0x08) - TYPE_UINT64 = Byte(0x09) + TYPE_NIL = Byte(0x00) + TYPE_BYTE = Byte(0x01) + TYPE_INT8 = Byte(0x02) + TYPE_UINT8 = Byte(0x03) + TYPE_INT16 = Byte(0x04) + TYPE_UINT16 = Byte(0x05) + TYPE_INT32 = Byte(0x06) + TYPE_UINT32 = Byte(0x07) + TYPE_INT64 = Byte(0x08) + TYPE_UINT64 = Byte(0x09) - TYPE_STRING = Byte(0x10) - TYPE_BYTESLICE = Byte(0x11) + TYPE_STRING = Byte(0x10) + TYPE_BYTESLICE = Byte(0x11) - TYPE_TIME = Byte(0x20) + TYPE_TIME = Byte(0x20) ) func GetBinaryType(o Binary) Byte { - switch o.(type) { - case nil: return TYPE_NIL - case Byte: return TYPE_BYTE - case Int8: return TYPE_INT8 - case UInt8: return TYPE_UINT8 - case Int16: return TYPE_INT16 - case UInt16: return TYPE_UINT16 - case Int32: return TYPE_INT32 - case UInt32: return TYPE_UINT32 - case Int64: return TYPE_INT64 - case UInt64: return TYPE_UINT64 - case Int: panic("Int not supported") - case UInt: panic("UInt not supported") + switch o.(type) { + case nil: + return TYPE_NIL + case Byte: + return TYPE_BYTE + case Int8: + return TYPE_INT8 + case UInt8: + return TYPE_UINT8 + case Int16: + return TYPE_INT16 + case UInt16: + return TYPE_UINT16 + case Int32: + return TYPE_INT32 + case UInt32: + return TYPE_UINT32 + case Int64: + return TYPE_INT64 + case UInt64: + return TYPE_UINT64 + case Int: + panic("Int not supported") + case UInt: + panic("UInt not supported") - case String: return TYPE_STRING - case ByteSlice: return TYPE_BYTESLICE + case String: + return TYPE_STRING + case ByteSlice: + return TYPE_BYTESLICE - case Time: return TYPE_TIME + case Time: + return TYPE_TIME - default: panic("Unsupported type") - } + default: + panic("Unsupported type") + } } func ReadBinary(r io.Reader) Binary { - type_ := ReadByte(r) - switch type_ { - case TYPE_NIL: return nil - case TYPE_BYTE: return ReadByte(r) - case TYPE_INT8: return ReadInt8(r) - case TYPE_UINT8: return ReadUInt8(r) - case TYPE_INT16: return ReadInt16(r) - case TYPE_UINT16: return ReadUInt16(r) - case TYPE_INT32: return ReadInt32(r) - case TYPE_UINT32: return ReadUInt32(r) - case TYPE_INT64: return ReadInt64(r) - case TYPE_UINT64: return ReadUInt64(r) + type_ := ReadByte(r) + switch type_ { + case TYPE_NIL: + return nil + case TYPE_BYTE: + return ReadByte(r) + case TYPE_INT8: + return ReadInt8(r) + case TYPE_UINT8: + return ReadUInt8(r) + case TYPE_INT16: + return ReadInt16(r) + case TYPE_UINT16: + return ReadUInt16(r) + case TYPE_INT32: + return ReadInt32(r) + case TYPE_UINT32: + return ReadUInt32(r) + case TYPE_INT64: + return ReadInt64(r) + case TYPE_UINT64: + return ReadUInt64(r) - case TYPE_STRING: return ReadString(r) - case TYPE_BYTESLICE:return ReadByteSlice(r) + case TYPE_STRING: + return ReadString(r) + case TYPE_BYTESLICE: + return ReadByteSlice(r) - case TYPE_TIME: return ReadTime(r) + case TYPE_TIME: + return ReadTime(r) - default: panic("Unsupported type") - } + default: + panic("Unsupported type") + } } diff --git a/binary/int.go b/binary/int.go index 7d381014..1a022367 100644 --- a/binary/int.go +++ b/binary/int.go @@ -1,8 +1,8 @@ package binary import ( - "io" - "encoding/binary" + "encoding/binary" + "io" ) type Byte byte @@ -17,397 +17,426 @@ type UInt64 uint64 type Int int type UInt uint - // Byte func (self Byte) Equals(other Binary) bool { - return self == other + return self == other } func (self Byte) Less(other Binary) bool { - if o, ok := other.(Byte); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Byte); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Byte) ByteSize() int { - return 1 + return 1 } func (self Byte) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err + n, err := w.Write([]byte{byte(self)}) + return int64(n), err } func ReadByteSafe(r io.Reader) (Byte, error) { - buf := [1]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return 0, err } - return Byte(buf[0]), nil + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return 0, err + } + return Byte(buf[0]), nil } -func ReadByte(r io.Reader) (Byte) { - b, err := ReadByteSafe(r) - if err != nil { panic(err) } - return b +func ReadByte(r io.Reader) Byte { + b, err := ReadByteSafe(r) + if err != nil { + panic(err) + } + return b } - // Int8 func (self Int8) Equals(other Binary) bool { - return self == other + return self == other } func (self Int8) Less(other Binary) bool { - if o, ok := other.(Int8); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Int8); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Int8) ByteSize() int { - return 1 + return 1 } func (self Int8) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err + n, err := w.Write([]byte{byte(self)}) + return int64(n), err } func ReadInt8Safe(r io.Reader) (Int8, error) { - buf := [1]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return Int8(0), err } - return Int8(buf[0]), nil + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return Int8(0), err + } + return Int8(buf[0]), nil } -func ReadInt8(r io.Reader) (Int8) { - b, err := ReadInt8Safe(r) - if err != nil { panic(err) } - return b +func ReadInt8(r io.Reader) Int8 { + b, err := ReadInt8Safe(r) + if err != nil { + panic(err) + } + return b } - // UInt8 func (self UInt8) Equals(other Binary) bool { - return self == other + return self == other } func (self UInt8) Less(other Binary) bool { - if o, ok := other.(UInt8); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(UInt8); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self UInt8) ByteSize() int { - return 1 + return 1 } func (self UInt8) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err + n, err := w.Write([]byte{byte(self)}) + return int64(n), err } func ReadUInt8Safe(r io.Reader) (UInt8, error) { - buf := [1]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return UInt8(0), err } - return UInt8(buf[0]), nil + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return UInt8(0), err + } + return UInt8(buf[0]), nil } -func ReadUInt8(r io.Reader) (UInt8) { - b, err := ReadUInt8Safe(r) - if err != nil { panic(err) } - return b +func ReadUInt8(r io.Reader) UInt8 { + b, err := ReadUInt8Safe(r) + if err != nil { + panic(err) + } + return b } - // Int16 func (self Int16) Equals(other Binary) bool { - return self == other + return self == other } func (self Int16) Less(other Binary) bool { - if o, ok := other.(Int16); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Int16); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Int16) ByteSize() int { - return 2 + return 2 } func (self Int16) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, int16(self)) - return 2, err + err := binary.Write(w, binary.LittleEndian, int16(self)) + return 2, err } func ReadInt16Safe(r io.Reader) (Int16, error) { - buf := [2]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return Int16(0), err } - return Int16(binary.LittleEndian.Uint16(buf[:])), nil + buf := [2]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return Int16(0), err + } + return Int16(binary.LittleEndian.Uint16(buf[:])), nil } -func ReadInt16(r io.Reader) (Int16) { - b, err := ReadInt16Safe(r) - if err != nil { panic(err) } - return b +func ReadInt16(r io.Reader) Int16 { + b, err := ReadInt16Safe(r) + if err != nil { + panic(err) + } + return b } - // UInt16 func (self UInt16) Equals(other Binary) bool { - return self == other + return self == other } func (self UInt16) Less(other Binary) bool { - if o, ok := other.(UInt16); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(UInt16); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self UInt16) ByteSize() int { - return 2 + return 2 } func (self UInt16) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, uint16(self)) - return 2, err + err := binary.Write(w, binary.LittleEndian, uint16(self)) + return 2, err } func ReadUInt16Safe(r io.Reader) (UInt16, error) { - buf := [2]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return UInt16(0), err } - return UInt16(binary.LittleEndian.Uint16(buf[:])), nil + buf := [2]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return UInt16(0), err + } + return UInt16(binary.LittleEndian.Uint16(buf[:])), nil } -func ReadUInt16(r io.Reader) (UInt16) { - b, err := ReadUInt16Safe(r) - if err != nil { panic(err) } - return b +func ReadUInt16(r io.Reader) UInt16 { + b, err := ReadUInt16Safe(r) + if err != nil { + panic(err) + } + return b } - // Int32 func (self Int32) Equals(other Binary) bool { - return self == other + return self == other } func (self Int32) Less(other Binary) bool { - if o, ok := other.(Int32); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Int32); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Int32) ByteSize() int { - return 4 + return 4 } func (self Int32) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, int32(self)) - return 4, err + err := binary.Write(w, binary.LittleEndian, int32(self)) + return 4, err } func ReadInt32Safe(r io.Reader) (Int32, error) { - buf := [4]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return Int32(0), err } - return Int32(binary.LittleEndian.Uint32(buf[:])), nil + buf := [4]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return Int32(0), err + } + return Int32(binary.LittleEndian.Uint32(buf[:])), nil } -func ReadInt32(r io.Reader) (Int32) { - b, err := ReadInt32Safe(r) - if err != nil { panic(err) } - return b +func ReadInt32(r io.Reader) Int32 { + b, err := ReadInt32Safe(r) + if err != nil { + panic(err) + } + return b } - // UInt32 func (self UInt32) Equals(other Binary) bool { - return self == other + return self == other } func (self UInt32) Less(other Binary) bool { - if o, ok := other.(UInt32); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(UInt32); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self UInt32) ByteSize() int { - return 4 + return 4 } func (self UInt32) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, uint32(self)) - return 4, err + err := binary.Write(w, binary.LittleEndian, uint32(self)) + return 4, err } func ReadUInt32Safe(r io.Reader) (UInt32, error) { - buf := [4]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return UInt32(0), err } - return UInt32(binary.LittleEndian.Uint32(buf[:])), nil + buf := [4]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return UInt32(0), err + } + return UInt32(binary.LittleEndian.Uint32(buf[:])), nil } -func ReadUInt32(r io.Reader) (UInt32) { - b, err := ReadUInt32Safe(r) - if err != nil { panic(err) } - return b +func ReadUInt32(r io.Reader) UInt32 { + b, err := ReadUInt32Safe(r) + if err != nil { + panic(err) + } + return b } - // Int64 func (self Int64) Equals(other Binary) bool { - return self == other + return self == other } func (self Int64) Less(other Binary) bool { - if o, ok := other.(Int64); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Int64); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Int64) ByteSize() int { - return 8 + return 8 } func (self Int64) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, int64(self)) - return 8, err + err := binary.Write(w, binary.LittleEndian, int64(self)) + return 8, err } func ReadInt64Safe(r io.Reader) (Int64, error) { - buf := [8]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return Int64(0), err } - return Int64(binary.LittleEndian.Uint64(buf[:])), nil + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return Int64(0), err + } + return Int64(binary.LittleEndian.Uint64(buf[:])), nil } -func ReadInt64(r io.Reader) (Int64) { - b, err := ReadInt64Safe(r) - if err != nil { panic(err) } - return b +func ReadInt64(r io.Reader) Int64 { + b, err := ReadInt64Safe(r) + if err != nil { + panic(err) + } + return b } - // UInt64 func (self UInt64) Equals(other Binary) bool { - return self == other + return self == other } func (self UInt64) Less(other Binary) bool { - if o, ok := other.(UInt64); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(UInt64); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self UInt64) ByteSize() int { - return 8 + return 8 } func (self UInt64) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, uint64(self)) - return 8, err + err := binary.Write(w, binary.LittleEndian, uint64(self)) + return 8, err } func ReadUInt64Safe(r io.Reader) (UInt64, error) { - buf := [8]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { return UInt64(0), err } - return UInt64(binary.LittleEndian.Uint64(buf[:])), nil + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return UInt64(0), err + } + return UInt64(binary.LittleEndian.Uint64(buf[:])), nil } -func ReadUInt64(r io.Reader) (UInt64) { - b, err := ReadUInt64Safe(r) - if err != nil { panic(err) } - return b +func ReadUInt64(r io.Reader) UInt64 { + b, err := ReadUInt64Safe(r) + if err != nil { + panic(err) + } + return b } - // Int func (self Int) Equals(other Binary) bool { - return self == other + return self == other } func (self Int) Less(other Binary) bool { - if o, ok := other.(Int); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Int); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self Int) ByteSize() int { - return 8 + return 8 } func (self Int) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, int64(self)) - return 8, err + err := binary.Write(w, binary.LittleEndian, int64(self)) + return 8, err } func ReadInt(r io.Reader) Int { - buf := [8]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { panic(err) } - return Int(binary.LittleEndian.Uint64(buf[:])) + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + panic(err) + } + return Int(binary.LittleEndian.Uint64(buf[:])) } - // UInt func (self UInt) Equals(other Binary) bool { - return self == other + return self == other } func (self UInt) Less(other Binary) bool { - if o, ok := other.(UInt); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(UInt); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self UInt) ByteSize() int { - return 8 + return 8 } func (self UInt) WriteTo(w io.Writer) (int64, error) { - err := binary.Write(w, binary.LittleEndian, uint64(self)) - return 8, err + err := binary.Write(w, binary.LittleEndian, uint64(self)) + return 8, err } func ReadUInt(r io.Reader) UInt { - buf := [8]byte{0} - _, err := io.ReadFull(r, buf[:]) - if err != nil { panic(err) } - return UInt(binary.LittleEndian.Uint64(buf[:])) + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { + panic(err) + } + return UInt(binary.LittleEndian.Uint64(buf[:])) } diff --git a/binary/string.go b/binary/string.go index 063385d8..2bc1d8cb 100644 --- a/binary/string.go +++ b/binary/string.go @@ -7,40 +7,48 @@ type String string // String func (self String) Equals(other Binary) bool { - return self == other + return self == other } func (self String) Less(other Binary) bool { - if o, ok := other.(String); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(String); ok { + return self < o + } else { + panic("Cannot compare unequal types") + } } func (self String) ByteSize() int { - return len(self)+4 + return len(self) + 4 } func (self String) WriteTo(w io.Writer) (n int64, err error) { - var n_ int - _, err = UInt32(len(self)).WriteTo(w) - if err != nil { return n, err } - n_, err = w.Write([]byte(self)) - return int64(n_+4), err + var n_ int + _, err = UInt32(len(self)).WriteTo(w) + if err != nil { + return n, err + } + n_, err = w.Write([]byte(self)) + return int64(n_ + 4), err } func ReadStringSafe(r io.Reader) (String, error) { - length, err := ReadUInt32Safe(r) - if err != nil { return "", err } - bytes := make([]byte, int(length)) - _, err = io.ReadFull(r, bytes) - if err != nil { return "", err } - return String(bytes), nil + length, err := ReadUInt32Safe(r) + if err != nil { + return "", err + } + bytes := make([]byte, int(length)) + _, err = io.ReadFull(r, bytes) + if err != nil { + return "", err + } + return String(bytes), nil } func ReadString(r io.Reader) String { - str, err := ReadStringSafe(r) - if r != nil { panic(err) } - return str + str, err := ReadStringSafe(r) + if r != nil { + panic(err) + } + return str } diff --git a/binary/time.go b/binary/time.go index 9d02ea57..62f7f066 100644 --- a/binary/time.go +++ b/binary/time.go @@ -1,38 +1,38 @@ package binary import ( - "io" - "time" + "io" + "time" ) type Time struct { - time.Time + time.Time } func (self Time) Equals(other Binary) bool { - if o, ok := other.(Time); ok { - return self.Equal(o.Time) - } else { - return false - } + if o, ok := other.(Time); ok { + return self.Equal(o.Time) + } else { + return false + } } func (self Time) Less(other Binary) bool { - if o, ok := other.(Time); ok { - return self.Before(o.Time) - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(Time); ok { + return self.Before(o.Time) + } else { + panic("Cannot compare unequal types") + } } func (self Time) ByteSize() int { - return 8 + return 8 } func (self Time) WriteTo(w io.Writer) (int64, error) { - return Int64(self.Unix()).WriteTo(w) + return Int64(self.Unix()).WriteTo(w) } func ReadTime(r io.Reader) Time { - return Time{time.Unix(int64(ReadInt64(r)), 0)} + return Time{time.Unix(int64(ReadInt64(r)), 0)} } diff --git a/binary/util.go b/binary/util.go index 4087378c..e08b8b56 100644 --- a/binary/util.go +++ b/binary/util.go @@ -1,33 +1,35 @@ package binary import ( - "crypto/sha256" - "bytes" + "bytes" + "crypto/sha256" ) func BinaryBytes(b Binary) ByteSlice { - buf := bytes.NewBuffer(nil) - b.WriteTo(buf) - return ByteSlice(buf.Bytes()) + buf := bytes.NewBuffer(nil) + b.WriteTo(buf) + return ByteSlice(buf.Bytes()) } // NOTE: does not care about the type, only the binary representation. func BinaryEqual(a, b Binary) bool { - aBytes := BinaryBytes(a) - bBytes := BinaryBytes(b) - return bytes.Equal(aBytes, bBytes) + aBytes := BinaryBytes(a) + bBytes := BinaryBytes(b) + return bytes.Equal(aBytes, bBytes) } // NOTE: does not care about the type, only the binary representation. func BinaryCompare(a, b Binary) int { - aBytes := BinaryBytes(a) - bBytes := BinaryBytes(b) - return bytes.Compare(aBytes, bBytes) + aBytes := BinaryBytes(a) + bBytes := BinaryBytes(b) + return bytes.Compare(aBytes, bBytes) } func BinaryHash(b Binary) ByteSlice { - hasher := sha256.New() - _, err := b.WriteTo(hasher) - if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) + hasher := sha256.New() + _, err := b.WriteTo(hasher) + if err != nil { + panic(err) + } + return ByteSlice(hasher.Sum(nil)) } diff --git a/blocks/account.go b/blocks/account.go index e7b2c595..b4b1915c 100644 --- a/blocks/account.go +++ b/blocks/account.go @@ -1,48 +1,48 @@ package blocks import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "io" ) type AccountId struct { - Type Byte - Number UInt64 - PubKey ByteSlice + Type Byte + Number UInt64 + PubKey ByteSlice } const ( - ACCOUNT_TYPE_NUMBER = Byte(0x01) - ACCOUNT_TYPE_PUBKEY = Byte(0x02) - ACCOUNT_TYPE_BOTH = Byte(0x03) + ACCOUNT_TYPE_NUMBER = Byte(0x01) + ACCOUNT_TYPE_PUBKEY = Byte(0x02) + ACCOUNT_TYPE_BOTH = Byte(0x03) ) func ReadAccountId(r io.Reader) AccountId { - switch t := ReadByte(r); t { - case ACCOUNT_TYPE_NUMBER: - return AccountId{t, ReadUInt64(r), nil} - case ACCOUNT_TYPE_PUBKEY: - return AccountId{t, 0, ReadByteSlice(r)} - case ACCOUNT_TYPE_BOTH: - return AccountId{t, ReadUInt64(r), ReadByteSlice(r)} - default: - Panicf("Unknown AccountId type %x", t) - return AccountId{} - } + switch t := ReadByte(r); t { + case ACCOUNT_TYPE_NUMBER: + return AccountId{t, ReadUInt64(r), nil} + case ACCOUNT_TYPE_PUBKEY: + return AccountId{t, 0, ReadByteSlice(r)} + case ACCOUNT_TYPE_BOTH: + return AccountId{t, ReadUInt64(r), ReadByteSlice(r)} + default: + Panicf("Unknown AccountId type %x", t) + return AccountId{} + } } func (self AccountId) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type, w, n, err) - if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH { - n, err = WriteOnto(self.Number, w, n, err) - } - if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH { - n, err = WriteOnto(self.PubKey, w, n, err) - } - return + n, err = WriteOnto(self.Type, w, n, err) + if self.Type == ACCOUNT_TYPE_NUMBER || self.Type == ACCOUNT_TYPE_BOTH { + n, err = WriteOnto(self.Number, w, n, err) + } + if self.Type == ACCOUNT_TYPE_PUBKEY || self.Type == ACCOUNT_TYPE_BOTH { + n, err = WriteOnto(self.PubKey, w, n, err) + } + return } func AccountNumber(n UInt64) AccountId { - return AccountId{ACCOUNT_TYPE_NUMBER, n, nil} + return AccountId{ACCOUNT_TYPE_NUMBER, n, nil} } diff --git a/blocks/adjustment.go b/blocks/adjustment.go index ecac5896..175bda14 100644 --- a/blocks/adjustment.go +++ b/blocks/adjustment.go @@ -1,9 +1,9 @@ package blocks import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "io" ) /* Adjustment @@ -17,126 +17,122 @@ TODO: signing a bad checkpoint (block) */ type Adjustment interface { - Type() Byte - Binary + Type() Byte + Binary } const ( - ADJ_TYPE_BOND = Byte(0x01) - ADJ_TYPE_UNBOND = Byte(0x02) - ADJ_TYPE_TIMEOUT = Byte(0x03) - ADJ_TYPE_DUPEOUT = Byte(0x04) + ADJ_TYPE_BOND = Byte(0x01) + ADJ_TYPE_UNBOND = Byte(0x02) + ADJ_TYPE_TIMEOUT = Byte(0x03) + ADJ_TYPE_DUPEOUT = Byte(0x04) ) func ReadAdjustment(r io.Reader) Adjustment { - switch t := ReadByte(r); t { - case ADJ_TYPE_BOND: - return &Bond{ - Fee: ReadUInt64(r), - UnbondTo: ReadAccountId(r), - Amount: ReadUInt64(r), - Signature: ReadSignature(r), - } - case ADJ_TYPE_UNBOND: - return &Unbond{ - Fee: ReadUInt64(r), - Amount: ReadUInt64(r), - Signature: ReadSignature(r), - } - case ADJ_TYPE_TIMEOUT: - return &Timeout{ - Account: ReadAccountId(r), - Penalty: ReadUInt64(r), - } - case ADJ_TYPE_DUPEOUT: - return &Dupeout{ - VoteA: ReadVote(r), - VoteB: ReadVote(r), - } - default: - Panicf("Unknown Adjustment type %x", t) - return nil - } + switch t := ReadByte(r); t { + case ADJ_TYPE_BOND: + return &Bond{ + Fee: ReadUInt64(r), + UnbondTo: ReadAccountId(r), + Amount: ReadUInt64(r), + Signature: ReadSignature(r), + } + case ADJ_TYPE_UNBOND: + return &Unbond{ + Fee: ReadUInt64(r), + Amount: ReadUInt64(r), + Signature: ReadSignature(r), + } + case ADJ_TYPE_TIMEOUT: + return &Timeout{ + Account: ReadAccountId(r), + Penalty: ReadUInt64(r), + } + case ADJ_TYPE_DUPEOUT: + return &Dupeout{ + VoteA: ReadVote(r), + VoteB: ReadVote(r), + } + default: + Panicf("Unknown Adjustment type %x", t) + return nil + } } - /* Bond < Adjustment */ type Bond struct { - Fee UInt64 - UnbondTo AccountId - Amount UInt64 - Signature + Fee UInt64 + UnbondTo AccountId + Amount UInt64 + Signature } func (self *Bond) Type() Byte { - return ADJ_TYPE_BOND + return ADJ_TYPE_BOND } func (self *Bond) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.Fee, w, n, err) - n, err = WriteOnto(self.UnbondTo, w, n, err) - n, err = WriteOnto(self.Amount, w, n, err) - n, err = WriteOnto(self.Signature, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.Fee, w, n, err) + n, err = WriteOnto(self.UnbondTo, w, n, err) + n, err = WriteOnto(self.Amount, w, n, err) + n, err = WriteOnto(self.Signature, w, n, err) + return } - /* Unbond < Adjustment */ type Unbond struct { - Fee UInt64 - Amount UInt64 - Signature + Fee UInt64 + Amount UInt64 + Signature } func (self *Unbond) Type() Byte { - return ADJ_TYPE_UNBOND + return ADJ_TYPE_UNBOND } func (self *Unbond) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.Fee, w, n, err) - n, err = WriteOnto(self.Amount, w, n, err) - n, err = WriteOnto(self.Signature, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.Fee, w, n, err) + n, err = WriteOnto(self.Amount, w, n, err) + n, err = WriteOnto(self.Signature, w, n, err) + return } - /* Timeout < Adjustment */ type Timeout struct { - Account AccountId - Penalty UInt64 + Account AccountId + Penalty UInt64 } func (self *Timeout) Type() Byte { - return ADJ_TYPE_TIMEOUT + return ADJ_TYPE_TIMEOUT } func (self *Timeout) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.Account, w, n, err) - n, err = WriteOnto(self.Penalty, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.Account, w, n, err) + n, err = WriteOnto(self.Penalty, w, n, err) + return } - /* Dupeout < Adjustment */ type Dupeout struct { - VoteA Vote - VoteB Vote + VoteA Vote + VoteB Vote } func (self *Dupeout) Type() Byte { - return ADJ_TYPE_DUPEOUT + return ADJ_TYPE_DUPEOUT } func (self *Dupeout) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.VoteA, w, n, err) - n, err = WriteOnto(self.VoteB, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.VoteA, w, n, err) + n, err = WriteOnto(self.VoteB, w, n, err) + return } diff --git a/blocks/block.go b/blocks/block.go index 66bee940..3fcdd8dd 100644 --- a/blocks/block.go +++ b/blocks/block.go @@ -1,140 +1,137 @@ package blocks import ( - . "github.com/tendermint/tendermint/binary" - "github.com/tendermint/tendermint/merkle" - "io" + . "github.com/tendermint/tendermint/binary" + "github.com/tendermint/tendermint/merkle" + "io" ) - /* Block */ type Block struct { - Header - Validation - Data - // Checkpoint + Header + Validation + Data + // Checkpoint } func ReadBlock(r io.Reader) *Block { - return &Block{ - Header: ReadHeader(r), - Validation: ReadValidation(r), - Data: ReadData(r), - } + return &Block{ + Header: ReadHeader(r), + Validation: ReadValidation(r), + Data: ReadData(r), + } } func (self *Block) Validate() bool { - return false + return false } func (self *Block) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(&self.Header, w, n, err) - n, err = WriteOnto(&self.Validation, w, n, err) - n, err = WriteOnto(&self.Data, w, n, err) - return + n, err = WriteOnto(&self.Header, w, n, err) + n, err = WriteOnto(&self.Validation, w, n, err) + n, err = WriteOnto(&self.Data, w, n, err) + return } - /* Block > Header */ type Header struct { - Name String - Height UInt64 - Fees UInt64 - Time UInt64 - PrevHash ByteSlice - ValidationHash ByteSlice - DataHash ByteSlice + Name String + Height UInt64 + Fees UInt64 + Time UInt64 + PrevHash ByteSlice + ValidationHash ByteSlice + DataHash ByteSlice } func ReadHeader(r io.Reader) Header { - return Header{ - Name: ReadString(r), - Height: ReadUInt64(r), - Fees: ReadUInt64(r), - Time: ReadUInt64(r), - PrevHash: ReadByteSlice(r), - ValidationHash: ReadByteSlice(r), - DataHash: ReadByteSlice(r), - } + return Header{ + Name: ReadString(r), + Height: ReadUInt64(r), + Fees: ReadUInt64(r), + Time: ReadUInt64(r), + PrevHash: ReadByteSlice(r), + ValidationHash: ReadByteSlice(r), + DataHash: ReadByteSlice(r), + } } func (self *Header) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Name, w, n, err) - n, err = WriteOnto(self.Height, w, n, err) - n, err = WriteOnto(self.Fees, w, n, err) - n, err = WriteOnto(self.Time, w, n, err) - n, err = WriteOnto(self.PrevHash, w, n, err) - n, err = WriteOnto(self.ValidationHash, w, n, err) - n, err = WriteOnto(self.DataHash, w, n, err) - return + n, err = WriteOnto(self.Name, w, n, err) + n, err = WriteOnto(self.Height, w, n, err) + n, err = WriteOnto(self.Fees, w, n, err) + n, err = WriteOnto(self.Time, w, n, err) + n, err = WriteOnto(self.PrevHash, w, n, err) + n, err = WriteOnto(self.ValidationHash, w, n, err) + n, err = WriteOnto(self.DataHash, w, n, err) + return } - /* Block > Validation */ type Validation struct { - Signatures []Signature - Adjustments []Adjustment + Signatures []Signature + Adjustments []Adjustment } func ReadValidation(r io.Reader) Validation { - numSigs := int(ReadUInt64(r)) - numAdjs := int(ReadUInt64(r)) - sigs := make([]Signature, 0, numSigs) - for i:=0; i Data */ type Data struct { - Txs []Tx + Txs []Tx } func ReadData(r io.Reader) Data { - numTxs := int(ReadUInt64(r)) - txs := make([]Tx, 0, numTxs) - for i:=0; iRead of block failed.") - } + if !BinaryEqual(blockBytes, blockBytes2) { + t.Fatal("Write->Read of block failed.") + } } diff --git a/blocks/signature.go b/blocks/signature.go index d32a6668..f72b4f92 100644 --- a/blocks/signature.go +++ b/blocks/signature.go @@ -1,8 +1,8 @@ package blocks import ( - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + "io" ) /* @@ -19,23 +19,23 @@ It usually follows the message to be signed. */ type Signature struct { - Signer AccountId - SigBytes ByteSlice + Signer AccountId + SigBytes ByteSlice } func ReadSignature(r io.Reader) Signature { - return Signature{ - Signer: ReadAccountId(r), - SigBytes: ReadByteSlice(r), - } + return Signature{ + Signer: ReadAccountId(r), + SigBytes: ReadByteSlice(r), + } } func (self Signature) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Signer, w, n, err) - n, err = WriteOnto(self.SigBytes, w, n, err) - return + n, err = WriteOnto(self.Signer, w, n, err) + n, err = WriteOnto(self.SigBytes, w, n, err) + return } func (self *Signature) Verify(msg ByteSlice) bool { - return false + return false } diff --git a/blocks/tx.go b/blocks/tx.go index 4ae1a2b8..d9f08f4c 100644 --- a/blocks/tx.go +++ b/blocks/tx.go @@ -1,9 +1,9 @@ package blocks import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "io" ) /* @@ -21,79 +21,77 @@ Tx wire format: */ type Tx interface { - Type() Byte - Binary + Type() Byte + Binary } const ( - TX_TYPE_SEND = Byte(0x01) - TX_TYPE_NAME = Byte(0x02) + TX_TYPE_SEND = Byte(0x01) + TX_TYPE_NAME = Byte(0x02) ) func ReadTx(r io.Reader) Tx { - switch t := ReadByte(r); t { - case TX_TYPE_SEND: - return &SendTx{ - Fee: ReadUInt64(r), - To: ReadAccountId(r), - Amount: ReadUInt64(r), - Signature: ReadSignature(r), - } - case TX_TYPE_NAME: - return &NameTx{ - Fee: ReadUInt64(r), - Name: ReadString(r), - PubKey: ReadByteSlice(r), - Signature: ReadSignature(r), - } - default: - Panicf("Unknown Tx type %x", t) - return nil - } + switch t := ReadByte(r); t { + case TX_TYPE_SEND: + return &SendTx{ + Fee: ReadUInt64(r), + To: ReadAccountId(r), + Amount: ReadUInt64(r), + Signature: ReadSignature(r), + } + case TX_TYPE_NAME: + return &NameTx{ + Fee: ReadUInt64(r), + Name: ReadString(r), + PubKey: ReadByteSlice(r), + Signature: ReadSignature(r), + } + default: + Panicf("Unknown Tx type %x", t) + return nil + } } - /* SendTx < Tx */ type SendTx struct { - Fee UInt64 - To AccountId - Amount UInt64 - Signature + Fee UInt64 + To AccountId + Amount UInt64 + Signature } func (self *SendTx) Type() Byte { - return TX_TYPE_SEND + return TX_TYPE_SEND } func (self *SendTx) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.Fee, w, n, err) - n, err = WriteOnto(self.To, w, n, err) - n, err = WriteOnto(self.Amount, w, n, err) - n, err = WriteOnto(self.Signature, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.Fee, w, n, err) + n, err = WriteOnto(self.To, w, n, err) + n, err = WriteOnto(self.Amount, w, n, err) + n, err = WriteOnto(self.Signature, w, n, err) + return } - /* NameTx < Tx */ type NameTx struct { - Fee UInt64 - Name String - PubKey ByteSlice - Signature + Fee UInt64 + Name String + PubKey ByteSlice + Signature } func (self *NameTx) Type() Byte { - return TX_TYPE_NAME + return TX_TYPE_NAME } func (self *NameTx) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Type(), w, n, err) - n, err = WriteOnto(self.Fee, w, n, err) - n, err = WriteOnto(self.Name, w, n, err) - n, err = WriteOnto(self.PubKey, w, n, err) - n, err = WriteOnto(self.Signature, w, n, err) - return + n, err = WriteOnto(self.Type(), w, n, err) + n, err = WriteOnto(self.Fee, w, n, err) + n, err = WriteOnto(self.Name, w, n, err) + n, err = WriteOnto(self.PubKey, w, n, err) + n, err = WriteOnto(self.Signature, w, n, err) + return } diff --git a/blocks/vote.go b/blocks/vote.go index 92772af2..3e64bdd6 100644 --- a/blocks/vote.go +++ b/blocks/vote.go @@ -1,8 +1,8 @@ package blocks import ( - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + "io" ) /* @@ -11,22 +11,22 @@ Typically only the signature is passed around, as the hash & height are implied. */ type Vote struct { - Height UInt64 - BlockHash ByteSlice - Signature + Height UInt64 + BlockHash ByteSlice + Signature } func ReadVote(r io.Reader) Vote { - return Vote{ - Height: ReadUInt64(r), - BlockHash: ReadByteSlice(r), - Signature: ReadSignature(r), - } + return Vote{ + Height: ReadUInt64(r), + BlockHash: ReadByteSlice(r), + Signature: ReadSignature(r), + } } func (self Vote) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(self.Height, w, n, err) - n, err = WriteOnto(self.BlockHash, w, n, err) - n, err = WriteOnto(self.Signature, w, n, err) - return + n, err = WriteOnto(self.Height, w, n, err) + n, err = WriteOnto(self.BlockHash, w, n, err) + n, err = WriteOnto(self.Signature, w, n, err) + return } diff --git a/common/debounce.go b/common/debounce.go index cf777b7e..125aaa98 100644 --- a/common/debounce.go +++ b/common/debounce.go @@ -1,45 +1,48 @@ package common import ( - "time" - "sync" + "sync" + "time" ) /* Debouncer */ type Debouncer struct { - Ch chan struct{} - quit chan struct{} - dur time.Duration - mtx sync.Mutex - timer *time.Timer + Ch chan struct{} + quit chan struct{} + dur time.Duration + mtx sync.Mutex + timer *time.Timer } func NewDebouncer(dur time.Duration) *Debouncer { - var timer *time.Timer - var ch = make(chan struct{}) - var quit = make(chan struct{}) - var mtx sync.Mutex - fire := func() { - go func() { - select { - case ch <- struct{}{}: - case <-quit: - } - }() - mtx.Lock(); defer mtx.Unlock() - timer.Reset(dur) - } - timer = time.AfterFunc(dur, fire) - return &Debouncer{Ch:ch, dur:dur, quit:quit, mtx:mtx, timer:timer} + var timer *time.Timer + var ch = make(chan struct{}) + var quit = make(chan struct{}) + var mtx sync.Mutex + fire := func() { + go func() { + select { + case ch <- struct{}{}: + case <-quit: + } + }() + mtx.Lock() + defer mtx.Unlock() + timer.Reset(dur) + } + timer = time.AfterFunc(dur, fire) + return &Debouncer{Ch: ch, dur: dur, quit: quit, mtx: mtx, timer: timer} } func (d *Debouncer) Reset() { - d.mtx.Lock(); defer d.mtx.Unlock() - d.timer.Reset(d.dur) + d.mtx.Lock() + defer d.mtx.Unlock() + d.timer.Reset(d.dur) } func (d *Debouncer) Stop() bool { - d.mtx.Lock(); defer d.mtx.Unlock() - close(d.quit) - return d.timer.Stop() + d.mtx.Lock() + defer d.mtx.Unlock() + close(d.quit) + return d.timer.Stop() } diff --git a/common/heap.go b/common/heap.go index 265ff4e4..4505eb20 100644 --- a/common/heap.go +++ b/common/heap.go @@ -1,28 +1,28 @@ package common import ( - "container/heap" + "container/heap" ) type Heap struct { - pq priorityQueue + pq priorityQueue } func NewHeap() *Heap { - return &Heap{pq:make([]*pqItem, 0)} + return &Heap{pq: make([]*pqItem, 0)} } func (h *Heap) Len() int { - return len(h.pq) + return len(h.pq) } func (h *Heap) Push(value interface{}, priority int) { - heap.Push(&h.pq, &pqItem{value:value, priority:priority}) + heap.Push(&h.pq, &pqItem{value: value, priority: priority}) } func (h *Heap) Pop() interface{} { - item := heap.Pop(&h.pq).(*pqItem) - return item.value + item := heap.Pop(&h.pq).(*pqItem) + return item.value } /* @@ -43,9 +43,9 @@ func main() { // From: http://golang.org/pkg/container/heap/#example__priorityQueue type pqItem struct { - value interface{} - priority int - index int + value interface{} + priority int + index int } type priorityQueue []*pqItem @@ -53,35 +53,34 @@ type priorityQueue []*pqItem func (pq priorityQueue) Len() int { return len(pq) } func (pq priorityQueue) Less(i, j int) bool { - return pq[i].priority < pq[j].priority + return pq[i].priority < pq[j].priority } func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j } func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) - item := x.(*pqItem) - item.index = n - *pq = append(*pq, item) + n := len(*pq) + item := x.(*pqItem) + item.index = n + *pq = append(*pq, item) } func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - item.index = -1 // for safety - *pq = old[0 : n-1] - return item + old := *pq + n := len(old) + item := old[n-1] + item.index = -1 // for safety + *pq = old[0 : n-1] + return item } func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority int) { - heap.Remove(pq, item.index) - item.value = value - item.priority = priority - heap.Push(pq, item) + heap.Remove(pq, item.index) + item.value = value + item.priority = priority + heap.Push(pq, item) } - diff --git a/common/panic.go b/common/panic.go index 14f63d2e..8621f163 100644 --- a/common/panic.go +++ b/common/panic.go @@ -1,9 +1,9 @@ package common import ( - "fmt" + "fmt" ) func Panicf(s string, args ...interface{}) { - panic(fmt.Sprintf(s, args...)) + panic(fmt.Sprintf(s, args...)) } diff --git a/config/config.go b/config/config.go index 6c9ca271..0350b6fe 100644 --- a/config/config.go +++ b/config/config.go @@ -1,109 +1,115 @@ package config import ( - "encoding/json" - "fmt" - "io/ioutil" - "log" - "os" - "path/filepath" - "strings" - "errors" - //"crypto/rand" - //"encoding/hex" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "strings" + //"crypto/rand" + //"encoding/hex" ) var APP_DIR = os.Getenv("HOME") + "/.tendermint" - /* Global & initialization */ var Config Config_ func init() { - configFile := APP_DIR+"/config.json" + configFile := APP_DIR + "/config.json" - // try to read configuration. if missing, write default - configBytes, err := ioutil.ReadFile(configFile) - if err != nil { - defaultConfig.write(configFile) - fmt.Println("Config file written to config.json. Please edit & run again") - os.Exit(1) - return - } + // try to read configuration. if missing, write default + configBytes, err := ioutil.ReadFile(configFile) + if err != nil { + defaultConfig.write(configFile) + fmt.Println("Config file written to config.json. Please edit & run again") + os.Exit(1) + return + } - // try to parse configuration. on error, die - Config = Config_{} - err = json.Unmarshal(configBytes, &Config) - if err != nil { - log.Panicf("Invalid configuration file %s: %v", configFile, err) - } - err = Config.validate() - if err != nil { - log.Panicf("Invalid configuration file %s: %v", configFile, err) - } + // try to parse configuration. on error, die + Config = Config_{} + err = json.Unmarshal(configBytes, &Config) + if err != nil { + log.Panicf("Invalid configuration file %s: %v", configFile, err) + } + err = Config.validate() + if err != nil { + log.Panicf("Invalid configuration file %s: %v", configFile, err) + } } - /* Default configuration */ var defaultConfig = Config_{ - Host: "127.0.0.1", - Port: 8770, - Db: DbConfig{ - Type: "level", - Dir: APP_DIR+"/data", - }, - Twilio: TwilioConfig{ - }, + Host: "127.0.0.1", + Port: 8770, + Db: DbConfig{ + Type: "level", + Dir: APP_DIR + "/data", + }, + Twilio: TwilioConfig{}, } - /* Configuration types */ type Config_ struct { - Host string - Port int - Db DbConfig - Twilio TwilioConfig + Host string + Port int + Db DbConfig + Twilio TwilioConfig } type TwilioConfig struct { - Sid string - Token string - From string - To string - MinInterval int + Sid string + Token string + From string + To string + MinInterval int } type DbConfig struct { - Type string - Dir string + Type string + Dir string } func (cfg *Config_) validate() error { - if cfg.Host == "" { return errors.New("Host must be set") } - if cfg.Port == 0 { return errors.New("Port must be set") } - if cfg.Db.Type == "" { return errors.New("Db.Type must be set") } - return nil + if cfg.Host == "" { + return errors.New("Host must be set") + } + if cfg.Port == 0 { + return errors.New("Port must be set") + } + if cfg.Db.Type == "" { + return errors.New("Db.Type must be set") + } + return nil } func (cfg *Config_) bytes() []byte { - configBytes, err := json.Marshal(cfg) - if err != nil { panic(err) } - return configBytes + configBytes, err := json.Marshal(cfg) + if err != nil { + panic(err) + } + return configBytes } func (cfg *Config_) write(configFile string) { - if strings.Index(configFile, "/") != -1 { - err := os.MkdirAll(filepath.Dir(configFile), 0700) - if err != nil { panic(err) } - } - err := ioutil.WriteFile(configFile, cfg.bytes(), 0600) - if err != nil { - panic(err) - } + if strings.Index(configFile, "/") != -1 { + err := os.MkdirAll(filepath.Dir(configFile), 0700) + if err != nil { + panic(err) + } + } + err := ioutil.WriteFile(configFile, cfg.bytes(), 0600) + if err != nil { + panic(err) + } } /* TODO: generate priv/pub keys @@ -113,4 +119,3 @@ func generateKeys() string { return hex.EncodeToString(bytes[:]) } */ - diff --git a/crypto/ed25519.go b/crypto/ed25519.go index 7eb02e53..65823853 100644 --- a/crypto/ed25519.go +++ b/crypto/ed25519.go @@ -11,60 +11,60 @@ import "C" import "unsafe" type Verify struct { - Message []byte - PubKey []byte - Signature []byte - Valid bool + Message []byte + PubKey []byte + Signature []byte + Valid bool } func MakePubKey(privKey []byte) []byte { - pubKey := [32]byte{} - C.ed25519_publickey( - (*C.uchar)(unsafe.Pointer(&privKey[0])), - (*C.uchar)(unsafe.Pointer(&pubKey[0])), - ) - return pubKey[:] + pubKey := [32]byte{} + C.ed25519_publickey( + (*C.uchar)(unsafe.Pointer(&privKey[0])), + (*C.uchar)(unsafe.Pointer(&pubKey[0])), + ) + return pubKey[:] } func SignMessage(message []byte, privKey []byte, pubKey []byte) []byte { - sig := [64]byte{} - C.ed25519_sign( - (*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)), - (*C.uchar)(unsafe.Pointer(&privKey[0])), - (*C.uchar)(unsafe.Pointer(&pubKey[0])), - (*C.uchar)(unsafe.Pointer(&sig[0])), - ) - return sig[:] + sig := [64]byte{} + C.ed25519_sign( + (*C.uchar)(unsafe.Pointer(&message[0])), (C.size_t)(len(message)), + (*C.uchar)(unsafe.Pointer(&privKey[0])), + (*C.uchar)(unsafe.Pointer(&pubKey[0])), + (*C.uchar)(unsafe.Pointer(&sig[0])), + ) + return sig[:] } func VerifyBatch(verifys []*Verify) bool { - count := len(verifys) + count := len(verifys) - msgs := make([]*byte, count) - lens := make([]C.size_t, count) - pubs := make([]*byte, count) - sigs := make([]*byte, count) - valids := make([]C.int, count) + msgs := make([]*byte, count) + lens := make([]C.size_t, count) + pubs := make([]*byte, count) + sigs := make([]*byte, count) + valids := make([]C.int, count) - for i, v := range verifys { - msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0])) - lens[i] = (C.size_t)(len(v.Message)) - pubs[i] = (*byte)(&v.PubKey[0]) - sigs[i] = (*byte)(&v.Signature[0]) - } + for i, v := range verifys { + msgs[i] = (*byte)(unsafe.Pointer(&v.Message[0])) + lens[i] = (C.size_t)(len(v.Message)) + pubs[i] = (*byte)(&v.PubKey[0]) + sigs[i] = (*byte)(&v.Signature[0]) + } - count_ := (C.size_t)(count) - msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0])) - lens_ := (*C.size_t)(unsafe.Pointer(&lens[0])) - pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0])) - sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0])) + count_ := (C.size_t)(count) + msgs_ := (**C.uchar)(unsafe.Pointer(&msgs[0])) + lens_ := (*C.size_t)(unsafe.Pointer(&lens[0])) + pubs_ := (**C.uchar)(unsafe.Pointer(&pubs[0])) + sigs_ := (**C.uchar)(unsafe.Pointer(&sigs[0])) - res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0]) + res := C.ed25519_sign_open_batch(msgs_, lens_, pubs_, sigs_, count_, &valids[0]) - for i, valid := range valids { - verifys[i].Valid = valid > 0 - } + for i, valid := range valids { + verifys[i].Valid = valid > 0 + } - return res == 0 + return res == 0 } diff --git a/crypto/ed25519_test.go b/crypto/ed25519_test.go index 230646d1..31c2a766 100644 --- a/crypto/ed25519_test.go +++ b/crypto/ed25519_test.go @@ -1,35 +1,47 @@ package crypto import ( - "testing" - "crypto/rand" + "crypto/rand" + "testing" ) func TestSign(t *testing.T) { - privKey := make([]byte, 32) - _, err := rand.Read(privKey) - if err != nil { t.Fatal(err) } - pubKey := MakePubKey(privKey) - signature := SignMessage([]byte("hello"), privKey, pubKey) + privKey := make([]byte, 32) + _, err := rand.Read(privKey) + if err != nil { + t.Fatal(err) + } + pubKey := MakePubKey(privKey) + signature := SignMessage([]byte("hello"), privKey, pubKey) - v1 := &Verify{ - Message: []byte("hello"), - PubKey: pubKey, - Signature: signature, - } + v1 := &Verify{ + Message: []byte("hello"), + PubKey: pubKey, + Signature: signature, + } - ok := VerifyBatch([]*Verify{v1, v1, v1, v1}) - if ok != true { t.Fatal("Expected ok == true") } - if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") } + ok := VerifyBatch([]*Verify{v1, v1, v1, v1}) + if ok != true { + t.Fatal("Expected ok == true") + } + if v1.Valid != true { + t.Fatal("Expected v1.Valid to be true") + } - v2 := &Verify{ - Message: []byte{0x73}, - PubKey: pubKey, - Signature: signature, - } + v2 := &Verify{ + Message: []byte{0x73}, + PubKey: pubKey, + Signature: signature, + } - ok = VerifyBatch([]*Verify{v1, v1, v1, v2}) - if ok != false { t.Fatal("Expected ok == false") } - if v1.Valid != true { t.Fatal("Expected v1.Valid to be true") } - if v2.Valid != false { t.Fatal("Expected v2.Valid to be true") } + ok = VerifyBatch([]*Verify{v1, v1, v1, v2}) + if ok != false { + t.Fatal("Expected ok == false") + } + if v1.Valid != true { + t.Fatal("Expected v1.Valid to be true") + } + if v2.Valid != false { + t.Fatal("Expected v2.Valid to be true") + } } diff --git a/db/level_db.go b/db/level_db.go index 48e14a4a..cf9059ba 100644 --- a/db/level_db.go +++ b/db/level_db.go @@ -1,54 +1,60 @@ package db import ( - "fmt" - "github.com/syndtr/goleveldb/leveldb" - "path" + "fmt" + "github.com/syndtr/goleveldb/leveldb" + "path" ) type LevelDB struct { - db *leveldb.DB + db *leveldb.DB } func NewLevelDB(name string) (*LevelDB, error) { - dbPath := path.Join(name) - db, err := leveldb.OpenFile(dbPath, nil) - if err != nil { - return nil, err - } - database := &LevelDB{db: db} - return database, nil + dbPath := path.Join(name) + db, err := leveldb.OpenFile(dbPath, nil) + if err != nil { + return nil, err + } + database := &LevelDB{db: db} + return database, nil } func (db *LevelDB) Put(key []byte, value []byte) { - err := db.db.Put(key, value, nil) - if err != nil { panic(err) } + err := db.db.Put(key, value, nil) + if err != nil { + panic(err) + } } -func (db *LevelDB) Get(key []byte) ([]byte) { - res, err := db.db.Get(key, nil) - if err != nil { panic(err) } - return res +func (db *LevelDB) Get(key []byte) []byte { + res, err := db.db.Get(key, nil) + if err != nil { + panic(err) + } + return res } func (db *LevelDB) Delete(key []byte) { - err := db.db.Delete(key, nil) - if err != nil { panic(err) } + err := db.db.Delete(key, nil) + if err != nil { + panic(err) + } } func (db *LevelDB) Db() *leveldb.DB { - return db.db + return db.db } func (db *LevelDB) Close() { - db.db.Close() + db.db.Close() } func (db *LevelDB) Print() { - iter := db.db.NewIterator(nil, nil) - for iter.Next() { - key := iter.Key() - value := iter.Value() - fmt.Printf("[%x]:\t[%x]", key, value) - } + iter := db.db.NewIterator(nil, nil) + for iter.Next() { + key := iter.Key() + value := iter.Value() + fmt.Printf("[%x]:\t[%x]", key, value) + } } diff --git a/db/mem_db.go b/db/mem_db.go index 9cde2754..21151c65 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -1,32 +1,32 @@ package db import ( - "fmt" + "fmt" ) type MemDB struct { - db map[string][]byte + db map[string][]byte } -func NewMemDB() (*MemDB) { - database := &MemDB{db:make(map[string][]byte)} - return database +func NewMemDB() *MemDB { + database := &MemDB{db: make(map[string][]byte)} + return database } func (db *MemDB) Put(key []byte, value []byte) { - db.db[string(key)] = value + db.db[string(key)] = value } -func (db *MemDB) Get(key []byte) ([]byte) { - return db.db[string(key)] +func (db *MemDB) Get(key []byte) []byte { + return db.db[string(key)] } func (db *MemDB) Delete(key []byte) { - delete(db.db, string(key)) + delete(db.db, string(key)) } func (db *MemDB) Print() { - for key, value := range db.db { - fmt.Printf("[%x]:\t[%x]", []byte(key), value) - } + for key, value := range db.db { + fmt.Printf("[%x]:\t[%x]", []byte(key), value) + } } diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index d8312827..2f58071e 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -1,405 +1,447 @@ package merkle import ( - . "github.com/tendermint/tendermint/binary" - "bytes" - "io" - "crypto/sha256" + "bytes" + "crypto/sha256" + . "github.com/tendermint/tendermint/binary" + "io" ) // Node type IAVLNode struct { - key Key - value Value - size uint64 - height uint8 - hash ByteSlice - left *IAVLNode - right *IAVLNode + key Key + value Value + size uint64 + height uint8 + hash ByteSlice + left *IAVLNode + right *IAVLNode - // volatile - flags byte + // volatile + flags byte } const ( - IAVLNODE_FLAG_PERSISTED = byte(0x01) - IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) + IAVLNODE_FLAG_PERSISTED = byte(0x01) + IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) ) func NewIAVLNode(key Key, value Value) *IAVLNode { - return &IAVLNode{ - key: key, - value: value, - size: 1, - } + return &IAVLNode{ + key: key, + value: value, + size: 1, + } } func (self *IAVLNode) Copy() *IAVLNode { - if self.height == 0 { - panic("Why are you copying a value node?") - } - return &IAVLNode{ - key: self.key, - size: self.size, - height: self.height, - left: self.left, - right: self.right, - hash: nil, - flags: byte(0), - } + if self.height == 0 { + panic("Why are you copying a value node?") + } + return &IAVLNode{ + key: self.key, + size: self.size, + height: self.height, + left: self.left, + right: self.right, + hash: nil, + flags: byte(0), + } } func (self *IAVLNode) Key() Key { - return self.key + return self.key } func (self *IAVLNode) Value() Value { - return self.value + return self.value } func (self *IAVLNode) Size() uint64 { - return self.size + return self.size } func (self *IAVLNode) Height() uint8 { - return self.height + return self.height } func (self *IAVLNode) has(db Db, key Key) (has bool) { - if self.key.Equals(key) { - return true - } - if self.height == 0 { - return false - } else { - if key.Less(self.key) { - return self.leftFilled(db).has(db, key) - } else { - return self.rightFilled(db).has(db, key) - } - } + if self.key.Equals(key) { + return true + } + if self.height == 0 { + return false + } else { + if key.Less(self.key) { + return self.leftFilled(db).has(db, key) + } else { + return self.rightFilled(db).has(db, key) + } + } } func (self *IAVLNode) get(db Db, key Key) (value Value) { - if self.height == 0 { - if self.key.Equals(key) { - return self.value - } else { - return nil - } - } else { - if key.Less(self.key) { - return self.leftFilled(db).get(db, key) - } else { - return self.rightFilled(db).get(db, key) - } - } + if self.height == 0 { + if self.key.Equals(key) { + return self.value + } else { + return nil + } + } else { + if key.Less(self.key) { + return self.leftFilled(db).get(db, key) + } else { + return self.rightFilled(db).get(db, key) + } + } } func (self *IAVLNode) Hash() (ByteSlice, uint64) { - if self.hash != nil { - return self.hash, 0 - } + if self.hash != nil { + return self.hash, 0 + } - hasher := sha256.New() - _, hashCount, err := self.saveToCountHashes(hasher, false) - if err != nil { panic(err) } - self.hash = hasher.Sum(nil) + hasher := sha256.New() + _, hashCount, err := self.saveToCountHashes(hasher, false) + if err != nil { + panic(err) + } + self.hash = hasher.Sum(nil) - return self.hash, hashCount+1 + return self.hash, hashCount + 1 } func (self *IAVLNode) Save(db Db) { - if self.hash == nil { - panic("savee.hash can't be nil") - } - if self.flags & IAVLNODE_FLAG_PERSISTED > 0 || - self.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { - return - } + if self.hash == nil { + panic("savee.hash can't be nil") + } + if self.flags&IAVLNODE_FLAG_PERSISTED > 0 || + self.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { + return + } - // children - if self.height > 0 { - self.left.Save(db) - self.right.Save(db) - } + // children + if self.height > 0 { + self.left.Save(db) + self.right.Save(db) + } - // save self - buf := bytes.NewBuffer(nil) - _, err := self.WriteTo(buf) - if err != nil { panic(err) } - db.Put([]byte(self.hash), buf.Bytes()) + // save self + buf := bytes.NewBuffer(nil) + _, err := self.WriteTo(buf) + if err != nil { + panic(err) + } + db.Put([]byte(self.hash), buf.Bytes()) - self.flags |= IAVLNODE_FLAG_PERSISTED + self.flags |= IAVLNODE_FLAG_PERSISTED } func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { - if self.height == 0 { - if key.Less(self.key) { - return &IAVLNode{ - key: self.key, - height: 1, - size: 2, - left: NewIAVLNode(key, value), - right: self, - }, false - } else if self.key.Equals(key) { - return NewIAVLNode(key, value), true - } else { - return &IAVLNode{ - key: key, - height: 1, - size: 2, - left: self, - right: NewIAVLNode(key, value), - }, false - } - } else { - self = self.Copy() - if key.Less(self.key) { - self.left, updated = self.leftFilled(db).put(db, key, value) - } else { - self.right, updated = self.rightFilled(db).put(db, key, value) - } - if updated { - return self, updated - } else { - self.calcHeightAndSize(db) - return self.balance(db), updated - } - } + if self.height == 0 { + if key.Less(self.key) { + return &IAVLNode{ + key: self.key, + height: 1, + size: 2, + left: NewIAVLNode(key, value), + right: self, + }, false + } else if self.key.Equals(key) { + return NewIAVLNode(key, value), true + } else { + return &IAVLNode{ + key: key, + height: 1, + size: 2, + left: self, + right: NewIAVLNode(key, value), + }, false + } + } else { + self = self.Copy() + if key.Less(self.key) { + self.left, updated = self.leftFilled(db).put(db, key, value) + } else { + self.right, updated = self.rightFilled(db).put(db, key, value) + } + if updated { + return self, updated + } else { + self.calcHeightAndSize(db) + return self.balance(db), updated + } + } } // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) { - if self.height == 0 { - if self.key.Equals(key) { - return nil, nil, self.value, nil - } else { - return self, nil, nil, NotFound(key) - } - } else { - if key.Less(self.key) { - var newLeft *IAVLNode - newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) - if err != nil { - return self, nil, value, err - } else if newLeft == nil { // left node held value, was removed - return self.right, self.key, value, nil - } - self = self.Copy() - self.left = newLeft - } else { - var newRight *IAVLNode - newRight, newKey, value, err = self.rightFilled(db).remove(db, key) - if err != nil { - return self, nil, value, err - } else if newRight == nil { // right node held value, was removed - return self.left, nil, value, nil - } - self = self.Copy() - self.right = newRight - if newKey != nil { - self.key = newKey - newKey = nil - } - } - self.calcHeightAndSize(db) - return self.balance(db), newKey, value, err - } + if self.height == 0 { + if self.key.Equals(key) { + return nil, nil, self.value, nil + } else { + return self, nil, nil, NotFound(key) + } + } else { + if key.Less(self.key) { + var newLeft *IAVLNode + newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) + if err != nil { + return self, nil, value, err + } else if newLeft == nil { // left node held value, was removed + return self.right, self.key, value, nil + } + self = self.Copy() + self.left = newLeft + } else { + var newRight *IAVLNode + newRight, newKey, value, err = self.rightFilled(db).remove(db, key) + if err != nil { + return self, nil, value, err + } else if newRight == nil { // right node held value, was removed + return self.left, nil, value, nil + } + self = self.Copy() + self.right = newRight + if newKey != nil { + self.key = newKey + newKey = nil + } + } + self.calcHeightAndSize(db) + return self.balance(db), newKey, value, err + } } func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { - n, _, err = self.saveToCountHashes(w, true) - return + n, _, err = self.saveToCountHashes(w, true) + return } func (self *IAVLNode) saveToCountHashes(w io.Writer, meta bool) (n int64, hashCount uint64, err error) { - var _n int64 + var _n int64 - if meta { - // height & size - _n, err = UInt8(self.height).WriteTo(w) - if err != nil { return } else { n += _n } - _n, err = UInt64(self.size).WriteTo(w) - if err != nil { return } else { n += _n } + if meta { + // height & size + _n, err = UInt8(self.height).WriteTo(w) + if err != nil { + return + } else { + n += _n + } + _n, err = UInt64(self.size).WriteTo(w) + if err != nil { + return + } else { + n += _n + } - // key - _n, err = Byte(GetBinaryType(self.key)).WriteTo(w) - if err != nil { return } else { n += _n } - _n, err = self.key.WriteTo(w) - if err != nil { return } else { n += _n } - } + // key + _n, err = Byte(GetBinaryType(self.key)).WriteTo(w) + if err != nil { + return + } else { + n += _n + } + _n, err = self.key.WriteTo(w) + if err != nil { + return + } else { + n += _n + } + } - if self.height == 0 { - // value - _n, err = Byte(GetBinaryType(self.value)).WriteTo(w) - if err != nil { return } else { n += _n } - if self.value != nil { - _n, err = self.value.WriteTo(w) - if err != nil { return } else { n += _n } - } - } else { - // left - leftHash, leftCount := self.left.Hash() - hashCount += leftCount - _n, err = leftHash.WriteTo(w) - if err != nil { return } else { n += _n } - // right - rightHash, rightCount := self.right.Hash() - hashCount += rightCount - _n, err = rightHash.WriteTo(w) - if err != nil { return } else { n += _n } - } + if self.height == 0 { + // value + _n, err = Byte(GetBinaryType(self.value)).WriteTo(w) + if err != nil { + return + } else { + n += _n + } + if self.value != nil { + _n, err = self.value.WriteTo(w) + if err != nil { + return + } else { + n += _n + } + } + } else { + // left + leftHash, leftCount := self.left.Hash() + hashCount += leftCount + _n, err = leftHash.WriteTo(w) + if err != nil { + return + } else { + n += _n + } + // right + rightHash, rightCount := self.right.Hash() + hashCount += rightCount + _n, err = rightHash.WriteTo(w) + if err != nil { + return + } else { + n += _n + } + } - return + return } // Given a placeholder node which has only the hash set, // load the rest of the data from db. // Not threadsafe. func (self *IAVLNode) fill(db Db) { - if self.hash == nil { - panic("placeholder.hash can't be nil") - } - buf := db.Get(self.hash) - r := bytes.NewReader(buf) - // node header - self.height = uint8(ReadUInt8(r)) - self.size = uint64(ReadUInt64(r)) - // key - key := ReadBinary(r) - self.key = key.(Key) + if self.hash == nil { + panic("placeholder.hash can't be nil") + } + buf := db.Get(self.hash) + r := bytes.NewReader(buf) + // node header + self.height = uint8(ReadUInt8(r)) + self.size = uint64(ReadUInt64(r)) + // key + key := ReadBinary(r) + self.key = key.(Key) - if self.height == 0 { - // value - self.value = ReadBinary(r) - } else { - // left - var leftHash ByteSlice - leftHash = ReadByteSlice(r) - self.left = &IAVLNode{ - hash: leftHash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - // right - var rightHash ByteSlice - rightHash = ReadByteSlice(r) - self.right = &IAVLNode{ - hash: rightHash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - if r.Len() != 0 { - panic("buf not all consumed") - } - } - self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER + if self.height == 0 { + // value + self.value = ReadBinary(r) + } else { + // left + var leftHash ByteSlice + leftHash = ReadByteSlice(r) + self.left = &IAVLNode{ + hash: leftHash, + flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, + } + // right + var rightHash ByteSlice + rightHash = ReadByteSlice(r) + self.right = &IAVLNode{ + hash: rightHash, + flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, + } + if r.Len() != 0 { + panic("buf not all consumed") + } + } + self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER } func (self *IAVLNode) leftFilled(db Db) *IAVLNode { - if self.left.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { - self.left.fill(db) - } - return self.left + if self.left.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { + self.left.fill(db) + } + return self.left } func (self *IAVLNode) rightFilled(db Db) *IAVLNode { - if self.right.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { - self.right.fill(db) - } - return self.right + if self.right.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { + self.right.fill(db) + } + return self.right } func (self *IAVLNode) rotateRight(db Db) *IAVLNode { - self = self.Copy() - sl := self.leftFilled(db).Copy() - slr := sl.right + self = self.Copy() + sl := self.leftFilled(db).Copy() + slr := sl.right - sl.right = self - self.left = slr + sl.right = self + self.left = slr - self.calcHeightAndSize(db) - sl.calcHeightAndSize(db) + self.calcHeightAndSize(db) + sl.calcHeightAndSize(db) - return sl + return sl } func (self *IAVLNode) rotateLeft(db Db) *IAVLNode { - self = self.Copy() - sr := self.rightFilled(db).Copy() - srl := sr.left + self = self.Copy() + sr := self.rightFilled(db).Copy() + srl := sr.left - sr.left = self - self.right = srl + sr.left = self + self.right = srl - self.calcHeightAndSize(db) - sr.calcHeightAndSize(db) + self.calcHeightAndSize(db) + sr.calcHeightAndSize(db) - return sr + return sr } func (self *IAVLNode) calcHeightAndSize(db Db) { - self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1 - self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() + self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1 + self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() } func (self *IAVLNode) calcBalance(db Db) int { - return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height()) + return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height()) } func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) { - balance := self.calcBalance(db) - if (balance > 1) { - if (self.leftFilled(db).calcBalance(db) >= 0) { - // Left Left Case - return self.rotateRight(db) - } else { - // Left Right Case - self = self.Copy() - self.left = self.leftFilled(db).rotateLeft(db) - //self.calcHeightAndSize() - return self.rotateRight(db) - } - } - if (balance < -1) { - if (self.rightFilled(db).calcBalance(db) <= 0) { - // Right Right Case - return self.rotateLeft(db) - } else { - // Right Left Case - self = self.Copy() - self.right = self.rightFilled(db).rotateRight(db) - //self.calcHeightAndSize() - return self.rotateLeft(db) - } - } - // Nothing changed - return self + balance := self.calcBalance(db) + if balance > 1 { + if self.leftFilled(db).calcBalance(db) >= 0 { + // Left Left Case + return self.rotateRight(db) + } else { + // Left Right Case + self = self.Copy() + self.left = self.leftFilled(db).rotateLeft(db) + //self.calcHeightAndSize() + return self.rotateRight(db) + } + } + if balance < -1 { + if self.rightFilled(db).calcBalance(db) <= 0 { + // Right Right Case + return self.rotateLeft(db) + } else { + // Right Left Case + self = self.Copy() + self.right = self.rightFilled(db).rotateRight(db) + //self.calcHeightAndSize() + return self.rotateLeft(db) + } + } + // Nothing changed + return self } -func (self *IAVLNode) lmd(db Db) (*IAVLNode) { - if self.height == 0 { - return self - } - return self.leftFilled(db).lmd(db) +func (self *IAVLNode) lmd(db Db) *IAVLNode { + if self.height == 0 { + return self + } + return self.leftFilled(db).lmd(db) } -func (self *IAVLNode) rmd(db Db) (*IAVLNode) { - if self.height == 0 { - return self - } - return self.rightFilled(db).rmd(db) +func (self *IAVLNode) rmd(db Db) *IAVLNode { + if self.height == 0 { + return self + } + return self.rightFilled(db).rmd(db) } -func (self *IAVLNode) traverse(db Db, cb func(Node)bool) bool { - stop := cb(self) - if stop { return stop } - if self.height > 0 { - stop = self.leftFilled(db).traverse(db, cb) - if stop { return stop } - stop = self.rightFilled(db).traverse(db, cb) - if stop { return stop } - } - return false +func (self *IAVLNode) traverse(db Db, cb func(Node) bool) bool { + stop := cb(self) + if stop { + return stop + } + if self.height > 0 { + stop = self.leftFilled(db).traverse(db, cb) + if stop { + return stop + } + stop = self.rightFilled(db).traverse(db, cb) + if stop { + return stop + } + } + return false } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 7fdf5887..78dc0b2e 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -1,283 +1,283 @@ package merkle import ( - . "github.com/tendermint/tendermint/binary" - "testing" - "fmt" - "os" - "bytes" - "math/rand" - "encoding/binary" - "github.com/tendermint/tendermint/db" - "crypto/sha256" - "runtime" + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + . "github.com/tendermint/tendermint/binary" + "github.com/tendermint/tendermint/db" + "math/rand" + "os" + "runtime" + "testing" ) func init() { - if urandom, err := os.Open("/dev/urandom"); err != nil { - return - } else { - buf := make([]byte, 8) - if _, err := urandom.Read(buf); err == nil { - buf_reader := bytes.NewReader(buf) - if seed, err := binary.ReadVarint(buf_reader); err == nil { - rand.Seed(seed) - } - } - urandom.Close() - } + if urandom, err := os.Open("/dev/urandom"); err != nil { + return + } else { + buf := make([]byte, 8) + if _, err := urandom.Read(buf); err == nil { + buf_reader := bytes.NewReader(buf) + if seed, err := binary.ReadVarint(buf_reader); err == nil { + rand.Seed(seed) + } + } + urandom.Close() + } } func TestUnit(t *testing.T) { - // Convenience for a new node - N := func(l, r interface{}) *IAVLNode { - var left, right *IAVLNode - if _, ok := l.(*IAVLNode); ok { - left = l.(*IAVLNode) - } else { - left = NewIAVLNode(Int32(l.(int)), nil) - } - if _, ok := r.(*IAVLNode); ok { - right = r.(*IAVLNode) - } else { - right = NewIAVLNode(Int32(r.(int)), nil) - } + // Convenience for a new node + N := func(l, r interface{}) *IAVLNode { + var left, right *IAVLNode + if _, ok := l.(*IAVLNode); ok { + left = l.(*IAVLNode) + } else { + left = NewIAVLNode(Int32(l.(int)), nil) + } + if _, ok := r.(*IAVLNode); ok { + right = r.(*IAVLNode) + } else { + right = NewIAVLNode(Int32(r.(int)), nil) + } - n := &IAVLNode{ - key: right.lmd(nil).key, - left: left, - right: right, - } - n.calcHeightAndSize(nil) - n.Hash() - return n - } + n := &IAVLNode{ + key: right.lmd(nil).key, + left: left, + right: right, + } + n.calcHeightAndSize(nil) + n.Hash() + return n + } - // Convenience for simple printing of keys & tree structure - var P func(*IAVLNode) string - P = func(n *IAVLNode) string { - if n.height == 0 { - return fmt.Sprintf("%v", n.key) - } else { - return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) - } - } + // Convenience for simple printing of keys & tree structure + var P func(*IAVLNode) string + P = func(n *IAVLNode) string { + if n.height == 0 { + return fmt.Sprintf("%v", n.key) + } else { + return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) + } + } - expectHash := func(n2 *IAVLNode, hashCount uint64) { - // ensure number of new hash calculations is as expected. - hash, count := n2.Hash() - if count != hashCount { - t.Fatalf("Expected %v new hashes, got %v", hashCount, count) - } - // nuke hashes and reconstruct hash, ensure it's the same. - (&IAVLTree{root:n2}).Traverse(func(node Node) bool { - node.(*IAVLNode).hash = nil - return false - }) - // ensure that the new hash after nuking is the same as the old. - newHash, _ := n2.Hash() - if bytes.Compare(hash, newHash) != 0 { - t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) - } - } + expectHash := func(n2 *IAVLNode, hashCount uint64) { + // ensure number of new hash calculations is as expected. + hash, count := n2.Hash() + if count != hashCount { + t.Fatalf("Expected %v new hashes, got %v", hashCount, count) + } + // nuke hashes and reconstruct hash, ensure it's the same. + (&IAVLTree{root: n2}).Traverse(func(node Node) bool { + node.(*IAVLNode).hash = nil + return false + }) + // ensure that the new hash after nuking is the same as the old. + newHash, _ := n2.Hash() + if bytes.Compare(hash, newHash) != 0 { + t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) + } + } - expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, updated := n.put(nil, Int32(i), nil) - // ensure node was added & structure is as expected. - if updated == true || P(n2) != repr { - t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", - i, P(n), repr, P(n2), updated) - } - // ensure hash calculation requirements - expectHash(n2, hashCount) - } + expectPut := func(n *IAVLNode, i int, repr string, hashCount uint64) { + n2, updated := n.put(nil, Int32(i), nil) + // ensure node was added & structure is as expected. + if updated == true || P(n2) != repr { + t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", + i, P(n), repr, P(n2), updated) + } + // ensure hash calculation requirements + expectHash(n2, hashCount) + } - expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, _, value, err := n.remove(nil, Int32(i)) - // ensure node was added & structure is as expected. - if value != nil || err != nil || P(n2) != repr { - t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", - i, P(n), repr, P(n2), value, err) - } - // ensure hash calculation requirements - expectHash(n2, hashCount) - } + expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { + n2, _, value, err := n.remove(nil, Int32(i)) + // ensure node was added & structure is as expected. + if value != nil || err != nil || P(n2) != repr { + t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", + i, P(n), repr, P(n2), value, err) + } + // ensure hash calculation requirements + expectHash(n2, hashCount) + } - //////// Test Put cases: + //////// Test Put cases: - // Case 1: - n1 := N(4, 20) + // Case 1: + n1 := N(4, 20) - expectPut(n1, 8, "((4 8) 20)", 3) - expectPut(n1, 25, "(4 (20 25))", 3) + expectPut(n1, 8, "((4 8) 20)", 3) + expectPut(n1, 25, "(4 (20 25))", 3) - n2 := N(4, N(20, 25)) + n2 := N(4, N(20, 25)) - expectPut(n2, 8, "((4 8) (20 25))", 3) - expectPut(n2, 30, "((4 20) (25 30))", 4) + expectPut(n2, 8, "((4 8) (20 25))", 3) + expectPut(n2, 30, "((4 20) (25 30))", 4) - n3 := N(N(1, 2), 6) + n3 := N(N(1, 2), 6) - expectPut(n3, 4, "((1 2) (4 6))", 4) - expectPut(n3, 8, "((1 2) (6 8))", 3) + expectPut(n3, 4, "((1 2) (4 6))", 4) + expectPut(n3, 8, "((1 2) (6 8))", 3) - n4 := N(N(1, 2), N(N(5, 6), N(7, 9))) + n4 := N(N(1, 2), N(N(5, 6), N(7, 9))) - expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) - expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) + expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) + expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) - //////// Test Remove cases: + //////// Test Remove cases: - n10 := N(N(1, 2), 3) + n10 := N(N(1, 2), 3) - expectRemove(n10, 2, "(1 3)", 1) - expectRemove(n10, 3, "(1 2)", 0) + expectRemove(n10, 2, "(1 3)", 1) + expectRemove(n10, 3, "(1 2)", 0) - n11 := N(N(N(1, 2), 3), N(4, 5)) + n11 := N(N(N(1, 2), 3), N(4, 5)) - expectRemove(n11, 4, "((1 2) (3 5))", 2) - expectRemove(n11, 3, "((1 2) (4 5))", 1) + expectRemove(n11, 4, "((1 2) (3 5))", 2) + expectRemove(n11, 3, "((1 2) (4 5))", 1) } func TestIntegration(t *testing.T) { - type record struct { - key String - value String - } + type record struct { + key String + value String + } - records := make([]*record, 400) - var tree *IAVLTree = NewIAVLTree(nil) - var err error - var val Value - var updated bool + records := make([]*record, 400) + var tree *IAVLTree = NewIAVLTree(nil) + var err error + var val Value + var updated bool - randomRecord := func() *record { - return &record{ randstr(20), randstr(20) } - } + randomRecord := func() *record { + return &record{randstr(20), randstr(20)} + } - for i := range records { - r := randomRecord() - records[i] = r - //t.Log("New record", r) - //PrintIAVLNode(tree.root) - updated = tree.Put(r.key, String("")) - if updated { - t.Error("should have not been updated") - } - updated = tree.Put(r.key, r.value) - if !updated { - t.Error("should have been updated") - } - if tree.Size() != uint64(i+1) { - t.Error("size was wrong", tree.Size(), i+1) - } - } + for i := range records { + r := randomRecord() + records[i] = r + //t.Log("New record", r) + //PrintIAVLNode(tree.root) + updated = tree.Put(r.key, String("")) + if updated { + t.Error("should have not been updated") + } + updated = tree.Put(r.key, r.value) + if !updated { + t.Error("should have been updated") + } + if tree.Size() != uint64(i+1) { + t.Error("size was wrong", tree.Size(), i+1) + } + } - for _, r := range records { - if has := tree.Has(r.key); !has { - t.Error("Missing key", r.key) - } - if has := tree.Has(randstr(12)); has { - t.Error("Table has extra key") - } - if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { - t.Error("wrong value") - } - } + for _, r := range records { + if has := tree.Has(r.key); !has { + t.Error("Missing key", r.key) + } + if has := tree.Has(randstr(12)); has { + t.Error("Table has extra key") + } + if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } - for i, x := range records { - if val, err = tree.Remove(x.key); err != nil { - t.Error(err) - } else if !(val.(String)).Equals(x.value) { - t.Error("wrong value") - } - for _, r := range records[i+1:] { - if has := tree.Has(r.key); !has { - t.Error("Missing key", r.key) - } - if has := tree.Has(randstr(12)); has { - t.Error("Table has extra key") - } - if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { - t.Error("wrong value") - } - } - if tree.Size() != uint64(len(records) - (i+1)) { - t.Error("size was wrong", tree.Size(), (len(records) - (i+1))) - } - } + for i, x := range records { + if val, err = tree.Remove(x.key); err != nil { + t.Error(err) + } else if !(val.(String)).Equals(x.value) { + t.Error("wrong value") + } + for _, r := range records[i+1:] { + if has := tree.Has(r.key); !has { + t.Error("Missing key", r.key) + } + if has := tree.Has(randstr(12)); has { + t.Error("Table has extra key") + } + if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } + if tree.Size() != uint64(len(records)-(i+1)) { + t.Error("size was wrong", tree.Size(), (len(records) - (i + 1))) + } + } } func TestPersistence(t *testing.T) { - db := db.NewMemDB() + db := db.NewMemDB() - // Create some random key value pairs - records := make(map[String]String) - for i:=0; i<10000; i++ { - records[String(randstr(20))] = String(randstr(20)) - } + // Create some random key value pairs + records := make(map[String]String) + for i := 0; i < 10000; i++ { + records[String(randstr(20))] = String(randstr(20)) + } - // Construct some tree and save it - t1 := NewIAVLTree(db) - for key, value := range records { - t1.Put(key, value) - } - t1.Save() + // Construct some tree and save it + t1 := NewIAVLTree(db) + for key, value := range records { + t1.Put(key, value) + } + t1.Save() - hash, _ := t1.Hash() + hash, _ := t1.Hash() - // Load a tree - t2 := NewIAVLTreeFromHash(db, hash) - for key, value := range records { - t2value := t2.Get(key) - if !BinaryEqual(t2value, value) { - t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) - } - } + // Load a tree + t2 := NewIAVLTreeFromHash(db, hash) + for key, value := range records { + t2value := t2.Get(key) + if !BinaryEqual(t2value, value) { + t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) + } + } } func BenchmarkHash(b *testing.B) { - b.StopTimer() + b.StopTimer() - s := randstr(128) + s := randstr(128) - b.StartTimer() - for i := 0; i < b.N; i++ { - hasher := sha256.New() - hasher.Write([]byte(s)) - hasher.Sum(nil) - } + b.StartTimer() + for i := 0; i < b.N; i++ { + hasher := sha256.New() + hasher.Write([]byte(s)) + hasher.Sum(nil) + } } func BenchmarkImmutableAvlTree(b *testing.B) { - b.StopTimer() + b.StopTimer() - type record struct { - key String - value String - } + type record struct { + key String + value String + } - randomRecord := func() *record { - return &record{ randstr(32), randstr(32) } - } + randomRecord := func() *record { + return &record{randstr(32), randstr(32)} + } - t := NewIAVLTree(nil) - for i:=0; i<1000000; i++ { - r := randomRecord() - t.Put(r.key, r.value) - } + t := NewIAVLTree(nil) + for i := 0; i < 1000000; i++ { + r := randomRecord() + t.Put(r.key, r.value) + } - fmt.Println("ok, starting") + fmt.Println("ok, starting") - runtime.GC() + runtime.GC() - b.StartTimer() - for i := 0; i < b.N; i++ { - r := randomRecord() - t.Put(r.key, r.value) - t.Remove(r.key) - } + b.StartTimer() + for i := 0; i < b.N; i++ { + r := randomRecord() + t.Put(r.key, r.value) + t.Remove(r.key) + } } diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index a6c5493e..0e6fb79d 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -1,10 +1,10 @@ package merkle import ( - . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/binary" ) -const HASH_BYTE_SIZE int = 4+32 +const HASH_BYTE_SIZE int = 4 + 32 /* Immutable AVL Tree (wraps the Node root) @@ -13,100 +13,118 @@ This tree is not concurrency safe. You must wrap your calls with your own mutex. */ type IAVLTree struct { - db Db - root *IAVLNode + db Db + root *IAVLNode } func NewIAVLTree(db Db) *IAVLTree { - return &IAVLTree{db:db, root:nil} + return &IAVLTree{db: db, root: nil} } func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { - root := &IAVLNode{ - hash: hash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - root.fill(db) - return &IAVLTree{db:db, root:root} + root := &IAVLNode{ + hash: hash, + flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, + } + root.fill(db) + return &IAVLTree{db: db, root: root} } func (t *IAVLTree) Root() Node { - return t.root + return t.root } func (t *IAVLTree) Size() uint64 { - if t.root == nil { return 0 } - return t.root.Size() + if t.root == nil { + return 0 + } + return t.root.Size() } func (t *IAVLTree) Height() uint8 { - if t.root == nil { return 0 } - return t.root.Height() + if t.root == nil { + return 0 + } + return t.root.Height() } func (t *IAVLTree) Has(key Key) bool { - if t.root == nil { return false } - return t.root.has(t.db, key) + if t.root == nil { + return false + } + return t.root.has(t.db, key) } func (t *IAVLTree) Put(key Key, value Value) (updated bool) { - if t.root == nil { - t.root = NewIAVLNode(key, value) - return false - } - t.root, updated = t.root.put(t.db, key, value) - return updated + if t.root == nil { + t.root = NewIAVLNode(key, value) + return false + } + t.root, updated = t.root.put(t.db, key, value) + return updated } func (t *IAVLTree) Hash() (ByteSlice, uint64) { - if t.root == nil { return nil, 0 } - return t.root.Hash() + if t.root == nil { + return nil, 0 + } + return t.root.Hash() } func (t *IAVLTree) Save() { - if t.root == nil { return } - if t.root.hash == nil { - t.root.Hash() - } - t.root.Save(t.db) + if t.root == nil { + return + } + if t.root.hash == nil { + t.root.Hash() + } + t.root.Save(t.db) } func (t *IAVLTree) Get(key Key) (value Value) { - if t.root == nil { return nil } - return t.root.get(t.db, key) + if t.root == nil { + return nil + } + return t.root.get(t.db, key) } func (t *IAVLTree) Remove(key Key) (value Value, err error) { - if t.root == nil { return nil, NotFound(key) } - newRoot, _, value, err := t.root.remove(t.db, key) - if err != nil { - return nil, err - } - t.root = newRoot - return value, nil + if t.root == nil { + return nil, NotFound(key) + } + newRoot, _, value, err := t.root.remove(t.db, key) + if err != nil { + return nil, err + } + t.root = newRoot + return value, nil } func (t *IAVLTree) Copy() Tree { - return &IAVLTree{db:t.db, root:t.root} + return &IAVLTree{db: t.db, root: t.root} } // Traverses all the nodes of the tree in prefix order. // return true from cb to halt iteration. // node.Height() == 0 if you just want a value node. func (t *IAVLTree) Traverse(cb func(Node) bool) { - if t.root == nil { return } - t.root.traverse(t.db, cb) + if t.root == nil { + return + } + t.root.traverse(t.db, cb) } func (t *IAVLTree) Values() <-chan Value { - root := t.root - ch := make(chan Value) - go func() { - root.traverse(t.db, func(n Node) bool { - if n.Height() == 0 { ch <- n.Value() } - return true - }) - close(ch) - }() - return ch + root := t.root + ch := make(chan Value) + go func() { + root.traverse(t.db, func(n Node) bool { + if n.Height() == 0 { + ch <- n.Value() + } + return true + }) + close(ch) + }() + return ch } diff --git a/merkle/types.go b/merkle/types.go index b919dd9a..d51cdf98 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -1,50 +1,50 @@ package merkle import ( - . "github.com/tendermint/tendermint/binary" - "fmt" + "fmt" + . "github.com/tendermint/tendermint/binary" ) type Value interface { - Binary + Binary } type Key interface { - Binary - Equals(Binary) bool - Less(b Binary) bool + Binary + Equals(Binary) bool + Less(b Binary) bool } type Db interface { - Get([]byte) []byte - Put([]byte, []byte) + Get([]byte) []byte + Put([]byte, []byte) } type Node interface { - Binary - Key() Key - Value() Value - Size() uint64 - Height() uint8 - Hash() (ByteSlice, uint64) - Save(Db) + Binary + Key() Key + Value() Value + Size() uint64 + Height() uint8 + Hash() (ByteSlice, uint64) + Save(Db) } type Tree interface { - Root() Node - Size() uint64 - Height() uint8 - Has(key Key) bool - Get(key Key) Value - Hash() (ByteSlice, uint64) - Save() - Put(Key, Value) bool - Remove(Key) (Value, error) - Copy() Tree - Traverse(func(Node)bool) - Values() <-chan Value + Root() Node + Size() uint64 + Height() uint8 + Has(key Key) bool + Get(key Key) Value + Hash() (ByteSlice, uint64) + Save() + Put(Key, Value) bool + Remove(Key) (Value, error) + Copy() Tree + Traverse(func(Node) bool) + Values() <-chan Value } func NotFound(key Key) error { - return fmt.Errorf("Key was not found.") + return fmt.Errorf("Key was not found.") } diff --git a/merkle/util.go b/merkle/util.go index 0b5b7ebc..740e8cec 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -1,78 +1,83 @@ package merkle import ( - . "github.com/tendermint/tendermint/binary" - "os" - "fmt" - "crypto/sha256" + "crypto/sha256" + "fmt" + . "github.com/tendermint/tendermint/binary" + "os" ) /* Compute a deterministic merkle hash from a list of byteslices. */ func HashFromBinarySlice(items []Binary) ByteSlice { - switch len(items) { - case 0: - panic("Cannot compute hash of empty slice") - case 1: - hasher := sha256.New() - _, err := items[0].WriteTo(hasher) - if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) - default: - hasher := sha256.New() - _, err := HashFromBinarySlice(items[0:len(items)/2]).WriteTo(hasher) - if err != nil { panic(err) } - _, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher) - if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) - } + switch len(items) { + case 0: + panic("Cannot compute hash of empty slice") + case 1: + hasher := sha256.New() + _, err := items[0].WriteTo(hasher) + if err != nil { + panic(err) + } + return ByteSlice(hasher.Sum(nil)) + default: + hasher := sha256.New() + _, err := HashFromBinarySlice(items[0 : len(items)/2]).WriteTo(hasher) + if err != nil { + panic(err) + } + _, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher) + if err != nil { + panic(err) + } + return ByteSlice(hasher.Sum(nil)) + } } func PrintIAVLNode(node *IAVLNode) { - fmt.Println("==== NODE") - if node != nil { - printIAVLNode(node, 0) - } - fmt.Println("==== END") + fmt.Println("==== NODE") + if node != nil { + printIAVLNode(node, 0) + } + fmt.Println("==== END") } func printIAVLNode(node *IAVLNode, indent int) { - indentPrefix := "" - for i:=0; i b { - return a - } - return b + if a > b { + return a + } + return b } - diff --git a/peer/addrbook.go b/peer/addrbook.go index d84712fa..f3b2db18 100644 --- a/peer/addrbook.go +++ b/peer/addrbook.go @@ -5,217 +5,236 @@ package peer import ( - . "github.com/tendermint/tendermint/binary" - crand "crypto/rand" // for seeding - "encoding/binary" - "encoding/json" - "io" - "math" - "math/rand" - "net" - "sync" - "sync/atomic" - "time" - "os" - "fmt" + crand "crypto/rand" // for seeding + "encoding/binary" + "encoding/json" + "fmt" + . "github.com/tendermint/tendermint/binary" + "io" + "math" + "math/rand" + "net" + "os" + "sync" + "sync/atomic" + "time" ) /* AddrBook - concurrency safe peer address manager */ type AddrBook struct { - filePath string + filePath string - mtx sync.Mutex - rand *rand.Rand - key [32]byte - addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress - addrNew [newBucketCount]map[string]*KnownAddress - addrOld [oldBucketCount][]*KnownAddress - started int32 - shutdown int32 - wg sync.WaitGroup - quit chan struct{} - nOld int - nNew int + mtx sync.Mutex + rand *rand.Rand + key [32]byte + addrIndex map[string]*KnownAddress // addr.String() -> KnownAddress + addrNew [newBucketCount]map[string]*KnownAddress + addrOld [oldBucketCount][]*KnownAddress + started int32 + shutdown int32 + wg sync.WaitGroup + quit chan struct{} + nOld int + nNew int } const ( - // addresses under which the address manager will claim to need more addresses. - needAddressThreshold = 1000 + // addresses under which the address manager will claim to need more addresses. + needAddressThreshold = 1000 - // interval used to dump the address cache to disk for future use. - dumpAddressInterval = time.Minute * 2 + // interval used to dump the address cache to disk for future use. + dumpAddressInterval = time.Minute * 2 - // max addresses in each old address bucket. - oldBucketSize = 64 + // max addresses in each old address bucket. + oldBucketSize = 64 - // buckets we split old addresses over. - oldBucketCount = 64 + // buckets we split old addresses over. + oldBucketCount = 64 - // max addresses in each new address bucket. - newBucketSize = 64 + // max addresses in each new address bucket. + newBucketSize = 64 - // buckets that we spread new addresses over. - newBucketCount = 256 + // buckets that we spread new addresses over. + newBucketCount = 256 - // old buckets over which an address group will be spread. - oldBucketsPerGroup = 4 + // old buckets over which an address group will be spread. + oldBucketsPerGroup = 4 - // new buckets over which an source address group will be spread. - newBucketsPerGroup = 32 + // new buckets over which an source address group will be spread. + newBucketsPerGroup = 32 - // buckets a frequently seen new address may end up in. - newBucketsPerAddress = 4 + // buckets a frequently seen new address may end up in. + newBucketsPerAddress = 4 - // days before which we assume an address has vanished - // if we have not seen it announced in that long. - numMissingDays = 30 + // days before which we assume an address has vanished + // if we have not seen it announced in that long. + numMissingDays = 30 - // tries without a single success before we assume an address is bad. - numRetries = 3 + // tries without a single success before we assume an address is bad. + numRetries = 3 - // max failures we will accept without a success before considering an address bad. - maxFailures = 10 + // max failures we will accept without a success before considering an address bad. + maxFailures = 10 - // days since the last success before we will consider evicting an address. - minBadDays = 7 + // days since the last success before we will consider evicting an address. + minBadDays = 7 - // max addresses that we will send in response to a getAddr - // (in practise the most addresses we will return from a call to AddressCache()). - getAddrMax = 2500 + // max addresses that we will send in response to a getAddr + // (in practise the most addresses we will return from a call to AddressCache()). + getAddrMax = 2500 - // % of total addresses known that we will share with a call to AddressCache. - getAddrPercent = 23 + // % of total addresses known that we will share with a call to AddressCache. + getAddrPercent = 23 - // current version of the on-disk format. - serialisationVersion = 1 + // current version of the on-disk format. + serialisationVersion = 1 ) // Use Start to begin processing asynchronous address updates. func NewAddrBook(filePath string) *AddrBook { - am := AddrBook{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - quit: make(chan struct{}), - filePath: filePath, - } - am.init() - return &am + am := AddrBook{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + quit: make(chan struct{}), + filePath: filePath, + } + am.init() + return &am } // When modifying this, don't forget to update loadFromFile() func (a *AddrBook) init() { - a.addrIndex = make(map[string]*KnownAddress) - io.ReadFull(crand.Reader, a.key[:]) - for i := range a.addrNew { - a.addrNew[i] = make(map[string]*KnownAddress) - } - for i := range a.addrOld { - a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize) - } + a.addrIndex = make(map[string]*KnownAddress) + io.ReadFull(crand.Reader, a.key[:]) + for i := range a.addrNew { + a.addrNew[i] = make(map[string]*KnownAddress) + } + for i := range a.addrOld { + a.addrOld[i] = make([]*KnownAddress, 0, oldBucketSize) + } } func (a *AddrBook) Start() { - if atomic.AddInt32(&a.started, 1) != 1 { return } - log.Trace("Starting address manager") - a.loadFromFile(a.filePath) - a.wg.Add(1) - go a.addressHandler() + if atomic.AddInt32(&a.started, 1) != 1 { + return + } + log.Trace("Starting address manager") + a.loadFromFile(a.filePath) + a.wg.Add(1) + go a.addressHandler() } func (a *AddrBook) Stop() { - if atomic.AddInt32(&a.shutdown, 1) != 1 { return } - log.Infof("Address manager shutting down") - close(a.quit) - a.wg.Wait() + if atomic.AddInt32(&a.shutdown, 1) != 1 { + return + } + log.Infof("Address manager shutting down") + close(a.quit) + a.wg.Wait() } func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) { - a.mtx.Lock(); defer a.mtx.Unlock() - a.addAddress(addr, src) + a.mtx.Lock() + defer a.mtx.Unlock() + a.addAddress(addr, src) } func (a *AddrBook) NeedMoreAddresses() bool { - return a.NumAddresses() < needAddressThreshold + return a.NumAddresses() < needAddressThreshold } func (a *AddrBook) NumAddresses() int { - a.mtx.Lock(); defer a.mtx.Unlock() - return a.nOld + a.nNew + a.mtx.Lock() + defer a.mtx.Unlock() + return a.nOld + a.nNew } // Pick a new address to connect to. func (a *AddrBook) PickAddress(class string, newBias int) *KnownAddress { - a.mtx.Lock(); defer a.mtx.Unlock() + a.mtx.Lock() + defer a.mtx.Unlock() - if a.nOld == 0 && a.nNew == 0 { return nil } - if newBias > 100 { newBias = 100 } - if newBias < 0 { newBias = 0 } + if a.nOld == 0 && a.nNew == 0 { + return nil + } + if newBias > 100 { + newBias = 100 + } + if newBias < 0 { + newBias = 0 + } - // Bias between new and old addresses. - oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) - newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) + // Bias between new and old addresses. + oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) + newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) - if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { - // pick random Old bucket. - var bucket []*KnownAddress = nil - for len(bucket) == 0 { - bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] - } - // pick a random ka from bucket. - return bucket[a.rand.Intn(len(bucket))] - } else { - // pick random New bucket. - var bucket map[string]*KnownAddress = nil - for len(bucket) == 0 { - bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] - } - // pick a random ka from bucket. - randIndex := a.rand.Intn(len(bucket)) - for _, ka := range bucket { - randIndex-- - if randIndex == 0 { - return ka - } - } - panic("Should not happen") - } - return nil + if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { + // pick random Old bucket. + var bucket []*KnownAddress = nil + for len(bucket) == 0 { + bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] + } + // pick a random ka from bucket. + return bucket[a.rand.Intn(len(bucket))] + } else { + // pick random New bucket. + var bucket map[string]*KnownAddress = nil + for len(bucket) == 0 { + bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] + } + // pick a random ka from bucket. + randIndex := a.rand.Intn(len(bucket)) + for _, ka := range bucket { + randIndex-- + if randIndex == 0 { + return ka + } + } + panic("Should not happen") + } + return nil } func (a *AddrBook) MarkGood(addr *NetAddress) { - a.mtx.Lock(); defer a.mtx.Unlock() - ka := a.addrIndex[addr.String()] - if ka == nil { return } - ka.MarkAttempt(true) - if ka.OldBucket == -1 { - a.moveToOld(ka) - } + a.mtx.Lock() + defer a.mtx.Unlock() + ka := a.addrIndex[addr.String()] + if ka == nil { + return + } + ka.MarkAttempt(true) + if ka.OldBucket == -1 { + a.moveToOld(ka) + } } func (a *AddrBook) MarkAttempt(addr *NetAddress) { - a.mtx.Lock(); defer a.mtx.Unlock() - ka := a.addrIndex[addr.String()] - if ka == nil { return } - ka.MarkAttempt(false) + a.mtx.Lock() + defer a.mtx.Unlock() + ka := a.addrIndex[addr.String()] + if ka == nil { + return + } + ka.MarkAttempt(false) } /* Loading & Saving */ type addrBookJSON struct { - Key [32]byte - AddrNew [newBucketCount]map[string]*KnownAddress - AddrOld [oldBucketCount][]*KnownAddress - NOld int - NNew int + Key [32]byte + AddrNew [newBucketCount]map[string]*KnownAddress + AddrOld [oldBucketCount][]*KnownAddress + NOld int + NNew int } func (a *AddrBook) saveToFile(filePath string) { - aJSON := &addrBookJSON{ - Key: a.key, - AddrNew: a.addrNew, - AddrOld: a.addrOld, - NOld: a.nOld, - NNew: a.nNew, - } + aJSON := &addrBookJSON{ + Key: a.key, + AddrNew: a.addrNew, + AddrOld: a.addrOld, + NOld: a.nOld, + NNew: a.nNew, + } w, err := os.Create(filePath) if err != nil { @@ -225,296 +244,306 @@ func (a *AddrBook) saveToFile(filePath string) { enc := json.NewEncoder(w) defer w.Close() err = enc.Encode(&aJSON) - if err != nil { panic(err) } + if err != nil { + panic(err) + } } func (a *AddrBook) loadFromFile(filePath string) { - // If doesn't exist, do nothing. + // If doesn't exist, do nothing. _, err := os.Stat(filePath) - if os.IsNotExist(err) { return } + if os.IsNotExist(err) { + return + } - // Load addrBookJSON{} + // Load addrBookJSON{} r, err := os.Open(filePath) if err != nil { - panic(fmt.Errorf("%s error opening file: %v", filePath, err)) + panic(fmt.Errorf("%s error opening file: %v", filePath, err)) } defer r.Close() - aJSON := &addrBookJSON{} + aJSON := &addrBookJSON{} dec := json.NewDecoder(r) err = dec.Decode(aJSON) if err != nil { panic(fmt.Errorf("error reading %s: %v", filePath, err)) } - // Now we need to initialize self. + // Now we need to initialize self. - copy(a.key[:], aJSON.Key[:]) - a.addrNew = aJSON.AddrNew - for i, oldBucket := range aJSON.AddrOld { - copy(a.addrOld[i], oldBucket) - } - a.nNew = aJSON.NNew - a.nOld = aJSON.NOld + copy(a.key[:], aJSON.Key[:]) + a.addrNew = aJSON.AddrNew + for i, oldBucket := range aJSON.AddrOld { + copy(a.addrOld[i], oldBucket) + } + a.nNew = aJSON.NNew + a.nOld = aJSON.NOld - a.addrIndex = make(map[string]*KnownAddress) - for _, newBucket := range a.addrNew { - for key, ka := range newBucket { - a.addrIndex[key] = ka - } - } + a.addrIndex = make(map[string]*KnownAddress) + for _, newBucket := range a.addrNew { + for key, ka := range newBucket { + a.addrIndex[key] = ka + } + } } - /* Private methods */ func (a *AddrBook) addressHandler() { - dumpAddressTicker := time.NewTicker(dumpAddressInterval) + dumpAddressTicker := time.NewTicker(dumpAddressInterval) out: - for { - select { - case <-dumpAddressTicker.C: - a.saveToFile(a.filePath) - case <-a.quit: - break out - } - } - dumpAddressTicker.Stop() - a.saveToFile(a.filePath) - a.wg.Done() - log.Trace("Address handler done") + for { + select { + case <-dumpAddressTicker.C: + a.saveToFile(a.filePath) + case <-a.quit: + break out + } + } + dumpAddressTicker.Stop() + a.saveToFile(a.filePath) + a.wg.Done() + log.Trace("Address handler done") } func (a *AddrBook) addAddress(addr, src *NetAddress) { - if !addr.Routable() { return } + if !addr.Routable() { + return + } - key := addr.String() - ka := a.addrIndex[key] + key := addr.String() + ka := a.addrIndex[key] - if ka != nil { - // Already added - if ka.OldBucket != -1 { return } - if ka.NewRefs == newBucketsPerAddress { return } + if ka != nil { + // Already added + if ka.OldBucket != -1 { + return + } + if ka.NewRefs == newBucketsPerAddress { + return + } - // The more entries we have, the less likely we are to add more. - factor := int32(2 * ka.NewRefs) - if a.rand.Int31n(factor) != 0 { - return - } - } else { - ka = NewKnownAddress(addr, src) - a.addrIndex[key] = ka - a.nNew++ - } + // The more entries we have, the less likely we are to add more. + factor := int32(2 * ka.NewRefs) + if a.rand.Int31n(factor) != 0 { + return + } + } else { + ka = NewKnownAddress(addr, src) + a.addrIndex[key] = ka + a.nNew++ + } - bucket := a.getNewBucket(addr, src) + bucket := a.getNewBucket(addr, src) - // Already exists? - if _, ok := a.addrNew[bucket][key]; ok { - return - } + // Already exists? + if _, ok := a.addrNew[bucket][key]; ok { + return + } - // Enforce max addresses. - if len(a.addrNew[bucket]) > newBucketSize { - log.Tracef("new bucket is full, expiring old ") - a.expireNew(bucket) - } + // Enforce max addresses. + if len(a.addrNew[bucket]) > newBucketSize { + log.Tracef("new bucket is full, expiring old ") + a.expireNew(bucket) + } - // Add to new bucket. - ka.NewRefs++ - a.addrNew[bucket][key] = ka + // Add to new bucket. + ka.NewRefs++ + a.addrNew[bucket][key] = ka - log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew) + log.Tracef("Added new address %s for a total of %d addresses", addr, a.nOld+a.nNew) } // Make space in the new buckets by expiring the really bad entries. // If no bad entries are available we look at a few and remove the oldest. func (a *AddrBook) expireNew(bucket int) { - var oldest *KnownAddress - for k, v := range a.addrNew[bucket] { - // If an entry is bad, throw it away - if v.Bad() { - log.Tracef("expiring bad address %v", k) - delete(a.addrNew[bucket], k) - v.NewRefs-- - if v.NewRefs == 0 { - a.nNew-- - delete(a.addrIndex, k) - } - return - } - // or, keep track of the oldest entry - if oldest == nil { - oldest = v - } else if v.LastAttempt.Before(oldest.LastAttempt.Time) { - oldest = v - } - } + var oldest *KnownAddress + for k, v := range a.addrNew[bucket] { + // If an entry is bad, throw it away + if v.Bad() { + log.Tracef("expiring bad address %v", k) + delete(a.addrNew[bucket], k) + v.NewRefs-- + if v.NewRefs == 0 { + a.nNew-- + delete(a.addrIndex, k) + } + return + } + // or, keep track of the oldest entry + if oldest == nil { + oldest = v + } else if v.LastAttempt.Before(oldest.LastAttempt.Time) { + oldest = v + } + } - // If we haven't thrown out a bad entry, throw out the oldest entry - if oldest != nil { - key := oldest.Addr.String() - log.Tracef("expiring oldest address %v", key) - delete(a.addrNew[bucket], key) - oldest.NewRefs-- - if oldest.NewRefs == 0 { - a.nNew-- - delete(a.addrIndex, key) - } - } + // If we haven't thrown out a bad entry, throw out the oldest entry + if oldest != nil { + key := oldest.Addr.String() + log.Tracef("expiring oldest address %v", key) + delete(a.addrNew[bucket], key) + oldest.NewRefs-- + if oldest.NewRefs == 0 { + a.nNew-- + delete(a.addrIndex, key) + } + } } func (a *AddrBook) moveToOld(ka *KnownAddress) { - // Remove from all new buckets. - // Remember one of those new buckets. - addrKey := ka.Addr.String() - freedBucket := -1 - for i := range a.addrNew { - // we check for existance so we can record the first one - if _, ok := a.addrNew[i][addrKey]; ok { - delete(a.addrNew[i], addrKey) - ka.NewRefs-- - if freedBucket == -1 { - freedBucket = i - } - } - } - a.nNew-- - if freedBucket == -1 { panic("Expected to find addr in at least one new bucket") } + // Remove from all new buckets. + // Remember one of those new buckets. + addrKey := ka.Addr.String() + freedBucket := -1 + for i := range a.addrNew { + // we check for existance so we can record the first one + if _, ok := a.addrNew[i][addrKey]; ok { + delete(a.addrNew[i], addrKey) + ka.NewRefs-- + if freedBucket == -1 { + freedBucket = i + } + } + } + a.nNew-- + if freedBucket == -1 { + panic("Expected to find addr in at least one new bucket") + } - oldBucket := a.getOldBucket(ka.Addr) + oldBucket := a.getOldBucket(ka.Addr) - // If room in oldBucket, put it in. - if len(a.addrOld[oldBucket]) < oldBucketSize { - ka.OldBucket = Int16(oldBucket) - a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka) - a.nOld++ - return - } + // If room in oldBucket, put it in. + if len(a.addrOld[oldBucket]) < oldBucketSize { + ka.OldBucket = Int16(oldBucket) + a.addrOld[oldBucket] = append(a.addrOld[oldBucket], ka) + a.nOld++ + return + } - // No room, we have to evict something else. - rmkaIndex := a.pickOld(oldBucket) - rmka := a.addrOld[oldBucket][rmkaIndex] + // No room, we have to evict something else. + rmkaIndex := a.pickOld(oldBucket) + rmka := a.addrOld[oldBucket][rmkaIndex] - // Find a new bucket to put rmka in. - newBucket := a.getNewBucket(rmka.Addr, rmka.Src) - if len(a.addrNew[newBucket]) >= newBucketSize { - newBucket = freedBucket - } + // Find a new bucket to put rmka in. + newBucket := a.getNewBucket(rmka.Addr, rmka.Src) + if len(a.addrNew[newBucket]) >= newBucketSize { + newBucket = freedBucket + } - // replace with ka in list. - ka.OldBucket = Int16(oldBucket) - a.addrOld[oldBucket][rmkaIndex] = ka - rmka.OldBucket = -1 + // replace with ka in list. + ka.OldBucket = Int16(oldBucket) + a.addrOld[oldBucket][rmkaIndex] = ka + rmka.OldBucket = -1 - // put rmka into new bucket - rmkey := rmka.Addr.String() - log.Tracef("Replacing %s with %s in old", rmkey, addrKey) - a.addrNew[newBucket][rmkey] = rmka - rmka.NewRefs++ - a.nNew++ + // put rmka into new bucket + rmkey := rmka.Addr.String() + log.Tracef("Replacing %s with %s in old", rmkey, addrKey) + a.addrNew[newBucket][rmkey] = rmka + rmka.NewRefs++ + a.nNew++ } // Returns the index in old bucket of oldest entry. func (a *AddrBook) pickOld(bucket int) int { - var oldest *KnownAddress - var oldestIndex int - for i, ka := range a.addrOld[bucket] { - if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) { - oldest = ka - oldestIndex = i - } - } - return oldestIndex + var oldest *KnownAddress + var oldestIndex int + for i, ka := range a.addrOld[bucket] { + if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt.Time) { + oldest = ka + oldestIndex = i + } + } + return oldestIndex } // doublesha256(key + sourcegroup + // int64(doublesha256(key + group + sourcegroup))%bucket_per_source_group) % num_new_buckes func (a *AddrBook) getNewBucket(addr, src *NetAddress) int { - data1 := []byte{} - data1 = append(data1, a.key[:]...) - data1 = append(data1, []byte(GroupKey(addr))...) - data1 = append(data1, []byte(GroupKey(src))...) - hash1 := DoubleSha256(data1) - hash64 := binary.LittleEndian.Uint64(hash1) - hash64 %= newBucketsPerGroup - var hashbuf [8]byte - binary.LittleEndian.PutUint64(hashbuf[:], hash64) - data2 := []byte{} - data2 = append(data2, a.key[:]...) - data2 = append(data2, GroupKey(src)...) - data2 = append(data2, hashbuf[:]...) + data1 := []byte{} + data1 = append(data1, a.key[:]...) + data1 = append(data1, []byte(GroupKey(addr))...) + data1 = append(data1, []byte(GroupKey(src))...) + hash1 := DoubleSha256(data1) + hash64 := binary.LittleEndian.Uint64(hash1) + hash64 %= newBucketsPerGroup + var hashbuf [8]byte + binary.LittleEndian.PutUint64(hashbuf[:], hash64) + data2 := []byte{} + data2 = append(data2, a.key[:]...) + data2 = append(data2, GroupKey(src)...) + data2 = append(data2, hashbuf[:]...) - hash2 := DoubleSha256(data2) - return int(binary.LittleEndian.Uint64(hash2) % newBucketCount) + hash2 := DoubleSha256(data2) + return int(binary.LittleEndian.Uint64(hash2) % newBucketCount) } // doublesha256(key + group + truncate_to_64bits(doublesha256(key + addr))%buckets_per_group) % num_buckets func (a *AddrBook) getOldBucket(addr *NetAddress) int { - data1 := []byte{} - data1 = append(data1, a.key[:]...) - data1 = append(data1, []byte(addr.String())...) - hash1 := DoubleSha256(data1) - hash64 := binary.LittleEndian.Uint64(hash1) - hash64 %= oldBucketsPerGroup - var hashbuf [8]byte - binary.LittleEndian.PutUint64(hashbuf[:], hash64) - data2 := []byte{} - data2 = append(data2, a.key[:]...) - data2 = append(data2, GroupKey(addr)...) - data2 = append(data2, hashbuf[:]...) + data1 := []byte{} + data1 = append(data1, a.key[:]...) + data1 = append(data1, []byte(addr.String())...) + hash1 := DoubleSha256(data1) + hash64 := binary.LittleEndian.Uint64(hash1) + hash64 %= oldBucketsPerGroup + var hashbuf [8]byte + binary.LittleEndian.PutUint64(hashbuf[:], hash64) + data2 := []byte{} + data2 = append(data2, a.key[:]...) + data2 = append(data2, GroupKey(addr)...) + data2 = append(data2, hashbuf[:]...) - hash2 := DoubleSha256(data2) - return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount) + hash2 := DoubleSha256(data2) + return int(binary.LittleEndian.Uint64(hash2) % oldBucketCount) } - // Return a string representing the network group of this address. // This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string // "local" for a local address and the string "unroutable for an unroutable // address. -func GroupKey (na *NetAddress) string { - if na.Local() { - return "local" - } - if !na.Routable() { - return "unroutable" - } +func GroupKey(na *NetAddress) string { + if na.Local() { + return "local" + } + if !na.Routable() { + return "unroutable" + } - if ipv4 := na.IP.To4(); ipv4 != nil { - return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() - } - if na.RFC6145() || na.RFC6052() { - // last four bytes are the ip address - ip := net.IP(na.IP[12:16]) - return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() - } + if ipv4 := na.IP.To4(); ipv4 != nil { + return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() + } + if na.RFC6145() || na.RFC6052() { + // last four bytes are the ip address + ip := net.IP(na.IP[12:16]) + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + } - if na.RFC3964() { - ip := net.IP(na.IP[2:7]) - return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + if na.RFC3964() { + ip := net.IP(na.IP[2:7]) + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() - } - if na.RFC4380() { - // teredo tunnels have the last 4 bytes as the v4 address XOR - // 0xff. - ip := net.IP(make([]byte, 4)) - for i, byte := range na.IP[12:16] { - ip[i] = byte ^ 0xff - } - return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() - } + } + if na.RFC4380() { + // teredo tunnels have the last 4 bytes as the v4 address XOR + // 0xff. + ip := net.IP(make([]byte, 4)) + for i, byte := range na.IP[12:16] { + ip[i] = byte ^ 0xff + } + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + } - // OK, so now we know ourselves to be a IPv6 address. - // bitcoind uses /32 for everything, except for Hurricane Electric's - // (he.net) IP range, which it uses /36 for. - bits := 32 - heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), - Mask: net.CIDRMask(32, 128)} - if heNet.Contains(na.IP) { - bits = 36 - } + // OK, so now we know ourselves to be a IPv6 address. + // bitcoind uses /32 for everything, except for Hurricane Electric's + // (he.net) IP range, which it uses /36 for. + bits := 32 + heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), + Mask: net.CIDRMask(32, 128)} + if heNet.Contains(na.IP) { + bits = 36 + } - return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() + return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() } diff --git a/peer/client.go b/peer/client.go index 94390456..27b3c711 100644 --- a/peer/client.go +++ b/peer/client.go @@ -1,12 +1,12 @@ package peer import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "github.com/tendermint/tendermint/merkle" - "sync/atomic" - "sync" - "errors" + "errors" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "github.com/tendermint/tendermint/merkle" + "sync" + "sync/atomic" ) /* Client @@ -21,147 +21,161 @@ import ( XXX what about peer disconnects? */ type Client struct { - addrBook *AddrBook - targetNumPeers int - makePeerFn func(*Connection) *Peer - self *Peer - recvQueues map[String]chan *InboundPacket + addrBook *AddrBook + targetNumPeers int + makePeerFn func(*Connection) *Peer + self *Peer + recvQueues map[String]chan *InboundPacket - mtx sync.Mutex - peers merkle.Tree // addr -> *Peer - quit chan struct{} - stopped uint32 + mtx sync.Mutex + peers merkle.Tree // addr -> *Peer + quit chan struct{} + stopped uint32 } var ( - CLIENT_STOPPED_ERROR = errors.New("Client already stopped") - CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer") + CLIENT_STOPPED_ERROR = errors.New("Client already stopped") + CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer") ) func NewClient(makePeerFn func(*Connection) *Peer) *Client { - self := makePeerFn(nil) - if self == nil { - Panicf("makePeerFn(nil) must return a prototypical peer for self") - } + self := makePeerFn(nil) + if self == nil { + Panicf("makePeerFn(nil) must return a prototypical peer for self") + } - recvQueues := make(map[String]chan *InboundPacket) - for chName, _ := range self.channels { - recvQueues[chName] = make(chan *InboundPacket) - } + recvQueues := make(map[String]chan *InboundPacket) + for chName, _ := range self.channels { + recvQueues[chName] = make(chan *InboundPacket) + } - c := &Client{ - addrBook: nil, // TODO - targetNumPeers: 0, // TODO - makePeerFn: makePeerFn, - self: self, - recvQueues: recvQueues, + c := &Client{ + addrBook: nil, // TODO + targetNumPeers: 0, // TODO + makePeerFn: makePeerFn, + self: self, + recvQueues: recvQueues, - peers: merkle.NewIAVLTree(nil), - quit: make(chan struct{}), - stopped: 0, - } - return c + peers: merkle.NewIAVLTree(nil), + quit: make(chan struct{}), + stopped: 0, + } + return c } func (c *Client) Stop() { - log.Infof("Stopping client") - // lock - c.mtx.Lock() - if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { - close(c.quit) - // stop each peer. - for peerValue := range c.peers.Values() { - peer := peerValue.(*Peer) - peer.Stop() - } - // empty tree. - c.peers = merkle.NewIAVLTree(nil) - } - c.mtx.Unlock() - // unlock + log.Infof("Stopping client") + // lock + c.mtx.Lock() + if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { + close(c.quit) + // stop each peer. + for peerValue := range c.peers.Values() { + peer := peerValue.(*Peer) + peer.Stop() + } + // empty tree. + c.peers = merkle.NewIAVLTree(nil) + } + c.mtx.Unlock() + // unlock } func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) { - if atomic.LoadUint32(&c.stopped) == 1 { return nil, CLIENT_STOPPED_ERROR } + if atomic.LoadUint32(&c.stopped) == 1 { + return nil, CLIENT_STOPPED_ERROR + } - log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing) - peer := c.makePeerFn(conn) - peer.outgoing = outgoing - err := c.addPeer(peer) - if err != nil { return nil, err } + log.Infof("Adding peer with connection: %v, outgoing: %v", conn, outgoing) + peer := c.makePeerFn(conn) + peer.outgoing = outgoing + err := c.addPeer(peer) + if err != nil { + return nil, err + } - go peer.Start(c.recvQueues) + go peer.Start(c.recvQueues) - return peer, nil + return peer, nil } func (c *Client) Broadcast(pkt Packet) { - if atomic.LoadUint32(&c.stopped) == 1 { return } + if atomic.LoadUint32(&c.stopped) == 1 { + return + } - log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes)) - for v := range c.Peers().Values() { - peer := v.(*Peer) - success := peer.TrySend(pkt) - log.Tracef("Broadcast for peer %v success: %v", peer, success) - if !success { - // TODO: notify the peer - } - } + log.Tracef("Broadcast on [%v] len: %v", pkt.Channel, len(pkt.Bytes)) + for v := range c.Peers().Values() { + peer := v.(*Peer) + success := peer.TrySend(pkt) + log.Tracef("Broadcast for peer %v success: %v", peer, success) + if !success { + // TODO: notify the peer + } + } } // blocks until a message is popped. func (c *Client) Receive(chName String) *InboundPacket { - if atomic.LoadUint32(&c.stopped) == 1 { return nil } + if atomic.LoadUint32(&c.stopped) == 1 { + return nil + } - log.Tracef("Receive on [%v]", chName) - q := c.recvQueues[chName] - if q == nil { Panicf("Expected recvQueues[%f], found none", chName) } + log.Tracef("Receive on [%v]", chName) + q := c.recvQueues[chName] + if q == nil { + Panicf("Expected recvQueues[%f], found none", chName) + } - for { - select { - case <-c.quit: - return nil - case inPacket := <-q: - return inPacket - } - } + for { + select { + case <-c.quit: + return nil + case inPacket := <-q: + return inPacket + } + } } func (c *Client) Peers() merkle.Tree { - // lock & defer - c.mtx.Lock(); defer c.mtx.Unlock() - return c.peers.Copy() - // unlock deferred + // lock & defer + c.mtx.Lock() + defer c.mtx.Unlock() + return c.peers.Copy() + // unlock deferred } func (c *Client) StopPeer(peer *Peer) { - // lock - c.mtx.Lock() - peerValue, _ := c.peers.Remove(peer.RemoteAddress()) - c.mtx.Unlock() - // unlock + // lock + c.mtx.Lock() + peerValue, _ := c.peers.Remove(peer.RemoteAddress()) + c.mtx.Unlock() + // unlock - peer_ := peerValue.(*Peer) - if peer_ != nil { - peer_.Stop() - } + peer_ := peerValue.(*Peer) + if peer_ != nil { + peer_.Stop() + } } func (c *Client) addPeer(peer *Peer) error { - addr := peer.RemoteAddress() + addr := peer.RemoteAddress() - // lock & defer - c.mtx.Lock(); defer c.mtx.Unlock() - if c.stopped == 1 { return CLIENT_STOPPED_ERROR } - if !c.peers.Has(addr) { - log.Tracef("Actually putting addr: %v, peer: %v", addr, peer) - c.peers.Put(addr, peer) - return nil - } else { - // ignore duplicate peer for addr. - log.Infof("Ignoring duplicate peer for addr %v", addr) - return CLIENT_DUPLICATE_PEER_ERROR - } - // unlock deferred + // lock & defer + c.mtx.Lock() + defer c.mtx.Unlock() + if c.stopped == 1 { + return CLIENT_STOPPED_ERROR + } + if !c.peers.Has(addr) { + log.Tracef("Actually putting addr: %v, peer: %v", addr, peer) + c.peers.Put(addr, peer) + return nil + } else { + // ignore duplicate peer for addr. + log.Infof("Ignoring duplicate peer for addr %v", addr) + return CLIENT_DUPLICATE_PEER_ERROR + } + // unlock deferred } diff --git a/peer/client_test.go b/peer/client_test.go index 746bbda7..d7c87620 100644 --- a/peer/client_test.go +++ b/peer/client_test.go @@ -1,106 +1,105 @@ package peer import ( - . "github.com/tendermint/tendermint/binary" - "testing" - "time" + . "github.com/tendermint/tendermint/binary" + "testing" + "time" ) // convenience method for creating two clients connected to each other. func makeClientPair(t *testing.T, bufferSize int, channels []string) (*Client, *Client) { - peerMaker := func(conn *Connection) *Peer { - p := NewPeer(conn) - p.channels = map[String]*Channel{} - for chName := range channels { - p.channels[String(chName)] = NewChannel(String(chName), bufferSize) - } - return p - } + peerMaker := func(conn *Connection) *Peer { + p := NewPeer(conn) + p.channels = map[String]*Channel{} + for chName := range channels { + p.channels[String(chName)] = NewChannel(String(chName), bufferSize) + } + return p + } - // Create two clients that will be interconnected. - c1 := NewClient(peerMaker) - c2 := NewClient(peerMaker) + // Create two clients that will be interconnected. + c1 := NewClient(peerMaker) + c2 := NewClient(peerMaker) - // Create a server for the listening client. - s1 := NewServer("tcp", ":8001", c1) + // Create a server for the listening client. + s1 := NewServer("tcp", ":8001", c1) - // Dial the server & add the connection to c2. - s1laddr := s1.LocalAddress() - conn, err := s1laddr.Dial() - if err != nil { - t.Fatalf("Could not connect to server address %v", s1laddr) - } else { - t.Logf("Created a connection to local server address %v", s1laddr) - } + // Dial the server & add the connection to c2. + s1laddr := s1.LocalAddress() + conn, err := s1laddr.Dial() + if err != nil { + t.Fatalf("Could not connect to server address %v", s1laddr) + } else { + t.Logf("Created a connection to local server address %v", s1laddr) + } - c2.AddPeerWithConnection(conn, true) + c2.AddPeerWithConnection(conn, true) - // Wait for things to happen, peers to get added... - time.Sleep(100 * time.Millisecond) + // Wait for things to happen, peers to get added... + time.Sleep(100 * time.Millisecond) - return c1, c2 + return c1, c2 } func TestClients(t *testing.T) { - c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) + c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) - // Lets send a message from c1 to c2. - if c1.Peers().Size() != 1 { - t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size()) - } - if c2.Peers().Size() != 1 { - t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size()) - } + // Lets send a message from c1 to c2. + if c1.Peers().Size() != 1 { + t.Errorf("Expected exactly 1 peer in c1, got %v", c1.Peers().Size()) + } + if c2.Peers().Size() != 1 { + t.Errorf("Expected exactly 1 peer in c2, got %v", c2.Peers().Size()) + } - // Broadcast a message on ch1 - c1.Broadcast(NewPacket("ch1", ByteSlice("channel one"))) - // Broadcast a message on ch2 - c1.Broadcast(NewPacket("ch2", ByteSlice("channel two"))) - // Broadcast a message on ch3 - c1.Broadcast(NewPacket("ch3", ByteSlice("channel three"))) + // Broadcast a message on ch1 + c1.Broadcast(NewPacket("ch1", ByteSlice("channel one"))) + // Broadcast a message on ch2 + c1.Broadcast(NewPacket("ch2", ByteSlice("channel two"))) + // Broadcast a message on ch3 + c1.Broadcast(NewPacket("ch3", ByteSlice("channel three"))) - // Wait for things to settle... - time.Sleep(100 * time.Millisecond) + // Wait for things to settle... + time.Sleep(100 * time.Millisecond) - // Receive message from channel 2 and check - inMsg := c2.Receive("ch2") - if string(inMsg.Bytes) != "channel two" { - t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) - } + // Receive message from channel 2 and check + inMsg := c2.Receive("ch2") + if string(inMsg.Bytes) != "channel two" { + t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) + } - // Receive message from channel 1 and check - inMsg = c2.Receive("ch1") - if string(inMsg.Bytes) != "channel one" { - t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) - } + // Receive message from channel 1 and check + inMsg = c2.Receive("ch1") + if string(inMsg.Bytes) != "channel one" { + t.Errorf("Unexpected received message bytes: %v", string(inMsg.Bytes)) + } - s1.Stop() - c2.Stop() + s1.Stop() + c2.Stop() } - func BenchmarkClients(b *testing.B) { - b.StopTimer() + b.StopTimer() - // TODO: benchmark the random functions, which is faster? + // TODO: benchmark the random functions, which is faster? - c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) + c1, c2 := makeClientPair(t, 10, []string{"ch1", "ch2", "ch3"}) - // Create a sink on either channel to just pop off messages. - // TODO: ensure that when clients stop, this goroutine stops. - func recvHandler(c *Client) { - } + // Create a sink on either channel to just pop off messages. + // TODO: ensure that when clients stop, this goroutine stops. + recvHandler := func(c *Client) { + } - go recvHandler(c1) - go recvHandler(c2) + go recvHandler(c1) + go recvHandler(c2) - b.StartTimer() + b.StartTimer() - // Send random message from one channel to another - for i := 0; i < b.N; i++ { - } + // Send random message from one channel to another + for i := 0; i < b.N; i++ { + } } diff --git a/peer/connection.go b/peer/connection.go index 91ed5a74..16af6a64 100644 --- a/peer/connection.go +++ b/peer/connection.go @@ -1,191 +1,192 @@ package peer import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "sync/atomic" - "net" - "time" - "fmt" + "fmt" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "net" + "sync/atomic" + "time" ) const ( - OUT_QUEUE_SIZE = 50 - IDLE_TIMEOUT_MINUTES = 5 - PING_TIMEOUT_MINUTES = 2 + OUT_QUEUE_SIZE = 50 + IDLE_TIMEOUT_MINUTES = 5 + PING_TIMEOUT_MINUTES = 2 ) /* Connnection */ type Connection struct { - ioStats IOStats + ioStats IOStats - sendQueue chan Packet // never closes - conn net.Conn - quit chan struct{} - stopped uint32 - pingDebouncer *Debouncer - pong chan struct{} + sendQueue chan Packet // never closes + conn net.Conn + quit chan struct{} + stopped uint32 + pingDebouncer *Debouncer + pong chan struct{} } var ( - PACKET_TYPE_PING = UInt8(0x00) - PACKET_TYPE_PONG = UInt8(0x01) - PACKET_TYPE_MSG = UInt8(0x10) + PACKET_TYPE_PING = UInt8(0x00) + PACKET_TYPE_PONG = UInt8(0x01) + PACKET_TYPE_MSG = UInt8(0x10) ) func NewConnection(conn net.Conn) *Connection { - return &Connection{ - sendQueue: make(chan Packet, OUT_QUEUE_SIZE), - conn: conn, - quit: make(chan struct{}), - pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute), - pong: make(chan struct{}), - } + return &Connection{ + sendQueue: make(chan Packet, OUT_QUEUE_SIZE), + conn: conn, + quit: make(chan struct{}), + pingDebouncer: NewDebouncer(PING_TIMEOUT_MINUTES * time.Minute), + pong: make(chan struct{}), + } } // returns true if successfully queued, // returns false if connection was closed. // blocks. func (c *Connection) Send(pkt Packet) bool { - select { - case c.sendQueue <- pkt: - return true - case <-c.quit: - return false - } + select { + case c.sendQueue <- pkt: + return true + case <-c.quit: + return false + } } func (c *Connection) Start(channels map[String]*Channel) { - log.Debugf("Starting %v", c) - go c.sendHandler() - go c.recvHandler(channels) + log.Debugf("Starting %v", c) + go c.sendHandler() + go c.recvHandler(channels) } func (c *Connection) Stop() { - if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { - log.Debugf("Stopping %v", c) - close(c.quit) - c.conn.Close() - c.pingDebouncer.Stop() - // We can't close pong safely here because - // recvHandler may write to it after we've stopped. - // Though it doesn't need to get closed at all, - // we close it @ recvHandler. - // close(c.pong) - } + if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { + log.Debugf("Stopping %v", c) + close(c.quit) + c.conn.Close() + c.pingDebouncer.Stop() + // We can't close pong safely here because + // recvHandler may write to it after we've stopped. + // Though it doesn't need to get closed at all, + // we close it @ recvHandler. + // close(c.pong) + } } func (c *Connection) LocalAddress() *NetAddress { - return NewNetAddress(c.conn.LocalAddr()) + return NewNetAddress(c.conn.LocalAddr()) } func (c *Connection) RemoteAddress() *NetAddress { - return NewNetAddress(c.conn.RemoteAddr()) + return NewNetAddress(c.conn.RemoteAddr()) } func (c *Connection) String() string { - return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr()) + return fmt.Sprintf("Connection{%v}", c.conn.RemoteAddr()) } func (c *Connection) flush() { - // TODO flush? (turn off nagel, turn back on, etc) + // TODO flush? (turn off nagel, turn back on, etc) } func (c *Connection) sendHandler() { - log.Tracef("%v sendHandler", c) + log.Tracef("%v sendHandler", c) - // TODO: catch panics & stop connection. + // TODO: catch panics & stop connection. - FOR_LOOP: - for { - var err error - select { - case <-c.pingDebouncer.Ch: - _, err = PACKET_TYPE_PING.WriteTo(c.conn) - case sendPkt := <-c.sendQueue: - log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection") - _, err = PACKET_TYPE_MSG.WriteTo(c.conn) - if err != nil { break } - _, err = sendPkt.WriteTo(c.conn) - case <-c.pong: - _, err = PACKET_TYPE_PONG.WriteTo(c.conn) - case <-c.quit: - break FOR_LOOP - } +FOR_LOOP: + for { + var err error + select { + case <-c.pingDebouncer.Ch: + _, err = PACKET_TYPE_PING.WriteTo(c.conn) + case sendPkt := <-c.sendQueue: + log.Tracef("Found pkt from sendQueue. Writing pkt to underlying connection") + _, err = PACKET_TYPE_MSG.WriteTo(c.conn) + if err != nil { + break + } + _, err = sendPkt.WriteTo(c.conn) + case <-c.pong: + _, err = PACKET_TYPE_PONG.WriteTo(c.conn) + case <-c.quit: + break FOR_LOOP + } - if err != nil { - log.Infof("%v failed @ sendHandler:\n%v", c, err) - c.Stop() - break FOR_LOOP - } + if err != nil { + log.Infof("%v failed @ sendHandler:\n%v", c, err) + c.Stop() + break FOR_LOOP + } - c.flush() - } + c.flush() + } - log.Tracef("%v sendHandler done", c) - // cleanup + log.Tracef("%v sendHandler done", c) + // cleanup } func (c *Connection) recvHandler(channels map[String]*Channel) { - log.Tracef("%v recvHandler with %v channels", c, len(channels)) + log.Tracef("%v recvHandler with %v channels", c, len(channels)) - // TODO: catch panics & stop connection. + // TODO: catch panics & stop connection. - FOR_LOOP: - for { - pktType, err := ReadUInt8Safe(c.conn) - if err != nil { - if atomic.LoadUint32(&c.stopped) != 1 { - log.Infof("%v failed @ recvHandler", c) - c.Stop() - } - break FOR_LOOP - } else { - log.Tracef("Found pktType %v", pktType) - } +FOR_LOOP: + for { + pktType, err := ReadUInt8Safe(c.conn) + if err != nil { + if atomic.LoadUint32(&c.stopped) != 1 { + log.Infof("%v failed @ recvHandler", c) + c.Stop() + } + break FOR_LOOP + } else { + log.Tracef("Found pktType %v", pktType) + } - switch pktType { - case PACKET_TYPE_PING: - c.pong <- struct{}{} - case PACKET_TYPE_PONG: - // do nothing - case PACKET_TYPE_MSG: - pkt, err := ReadPacketSafe(c.conn) - if err != nil { - if atomic.LoadUint32(&c.stopped) != 1 { - log.Infof("%v failed @ recvHandler", c) - c.Stop() - } - break FOR_LOOP - } - channel := channels[pkt.Channel] - if channel == nil { - Panicf("Unknown channel %v", pkt.Channel) - } - channel.recvQueue <- pkt - default: - Panicf("Unknown message type %v", pktType) - } + switch pktType { + case PACKET_TYPE_PING: + c.pong <- struct{}{} + case PACKET_TYPE_PONG: + // do nothing + case PACKET_TYPE_MSG: + pkt, err := ReadPacketSafe(c.conn) + if err != nil { + if atomic.LoadUint32(&c.stopped) != 1 { + log.Infof("%v failed @ recvHandler", c) + c.Stop() + } + break FOR_LOOP + } + channel := channels[pkt.Channel] + if channel == nil { + Panicf("Unknown channel %v", pkt.Channel) + } + channel.recvQueue <- pkt + default: + Panicf("Unknown message type %v", pktType) + } - c.pingDebouncer.Reset() - } + c.pingDebouncer.Reset() + } - log.Tracef("%v recvHandler done", c) - // cleanup - close(c.pong) - for _ = range c.pong { - // drain - } + log.Tracef("%v recvHandler done", c) + // cleanup + close(c.pong) + for _ = range c.pong { + // drain + } } - /* IOStats */ type IOStats struct { - TimeConnected Time - LastSent Time - LastRecv Time - BytesRecv UInt64 - BytesSent UInt64 - PktsRecv UInt64 - PktsSent UInt64 + TimeConnected Time + LastSent Time + LastRecv Time + BytesRecv UInt64 + BytesSent UInt64 + PktsRecv UInt64 + PktsSent UInt64 } diff --git a/peer/knownaddress.go b/peer/knownaddress.go index 3baed60f..c18e9401 100644 --- a/peer/knownaddress.go +++ b/peer/knownaddress.go @@ -1,104 +1,104 @@ package peer import ( - . "github.com/tendermint/tendermint/binary" - "time" - "io" + . "github.com/tendermint/tendermint/binary" + "io" + "time" ) /* - KnownAddress + KnownAddress - tracks information about a known network address that is used - to determine how viable an address is. + tracks information about a known network address that is used + to determine how viable an address is. */ type KnownAddress struct { - Addr *NetAddress - Src *NetAddress - Attempts UInt32 - LastAttempt Time - LastSuccess Time - NewRefs UInt16 - OldBucket Int16 // TODO init to -1 + Addr *NetAddress + Src *NetAddress + Attempts UInt32 + LastAttempt Time + LastSuccess Time + NewRefs UInt16 + OldBucket Int16 // TODO init to -1 } func NewKnownAddress(addr *NetAddress, src *NetAddress) *KnownAddress { - return &KnownAddress{ - Addr: addr, - Src: src, - OldBucket: -1, - LastAttempt: Time{time.Now()}, - Attempts: 0, - } + return &KnownAddress{ + Addr: addr, + Src: src, + OldBucket: -1, + LastAttempt: Time{time.Now()}, + Attempts: 0, + } } func ReadKnownAddress(r io.Reader) *KnownAddress { - return &KnownAddress{ - Addr: ReadNetAddress(r), - Src: ReadNetAddress(r), - Attempts: ReadUInt32(r), - LastAttempt: ReadTime(r), - LastSuccess: ReadTime(r), - NewRefs: ReadUInt16(r), - OldBucket: ReadInt16(r), - } + return &KnownAddress{ + Addr: ReadNetAddress(r), + Src: ReadNetAddress(r), + Attempts: ReadUInt32(r), + LastAttempt: ReadTime(r), + LastSuccess: ReadTime(r), + NewRefs: ReadUInt16(r), + OldBucket: ReadInt16(r), + } } func (ka *KnownAddress) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(ka.Addr, w, n, err) - n, err = WriteOnto(ka.Src, w, n, err) - n, err = WriteOnto(ka.Attempts, w, n, err) - n, err = WriteOnto(ka.LastAttempt, w, n, err) - n, err = WriteOnto(ka.LastSuccess, w, n, err) - n, err = WriteOnto(ka.NewRefs, w, n, err) - n, err = WriteOnto(ka.OldBucket, w, n, err) - return + n, err = WriteOnto(ka.Addr, w, n, err) + n, err = WriteOnto(ka.Src, w, n, err) + n, err = WriteOnto(ka.Attempts, w, n, err) + n, err = WriteOnto(ka.LastAttempt, w, n, err) + n, err = WriteOnto(ka.LastSuccess, w, n, err) + n, err = WriteOnto(ka.NewRefs, w, n, err) + n, err = WriteOnto(ka.OldBucket, w, n, err) + return } func (ka *KnownAddress) MarkAttempt(success bool) { - now := Time{time.Now()} - ka.LastAttempt = now - if success { - ka.LastSuccess = now - ka.Attempts = 0 - } else { - ka.Attempts += 1 - } + now := Time{time.Now()} + ka.LastAttempt = now + if success { + ka.LastSuccess = now + ka.Attempts = 0 + } else { + ka.Attempts += 1 + } } /* - An address is bad if the address in question has not been tried in the last - minute and meets one of the following criteria: + An address is bad if the address in question has not been tried in the last + minute and meets one of the following criteria: - 1) It claims to be from the future - 2) It hasn't been seen in over a month - 3) It has failed at least three times and never succeeded - 4) It has failed ten times in the last week + 1) It claims to be from the future + 2) It hasn't been seen in over a month + 3) It has failed at least three times and never succeeded + 4) It has failed ten times in the last week - All addresses that meet these criteria are assumed to be worthless and not - worth keeping hold of. + All addresses that meet these criteria are assumed to be worthless and not + worth keeping hold of. */ func (ka *KnownAddress) Bad() bool { - // Has been attempted in the last minute --> good - if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { - return false - } + // Has been attempted in the last minute --> good + if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { + return false + } - // Over a month old? - if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { - return true - } + // Over a month old? + if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { + return true + } - // Never succeeded? - if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { - return true - } + // Never succeeded? + if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { + return true + } - // Hasn't succeeded in too long? - if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && - ka.Attempts >= maxFailures { - return true - } + // Hasn't succeeded in too long? + if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && + ka.Attempts >= maxFailures { + return true + } - return false + return false } diff --git a/peer/listener.go b/peer/listener.go index 3e4f7ab5..1db13702 100644 --- a/peer/listener.go +++ b/peer/listener.go @@ -1,127 +1,145 @@ package peer import ( - . "github.com/tendermint/tendermint/common" - "sync/atomic" - "net" + . "github.com/tendermint/tendermint/common" + "net" + "sync/atomic" ) const ( - DEFAULT_PORT = 8001 + DEFAULT_PORT = 8001 ) /* Listener */ type Listener interface { - Connections() <-chan *Connection - LocalAddress() *NetAddress - Stop() + Connections() <-chan *Connection + LocalAddress() *NetAddress + Stop() } - /* DefaultListener */ type DefaultListener struct { - listener net.Listener - connections chan *Connection - stopped uint32 + listener net.Listener + connections chan *Connection + stopped uint32 } const ( - DEFAULT_BUFFERED_CONNECTIONS = 10 + DEFAULT_BUFFERED_CONNECTIONS = 10 ) func NewDefaultListener(protocol string, listenAddr string) Listener { - listener, err := net.Listen(protocol, listenAddr) - if err != nil { panic(err) } + listener, err := net.Listen(protocol, listenAddr) + if err != nil { + panic(err) + } - dl := &DefaultListener{ - listener: listener, - connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS), - } + dl := &DefaultListener{ + listener: listener, + connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS), + } - go dl.listenHandler() + go dl.listenHandler() - return dl + return dl } func (l *DefaultListener) listenHandler() { - for { - conn, err := l.listener.Accept() + for { + conn, err := l.listener.Accept() - if atomic.LoadUint32(&l.stopped) == 1 { return } + if atomic.LoadUint32(&l.stopped) == 1 { + return + } - // listener wasn't stopped, - // yet we encountered an error. - if err != nil { panic(err) } + // listener wasn't stopped, + // yet we encountered an error. + if err != nil { + panic(err) + } - c := NewConnection(conn) - l.connections <- c - } + c := NewConnection(conn) + l.connections <- c + } - // cleanup - close(l.connections) - for _ = range l.connections { - // drain - } + // cleanup + close(l.connections) + for _ = range l.connections { + // drain + } } func (l *DefaultListener) Connections() <-chan *Connection { - return l.connections + return l.connections } func (l *DefaultListener) LocalAddress() *NetAddress { - return GetLocalAddress() + return GetLocalAddress() } func (l *DefaultListener) Stop() { - if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) { - l.listener.Close() - } + if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) { + l.listener.Close() + } } - /* local address helpers */ func GetLocalAddress() *NetAddress { - laddr := GetUPNPLocalAddress() - if laddr != nil { return laddr } + laddr := GetUPNPLocalAddress() + if laddr != nil { + return laddr + } - laddr = GetDefaultLocalAddress() - if laddr != nil { return laddr } + laddr = GetDefaultLocalAddress() + if laddr != nil { + return laddr + } - panic("Could not determine local address") + panic("Could not determine local address") } // UPNP external address discovery & port mapping // TODO: more flexible internal & external ports func GetUPNPLocalAddress() *NetAddress { - nat, err := Discover() - if err != nil { return nil } + nat, err := Discover() + if err != nil { + return nil + } - ext, err := nat.GetExternalAddress() - if err != nil { return nil } + ext, err := nat.GetExternalAddress() + if err != nil { + return nil + } - _, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0) - if err != nil { return nil } + _, err = nat.AddPortMapping("tcp", DEFAULT_PORT, DEFAULT_PORT, "tendermint", 0) + if err != nil { + return nil + } - return NewNetAddressIPPort(ext, DEFAULT_PORT) + return NewNetAddressIPPort(ext, DEFAULT_PORT) } // Naive local IPv4 interface address detection // TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh func GetDefaultLocalAddress() *NetAddress { - addrs, err := net.InterfaceAddrs() - if err != nil { Panicf("Unexpected error fetching interface addresses: %v", err) } + addrs, err := net.InterfaceAddrs() + if err != nil { + Panicf("Unexpected error fetching interface addresses: %v", err) + } - for _, a := range addrs { - ipnet, ok := a.(*net.IPNet) - if !ok { continue } - v4 := ipnet.IP.To4() - if v4 == nil || v4[0] == 127 { continue } // loopback - return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT) - } - return nil + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { + continue + } // loopback + return NewNetAddressIPPort(ipnet.IP, DEFAULT_PORT) + } + return nil } - - diff --git a/peer/log.go b/peer/log.go index 4fbb1e8e..c9a3702e 100644 --- a/peer/log.go +++ b/peer/log.go @@ -1,14 +1,14 @@ package peer import ( - "github.com/cihub/seelog" + "github.com/cihub/seelog" ) var log seelog.LoggerInterface func init() { - // TODO: replace with configuration file in the ~/.tendermint directory. - config := ` + // TODO: replace with configuration file in the ~/.tendermint directory. + config := ` @@ -19,7 +19,9 @@ func init() { ` - var err error - log, err = seelog.LoggerFromConfigAsBytes([]byte(config)) - if err != nil { panic(err) } + var err error + log, err = seelog.LoggerFromConfigAsBytes([]byte(config)) + if err != nil { + panic(err) + } } diff --git a/peer/msg.go b/peer/msg.go index 396612c5..9f07ef73 100644 --- a/peer/msg.go +++ b/peer/msg.go @@ -1,59 +1,61 @@ package peer import ( - . "github.com/tendermint/tendermint/binary" - "io" + . "github.com/tendermint/tendermint/binary" + "io" ) /* Packet */ type Packet struct { - Channel String - Bytes ByteSlice - // Hash + Channel String + Bytes ByteSlice + // Hash } func NewPacket(chName String, bytes ByteSlice) Packet { - return Packet{ - Channel: chName, - Bytes: bytes, - } + return Packet{ + Channel: chName, + Bytes: bytes, + } } func (p Packet) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(&p.Channel, w, n, err) - n, err = WriteOnto(&p.Bytes, w, n, err) - return + n, err = WriteOnto(&p.Channel, w, n, err) + n, err = WriteOnto(&p.Bytes, w, n, err) + return } func ReadPacketSafe(r io.Reader) (pkt Packet, err error) { - chName, err := ReadStringSafe(r) - if err != nil { return } - // TODO: packet length sanity check. - bytes, err := ReadByteSliceSafe(r) - if err != nil { return } - return NewPacket(chName, bytes), nil + chName, err := ReadStringSafe(r) + if err != nil { + return + } + // TODO: packet length sanity check. + bytes, err := ReadByteSliceSafe(r) + if err != nil { + return + } + return NewPacket(chName, bytes), nil } - /* InboundPacket */ type InboundPacket struct { - Peer *Peer - Channel *Channel - Time Time - Packet + Peer *Peer + Channel *Channel + Time Time + Packet } - /* NewFilterMsg */ type NewFilterMsg struct { - ChName String - Filter interface{} // todo + ChName String + Filter interface{} // todo } func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) { - panic("TODO: implement") - return 0, nil // TODO + panic("TODO: implement") + return 0, nil // TODO } diff --git a/peer/netaddress.go b/peer/netaddress.go index a7d6bc14..df12c262 100644 --- a/peer/netaddress.go +++ b/peer/netaddress.go @@ -5,150 +5,158 @@ package peer import ( - . "github.com/tendermint/tendermint/common" - . "github.com/tendermint/tendermint/binary" - "io" - "net" - "strconv" + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/common" + "io" + "net" + "strconv" ) /* NetAddress */ type NetAddress struct { - IP net.IP - Port UInt16 + IP net.IP + Port UInt16 } // TODO: socks proxies? func NewNetAddress(addr net.Addr) *NetAddress { - tcpAddr, ok := addr.(*net.TCPAddr) - if !ok { Panicf("Only TCPAddrs are supported. Got: %v", addr) } - ip := tcpAddr.IP - port := UInt16(tcpAddr.Port) - return NewNetAddressIPPort(ip, port) + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + Panicf("Only TCPAddrs are supported. Got: %v", addr) + } + ip := tcpAddr.IP + port := UInt16(tcpAddr.Port) + return NewNetAddressIPPort(ip, port) } func NewNetAddressString(addr string) *NetAddress { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { panic(err) } - ip := net.ParseIP(host) - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { panic(err) } - na := NewNetAddressIPPort(ip, UInt16(port)) - return na + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + panic(err) + } + ip := net.ParseIP(host) + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + panic(err) + } + na := NewNetAddressIPPort(ip, UInt16(port)) + return na } func NewNetAddressIPPort(ip net.IP, port UInt16) *NetAddress { - na := NetAddress{ - IP: ip, - Port: port, - } - return &na + na := NetAddress{ + IP: ip, + Port: port, + } + return &na } func ReadNetAddress(r io.Reader) *NetAddress { - return &NetAddress{ - IP: net.IP(ReadByteSlice(r)), - Port: ReadUInt16(r), - } + return &NetAddress{ + IP: net.IP(ReadByteSlice(r)), + Port: ReadUInt16(r), + } } func (na *NetAddress) WriteTo(w io.Writer) (n int64, err error) { - n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err) - n, err = WriteOnto(na.Port, w, n, err) - return + n, err = WriteOnto(ByteSlice(na.IP.To16()), w, n, err) + n, err = WriteOnto(na.Port, w, n, err) + return } func (na *NetAddress) Equals(other Binary) bool { - if o, ok := other.(*NetAddress); ok { - return na.String() == o.String() - } else { - return false - } + if o, ok := other.(*NetAddress); ok { + return na.String() == o.String() + } else { + return false + } } func (na *NetAddress) Less(other Binary) bool { - if o, ok := other.(*NetAddress); ok { - return na.String() < o.String() - } else { - panic("Cannot compare unequal types") - } + if o, ok := other.(*NetAddress); ok { + return na.String() < o.String() + } else { + panic("Cannot compare unequal types") + } } func (na *NetAddress) String() string { - port := strconv.FormatUint(uint64(na.Port), 10) - addr := net.JoinHostPort(na.IP.String(), port) - return addr + port := strconv.FormatUint(uint64(na.Port), 10) + addr := net.JoinHostPort(na.IP.String(), port) + return addr } func (na *NetAddress) Dial() (*Connection, error) { - conn, err := net.Dial("tcp", na.String()) - if err != nil { return nil, err } - return NewConnection(conn), nil + conn, err := net.Dial("tcp", na.String()) + if err != nil { + return nil, err + } + return NewConnection(conn), nil } func (na *NetAddress) Routable() bool { - // TODO(oga) bitcoind doesn't include RFC3849 here, but should we? - return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || - na.RFC4193() || na.RFC4843() || na.Local()) + // TODO(oga) bitcoind doesn't include RFC3849 here, but should we? + return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || + na.RFC4193() || na.RFC4843() || na.Local()) } // For IPv4 these are either a 0 or all bits set address. For IPv6 a zero // address or one that matches the RFC3849 documentation address format. func (na *NetAddress) Valid() bool { - return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || - na.IP.Equal(net.IPv4bcast)) + return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || + na.IP.Equal(net.IPv4bcast)) } func (na *NetAddress) Local() bool { - return na.IP.IsLoopback() || zero4.Contains(na.IP) + return na.IP.IsLoopback() || zero4.Contains(na.IP) } func (na *NetAddress) ReachabilityTo(o *NetAddress) int { - const ( - Unreachable = 0 - Default = iota - Teredo - Ipv6_weak - Ipv4 - Ipv6_strong - Private - ) - if !na.Routable() { - return Unreachable - } else if na.RFC4380() { - if !o.Routable() { - return Default - } else if o.RFC4380() { - return Teredo - } else if o.IP.To4() != nil { - return Ipv4 - } else { // ipv6 - return Ipv6_weak - } - } else if na.IP.To4() != nil { - if o.Routable() && o.IP.To4() != nil { - return Ipv4 - } - return Default - } else /* ipv6 */ { - var tunnelled bool - // Is our v6 is tunnelled? - if o.RFC3964() || o.RFC6052() || o.RFC6145() { - tunnelled = true - } - if !o.Routable() { - return Default - } else if o.RFC4380() { - return Teredo - } else if o.IP.To4() != nil { - return Ipv4 - } else if tunnelled { - // only prioritise ipv6 if we aren't tunnelling it. - return Ipv6_weak - } - return Ipv6_strong - } + const ( + Unreachable = 0 + Default = iota + Teredo + Ipv6_weak + Ipv4 + Ipv6_strong + Private + ) + if !na.Routable() { + return Unreachable + } else if na.RFC4380() { + if !o.Routable() { + return Default + } else if o.RFC4380() { + return Teredo + } else if o.IP.To4() != nil { + return Ipv4 + } else { // ipv6 + return Ipv6_weak + } + } else if na.IP.To4() != nil { + if o.Routable() && o.IP.To4() != nil { + return Ipv4 + } + return Default + } else /* ipv6 */ { + var tunnelled bool + // Is our v6 is tunnelled? + if o.RFC3964() || o.RFC6052() || o.RFC6145() { + tunnelled = true + } + if !o.Routable() { + return Default + } else if o.RFC4380() { + return Teredo + } else if o.IP.To4() != nil { + return Ipv4 + } else if tunnelled { + // only prioritise ipv6 if we aren't tunnelling it. + return Ipv6_weak + } + return Ipv6_strong + } } // RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) @@ -161,23 +169,25 @@ func (na *NetAddress) ReachabilityTo(o *NetAddress) int { // RFC4862: IPv6 Autoconfig (FE80::/64) // RFC6052: IPv6 well known prefix (64:FF9B::/96) // RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96 -var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} -var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} -var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} -var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} -var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} -var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} -var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} -var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} -var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} -var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} -var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} -var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} -var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} +var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} +var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} +var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} +var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} +var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} +var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} +var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} +var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} +var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} +var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} +var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} +var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} +var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} -func (na *NetAddress) RFC1918() bool { return rfc1918_10.Contains(na.IP) || - rfc1918_192.Contains(na.IP) || - rfc1918_172.Contains(na.IP) } +func (na *NetAddress) RFC1918() bool { + return rfc1918_10.Contains(na.IP) || + rfc1918_192.Contains(na.IP) || + rfc1918_172.Contains(na.IP) +} func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) } func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) } func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) } diff --git a/peer/peer.go b/peer/peer.go index c9e9a16d..beea28bf 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1,172 +1,174 @@ package peer import ( - . "github.com/tendermint/tendermint/binary" - "sync/atomic" - "sync" - "io" - "time" - "fmt" + "fmt" + . "github.com/tendermint/tendermint/binary" + "io" + "sync" + "sync/atomic" + "time" ) /* Peer */ type Peer struct { - outgoing bool - conn *Connection - channels map[String]*Channel + outgoing bool + conn *Connection + channels map[String]*Channel - mtx sync.Mutex - quit chan struct{} - stopped uint32 + mtx sync.Mutex + quit chan struct{} + stopped uint32 } func NewPeer(conn *Connection) *Peer { - return &Peer{ - conn: conn, - quit: make(chan struct{}), - stopped: 0, - } + return &Peer{ + conn: conn, + quit: make(chan struct{}), + stopped: 0, + } } -func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket ) { - log.Debugf("Starting %v", p) - p.conn.Start(p.channels) - for chName, _ := range p.channels { - go p.recvHandler(chName, peerRecvQueues[chName]) - go p.sendHandler(chName) - } +func (p *Peer) Start(peerRecvQueues map[String]chan *InboundPacket) { + log.Debugf("Starting %v", p) + p.conn.Start(p.channels) + for chName, _ := range p.channels { + go p.recvHandler(chName, peerRecvQueues[chName]) + go p.sendHandler(chName) + } } func (p *Peer) Stop() { - // lock - p.mtx.Lock() - if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { - log.Debugf("Stopping %v", p) - close(p.quit) - p.conn.Stop() - } - p.mtx.Unlock() - // unlock + // lock + p.mtx.Lock() + if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { + log.Debugf("Stopping %v", p) + close(p.quit) + p.conn.Stop() + } + p.mtx.Unlock() + // unlock } func (p *Peer) LocalAddress() *NetAddress { - return p.conn.LocalAddress() + return p.conn.LocalAddress() } func (p *Peer) RemoteAddress() *NetAddress { - return p.conn.RemoteAddress() + return p.conn.RemoteAddress() } func (p *Peer) Channel(chName String) *Channel { - return p.channels[chName] + return p.channels[chName] } // If the channel's queue is full, just return false. // Later the sendHandler will send the pkt to the underlying connection. func (p *Peer) TrySend(pkt Packet) bool { - channel := p.Channel(pkt.Channel) - sendQueue := channel.SendQueue() + channel := p.Channel(pkt.Channel) + sendQueue := channel.SendQueue() - // lock & defer - p.mtx.Lock(); defer p.mtx.Unlock() - if p.stopped == 1 { return false } - select { - case sendQueue <- pkt: - return true - default: // buffer full - return false - } - // unlock deferred + // lock & defer + p.mtx.Lock() + defer p.mtx.Unlock() + if p.stopped == 1 { + return false + } + select { + case sendQueue <- pkt: + return true + default: // buffer full + return false + } + // unlock deferred } func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { - return p.RemoteAddress().WriteTo(w) + return p.RemoteAddress().WriteTo(w) } func (p *Peer) String() string { - return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing) + return fmt.Sprintf("Peer{%v-%v,o:%v}", p.LocalAddress(), p.RemoteAddress(), p.outgoing) } func (p *Peer) recvHandler(chName String, inboundPacketQueue chan<- *InboundPacket) { - log.Tracef("%v recvHandler [%v]", p, chName) - channel := p.channels[chName] - recvQueue := channel.RecvQueue() + log.Tracef("%v recvHandler [%v]", p, chName) + channel := p.channels[chName] + recvQueue := channel.RecvQueue() - FOR_LOOP: - for { - select { - case <-p.quit: - break FOR_LOOP - case pkt := <-recvQueue: - // send to inboundPacketQueue - inboundPacket := &InboundPacket{ - Peer: p, - Channel: channel, - Time: Time{time.Now()}, - Packet: pkt, - } - select { - case <-p.quit: - break FOR_LOOP - case inboundPacketQueue <- inboundPacket: - continue - } - } - } +FOR_LOOP: + for { + select { + case <-p.quit: + break FOR_LOOP + case pkt := <-recvQueue: + // send to inboundPacketQueue + inboundPacket := &InboundPacket{ + Peer: p, + Channel: channel, + Time: Time{time.Now()}, + Packet: pkt, + } + select { + case <-p.quit: + break FOR_LOOP + case inboundPacketQueue <- inboundPacket: + continue + } + } + } - log.Tracef("%v recvHandler [%v] closed", p, chName) - // cleanup - // (none) + log.Tracef("%v recvHandler [%v] closed", p, chName) + // cleanup + // (none) } func (p *Peer) sendHandler(chName String) { - log.Tracef("%v sendHandler [%v]", p, chName) - chSendQueue := p.channels[chName].sendQueue - FOR_LOOP: - for { - select { - case <-p.quit: - break FOR_LOOP - case pkt := <-chSendQueue: - log.Tracef("Sending packet to peer chSendQueue") - // blocks until the connection is Stop'd, - // which happens when this peer is Stop'd. - p.conn.Send(pkt) - } - } + log.Tracef("%v sendHandler [%v]", p, chName) + chSendQueue := p.channels[chName].sendQueue +FOR_LOOP: + for { + select { + case <-p.quit: + break FOR_LOOP + case pkt := <-chSendQueue: + log.Tracef("Sending packet to peer chSendQueue") + // blocks until the connection is Stop'd, + // which happens when this peer is Stop'd. + p.conn.Send(pkt) + } + } - log.Tracef("%v sendHandler [%v] closed", p, chName) - // cleanup - // (none) + log.Tracef("%v sendHandler [%v] closed", p, chName) + // cleanup + // (none) } - /* Channel */ type Channel struct { - name String - recvQueue chan Packet - sendQueue chan Packet - //stats Stats + name String + recvQueue chan Packet + sendQueue chan Packet + //stats Stats } func NewChannel(name String, bufferSize int) *Channel { - return &Channel{ - name: name, - recvQueue: make(chan Packet, bufferSize), - sendQueue: make(chan Packet, bufferSize), - } + return &Channel{ + name: name, + recvQueue: make(chan Packet, bufferSize), + sendQueue: make(chan Packet, bufferSize), + } } func (c *Channel) Name() String { - return c.name + return c.name } func (c *Channel) RecvQueue() <-chan Packet { - return c.recvQueue + return c.recvQueue } func (c *Channel) SendQueue() chan<- Packet { - return c.sendQueue + return c.sendQueue } diff --git a/peer/server.go b/peer/server.go index da21be90..9749ed2a 100644 --- a/peer/server.go +++ b/peer/server.go @@ -1,39 +1,38 @@ package peer -import ( -) +import () /* Server */ type Server struct { - listener Listener - client *Client + listener Listener + client *Client } func NewServer(protocol string, laddr string, c *Client) *Server { - l := NewDefaultListener(protocol, laddr) - s := &Server{ - listener: l, - client: c, - } - go s.IncomingConnectionHandler() - return s + l := NewDefaultListener(protocol, laddr) + s := &Server{ + listener: l, + client: c, + } + go s.IncomingConnectionHandler() + return s } func (s *Server) LocalAddress() *NetAddress { - return s.listener.LocalAddress() + return s.listener.LocalAddress() } // meant to run in a goroutine func (s *Server) IncomingConnectionHandler() { - for conn := range s.listener.Connections() { - log.Infof("New connection found: %v", conn) - s.client.AddPeerWithConnection(conn, false) - } + for conn := range s.listener.Connections() { + log.Infof("New connection found: %v", conn) + s.client.AddPeerWithConnection(conn, false) + } } func (s *Server) Stop() { - log.Infof("Stopping server") - s.listener.Stop() - s.client.Stop() + log.Infof("Stopping server") + s.listener.Stop() + s.client.Stop() } diff --git a/peer/upnp.go b/peer/upnp.go index 920bc329..db8d66f9 100644 --- a/peer/upnp.go +++ b/peer/upnp.go @@ -7,370 +7,370 @@ package peer // import ( - "bytes" - "encoding/xml" - "errors" - "io/ioutil" - "net" - "net/http" - "strconv" - "strings" - "time" + "bytes" + "encoding/xml" + "errors" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + "time" ) type upnpNAT struct { - serviceURL string - ourIP string - urnDomain string + serviceURL string + ourIP string + urnDomain string } // protocol is either "udp" or "tcp" type NAT interface { - GetExternalAddress() (addr net.IP, err error) - AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) - DeletePortMapping(protocol string, externalPort, internalPort int) (err error) + GetExternalAddress() (addr net.IP, err error) + AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) + DeletePortMapping(protocol string, externalPort, internalPort int) (err error) } func Discover() (nat NAT, err error) { - ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") - if err != nil { - return - } - conn, err := net.ListenPacket("udp4", ":0") - if err != nil { - return - } - socket := conn.(*net.UDPConn) - defer socket.Close() + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return + } + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return + } + socket := conn.(*net.UDPConn) + defer socket.Close() - err = socket.SetDeadline(time.Now().Add(3 * time.Second)) - if err != nil { - return - } + err = socket.SetDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } - st := "InternetGatewayDevice:1" + st := "InternetGatewayDevice:1" - buf := bytes.NewBufferString( - "M-SEARCH * HTTP/1.1\r\n" + - "HOST: 239.255.255.250:1900\r\n" + - "ST: ssdp:all\r\n" + - "MAN: \"ssdp:discover\"\r\n" + - "MX: 2\r\n\r\n") - message := buf.Bytes() - answerBytes := make([]byte, 1024) - for i := 0; i < 3; i++ { - _, err = socket.WriteToUDP(message, ssdp) - if err != nil { - return - } - var n int - n, _, err = socket.ReadFromUDP(answerBytes) - for { - n, _, err = socket.ReadFromUDP(answerBytes) - if err != nil { - break - } - answer := string(answerBytes[0:n]) - if strings.Index(answer, st) < 0 { - continue - } - // HTTP header field names are case-insensitive. - // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 - locString := "\r\nlocation:" - answer = strings.ToLower(answer) - locIndex := strings.Index(answer, locString) - if locIndex < 0 { - continue - } - loc := answer[locIndex+len(locString):] - endIndex := strings.Index(loc, "\r\n") - if endIndex < 0 { - continue - } - locURL := strings.TrimSpace(loc[0:endIndex]) - var serviceURL, urnDomain string - serviceURL, urnDomain, err = getServiceURL(locURL) - if err != nil { - return - } - var ourIP net.IP - ourIP, err = localIPv4() - if err != nil { - return - } - nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} - return - } - } - err = errors.New("UPnP port discovery failed.") - return + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + "ST: ssdp:all\r\n" + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < 3; i++ { + _, err = socket.WriteToUDP(message, ssdp) + if err != nil { + return + } + var n int + n, _, err = socket.ReadFromUDP(answerBytes) + for { + n, _, err = socket.ReadFromUDP(answerBytes) + if err != nil { + break + } + answer := string(answerBytes[0:n]) + if strings.Index(answer, st) < 0 { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation:" + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := strings.TrimSpace(loc[0:endIndex]) + var serviceURL, urnDomain string + serviceURL, urnDomain, err = getServiceURL(locURL) + if err != nil { + return + } + var ourIP net.IP + ourIP, err = localIPv4() + if err != nil { + return + } + nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} + return + } + } + err = errors.New("UPnP port discovery failed.") + return } type Envelope struct { - XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` - Soap *SoapBody + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` + Soap *SoapBody } type SoapBody struct { - XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` - ExternalIP *ExternalIPAddressResponse + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` + ExternalIP *ExternalIPAddressResponse } type ExternalIPAddressResponse struct { - XMLName xml.Name `xml:"GetExternalIPAddressResponse"` - IPAddress string `xml:"NewExternalIPAddress"` + XMLName xml.Name `xml:"GetExternalIPAddressResponse"` + IPAddress string `xml:"NewExternalIPAddress"` } type ExternalIPAddress struct { - XMLName xml.Name `xml:"NewExternalIPAddress"` - IP string + XMLName xml.Name `xml:"NewExternalIPAddress"` + IP string } type Service struct { - ServiceType string `xml:"serviceType"` - ControlURL string `xml:"controlURL"` + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` } type DeviceList struct { - Device []Device `xml:"device"` + Device []Device `xml:"device"` } type ServiceList struct { - Service []Service `xml:"service"` + Service []Service `xml:"service"` } type Device struct { - XMLName xml.Name `xml:"device"` - DeviceType string `xml:"deviceType"` - DeviceList DeviceList `xml:"deviceList"` - ServiceList ServiceList `xml:"serviceList"` + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList DeviceList `xml:"deviceList"` + ServiceList ServiceList `xml:"serviceList"` } type Root struct { - Device Device + Device Device } func getChildDevice(d *Device, deviceType string) *Device { - dl := d.DeviceList.Device - for i := 0; i < len(dl); i++ { - if strings.Index(dl[i].DeviceType, deviceType) >= 0 { - return &dl[i] - } - } - return nil + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if strings.Index(dl[i].DeviceType, deviceType) >= 0 { + return &dl[i] + } + } + return nil } func getChildService(d *Device, serviceType string) *Service { - sl := d.ServiceList.Service - for i := 0; i < len(sl); i++ { - if strings.Index(sl[i].ServiceType, serviceType) >= 0 { - return &sl[i] - } - } - return nil + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if strings.Index(sl[i].ServiceType, serviceType) >= 0 { + return &sl[i] + } + } + return nil } func localIPv4() (net.IP, error) { - tt, err := net.Interfaces() - if err != nil { - return nil, err - } - for _, t := range tt { - aa, err := t.Addrs() - if err != nil { - return nil, err - } - for _, a := range aa { - ipnet, ok := a.(*net.IPNet) - if !ok { - continue - } - v4 := ipnet.IP.To4() - if v4 == nil || v4[0] == 127 { // loopback address - continue - } - return v4, nil - } - } - return nil, errors.New("cannot find local IP address") + tt, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, t := range tt { + aa, err := t.Addrs() + if err != nil { + return nil, err + } + for _, a := range aa { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { // loopback address + continue + } + return v4, nil + } + } + return nil, errors.New("cannot find local IP address") } func getServiceURL(rootURL string) (url, urnDomain string, err error) { - r, err := http.Get(rootURL) - if err != nil { - return - } - defer r.Body.Close() - if r.StatusCode >= 400 { - err = errors.New(string(r.StatusCode)) - return - } - var root Root - err = xml.NewDecoder(r.Body).Decode(&root) - if err != nil { - return - } - a := &root.Device - if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { - err = errors.New("No InternetGatewayDevice") - return - } - b := getChildDevice(a, "WANDevice:1") - if b == nil { - err = errors.New("No WANDevice") - return - } - c := getChildDevice(b, "WANConnectionDevice:1") - if c == nil { - err = errors.New("No WANConnectionDevice") - return - } - d := getChildService(c, "WANIPConnection:1") - if d == nil { - // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, - // instead of under WanConnectionDevice - d = getChildService(b, "WANIPConnection:1") + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(string(r.StatusCode)) + return + } + var root Root + err = xml.NewDecoder(r.Body).Decode(&root) + if err != nil { + return + } + a := &root.Device + if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { + err = errors.New("No InternetGatewayDevice") + return + } + b := getChildDevice(a, "WANDevice:1") + if b == nil { + err = errors.New("No WANDevice") + return + } + c := getChildDevice(b, "WANConnectionDevice:1") + if c == nil { + err = errors.New("No WANConnectionDevice") + return + } + d := getChildService(c, "WANIPConnection:1") + if d == nil { + // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, + // instead of under WanConnectionDevice + d = getChildService(b, "WANIPConnection:1") - if d == nil { - err = errors.New("No WANIPConnection") - return - } - } - // Extract the domain name, which isn't always 'schemas-upnp-org' - urnDomain = strings.Split(d.ServiceType, ":")[1] - url = combineURL(rootURL, d.ControlURL) - return + if d == nil { + err = errors.New("No WANIPConnection") + return + } + } + // Extract the domain name, which isn't always 'schemas-upnp-org' + urnDomain = strings.Split(d.ServiceType, ":")[1] + url = combineURL(rootURL, d.ControlURL) + return } func combineURL(rootURL, subURL string) string { - protocolEnd := "://" - protoEndIndex := strings.Index(rootURL, protocolEnd) - a := rootURL[protoEndIndex+len(protocolEnd):] - rootIndex := strings.Index(a, "/") - return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL } func soapRequest(url, function, message, domain string) (r *http.Response, err error) { - fullMessage := "" + - "\r\n" + - "" + message + "" + fullMessage := "" + + "\r\n" + + "" + message + "" - req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") - req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") - //req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") - req.Header.Set("Connection", "Close") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Pragma", "no-cache") + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") - // log.Stderr("soapRequest ", req) + // log.Stderr("soapRequest ", req) - r, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - /*if r.Body != nil { - defer r.Body.Close() - }*/ + r, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + /*if r.Body != nil { + defer r.Body.Close() + }*/ - if r.StatusCode >= 400 { - // log.Stderr(function, r.StatusCode) - err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) - r = nil - return - } - return + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return } type statusInfo struct { - externalIpAddress string + externalIpAddress string } func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { - message := "\r\n" + - "" + message := "\r\n" + + "" - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) - if response != nil { - defer response.Body.Close() - } - if err != nil { - return - } - var envelope Envelope - data, err := ioutil.ReadAll(response.Body) - reader := bytes.NewReader(data) - xml.NewDecoder(reader).Decode(&envelope) + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + var envelope Envelope + data, err := ioutil.ReadAll(response.Body) + reader := bytes.NewReader(data) + xml.NewDecoder(reader).Decode(&envelope) - info = statusInfo{envelope.Soap.ExternalIP.IPAddress} + info = statusInfo{envelope.Soap.ExternalIP.IPAddress} - if err != nil { - return - } + if err != nil { + return + } - return + return } func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - info, err := n.getExternalIPAddress() - if err != nil { - return - } - addr = net.ParseIP(info.externalIpAddress) - return + info, err := n.getExternalIPAddress() + if err != nil { + return + } + addr = net.ParseIP(info.externalIpAddress) + return } func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { - // A single concatenation would break ARM compilation. - message := "\r\n" + - "" + strconv.Itoa(externalPort) - message += "" + protocol + "" - message += "" + strconv.Itoa(internalPort) + "" + - "" + n.ourIP + "" + - "1" - message += description + - "" + strconv.Itoa(timeout) + - "" + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(externalPort) + message += "" + protocol + "" + message += "" + strconv.Itoa(internalPort) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + strconv.Itoa(timeout) + + "" - var response *http.Response - response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) - if response != nil { - defer response.Body.Close() - } - if err != nil { - return - } + var response *http.Response + response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } - // TODO: check response to see if the port was forwarded - // log.Println(message, response) - mappedExternalPort = externalPort - _ = response - return + // TODO: check response to see if the port was forwarded + // log.Println(message, response) + mappedExternalPort = externalPort + _ = response + return } func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - message := "\r\n" + - "" + strconv.Itoa(externalPort) + - "" + protocol + "" + - "" + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" - var response *http.Response - response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) - if response != nil { - defer response.Body.Close() - } - if err != nil { - return - } + var response *http.Response + response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } - // TODO: check response to see if the port was deleted - // log.Println(message, response) - _ = response - return + // TODO: check response to see if the port was deleted + // log.Println(message, response) + _ = response + return } diff --git a/peer/upnp_test.go b/peer/upnp_test.go index 9735949f..082ed4ff 100644 --- a/peer/upnp_test.go +++ b/peer/upnp_test.go @@ -1,8 +1,8 @@ package peer import ( - "testing" - "time" + "testing" + "time" ) /* @@ -11,38 +11,38 @@ TODO: set up or find a service to probe open ports. */ func TestUPNP(t *testing.T) { - t.Log("hello!") + t.Log("hello!") - nat, err := Discover() - if err != nil { - t.Fatalf("NAT upnp could not be discovered: %v", err) - } + nat, err := Discover() + if err != nil { + t.Fatalf("NAT upnp could not be discovered: %v", err) + } - t.Log("ourIP: ", nat.(*upnpNAT).ourIP) + t.Log("ourIP: ", nat.(*upnpNAT).ourIP) - ext, err := nat.GetExternalAddress() - if err != nil { - t.Fatalf("External address error: %v", err) - } - t.Logf("External address: %v", ext) + ext, err := nat.GetExternalAddress() + if err != nil { + t.Fatalf("External address error: %v", err) + } + t.Logf("External address: %v", ext) - port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0) - if err != nil { - t.Fatalf("Port mapping error: %v", err) - } - t.Logf("Port mapping mapped: %v", port) + port, err := nat.AddPortMapping("tcp", 8001, 8001, "testing", 0) + if err != nil { + t.Fatalf("Port mapping error: %v", err) + } + t.Logf("Port mapping mapped: %v", port) - // also run the listener, open for all remote addresses. - listener := NewDefaultListener("tcp", "0.0.0.0:8001") + // also run the listener, open for all remote addresses. + listener := NewDefaultListener("tcp", "0.0.0.0:8001") - // now sleep for 10 seconds - time.Sleep(10 * time.Second) + // now sleep for 10 seconds + time.Sleep(10 * time.Second) - err = nat.DeletePortMapping("tcp", 8001, 8001) - if err != nil { - t.Fatalf("Port mapping delete error: %v", err) - } - t.Logf("Port mapping deleted") + err = nat.DeletePortMapping("tcp", 8001, 8001) + if err != nil { + t.Fatalf("Port mapping delete error: %v", err) + } + t.Logf("Port mapping deleted") - listener.Stop() + listener.Stop() } diff --git a/peer/util.go b/peer/util.go index e4d7fbfb..e9c1a6df 100644 --- a/peer/util.go +++ b/peer/util.go @@ -1,7 +1,7 @@ package peer import ( - "crypto/sha256" + "crypto/sha256" ) // DoubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes.