diff --git a/binary/byteslice.go b/binary/byteslice.go
index bcbde1b4..bfe92bea 100644
--- a/binary/byteslice.go
+++ b/binary/byteslice.go
@@ -40,3 +40,11 @@ func ReadByteSlice(r io.Reader) ByteSlice {
if err != nil { panic(err) }
return ByteSlice(bytes)
}
+
+func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) {
+ length := int(ReadUInt32(r))
+ bytes := make([]byte, length)
+ _, err := io.ReadFull(r, bytes)
+ if err != nil { return nil, err }
+ return ByteSlice(bytes), nil
+}
diff --git a/binary/int.go b/binary/int.go
index 9aeac271..85bc21c7 100644
--- a/binary/int.go
+++ b/binary/int.go
@@ -48,6 +48,13 @@ func ReadByte(r io.Reader) Byte {
return Byte(buf[0])
}
+func ReadByteSafe(r io.Reader) (Byte, error) {
+ buf := [1]byte{0}
+ _, err := io.ReadFull(r, buf[:])
+ if err != nil { return Byte(0), err }
+ return Byte(buf[0]), nil
+}
+
// Int8
diff --git a/merkle/iavl.go b/merkle/iavl_node.go
similarity index 81%
rename from merkle/iavl.go
rename to merkle/iavl_node.go
index 696d672e..d8312827 100644
--- a/merkle/iavl.go
+++ b/merkle/iavl_node.go
@@ -1,9 +1,3 @@
-/*
-This tree is not concurrency safe.
-If you want to use it from multiple goroutines, you need to wrap all calls to *IAVLTree
-with a mutex.
-*/
-
package merkle
import (
@@ -13,110 +7,6 @@ import (
"crypto/sha256"
)
-const HASH_BYTE_SIZE int = 4+32
-
-// Immutable AVL Tree (wraps the Node root)
-
-type IAVLTree struct {
- db Db
- root *IAVLNode
-}
-
-func NewIAVLTree(db Db) *IAVLTree {
- return &IAVLTree{db:db, root:nil}
-}
-
-func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree {
- root := &IAVLNode{
- hash: hash,
- flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
- }
- root.fill(db)
- return &IAVLTree{db:db, root:root}
-}
-
-func (t *IAVLTree) Root() Node {
- return t.root
-}
-
-func (t *IAVLTree) Size() uint64 {
- if t.root == nil { return 0 }
- return t.root.Size()
-}
-
-func (t *IAVLTree) Height() uint8 {
- if t.root == nil { return 0 }
- return t.root.Height()
-}
-
-func (t *IAVLTree) Has(key Key) bool {
- if t.root == nil { return false }
- return t.root.has(t.db, key)
-}
-
-func (t *IAVLTree) Put(key Key, value Value) (updated bool) {
- if t.root == nil {
- t.root = NewIAVLNode(key, value)
- return false
- }
- t.root, updated = t.root.put(t.db, key, value)
- return updated
-}
-
-func (t *IAVLTree) Hash() (ByteSlice, uint64) {
- if t.root == nil { return nil, 0 }
- return t.root.Hash()
-}
-
-func (t *IAVLTree) Save() {
- if t.root == nil { return }
- if t.root.hash == nil {
- t.root.Hash()
- }
- t.root.Save(t.db)
-}
-
-func (t *IAVLTree) Get(key Key) (value Value) {
- if t.root == nil { return nil }
- return t.root.get(t.db, key)
-}
-
-func (t *IAVLTree) Remove(key Key) (value Value, err error) {
- if t.root == nil { return nil, NotFound(key) }
- newRoot, _, value, err := t.root.remove(t.db, key)
- if err != nil {
- return nil, err
- }
- t.root = newRoot
- return value, nil
-}
-
-func (t *IAVLTree) Copy() Tree {
- return &IAVLTree{db:t.db, root:t.root}
-}
-
-// Traverses all the nodes of the tree in prefix order.
-// return true from cb to halt iteration.
-// node.Height() == 0 if you just want a value node.
-func (t *IAVLTree) Traverse(cb func(Node) bool) {
- if t.root == nil { return }
- t.root.traverse(t.db, cb)
-}
-
-func (t *IAVLTree) Values() <-chan Value {
- root := t.root
- ch := make(chan Value)
- go func() {
- root.traverse(t.db, func(n Node) bool {
- if n.Height() == 0 { ch <- n.Value() }
- return true
- })
- close(ch)
- }()
- return ch
-}
-
-
// Node
type IAVLNode struct {
diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go
new file mode 100644
index 00000000..de2c6930
--- /dev/null
+++ b/merkle/iavl_tree.go
@@ -0,0 +1,108 @@
+package merkle
+
+const HASH_BYTE_SIZE int = 4+32
+
+/*
+Immutable AVL Tree (wraps the Node root)
+
+This tree is not concurrency safe.
+You must wrap your calls with your own mutex.
+*/
+type IAVLTree struct {
+ db Db
+ root *IAVLNode
+}
+
+func NewIAVLTree(db Db) *IAVLTree {
+ return &IAVLTree{db:db, root:nil}
+}
+
+func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree {
+ root := &IAVLNode{
+ hash: hash,
+ flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
+ }
+ root.fill(db)
+ return &IAVLTree{db:db, root:root}
+}
+
+func (t *IAVLTree) Root() Node {
+ return t.root
+}
+
+func (t *IAVLTree) Size() uint64 {
+ if t.root == nil { return 0 }
+ return t.root.Size()
+}
+
+func (t *IAVLTree) Height() uint8 {
+ if t.root == nil { return 0 }
+ return t.root.Height()
+}
+
+func (t *IAVLTree) Has(key Key) bool {
+ if t.root == nil { return false }
+ return t.root.has(t.db, key)
+}
+
+func (t *IAVLTree) Put(key Key, value Value) (updated bool) {
+ if t.root == nil {
+ t.root = NewIAVLNode(key, value)
+ return false
+ }
+ t.root, updated = t.root.put(t.db, key, value)
+ return updated
+}
+
+func (t *IAVLTree) Hash() (ByteSlice, uint64) {
+ if t.root == nil { return nil, 0 }
+ return t.root.Hash()
+}
+
+func (t *IAVLTree) Save() {
+ if t.root == nil { return }
+ if t.root.hash == nil {
+ t.root.Hash()
+ }
+ t.root.Save(t.db)
+}
+
+func (t *IAVLTree) Get(key Key) (value Value) {
+ if t.root == nil { return nil }
+ return t.root.get(t.db, key)
+}
+
+func (t *IAVLTree) Remove(key Key) (value Value, err error) {
+ if t.root == nil { return nil, NotFound(key) }
+ newRoot, _, value, err := t.root.remove(t.db, key)
+ if err != nil {
+ return nil, err
+ }
+ t.root = newRoot
+ return value, nil
+}
+
+func (t *IAVLTree) Copy() Tree {
+ return &IAVLTree{db:t.db, root:t.root}
+}
+
+// Traverses all the nodes of the tree in prefix order.
+// return true from cb to halt iteration.
+// node.Height() == 0 if you just want a value node.
+func (t *IAVLTree) Traverse(cb func(Node) bool) {
+ if t.root == nil { return }
+ t.root.traverse(t.db, cb)
+}
+
+func (t *IAVLTree) Values() <-chan Value {
+ root := t.root
+ ch := make(chan Value)
+ go func() {
+ root.traverse(t.db, func(n Node) bool {
+ if n.Height() == 0 { ch <- n.Value() }
+ return true
+ })
+ close(ch)
+ }()
+ return ch
+}
diff --git a/peer/addrbook.go b/peer/addrbook.go
index c5381498..a833de8b 100644
--- a/peer/addrbook.go
+++ b/peer/addrbook.go
@@ -36,9 +36,6 @@ type AddrBook struct {
quit chan struct{}
nOld int
nNew int
-
- lamtx sync.Mutex
- localAddresses map[string]*localAddress
}
const (
@@ -474,83 +471,6 @@ func (a *AddrBook) getOldBucket(addr *NetAddress) int {
}
-/* Local Address */
-
-// addressPrio is an enum type used to describe the heirarchy of local address
-// discovery methods.
-type addressPrio int
-
-const (
- InterfacePrio addressPrio = iota // address of local interface.
- BoundPrio // Address explicitly bound to.
- UpnpPrio // External IP discovered from UPnP
- HttpPrio // Obtained from internet service.
- ManualPrio // provided by --externalip.
-)
-
-type localAddress struct {
- Addr *NetAddress
- Score addressPrio
-}
-
-func (a *AddrBook) AddLocalAddress(addr *NetAddress, priority addressPrio) {
- a.mtx.Lock(); defer a.mtx.Unlock()
-
- // sanity check.
- if !addr.Routable() {
- log.Debugf("rejecting address %s:%d due to routability", addr.IP, addr.Port)
- return
- }
- log.Debugf("adding address %s:%d", addr.IP, addr.Port)
-
- key := addr.String()
- la, ok := a.localAddresses[key]
- if !ok || la.Score < priority {
- if ok {
- la.Score = priority + 1
- } else {
- a.localAddresses[key] = &localAddress{
- Addr: addr,
- Score: priority,
- }
- }
- }
-}
-
-// Returns the most appropriate local address that we know
-// of to be contacted by rna (remote net address)
-func (a *AddrBook) GetBestLocalAddress(rna *NetAddress) *NetAddress {
- a.mtx.Lock(); defer a.mtx.Unlock()
-
- bestReach := 0
- var bestScore addressPrio
- var bestAddr *NetAddress
- for _, la := range a.localAddresses {
- reach := rna.ReachabilityTo(la.Addr)
- if reach > bestReach ||
- (reach == bestReach && la.Score > bestScore) {
- bestReach = reach
- bestScore = la.Score
- bestAddr = la.Addr
- }
- }
- if bestAddr != nil {
- log.Debugf("Suggesting address %s:%d for %s:%d",
- bestAddr.IP, bestAddr.Port, rna.IP, rna.Port)
- } else {
- log.Debugf("No worthy address for %s:%d",
- rna.IP, rna.Port)
- // Send something unroutable if nothing suitable.
- bestAddr = &NetAddress{
- IP: net.IP([]byte{0, 0, 0, 0}),
- Port: 0,
- }
- }
-
- return bestAddr
-}
-
-
// Return a string representing the network group of this address.
// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string
// "local" for a local address and the string "unroutable for an unroutable
diff --git a/peer/client.go b/peer/client.go
index 567f6d77..c0962b28 100644
--- a/peer/client.go
+++ b/peer/client.go
@@ -3,153 +3,169 @@ package peer
import (
. "github.com/tendermint/tendermint/binary"
"github.com/tendermint/tendermint/merkle"
+ "atomic"
"sync"
"io"
+ "errors"
)
-/* Client */
+/* Client
+
+ A client is half of a p2p system.
+ It can reach out to the network and establish connections with servers.
+ A client doesn't listen for incoming connections -- that's done by the server.
+
+ newPeerCb is a factory method for generating new peers from new *Connections.
+ newPeerCb(nil) must return a prototypical peer that represents the self "peer".
+
+ XXX what about peer disconnects?
+*/
type Client struct {
- listener *Listener
addrBook AddrBook
- strategies map[String]*FilterStrategy
targetNumPeers int
+ newPeerCb func(*Connection) *Peer
+ self *Peer
+ inQueues map[String]chan *InboundMsg
- peersMtx sync.Mutex
+ mtx sync.Mutex
peers merkle.Tree // addr -> *Peer
-
- filtersMtx sync.Mutex
- filters merkle.Tree // channelName -> Filter (objects that I know of)
+ quit chan struct{}
+ stopped uint32
}
-func NewClient(protocol string, laddr string) *Client {
- // XXX set the handler
- listener := NewListener(protocol, laddr, nil)
+var (
+ CLIENT_STOPPED_ERROR = errors.New("Client already stopped")
+ CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer")
+)
+
+func NewClient(newPeerCb func(*Connect) *Peer) *Client {
+ self := newPeerCb(nil)
+ if self == nil {
+ Panicf("newPeerCb(nil) must return a prototypical peer for self")
+ }
+
+ inQueues := make(map[String]chan *InboundMsg)
+ for chName, channel := peer.channels {
+ inQueues[chName] = make(chan *InboundMsg)
+ }
+
c := &Client{
- listener: listener,
+ newPeerCb: newPeerCb,
peers: merkle.NewIAVLTree(nil),
- filters: merkle.NewIAVLTree(nil),
+ self: self,
+ inQueues: inQueues,
}
return c
}
-func (c *Client) Start() (<-chan *IncomingMsg) {
- return nil
-}
-
func (c *Client) Stop() {
- c.listener.Close()
-}
-
-func (c *Client) LocalAddress() *NetAddress {
- return c.listener.LocalAddress()
-}
-
-func (c *Client) ConnectTo(addr *NetAddress) (*Peer, error) {
-
- conn, err := addr.Dial()
- if err != nil { return nil, err }
- peer := NewPeer(conn)
-
// lock
- c.peersMtx.Lock()
- c.peers.Put(addr, peer)
- c.peersMtx.Unlock()
+ c.mtx.Lock()
+ if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) {
+ close(c.quit)
+ // stop each peer.
+ for peerValue := range c.peers.Values() {
+ peer := peerValue.(*Peer)
+ peer.Stop()
+ }
+ // empty tree.
+ c.peers = merkle.NewIAVLTree(nil)
+ }
+ c.mtx.Unlock()
// unlock
+}
+
+func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) {
+ if atomic.LoadUint32(&c.stopped) == 1 { return nil, CLIENT_STOPPED_ERROR }
+
+ peer := c.newPeerCb(conn)
+ peer.outgoing = outgoing
+ err := c.addPeer(peer)
+ if err != nil { return nil, err }
+
+ go peer.Start(c.inQueues)
return peer, nil
}
-func (c *Client) Broadcast(channel String, msg Binary) {
+func (c *Client) Broadcast(chName String, msg Msg) {
+ if atomic.LoadUint32(&c.stopped) == 1 { return }
+
for v := range c.peersCopy().Values() {
- peer, ok := v.(*Peer)
- if !ok { panic("Expected peer but got something else") }
- peer.Queue(channel, msg)
+ peer := v.(*Peer)
+ success := peer.TryQueueOut(chName , msg)
+ if !success {
+ // TODO: notify the peer
+ }
}
}
-// Updates the client's filter for a channel & broadcasts it.
-func (c *Client) UpdateFilter(channel String, filter Filter) {
- c.filtersMtx.Lock()
- c.filters.Put(channel, filter)
- c.filtersMtx.Unlock()
+func (c *Client) PopMessage(chName String) *InboundMsg {
+ if atomic.LoadUint32(&c.stopped) == 1 { return nil }
+
+ channel := c.Channel(chName)
+ q := c.inQueues[chName]
+ if q == nil { Panicf("Expected inQueues[%f], found none", chName) }
+
+ for {
+ select {
+ case <-quit:
+ return nil
+ case msg := <-q:
+ // skip if known.
+ if channel.Filter().Has(msg) {
+ continue
+ }
+ return msg
+ }
+ }
+}
+
+// Updates self's filter for a channel & broadcasts it.
+// TODO: maybe don't expose this
+func (c *Client) UpdateFilter(chName String, filter Filter) {
+ if atomic.LoadUint32(&c.stopped) == 1 { return }
+
+ c.self.Channel(chName).UpdateFilter(filter)
c.Broadcast("", &NewFilterMsg{
- Channel: channel,
+ Channel: chName,
Filter: filter,
})
}
-func (c *Client) peersCopy() merkle.Tree {
- c.peersMtx.Lock(); defer c.peersMtx.Unlock()
- return c.peers.Copy()
-}
+func (c *Client) StopPeer(peer *Peer) {
+ // lock
+ c.mtx.Lock()
+ p, _ := c.peers.Remove(peer.RemoteAddress())
+ c.mtx.Unlock()
+ // unlock
-
-/* Channel */
-type Channel struct {
- Name String
- Filter Filter
- //Stats Stats
-}
-
-
-/* Peer */
-type Peer struct {
- Conn *Connection
- Channels map[String]*Channel
-}
-
-func NewPeer(conn *Connection) *Peer {
- return &Peer{
- Conn: conn,
- Channels: nil,
+ if p != nil {
+ p.Stop()
}
}
-// Must be quick and nonblocking.
-func (p *Peer) Queue(channel String, msg Binary) {}
+func (c *Client) addPeer(peer *Peer) error {
+ addr := peer.RemoteAddress()
-func (p *Peer) WriteTo(w io.Writer) (n int64, err error) {
- return 0, nil // TODO
+ // lock & defer
+ c.mtx.Lock(); defer c.mtx.Unlock()
+ if c.stopped == 1 { return CLIENT_STOPPED_ERROR }
+ if !c.peers.Has(addr) {
+ c.peers.Put(addr, peer)
+ return nil
+ } else {
+ // ignore duplicate peer for addr.
+ log.Infof("Ignoring duplicate peer for addr %v", addr)
+ return CLIENT_DUPLICATE_PEER_ERROR
+ }
+ // unlock deferred
}
-
-/* IncomingMsg */
-type IncomingMsg struct {
- SPeer *Peer
- SChan *Channel
-
- Time Time
-
- Msg Binary
-}
-
-
-/* Filter
-
- A Filter could be a bloom filter for lossy filtering, or could be a lossless filter.
- Either way, it's used to keep track of what a peer knows of.
-*/
-type Filter interface {
- Binary
- Add(ByteSlice)
- Has(ByteSlice) bool
-}
-
-/* FilterStrategy
-
- Defines how filters are generated per peer, and whether they need to get refreshed occasionally.
-*/
-type FilterStrategy interface {
- LoadFilter(ByteSlice) Filter
-}
-
-/* NewFilterMsg */
-type NewFilterMsg struct {
- Channel String
- Filter Filter
-}
-
-func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) {
- return 0, nil // TODO
+func (c *Client) peersCopy() merkle.Tree {
+ // lock & defer
+ c.mtx.Lock(); defer c.mtx.Unlock()
+ return c.peers.Copy()
+ // unlock deferred
}
diff --git a/peer/connection.go b/peer/connection.go
index 6d5b6e53..27e2f57c 100644
--- a/peer/connection.go
+++ b/peer/connection.go
@@ -3,6 +3,7 @@ package peer
import (
. "github.com/tendermint/tendermint/common"
. "github.com/tendermint/tendermint/binary"
+ "atomic"
"sync"
"net"
"runtime"
@@ -20,12 +21,10 @@ const (
type Connection struct {
ioStats IOStats
- mtx sync.Mutex
- outQueue chan ByteSlice
+ outQueue chan ByteSlice // never closes.
conn net.Conn
quit chan struct{}
- disconnected bool
-
+ stopped int32
pingDebouncer *Debouncer
pong chan struct{}
}
@@ -46,13 +45,14 @@ func NewConnection(conn net.Conn) *Connection {
}
}
-func (c *Connection) QueueMessage(msg ByteSlice) bool {
- c.mtx.Lock(); defer c.mtx.Unlock()
- if c.disconnected { return false }
+// returns true if successfully queued,
+// returns false if connection was closed.
+// blocks.
+func (c *Connection) QueueOut(msg ByteSlice) bool {
select {
case c.outQueue <- msg:
return true
- default: // buffer full
+ case <-c.quit:
return false
}
}
@@ -62,13 +62,25 @@ func (c *Connection) Start() {
go c.inHandler()
}
-func (c *Connection) Disconnect() {
- c.mtx.Lock(); defer c.mtx.Unlock()
- close(c.quit)
- c.conn.Close()
- c.pingDebouncer.Stop()
- // do not close c.pong
- c.disconnected = true
+func (c *Connection) Stop() {
+ if atomic.SwapAndCompare(&c.stopped, 0, 1) {
+ close(c.quit)
+ c.conn.Close()
+ c.pingDebouncer.Stop()
+ // We can't close pong safely here because
+ // inHandler may write to it after we've stopped.
+ // Though it doesn't need to get closed at all,
+ // we close it @ inHandler.
+ // close(c.pong)
+ }
+}
+
+func (c *Connection) LocalAddress() *NetAddress {
+ return NewNetAddress(c.conn.LocalAddr())
+}
+
+func (c *Connection) RemoteAddress() *NetAddress {
+ return NewNetAddress(c.conn.RemoteAddr())
}
func (c *Connection) flush() {
@@ -79,41 +91,42 @@ func (c *Connection) outHandler() {
FOR_LOOP:
for {
+ var err error
select {
case <-c.pingDebouncer.Ch:
- PACKET_TYPE_PING.WriteTo(c.conn)
+ _, err = PACKET_TYPE_PING.WriteTo(c.conn)
case outMsg := <-c.outQueue:
- _, err := outMsg.WriteTo(c.conn)
- if err != nil { Panicf("TODO: handle error %v", err) }
+ _, err = outMsg.WriteTo(c.conn)
case <-c.pong:
- PACKET_TYPE_PONG.WriteTo(c.conn)
+ _, err = PACKET_TYPE_PONG.WriteTo(c.conn)
case <-c.quit:
break FOR_LOOP
}
+
+ if err != nil {
+ log.Infof("Connection %v failed @ outHandler:\n%v", c, err)
+ c.Stop()
+ break FOR_LOOP
+ }
+
c.flush()
}
- // cleanup
- for _ = range c.outQueue {
- // do nothing but drain.
- }
}
func (c *Connection) inHandler() {
- defer func() {
- if e := recover(); e != nil {
- // Get stack trace
- buf := make([]byte, 1<<16)
- runtime.Stack(buf, false)
- // TODO do proper logging
- fmt.Printf("Disconnecting due to error:\n\n%v\n", string(buf))
- c.Disconnect()
- }
- }()
- //FOR_LOOP:
+ FOR_LOOP:
for {
- msgType := ReadUInt8(c.conn)
+ msgType, err := ReadUInt8Safe(c.conn)
+
+ if err != nil {
+ if atomic.LoadUint32(&c.stopped) != 1 {
+ log.Infof("Connection %v failed @ inHandler", c)
+ c.Stop()
+ }
+ break FOR_LOOP
+ }
switch msgType {
case PACKET_TYPE_PING:
@@ -121,12 +134,29 @@ func (c *Connection) inHandler() {
case PACKET_TYPE_PONG:
// do nothing
case PACKET_TYPE_MSG:
- ReadByteSlice(c.conn)
+ msg, err := ReadByteSliceSafe(c.conn)
+ if err != nil {
+ if atomic.LoadUint32(&c.stopped) != 1 {
+ log.Infof("Connection %v failed @ inHandler", c)
+ c.Stop()
+ }
+ break FOR_LOOP
+ }
+ // What to do?
+ // TODO
+
default:
Panicf("Unknown message type %v", msgType)
}
+
c.pingDebouncer.Reset()
}
+
+ // cleanup
+ close(c.pong)
+ for _ = range c.pong {
+ // drain
+ }
}
diff --git a/peer/connection_test.go b/peer/connection_test.go
index 99f6c00e..e72e15d8 100644
--- a/peer/connection_test.go
+++ b/peer/connection_test.go
@@ -2,15 +2,50 @@ package peer
import (
"testing"
+ "time"
)
func TestLocalConnection(t *testing.T) {
- c1 := NewClient("tcp", ":8080")
- c2 := NewClient("tcp", ":8081")
+ c1 := NewClient(func(conn *Connection) *Peer {
+ p := &Peer{conn: conn}
+ ch1 := NewChannel(String("ch1"),
+ nil,
+ // XXX these channels should be buffered.
+ make(chan ByteSlice),
+ make(chan ByteSlice),
+ )
+
+ ch2 := NewChannel(String("ch2"),
+ nil,
+ make(chan ByteSlice),
+ make(chan ByteSlice),
+ )
+
+ channels := make(map[String]*Channel)
+ channels[ch1.Name] = ch1
+ channels[ch2.Name] = ch2
+ p.channels = channels
+
+ return p
+ })
+
+ // XXX make c2 like c1.
+
+ c2 := NewClient(func(conn *Connection) *Peer {
+ return nil
+ })
+
+ // XXX clients don't have "local addresses"
c1.ConnectTo(c2.LocalAddress())
+ // lets send a message from c1 to c2.
+ c1.Broadcast(String(""), String("message"))
+ time.Sleep(500 * time.Millisecond)
+
+ inMsg := c2.PopMessage()
+
c1.Stop()
c2.Stop()
}
diff --git a/peer/filter.go b/peer/filter.go
new file mode 100644
index 00000000..c0177e8a
--- /dev/null
+++ b/peer/filter.go
@@ -0,0 +1,16 @@
+package peer
+
+/* Filter
+
+ A Filter could be a bloom filter for lossy filtering, or could be a lossless filter.
+ Either way, it's used to keep track of what a peer knows of.
+*/
+type Filter interface {
+ Binary
+ Add(Msg)
+ Has(Msg) bool
+
+ // Loads a new filter.
+ // Convenience factory method
+ Load(ByteSlice) Filter
+}
diff --git a/peer/listener.go b/peer/listener.go
index 72ec9e40..c19a5a2e 100644
--- a/peer/listener.go
+++ b/peer/listener.go
@@ -1,60 +1,76 @@
package peer
import (
- "sync"
+ "atomic"
"net"
)
/* Listener */
-type Listener struct {
- listener net.Listener
- handler func(net.Conn)
- mtx sync.Mutex
- closed bool
+type Listener interface {
+ Connections() <-chan *Connection
+ LocalAddress() *NetAddress
+ Stop()
}
-func NewListener(protocol string, laddr string, handler func(net.Conn)) *Listener {
+
+/* DefaultListener */
+
+type DefaultListener struct {
+ listener net.Listener
+ connections chan *Connection
+ stopped uint32
+}
+
+const (
+ DEFAULT_BUFFERED_CONNECTIONS = 10
+)
+
+func NewListener(protocol string, laddr string) *Listener {
ln, err := net.Listen(protocol, laddr)
if err != nil { panic(err) }
s := &Listener{
- listener: ln,
- handler: handler,
+ listener: ln,
+ connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS),
}
- go s.listen()
+ go l.listenHandler()
return s
}
-func (s *Listener) listen() {
+func (l *Listener) listenHandler() {
for {
- conn, err := s.listener.Accept()
- if err != nil {
- // lock & defer
- s.mtx.Lock(); defer s.mtx.Unlock()
- if s.closed {
- return
- } else {
- panic(err)
- }
- // unlock (deferred)
- }
+ conn, err := l.listener.Accept()
- go s.handler(conn)
+ if atomic.LoadUint32(&l.stopped) == 1 { return }
+
+ // listener wasn't stopped,
+ // yet we encountered an error.
+ if err != nil { panic(err) }
+
+ c := NewConnection(con)
+ l.connections <- c
+ }
+
+ // cleanup
+ close(l.connections)
+ for _ = range l.connections {
+ // drain
}
}
-func (s *Listener) LocalAddress() *NetAddress {
- return NewNetAddress(s.listener.Addr())
+func (l *Listener) Connections() <-chan *Connection {
+ return l.connections
}
-func (s *Listener) Close() {
- // lock
- s.mtx.Lock()
- s.closed = true
- s.mtx.Unlock()
- // unlock
- s.listener.Close()
+func (l *Listener) LocalAddress() *NetAddress {
+ return NewNetAddress(l.listener.Addr())
+}
+
+func (l *Listener) Stop() {
+ if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) {
+ l.listener.Close()
+ }
}
diff --git a/peer/msg.go b/peer/msg.go
new file mode 100644
index 00000000..6581c56f
--- /dev/null
+++ b/peer/msg.go
@@ -0,0 +1,30 @@
+package peer
+
+/* Msg */
+
+type Msg struct {
+ Bytes ByteSlice
+ Hash ByteSlice
+}
+
+
+/* InboundMsg */
+
+type InboundMsg struct {
+ Peer *Peer
+ Channel *Channel
+ Time Time
+ Msg
+}
+
+
+/* NewFilterMsg */
+
+type NewFilterMsg struct {
+ ChName String
+ Filter Filter
+}
+
+func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) {
+ return 0, nil // TODO
+}
diff --git a/peer/peer.go b/peer/peer.go
index 38dab2ff..e7f84d4a 100644
--- a/peer/peer.go
+++ b/peer/peer.go
@@ -1,715 +1,182 @@
package peer
import (
- "bytes"
- "container/list"
- "fmt"
- "github.com/davecgh/go-spew/spew"
- "github.com/tendermint/btcwire"
- "net"
- "strconv"
+ "atomic"
"sync"
- "sync/atomic"
- "time"
)
-const (
- // max protocol version the peer supports.
- maxProtocolVersion = 70001
+/* Peer */
- // number of elements the output channels use.
- outputBufferSize = 50
+type Peer struct {
+ outgoing bool
+ conn *Connection
+ channels map[String]*Channel
- // number of seconds of inactivity before we timeout a peer
- // that hasn't completed the initial version negotiation.
- negotiateTimeoutSeconds = 30
-
- // number of minutes of inactivity before we time out a peer.
- idleTimeoutMinutes = 5
-
- // number of minutes since we last sent a message
- // requiring a reply before we will ping a host.
- pingTimeoutMinutes = 2
-)
-
-var (
- userAgentName = "tendermintd"
- userAgentVersion = fmt.Sprintf("%d.%d.%d", appMajor, appMinor, appPatch)
-)
-
-// zeroHash is the zero value hash (all zeros). It is defined as a convenience.
-var zeroHash btcwire.ShaHash
-
-// minUint32 is a helper function to return the minimum of two uint32s.
-// This avoids a math import and the need to cast to floats.
-func minUint32(a, b uint32) uint32 {
- if a < b {
- return a
- }
- return b
+ mtx sync.Mutex
+ quit chan struct{}
+ stopped uint32
}
-// TODO(davec): Rename and comment this
-type outMsg struct {
- msg btcwire.Message
- doneChan chan bool
-}
-
-/*
-The overall data flow is split into 2 goroutines.
-
-Inbound messages are read via the inHandler goroutine and generally
-dispatched to their own handler.
-
-Outbound messages are queued via QueueMessage.
-*/
-type peer struct {
- server *server
- addr *NetAddress
- inbound bool
- persistent bool
-
- started bool // atomic
- quit chan bool
-
- conn net.Conn
- connMtx sync.Mutex
- disconnected bool // atomic && protected by connMtx
- knownAddresses map[string]bool
- outputQueue chan outMsg
-
- statMtx sync.Mutex // protects all below here.
- protocolVersion uint32
- timeConnected time.Time
- lastSend time.Time
- lastRecv time.Time
- bytesReceived uint64
- bytesSent uint64
- userAgent string
- lastPingNonce uint64 // Set to nonce if we have a pending ping.
- lastPingTime time.Time // Time we sent last ping.
- lastPingMicros int64 // Time for last ping to return.
-}
-
-// String returns the peer's address and directionality as a human-readable
-// string.
-func (p *peer) String() string {
- return fmt.Sprintf("%s (%s)", p.addr.String(), directionString(p.inbound))
-}
-
-// VersionKnown returns the whether or not the version of a peer is known locally.
-// It is safe for concurrent access.
-func (p *peer) VersionKnown() bool {
- p.statMtx.Lock(); defer p.statMtx.Unlock()
-
- return p.protocolVersion != 0
-}
-
-// ProtocolVersion returns the peer protocol version in a manner that is safe
-// for concurrent access.
-func (p *peer) ProtocolVersion() uint32 {
- p.statMtx.Lock(); defer p.statMtx.Unlock()
-
- return p.protocolVersion
-}
-
-// pushVersionMsg sends a version message to the connected peer using the
-// current state.
-func (p *peer) pushVersionMsg() {
- _, blockNum, err := p.server.db.NewestSha()
- if err != nil { panic(err) }
-
- // Version message.
- // TODO: DisableListen -> send zero address
- msg := btcwire.NewMsgVersion(
- p.server.addrManager.getBestLocalAddress(p.addr), p.addr,
- p.server.nonce, int32(blockNum))
- msg.AddUserAgent(userAgentName, userAgentVersion)
-
- // Advertise our max supported protocol version.
- msg.ProtocolVersion = maxProtocolVersion
-
- p.QueueMessage(msg, nil)
-}
-
-// handleVersionMsg is invoked when a peer receives a version bitcoin message
-// and is used to negotiate the protocol version details as well as kick start
-// the communications.
-func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) {
- // Detect self connections.
- if msg.Nonce == p.server.nonce {
- peerLog.Debugf("Disconnecting peer connected to self %s", p)
- p.Disconnect()
- return
- }
-
- p.statMtx.Lock() // Updating a bunch of stats.
- // Limit to one version message per peer.
- if p.protocolVersion != 0 {
- p.logError("Only one version message per peer is allowed %s.", p)
- p.statMtx.Unlock()
- p.Disconnect()
- return
- }
-
- // Negotiate the protocol version.
- p.protocolVersion = minUint32(p.protocolVersion, uint32(msg.ProtocolVersion))
- peerLog.Debugf("Negotiated protocol version %d for peer %s", p.protocolVersion, p)
-
- // Set the remote peer's user agent.
- p.userAgent = msg.UserAgent
-
- p.statMtx.Unlock()
-
- // Inbound connections.
- if p.inbound {
- // Send version.
- p.pushVersionMsg()
- }
-
- // Send verack.
- p.QueueMessage(btcwire.NewMsgVerAck(), nil)
-
- if p.inbound {
- // A peer might not be advertising the same address that it
- // actually connected from. One example of why this can happen
- // is with NAT. Only add the address to the address manager if
- // the addresses agree.
- if msg.AddrMe.String() == p.addr.String() {
- p.server.addrManager.AddAddress(p.addr, p.addr)
- }
- } else {
- // Request known addresses from the remote peer.
- if !cfg.SimNet && p.server.addrManager.NeedMoreAddresses() {
- p.QueueMessage(btcwire.NewMsgGetAddr(), nil)
- }
- }
-
- // Mark the address as a known good address.
- p.server.addrManager.MarkGood(p.addr)
-
- // Signal the block manager this peer is a new sync candidate.
- p.server.blockManager.NewPeer(p)
-
- // TODO: Relay alerts.
-}
-
-
-// handleGetAddrMsg is invoked when a peer receives a getaddr bitcoin message
-// and is used to provide the peer with known addresses from the address
-// manager.
-func (p *peer) handleGetAddrMsg(msg *btcwire.MsgGetAddr) {
- // Don't return any addresses when running on the simulation test
- // network. This helps prevent the network from becoming another
- // public test network since it will not be able to learn about other
- // peers that have not specifically been provided.
- if cfg.SimNet {
- return
- }
-
- // Get the current known addresses from the address manager.
- addrCache := p.server.addrManager.AddressCache()
-
- // Push the addresses.
- p.pushAddrMsg(addrCache)
-}
-
-// pushAddrMsg sends one, or more, addr message(s) to the connected peer using
-// the provided addresses.
-func (p *peer) pushAddrMsg(addresses []*NetAddress) {
- // Nothing to send.
- if len(addresses) == 0 { return }
-
- numAdded := 0
- msg := btcwire.NewMsgAddr()
- for _, addr := range addresses {
- // Filter addresses the peer already knows about.
- if p.knownAddresses[addr.String()] {
- continue
- }
-
- // Add the address to the message.
- err := msg.AddAddress(addr)
- if err != nil { panic(err) } // XXX remove error condition
- numAdded++
-
- // Split into multiple messages as needed.
- if numAdded > 0 && numAdded%btcwire.MaxAddrPerMsg == 0 {
- p.QueueMessage(msg, nil)
-
- // NOTE: This needs to be a new address message and not
- // simply call ClearAddresses since the message is a
- // pointer and queueing it does not make a copy.
- msg = btcwire.NewMsgAddr()
- }
- }
-
- // Send message with remaining addresses if needed.
- if numAdded%btcwire.MaxAddrPerMsg != 0 {
- p.QueueMessage(msg, nil)
+func (p *Peer) Start(peerInQueues map[String]chan *InboundMsg ) {
+ for chName, _ := range p.channels {
+ go p.inHandler(chName, peerInQueues[chName])
+ go p.outHandler(chName)
}
}
-// handleAddrMsg is invoked when a peer receives an addr bitcoin message and
-// is used to notify the server about advertised addresses.
-func (p *peer) handleAddrMsg(msg *btcwire.MsgAddr) {
- // Ignore addresses when running on the simulation test network. This
- // helps prevent the network from becoming another public test network
- // since it will not be able to learn about other peers that have not
- // specifically been provided.
- if cfg.SimNet {
- return
+func (p *Peer) Stop() {
+ // lock
+ p.mtx.Lock()
+ if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
+ close(p.quit)
+ p.conn.Stop()
}
-
- // A message that has no addresses is invalid.
- if len(msg.AddrList) == 0 {
- p.logError("Command [%s] from %s does not contain any addresses", msg.Command(), p)
- p.Disconnect()
- return
- }
-
- for _, addr := range msg.AddrList {
- // Set the timestamp to 5 days ago if it's more than 24 hours
- // in the future so this address is one of the first to be
- // removed when space is needed.
- now := time.Now()
- if addr.Timestamp.After(now.Add(time.Minute * 10)) {
- addr.Timestamp = now.Add(-1 * time.Hour * 24 * 5)
- }
-
- // Add address to known addresses for this peer.
- p.knownAddresses[addr.String()] = true
- }
-
- // Add addresses to server address manager. The address manager handles
- // the details of things such as preventing duplicate addresses, max
- // addresses, and last seen updates.
- // XXX bitcoind gives a 2 hour time penalty here, do we want to do the
- // same?
- p.server.addrManager.AddAddresses(msg.AddrList, p.addr)
+ p.mtx.Unlock()
+ // unlock
}
-func (p *peer) handlePingMsg(msg *btcwire.MsgPing) {
- // Include nonce from ping so pong can be identified.
- p.QueueMessage(btcwire.NewMsgPong(msg.Nonce), nil)
+func (p *Peer) LocalAddress() *NetAddress {
+ return p.conn.LocalAddress()
}
-func (p *peer) handlePongMsg(msg *btcwire.MsgPong) {
- p.statMtx.Lock(); defer p.statMtx.Unlock()
-
- // Arguably we could use a buffered channel here sending data
- // in a fifo manner whenever we send a ping, or a list keeping track of
- // the times of each ping. For now we just make a best effort and
- // only record stats if it was for the last ping sent. Any preceding
- // and overlapping pings will be ignored. It is unlikely to occur
- // without large usage of the ping rpc call since we ping
- // infrequently enough that if they overlap we would have timed out
- // the peer.
- if p.lastPingNonce != 0 && msg.Nonce == p.lastPingNonce {
- p.lastPingMicros = time.Now().Sub(p.lastPingTime).Nanoseconds()
- p.lastPingMicros /= 1000 // convert to usec.
- p.lastPingNonce = 0
- }
+func (p *Peer) RemoteAddress() *NetAddress {
+ return p.conn.RemoteAddress()
}
-// readMessage reads the next bitcoin message from the peer with logging.
-func (p *peer) readMessage() (btcwire.Message, []byte, error) {
- n, msg, buf, err := btcwire.ReadMessageN(p.conn, p.ProtocolVersion())
- p.statMtx.Lock()
- p.bytesReceived += uint64(n)
- p.statMtx.Unlock()
- p.server.AddBytesReceived(uint64(n))
- if err != nil {
- return nil, nil, err
- }
-
- // Use closures to log expensive operations so they are only run when
- // the logging level requires it.
- peerLog.Debugf("%v", newLogClosure(func() string {
- // Debug summary of message.
- summary := messageSummary(msg)
- if len(summary) > 0 {
- summary = " (" + summary + ")"
- }
- return fmt.Sprintf("Received %v%s from %s", msg.Command(), summary, p)
- }))
- peerLog.Tracef("%v", newLogClosure(func() string {
- return spew.Sdump(msg)
- }))
- peerLog.Tracef("%v", newLogClosure(func() string {
- return spew.Sdump(buf)
- }))
-
- return msg, buf, nil
+func (p *Peer) Channel(chName String) *Channel {
+ return p.channels[chName]
}
-// writeMessage sends a bitcoin Message to the peer with logging.
-func (p *peer) writeMessage(msg btcwire.Message) {
- if p.Disconnected() { return }
+// If msg isn't already in the peer's filter, then
+// queue the msg for output.
+// If the queue is full, just return false.
+func (p *Peer) TryQueueOut(chName String, msg Msg) bool {
+ channel := p.Channel(chName)
+ outQueue := channel.OutQueue()
- if !p.VersionKnown() {
- switch msg.(type) {
- case *btcwire.MsgVersion:
- // This is OK.
- default:
- // We drop all messages other than version if we
- // haven't done the handshake already.
- return
- }
+ // just return if already in filter
+ if channel.Filter().Has(msg) {
+ return true
}
- // Use closures to log expensive operations so they are only run when
- // the logging level requires it.
- peerLog.Debugf("%v", newLogClosure(func() string {
- // Debug summary of message.
- summary := messageSummary(msg)
- if len(summary) > 0 {
- summary = " (" + summary + ")"
- }
- return fmt.Sprintf("Sending %v%s to %s", msg.Command(), summary, p)
- }))
- peerLog.Tracef("%v", newLogClosure(func() string {
- return spew.Sdump(msg)
- }))
- peerLog.Tracef("%v", newLogClosure(func() string {
- var buf bytes.Buffer
- err := btcwire.WriteMessage(&buf, msg, p.ProtocolVersion())
- if err != nil {
- return err.Error()
- }
- return spew.Sdump(buf.Bytes())
- }))
-
- // Write the message to the peer.
- n, err := btcwire.WriteMessageN(p.conn, msg, p.ProtocolVersion())
- p.statMtx.Lock()
- p.bytesSent += uint64(n)
- p.statMtx.Unlock()
- p.server.AddBytesSent(uint64(n))
- if err != nil {
- p.Disconnect()
- p.logError("Can't send message to %s: %v", p, err)
- return
+ // lock & defer
+ p.mtx.Lock(); defer p.mtx.Unlock()
+ if p.stopped == 1 { return false }
+ select {
+ case outQueue <- msg:
+ return true
+ default: // buffer full
+ return false
}
+ // unlock deferred
}
-
-// inHandler handles all incoming messages for the peer. It must be run as a
-// goroutine.
-func (p *peer) inHandler() {
- // Peers must complete the initial version negotiation within a shorter
- // timeframe than a general idle timeout. The timer is then reset below
- // to idleTimeoutMinutes for all future messages.
- idleTimer := time.AfterFunc(negotiateTimeoutSeconds*time.Second, func() {
- if p.VersionKnown() {
- peerLog.Warnf("Peer %s no answer for %d minutes, disconnecting", p, idleTimeoutMinutes)
- }
- p.Disconnect()
- })
-out:
- for !p.Disconnected() {
- rmsg, buf, err := p.readMessage()
- // Stop the timer now, if we go around again we will reset it.
- idleTimer.Stop()
- if err != nil {
- if !p.Disconnected() {
- p.logError("Can't read message from %s: %v", p, err)
- }
- break out
- }
- p.statMtx.Lock()
- p.lastRecv = time.Now()
- p.statMtx.Unlock()
-
- // Ensure version message comes first.
- if _, ok := rmsg.(*btcwire.MsgVersion); !ok && !p.VersionKnown() {
- p.logError("A version message must precede all others")
- break out
- }
-
- // Handle each supported message type.
- markGood := false
- switch msg := rmsg.(type) {
- case *btcwire.MsgVersion:
- p.handleVersionMsg(msg)
-
- case *btcwire.MsgVerAck:
- // Do nothing.
-
- case *btcwire.MsgGetAddr:
- p.handleGetAddrMsg(msg)
-
- case *btcwire.MsgAddr:
- p.handleAddrMsg(msg)
- markGood = true
-
- case *btcwire.MsgPing:
- p.handlePingMsg(msg)
- markGood = true
-
- case *btcwire.MsgPong:
- p.handlePongMsg(msg)
-
- case *btcwire.MsgAlert:
- p.server.BroadcastMessage(msg, p)
-
- case *btcwire.MsgNotFound:
- // TODO(davec): Ignore this for now, but ultimately
- // it should probably be used to detect when something
- // we requested needs to be re-requested from another
- // peer.
-
- default:
- peerLog.Debugf("Received unhandled message of type %v: Fix Me", rmsg.Command())
- }
-
- // Mark the address as currently connected and working as of
- // now if one of the messages that trigger it was processed.
- if markGood && !p.Disconnected() {
- if p.addr == nil {
- peerLog.Warnf("we're getting stuff before we got a version message. that's bad")
- continue
- }
- p.server.addrManager.MarkGood(p.addr)
- }
- // ok we got a message, reset the timer.
- // timer just calls p.Disconnect() after logging.
- idleTimer.Reset(idleTimeoutMinutes * time.Minute)
- }
-
- idleTimer.Stop()
-
- // Ensure connection is closed and notify the server that the peer is done.
- p.Disconnect()
- p.server.donePeers <- p
-
- // Only tell block manager we are gone if we ever told it we existed.
- if p.VersionKnown() {
- p.server.blockManager.DonePeer(p)
- }
-
- peerLog.Tracef("Peer input handler done for %s", p)
+func (p *Peer) WriteTo(w io.Writer) (n int64, err error) {
+ return p.RemoteAddress().WriteTo(w)
}
-// outHandler handles all outgoing messages for the peer. It must be run as a
-// goroutine. It uses a buffered channel to serialize output messages while
-// allowing the sender to continue running asynchronously.
-func (p *peer) outHandler() {
- pingTimer := time.AfterFunc(pingTimeoutMinutes*time.Minute, func() {
- nonce, err := btcwire.RandomUint64()
- if err != nil {
- peerLog.Errorf("Not sending ping on timeout to %s: %v",
- p, err)
- return
- }
- p.QueueMessage(btcwire.NewMsgPing(nonce), nil)
- })
-out:
+func (p *Peer) inHandler(chName String, inboundMsgQueue chan<- *InboundMsg) {
+ channel := p.channels[chName]
+ inQueue := channel.InQueue()
+
+ FOR_LOOP:
for {
select {
- case msg := <-p.outputQueue:
- // If the message is one we should get a reply for
- // then reset the timer, we only want to send pings
- // when otherwise we would not receive a reply from
- // the peer.
- peerLog.Tracef("%s: received from outputQueue", p)
- reset := true
- switch m := msg.msg.(type) {
- case *btcwire.MsgVersion:
- // should get an ack
- case *btcwire.MsgGetAddr:
- // should get addresses
- case *btcwire.MsgPing:
- // expects pong
- // Also set up statistics.
- p.statMtx.Lock()
- p.lastPingNonce = m.Nonce
- p.lastPingTime = time.Now()
- p.statMtx.Unlock()
- default:
- // Not one of the above, no sure reply.
- // We want to ping if nothing else
- // interesting happens.
- reset = false
+ case <-quit:
+ break FOR_LOOP
+ case msg := <-inQueue:
+ // add to channel filter
+ channel.Filter().Add(msg)
+ // send to inboundMsgQueue
+ inboundMsg := &InboundMsg{
+ Peer: p,
+ Channel: channel,
+ Time: Time(time.Now()),
+ Msg: msg,
}
- if reset {
- pingTimer.Reset(pingTimeoutMinutes * time.Minute)
- }
- p.writeMessage(msg.msg)
- p.statMtx.Lock()
- p.lastSend = time.Now()
- p.statMtx.Unlock()
- if msg.doneChan != nil {
- msg.doneChan <- true
- }
-
- case <-p.quit:
- break out
- }
- }
-
- pingTimer.Stop()
-
- // Drain outputQueue
- for msg := range p.outputQueue {
- if msg.doneChan != nil {
- msg.doneChan <- false
- }
- }
- peerLog.Tracef("Peer output handler done for %s", p)
-}
-
-// QueueMessage adds the passed bitcoin message to the peer outputQueue. It
-// uses a buffered channel to communicate with the output handler goroutine so
-// it is automatically rate limited and safe for concurrent access.
-func (p *peer) QueueMessage(msg btcwire.Message, doneChan chan bool) {
- // Avoid risk of deadlock if goroutine already exited. The goroutine
- // we will be sending to hangs around until it knows for a fact that
- // it is marked as disconnected. *then* it drains the channels.
- if p.Disconnected() {
- // avoid deadlock...
- if doneChan != nil {
- go func() {
- doneChan <- false
- }()
- }
- return
- }
- p.outputQueue <- outMsg{msg: msg, doneChan: doneChan}
-}
-
-// True if is (or will become) disconnected.
-func (p *peer) Disconnected() bool {
- return atomic.LoadInt32(&p.disconnected) == 1
-}
-
-// Disconnects the peer by closing the connection. It also sets
-// a flag so the impending shutdown can be detected.
-func (p *peer) Disconnect() {
- p.connMtx.Lock(); defer p.connMtx.Unlock()
- // did we win the race?
- if atomic.AddInt32(&p.disconnected, 1) != 1 {
- return
- }
- peerLog.Tracef("disconnecting %s", p)
- close(p.quit)
- if p.conn != nil {
- p.conn.Close()
- }
-}
-
-// Sets the connection & starts
-func (p *peer) StartWithConnection(conn *net.Conn) {
- p.connMtx.Lock(); defer p.connMtx.Unlock()
- if p.conn != nil { panic("Conn already set") }
- if atomic.LoadInt32(&p.disconnected) == 1 { return }
- peerLog.Debugf("Connected to %s", conn.RemoteAddr())
- p.timeConnected = time.Now()
- p.conn = conn
- p.Start()
-}
-
-// Start begins processing input and output messages. It also sends the initial
-// version message for outbound connections to start the negotiation process.
-func (p *peer) Start() error {
- // Already started?
- if atomic.AddInt32(&p.started, 1) != 1 {
- return nil
- }
-
- peerLog.Tracef("Starting peer %s", p)
-
- // Send an initial version message if this is an outbound connection.
- if !p.inbound {
- p.pushVersionMsg()
- }
-
- // Start processing input and output.
- go p.inHandler()
- go p.outHandler()
-
- return nil
-}
-
-// Shutdown gracefully shuts down the peer by disconnecting it.
-func (p *peer) Shutdown() {
- peerLog.Tracef("Shutdown peer %s", p)
- p.Disconnect()
-}
-
-// newPeerBase returns a new base peer for the provided server and inbound flag.
-// This is used by the newInboundPeer and newOutboundPeer functions to perform
-// base setup needed by both types of peers.
-func newPeerBase(s *server, inbound bool) *peer {
- p := peer{
- server: s,
- protocolVersion: maxProtocolVersion,
- inbound: inbound,
- knownAddresses: make(map[string]bool),
- outputQueue: make(chan outMsg, outputBufferSize),
- quit: make(chan bool),
- }
- return &p
-}
-
-// newPeer returns a new inbound bitcoin peer for the provided server and
-// connection. Use Start to begin processing incoming and outgoing messages.
-func newInboundPeer(s *server, conn net.Conn) *peer {
- addr := NewNetAddress(conn.RemoteAddr())
- // XXX What if p.addr doesn't match (to be) reported addr due to NAT?
- s.addrManager.MarkAttempt(addr)
-
- p := newPeerBase(s, true)
- p.conn = conn
- p.addr = addr
- p.timeConnected = time.Now()
- return p
-}
-
-// newOutbountPeer returns a new outbound bitcoin peer for the provided server and
-// address and connects to it asynchronously. If the connection is successful
-// then the peer will also be started.
-func newOutboundPeer(s *server, addr *NetAddress, persistent bool) *peer {
- p := newPeerBase(s, false)
- p.addr = addr
- p.persistent = persistent
-
- go func() {
- // Mark this as one attempt, regardless of # of reconnects.
- s.addrManager.MarkAttempt(p.addr)
- retryCount := 0
- // Attempt to connect to the peer. If the connection fails and
- // this is a persistent connection, retry after the retry
- // interval.
- for {
- peerLog.Debugf("Attempting to connect to %s", addr)
- conn, err := addr.Dial()
- if err == nil {
- p.StartWithConnection(conn)
- return
- } else {
- retryCount++
- peerLog.Debugf("Failed to connect to %s: %v", addr, err)
- if !persistent {
- p.server.donePeers <- p
- return
- }
- scaledInterval := connectionRetryInterval.Nanoseconds() * retryCount / 2
- scaledDuration := time.Duration(scaledInterval)
- peerLog.Debugf("Retrying connection to %s in %s", addr, scaledDuration)
- time.Sleep(scaledDuration)
+ select {
+ case <-quit:
+ break FOR_LOOP
+ case inboundMsgQueue <- inboundMsg:
continue
}
}
- }()
- return p
+ }
+
+ // cleanup
+ // (none)
}
-// logError makes sure that we only log errors loudly on user peers.
-func (p *peer) logError(fmt string, args ...interface{}) {
- if p.persistent {
- peerLog.Errorf(fmt, args...)
- } else {
- peerLog.Debugf(fmt, args...)
+func (p *Peer) outHandler(chName String) {
+ outQueue := p.channels[chName].OutQueue()
+ FOR_LOOP:
+ for {
+ select {
+ case <-quit:
+ break FOR_LOOP
+ case msg := <-outQueue:
+ // blocks until the connection is Stop'd,
+ // which happens when this peer is Stop'd.
+ p.conn.QueueOut(msg.Bytes)
+ }
+ }
+
+ // cleanup
+ // (none)
+}
+
+
+/* Channel */
+
+type Channel struct {
+ name String
+
+ mtx sync.Mutex
+ filter Filter
+
+ inQueue chan Msg
+ outQueue chan Msg
+ //stats Stats
+}
+
+func NewChannel(name String, filter Filter, in, out chan Msg) *Channel {
+ return &Channel{
+ name: name,
+ filter: filter,
+ inQueue: in,
+ outQueue: out,
}
}
+
+func (c *Channel) InQueue() <-chan Msg {
+ return c.inQueue
+}
+
+func (c *Channel) OutQueue() chan<- Msg {
+ return c.outQueue
+}
+
+func (c *Channel) Add(msg Msg) {
+ c.Filter().Add(msg)
+}
+
+func (c *Channel) Has(msg Msg) bool {
+ return c.Filter().Has(msg)
+}
+
+// TODO: maybe don't expose this
+func (c *Channel) Filter() Filter {
+ // lock & defer
+ c.mtx.Lock(); defer c.mtx.Unlock()
+ return c.filter
+ // unlock deferred
+}
+
+// TODO: maybe don't expose this
+func (c *Channel) UpdateFilter(filter Filter) {
+ // lock
+ c.mtx.Lock()
+ c.filter = filter
+ c.mtx.Unlock()
+ // unlock
+}
diff --git a/peer/server.go b/peer/server.go
new file mode 100644
index 00000000..787e5f7a
--- /dev/null
+++ b/peer/server.go
@@ -0,0 +1,32 @@
+package peer
+
+import (
+)
+
+/* Server */
+
+type Server struct {
+ listener Listener
+ client *Client
+}
+
+func NewServer(l Listener, c *Client) *Server {
+ s := &Server{
+ listener: l,
+ client: c,
+ }
+ go s.IncomingConnectionsHandler()
+ return s
+}
+
+// meant to run in a goroutine
+func (s *Server) IncomingConnectionHandler() {
+ for conn := range s.listener.Connections() {
+ s.client.AddIncomingConnection(conn)
+ }
+}
+
+func (s *Server) Stop() {
+ s.listener.Stop()
+ s.client.Stop()
+}
diff --git a/peer/upnp.go b/peer/upnp.go
new file mode 100644
index 00000000..711440c9
--- /dev/null
+++ b/peer/upnp.go
@@ -0,0 +1,368 @@
+// from taipei-torrent
+
+package peer
+
+// Just enough UPnP to be able to forward ports
+//
+
+import (
+ "bytes"
+ "encoding/xml"
+ "errors"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type upnpNAT struct {
+ serviceURL string
+ ourIP string
+ urnDomain string
+}
+
+func Discover() (nat NAT, err error) {
+ ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
+ if err != nil {
+ return
+ }
+ conn, err := net.ListenPacket("udp4", ":0")
+ if err != nil {
+ return
+ }
+ socket := conn.(*net.UDPConn)
+ defer socket.Close()
+
+ err = socket.SetDeadline(time.Now().Add(3 * time.Second))
+ if err != nil {
+ return
+ }
+
+ st := "InternetGatewayDevice:1"
+
+ buf := bytes.NewBufferString(
+ "M-SEARCH * HTTP/1.1\r\n" +
+ "HOST: 239.255.255.250:1900\r\n" +
+ "ST: ssdp:all\r\n" +
+ "MAN: \"ssdp:discover\"\r\n" +
+ "MX: 2\r\n\r\n")
+ message := buf.Bytes()
+ answerBytes := make([]byte, 1024)
+ for i := 0; i < 3; i++ {
+ _, err = socket.WriteToUDP(message, ssdp)
+ if err != nil {
+ return
+ }
+ var n int
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ for {
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ if err != nil {
+ break
+ }
+ answer := string(answerBytes[0:n])
+ if strings.Index(answer, st) < 0 {
+ continue
+ }
+ // HTTP header field names are case-insensitive.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
+ locString := "\r\nlocation:"
+ answer = strings.ToLower(answer)
+ locIndex := strings.Index(answer, locString)
+ if locIndex < 0 {
+ continue
+ }
+ loc := answer[locIndex+len(locString):]
+ endIndex := strings.Index(loc, "\r\n")
+ if endIndex < 0 {
+ continue
+ }
+ locURL := strings.TrimSpace(loc[0:endIndex])
+ var serviceURL, urnDomain string
+ serviceURL, urnDomain, err = getServiceURL(locURL)
+ if err != nil {
+ return
+ }
+ var ourIP net.IP
+ ourIP, err = localIPv4()
+ if err != nil {
+ return
+ }
+ nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
+ return
+ }
+ }
+ err = errors.New("UPnP port discovery failed.")
+ return
+}
+
+type Envelope struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
+ Soap *SoapBody
+}
+type SoapBody struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
+ ExternalIP *ExternalIPAddressResponse
+}
+
+type ExternalIPAddressResponse struct {
+ XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
+ IPAddress string `xml:"NewExternalIPAddress"`
+}
+
+type ExternalIPAddress struct {
+ XMLName xml.Name `xml:"NewExternalIPAddress"`
+ IP string
+}
+
+type Service struct {
+ ServiceType string `xml:"serviceType"`
+ ControlURL string `xml:"controlURL"`
+}
+
+type DeviceList struct {
+ Device []Device `xml:"device"`
+}
+
+type ServiceList struct {
+ Service []Service `xml:"service"`
+}
+
+type Device struct {
+ XMLName xml.Name `xml:"device"`
+ DeviceType string `xml:"deviceType"`
+ DeviceList DeviceList `xml:"deviceList"`
+ ServiceList ServiceList `xml:"serviceList"`
+}
+
+type Root struct {
+ Device Device
+}
+
+func getChildDevice(d *Device, deviceType string) *Device {
+ dl := d.DeviceList.Device
+ for i := 0; i < len(dl); i++ {
+ if strings.Index(dl[i].DeviceType, deviceType) >= 0 {
+ return &dl[i]
+ }
+ }
+ return nil
+}
+
+func getChildService(d *Device, serviceType string) *Service {
+ sl := d.ServiceList.Service
+ for i := 0; i < len(sl); i++ {
+ if strings.Index(sl[i].ServiceType, serviceType) >= 0 {
+ return &sl[i]
+ }
+ }
+ return nil
+}
+
+func localIPv4() (net.IP, error) {
+ tt, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, t := range tt {
+ aa, err := t.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, a := range aa {
+ ipnet, ok := a.(*net.IPNet)
+ if !ok {
+ continue
+ }
+ v4 := ipnet.IP.To4()
+ if v4 == nil || v4[0] == 127 { // loopback address
+ continue
+ }
+ return v4, nil
+ }
+ }
+ return nil, errors.New("cannot find local IP address")
+}
+
+func getServiceURL(rootURL string) (url, urnDomain string, err error) {
+ r, err := http.Get(rootURL)
+ if err != nil {
+ return
+ }
+ defer r.Body.Close()
+ if r.StatusCode >= 400 {
+ err = errors.New(string(r.StatusCode))
+ return
+ }
+ var root Root
+ err = xml.NewDecoder(r.Body).Decode(&root)
+ if err != nil {
+ return
+ }
+ a := &root.Device
+ if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 {
+ err = errors.New("No InternetGatewayDevice")
+ return
+ }
+ b := getChildDevice(a, "WANDevice:1")
+ if b == nil {
+ err = errors.New("No WANDevice")
+ return
+ }
+ c := getChildDevice(b, "WANConnectionDevice:1")
+ if c == nil {
+ err = errors.New("No WANConnectionDevice")
+ return
+ }
+ d := getChildService(c, "WANIPConnection:1")
+ if d == nil {
+ // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
+ // instead of under WanConnectionDevice
+ d = getChildService(b, "WANIPConnection:1")
+
+ if d == nil {
+ err = errors.New("No WANIPConnection")
+ return
+ }
+ }
+ // Extract the domain name, which isn't always 'schemas-upnp-org'
+ urnDomain = strings.Split(d.ServiceType, ":")[1]
+ url = combineURL(rootURL, d.ControlURL)
+ return
+}
+
+func combineURL(rootURL, subURL string) string {
+ protocolEnd := "://"
+ protoEndIndex := strings.Index(rootURL, protocolEnd)
+ a := rootURL[protoEndIndex+len(protocolEnd):]
+ rootIndex := strings.Index(a, "/")
+ return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
+}
+
+func soapRequest(url, function, message, domain string) (r *http.Response, err error) {
+ fullMessage := "" +
+ "\r\n" +
+ "" + message + ""
+
+ req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
+ req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
+ //req.Header.Set("Transfer-Encoding", "chunked")
+ req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
+ req.Header.Set("Connection", "Close")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Pragma", "no-cache")
+
+ // log.Stderr("soapRequest ", req)
+
+ r, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ /*if r.Body != nil {
+ defer r.Body.Close()
+ }*/
+
+ if r.StatusCode >= 400 {
+ // log.Stderr(function, r.StatusCode)
+ err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
+ r = nil
+ return
+ }
+ return
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) {
+
+ message := "\r\n" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+ var envelope Envelope
+ data, err := ioutil.ReadAll(response.Body)
+ reader := bytes.NewReader(data)
+ xml.NewDecoder(reader).Decode(&envelope)
+
+ info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
+
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+ info, err := n.getExternalIPAddress()
+ if err != nil {
+ return
+ }
+ addr = net.ParseIP(info.externalIpAddress)
+ return
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
+ // A single concatenation would break ARM compilation.
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort)
+ message += "" + protocol + ""
+ message += "" + strconv.Itoa(internalPort) + "" +
+ "" + n.ourIP + "" +
+ "1"
+ message += description +
+ "" + strconv.Itoa(timeout) +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was forwarded
+ // log.Println(message, response)
+ mappedExternalPort = externalPort
+ _ = response
+ return
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort) +
+ "" + protocol + "" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was deleted
+ // log.Println(message, response)
+ _ = response
+ return
+}