From 002cfc8f75619004b8ed68f58dbc5f11671ea70f Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 24 Jun 2014 17:28:40 -0700 Subject: [PATCH] architecting peer --- binary/byteslice.go | 8 + binary/int.go | 7 + merkle/{iavl.go => iavl_node.go} | 110 ----- merkle/iavl_tree.go | 108 ++++ peer/addrbook.go | 80 --- peer/client.go | 232 +++++---- peer/connection.go | 102 ++-- peer/connection_test.go | 39 +- peer/filter.go | 16 + peer/listener.go | 80 +-- peer/msg.go | 30 ++ peer/peer.go | 817 ++++++------------------------- peer/server.go | 32 ++ peer/upnp.go | 368 ++++++++++++++ 14 files changed, 986 insertions(+), 1043 deletions(-) rename merkle/{iavl.go => iavl_node.go} (81%) create mode 100644 merkle/iavl_tree.go create mode 100644 peer/filter.go create mode 100644 peer/msg.go create mode 100644 peer/server.go create mode 100644 peer/upnp.go diff --git a/binary/byteslice.go b/binary/byteslice.go index bcbde1b4..bfe92bea 100644 --- a/binary/byteslice.go +++ b/binary/byteslice.go @@ -40,3 +40,11 @@ func ReadByteSlice(r io.Reader) ByteSlice { if err != nil { panic(err) } return ByteSlice(bytes) } + +func ReadByteSliceSafe(r io.Reader) (ByteSlice, error) { + length := int(ReadUInt32(r)) + bytes := make([]byte, length) + _, err := io.ReadFull(r, bytes) + if err != nil { return nil, err } + return ByteSlice(bytes), nil +} diff --git a/binary/int.go b/binary/int.go index 9aeac271..85bc21c7 100644 --- a/binary/int.go +++ b/binary/int.go @@ -48,6 +48,13 @@ func ReadByte(r io.Reader) Byte { return Byte(buf[0]) } +func ReadByteSafe(r io.Reader) (Byte, error) { + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { return Byte(0), err } + return Byte(buf[0]), nil +} + // Int8 diff --git a/merkle/iavl.go b/merkle/iavl_node.go similarity index 81% rename from merkle/iavl.go rename to merkle/iavl_node.go index 696d672e..d8312827 100644 --- a/merkle/iavl.go +++ b/merkle/iavl_node.go @@ -1,9 +1,3 @@ -/* -This tree is not concurrency safe. -If you want to use it from multiple goroutines, you need to wrap all calls to *IAVLTree -with a mutex. -*/ - package merkle import ( @@ -13,110 +7,6 @@ import ( "crypto/sha256" ) -const HASH_BYTE_SIZE int = 4+32 - -// Immutable AVL Tree (wraps the Node root) - -type IAVLTree struct { - db Db - root *IAVLNode -} - -func NewIAVLTree(db Db) *IAVLTree { - return &IAVLTree{db:db, root:nil} -} - -func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { - root := &IAVLNode{ - hash: hash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - root.fill(db) - return &IAVLTree{db:db, root:root} -} - -func (t *IAVLTree) Root() Node { - return t.root -} - -func (t *IAVLTree) Size() uint64 { - if t.root == nil { return 0 } - return t.root.Size() -} - -func (t *IAVLTree) Height() uint8 { - if t.root == nil { return 0 } - return t.root.Height() -} - -func (t *IAVLTree) Has(key Key) bool { - if t.root == nil { return false } - return t.root.has(t.db, key) -} - -func (t *IAVLTree) Put(key Key, value Value) (updated bool) { - if t.root == nil { - t.root = NewIAVLNode(key, value) - return false - } - t.root, updated = t.root.put(t.db, key, value) - return updated -} - -func (t *IAVLTree) Hash() (ByteSlice, uint64) { - if t.root == nil { return nil, 0 } - return t.root.Hash() -} - -func (t *IAVLTree) Save() { - if t.root == nil { return } - if t.root.hash == nil { - t.root.Hash() - } - t.root.Save(t.db) -} - -func (t *IAVLTree) Get(key Key) (value Value) { - if t.root == nil { return nil } - return t.root.get(t.db, key) -} - -func (t *IAVLTree) Remove(key Key) (value Value, err error) { - if t.root == nil { return nil, NotFound(key) } - newRoot, _, value, err := t.root.remove(t.db, key) - if err != nil { - return nil, err - } - t.root = newRoot - return value, nil -} - -func (t *IAVLTree) Copy() Tree { - return &IAVLTree{db:t.db, root:t.root} -} - -// Traverses all the nodes of the tree in prefix order. -// return true from cb to halt iteration. -// node.Height() == 0 if you just want a value node. -func (t *IAVLTree) Traverse(cb func(Node) bool) { - if t.root == nil { return } - t.root.traverse(t.db, cb) -} - -func (t *IAVLTree) Values() <-chan Value { - root := t.root - ch := make(chan Value) - go func() { - root.traverse(t.db, func(n Node) bool { - if n.Height() == 0 { ch <- n.Value() } - return true - }) - close(ch) - }() - return ch -} - - // Node type IAVLNode struct { diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go new file mode 100644 index 00000000..de2c6930 --- /dev/null +++ b/merkle/iavl_tree.go @@ -0,0 +1,108 @@ +package merkle + +const HASH_BYTE_SIZE int = 4+32 + +/* +Immutable AVL Tree (wraps the Node root) + +This tree is not concurrency safe. +You must wrap your calls with your own mutex. +*/ +type IAVLTree struct { + db Db + root *IAVLNode +} + +func NewIAVLTree(db Db) *IAVLTree { + return &IAVLTree{db:db, root:nil} +} + +func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { + root := &IAVLNode{ + hash: hash, + flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, + } + root.fill(db) + return &IAVLTree{db:db, root:root} +} + +func (t *IAVLTree) Root() Node { + return t.root +} + +func (t *IAVLTree) Size() uint64 { + if t.root == nil { return 0 } + return t.root.Size() +} + +func (t *IAVLTree) Height() uint8 { + if t.root == nil { return 0 } + return t.root.Height() +} + +func (t *IAVLTree) Has(key Key) bool { + if t.root == nil { return false } + return t.root.has(t.db, key) +} + +func (t *IAVLTree) Put(key Key, value Value) (updated bool) { + if t.root == nil { + t.root = NewIAVLNode(key, value) + return false + } + t.root, updated = t.root.put(t.db, key, value) + return updated +} + +func (t *IAVLTree) Hash() (ByteSlice, uint64) { + if t.root == nil { return nil, 0 } + return t.root.Hash() +} + +func (t *IAVLTree) Save() { + if t.root == nil { return } + if t.root.hash == nil { + t.root.Hash() + } + t.root.Save(t.db) +} + +func (t *IAVLTree) Get(key Key) (value Value) { + if t.root == nil { return nil } + return t.root.get(t.db, key) +} + +func (t *IAVLTree) Remove(key Key) (value Value, err error) { + if t.root == nil { return nil, NotFound(key) } + newRoot, _, value, err := t.root.remove(t.db, key) + if err != nil { + return nil, err + } + t.root = newRoot + return value, nil +} + +func (t *IAVLTree) Copy() Tree { + return &IAVLTree{db:t.db, root:t.root} +} + +// Traverses all the nodes of the tree in prefix order. +// return true from cb to halt iteration. +// node.Height() == 0 if you just want a value node. +func (t *IAVLTree) Traverse(cb func(Node) bool) { + if t.root == nil { return } + t.root.traverse(t.db, cb) +} + +func (t *IAVLTree) Values() <-chan Value { + root := t.root + ch := make(chan Value) + go func() { + root.traverse(t.db, func(n Node) bool { + if n.Height() == 0 { ch <- n.Value() } + return true + }) + close(ch) + }() + return ch +} diff --git a/peer/addrbook.go b/peer/addrbook.go index c5381498..a833de8b 100644 --- a/peer/addrbook.go +++ b/peer/addrbook.go @@ -36,9 +36,6 @@ type AddrBook struct { quit chan struct{} nOld int nNew int - - lamtx sync.Mutex - localAddresses map[string]*localAddress } const ( @@ -474,83 +471,6 @@ func (a *AddrBook) getOldBucket(addr *NetAddress) int { } -/* Local Address */ - -// addressPrio is an enum type used to describe the heirarchy of local address -// discovery methods. -type addressPrio int - -const ( - InterfacePrio addressPrio = iota // address of local interface. - BoundPrio // Address explicitly bound to. - UpnpPrio // External IP discovered from UPnP - HttpPrio // Obtained from internet service. - ManualPrio // provided by --externalip. -) - -type localAddress struct { - Addr *NetAddress - Score addressPrio -} - -func (a *AddrBook) AddLocalAddress(addr *NetAddress, priority addressPrio) { - a.mtx.Lock(); defer a.mtx.Unlock() - - // sanity check. - if !addr.Routable() { - log.Debugf("rejecting address %s:%d due to routability", addr.IP, addr.Port) - return - } - log.Debugf("adding address %s:%d", addr.IP, addr.Port) - - key := addr.String() - la, ok := a.localAddresses[key] - if !ok || la.Score < priority { - if ok { - la.Score = priority + 1 - } else { - a.localAddresses[key] = &localAddress{ - Addr: addr, - Score: priority, - } - } - } -} - -// Returns the most appropriate local address that we know -// of to be contacted by rna (remote net address) -func (a *AddrBook) GetBestLocalAddress(rna *NetAddress) *NetAddress { - a.mtx.Lock(); defer a.mtx.Unlock() - - bestReach := 0 - var bestScore addressPrio - var bestAddr *NetAddress - for _, la := range a.localAddresses { - reach := rna.ReachabilityTo(la.Addr) - if reach > bestReach || - (reach == bestReach && la.Score > bestScore) { - bestReach = reach - bestScore = la.Score - bestAddr = la.Addr - } - } - if bestAddr != nil { - log.Debugf("Suggesting address %s:%d for %s:%d", - bestAddr.IP, bestAddr.Port, rna.IP, rna.Port) - } else { - log.Debugf("No worthy address for %s:%d", - rna.IP, rna.Port) - // Send something unroutable if nothing suitable. - bestAddr = &NetAddress{ - IP: net.IP([]byte{0, 0, 0, 0}), - Port: 0, - } - } - - return bestAddr -} - - // Return a string representing the network group of this address. // This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string // "local" for a local address and the string "unroutable for an unroutable diff --git a/peer/client.go b/peer/client.go index 567f6d77..c0962b28 100644 --- a/peer/client.go +++ b/peer/client.go @@ -3,153 +3,169 @@ package peer import ( . "github.com/tendermint/tendermint/binary" "github.com/tendermint/tendermint/merkle" + "atomic" "sync" "io" + "errors" ) -/* Client */ +/* Client + + A client is half of a p2p system. + It can reach out to the network and establish connections with servers. + A client doesn't listen for incoming connections -- that's done by the server. + + newPeerCb is a factory method for generating new peers from new *Connections. + newPeerCb(nil) must return a prototypical peer that represents the self "peer". + + XXX what about peer disconnects? +*/ type Client struct { - listener *Listener addrBook AddrBook - strategies map[String]*FilterStrategy targetNumPeers int + newPeerCb func(*Connection) *Peer + self *Peer + inQueues map[String]chan *InboundMsg - peersMtx sync.Mutex + mtx sync.Mutex peers merkle.Tree // addr -> *Peer - - filtersMtx sync.Mutex - filters merkle.Tree // channelName -> Filter (objects that I know of) + quit chan struct{} + stopped uint32 } -func NewClient(protocol string, laddr string) *Client { - // XXX set the handler - listener := NewListener(protocol, laddr, nil) +var ( + CLIENT_STOPPED_ERROR = errors.New("Client already stopped") + CLIENT_DUPLICATE_PEER_ERROR = errors.New("Duplicate peer") +) + +func NewClient(newPeerCb func(*Connect) *Peer) *Client { + self := newPeerCb(nil) + if self == nil { + Panicf("newPeerCb(nil) must return a prototypical peer for self") + } + + inQueues := make(map[String]chan *InboundMsg) + for chName, channel := peer.channels { + inQueues[chName] = make(chan *InboundMsg) + } + c := &Client{ - listener: listener, + newPeerCb: newPeerCb, peers: merkle.NewIAVLTree(nil), - filters: merkle.NewIAVLTree(nil), + self: self, + inQueues: inQueues, } return c } -func (c *Client) Start() (<-chan *IncomingMsg) { - return nil -} - func (c *Client) Stop() { - c.listener.Close() -} - -func (c *Client) LocalAddress() *NetAddress { - return c.listener.LocalAddress() -} - -func (c *Client) ConnectTo(addr *NetAddress) (*Peer, error) { - - conn, err := addr.Dial() - if err != nil { return nil, err } - peer := NewPeer(conn) - // lock - c.peersMtx.Lock() - c.peers.Put(addr, peer) - c.peersMtx.Unlock() + c.mtx.Lock() + if atomic.CompareAndSwapUint32(&c.stopped, 0, 1) { + close(c.quit) + // stop each peer. + for peerValue := range c.peers.Values() { + peer := peerValue.(*Peer) + peer.Stop() + } + // empty tree. + c.peers = merkle.NewIAVLTree(nil) + } + c.mtx.Unlock() // unlock +} + +func (c *Client) AddPeerWithConnection(conn *Connection, outgoing bool) (*Peer, error) { + if atomic.LoadUint32(&c.stopped) == 1 { return nil, CLIENT_STOPPED_ERROR } + + peer := c.newPeerCb(conn) + peer.outgoing = outgoing + err := c.addPeer(peer) + if err != nil { return nil, err } + + go peer.Start(c.inQueues) return peer, nil } -func (c *Client) Broadcast(channel String, msg Binary) { +func (c *Client) Broadcast(chName String, msg Msg) { + if atomic.LoadUint32(&c.stopped) == 1 { return } + for v := range c.peersCopy().Values() { - peer, ok := v.(*Peer) - if !ok { panic("Expected peer but got something else") } - peer.Queue(channel, msg) + peer := v.(*Peer) + success := peer.TryQueueOut(chName , msg) + if !success { + // TODO: notify the peer + } } } -// Updates the client's filter for a channel & broadcasts it. -func (c *Client) UpdateFilter(channel String, filter Filter) { - c.filtersMtx.Lock() - c.filters.Put(channel, filter) - c.filtersMtx.Unlock() +func (c *Client) PopMessage(chName String) *InboundMsg { + if atomic.LoadUint32(&c.stopped) == 1 { return nil } + + channel := c.Channel(chName) + q := c.inQueues[chName] + if q == nil { Panicf("Expected inQueues[%f], found none", chName) } + + for { + select { + case <-quit: + return nil + case msg := <-q: + // skip if known. + if channel.Filter().Has(msg) { + continue + } + return msg + } + } +} + +// Updates self's filter for a channel & broadcasts it. +// TODO: maybe don't expose this +func (c *Client) UpdateFilter(chName String, filter Filter) { + if atomic.LoadUint32(&c.stopped) == 1 { return } + + c.self.Channel(chName).UpdateFilter(filter) c.Broadcast("", &NewFilterMsg{ - Channel: channel, + Channel: chName, Filter: filter, }) } -func (c *Client) peersCopy() merkle.Tree { - c.peersMtx.Lock(); defer c.peersMtx.Unlock() - return c.peers.Copy() -} +func (c *Client) StopPeer(peer *Peer) { + // lock + c.mtx.Lock() + p, _ := c.peers.Remove(peer.RemoteAddress()) + c.mtx.Unlock() + // unlock - -/* Channel */ -type Channel struct { - Name String - Filter Filter - //Stats Stats -} - - -/* Peer */ -type Peer struct { - Conn *Connection - Channels map[String]*Channel -} - -func NewPeer(conn *Connection) *Peer { - return &Peer{ - Conn: conn, - Channels: nil, + if p != nil { + p.Stop() } } -// Must be quick and nonblocking. -func (p *Peer) Queue(channel String, msg Binary) {} +func (c *Client) addPeer(peer *Peer) error { + addr := peer.RemoteAddress() -func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { - return 0, nil // TODO + // lock & defer + c.mtx.Lock(); defer c.mtx.Unlock() + if c.stopped == 1 { return CLIENT_STOPPED_ERROR } + if !c.peers.Has(addr) { + c.peers.Put(addr, peer) + return nil + } else { + // ignore duplicate peer for addr. + log.Infof("Ignoring duplicate peer for addr %v", addr) + return CLIENT_DUPLICATE_PEER_ERROR + } + // unlock deferred } - -/* IncomingMsg */ -type IncomingMsg struct { - SPeer *Peer - SChan *Channel - - Time Time - - Msg Binary -} - - -/* Filter - - A Filter could be a bloom filter for lossy filtering, or could be a lossless filter. - Either way, it's used to keep track of what a peer knows of. -*/ -type Filter interface { - Binary - Add(ByteSlice) - Has(ByteSlice) bool -} - -/* FilterStrategy - - Defines how filters are generated per peer, and whether they need to get refreshed occasionally. -*/ -type FilterStrategy interface { - LoadFilter(ByteSlice) Filter -} - -/* NewFilterMsg */ -type NewFilterMsg struct { - Channel String - Filter Filter -} - -func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) { - return 0, nil // TODO +func (c *Client) peersCopy() merkle.Tree { + // lock & defer + c.mtx.Lock(); defer c.mtx.Unlock() + return c.peers.Copy() + // unlock deferred } diff --git a/peer/connection.go b/peer/connection.go index 6d5b6e53..27e2f57c 100644 --- a/peer/connection.go +++ b/peer/connection.go @@ -3,6 +3,7 @@ package peer import ( . "github.com/tendermint/tendermint/common" . "github.com/tendermint/tendermint/binary" + "atomic" "sync" "net" "runtime" @@ -20,12 +21,10 @@ const ( type Connection struct { ioStats IOStats - mtx sync.Mutex - outQueue chan ByteSlice + outQueue chan ByteSlice // never closes. conn net.Conn quit chan struct{} - disconnected bool - + stopped int32 pingDebouncer *Debouncer pong chan struct{} } @@ -46,13 +45,14 @@ func NewConnection(conn net.Conn) *Connection { } } -func (c *Connection) QueueMessage(msg ByteSlice) bool { - c.mtx.Lock(); defer c.mtx.Unlock() - if c.disconnected { return false } +// returns true if successfully queued, +// returns false if connection was closed. +// blocks. +func (c *Connection) QueueOut(msg ByteSlice) bool { select { case c.outQueue <- msg: return true - default: // buffer full + case <-c.quit: return false } } @@ -62,13 +62,25 @@ func (c *Connection) Start() { go c.inHandler() } -func (c *Connection) Disconnect() { - c.mtx.Lock(); defer c.mtx.Unlock() - close(c.quit) - c.conn.Close() - c.pingDebouncer.Stop() - // do not close c.pong - c.disconnected = true +func (c *Connection) Stop() { + if atomic.SwapAndCompare(&c.stopped, 0, 1) { + close(c.quit) + c.conn.Close() + c.pingDebouncer.Stop() + // We can't close pong safely here because + // inHandler may write to it after we've stopped. + // Though it doesn't need to get closed at all, + // we close it @ inHandler. + // close(c.pong) + } +} + +func (c *Connection) LocalAddress() *NetAddress { + return NewNetAddress(c.conn.LocalAddr()) +} + +func (c *Connection) RemoteAddress() *NetAddress { + return NewNetAddress(c.conn.RemoteAddr()) } func (c *Connection) flush() { @@ -79,41 +91,42 @@ func (c *Connection) outHandler() { FOR_LOOP: for { + var err error select { case <-c.pingDebouncer.Ch: - PACKET_TYPE_PING.WriteTo(c.conn) + _, err = PACKET_TYPE_PING.WriteTo(c.conn) case outMsg := <-c.outQueue: - _, err := outMsg.WriteTo(c.conn) - if err != nil { Panicf("TODO: handle error %v", err) } + _, err = outMsg.WriteTo(c.conn) case <-c.pong: - PACKET_TYPE_PONG.WriteTo(c.conn) + _, err = PACKET_TYPE_PONG.WriteTo(c.conn) case <-c.quit: break FOR_LOOP } + + if err != nil { + log.Infof("Connection %v failed @ outHandler:\n%v", c, err) + c.Stop() + break FOR_LOOP + } + c.flush() } - // cleanup - for _ = range c.outQueue { - // do nothing but drain. - } } func (c *Connection) inHandler() { - defer func() { - if e := recover(); e != nil { - // Get stack trace - buf := make([]byte, 1<<16) - runtime.Stack(buf, false) - // TODO do proper logging - fmt.Printf("Disconnecting due to error:\n\n%v\n", string(buf)) - c.Disconnect() - } - }() - //FOR_LOOP: + FOR_LOOP: for { - msgType := ReadUInt8(c.conn) + msgType, err := ReadUInt8Safe(c.conn) + + if err != nil { + if atomic.LoadUint32(&c.stopped) != 1 { + log.Infof("Connection %v failed @ inHandler", c) + c.Stop() + } + break FOR_LOOP + } switch msgType { case PACKET_TYPE_PING: @@ -121,12 +134,29 @@ func (c *Connection) inHandler() { case PACKET_TYPE_PONG: // do nothing case PACKET_TYPE_MSG: - ReadByteSlice(c.conn) + msg, err := ReadByteSliceSafe(c.conn) + if err != nil { + if atomic.LoadUint32(&c.stopped) != 1 { + log.Infof("Connection %v failed @ inHandler", c) + c.Stop() + } + break FOR_LOOP + } + // What to do? + // TODO + default: Panicf("Unknown message type %v", msgType) } + c.pingDebouncer.Reset() } + + // cleanup + close(c.pong) + for _ = range c.pong { + // drain + } } diff --git a/peer/connection_test.go b/peer/connection_test.go index 99f6c00e..e72e15d8 100644 --- a/peer/connection_test.go +++ b/peer/connection_test.go @@ -2,15 +2,50 @@ package peer import ( "testing" + "time" ) func TestLocalConnection(t *testing.T) { - c1 := NewClient("tcp", ":8080") - c2 := NewClient("tcp", ":8081") + c1 := NewClient(func(conn *Connection) *Peer { + p := &Peer{conn: conn} + ch1 := NewChannel(String("ch1"), + nil, + // XXX these channels should be buffered. + make(chan ByteSlice), + make(chan ByteSlice), + ) + + ch2 := NewChannel(String("ch2"), + nil, + make(chan ByteSlice), + make(chan ByteSlice), + ) + + channels := make(map[String]*Channel) + channels[ch1.Name] = ch1 + channels[ch2.Name] = ch2 + p.channels = channels + + return p + }) + + // XXX make c2 like c1. + + c2 := NewClient(func(conn *Connection) *Peer { + return nil + }) + + // XXX clients don't have "local addresses" c1.ConnectTo(c2.LocalAddress()) + // lets send a message from c1 to c2. + c1.Broadcast(String(""), String("message")) + time.Sleep(500 * time.Millisecond) + + inMsg := c2.PopMessage() + c1.Stop() c2.Stop() } diff --git a/peer/filter.go b/peer/filter.go new file mode 100644 index 00000000..c0177e8a --- /dev/null +++ b/peer/filter.go @@ -0,0 +1,16 @@ +package peer + +/* Filter + + A Filter could be a bloom filter for lossy filtering, or could be a lossless filter. + Either way, it's used to keep track of what a peer knows of. +*/ +type Filter interface { + Binary + Add(Msg) + Has(Msg) bool + + // Loads a new filter. + // Convenience factory method + Load(ByteSlice) Filter +} diff --git a/peer/listener.go b/peer/listener.go index 72ec9e40..c19a5a2e 100644 --- a/peer/listener.go +++ b/peer/listener.go @@ -1,60 +1,76 @@ package peer import ( - "sync" + "atomic" "net" ) /* Listener */ -type Listener struct { - listener net.Listener - handler func(net.Conn) - mtx sync.Mutex - closed bool +type Listener interface { + Connections() <-chan *Connection + LocalAddress() *NetAddress + Stop() } -func NewListener(protocol string, laddr string, handler func(net.Conn)) *Listener { + +/* DefaultListener */ + +type DefaultListener struct { + listener net.Listener + connections chan *Connection + stopped uint32 +} + +const ( + DEFAULT_BUFFERED_CONNECTIONS = 10 +) + +func NewListener(protocol string, laddr string) *Listener { ln, err := net.Listen(protocol, laddr) if err != nil { panic(err) } s := &Listener{ - listener: ln, - handler: handler, + listener: ln, + connections: make(chan *Connection, DEFAULT_BUFFERED_CONNECTIONS), } - go s.listen() + go l.listenHandler() return s } -func (s *Listener) listen() { +func (l *Listener) listenHandler() { for { - conn, err := s.listener.Accept() - if err != nil { - // lock & defer - s.mtx.Lock(); defer s.mtx.Unlock() - if s.closed { - return - } else { - panic(err) - } - // unlock (deferred) - } + conn, err := l.listener.Accept() - go s.handler(conn) + if atomic.LoadUint32(&l.stopped) == 1 { return } + + // listener wasn't stopped, + // yet we encountered an error. + if err != nil { panic(err) } + + c := NewConnection(con) + l.connections <- c + } + + // cleanup + close(l.connections) + for _ = range l.connections { + // drain } } -func (s *Listener) LocalAddress() *NetAddress { - return NewNetAddress(s.listener.Addr()) +func (l *Listener) Connections() <-chan *Connection { + return l.connections } -func (s *Listener) Close() { - // lock - s.mtx.Lock() - s.closed = true - s.mtx.Unlock() - // unlock - s.listener.Close() +func (l *Listener) LocalAddress() *NetAddress { + return NewNetAddress(l.listener.Addr()) +} + +func (l *Listener) Stop() { + if atomic.CompareAndSwapUint32(&l.stopped, 0, 1) { + l.listener.Close() + } } diff --git a/peer/msg.go b/peer/msg.go new file mode 100644 index 00000000..6581c56f --- /dev/null +++ b/peer/msg.go @@ -0,0 +1,30 @@ +package peer + +/* Msg */ + +type Msg struct { + Bytes ByteSlice + Hash ByteSlice +} + + +/* InboundMsg */ + +type InboundMsg struct { + Peer *Peer + Channel *Channel + Time Time + Msg +} + + +/* NewFilterMsg */ + +type NewFilterMsg struct { + ChName String + Filter Filter +} + +func (m *NewFilterMsg) WriteTo(w io.Writer) (int64, error) { + return 0, nil // TODO +} diff --git a/peer/peer.go b/peer/peer.go index 38dab2ff..e7f84d4a 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1,715 +1,182 @@ package peer import ( - "bytes" - "container/list" - "fmt" - "github.com/davecgh/go-spew/spew" - "github.com/tendermint/btcwire" - "net" - "strconv" + "atomic" "sync" - "sync/atomic" - "time" ) -const ( - // max protocol version the peer supports. - maxProtocolVersion = 70001 +/* Peer */ - // number of elements the output channels use. - outputBufferSize = 50 +type Peer struct { + outgoing bool + conn *Connection + channels map[String]*Channel - // number of seconds of inactivity before we timeout a peer - // that hasn't completed the initial version negotiation. - negotiateTimeoutSeconds = 30 - - // number of minutes of inactivity before we time out a peer. - idleTimeoutMinutes = 5 - - // number of minutes since we last sent a message - // requiring a reply before we will ping a host. - pingTimeoutMinutes = 2 -) - -var ( - userAgentName = "tendermintd" - userAgentVersion = fmt.Sprintf("%d.%d.%d", appMajor, appMinor, appPatch) -) - -// zeroHash is the zero value hash (all zeros). It is defined as a convenience. -var zeroHash btcwire.ShaHash - -// minUint32 is a helper function to return the minimum of two uint32s. -// This avoids a math import and the need to cast to floats. -func minUint32(a, b uint32) uint32 { - if a < b { - return a - } - return b + mtx sync.Mutex + quit chan struct{} + stopped uint32 } -// TODO(davec): Rename and comment this -type outMsg struct { - msg btcwire.Message - doneChan chan bool -} - -/* -The overall data flow is split into 2 goroutines. - -Inbound messages are read via the inHandler goroutine and generally -dispatched to their own handler. - -Outbound messages are queued via QueueMessage. -*/ -type peer struct { - server *server - addr *NetAddress - inbound bool - persistent bool - - started bool // atomic - quit chan bool - - conn net.Conn - connMtx sync.Mutex - disconnected bool // atomic && protected by connMtx - knownAddresses map[string]bool - outputQueue chan outMsg - - statMtx sync.Mutex // protects all below here. - protocolVersion uint32 - timeConnected time.Time - lastSend time.Time - lastRecv time.Time - bytesReceived uint64 - bytesSent uint64 - userAgent string - lastPingNonce uint64 // Set to nonce if we have a pending ping. - lastPingTime time.Time // Time we sent last ping. - lastPingMicros int64 // Time for last ping to return. -} - -// String returns the peer's address and directionality as a human-readable -// string. -func (p *peer) String() string { - return fmt.Sprintf("%s (%s)", p.addr.String(), directionString(p.inbound)) -} - -// VersionKnown returns the whether or not the version of a peer is known locally. -// It is safe for concurrent access. -func (p *peer) VersionKnown() bool { - p.statMtx.Lock(); defer p.statMtx.Unlock() - - return p.protocolVersion != 0 -} - -// ProtocolVersion returns the peer protocol version in a manner that is safe -// for concurrent access. -func (p *peer) ProtocolVersion() uint32 { - p.statMtx.Lock(); defer p.statMtx.Unlock() - - return p.protocolVersion -} - -// pushVersionMsg sends a version message to the connected peer using the -// current state. -func (p *peer) pushVersionMsg() { - _, blockNum, err := p.server.db.NewestSha() - if err != nil { panic(err) } - - // Version message. - // TODO: DisableListen -> send zero address - msg := btcwire.NewMsgVersion( - p.server.addrManager.getBestLocalAddress(p.addr), p.addr, - p.server.nonce, int32(blockNum)) - msg.AddUserAgent(userAgentName, userAgentVersion) - - // Advertise our max supported protocol version. - msg.ProtocolVersion = maxProtocolVersion - - p.QueueMessage(msg, nil) -} - -// handleVersionMsg is invoked when a peer receives a version bitcoin message -// and is used to negotiate the protocol version details as well as kick start -// the communications. -func (p *peer) handleVersionMsg(msg *btcwire.MsgVersion) { - // Detect self connections. - if msg.Nonce == p.server.nonce { - peerLog.Debugf("Disconnecting peer connected to self %s", p) - p.Disconnect() - return - } - - p.statMtx.Lock() // Updating a bunch of stats. - // Limit to one version message per peer. - if p.protocolVersion != 0 { - p.logError("Only one version message per peer is allowed %s.", p) - p.statMtx.Unlock() - p.Disconnect() - return - } - - // Negotiate the protocol version. - p.protocolVersion = minUint32(p.protocolVersion, uint32(msg.ProtocolVersion)) - peerLog.Debugf("Negotiated protocol version %d for peer %s", p.protocolVersion, p) - - // Set the remote peer's user agent. - p.userAgent = msg.UserAgent - - p.statMtx.Unlock() - - // Inbound connections. - if p.inbound { - // Send version. - p.pushVersionMsg() - } - - // Send verack. - p.QueueMessage(btcwire.NewMsgVerAck(), nil) - - if p.inbound { - // A peer might not be advertising the same address that it - // actually connected from. One example of why this can happen - // is with NAT. Only add the address to the address manager if - // the addresses agree. - if msg.AddrMe.String() == p.addr.String() { - p.server.addrManager.AddAddress(p.addr, p.addr) - } - } else { - // Request known addresses from the remote peer. - if !cfg.SimNet && p.server.addrManager.NeedMoreAddresses() { - p.QueueMessage(btcwire.NewMsgGetAddr(), nil) - } - } - - // Mark the address as a known good address. - p.server.addrManager.MarkGood(p.addr) - - // Signal the block manager this peer is a new sync candidate. - p.server.blockManager.NewPeer(p) - - // TODO: Relay alerts. -} - - -// handleGetAddrMsg is invoked when a peer receives a getaddr bitcoin message -// and is used to provide the peer with known addresses from the address -// manager. -func (p *peer) handleGetAddrMsg(msg *btcwire.MsgGetAddr) { - // Don't return any addresses when running on the simulation test - // network. This helps prevent the network from becoming another - // public test network since it will not be able to learn about other - // peers that have not specifically been provided. - if cfg.SimNet { - return - } - - // Get the current known addresses from the address manager. - addrCache := p.server.addrManager.AddressCache() - - // Push the addresses. - p.pushAddrMsg(addrCache) -} - -// pushAddrMsg sends one, or more, addr message(s) to the connected peer using -// the provided addresses. -func (p *peer) pushAddrMsg(addresses []*NetAddress) { - // Nothing to send. - if len(addresses) == 0 { return } - - numAdded := 0 - msg := btcwire.NewMsgAddr() - for _, addr := range addresses { - // Filter addresses the peer already knows about. - if p.knownAddresses[addr.String()] { - continue - } - - // Add the address to the message. - err := msg.AddAddress(addr) - if err != nil { panic(err) } // XXX remove error condition - numAdded++ - - // Split into multiple messages as needed. - if numAdded > 0 && numAdded%btcwire.MaxAddrPerMsg == 0 { - p.QueueMessage(msg, nil) - - // NOTE: This needs to be a new address message and not - // simply call ClearAddresses since the message is a - // pointer and queueing it does not make a copy. - msg = btcwire.NewMsgAddr() - } - } - - // Send message with remaining addresses if needed. - if numAdded%btcwire.MaxAddrPerMsg != 0 { - p.QueueMessage(msg, nil) +func (p *Peer) Start(peerInQueues map[String]chan *InboundMsg ) { + for chName, _ := range p.channels { + go p.inHandler(chName, peerInQueues[chName]) + go p.outHandler(chName) } } -// handleAddrMsg is invoked when a peer receives an addr bitcoin message and -// is used to notify the server about advertised addresses. -func (p *peer) handleAddrMsg(msg *btcwire.MsgAddr) { - // Ignore addresses when running on the simulation test network. This - // helps prevent the network from becoming another public test network - // since it will not be able to learn about other peers that have not - // specifically been provided. - if cfg.SimNet { - return +func (p *Peer) Stop() { + // lock + p.mtx.Lock() + if atomic.CompareAndSwapUint32(&p.stopped, 0, 1) { + close(p.quit) + p.conn.Stop() } - - // A message that has no addresses is invalid. - if len(msg.AddrList) == 0 { - p.logError("Command [%s] from %s does not contain any addresses", msg.Command(), p) - p.Disconnect() - return - } - - for _, addr := range msg.AddrList { - // Set the timestamp to 5 days ago if it's more than 24 hours - // in the future so this address is one of the first to be - // removed when space is needed. - now := time.Now() - if addr.Timestamp.After(now.Add(time.Minute * 10)) { - addr.Timestamp = now.Add(-1 * time.Hour * 24 * 5) - } - - // Add address to known addresses for this peer. - p.knownAddresses[addr.String()] = true - } - - // Add addresses to server address manager. The address manager handles - // the details of things such as preventing duplicate addresses, max - // addresses, and last seen updates. - // XXX bitcoind gives a 2 hour time penalty here, do we want to do the - // same? - p.server.addrManager.AddAddresses(msg.AddrList, p.addr) + p.mtx.Unlock() + // unlock } -func (p *peer) handlePingMsg(msg *btcwire.MsgPing) { - // Include nonce from ping so pong can be identified. - p.QueueMessage(btcwire.NewMsgPong(msg.Nonce), nil) +func (p *Peer) LocalAddress() *NetAddress { + return p.conn.LocalAddress() } -func (p *peer) handlePongMsg(msg *btcwire.MsgPong) { - p.statMtx.Lock(); defer p.statMtx.Unlock() - - // Arguably we could use a buffered channel here sending data - // in a fifo manner whenever we send a ping, or a list keeping track of - // the times of each ping. For now we just make a best effort and - // only record stats if it was for the last ping sent. Any preceding - // and overlapping pings will be ignored. It is unlikely to occur - // without large usage of the ping rpc call since we ping - // infrequently enough that if they overlap we would have timed out - // the peer. - if p.lastPingNonce != 0 && msg.Nonce == p.lastPingNonce { - p.lastPingMicros = time.Now().Sub(p.lastPingTime).Nanoseconds() - p.lastPingMicros /= 1000 // convert to usec. - p.lastPingNonce = 0 - } +func (p *Peer) RemoteAddress() *NetAddress { + return p.conn.RemoteAddress() } -// readMessage reads the next bitcoin message from the peer with logging. -func (p *peer) readMessage() (btcwire.Message, []byte, error) { - n, msg, buf, err := btcwire.ReadMessageN(p.conn, p.ProtocolVersion()) - p.statMtx.Lock() - p.bytesReceived += uint64(n) - p.statMtx.Unlock() - p.server.AddBytesReceived(uint64(n)) - if err != nil { - return nil, nil, err - } - - // Use closures to log expensive operations so they are only run when - // the logging level requires it. - peerLog.Debugf("%v", newLogClosure(func() string { - // Debug summary of message. - summary := messageSummary(msg) - if len(summary) > 0 { - summary = " (" + summary + ")" - } - return fmt.Sprintf("Received %v%s from %s", msg.Command(), summary, p) - })) - peerLog.Tracef("%v", newLogClosure(func() string { - return spew.Sdump(msg) - })) - peerLog.Tracef("%v", newLogClosure(func() string { - return spew.Sdump(buf) - })) - - return msg, buf, nil +func (p *Peer) Channel(chName String) *Channel { + return p.channels[chName] } -// writeMessage sends a bitcoin Message to the peer with logging. -func (p *peer) writeMessage(msg btcwire.Message) { - if p.Disconnected() { return } +// If msg isn't already in the peer's filter, then +// queue the msg for output. +// If the queue is full, just return false. +func (p *Peer) TryQueueOut(chName String, msg Msg) bool { + channel := p.Channel(chName) + outQueue := channel.OutQueue() - if !p.VersionKnown() { - switch msg.(type) { - case *btcwire.MsgVersion: - // This is OK. - default: - // We drop all messages other than version if we - // haven't done the handshake already. - return - } + // just return if already in filter + if channel.Filter().Has(msg) { + return true } - // Use closures to log expensive operations so they are only run when - // the logging level requires it. - peerLog.Debugf("%v", newLogClosure(func() string { - // Debug summary of message. - summary := messageSummary(msg) - if len(summary) > 0 { - summary = " (" + summary + ")" - } - return fmt.Sprintf("Sending %v%s to %s", msg.Command(), summary, p) - })) - peerLog.Tracef("%v", newLogClosure(func() string { - return spew.Sdump(msg) - })) - peerLog.Tracef("%v", newLogClosure(func() string { - var buf bytes.Buffer - err := btcwire.WriteMessage(&buf, msg, p.ProtocolVersion()) - if err != nil { - return err.Error() - } - return spew.Sdump(buf.Bytes()) - })) - - // Write the message to the peer. - n, err := btcwire.WriteMessageN(p.conn, msg, p.ProtocolVersion()) - p.statMtx.Lock() - p.bytesSent += uint64(n) - p.statMtx.Unlock() - p.server.AddBytesSent(uint64(n)) - if err != nil { - p.Disconnect() - p.logError("Can't send message to %s: %v", p, err) - return + // lock & defer + p.mtx.Lock(); defer p.mtx.Unlock() + if p.stopped == 1 { return false } + select { + case outQueue <- msg: + return true + default: // buffer full + return false } + // unlock deferred } - -// inHandler handles all incoming messages for the peer. It must be run as a -// goroutine. -func (p *peer) inHandler() { - // Peers must complete the initial version negotiation within a shorter - // timeframe than a general idle timeout. The timer is then reset below - // to idleTimeoutMinutes for all future messages. - idleTimer := time.AfterFunc(negotiateTimeoutSeconds*time.Second, func() { - if p.VersionKnown() { - peerLog.Warnf("Peer %s no answer for %d minutes, disconnecting", p, idleTimeoutMinutes) - } - p.Disconnect() - }) -out: - for !p.Disconnected() { - rmsg, buf, err := p.readMessage() - // Stop the timer now, if we go around again we will reset it. - idleTimer.Stop() - if err != nil { - if !p.Disconnected() { - p.logError("Can't read message from %s: %v", p, err) - } - break out - } - p.statMtx.Lock() - p.lastRecv = time.Now() - p.statMtx.Unlock() - - // Ensure version message comes first. - if _, ok := rmsg.(*btcwire.MsgVersion); !ok && !p.VersionKnown() { - p.logError("A version message must precede all others") - break out - } - - // Handle each supported message type. - markGood := false - switch msg := rmsg.(type) { - case *btcwire.MsgVersion: - p.handleVersionMsg(msg) - - case *btcwire.MsgVerAck: - // Do nothing. - - case *btcwire.MsgGetAddr: - p.handleGetAddrMsg(msg) - - case *btcwire.MsgAddr: - p.handleAddrMsg(msg) - markGood = true - - case *btcwire.MsgPing: - p.handlePingMsg(msg) - markGood = true - - case *btcwire.MsgPong: - p.handlePongMsg(msg) - - case *btcwire.MsgAlert: - p.server.BroadcastMessage(msg, p) - - case *btcwire.MsgNotFound: - // TODO(davec): Ignore this for now, but ultimately - // it should probably be used to detect when something - // we requested needs to be re-requested from another - // peer. - - default: - peerLog.Debugf("Received unhandled message of type %v: Fix Me", rmsg.Command()) - } - - // Mark the address as currently connected and working as of - // now if one of the messages that trigger it was processed. - if markGood && !p.Disconnected() { - if p.addr == nil { - peerLog.Warnf("we're getting stuff before we got a version message. that's bad") - continue - } - p.server.addrManager.MarkGood(p.addr) - } - // ok we got a message, reset the timer. - // timer just calls p.Disconnect() after logging. - idleTimer.Reset(idleTimeoutMinutes * time.Minute) - } - - idleTimer.Stop() - - // Ensure connection is closed and notify the server that the peer is done. - p.Disconnect() - p.server.donePeers <- p - - // Only tell block manager we are gone if we ever told it we existed. - if p.VersionKnown() { - p.server.blockManager.DonePeer(p) - } - - peerLog.Tracef("Peer input handler done for %s", p) +func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { + return p.RemoteAddress().WriteTo(w) } -// outHandler handles all outgoing messages for the peer. It must be run as a -// goroutine. It uses a buffered channel to serialize output messages while -// allowing the sender to continue running asynchronously. -func (p *peer) outHandler() { - pingTimer := time.AfterFunc(pingTimeoutMinutes*time.Minute, func() { - nonce, err := btcwire.RandomUint64() - if err != nil { - peerLog.Errorf("Not sending ping on timeout to %s: %v", - p, err) - return - } - p.QueueMessage(btcwire.NewMsgPing(nonce), nil) - }) -out: +func (p *Peer) inHandler(chName String, inboundMsgQueue chan<- *InboundMsg) { + channel := p.channels[chName] + inQueue := channel.InQueue() + + FOR_LOOP: for { select { - case msg := <-p.outputQueue: - // If the message is one we should get a reply for - // then reset the timer, we only want to send pings - // when otherwise we would not receive a reply from - // the peer. - peerLog.Tracef("%s: received from outputQueue", p) - reset := true - switch m := msg.msg.(type) { - case *btcwire.MsgVersion: - // should get an ack - case *btcwire.MsgGetAddr: - // should get addresses - case *btcwire.MsgPing: - // expects pong - // Also set up statistics. - p.statMtx.Lock() - p.lastPingNonce = m.Nonce - p.lastPingTime = time.Now() - p.statMtx.Unlock() - default: - // Not one of the above, no sure reply. - // We want to ping if nothing else - // interesting happens. - reset = false + case <-quit: + break FOR_LOOP + case msg := <-inQueue: + // add to channel filter + channel.Filter().Add(msg) + // send to inboundMsgQueue + inboundMsg := &InboundMsg{ + Peer: p, + Channel: channel, + Time: Time(time.Now()), + Msg: msg, } - if reset { - pingTimer.Reset(pingTimeoutMinutes * time.Minute) - } - p.writeMessage(msg.msg) - p.statMtx.Lock() - p.lastSend = time.Now() - p.statMtx.Unlock() - if msg.doneChan != nil { - msg.doneChan <- true - } - - case <-p.quit: - break out - } - } - - pingTimer.Stop() - - // Drain outputQueue - for msg := range p.outputQueue { - if msg.doneChan != nil { - msg.doneChan <- false - } - } - peerLog.Tracef("Peer output handler done for %s", p) -} - -// QueueMessage adds the passed bitcoin message to the peer outputQueue. It -// uses a buffered channel to communicate with the output handler goroutine so -// it is automatically rate limited and safe for concurrent access. -func (p *peer) QueueMessage(msg btcwire.Message, doneChan chan bool) { - // Avoid risk of deadlock if goroutine already exited. The goroutine - // we will be sending to hangs around until it knows for a fact that - // it is marked as disconnected. *then* it drains the channels. - if p.Disconnected() { - // avoid deadlock... - if doneChan != nil { - go func() { - doneChan <- false - }() - } - return - } - p.outputQueue <- outMsg{msg: msg, doneChan: doneChan} -} - -// True if is (or will become) disconnected. -func (p *peer) Disconnected() bool { - return atomic.LoadInt32(&p.disconnected) == 1 -} - -// Disconnects the peer by closing the connection. It also sets -// a flag so the impending shutdown can be detected. -func (p *peer) Disconnect() { - p.connMtx.Lock(); defer p.connMtx.Unlock() - // did we win the race? - if atomic.AddInt32(&p.disconnected, 1) != 1 { - return - } - peerLog.Tracef("disconnecting %s", p) - close(p.quit) - if p.conn != nil { - p.conn.Close() - } -} - -// Sets the connection & starts -func (p *peer) StartWithConnection(conn *net.Conn) { - p.connMtx.Lock(); defer p.connMtx.Unlock() - if p.conn != nil { panic("Conn already set") } - if atomic.LoadInt32(&p.disconnected) == 1 { return } - peerLog.Debugf("Connected to %s", conn.RemoteAddr()) - p.timeConnected = time.Now() - p.conn = conn - p.Start() -} - -// Start begins processing input and output messages. It also sends the initial -// version message for outbound connections to start the negotiation process. -func (p *peer) Start() error { - // Already started? - if atomic.AddInt32(&p.started, 1) != 1 { - return nil - } - - peerLog.Tracef("Starting peer %s", p) - - // Send an initial version message if this is an outbound connection. - if !p.inbound { - p.pushVersionMsg() - } - - // Start processing input and output. - go p.inHandler() - go p.outHandler() - - return nil -} - -// Shutdown gracefully shuts down the peer by disconnecting it. -func (p *peer) Shutdown() { - peerLog.Tracef("Shutdown peer %s", p) - p.Disconnect() -} - -// newPeerBase returns a new base peer for the provided server and inbound flag. -// This is used by the newInboundPeer and newOutboundPeer functions to perform -// base setup needed by both types of peers. -func newPeerBase(s *server, inbound bool) *peer { - p := peer{ - server: s, - protocolVersion: maxProtocolVersion, - inbound: inbound, - knownAddresses: make(map[string]bool), - outputQueue: make(chan outMsg, outputBufferSize), - quit: make(chan bool), - } - return &p -} - -// newPeer returns a new inbound bitcoin peer for the provided server and -// connection. Use Start to begin processing incoming and outgoing messages. -func newInboundPeer(s *server, conn net.Conn) *peer { - addr := NewNetAddress(conn.RemoteAddr()) - // XXX What if p.addr doesn't match (to be) reported addr due to NAT? - s.addrManager.MarkAttempt(addr) - - p := newPeerBase(s, true) - p.conn = conn - p.addr = addr - p.timeConnected = time.Now() - return p -} - -// newOutbountPeer returns a new outbound bitcoin peer for the provided server and -// address and connects to it asynchronously. If the connection is successful -// then the peer will also be started. -func newOutboundPeer(s *server, addr *NetAddress, persistent bool) *peer { - p := newPeerBase(s, false) - p.addr = addr - p.persistent = persistent - - go func() { - // Mark this as one attempt, regardless of # of reconnects. - s.addrManager.MarkAttempt(p.addr) - retryCount := 0 - // Attempt to connect to the peer. If the connection fails and - // this is a persistent connection, retry after the retry - // interval. - for { - peerLog.Debugf("Attempting to connect to %s", addr) - conn, err := addr.Dial() - if err == nil { - p.StartWithConnection(conn) - return - } else { - retryCount++ - peerLog.Debugf("Failed to connect to %s: %v", addr, err) - if !persistent { - p.server.donePeers <- p - return - } - scaledInterval := connectionRetryInterval.Nanoseconds() * retryCount / 2 - scaledDuration := time.Duration(scaledInterval) - peerLog.Debugf("Retrying connection to %s in %s", addr, scaledDuration) - time.Sleep(scaledDuration) + select { + case <-quit: + break FOR_LOOP + case inboundMsgQueue <- inboundMsg: continue } } - }() - return p + } + + // cleanup + // (none) } -// logError makes sure that we only log errors loudly on user peers. -func (p *peer) logError(fmt string, args ...interface{}) { - if p.persistent { - peerLog.Errorf(fmt, args...) - } else { - peerLog.Debugf(fmt, args...) +func (p *Peer) outHandler(chName String) { + outQueue := p.channels[chName].OutQueue() + FOR_LOOP: + for { + select { + case <-quit: + break FOR_LOOP + case msg := <-outQueue: + // blocks until the connection is Stop'd, + // which happens when this peer is Stop'd. + p.conn.QueueOut(msg.Bytes) + } + } + + // cleanup + // (none) +} + + +/* Channel */ + +type Channel struct { + name String + + mtx sync.Mutex + filter Filter + + inQueue chan Msg + outQueue chan Msg + //stats Stats +} + +func NewChannel(name String, filter Filter, in, out chan Msg) *Channel { + return &Channel{ + name: name, + filter: filter, + inQueue: in, + outQueue: out, } } + +func (c *Channel) InQueue() <-chan Msg { + return c.inQueue +} + +func (c *Channel) OutQueue() chan<- Msg { + return c.outQueue +} + +func (c *Channel) Add(msg Msg) { + c.Filter().Add(msg) +} + +func (c *Channel) Has(msg Msg) bool { + return c.Filter().Has(msg) +} + +// TODO: maybe don't expose this +func (c *Channel) Filter() Filter { + // lock & defer + c.mtx.Lock(); defer c.mtx.Unlock() + return c.filter + // unlock deferred +} + +// TODO: maybe don't expose this +func (c *Channel) UpdateFilter(filter Filter) { + // lock + c.mtx.Lock() + c.filter = filter + c.mtx.Unlock() + // unlock +} diff --git a/peer/server.go b/peer/server.go new file mode 100644 index 00000000..787e5f7a --- /dev/null +++ b/peer/server.go @@ -0,0 +1,32 @@ +package peer + +import ( +) + +/* Server */ + +type Server struct { + listener Listener + client *Client +} + +func NewServer(l Listener, c *Client) *Server { + s := &Server{ + listener: l, + client: c, + } + go s.IncomingConnectionsHandler() + return s +} + +// meant to run in a goroutine +func (s *Server) IncomingConnectionHandler() { + for conn := range s.listener.Connections() { + s.client.AddIncomingConnection(conn) + } +} + +func (s *Server) Stop() { + s.listener.Stop() + s.client.Stop() +} diff --git a/peer/upnp.go b/peer/upnp.go new file mode 100644 index 00000000..711440c9 --- /dev/null +++ b/peer/upnp.go @@ -0,0 +1,368 @@ +// from taipei-torrent + +package peer + +// Just enough UPnP to be able to forward ports +// + +import ( + "bytes" + "encoding/xml" + "errors" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + "time" +) + +type upnpNAT struct { + serviceURL string + ourIP string + urnDomain string +} + +func Discover() (nat NAT, err error) { + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return + } + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return + } + socket := conn.(*net.UDPConn) + defer socket.Close() + + err = socket.SetDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + + st := "InternetGatewayDevice:1" + + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + "ST: ssdp:all\r\n" + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < 3; i++ { + _, err = socket.WriteToUDP(message, ssdp) + if err != nil { + return + } + var n int + n, _, err = socket.ReadFromUDP(answerBytes) + for { + n, _, err = socket.ReadFromUDP(answerBytes) + if err != nil { + break + } + answer := string(answerBytes[0:n]) + if strings.Index(answer, st) < 0 { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation:" + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := strings.TrimSpace(loc[0:endIndex]) + var serviceURL, urnDomain string + serviceURL, urnDomain, err = getServiceURL(locURL) + if err != nil { + return + } + var ourIP net.IP + ourIP, err = localIPv4() + if err != nil { + return + } + nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} + return + } + } + err = errors.New("UPnP port discovery failed.") + return +} + +type Envelope struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` + Soap *SoapBody +} +type SoapBody struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` + ExternalIP *ExternalIPAddressResponse +} + +type ExternalIPAddressResponse struct { + XMLName xml.Name `xml:"GetExternalIPAddressResponse"` + IPAddress string `xml:"NewExternalIPAddress"` +} + +type ExternalIPAddress struct { + XMLName xml.Name `xml:"NewExternalIPAddress"` + IP string +} + +type Service struct { + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` +} + +type DeviceList struct { + Device []Device `xml:"device"` +} + +type ServiceList struct { + Service []Service `xml:"service"` +} + +type Device struct { + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList DeviceList `xml:"deviceList"` + ServiceList ServiceList `xml:"serviceList"` +} + +type Root struct { + Device Device +} + +func getChildDevice(d *Device, deviceType string) *Device { + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if strings.Index(dl[i].DeviceType, deviceType) >= 0 { + return &dl[i] + } + } + return nil +} + +func getChildService(d *Device, serviceType string) *Service { + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if strings.Index(sl[i].ServiceType, serviceType) >= 0 { + return &sl[i] + } + } + return nil +} + +func localIPv4() (net.IP, error) { + tt, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, t := range tt { + aa, err := t.Addrs() + if err != nil { + return nil, err + } + for _, a := range aa { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { // loopback address + continue + } + return v4, nil + } + } + return nil, errors.New("cannot find local IP address") +} + +func getServiceURL(rootURL string) (url, urnDomain string, err error) { + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(string(r.StatusCode)) + return + } + var root Root + err = xml.NewDecoder(r.Body).Decode(&root) + if err != nil { + return + } + a := &root.Device + if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { + err = errors.New("No InternetGatewayDevice") + return + } + b := getChildDevice(a, "WANDevice:1") + if b == nil { + err = errors.New("No WANDevice") + return + } + c := getChildDevice(b, "WANConnectionDevice:1") + if c == nil { + err = errors.New("No WANConnectionDevice") + return + } + d := getChildService(c, "WANIPConnection:1") + if d == nil { + // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, + // instead of under WanConnectionDevice + d = getChildService(b, "WANIPConnection:1") + + if d == nil { + err = errors.New("No WANIPConnection") + return + } + } + // Extract the domain name, which isn't always 'schemas-upnp-org' + urnDomain = strings.Split(d.ServiceType, ":")[1] + url = combineURL(rootURL, d.ControlURL) + return +} + +func combineURL(rootURL, subURL string) string { + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL +} + +func soapRequest(url, function, message, domain string) (r *http.Response, err error) { + fullMessage := "" + + "\r\n" + + "" + message + "" + + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") + + // log.Stderr("soapRequest ", req) + + r, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + /*if r.Body != nil { + defer r.Body.Close() + }*/ + + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { + + message := "\r\n" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + var envelope Envelope + data, err := ioutil.ReadAll(response.Body) + reader := bytes.NewReader(data) + xml.NewDecoder(reader).Decode(&envelope) + + info = statusInfo{envelope.Soap.ExternalIP.IPAddress} + + if err != nil { + return + } + + return +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + info, err := n.getExternalIPAddress() + if err != nil { + return + } + addr = net.ParseIP(info.externalIpAddress) + return +} + +func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(externalPort) + message += "" + protocol + "" + message += "" + strconv.Itoa(internalPort) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + strconv.Itoa(timeout) + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was forwarded + // log.Println(message, response) + mappedExternalPort = externalPort + _ = response + return +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was deleted + // log.Println(message, response) + _ = response + return +}