diff --git a/cmd/devp2p/crawl.go b/cmd/devp2p/crawl.go index 92aaad72a3..cc68dbcbfc 100644 --- a/cmd/devp2p/crawl.go +++ b/cmd/devp2p/crawl.go @@ -20,14 +20,13 @@ import ( "time" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" ) type crawler struct { input nodeSet output nodeSet - disc *discover.UDPv4 + disc resolver iters []enode.Iterator inputIter enode.Iterator ch chan *enode.Node @@ -37,7 +36,11 @@ type crawler struct { revalidateInterval time.Duration } -func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler { +type resolver interface { + RequestENR(*enode.Node) (*enode.Node, error) +} + +func newCrawler(input nodeSet, disc resolver, iters ...enode.Iterator) *crawler { c := &crawler{ input: input, output: make(nodeSet, len(input)), diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index e9195b8b10..3dd62daffe 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -83,6 +83,18 @@ var ( Name: "bootnodes", Usage: "Comma separated nodes used for bootstrapping", } + nodekeyFlag = cli.StringFlag{ + Name: "nodekey", + Usage: "Hex-encoded node key", + } + nodedbFlag = cli.StringFlag{ + Name: "nodedb", + Usage: "Nodes database location", + } + listenAddrFlag = cli.StringFlag{ + Name: "addr", + Usage: "Listening address", + } crawlTimeoutFlag = cli.DurationFlag{ Name: "timeout", Usage: "Time limit for the crawl.", @@ -180,29 +192,31 @@ func discv4Crawl(ctx *cli.Context) error { return nil } -func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { - s := utils.GetBootstrapNodes(ctx) - if ctx.IsSet(bootnodesFlag.Name) { - s = strings.Split(ctx.String(bootnodesFlag.Name), ",") - } - nodes := make([]*enode.Node, len(s)) - var err error - for i, record := range s { - nodes[i], err = parseNode(record) - if err != nil { - return nil, fmt.Errorf("invalid bootstrap node: %v", err) - } - } - return nodes, nil -} - // startV4 starts an ephemeral discovery V4 node. func startV4(ctx *cli.Context) *discover.UDPv4 { networkId := ctx.GlobalUint64(networkIdFlag.Name) - socket, ln, cfg, err := listen(networkId) + ln, config := makeDiscoveryConfig(ctx, networkId) + socket := listen(ln, ctx.String(listenAddrFlag.Name)) + disc, err := discover.ListenV4(socket, ln, config) if err != nil { exit(err) } + return disc +} + +func makeDiscoveryConfig(ctx *cli.Context, networkId uint64) (*enode.LocalNode, discover.Config) { + var cfg discover.Config + + if ctx.IsSet(nodekeyFlag.Name) { + key, err := crypto.HexToECDSA(ctx.String(nodekeyFlag.Name)) + if err != nil { + exit(fmt.Errorf("-%s: %v", nodekeyFlag.Name, err)) + } + cfg.PrivateKey = key + } else { + cfg.PrivateKey, _ = crypto.GenerateKey() + } + if commandHasFlag(ctx, bootnodesFlag) { bn, err := parseBootnodes(ctx) if err != nil { @@ -210,26 +224,43 @@ func startV4(ctx *cli.Context) *discover.UDPv4 { } cfg.Bootnodes = bn } - disc, err := discover.ListenV4(socket, ln, cfg) + + dbpath := ctx.String(nodedbFlag.Name) + db, err := enode.OpenDB(dbpath) if err != nil { exit(err) } - return disc -} - -func listen(networkId uint64) (*net.UDPConn, *enode.LocalNode, discover.Config, error) { - var cfg discover.Config - cfg.PrivateKey, _ = crypto.GenerateKey() - db, _ := enode.OpenDB("") ln := enode.NewLocalNode(db, cfg.PrivateKey, networkId) + return ln, cfg +} - socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{0, 0, 0, 0}}) +func listen(ln *enode.LocalNode, addr string) *net.UDPConn { + if addr == "" { + addr = "0.0.0.0:0" + } + socket, err := net.ListenPacket("udp4", addr) if err != nil { - db.Close() - return nil, nil, cfg, err + exit(err) } - addr := socket.LocalAddr().(*net.UDPAddr) + usocket := socket.(*net.UDPConn) + uaddr := socket.LocalAddr().(*net.UDPAddr) ln.SetFallbackIP(net.IP{127, 0, 0, 1}) - ln.SetFallbackUDP(addr.Port) - return socket, ln, cfg, nil + ln.SetFallbackUDP(uaddr.Port) + return usocket +} + +func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { + s := utils.GetBootstrapNodes(ctx) + if ctx.IsSet(bootnodesFlag.Name) { + s = strings.Split(ctx.String(bootnodesFlag.Name), ",") + } + nodes := make([]*enode.Node, len(s)) + var err error + for i, record := range s { + nodes[i], err = parseNode(record) + if err != nil { + return nil, fmt.Errorf("invalid bootstrap node: %v", err) + } + } + return nodes, nil } diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go new file mode 100644 index 0000000000..cb23de5b9f --- /dev/null +++ b/cmd/devp2p/discv5cmd.go @@ -0,0 +1,124 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/p2p/discover" + "gopkg.in/urfave/cli.v1" +) + +var ( + discv5Command = cli.Command{ + Name: "discv5", + Usage: "Node Discovery v5 tools", + Subcommands: []cli.Command{ + discv5PingCommand, + discv5ResolveCommand, + discv5CrawlCommand, + discv5ListenCommand, + }, + } + discv5PingCommand = cli.Command{ + Name: "ping", + Usage: "Sends ping to a node", + Action: discv5Ping, + } + discv5ResolveCommand = cli.Command{ + Name: "resolve", + Usage: "Finds a node in the DHT", + Action: discv5Resolve, + Flags: []cli.Flag{bootnodesFlag}, + } + discv5CrawlCommand = cli.Command{ + Name: "crawl", + Usage: "Updates a nodes.json file with random nodes found in the DHT", + Action: discv5Crawl, + Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, + } + discv5ListenCommand = cli.Command{ + Name: "listen", + Usage: "Runs a node", + Action: discv5Listen, + Flags: []cli.Flag{ + bootnodesFlag, + nodekeyFlag, + nodedbFlag, + listenAddrFlag, + }, + } +) + +func discv5Ping(ctx *cli.Context) error { + n := getNodeArg(ctx) + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Ping(n)) + return nil +} + +func discv5Resolve(ctx *cli.Context) error { + n := getNodeArg(ctx) + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Resolve(n)) + return nil +} + +func discv5Crawl(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + nodesFile := ctx.Args().First() + var inputSet nodeSet + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) + } + + disc := startV5(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, disc.RandomNodes()) + c.revalidateInterval = 10 * time.Minute + output := c.run(ctx.Duration(crawlTimeoutFlag.Name)) + writeNodesJSON(nodesFile, output) + return nil +} + +func discv5Listen(ctx *cli.Context) error { + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Self()) + select {} +} + +// startV5 starts an ephemeral discovery v5 node. +func startV5(ctx *cli.Context) *discover.UDPv5 { + networkId := ctx.GlobalUint64(networkIdFlag.Name) + ln, config := makeDiscoveryConfig(ctx, networkId) + socket := listen(ln, ctx.String(listenAddrFlag.Name)) + disc, err := discover.ListenV5(socket, ln, config) + if err != nil { + exit(err) + } + return disc +} diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go index 6faa650937..d5e777811a 100644 --- a/cmd/devp2p/main.go +++ b/cmd/devp2p/main.go @@ -59,6 +59,7 @@ func init() { app.Commands = []cli.Command{ enrdumpCommand, discv4Command, + discv5Command, dnsCommand, nodesetCommand, } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index 74178caa52..6d40d35ebb 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -20,8 +20,10 @@ import ( "crypto/ecdsa" "net" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" ) @@ -44,6 +46,21 @@ type Config struct { Unhandled chan<- ReadPacket // unhandled packets are sent on this channel Log log.Logger // if set, log messages go here PingIPFromPacket bool + ValidSchemes enr.IdentityScheme // allowed identity schemes + Clock mclock.Clock +} + +func (cfg Config) withDefaults() Config { + if cfg.Log == nil { + cfg.Log = log.Root() + } + if cfg.ValidSchemes == nil { + cfg.ValidSchemes = enode.ValidSchemes + } + if cfg.Clock == nil { + cfg.Clock = mclock.System{} + } + return cfg } // ListenUDP starts listening for discovery packets on the given UDP socket. @@ -52,8 +69,15 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { } // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled -// channel if configured. This is exported for internal use, do not use this type. +// channel if configured. type ReadPacket struct { Data []byte Addr *net.UDPAddr } + +func min(x, y int) int { + if x > y { + return y + } + return x +} diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index ab825fb05d..9ab4a71ce7 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -104,9 +104,7 @@ func (it *lookup) startQueries() bool { // The first query returns nodes from the local table. if it.queries == -1 { - it.tab.mutex.Lock() - closest := it.tab.closest(it.result.target, bucketSize, false) - it.tab.mutex.Unlock() + closest := it.tab.findnodeByID(it.result.target, bucketSize, false) // Avoid finishing the lookup too quickly if table is empty. It'd be better to wait // for the table to fill in this case, but there is no good mechanism for that // yet. @@ -150,11 +148,14 @@ func (it *lookup) query(n *node, reply chan<- []*node) { } else if len(r) == 0 { fails++ it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) - it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) - if fails >= maxFindnodeFailures { - it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) + // Remove the node from the local table if it fails to return anything useful too + // many times, but only if there are enough other nodes in the bucket. + dropped := false + if fails >= maxFindnodeFailures && it.tab.bucketLen(n.ID()) >= bucketSize/2 { + dropped = true it.tab.delete(n) } + it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "failcount", fails, "dropped", dropped, "err", err) } else if fails > 0 { // Reset failure counter because it counts _consecutive_ failures. it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0) diff --git a/p2p/discover/node.go b/p2p/discover/node.go index a7d9ce7368..e635c64ac9 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -18,6 +18,7 @@ package discover import ( "crypto/ecdsa" + "crypto/elliptic" "errors" "math/big" "net" @@ -45,13 +46,13 @@ func encodePubkey(key *ecdsa.PublicKey) encPubkey { return e } -func decodePubkey(e encPubkey) (*ecdsa.PublicKey, error) { - p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} +func decodePubkey(curve elliptic.Curve, e encPubkey) (*ecdsa.PublicKey, error) { + p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} half := len(e) / 2 p.X.SetBytes(e[:half]) p.Y.SetBytes(e[half:]) if !p.Curve.IsOnCurve(p.X, p.Y) { - return nil, errors.New("invalid secp256k1 curve point") + return nil, errors.New("invalid curve point") } return p, nil } @@ -60,17 +61,6 @@ func (e encPubkey) id() enode.ID { return enode.ID(crypto.Keccak256Hash(e[:])) } -// recoverNodeKey computes the public key used to sign the -// given hash from the signature. -func recoverNodeKey(hash, sig []byte) (key encPubkey, err error) { - pubkey, err := crypto.Ecrecover(hash, sig) - if err != nil { - return key, err - } - copy(key[:], pubkey[1:]) - return key, nil -} - func wrapNode(n *enode.Node) *node { return &node{Node: *n} } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 706d4d2a89..d7837fd778 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -362,10 +362,12 @@ func (tab *Table) doRevalidate(done chan<- struct{}) { } // No reply received, pick a replacement or delete the node if there aren't // any replacements. - if r := tab.replace(b, last); r != nil { - tab.log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) - } else { + if r := tab.replace(b, last); r == nil { tab.log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks) + } else if r == last { + tab.log.Debug("Left dead node in bucket", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) + } else { + tab.log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) } } @@ -407,22 +409,35 @@ func (tab *Table) copyLiveNodes() { } } -// closest returns the n nodes in the table that are closest to the -// given id. The caller must hold tab.mutex. -func (tab *Table) closest(target enode.ID, nresults int, checklive bool) *nodesByDistance { - // This is a very wasteful way to find the closest nodes but - // obviously correct. I believe that tree-based buckets would make - // this easier to implement efficiently. - close := &nodesByDistance{target: target} +// findnodeByID returns the n nodes in the table that are closest to the given id. +// This is used by the FINDNODE/v4 handler. +// +// The preferLive parameter says whether the caller wants liveness-checked results. If +// preferLive is true and the table contains any verified nodes, the result will not +// contain unverified nodes. However, if there are no verified nodes at all, the result +// will contain unverified nodes. +func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + // Scan all buckets. There might be a better way to do this, but there aren't that many + // buckets, so this solution should be fine. The worst-case complexity of this loop + // is O(tab.len() * nresults). + nodes := &nodesByDistance{target: target} + liveNodes := &nodesByDistance{target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - if checklive && n.livenessChecks == 0 { - continue + nodes.push(n, nresults) + if preferLive && n.livenessChecks > 0 { + liveNodes.push(n, nresults) } - close.push(n, nresults) } } - return close + + if preferLive && len(liveNodes.entries) > 0 { + return liveNodes + } + return nodes } // len returns the number of nodes in the table. @@ -436,9 +451,21 @@ func (tab *Table) len() (n int) { return n } +// bucketLen returns the number of nodes in the bucket for the given ID. +func (tab *Table) bucketLen(id enode.ID) int { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + return len(tab.bucket(id).entries) +} + // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(id enode.ID) *bucket { d := enode.LogDist(tab.self().ID(), id) + return tab.bucketAtDistance(d) +} + +func (tab *Table) bucketAtDistance(d int) *bucket { if d <= bucketMinDistance { return tab.buckets[0] } @@ -573,6 +600,7 @@ func (tab *Table) addReplacement(b *bucket, n *node) { // replace removes n from the replacement list and replaces 'last' with it if it is the // last entry in the bucket. If 'last' isn't the last entry, it has either been replaced // with someone else or became active. +// If last is the only node in the bucket and there are no replacements, leave it there. func (tab *Table) replace(b *bucket, last *node) *node { if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() { // Entry has moved, don't replace it. @@ -580,6 +608,9 @@ func (tab *Table) replace(b *bucket, last *node) *node { } // Still the last entry. if len(b.replacements) == 0 { + if len(b.entries) == 1 { + return last + } tab.deleteInBucket(b, last) return nil } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 32e1bb37eb..270385cdfa 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -189,7 +189,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) { } } -func TestTable_closest(t *testing.T) { +func TestTable_findnodeByID(t *testing.T) { t.Parallel() test := func(test *closeTest) bool { @@ -201,7 +201,7 @@ func TestTable_closest(t *testing.T) { fillTable(tab, test.All) // check that closest(Target, N) returns nodes - result := tab.closest(test.Target, test.N, false).entries + result := tab.findnodeByID(test.Target, test.N, false).entries if hasDuplicates(result) { t.Errorf("result contains duplicates") return false diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 7c89f7cf8d..5b0f5354df 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -24,7 +24,6 @@ import ( "fmt" "math/rand" "net" - "reflect" "sort" "sync" @@ -56,6 +55,23 @@ func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) } +// nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld. +func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node { + results := make([]*enode.Node, n) + for i := range results { + results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i))) + } + return results +} + +func nodesToRecords(nodes []*enode.Node) []*enr.Record { + records := make([]*enr.Record, len(nodes)) + for i := range nodes { + records[i] = nodes[i].Record() + } + return records +} + // idAtDistance returns a random hash such that enode.LogDist(a, b) == n func idAtDistance(a enode.ID, n int) (b enode.ID) { if n == 0 { @@ -172,9 +188,16 @@ func hasDuplicates(slice []*node) bool { } func checkNodesEqual(got, want []*enode.Node) error { - if reflect.DeepEqual(got, want) { - return nil + if len(got) == len(want) { + for i := range got { + if !nodeEqual(got[i], want[i]) { + goto NotEqual + } + return nil + } } + +NotEqual: output := new(bytes.Buffer) fmt.Fprintf(output, "got %d nodes:\n", len(got)) for _, n := range got { @@ -187,6 +210,10 @@ func checkNodesEqual(got, want []*enode.Node) error { return errors.New(output.String()) } +func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { + return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) +} + func sortByID(nodes []*enode.Node) { sort.Slice(nodes, func(i, j int) bool { return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 9b4042c5a2..2009385262 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -24,7 +24,9 @@ import ( "testing" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" ) func TestUDPv4_Lookup(t *testing.T) { @@ -32,7 +34,7 @@ func TestUDPv4_Lookup(t *testing.T) { test := newUDPTest(t) // Lookup on empty table returns no nodes. - targetKey, _ := decodePubkey(lookupTestnet.target) + targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target) if results := test.udp.LookupPubkey(targetKey); len(results) > 0 { t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) } @@ -59,15 +61,7 @@ func TestUDPv4_Lookup(t *testing.T) { if len(results) != bucketSize { t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) } - if hasDuplicates(wrapNodes(results)) { - t.Errorf("result set contains duplicate entries") - } - if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) { - t.Errorf("result set not sorted by distance to target") - } - if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil { - t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err) - } + checkLookupResults(t, lookupTestnet, results) } func TestUDPv4_LookupIterator(t *testing.T) { @@ -142,20 +136,40 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { func serveTestnet(test *udpTest, testnet *preminedTestnet) { for done := false; !done; { - done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) { + done = test.waitPacketOut(func(p v4wire.Packet, to *net.UDPAddr, hash []byte) { n, key := testnet.nodeByAddr(to) switch p.(type) { - case *pingV4: - test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash}) - case *findnodeV4: + case *v4wire.Ping: + test.packetInFrom(nil, key, to, &v4wire.Pong{Expiration: futureExp, ReplyTok: hash}) + case *v4wire.Findnode: dist := enode.LogDist(n.ID(), testnet.target.id()) nodes := testnet.nodesAtDistance(dist - 1) - test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes}) + test.packetInFrom(nil, key, to, &v4wire.Neighbors{Expiration: futureExp, Nodes: nodes}) } }) } } +// checkLookupResults verifies that the results of a lookup are the closest nodes to +// the testnet's target. +func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node) { + t.Helper() + t.Logf("results:") + for _, e := range results { + t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes()) + } + if hasDuplicates(wrapNodes(results)) { + t.Errorf("result set contains duplicate entries") + } + if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) { + t.Errorf("result set not sorted by distance to target") + } + wantNodes := tn.closest(len(results)) + if err := checkNodesEqual(results, wantNodes); err != nil { + t.Error(err) + } +} + // This is the test network for the Lookup test. // The nodes were obtained by running lookupTestnet.mine with a random NodeID as target. var lookupTestnet = &preminedTestnet{ @@ -242,8 +256,12 @@ func (tn *preminedTestnet) nodes() []*enode.Node { func (tn *preminedTestnet) node(dist, index int) *enode.Node { key := tn.dists[dist][index] - ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)} - return enode.NewV4(&key.PublicKey, ip, 0, 5000) + rec := new(enr.Record) + rec.Set(enr.IP{127, byte(dist >> 8), byte(dist), byte(index)}) + rec.Set(enr.UDP(5000)) + enode.SignV4(rec, key) + n, _ := enode.New(enode.ValidSchemes, rec) + return n } func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) { @@ -253,14 +271,27 @@ func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.Pr return tn.node(dist, index), key } -func (tn *preminedTestnet) nodesAtDistance(dist int) []rpcNode { - result := make([]rpcNode, len(tn.dists[dist])) +func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node { + result := make([]v4wire.Node, len(tn.dists[dist])) for i := range result { result[i] = nodeToRPC(wrapNode(tn.node(dist, i))) } return result } +func (tn *preminedTestnet) neighborsAtDistance(base *enode.Node, distance uint, elems int) []*enode.Node { + nodes := nodesByDistance{target: base.ID()} + for d := range lookupTestnet.dists { + for i := range lookupTestnet.dists[d] { + n := lookupTestnet.node(d, i) + if uint(enode.LogDist(n.ID(), base.ID())) == distance { + nodes.push(wrapNode(n), elems) + } + } + } + return unwrapNodes(nodes.entries) +} + func (tn *preminedTestnet) closest(n int) (nodes []*enode.Node) { for d := range tn.dists { for i := range tn.dists[d] { diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 1d423437ba..8eee9ec016 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -31,18 +31,14 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) -var celoClientSalt = []byte{0x63, 0x65, 0x6C, 0x6F} - // Errors var ( - errPacketTooSmall = errors.New("too small") - errBadHash = errors.New("bad hash") errExpired = errors.New("expired") errUnsolicitedReply = errors.New("unsolicited reply") errUnknownNode = errors.New("unknown node") @@ -50,6 +46,7 @@ var ( errClockWarp = errors.New("reply deadline too far in the future") errClosed = errors.New("socket closed") errBadNetworkId = errors.New("bad networkId") + errLowPort = errors.New("low port") ) const ( @@ -68,137 +65,6 @@ const ( maxPacketSize = 1280 ) -// RPC packet types -const ( - p_pingV4 = iota + 1 // zero is 'reserved' - p_pongV4 - p_findnodeV4 - p_neighborsV4 - p_enrRequestV4 - p_enrResponseV4 -) - -// RPC request structures -type ( - pingV4 struct { - senderKey *ecdsa.PublicKey // filled in by preverify - - Version uint - From, To rpcEndpoint - Expiration uint64 - NetworkId uint64 - - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - // pongV4 is the reply to pingV4. - pongV4 struct { - // This field should mirror the UDP envelope address - // of the ping packet, which provides a way to discover the - // the external address (after NAT). - To rpcEndpoint - - ReplyTok []byte // This contains the hash of the ping packet. - Expiration uint64 // Absolute timestamp at which the packet becomes invalid. - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - // findnodeV4 is a query for nodes close to the given target. - findnodeV4 struct { - Target encPubkey - Expiration uint64 - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - // neighborsV4 is the reply to findnodeV4. - neighborsV4 struct { - Nodes []rpcNode - Expiration uint64 - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - // enrRequestV4 queries for the remote node's record. - enrRequestV4 struct { - Expiration uint64 - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - // enrResponseV4 is the reply to enrRequestV4. - enrResponseV4 struct { - ReplyTok []byte // Hash of the enrRequest packet. - Record enr.Record - // Ignore additional fields (for forward compatibility). - Rest []rlp.RawValue `rlp:"tail"` - } - - rpcNode struct { - IP net.IP // len 4 for IPv4 or 16 for IPv6 - UDP uint16 // for discovery protocol - TCP uint16 // for RLPx protocol - ID encPubkey - } - - rpcEndpoint struct { - IP net.IP // len 4 for IPv4 or 16 for IPv6 - UDP uint16 // for discovery protocol - TCP uint16 // for RLPx protocol - } -) - -// packetV4 is implemented by all v4 protocol messages. -type packetV4 interface { - // preverify checks whether the packet is valid and should be handled at all. - preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error - // handle handles the packet. - handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) - // packet name and type for logging purposes. - name() string - kind() byte -} - -func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { - ip := net.IP{} - if ip4 := addr.IP.To4(); ip4 != nil { - ip = ip4 - } else if ip6 := addr.IP.To16(); ip6 != nil { - ip = ip6 - } - return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} -} - -func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*node, error) { - if rn.UDP <= 1024 { - return nil, errors.New("low port") - } - if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { - return nil, err - } - if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { - return nil, errors.New("not contained in netrestrict whitelist") - } - key, err := decodePubkey(rn.ID) - if err != nil { - return nil, err - } - n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) - err = n.ValidateComplete() - return n, err -} - -func nodeToRPC(n *node) rpcNode { - var key ecdsa.PublicKey - var ekey encPubkey - if err := n.Load((*enode.Secp256k1)(&key)); err == nil { - ekey = encodePubkey(&key) - } - return rpcNode{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())} -} - // UDPv4 implements the v4 wire protocol. type UDPv4 struct { conn UDPConn @@ -215,7 +81,7 @@ type UDPv4 struct { addReplyMatcher chan *replyMatcher gotreply chan reply closeCtx context.Context - cancelCloseCtx func() + cancelCloseCtx context.CancelFunc } // replyMatcher represents a pending reply. @@ -248,22 +114,23 @@ type replyMatcher struct { // reply contains the most recent reply. This field is safe for reading after errc has // received a value. - reply packetV4 + reply v4wire.Packet } -type replyMatchFunc func(interface{}) (matched bool, requestDone bool) +type replyMatchFunc func(v4wire.Packet) (matched bool, requestDone bool) // reply is a reply packet from a certain node. type reply struct { from enode.ID ip net.IP - data packetV4 + data v4wire.Packet // loop indicates whether there was // a matching request by sending on this channel. matched chan<- bool } func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { + cfg = cfg.withDefaults() closeCtx, cancel := context.WithCancel(context.Background()) t := &UDPv4{ conn: c, @@ -278,9 +145,6 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, pingIPFromPacket: cfg.PingIPFromPacket, } - if t.log == nil { - t.log = log.Root() - } tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) if err != nil { @@ -344,10 +208,10 @@ func (t *UDPv4) Info() *TableInfo { return t.tab.Info() } -func (t *UDPv4) ourEndpoint() rpcEndpoint { +func (t *UDPv4) ourEndpoint() v4wire.Endpoint { n := t.Self() a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} - return makeEndpoint(a, uint16(n.TCP())) + return v4wire.NewEndpoint(a, uint16(n.TCP())) } // Ping sends a ping message to the given node. @@ -360,7 +224,7 @@ func (t *UDPv4) Ping(n *enode.Node) error { func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { rm := t.sendPing(n.ID(), &net.UDPAddr{IP: n.IP(), Port: n.UDP()}, nil) if err = <-rm.errc; err == nil { - seq = seqFromTail(rm.reply.(*pongV4).Rest) + seq = rm.reply.(*v4wire.Pong).ENRSeq() } return seq, err } @@ -369,7 +233,7 @@ func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { // when the reply arrives. func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *replyMatcher { req := t.makePing(toaddr) - packet, hash, err := t.encode(t.priv, req) + packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { errc := make(chan error, 1) errc <- err @@ -377,8 +241,8 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r } // Add a matcher for the reply to the pending reply queue. Pongs are matched if they // reference the ping we're about to send. - rm := t.pending(toid, toaddr.IP, p_pongV4, func(p interface{}) (matched bool, requestDone bool) { - matched = bytes.Equal(p.(*pongV4).ReplyTok, hash) + rm := t.pending(toid, toaddr.IP, v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { + matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash) if matched && callback != nil { callback() } @@ -386,16 +250,16 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r }) // Send the packet. t.localNode.UDPContact(toaddr) - t.write(toaddr, toid, req.name(), packet) + t.write(toaddr, toid, req.Name(), packet) return rm } -func (t *UDPv4) makePing(toaddr *net.UDPAddr) *pingV4 { +func (t *UDPv4) makePing(toaddr *net.UDPAddr) *v4wire.Ping { seq, _ := rlp.EncodeToBytes(t.localNode.Node().Seq()) - return &pingV4{ + return &v4wire.Ping{ Version: 4, From: t.ourEndpoint(), - To: makeEndpoint(toaddr, 0), + To: v4wire.NewEndpoint(toaddr, 0), Expiration: uint64(time.Now().Add(expiration).Unix()), NetworkId: t.localNode.NetworkId(), Rest: []rlp.RawValue{seq}, @@ -435,23 +299,24 @@ func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { target := enode.ID(crypto.Keccak256Hash(targetKey[:])) + ekey := v4wire.Pubkey(targetKey) it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { - return t.findnode(n.ID(), n.addr(), targetKey) + return t.findnode(n.ID(), n.addr(), ekey) }) return it } // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { +func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubkey) ([]*node, error) { t.ensureBond(toid, toaddr) // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // active until enough nodes have been received. nodes := make([]*node, 0, bucketSize) nreceived := 0 - rm := t.pending(toid, toaddr.IP, p_neighborsV4, func(r interface{}) (matched bool, requestDone bool) { - reply := r.(*neighborsV4) + rm := t.pending(toid, toaddr.IP, v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + reply := r.(*v4wire.Neighbors) for _, rn := range reply.Nodes { nreceived++ n, err := t.nodeFromRPC(toaddr, rn) @@ -463,11 +328,20 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ( } return true, nreceived >= bucketSize }) - t.send(toaddr, toid, &findnodeV4{ + t.send(toaddr, toid, &v4wire.Findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) - return nodes, <-rm.errc + // Ensure that callers don't see a timeout if the node actually responded. Since + // findnode can receive more than one neighbors response, the reply matcher will be + // active until the remote node sends enough nodes. If the remote end doesn't have + // enough nodes the reply matcher will time out waiting for the second reply, but + // there's no need for an error in that case. + err := <-rm.errc + if err == errTimeout && rm.reply != nil { + err = nil + } + return nodes, err } // RequestENR sends enrRequest to the given node and waits for a response. @@ -475,26 +349,27 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} t.ensureBond(n.ID(), addr) - req := &enrRequestV4{ + req := &v4wire.ENRRequest{ Expiration: uint64(time.Now().Add(expiration).Unix()), } - packet, hash, err := t.encode(t.priv, req) + packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { return nil, err } + // Add a matcher for the reply to the pending reply queue. Responses are matched if // they reference the request we're about to send. - rm := t.pending(n.ID(), addr.IP, p_enrResponseV4, func(r interface{}) (matched bool, requestDone bool) { - matched = bytes.Equal(r.(*enrResponseV4).ReplyTok, hash) + rm := t.pending(n.ID(), addr.IP, v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash) return matched, matched }) // Send the packet and wait for the reply. - t.write(addr, n.ID(), req.name(), packet) + t.write(addr, n.ID(), req.Name(), packet) if err := <-rm.errc; err != nil { return nil, err } // Verify the response record. - respN, err := enode.New(enode.ValidSchemes, &rm.reply.(*enrResponseV4).Record) + respN, err := enode.New(enode.ValidSchemes, &rm.reply.(*v4wire.ENRResponse).Record) if err != nil { return nil, err } @@ -526,7 +401,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF // handleReply dispatches a reply packet, invoking reply matchers. It returns // whether any matcher considered the packet acceptable. -func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req packetV4) bool { +func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req v4wire.Packet) bool { matched := make(chan bool, 1) select { case t.gotreply <- reply{from, fromIP, req, matched}: @@ -592,12 +467,12 @@ func (t *UDPv4) loop() { var matched bool // whether any replyMatcher considered the reply acceptable. for el := plist.Front(); el != nil; el = el.Next() { p := el.Value.(*replyMatcher) - if p.from == r.from && p.ptype == r.data.kind() && (t.pingIPFromPacket || p.ip.Equal(r.ip)) { + if p.from == r.from && p.ptype == r.data.Kind() && (t.pingIPFromPacket || p.ip.Equal(r.ip)) { ok, requestDone := p.callback(r.data) matched = matched || ok + p.reply = r.data // Remove the matcher if callback indicates that all replies have been received. if requestDone { - p.reply = r.data p.errc <- nil plist.Remove(el) } @@ -629,44 +504,12 @@ func (t *UDPv4) loop() { } } -const ( - macSize = 256 / 8 - sigSize = 520 / 8 - headSize = macSize + sigSize // space of packet frame data -) - -var ( - headSpace = make([]byte, headSize) - - // Neighbors replies are sent across multiple packets to - // stay below the packet size limit. We compute the maximum number - // of entries by stuffing a packet until it grows too large. - maxNeighbors int -) - -func init() { - p := neighborsV4{Expiration: ^uint64(0)} - maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} - for n := 0; ; n++ { - p.Nodes = append(p.Nodes, maxSizeNode) - size, _, err := rlp.EncodeToReader(p) - if err != nil { - // If this ever happens, it will be caught by the unit tests. - panic("cannot encode: " + err.Error()) - } - if headSize+size+1 >= maxPacketSize { - maxNeighbors = n - break - } - } -} - -func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req packetV4) ([]byte, error) { - packet, hash, err := t.encode(t.priv, req) +func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]byte, error) { + packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { return hash, err } - return hash, t.write(toaddr, toid, req.name(), packet) + return hash, t.write(toaddr, toid, req.Name(), packet) } func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { @@ -678,30 +521,6 @@ func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet [] return err } -func (t *UDPv4) encode(priv *ecdsa.PrivateKey, req packetV4) (packet, hash []byte, err error) { - name := req.name() - b := new(bytes.Buffer) - b.Write(headSpace) - b.WriteByte(req.kind()) - if err := rlp.Encode(b, req); err != nil { - t.log.Error(fmt.Sprintf("Can't encode %s packet", name), "err", err) - return nil, nil, err - } - packet = b.Bytes() - sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) - if err != nil { - t.log.Error(fmt.Sprintf("Can't sign %s packet", name), "err", err) - return nil, nil, err - } - copy(packet[macSize:], sig) - // add the hash to the front. Note: this doesn't protect the - // packet in any way. Our public key will be part of this hash in - // The future. - hash = crypto.Keccak256(packet[macSize:], celoClientSalt) - copy(packet, hash) - return packet, hash, nil -} - // readLoop runs in its own goroutine. it handles incoming UDP packets. func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { defer t.wg.Done() @@ -734,58 +553,23 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { } func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { - packet, fromKey, hash, err := decodeV4(buf) + rawpacket, fromKey, hash, err := v4wire.Decode(buf) if err != nil { t.log.Debug("Bad discv4 packet", "addr", from, "err", err) return err } - fromID := fromKey.id() - if err == nil { - err = packet.preverify(t, from, fromID, fromKey) + packet := t.wrapPacket(rawpacket) + fromID := fromKey.ID() + if err == nil && packet.preverify != nil { + err = packet.preverify(packet, from, fromID, fromKey) } - t.log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err) - if err == nil { - packet.handle(t, from, fromID, hash) + t.log.Trace("<< "+packet.Name(), "id", fromID, "addr", from, "err", err) + if err == nil && packet.handle != nil { + packet.handle(packet, from, fromID, hash) } return err } -func decodeV4(buf []byte) (packetV4, encPubkey, []byte, error) { - if len(buf) < headSize+1 { - return nil, encPubkey{}, nil, errPacketTooSmall - } - hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] - shouldhash := crypto.Keccak256(buf[macSize:], celoClientSalt) - if !bytes.Equal(hash, shouldhash) { - return nil, encPubkey{}, nil, errBadHash - } - fromKey, err := recoverNodeKey(crypto.Keccak256(buf[headSize:]), sig) - if err != nil { - return nil, fromKey, hash, err - } - - var req packetV4 - switch ptype := sigdata[0]; ptype { - case p_pingV4: - req = new(pingV4) - case p_pongV4: - req = new(pongV4) - case p_findnodeV4: - req = new(findnodeV4) - case p_neighborsV4: - req = new(neighborsV4) - case p_enrRequestV4: - req = new(enrRequestV4) - case p_enrResponseV4: - req = new(enrResponseV4) - default: - return nil, fromKey, hash, fmt.Errorf("unknown type: %d", ptype) - } - s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) - err = s.Decode(req) - return req, fromKey, hash, err -} - // checkBond checks if the given node has a recent enough endpoint proof. func (t *UDPv4) checkBond(id enode.ID, ip net.IP) bool { return time.Since(t.db.LastPongReceived(id, ip)) < bondExpiration @@ -803,52 +587,102 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { } } -// expired checks whether the given UNIX time stamp is in the past. -func expired(ts uint64) bool { - return time.Unix(int64(ts), 0).Before(time.Now()) +func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) { + if rn.UDP <= 1024 { + return nil, errLowPort + } + if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + return nil, err + } + if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { + return nil, errors.New("not contained in netrestrict whitelist") + } + key, err := v4wire.DecodePubkey(crypto.S256(), rn.ID) + if err != nil { + return nil, err + } + n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) + err = n.ValidateComplete() + return n, err } -func seqFromTail(tail []rlp.RawValue) uint64 { - if len(tail) == 0 { - return 0 - } - var seq uint64 - rlp.DecodeBytes(tail[0], &seq) - return seq +func nodeToRPC(n *node) v4wire.Node { + var key ecdsa.PublicKey + var ekey v4wire.Pubkey + if err := n.Load((*enode.Secp256k1)(&key)); err == nil { + ekey = v4wire.EncodePubkey(&key) + } + return v4wire.Node{ID: ekey, IP: n.IP(), UDP: uint16(n.UDP()), TCP: uint16(n.TCP())} +} + +// wrapPacket returns the handler functions applicable to a packet. +func (t *UDPv4) wrapPacket(p v4wire.Packet) *packetHandlerV4 { + var h packetHandlerV4 + h.Packet = p + switch p.(type) { + case *v4wire.Ping: + h.preverify = t.verifyPing + h.handle = t.handlePing + case *v4wire.Pong: + h.preverify = t.verifyPong + case *v4wire.Findnode: + h.preverify = t.verifyFindnode + h.handle = t.handleFindnode + case *v4wire.Neighbors: + h.preverify = t.verifyNeighbors + case *v4wire.ENRRequest: + h.preverify = t.verifyENRRequest + h.handle = t.handleENRRequest + case *v4wire.ENRResponse: + h.preverify = t.verifyENRResponse + } + return &h +} + +// packetHandlerV4 wraps a packet with handler functions. +type packetHandlerV4 struct { + v4wire.Packet + senderKey *ecdsa.PublicKey // used for ping + + // preverify checks whether the packet is valid and should be handled at all. + preverify func(p *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error + // handle handles the packet. + handle func(req *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) } // PING/v4 -func (req *pingV4) name() string { return "PING/v4" } -func (req *pingV4) kind() byte { return p_pingV4 } +func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + req := h.Packet.(*v4wire.Ping) -func (req *pingV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { if t.localNode.NetworkId() != req.NetworkId { return errBadNetworkId } - if expired(req.Expiration) { - return errExpired - } - key, err := decodePubkey(fromKey) + senderKey, err := v4wire.DecodePubkey(crypto.S256(), fromKey) if err != nil { - return errors.New("invalid public key") + return err } - req.senderKey = key + if v4wire.Expired(req.Expiration) { + return errExpired + } + h.senderKey = senderKey return nil } -func (req *pingV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { + req := h.Packet.(*v4wire.Ping) + // Reply. seq, _ := rlp.EncodeToBytes(t.localNode.Node().Seq()) - t.send(from, fromID, &pongV4{ - To: makeEndpoint(from, req.From.TCP), + t.send(from, fromID, &v4wire.Pong{ + To: v4wire.NewEndpoint(from, req.From.TCP), ReplyTok: mac, Expiration: uint64(time.Now().Add(expiration).Unix()), Rest: []rlp.RawValue{seq}, }) // Ping back if our last pong on file is too far in the past. - n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port)) + n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { t.sendPing(fromID, from, func() { t.tab.addVerifiedNode(n) @@ -864,31 +698,26 @@ func (req *pingV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []by // PONG/v4 -func (req *pongV4) name() string { return "PONG/v4" } -func (req *pongV4) kind() byte { return p_pongV4 } +func (t *UDPv4) verifyPong(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + req := h.Packet.(*v4wire.Pong) -func (req *pongV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { - if expired(req.Expiration) { + if v4wire.Expired(req.Expiration) { return errExpired } if !t.handleReply(fromID, from.IP, req) { return errUnsolicitedReply } - return nil -} - -func (req *pongV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) + return nil } // FINDNODE/v4 -func (req *findnodeV4) name() string { return "FINDNODE/v4" } -func (req *findnodeV4) kind() byte { return p_findnodeV4 } +func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + req := h.Packet.(*v4wire.Findnode) -func (req *findnodeV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { - if expired(req.Expiration) { + if v4wire.Expired(req.Expiration) { return errExpired } if !t.checkBond(fromID, from.IP) { @@ -903,22 +732,22 @@ func (req *findnodeV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, f return nil } -func (req *findnodeV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { + req := h.Packet.(*v4wire.Findnode) + // Determine closest nodes. target := enode.ID(crypto.Keccak256Hash(req.Target[:])) - t.tab.mutex.Lock() - closest := t.tab.closest(target, bucketSize, true).entries - t.tab.mutex.Unlock() + closest := t.tab.findnodeByID(target, bucketSize, true).entries // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the packet size limit. - p := neighborsV4{Expiration: uint64(time.Now().Add(expiration).Unix())} + p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool for _, n := range closest { if netutil.CheckRelayIP(from.IP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } - if len(p.Nodes) == maxNeighbors { + if len(p.Nodes) == v4wire.MaxNeighbors { t.send(from, fromID, &p) p.Nodes = p.Nodes[:0] sent = true @@ -931,29 +760,24 @@ func (req *findnodeV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac // NEIGHBORS/v4 -func (req *neighborsV4) name() string { return "NEIGHBORS/v4" } -func (req *neighborsV4) kind() byte { return p_neighborsV4 } +func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + req := h.Packet.(*v4wire.Neighbors) -func (req *neighborsV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { - if expired(req.Expiration) { + if v4wire.Expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, from.IP, req) { + if !t.handleReply(fromID, from.IP, h.Packet) { return errUnsolicitedReply } return nil } -func (req *neighborsV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { -} - // ENRREQUEST/v4 -func (req *enrRequestV4) name() string { return "ENRREQUEST/v4" } -func (req *enrRequestV4) kind() byte { return p_enrRequestV4 } +func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + req := h.Packet.(*v4wire.ENRRequest) -func (req *enrRequestV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { - if expired(req.Expiration) { + if v4wire.Expired(req.Expiration) { return errExpired } if !t.checkBond(fromID, from.IP) { @@ -962,8 +786,8 @@ func (req *enrRequestV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, return nil } -func (req *enrRequestV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { - t.send(from, fromID, &enrResponseV4{ +func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { + t.send(from, fromID, &v4wire.ENRResponse{ ReplyTok: mac, Record: *t.localNode.Node().Record(), }) @@ -971,15 +795,9 @@ func (req *enrRequestV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, ma // ENRRESPONSE/v4 -func (req *enrResponseV4) name() string { return "ENRRESPONSE/v4" } -func (req *enrResponseV4) kind() byte { return p_enrResponseV4 } - -func (req *enrResponseV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error { - if !t.handleReply(fromID, from.IP, req) { +func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { + if !t.handleReply(fromID, from.IP, h.Packet) { return errUnsolicitedReply } return nil } - -func (req *enrResponseV4) handle(t *UDPv4, from *net.UDPAddr, fromID enode.ID, mac []byte) { -} diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index ba999cea38..87bee5ec57 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -21,8 +21,8 @@ import ( "crypto/ecdsa" crand "crypto/rand" "encoding/binary" - "encoding/hex" "errors" + "fmt" "io" "math/rand" "net" @@ -31,27 +31,20 @@ import ( "testing" "time" - "github.com/davecgh/go-spew/spew" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/v4wire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/rlp" ) -func init() { - spew.Config.DisableMethods = true -} - // shared test variables var ( futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) - testTarget = encPubkey{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} - testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2} - testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4} - testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6} + testTarget = v4wire.Pubkey{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} + testRemote = v4wire.Endpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2} + testLocalAnnounced = v4wire.Endpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4} + testLocal = v4wire.Endpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6} testNetworkId = uint64(1) ) @@ -93,19 +86,19 @@ func (test *udpTest) close() { } // handles a packet as if it had been sent to the transport. -func (test *udpTest) packetIn(wantError error, data packetV4) { +func (test *udpTest) packetIn(wantError error, data v4wire.Packet) { test.t.Helper() test.packetInFrom(wantError, test.remotekey, test.remoteaddr, data) } // handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data packetV4) { +func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data v4wire.Packet) { test.t.Helper() - enc, _, err := test.udp.encode(key, data) + enc, _, err := v4wire.Encode(key, data) if err != nil { - test.t.Errorf("%s encode error: %v", data.name(), err) + test.t.Errorf("%s encode error: %v", data.Name(), err) } test.sent = append(test.sent, enc) if err = test.udp.handlePacket(addr, enc); err != wantError { @@ -118,11 +111,14 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr * func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() - dgram, ok := test.pipe.receive() - if !ok { + dgram, err := test.pipe.receive() + if err == errClosed { return true + } else if err != nil { + test.t.Error("packet receive error:", err) + return false } - p, _, hash, err := decodeV4(dgram.data) + p, _, hash, err := v4wire.Decode(dgram.data) if err != nil { test.t.Errorf("sent packet decode error: %v", err) return false @@ -141,11 +137,11 @@ func TestUDPv4_packetErrors(t *testing.T) { test := newUDPTest(t) defer test.close() - test.packetIn(errExpired, &pingV4{From: testRemote, To: testLocalAnnounced, Version: 4, NetworkId: testNetworkId}) - test.packetIn(errUnsolicitedReply, &pongV4{ReplyTok: []byte{}, Expiration: futureExp}) - test.packetIn(errUnknownNode, &findnodeV4{Expiration: futureExp}) - test.packetIn(errUnsolicitedReply, &neighborsV4{Expiration: futureExp}) - test.packetIn(errBadNetworkId, &pingV4{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId + 1}) + test.packetIn(errExpired, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, NetworkId: testNetworkId}) + test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: []byte{}, Expiration: futureExp}) + test.packetIn(errUnknownNode, &v4wire.Findnode{Expiration: futureExp}) + test.packetIn(errUnsolicitedReply, &v4wire.Neighbors{Expiration: futureExp}) + test.packetIn(errBadNetworkId, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId + 1}) } func TestUDPv4_pingTimeout(t *testing.T) { @@ -163,13 +159,8 @@ func TestUDPv4_pingTimeout(t *testing.T) { type testPacket byte -func (req testPacket) kind() byte { return byte(req) } -func (req testPacket) name() string { return "" } -func (req testPacket) preverify(*UDPv4, *net.UDPAddr, enode.ID, encPubkey) error { - return nil -} -func (req testPacket) handle(*UDPv4, *net.UDPAddr, enode.ID, []byte) { -} +func (req testPacket) Kind() byte { return byte(req) } +func (req testPacket) Name() string { return "" } func TestUDPv4_responseTimeouts(t *testing.T) { t.Parallel() @@ -194,7 +185,7 @@ func TestUDPv4_responseTimeouts(t *testing.T) { // within the timeout window. p := &replyMatcher{ ptype: byte(rand.Intn(255)), - callback: func(interface{}) (bool, bool) { return true, true }, + callback: func(v4wire.Packet) (bool, bool) { return true, true }, } binary.BigEndian.PutUint64(p.from[:], uint64(i)) if p.ptype <= 128 { @@ -250,7 +241,7 @@ func TestUDPv4_findnodeTimeout(t *testing.T) { toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} toid := enode.ID{1, 2, 3, 4} - target := encPubkey{4, 5, 6, 7} + target := v4wire.Pubkey{4, 5, 6, 7} result, err := test.udp.findnode(toid, toaddr, target) if err != errTimeout { t.Error("expected timeout error, got", err) @@ -267,7 +258,7 @@ func TestUDPv4_findnode(t *testing.T) { // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. - nodes := &nodesByDistance{target: testTarget.id()} + nodes := &nodesByDistance{target: testTarget.ID()} live := make(map[enode.ID]bool) numCandidates := 2 * bucketSize for i := 0; i < numCandidates; i++ { @@ -285,32 +276,32 @@ func TestUDPv4_findnode(t *testing.T) { // ensure there's a bond with the test node, // findnode won't be accepted otherwise. - remoteID := encodePubkey(&test.remotekey.PublicKey).id() + remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) // check that closest neighbors are returned. - expected := test.table.closest(testTarget.id(), bucketSize, true) - test.packetIn(nil, &findnodeV4{Target: testTarget, Expiration: futureExp}) + expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) + test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) waitNeighbors := func(want []*node) { - test.waitPacketOut(func(p *neighborsV4, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Neighbors, to *net.UDPAddr, hash []byte) { if len(p.Nodes) != len(want) { t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) } for i, n := range p.Nodes { - if n.ID.id() != want[i].ID() { + if n.ID.ID() != want[i].ID() { t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i]) } - if !live[n.ID.id()] { - t.Errorf("result includes dead node %v", n.ID.id()) + if !live[n.ID.ID()] { + t.Errorf("result includes dead node %v", n.ID.ID()) } } }) } // Receive replies. want := expected.entries - if len(want) > maxNeighbors { - waitNeighbors(want[:maxNeighbors]) - want = want[maxNeighbors:] + if len(want) > v4wire.MaxNeighbors { + waitNeighbors(want[:v4wire.MaxNeighbors]) + want = want[v4wire.MaxNeighbors:] } waitNeighbors(want) } @@ -336,7 +327,7 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { // wait for the findnode to be sent. // after it is sent, the transport is waiting for a reply - test.waitPacketOut(func(p *findnodeV4, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Findnode, to *net.UDPAddr, hash []byte) { if p.Target != testTarget { t.Errorf("wrong target: got %v, want %v", p.Target, testTarget) } @@ -349,12 +340,12 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { wrapNode(enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")), wrapNode(enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")), } - rpclist := make([]rpcNode, len(list)) + rpclist := make([]v4wire.Node, len(list)) for i := range list { rpclist[i] = nodeToRPC(list[i]) } - test.packetIn(nil, &neighborsV4{Expiration: futureExp, Nodes: rpclist[:2]}) - test.packetIn(nil, &neighborsV4{Expiration: futureExp, Nodes: rpclist[2:]}) + test.packetIn(nil, &v4wire.Neighbors{Expiration: futureExp, Nodes: rpclist[:2]}) + test.packetIn(nil, &v4wire.Neighbors{Expiration: futureExp, Nodes: rpclist[2:]}) // check that the sent neighbors are all returned by findnode select { @@ -378,10 +369,10 @@ func TestUDPv4_pingMatch(t *testing.T) { randToken := make([]byte, 32) crand.Read(randToken) - test.packetIn(nil, &pingV4{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) - test.waitPacketOut(func(*pongV4, *net.UDPAddr, []byte) {}) - test.waitPacketOut(func(*pingV4, *net.UDPAddr, []byte) {}) - test.packetIn(errUnsolicitedReply, &pongV4{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) + test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) + test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) + test.waitPacketOut(func(*v4wire.Ping, *net.UDPAddr, []byte) {}) + test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) } // This test checks that reply matching of pong verifies the sender IP address. @@ -389,12 +380,12 @@ func TestUDPv4_pingMatchIP(t *testing.T) { test := newUDPTest(t) defer test.close() - test.packetIn(nil, &pingV4{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) - test.waitPacketOut(func(*pongV4, *net.UDPAddr, []byte) {}) + test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) + test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) - test.waitPacketOut(func(p *pingV4, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} - test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &pongV4{ + test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{ ReplyTok: hash, To: testLocalAnnounced, Expiration: futureExp, @@ -409,15 +400,15 @@ func TestUDPv4_successfulPing(t *testing.T) { defer test.close() // The remote side sends a ping packet to initiate the exchange. - go test.packetIn(nil, &pingV4{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) + go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp, NetworkId: testNetworkId}) // The ping is replied to. - test.waitPacketOut(func(p *pongV4, to *net.UDPAddr, hash []byte) { - pinghash := test.sent[0][:macSize] + test.waitPacketOut(func(p *v4wire.Pong, to *net.UDPAddr, hash []byte) { + pinghash := test.sent[0][:32] if !bytes.Equal(p.ReplyTok, pinghash) { t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash) } - wantTo := rpcEndpoint{ + wantTo := v4wire.Endpoint{ // The mirrored UDP address is the UDP packet sender IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), // The mirrored TCP port is the one from the ping packet @@ -429,11 +420,11 @@ func TestUDPv4_successfulPing(t *testing.T) { }) // Remote is unknown, the table pings back. - test.waitPacketOut(func(p *pingV4, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } - wantTo := rpcEndpoint{ + wantTo := v4wire.Endpoint{ // The mirrored UDP address is the UDP packet sender. IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), @@ -442,7 +433,7 @@ func TestUDPv4_successfulPing(t *testing.T) { if !reflect.DeepEqual(p.To, wantTo) { t.Errorf("got ping.To %v, want %v", p.To, wantTo) } - test.packetIn(nil, &pongV4{ReplyTok: hash, Expiration: futureExp}) + test.packetIn(nil, &v4wire.Pong{ReplyTok: hash, Expiration: futureExp}) }) // The node should be added to the table shortly after getting the @@ -476,25 +467,25 @@ func TestUDPv4_EIP868(t *testing.T) { wantNode := test.udp.localNode.Node() // ENR requests aren't allowed before endpoint proof. - test.packetIn(errUnknownNode, &enrRequestV4{Expiration: futureExp}) + test.packetIn(errUnknownNode, &v4wire.ENRRequest{Expiration: futureExp}) // Perform endpoint proof and check for sequence number in packet tail. - test.packetIn(nil, &pingV4{Expiration: futureExp, NetworkId: testNetworkId}) - test.waitPacketOut(func(p *pongV4, addr *net.UDPAddr, hash []byte) { - if seq := seqFromTail(p.Rest); seq != wantNode.Seq() { - t.Errorf("wrong sequence number in pong: %d, want %d", seq, wantNode.Seq()) + test.packetIn(nil, &v4wire.Ping{Expiration: futureExp, NetworkId: testNetworkId}) + test.waitPacketOut(func(p *v4wire.Pong, addr *net.UDPAddr, hash []byte) { + if p.ENRSeq() != wantNode.Seq() { + t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq(), wantNode.Seq()) } }) - test.waitPacketOut(func(p *pingV4, addr *net.UDPAddr, hash []byte) { - if seq := seqFromTail(p.Rest); seq != wantNode.Seq() { - t.Errorf("wrong sequence number in ping: %d, want %d", seq, wantNode.Seq()) + test.waitPacketOut(func(p *v4wire.Ping, addr *net.UDPAddr, hash []byte) { + if p.ENRSeq() != wantNode.Seq() { + t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq(), wantNode.Seq()) } - test.packetIn(nil, &pongV4{Expiration: futureExp, ReplyTok: hash}) + test.packetIn(nil, &v4wire.Pong{Expiration: futureExp, ReplyTok: hash}) }) // Request should work now. - test.packetIn(nil, &enrRequestV4{Expiration: futureExp}) - test.waitPacketOut(func(p *enrResponseV4, addr *net.UDPAddr, hash []byte) { + test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp}) + test.waitPacketOut(func(p *v4wire.ENRResponse, addr *net.UDPAddr, hash []byte) { n, err := enode.New(enode.ValidSchemes, &p.Record) if err != nil { t.Fatalf("invalid record: %v", err) @@ -505,119 +496,91 @@ func TestUDPv4_EIP868(t *testing.T) { }) } -// EIP-8 test vectors. -var testPackets = []struct { - input string - wantPacket interface{} -}{ - { - input: "5f76a8dbcc2cfb869e84ed53a0c511642bcf4b4725ac09f4bbb05758519b4a0c820b24a50e9a92ab6b54c29ec27415e4b1fb2e7221ae54df539e24eb7b0708ec5cd65263edbf18c639658308a5fb6cbe273b11231dc6db1eb8f0e91ebcd52e740101eb04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a35501", - wantPacket: &pingV4{ - Version: 4, - From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, - To: rpcEndpoint{net.ParseIP("::1"), 2222, 3333}, - Expiration: 1136239445, - NetworkId: testNetworkId, - Rest: []rlp.RawValue{}, - }, - }, - { - input: "c552fb8e82b033d29aa9a0d8a419430ccb60ccbd850c772c2b566b6f5648567563ece5c430d9583ce11f1ef5cf2eaba463d3b0b3dcb48d2989803052eca8189173a60da7d08d5c756d0aad6fc05cecdfa7ab4149be85e4c1e9ee32e34457ca050101ed04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355010102", - wantPacket: &pingV4{ - Version: 4, - From: rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, - To: rpcEndpoint{net.ParseIP("::1"), 2222, 3333}, - Expiration: 1136239445, - NetworkId: testNetworkId, - Rest: []rlp.RawValue{{0x01}, {0x02}}, - }, - }, - { - input: "756eea192b2715cf03ba03d69b0bd9e0f9d1bc9b1fab7a79f1e09291e94d88a780480dcfd0ef9b5c82ec6ce99588be5a4d34adf61c6d0baaa31a4336a44b1b9214bfe9e095459073a494529b9b5fbbdf5eac1e5ec8c7017b80cc8d8da0f07c560101f83f82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a35501c50102030405", - wantPacket: &pingV4{ - Version: 555, - From: rpcEndpoint{net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 3322, 5544}, - To: rpcEndpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338}, - Expiration: 1136239445, - NetworkId: testNetworkId, - Rest: []rlp.RawValue{{0xC5, 0x01, 0x02, 0x03, 0x04, 0x05}}, - }, - }, - { - input: "393771faaa6d2ec01c45c1c476e564e29aef48b65a480e7a3aee77c05b43c7c12f80cc9ae69d463552cf47c9fd3b468003f30cf198561d6fc76ef838dae25d2209c41dd7eb554cdd311a01d8dc628cd60ff761a484b858cea3b852f5f59d5a560102f846d79020010db885a308d313198a2e037073488208ae82823aa0fbc914b16819237dcd8801d7e53f69e9719adecb3cc0e790c57e91ca4461c9548443b9a355c6010203c2040506", - wantPacket: &pongV4{ - To: rpcEndpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338}, - ReplyTok: common.Hex2Bytes("fbc914b16819237dcd8801d7e53f69e9719adecb3cc0e790c57e91ca4461c954"), - Expiration: 1136239445, - Rest: []rlp.RawValue{{0xC6, 0x01, 0x02, 0x03, 0xC2, 0x04, 0x05}, {0x06}}, - }, - }, - { - input: "b4ddc372344d2fea1d58c26edfc7cdb8f7359cb4f6858484cf48ec23feeeaff0fee71339958ee7859a936d61e6e4e43f74f5dc119fffcd6b424df1929f55197b159aaef76f9bac9fed4f35677e85b049a618cdb62d5cdb70a3b238439c79bce30103f84eb840ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f8443b9a35582999983999999", - wantPacket: &findnodeV4{ - Target: hexEncPubkey("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"), - Expiration: 1136239445, - Rest: []rlp.RawValue{{0x82, 0x99, 0x99}, {0x83, 0x99, 0x99, 0x99}}, - }, - }, - { - input: "46742cb11a565879175e92ca349eaf9f9ce0380ac4d8976c8ea01f6c2620475c56b8409b9176e182b36ebc0715d6197a69b0eb806d6a7b7aa8615677891e15705c4cf3849f0ff477db229126dc4c0715e11f3ee9172659726dbb3eff8a64a1590004f9015bf90150f84d846321163782115c82115db8403155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32f84984010203040101b840312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069dbf8599020010db83c4d001500000000abcdef12820d05820d05b84038643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aacf8599020010db885a308d313198a2e037073488203e78203e8b8408dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df738443b9a355010203", - wantPacket: &neighborsV4{ - Nodes: []rpcNode{ - { - ID: hexEncPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), - IP: net.ParseIP("99.33.22.55").To4(), - UDP: 4444, - TCP: 4445, - }, - { - ID: hexEncPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), - IP: net.ParseIP("1.2.3.4").To4(), - UDP: 1, - TCP: 1, - }, - { - ID: hexEncPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), - IP: net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), - UDP: 3333, - TCP: 3333, - }, - { - ID: hexEncPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), - IP: net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), - UDP: 999, - TCP: 1000, - }, - }, - Expiration: 1136239445, - Rest: []rlp.RawValue{{0x01}, {0x02}, {0x03}}, - }, - }, -} - -func TestUDPv4_forwardCompatibility(t *testing.T) { - testkey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - wantNodeKey := encodePubkey(&testkey.PublicKey) +// This test verifies that a small network of nodes can boot up into a healthy state. +func TestUDPv4_smallNetConvergence(t *testing.T) { + t.Parallel() - for _, test := range testPackets { - input, err := hex.DecodeString(test.input) - if err != nil { - t.Fatalf("invalid hex: %s", test.input) - } - packet, nodekey, _, err := decodeV4(input) - if err != nil { - t.Errorf("did not accept packet %s\n%v", test.input, err) - continue + // Start the network. + nodes := make([]*UDPv4, 4) + for i := range nodes { + var cfg Config + if i > 0 { + bn := nodes[0].Self() + cfg.Bootnodes = []*enode.Node{bn} } - if !reflect.DeepEqual(packet, test.wantPacket) { - t.Errorf("got %s\nwant %s", spew.Sdump(packet), spew.Sdump(test.wantPacket)) - } - if nodekey != wantNodeKey { - t.Errorf("got id %v\nwant id %v", nodekey, wantNodeKey) + nodes[i] = startLocalhostV4(t, cfg) + defer nodes[i].Close() + } + + // Run through the iterator on all nodes until + // they have all found each other. + status := make(chan error, len(nodes)) + for i := range nodes { + node := nodes[i] + go func() { + found := make(map[enode.ID]bool, len(nodes)) + it := node.RandomNodes() + for it.Next() { + found[it.Node().ID()] = true + if len(found) == len(nodes) { + status <- nil + return + } + } + status <- fmt.Errorf("node %s didn't find all nodes", node.Self().ID().TerminalString()) + }() + } + + // Wait for all status reports. + timeout := time.NewTimer(30 * time.Second) + defer timeout.Stop() + for received := 0; received < len(nodes); { + select { + case <-timeout.C: + for _, node := range nodes { + node.Close() + } + case err := <-status: + received++ + if err != nil { + t.Error("ERROR:", err) + return + } } } } +func startLocalhostV4(t *testing.T, cfg Config) *UDPv4 { + t.Helper() + + cfg.PrivateKey = newkey() + db, _ := enode.OpenDB("") + ln := enode.NewLocalNode(db, cfg.PrivateKey, testNetworkId) + + // Prefix logs with node ID. + lprefix := fmt.Sprintf("(%s)", ln.ID().TerminalString()) + lfmt := log.TerminalFormat(false) + cfg.Log = testlog.Logger(t, log.LvlTrace) + cfg.Log.SetHandler(log.FuncHandler(func(r *log.Record) error { + t.Logf("%s %s", lprefix, lfmt.Format(r)) + return nil + })) + + // Listen. + socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + if err != nil { + t.Fatal(err) + } + realaddr := socket.LocalAddr().(*net.UDPAddr) + ln.SetStaticIP(realaddr.IP) + ln.SetFallbackUDP(realaddr.Port) + udp, err := ListenV4(socket, ln, cfg) + if err != nil { + t.Fatal(err) + } + return udp +} + // dgramPipe is a fake UDP socket. It queues all sent datagrams. type dgramPipe struct { mu *sync.Mutex @@ -676,17 +639,30 @@ func (c *dgramPipe) LocalAddr() net.Addr { return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)} } -func (c *dgramPipe) receive() (dgram, bool) { +func (c *dgramPipe) receive() (dgram, error) { c.mu.Lock() defer c.mu.Unlock() - for len(c.queue) == 0 && !c.closed { + + var timedOut bool + timer := time.AfterFunc(3*time.Second, func() { + c.mu.Lock() + timedOut = true + c.mu.Unlock() + c.cond.Broadcast() + }) + defer timer.Stop() + + for len(c.queue) == 0 && !c.closed && !timedOut { c.cond.Wait() } if c.closed { - return dgram{}, false + return dgram{}, errClosed + } + if timedOut { + return dgram{}, errTimeout } p := c.queue[0] copy(c.queue, c.queue[1:]) c.queue = c.queue[:len(c.queue)-1] - return p, true + return p, nil } diff --git a/p2p/discover/v4wire/v4wire.go b/p2p/discover/v4wire/v4wire.go new file mode 100644 index 0000000000..1ac61ec0dc --- /dev/null +++ b/p2p/discover/v4wire/v4wire.go @@ -0,0 +1,303 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package v4wire implements the Discovery v4 Wire Protocol. +package v4wire + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "errors" + "fmt" + "math/big" + "net" + "time" + + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +// RPC packet types +const ( + PingPacket = iota + 1 // zero is 'reserved' + PongPacket + FindnodePacket + NeighborsPacket + ENRRequestPacket + ENRResponsePacket +) + +// RPC request structures +type ( + Ping struct { + Version uint + From, To Endpoint + Expiration uint64 + NetworkId uint64 + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } + + // Pong is the reply to ping. + Pong struct { + // This field should mirror the UDP envelope address + // of the ping packet, which provides a way to discover the + // the external address (after NAT). + To Endpoint + ReplyTok []byte // This contains the hash of the ping packet. + Expiration uint64 // Absolute timestamp at which the packet becomes invalid. + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } + + // Findnode is a query for nodes close to the given target. + Findnode struct { + Target Pubkey + Expiration uint64 + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } + + // Neighbors is the reply to findnode. + Neighbors struct { + Nodes []Node + Expiration uint64 + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } + + // enrRequest queries for the remote node's record. + ENRRequest struct { + Expiration uint64 + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } + + // enrResponse is the reply to enrRequest. + ENRResponse struct { + ReplyTok []byte // Hash of the enrRequest packet. + Record enr.Record + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` + } +) + +// This number is the maximum number of neighbor nodes in a Neigbors packet. +const MaxNeighbors = 12 + +// This code computes the MaxNeighbors constant value. + +// func init() { +// var maxNeighbors int +// p := Neighbors{Expiration: ^uint64(0)} +// maxSizeNode := Node{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} +// for n := 0; ; n++ { +// p.Nodes = append(p.Nodes, maxSizeNode) +// size, _, err := rlp.EncodeToReader(p) +// if err != nil { +// // If this ever happens, it will be caught by the unit tests. +// panic("cannot encode: " + err.Error()) +// } +// if headSize+size+1 >= 1280 { +// maxNeighbors = n +// break +// } +// } +// fmt.Println("maxNeighbors", maxNeighbors) +// } + +// Pubkey represents an encoded 64-byte secp256k1 public key. +type Pubkey [64]byte + +// ID returns the node ID corresponding to the public key. +func (e Pubkey) ID() enode.ID { + return enode.ID(crypto.Keccak256Hash(e[:])) +} + +// Node represents information about a node. +type Node struct { + IP net.IP // len 4 for IPv4 or 16 for IPv6 + UDP uint16 // for discovery protocol + TCP uint16 // for RLPx protocol + ID Pubkey +} + +// Endpoint represents a network endpoint. +type Endpoint struct { + IP net.IP // len 4 for IPv4 or 16 for IPv6 + UDP uint16 // for discovery protocol + TCP uint16 // for RLPx protocol +} + +// NewEndpoint creates an endpoint. +func NewEndpoint(addr *net.UDPAddr, tcpPort uint16) Endpoint { + ip := net.IP{} + if ip4 := addr.IP.To4(); ip4 != nil { + ip = ip4 + } else if ip6 := addr.IP.To16(); ip6 != nil { + ip = ip6 + } + return Endpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} +} + +type Packet interface { + // packet name and type for logging purposes. + Name() string + Kind() byte +} + +func (req *Ping) Name() string { return "PING/v4" } +func (req *Ping) Kind() byte { return PingPacket } +func (req *Ping) ENRSeq() uint64 { return seqFromTail(req.Rest) } + +func (req *Pong) Name() string { return "PONG/v4" } +func (req *Pong) Kind() byte { return PongPacket } +func (req *Pong) ENRSeq() uint64 { return seqFromTail(req.Rest) } + +func (req *Findnode) Name() string { return "FINDNODE/v4" } +func (req *Findnode) Kind() byte { return FindnodePacket } + +func (req *Neighbors) Name() string { return "NEIGHBORS/v4" } +func (req *Neighbors) Kind() byte { return NeighborsPacket } + +func (req *ENRRequest) Name() string { return "ENRREQUEST/v4" } +func (req *ENRRequest) Kind() byte { return ENRRequestPacket } + +func (req *ENRResponse) Name() string { return "ENRRESPONSE/v4" } +func (req *ENRResponse) Kind() byte { return ENRResponsePacket } + +// Expired checks whether the given UNIX time stamp is in the past. +func Expired(ts uint64) bool { + return time.Unix(int64(ts), 0).Before(time.Now()) +} + +func seqFromTail(tail []rlp.RawValue) uint64 { + if len(tail) == 0 { + return 0 + } + var seq uint64 + rlp.DecodeBytes(tail[0], &seq) + return seq +} + +// Encoder/decoder. + +const ( + macSize = 32 + sigSize = crypto.SignatureLength + headSize = macSize + sigSize // space of packet frame data +) + +var ( + ErrPacketTooSmall = errors.New("too small") + ErrBadHash = errors.New("bad hash") + ErrBadPoint = errors.New("invalid curve point") +) + +var headSpace = make([]byte, headSize) + +var celoClientSalt = []byte{0x63, 0x65, 0x6C, 0x6F} + +// Decode reads a discovery v4 packet. +func Decode(input []byte) (Packet, Pubkey, []byte, error) { + if len(input) < headSize+1 { + return nil, Pubkey{}, nil, ErrPacketTooSmall + } + hash, sig, sigdata := input[:macSize], input[macSize:headSize], input[headSize:] + shouldhash := crypto.Keccak256(input[macSize:], celoClientSalt) + if !bytes.Equal(hash, shouldhash) { + return nil, Pubkey{}, nil, ErrBadHash + } + fromKey, err := recoverNodeKey(crypto.Keccak256(input[headSize:]), sig) + if err != nil { + return nil, fromKey, hash, err + } + + var req Packet + switch ptype := sigdata[0]; ptype { + case PingPacket: + req = new(Ping) + case PongPacket: + req = new(Pong) + case FindnodePacket: + req = new(Findnode) + case NeighborsPacket: + req = new(Neighbors) + case ENRRequestPacket: + req = new(ENRRequest) + case ENRResponsePacket: + req = new(ENRResponse) + default: + return nil, fromKey, hash, fmt.Errorf("unknown type: %d", ptype) + } + s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) + err = s.Decode(req) + return req, fromKey, hash, err +} + +// Encode encodes a discovery packet. +func Encode(priv *ecdsa.PrivateKey, req Packet) (packet, hash []byte, err error) { + b := new(bytes.Buffer) + b.Write(headSpace) + b.WriteByte(req.Kind()) + if err := rlp.Encode(b, req); err != nil { + return nil, nil, err + } + packet = b.Bytes() + sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) + if err != nil { + return nil, nil, err + } + copy(packet[macSize:], sig) + // Add the hash to the front. Note: this doesn't protect the packet in any way. + hash = crypto.Keccak256(packet[macSize:], celoClientSalt) + copy(packet, hash) + return packet, hash, nil +} + +// recoverNodeKey computes the public key used to sign the given hash from the signature. +func recoverNodeKey(hash, sig []byte) (key Pubkey, err error) { + pubkey, err := crypto.Ecrecover(hash, sig) + if err != nil { + return key, err + } + copy(key[:], pubkey[1:]) + return key, nil +} + +// EncodePubkey encodes a secp256k1 public key. +func EncodePubkey(key *ecdsa.PublicKey) Pubkey { + var e Pubkey + math.ReadBits(key.X, e[:len(e)/2]) + math.ReadBits(key.Y, e[len(e)/2:]) + return e +} + +// DecodePubkey reads an encoded secp256k1 public key. +func DecodePubkey(curve elliptic.Curve, e Pubkey) (*ecdsa.PublicKey, error) { + p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} + half := len(e) / 2 + p.X.SetBytes(e[:half]) + p.Y.SetBytes(e[half:]) + if !p.Curve.IsOnCurve(p.X, p.Y) { + return nil, ErrBadPoint + } + return p, nil +} diff --git a/p2p/discover/v4wire/v4wire_test.go b/p2p/discover/v4wire/v4wire_test.go new file mode 100644 index 0000000000..c6413edb33 --- /dev/null +++ b/p2p/discover/v4wire/v4wire_test.go @@ -0,0 +1,157 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v4wire + +import ( + "encoding/hex" + "net" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" +) + +var testNetworkId = uint64(1) + +// EIP-8 test vectors. +var testPackets = []struct { + input string + wantPacket interface{} +}{ + { + input: "5f76a8dbcc2cfb869e84ed53a0c511642bcf4b4725ac09f4bbb05758519b4a0c820b24a50e9a92ab6b54c29ec27415e4b1fb2e7221ae54df539e24eb7b0708ec5cd65263edbf18c639658308a5fb6cbe273b11231dc6db1eb8f0e91ebcd52e740101eb04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a35501", + wantPacket: &Ping{ + Version: 4, + From: Endpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, + To: Endpoint{net.ParseIP("::1"), 2222, 3333}, + Expiration: 1136239445, + NetworkId: testNetworkId, + Rest: []rlp.RawValue{}, + }, + }, + { + input: "c552fb8e82b033d29aa9a0d8a419430ccb60ccbd850c772c2b566b6f5648567563ece5c430d9583ce11f1ef5cf2eaba463d3b0b3dcb48d2989803052eca8189173a60da7d08d5c756d0aad6fc05cecdfa7ab4149be85e4c1e9ee32e34457ca050101ed04cb847f000001820cfa8215a8d790000000000000000000000000000000018208ae820d058443b9a355010102", + wantPacket: &Ping{ + Version: 4, + From: Endpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544}, + To: Endpoint{net.ParseIP("::1"), 2222, 3333}, + Expiration: 1136239445, + NetworkId: testNetworkId, + Rest: []rlp.RawValue{{0x01}, {0x02}}, + }, + }, + { + input: "756eea192b2715cf03ba03d69b0bd9e0f9d1bc9b1fab7a79f1e09291e94d88a780480dcfd0ef9b5c82ec6ce99588be5a4d34adf61c6d0baaa31a4336a44b1b9214bfe9e095459073a494529b9b5fbbdf5eac1e5ec8c7017b80cc8d8da0f07c560101f83f82022bd79020010db83c4d001500000000abcdef12820cfa8215a8d79020010db885a308d313198a2e037073488208ae82823a8443b9a35501c50102030405", + wantPacket: &Ping{ + Version: 555, + From: Endpoint{net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 3322, 5544}, + To: Endpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338}, + Expiration: 1136239445, + NetworkId: testNetworkId, + Rest: []rlp.RawValue{{0xC5, 0x01, 0x02, 0x03, 0x04, 0x05}}, + }, + }, + { + input: "393771faaa6d2ec01c45c1c476e564e29aef48b65a480e7a3aee77c05b43c7c12f80cc9ae69d463552cf47c9fd3b468003f30cf198561d6fc76ef838dae25d2209c41dd7eb554cdd311a01d8dc628cd60ff761a484b858cea3b852f5f59d5a560102f846d79020010db885a308d313198a2e037073488208ae82823aa0fbc914b16819237dcd8801d7e53f69e9719adecb3cc0e790c57e91ca4461c9548443b9a355c6010203c2040506", + wantPacket: &Pong{ + To: Endpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338}, + ReplyTok: common.Hex2Bytes("fbc914b16819237dcd8801d7e53f69e9719adecb3cc0e790c57e91ca4461c954"), + Expiration: 1136239445, + Rest: []rlp.RawValue{{0xC6, 0x01, 0x02, 0x03, 0xC2, 0x04, 0x05}, {0x06}}, + }, + }, + { + input: "b4ddc372344d2fea1d58c26edfc7cdb8f7359cb4f6858484cf48ec23feeeaff0fee71339958ee7859a936d61e6e4e43f74f5dc119fffcd6b424df1929f55197b159aaef76f9bac9fed4f35677e85b049a618cdb62d5cdb70a3b238439c79bce30103f84eb840ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f8443b9a35582999983999999", + wantPacket: &Findnode{ + Target: hexPubkey("ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd31387574077f301b421bc84df7266c44e9e6d569fc56be00812904767bf5ccd1fc7f"), + Expiration: 1136239445, + Rest: []rlp.RawValue{{0x82, 0x99, 0x99}, {0x83, 0x99, 0x99, 0x99}}, + }, + }, + { + input: "46742cb11a565879175e92ca349eaf9f9ce0380ac4d8976c8ea01f6c2620475c56b8409b9176e182b36ebc0715d6197a69b0eb806d6a7b7aa8615677891e15705c4cf3849f0ff477db229126dc4c0715e11f3ee9172659726dbb3eff8a64a1590004f9015bf90150f84d846321163782115c82115db8403155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32f84984010203040101b840312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069dbf8599020010db83c4d001500000000abcdef12820d05820d05b84038643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aacf8599020010db885a308d313198a2e037073488203e78203e8b8408dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df738443b9a355010203", + wantPacket: &Neighbors{ + Nodes: []Node{ + { + ID: hexPubkey("3155e1427f85f10a5c9a7755877748041af1bcd8d474ec065eb33df57a97babf54bfd2103575fa829115d224c523596b401065a97f74010610fce76382c0bf32"), + IP: net.ParseIP("99.33.22.55").To4(), + UDP: 4444, + TCP: 4445, + }, + { + ID: hexPubkey("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d20951933beea1e4dfa6f968212385e829f04c2d314fc2d4e255e0d3bc08792b069db"), + IP: net.ParseIP("1.2.3.4").To4(), + UDP: 1, + TCP: 1, + }, + { + ID: hexPubkey("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c765dd2d96126051913f44582e8c199ad7c6d6819e9a56483f637feaac9448aac"), + IP: net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), + UDP: 3333, + TCP: 3333, + }, + { + ID: hexPubkey("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2d47295286fc00cc081bb542d760717d1bdd6bec2c37cd72eca367d6dd3b9df73"), + IP: net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), + UDP: 999, + TCP: 1000, + }, + }, + Expiration: 1136239445, + Rest: []rlp.RawValue{{0x01}, {0x02}, {0x03}}, + }, + }, +} + +// This test checks that the decoder accepts packets according to EIP-8. +func TestForwardCompatibility(t *testing.T) { + testkey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + wantNodeKey := EncodePubkey(&testkey.PublicKey) + + for _, test := range testPackets { + input, err := hex.DecodeString(test.input) + if err != nil { + t.Fatalf("invalid hex: %s", test.input) + } + packet, nodekey, _, err := Decode(input) + if err != nil { + t.Errorf("did not accept packet %s\n%v", test.input, err) + continue + } + if !reflect.DeepEqual(packet, test.wantPacket) { + t.Errorf("got %s\nwant %s", spew.Sdump(packet), spew.Sdump(test.wantPacket)) + } + if nodekey != wantNodeKey { + t.Errorf("got id %v\nwant id %v", nodekey, wantNodeKey) + } + } +} + +func hexPubkey(h string) (ret Pubkey) { + b, err := hex.DecodeString(h) + if err != nil { + panic(err) + } + if len(b) != len(ret) { + panic("invalid length") + } + copy(ret[:], b) + return ret +} diff --git a/p2p/discover/v5_encoding.go b/p2p/discover/v5_encoding.go new file mode 100644 index 0000000000..842234e790 --- /dev/null +++ b/p2p/discover/v5_encoding.go @@ -0,0 +1,659 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "hash" + "net" + "time" + + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/crypto/hkdf" +) + +// TODO concurrent WHOAREYOU tie-breaker +// TODO deal with WHOAREYOU amplification factor (min packet size?) +// TODO add counter to nonce +// TODO rehandshake after X packets + +// Discovery v5 packet types. +const ( + p_pingV5 byte = iota + 1 + p_pongV5 + p_findnodeV5 + p_nodesV5 + p_requestTicketV5 + p_ticketV5 + p_regtopicV5 + p_regconfirmationV5 + p_topicqueryV5 + p_unknownV5 = byte(255) // any non-decryptable packet + p_whoareyouV5 = byte(254) // the WHOAREYOU packet +) + +// Discovery v5 packet structures. +type ( + // unknownV5 represents any packet that can't be decrypted. + unknownV5 struct { + AuthTag []byte + } + + // WHOAREYOU contains the handshake challenge. + whoareyouV5 struct { + AuthTag []byte + IDNonce [32]byte // To be signed by recipient. + RecordSeq uint64 // ENR sequence number of recipient + + node *enode.Node + sent mclock.AbsTime + } + + // PING is sent during liveness checks. + pingV5 struct { + ReqID []byte + ENRSeq uint64 + } + + // PONG is the reply to PING. + pongV5 struct { + ReqID []byte + ENRSeq uint64 + ToIP net.IP // These fields should mirror the UDP envelope address of the ping + ToPort uint16 // packet, which provides a way to discover the the external address (after NAT). + } + + // FINDNODE is a query for nodes in the given bucket. + findnodeV5 struct { + ReqID []byte + Distance uint + } + + // NODES is the reply to FINDNODE and TOPICQUERY. + nodesV5 struct { + ReqID []byte + Total uint8 + Nodes []*enr.Record + } + + // REQUESTTICKET requests a ticket for a topic queue. + requestTicketV5 struct { + ReqID []byte + Topic []byte + } + + // TICKET is the response to REQUESTTICKET. + ticketV5 struct { + ReqID []byte + Ticket []byte + } + + // REGTOPIC registers the sender in a topic queue using a ticket. + regtopicV5 struct { + ReqID []byte + Ticket []byte + ENR *enr.Record + } + + // REGCONFIRMATION is the reply to REGTOPIC. + regconfirmationV5 struct { + ReqID []byte + Registered bool + } + + // TOPICQUERY asks for nodes with the given topic. + topicqueryV5 struct { + ReqID []byte + Topic []byte + } +) + +const ( + // Encryption/authentication parameters. + authSchemeName = "gcm" + aesKeySize = 16 + gcmNonceSize = 12 + idNoncePrefix = "discovery-id-nonce" + handshakeTimeout = time.Second +) + +var ( + errTooShort = errors.New("packet too short") + errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake") + errHandshakeNonceMismatch = errors.New("wrong nonce in auth response") + errInvalidAuthKey = errors.New("invalid ephemeral pubkey") + errUnknownAuthScheme = errors.New("unknown auth scheme in handshake") + errNoRecord = errors.New("expected ENR in handshake but none sent") + errInvalidNonceSig = errors.New("invalid ID nonce signature") + zeroNonce = make([]byte, gcmNonceSize) +) + +// wireCodec encodes and decodes discovery v5 packets. +type wireCodec struct { + sha256 hash.Hash + localnode *enode.LocalNode + privkey *ecdsa.PrivateKey + myChtagHash enode.ID + myWhoareyouMagic []byte + + sc *sessionCache +} + +type handshakeSecrets struct { + writeKey, readKey, authRespKey []byte +} + +type authHeader struct { + authHeaderList + isHandshake bool +} + +type authHeaderList struct { + Auth []byte // authentication info of packet + IDNonce [32]byte // IDNonce of WHOAREYOU + Scheme string // name of encryption/authentication scheme + EphemeralKey []byte // ephemeral public key + Response []byte // encrypted authResponse +} + +type authResponse struct { + Version uint + Signature []byte + Record *enr.Record `rlp:"nil"` // sender's record +} + +func (h *authHeader) DecodeRLP(r *rlp.Stream) error { + k, _, err := r.Kind() + if err != nil { + return err + } + if k == rlp.Byte || k == rlp.String { + return r.Decode(&h.Auth) + } + h.isHandshake = true + return r.Decode(&h.authHeaderList) +} + +// ephemeralKey decodes the ephemeral public key in the header. +func (h *authHeaderList) ephemeralKey(curve elliptic.Curve) *ecdsa.PublicKey { + var key encPubkey + copy(key[:], h.EphemeralKey) + pubkey, _ := decodePubkey(curve, key) + return pubkey +} + +// newWireCodec creates a wire codec. +func newWireCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *wireCodec { + c := &wireCodec{ + sha256: sha256.New(), + localnode: ln, + privkey: key, + sc: newSessionCache(1024, clock), + } + // Create magic strings for packet matching. + self := ln.ID() + c.myWhoareyouMagic = c.sha256sum(self[:], []byte("WHOAREYOU")) + copy(c.myChtagHash[:], c.sha256sum(self[:])) + return c +} + +// encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The +// 'challenge' parameter should be the most recently received WHOAREYOU packet from that +// node. +func (c *wireCodec) encode(id enode.ID, addr string, packet packetV5, challenge *whoareyouV5) ([]byte, []byte, error) { + if packet.kind() == p_whoareyouV5 { + p := packet.(*whoareyouV5) + enc, err := c.encodeWhoareyou(id, p) + if err == nil { + c.sc.storeSentHandshake(id, addr, p) + } + return enc, nil, err + } + // Ensure calling code sets node if needed. + if challenge != nil && challenge.node == nil { + panic("BUG: missing challenge.node in encode") + } + writeKey := c.sc.writeKey(id, addr) + if writeKey != nil || challenge != nil { + return c.encodeEncrypted(id, addr, packet, writeKey, challenge) + } + return c.encodeRandom(id) +} + +// encodeRandom encodes a random packet. +func (c *wireCodec) encodeRandom(toID enode.ID) ([]byte, []byte, error) { + tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) + r := make([]byte, 44) // TODO randomize size + if _, err := crand.Read(r); err != nil { + return nil, nil, err + } + nonce := make([]byte, gcmNonceSize) + if _, err := crand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("can't get random data: %v", err) + } + b := new(bytes.Buffer) + b.Write(tag[:]) + rlp.Encode(b, nonce) + b.Write(r) + return b.Bytes(), nonce, nil +} + +// encodeWhoareyou encodes WHOAREYOU. +func (c *wireCodec) encodeWhoareyou(toID enode.ID, packet *whoareyouV5) ([]byte, error) { + // Sanity check node field to catch misbehaving callers. + if packet.RecordSeq > 0 && packet.node == nil { + panic("BUG: missing node in whoareyouV5 with non-zero seq") + } + b := new(bytes.Buffer) + b.Write(c.sha256sum(toID[:], []byte("WHOAREYOU"))) + err := rlp.Encode(b, packet) + return b.Bytes(), err +} + +// encodeEncrypted encodes an encrypted packet. +func (c *wireCodec) encodeEncrypted(toID enode.ID, toAddr string, packet packetV5, writeKey []byte, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) { + nonce := make([]byte, gcmNonceSize) + if _, err := crand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("can't get random data: %v", err) + } + + var headEnc []byte + if challenge == nil { + // Regular packet, use existing key and simply encode nonce. + headEnc, _ = rlp.EncodeToBytes(nonce) + } else { + // We're answering WHOAREYOU, generate new keys and encrypt with those. + header, sec, err := c.makeAuthHeader(nonce, challenge) + if err != nil { + return nil, nil, err + } + if headEnc, err = rlp.EncodeToBytes(header); err != nil { + return nil, nil, err + } + c.sc.storeNewSession(toID, toAddr, sec.readKey, sec.writeKey) + writeKey = sec.writeKey + } + + // Encode the packet. + body := new(bytes.Buffer) + body.WriteByte(packet.kind()) + if err := rlp.Encode(body, packet); err != nil { + return nil, nil, err + } + tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) + headsize := len(tag) + len(headEnc) + headbuf := make([]byte, headsize) + copy(headbuf[:], tag[:]) + copy(headbuf[len(tag):], headEnc) + + // Encrypt the body. + enc, err = encryptGCM(headbuf, writeKey, nonce, body.Bytes(), tag[:]) + return enc, nonce, err +} + +// encodeAuthHeader creates the auth header on a call packet following WHOAREYOU. +func (c *wireCodec) makeAuthHeader(nonce []byte, challenge *whoareyouV5) (*authHeaderList, *handshakeSecrets, error) { + resp := &authResponse{Version: 5} + + // Add our record to response if it's newer than what remote + // side has. + ln := c.localnode.Node() + if challenge.RecordSeq < ln.Seq() { + resp.Record = ln.Record() + } + + // Create the ephemeral key. This needs to be first because the + // key is part of the ID nonce signature. + var remotePubkey = new(ecdsa.PublicKey) + if err := challenge.node.Load((*enode.Secp256k1)(remotePubkey)); err != nil { + return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient") + } + ephkey, err := crypto.GenerateKey() + if err != nil { + return nil, nil, fmt.Errorf("can't generate ephemeral key") + } + ephpubkey := encodePubkey(&ephkey.PublicKey) + + // Add ID nonce signature to response. + idsig, err := c.signIDNonce(challenge.IDNonce[:], ephpubkey[:]) + if err != nil { + return nil, nil, fmt.Errorf("can't sign: %v", err) + } + resp.Signature = idsig + + // Create session keys. + sec := c.deriveKeys(c.localnode.ID(), challenge.node.ID(), ephkey, remotePubkey, challenge) + if sec == nil { + return nil, nil, fmt.Errorf("key derivation failed") + } + + // Encrypt the authentication response and assemble the auth header. + respRLP, err := rlp.EncodeToBytes(resp) + if err != nil { + return nil, nil, fmt.Errorf("can't encode auth response: %v", err) + } + respEnc, err := encryptGCM(nil, sec.authRespKey, zeroNonce, respRLP, nil) + if err != nil { + return nil, nil, fmt.Errorf("can't encrypt auth response: %v", err) + } + head := &authHeaderList{ + Auth: nonce, + Scheme: authSchemeName, + IDNonce: challenge.IDNonce, + EphemeralKey: ephpubkey[:], + Response: respEnc, + } + return head, sec, err +} + +// deriveKeys generates session keys using elliptic-curve Diffie-Hellman key agreement. +func (c *wireCodec) deriveKeys(n1, n2 enode.ID, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, challenge *whoareyouV5) *handshakeSecrets { + eph := ecdh(priv, pub) + if eph == nil { + return nil + } + + info := []byte("discovery v5 key agreement") + info = append(info, n1[:]...) + info = append(info, n2[:]...) + kdf := hkdf.New(c.sha256reset, eph, challenge.IDNonce[:], info) + sec := handshakeSecrets{ + writeKey: make([]byte, aesKeySize), + readKey: make([]byte, aesKeySize), + authRespKey: make([]byte, aesKeySize), + } + kdf.Read(sec.writeKey) + kdf.Read(sec.readKey) + kdf.Read(sec.authRespKey) + for i := range eph { + eph[i] = 0 + } + return &sec +} + +// signIDNonce creates the ID nonce signature. +func (c *wireCodec) signIDNonce(nonce, ephkey []byte) ([]byte, error) { + idsig, err := crypto.Sign(c.idNonceHash(nonce, ephkey), c.privkey) + if err != nil { + return nil, fmt.Errorf("can't sign: %v", err) + } + return idsig[:len(idsig)-1], nil // remove recovery ID +} + +// idNonceHash computes the hash of id nonce with prefix. +func (c *wireCodec) idNonceHash(nonce, ephkey []byte) []byte { + h := c.sha256reset() + h.Write([]byte(idNoncePrefix)) + h.Write(nonce) + h.Write(ephkey) + return h.Sum(nil) +} + +// decode decodes a discovery packet. +func (c *wireCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { + // Delete timed-out handshakes. This must happen before decoding to avoid + // processing the same handshake twice. + c.sc.handshakeGC() + + if len(input) < 32 { + return enode.ID{}, nil, nil, errTooShort + } + if bytes.HasPrefix(input, c.myWhoareyouMagic) { + p, err := c.decodeWhoareyou(input) + return enode.ID{}, nil, p, err + } + sender := xorTag(input[:32], c.myChtagHash) + p, n, err := c.decodeEncrypted(sender, addr, input) + return sender, n, p, err +} + +// decodeWhoareyou decode a WHOAREYOU packet. +func (c *wireCodec) decodeWhoareyou(input []byte) (packetV5, error) { + packet := new(whoareyouV5) + err := rlp.DecodeBytes(input[32:], packet) + return packet, err +} + +// decodeEncrypted decodes an encrypted discovery packet. +func (c *wireCodec) decodeEncrypted(fromID enode.ID, fromAddr string, input []byte) (packetV5, *enode.Node, error) { + // Decode packet header. + var head authHeader + r := bytes.NewReader(input[32:]) + err := rlp.Decode(r, &head) + if err != nil { + return nil, nil, err + } + + // Decrypt and process auth response. + readKey, node, err := c.decodeAuth(fromID, fromAddr, &head) + if err != nil { + return nil, nil, err + } + + // Decrypt and decode the packet body. + headsize := len(input) - r.Len() + bodyEnc := input[headsize:] + body, err := decryptGCM(readKey, head.Auth, bodyEnc, input[:32]) + if err != nil { + if !head.isHandshake { + // Can't decrypt, start handshake. + return &unknownV5{AuthTag: head.Auth}, nil, nil + } + return nil, nil, fmt.Errorf("handshake failed: %v", err) + } + if len(body) == 0 { + return nil, nil, errTooShort + } + p, err := decodePacketBodyV5(body[0], body[1:]) + return p, node, err +} + +// decodeAuth processes an auth header. +func (c *wireCodec) decodeAuth(fromID enode.ID, fromAddr string, head *authHeader) ([]byte, *enode.Node, error) { + if !head.isHandshake { + return c.sc.readKey(fromID, fromAddr), nil, nil + } + + // Remote is attempting handshake. Verify against our last WHOAREYOU. + challenge := c.sc.getHandshake(fromID, fromAddr) + if challenge == nil { + return nil, nil, errUnexpectedHandshake + } + if head.IDNonce != challenge.IDNonce { + return nil, nil, errHandshakeNonceMismatch + } + sec, n, err := c.decodeAuthResp(fromID, fromAddr, &head.authHeaderList, challenge) + if err != nil { + return nil, n, err + } + // Swap keys to match remote. + sec.readKey, sec.writeKey = sec.writeKey, sec.readKey + c.sc.storeNewSession(fromID, fromAddr, sec.readKey, sec.writeKey) + c.sc.deleteHandshake(fromID, fromAddr) + return sec.readKey, n, err +} + +// decodeAuthResp decodes and verifies an authentication response. +func (c *wireCodec) decodeAuthResp(fromID enode.ID, fromAddr string, head *authHeaderList, challenge *whoareyouV5) (*handshakeSecrets, *enode.Node, error) { + // Decrypt / decode the response. + if head.Scheme != authSchemeName { + return nil, nil, errUnknownAuthScheme + } + ephkey := head.ephemeralKey(c.privkey.Curve) + if ephkey == nil { + return nil, nil, errInvalidAuthKey + } + sec := c.deriveKeys(fromID, c.localnode.ID(), c.privkey, ephkey, challenge) + respPT, err := decryptGCM(sec.authRespKey, zeroNonce, head.Response, nil) + if err != nil { + return nil, nil, fmt.Errorf("can't decrypt auth response header: %v", err) + } + var resp authResponse + if err := rlp.DecodeBytes(respPT, &resp); err != nil { + return nil, nil, fmt.Errorf("invalid auth response: %v", err) + } + + // Verify response node record. The remote node should include the record + // if we don't have one or if ours is older than the latest version. + node := challenge.node + if resp.Record != nil { + if node == nil || node.Seq() < resp.Record.Seq() { + n, err := enode.New(enode.ValidSchemes, resp.Record) + if err != nil { + return nil, nil, fmt.Errorf("invalid node record: %v", err) + } + if n.ID() != fromID { + return nil, nil, fmt.Errorf("record in auth respose has wrong ID: %v", n.ID()) + } + node = n + } + } + if node == nil { + return nil, nil, errNoRecord + } + + // Verify ID nonce signature. + err = c.verifyIDSignature(challenge.IDNonce[:], head.EphemeralKey, resp.Signature, node) + if err != nil { + return nil, nil, err + } + return sec, node, nil +} + +// verifyIDSignature checks that signature over idnonce was made by the node with given record. +func (c *wireCodec) verifyIDSignature(nonce, ephkey, sig []byte, n *enode.Node) error { + switch idscheme := n.Record().IdentityScheme(); idscheme { + case "v4": + var pk ecdsa.PublicKey + n.Load((*enode.Secp256k1)(&pk)) // cannot fail because record is valid + if !crypto.VerifySignature(crypto.FromECDSAPub(&pk), c.idNonceHash(nonce, ephkey), sig) { + return errInvalidNonceSig + } + return nil + default: + return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme) + } +} + +// decodePacketBody decodes the body of an encrypted discovery packet. +func decodePacketBodyV5(ptype byte, body []byte) (packetV5, error) { + var dec packetV5 + switch ptype { + case p_pingV5: + dec = new(pingV5) + case p_pongV5: + dec = new(pongV5) + case p_findnodeV5: + dec = new(findnodeV5) + case p_nodesV5: + dec = new(nodesV5) + case p_requestTicketV5: + dec = new(requestTicketV5) + case p_ticketV5: + dec = new(ticketV5) + case p_regtopicV5: + dec = new(regtopicV5) + case p_regconfirmationV5: + dec = new(regconfirmationV5) + case p_topicqueryV5: + dec = new(topicqueryV5) + default: + return nil, fmt.Errorf("unknown packet type %d", ptype) + } + if err := rlp.DecodeBytes(body, dec); err != nil { + return nil, err + } + return dec, nil +} + +// sha256reset returns the shared hash instance. +func (c *wireCodec) sha256reset() hash.Hash { + c.sha256.Reset() + return c.sha256 +} + +// sha256sum computes sha256 on the concatenation of inputs. +func (c *wireCodec) sha256sum(inputs ...[]byte) []byte { + c.sha256.Reset() + for _, b := range inputs { + c.sha256.Write(b) + } + return c.sha256.Sum(nil) +} + +func xorTag(a []byte, b enode.ID) enode.ID { + var r enode.ID + for i := range r { + r[i] = a[i] ^ b[i] + } + return r +} + +// ecdh creates a shared secret. +func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte { + secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes()) + if secX == nil { + return nil + } + sec := make([]byte, 33) + sec[0] = 0x02 | byte(secY.Bit(0)) + math.ReadBits(secX, sec[1:]) + return sec +} + +// encryptGCM encrypts pt using AES-GCM with the given key and nonce. +func encryptGCM(dest, key, nonce, pt, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + panic(fmt.Errorf("can't create block cipher: %v", err)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + panic(fmt.Errorf("can't create GCM: %v", err)) + } + return aesgcm.Seal(dest, nonce, pt, authData), nil +} + +// decryptGCM decrypts ct using AES-GCM with the given key and nonce. +func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("can't create block cipher: %v", err) + } + if len(nonce) != gcmNonceSize { + return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + return nil, fmt.Errorf("can't create GCM: %v", err) + } + pt := make([]byte, 0, len(ct)) + return aesgcm.Open(pt, nonce, ct, authData) +} diff --git a/p2p/discover/v5_encoding_test.go b/p2p/discover/v5_encoding_test.go new file mode 100644 index 0000000000..87ce9e72ce --- /dev/null +++ b/p2p/discover/v5_encoding_test.go @@ -0,0 +1,373 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/ecdsa" + "encoding/hex" + "fmt" + "net" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +var ( + testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") + testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") + testIDnonce = [32]byte{5, 6, 7, 8, 9, 10, 11, 12} +) + +func TestDeriveKeysV5(t *testing.T) { + t.Parallel() + + var ( + n1 = enode.ID{1} + n2 = enode.ID{2} + challenge = &whoareyouV5{} + db, _ = enode.OpenDB("") + ln = enode.NewLocalNode(db, testKeyA, testNetworkId) + c = newWireCodec(ln, testKeyA, mclock.System{}) + ) + defer db.Close() + + sec1 := c.deriveKeys(n1, n2, testKeyA, &testKeyB.PublicKey, challenge) + sec2 := c.deriveKeys(n1, n2, testKeyB, &testKeyA.PublicKey, challenge) + if sec1 == nil || sec2 == nil { + t.Fatal("key agreement failed") + } + if !reflect.DeepEqual(sec1, sec2) { + t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2) + } +} + +// This test checks the basic handshake flow where A talks to B and A has no secrets. +func TestHandshakeV5(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + if len(net.nodeB.c.sc.handshakes) > 0 { + t.Fatalf("node B didn't remove handshake from challenge map") + } + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// This test checks that handshake attempts are removed within the timeout. +func TestHandshakeV5_timeout(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE after timeout + net.clock.Run(handshakeTimeout + 1) + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode) +} + +// This test checks handshake behavior when no record is sent in the auth response. +func TestHandshakeV5_norecord(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + nodeA := net.nodeA.n() + if nodeA.Seq() == 0 { + t.Fatal("need non-zero sequence number") + } + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: nodeA.Seq(), + node: nodeA, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// In this test, A tries to send FINDNODE with existing secrets but B doesn't know +// anything about A. +func TestHandshakeV5_rekey(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + initKeys := &handshakeSecrets{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeys.readKey, initKeys.writeKey) + + // A -> B FINDNODE (encrypted with zero keys) + findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + net.nodeB.expectDecode(t, p_unknownV5, findnode) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // Check that new keys haven't been stored yet. + if s := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()); !bytes.Equal(s.writeKey, initKeys.writeKey) || !bytes.Equal(s.readKey, initKeys.readKey) { + t.Fatal("node A stored keys too early") + } + if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil { + t.Fatal("node B stored keys too early") + } + + // A -> B FINDNODE encrypted with new keys + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// In this test A and B have different keys before the handshake. +func TestHandshakeV5_rekey2(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + initKeysA := &handshakeSecrets{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + initKeysB := &handshakeSecrets{ + readKey: []byte("CCCCCCCCCCCCCCCC"), + writeKey: []byte("DDDDDDDDDDDDDDDD"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA.readKey, initKeysA.writeKey) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB.readKey, initKeysA.writeKey) + + // A -> B FINDNODE encrypted with initKeysA + findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{Distance: 3}) + net.nodeB.expectDecode(t, p_unknownV5, findnode) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE encrypted with new keys + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// This test checks some malformed packets. +func TestDecodeErrorsV5(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + net.nodeA.expectDecodeErr(t, errTooShort, []byte{}) + // TODO some more tests would be nice :) +} + +// This benchmark checks performance of authHeader decoding, verification and key derivation. +func BenchmarkV5_DecodeAuthSecp256k1(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + var ( + idA = net.nodeA.id() + addrA = net.nodeA.addr() + challenge = &whoareyouV5{AuthTag: []byte("authresp"), RecordSeq: 0, node: net.nodeB.n()} + nonce = make([]byte, gcmNonceSize) + ) + header, _, _ := net.nodeA.c.makeAuthHeader(nonce, challenge) + challenge.node = nil // force ENR signature verification in decoder + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _, err := net.nodeB.c.decodeAuthResp(idA, addrA, header, challenge) + if err != nil { + b.Fatal(err) + } + } +} + +// This benchmark checks how long it takes to decode an encrypted ping packet. +func BenchmarkV5_DecodePing(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + r := []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17} + w := []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134} + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), r, w) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), w, r) + addrB := net.nodeA.addr() + ping := &pingV5{ReqID: []byte("reqid"), ENRSeq: 5} + enc, _, err := net.nodeA.c.encode(net.nodeB.id(), addrB, ping, nil) + if err != nil { + b.Fatalf("can't encode: %v", err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _, p, _ := net.nodeB.c.decode(enc, addrB) + if _, ok := p.(*pingV5); !ok { + b.Fatalf("wrong packet type %T", p) + } + } +} + +var pp = spew.NewDefaultConfig() + +type handshakeTest struct { + nodeA, nodeB handshakeTestNode + clock mclock.Simulated +} + +type handshakeTestNode struct { + ln *enode.LocalNode + c *wireCodec +} + +func newHandshakeTest() *handshakeTest { + t := new(handshakeTest) + t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) + t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) + return t +} + +func (t *handshakeTest) close() { + t.nodeA.ln.Database().Close() + t.nodeB.ln.Database().Close() +} + +func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { + db, _ := enode.OpenDB("") + n.ln = enode.NewLocalNode(db, key, testNetworkId) + n.ln.SetStaticIP(ip) + n.c = newWireCodec(n.ln, key, clock) +} + +func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p packetV5) ([]byte, []byte) { + t.Helper() + return n.encodeWithChallenge(t, to, nil, p) +} + +func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *whoareyouV5, p packetV5) ([]byte, []byte) { + t.Helper() + // Copy challenge and add destination node. This avoids sharing 'c' among the two codecs. + var challenge *whoareyouV5 + if c != nil { + challengeCopy := *c + challenge = &challengeCopy + challenge.node = to.n() + } + // Encode to destination. + enc, authTag, err := n.c.encode(to.id(), to.addr(), p, challenge) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.name(), hex.Dump(enc)) + return enc, authTag +} + +func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) packetV5 { + t.Helper() + dec, err := n.decode(p) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec)) + if dec.kind() != ptype { + t.Fatalf("expected packet type %d, got %d", ptype, dec.kind()) + } + return dec +} + +func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) { + t.Helper() + if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) { + t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr)) + } +} + +func (n *handshakeTestNode) decode(input []byte) (packetV5, error) { + _, _, p, err := n.c.decode(input, "127.0.0.1") + return p, err +} + +func (n *handshakeTestNode) n() *enode.Node { + return n.ln.Node() +} + +func (n *handshakeTestNode) addr() string { + return n.ln.Node().IP().String() +} + +func (n *handshakeTestNode) id() enode.ID { + return n.ln.ID() +} diff --git a/p2p/discover/v5_session.go b/p2p/discover/v5_session.go new file mode 100644 index 0000000000..e19f25a335 --- /dev/null +++ b/p2p/discover/v5_session.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + crand "crypto/rand" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/hashicorp/golang-lru/simplelru" +) + +// The sessionCache keeps negotiated encryption keys and +// state for in-progress handshakes in the Discovery v5 wire protocol. +type sessionCache struct { + sessions *simplelru.LRU + handshakes map[sessionID]*whoareyouV5 + clock mclock.Clock +} + +// sessionID identifies a session or handshake. +type sessionID struct { + id enode.ID + addr string +} + +// session contains session information +type session struct { + writeKey []byte + readKey []byte + nonceCounter uint32 //nolint:unused +} + +func newSessionCache(maxItems int, clock mclock.Clock) *sessionCache { + cache, err := simplelru.NewLRU(maxItems, nil) + if err != nil { + panic("can't create session cache") + } + return &sessionCache{ + sessions: cache, + handshakes: make(map[sessionID]*whoareyouV5), + clock: clock, + } +} + +// nextNonce creates a nonce for encrypting a message to the given session. +func (sc *sessionCache) nextNonce(id enode.ID, addr string) []byte { //nolint:unused + n := make([]byte, gcmNonceSize) + crand.Read(n) + return n +} + +// session returns the current session for the given node, if any. +func (sc *sessionCache) session(id enode.ID, addr string) *session { + item, ok := sc.sessions.Get(sessionID{id, addr}) + if !ok { + return nil + } + return item.(*session) +} + +// readKey returns the current read key for the given node. +func (sc *sessionCache) readKey(id enode.ID, addr string) []byte { + if s := sc.session(id, addr); s != nil { + return s.readKey + } + return nil +} + +// writeKey returns the current read key for the given node. +func (sc *sessionCache) writeKey(id enode.ID, addr string) []byte { + if s := sc.session(id, addr); s != nil { + return s.writeKey + } + return nil +} + +// storeNewSession stores new encryption keys in the cache. +func (sc *sessionCache) storeNewSession(id enode.ID, addr string, r, w []byte) { + sc.sessions.Add(sessionID{id, addr}, &session{ + readKey: r, writeKey: w, + }) +} + +// getHandshake gets the handshake challenge we previously sent to the given remote node. +func (sc *sessionCache) getHandshake(id enode.ID, addr string) *whoareyouV5 { + return sc.handshakes[sessionID{id, addr}] +} + +// storeSentHandshake stores the handshake challenge sent to the given remote node. +func (sc *sessionCache) storeSentHandshake(id enode.ID, addr string, challenge *whoareyouV5) { + challenge.sent = sc.clock.Now() + sc.handshakes[sessionID{id, addr}] = challenge +} + +// deleteHandshake deletes handshake data for the given node. +func (sc *sessionCache) deleteHandshake(id enode.ID, addr string) { + delete(sc.handshakes, sessionID{id, addr}) +} + +// handshakeGC deletes timed-out handshakes. +func (sc *sessionCache) handshakeGC() { + deadline := sc.clock.Now().Add(-handshakeTimeout) + for key, challenge := range sc.handshakes { + if challenge.sent < deadline { + delete(sc.handshakes, key) + } + } +} diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go new file mode 100644 index 0000000000..598dba871c --- /dev/null +++ b/p2p/discover/v5_udp.go @@ -0,0 +1,832 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "context" + "crypto/ecdsa" + crand "crypto/rand" + "errors" + "fmt" + "io" + "math" + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/netutil" +) + +const ( + lookupRequestLimit = 3 // max requests against a single node during lookup + findnodeResultLimit = 15 // applies in FINDNODE handler + totalNodesResponseLimit = 5 // applies in waitForNodes + nodesResponseItemLimit = 3 // applies in sendNodes + + respTimeoutV5 = 700 * time.Millisecond +) + +// codecV5 is implemented by wireCodec (and testCodec). +// +// The UDPv5 transport is split into two objects: the codec object deals with +// encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns. +type codecV5 interface { + // encode encodes a packet. The 'challenge' parameter is non-nil for calls which got a + // WHOAREYOU response. + encode(fromID enode.ID, fromAddr string, p packetV5, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) + + // decode decodes a packet. It returns an *unknownV5 packet if decryption fails. + // The fromNode return value is non-nil when the input contains a handshake response. + decode(input []byte, fromAddr string) (fromID enode.ID, fromNode *enode.Node, p packetV5, err error) +} + +// packetV5 is implemented by all discv5 packet type structs. +type packetV5 interface { + // These methods provide information and set the request ID. + name() string + kind() byte + setreqid([]byte) + // handle should perform the appropriate action to handle the packet, i.e. this is the + // place to send the response. + handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) +} + +// UDPv5 is the implementation of protocol version 5. +type UDPv5 struct { + // static fields + conn UDPConn + tab *Table + netrestrict *netutil.Netlist + priv *ecdsa.PrivateKey + localNode *enode.LocalNode + db *enode.DB + log log.Logger + clock mclock.Clock + validSchemes enr.IdentityScheme + + // channels into dispatch + packetInCh chan ReadPacket + readNextCh chan struct{} + callCh chan *callV5 + callDoneCh chan *callV5 + respTimeoutCh chan *callTimeout + + // state of dispatch + codec codecV5 + activeCallByNode map[enode.ID]*callV5 + activeCallByAuth map[string]*callV5 + callQueue map[enode.ID][]*callV5 + + // shutdown stuff + closeOnce sync.Once + closeCtx context.Context + cancelCloseCtx context.CancelFunc + wg sync.WaitGroup +} + +// callV5 represents a remote procedure call against another node. +type callV5 struct { + node *enode.Node + packet packetV5 + responseType byte // expected packet type of response + reqid []byte + ch chan packetV5 // responses sent here + err chan error // errors sent here + + // Valid for active calls only: + authTag []byte // authTag of request packet + handshakeCount int // # times we attempted handshake for this call + challenge *whoareyouV5 // last sent handshake challenge + timeout mclock.Timer +} + +// callTimeout is the response timeout event of a call. +type callTimeout struct { + c *callV5 + timer mclock.Timer +} + +// ListenV5 listens on the given connection. +func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { + t, err := newUDPv5(conn, ln, cfg) + if err != nil { + return nil, err + } + go t.tab.loop() + t.wg.Add(2) + go t.readLoop() + go t.dispatch() + return t, nil +} + +// newUDPv5 creates a UDPv5 transport, but doesn't start any goroutines. +func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { + closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) + cfg = cfg.withDefaults() + t := &UDPv5{ + // static fields + conn: conn, + localNode: ln, + db: ln.Database(), + netrestrict: cfg.NetRestrict, + priv: cfg.PrivateKey, + log: cfg.Log, + validSchemes: cfg.ValidSchemes, + clock: cfg.Clock, + // channels into dispatch + packetInCh: make(chan ReadPacket, 1), + readNextCh: make(chan struct{}, 1), + callCh: make(chan *callV5), + callDoneCh: make(chan *callV5), + respTimeoutCh: make(chan *callTimeout), + // state of dispatch + codec: newWireCodec(ln, cfg.PrivateKey, cfg.Clock), + activeCallByNode: make(map[enode.ID]*callV5), + activeCallByAuth: make(map[string]*callV5), + callQueue: make(map[enode.ID][]*callV5), + // shutdown + closeCtx: closeCtx, + cancelCloseCtx: cancelCloseCtx, + } + tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log) + if err != nil { + return nil, err + } + t.tab = tab + return t, nil +} + +// Self returns the local node record. +func (t *UDPv5) Self() *enode.Node { + return t.localNode.Node() +} + +// Close shuts down packet processing. +func (t *UDPv5) Close() { + t.closeOnce.Do(func() { + t.cancelCloseCtx() + t.conn.Close() + t.wg.Wait() + t.tab.close() + }) +} + +// Ping sends a ping message to the given node. +func (t *UDPv5) Ping(n *enode.Node) error { + _, err := t.ping(n) + return err +} + +// Resolve searches for a specific node with the given ID and tries to get the most recent +// version of the node record for it. It returns n if the node could not be resolved. +func (t *UDPv5) Resolve(n *enode.Node) *enode.Node { + if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + n = intable + } + // Try asking directly. This works if the node is still responding on the endpoint we have. + if resp, err := t.RequestENR(n); err == nil { + return resp + } + // Otherwise do a network lookup. + result := t.Lookup(n.ID()) + for _, rn := range result { + if rn.ID() == n.ID() && rn.Seq() > n.Seq() { + return rn + } + } + return n +} + +func (t *UDPv5) RandomNodes() enode.Iterator { + if t.tab.len() == 0 { + // All nodes were dropped, refresh. The very first query will hit this + // case and run the bootstrapping logic. + <-t.tab.refresh() + } + + return newLookupIterator(t.closeCtx, t.newRandomLookup) +} + +// Lookup performs a recursive lookup for the given target. +// It returns the closest nodes to target. +func (t *UDPv5) Lookup(target enode.ID) []*enode.Node { + return t.newLookup(t.closeCtx, target).run() +} + +// lookupRandom looks up a random target. +// This is needed to satisfy the transport interface. +func (t *UDPv5) lookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).run() +} + +// lookupSelf looks up our own node ID. +// This is needed to satisfy the transport interface. +func (t *UDPv5) lookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, t.Self().ID()).run() +} + +func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { + var target enode.ID + crand.Read(target[:]) + return t.newLookup(ctx, target) +} + +func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { + return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return t.lookupWorker(n, target) + }) +} + +// lookupWorker performs FINDNODE calls against a single node during lookup. +func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { + var ( + dists = lookupDistances(target, destNode.ID()) + nodes = nodesByDistance{target: target} + err error + ) + for i := 0; i < lookupRequestLimit && len(nodes.entries) < findnodeResultLimit; i++ { + var r []*enode.Node + r, err = t.findnode(unwrapNode(destNode), dists[i]) + if err == errClosed { + return nil, err + } + for _, n := range r { + if n.ID() != t.Self().ID() { + nodes.push(wrapNode(n), findnodeResultLimit) + } + } + } + return nodes.entries, err +} + +// lookupDistances computes the distance parameter for FINDNODE calls to dest. +// It chooses distances adjacent to logdist(target, dest), e.g. for a target +// with logdist(target, dest) = 255 the result is [255, 256, 254]. +func lookupDistances(target, dest enode.ID) (dists []int) { + td := enode.LogDist(target, dest) + dists = append(dists, td) + for i := 1; len(dists) < lookupRequestLimit; i++ { + if td+i < 256 { + dists = append(dists, td+i) + } + if td-i > 0 { + dists = append(dists, td-i) + } + } + return dists +} + +// ping calls PING on a node and waits for a PONG response. +func (t *UDPv5) ping(n *enode.Node) (uint64, error) { + resp := t.call(n, p_pongV5, &pingV5{ENRSeq: t.localNode.Node().Seq()}) + defer t.callDone(resp) + select { + case pong := <-resp.ch: + return pong.(*pongV5).ENRSeq, nil + case err := <-resp.err: + return 0, err + } +} + +// requestENR requests n's record. +func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { + nodes, err := t.findnode(n, 0) + if err != nil { + return nil, err + } + if len(nodes) != 1 { + return nil, fmt.Errorf("%d nodes in response for distance zero", len(nodes)) + } + return nodes[0], nil +} + +// requestTicket calls REQUESTTICKET on a node and waits for a TICKET response. +func (t *UDPv5) requestTicket(n *enode.Node) ([]byte, error) { //nolint:unused + resp := t.call(n, p_ticketV5, &pingV5{}) + defer t.callDone(resp) + select { + case response := <-resp.ch: + return response.(*ticketV5).Ticket, nil + case err := <-resp.err: + return nil, err + } +} + +// findnode calls FINDNODE on a node and waits for responses. +func (t *UDPv5) findnode(n *enode.Node, distance int) ([]*enode.Node, error) { + resp := t.call(n, p_nodesV5, &findnodeV5{Distance: uint(distance)}) + return t.waitForNodes(resp, distance) +} + +// waitForNodes waits for NODES responses to the given call. +func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { + defer t.callDone(c) + + var ( + nodes []*enode.Node + seen = make(map[enode.ID]struct{}) + received, total = 0, -1 + ) + for { + select { + case responseP := <-c.ch: + response := responseP.(*nodesV5) + for _, record := range response.Nodes { + node, err := t.verifyResponseNode(c, record, distance, seen) + if err != nil { + t.log.Debug("Invalid record in "+response.name(), "id", c.node.ID(), "err", err) + continue + } + nodes = append(nodes, node) + } + if total == -1 { + total = min(int(response.Total), totalNodesResponseLimit) + } + if received++; received == total { + return nodes, nil + } + case err := <-c.err: + return nodes, err + } + } +} + +// verifyResponseNode checks validity of a record in a NODES response. +func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen map[enode.ID]struct{}) (*enode.Node, error) { + node, err := enode.New(t.validSchemes, r) + if err != nil { + return nil, err + } + if err := netutil.CheckRelayIP(c.node.IP(), node.IP()); err != nil { + return nil, err + } + if c.node.UDP() <= 1024 { + return nil, errLowPort + } + if distance != -1 { + if d := enode.LogDist(c.node.ID(), node.ID()); d != distance { + return nil, fmt.Errorf("wrong distance %d", d) + } + } + if _, ok := seen[node.ID()]; ok { + return nil, fmt.Errorf("duplicate record") + } + seen[node.ID()] = struct{}{} + return node, nil +} + +// call sends the given call and sets up a handler for response packets (of type c.responseType). +// Responses are dispatched to the call's response channel. +func (t *UDPv5) call(node *enode.Node, responseType byte, packet packetV5) *callV5 { + c := &callV5{ + node: node, + packet: packet, + responseType: responseType, + reqid: make([]byte, 8), + ch: make(chan packetV5, 1), + err: make(chan error, 1), + } + // Assign request ID. + crand.Read(c.reqid) + packet.setreqid(c.reqid) + // Send call to dispatch. + select { + case t.callCh <- c: + case <-t.closeCtx.Done(): + c.err <- errClosed + } + return c +} + +// callDone tells dispatch that the active call is done. +func (t *UDPv5) callDone(c *callV5) { + select { + case t.callDoneCh <- c: + case <-t.closeCtx.Done(): + } +} + +// dispatch runs in its own goroutine, handles incoming packets and deals with calls. +// +// For any destination node there is at most one 'active call', stored in the t.activeCall* +// maps. A call is made active when it is sent. The active call can be answered by a +// matching response, in which case c.ch receives the response; or by timing out, in which case +// c.err receives the error. When the function that created the call signals the active +// call is done through callDone, the next call from the call queue is started. +// +// Calls may also be answered by a WHOAREYOU packet referencing the call packet's authTag. +// When that happens the call is simply re-sent to complete the handshake. We allow one +// handshake attempt per call. +func (t *UDPv5) dispatch() { + defer t.wg.Done() + + // Arm first read. + t.readNextCh <- struct{}{} + + for { + select { + case c := <-t.callCh: + id := c.node.ID() + t.callQueue[id] = append(t.callQueue[id], c) + t.sendNextCall(id) + + case ct := <-t.respTimeoutCh: + active := t.activeCallByNode[ct.c.node.ID()] + if ct.c == active && ct.timer == active.timeout { + ct.c.err <- errTimeout + } + + case c := <-t.callDoneCh: + id := c.node.ID() + active := t.activeCallByNode[id] + if active != c { + panic("BUG: callDone for inactive call") + } + c.timeout.Stop() + delete(t.activeCallByAuth, string(c.authTag)) + delete(t.activeCallByNode, id) + t.sendNextCall(id) + + case p := <-t.packetInCh: + t.handlePacket(p.Data, p.Addr) + // Arm next read. + t.readNextCh <- struct{}{} + + case <-t.closeCtx.Done(): + close(t.readNextCh) + for id, queue := range t.callQueue { + for _, c := range queue { + c.err <- errClosed + } + delete(t.callQueue, id) + } + for id, c := range t.activeCallByNode { + c.err <- errClosed + delete(t.activeCallByNode, id) + delete(t.activeCallByAuth, string(c.authTag)) + } + return + } + } +} + +// startResponseTimeout sets the response timer for a call. +func (t *UDPv5) startResponseTimeout(c *callV5) { + if c.timeout != nil { + c.timeout.Stop() + } + var ( + timer mclock.Timer + done = make(chan struct{}) + ) + timer = t.clock.AfterFunc(respTimeoutV5, func() { + <-done + select { + case t.respTimeoutCh <- &callTimeout{c, timer}: + case <-t.closeCtx.Done(): + } + }) + c.timeout = timer + close(done) +} + +// sendNextCall sends the next call in the call queue if there is no active call. +func (t *UDPv5) sendNextCall(id enode.ID) { + queue := t.callQueue[id] + if len(queue) == 0 || t.activeCallByNode[id] != nil { + return + } + t.activeCallByNode[id] = queue[0] + t.sendCall(t.activeCallByNode[id]) + if len(queue) == 1 { + delete(t.callQueue, id) + } else { + copy(queue, queue[1:]) + t.callQueue[id] = queue[:len(queue)-1] + } +} + +// sendCall encodes and sends a request packet to the call's recipient node. +// This performs a handshake if needed. +func (t *UDPv5) sendCall(c *callV5) { + if len(c.authTag) > 0 { + // The call already has an authTag from a previous handshake attempt. Remove the + // entry for the authTag because we're about to generate a new authTag for this + // call. + delete(t.activeCallByAuth, string(c.authTag)) + } + + addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()} + newTag, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) + c.authTag = newTag + t.activeCallByAuth[string(c.authTag)] = c + t.startResponseTimeout(c) +} + +// sendResponse sends a response packet to the given node. +// This doesn't trigger a handshake even if no keys are available. +func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet packetV5) error { + _, err := t.send(toID, toAddr, packet, nil) + return err +} + +// send sends a packet to the given node. +func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet packetV5, c *whoareyouV5) ([]byte, error) { + addr := toAddr.String() + enc, authTag, err := t.codec.encode(toID, addr, packet, c) + if err != nil { + t.log.Warn(">> "+packet.name(), "id", toID, "addr", addr, "err", err) + return authTag, err + } + _, err = t.conn.WriteToUDP(enc, toAddr) + t.log.Trace(">> "+packet.name(), "id", toID, "addr", addr) + return authTag, err +} + +// readLoop runs in its own goroutine and reads packets from the network. +func (t *UDPv5) readLoop() { + defer t.wg.Done() + + buf := make([]byte, maxPacketSize) + for range t.readNextCh { + nbytes, from, err := t.conn.ReadFromUDP(buf) + if netutil.IsTemporaryError(err) { + // Ignore temporary read errors. + t.log.Debug("Temporary UDP read error", "err", err) + continue + } else if err != nil { + // Shut down the loop for permament errors. + if err != io.EOF { + t.log.Debug("UDP read error", "err", err) + } + return + } + t.dispatchReadPacket(from, buf[:nbytes]) + } +} + +// dispatchReadPacket sends a packet into the dispatch loop. +func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { + select { + case t.packetInCh <- ReadPacket{content, from}: + return true + case <-t.closeCtx.Done(): + return false + } +} + +// handlePacket decodes and processes an incoming packet from the network. +func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { + addr := fromAddr.String() + fromID, fromNode, packet, err := t.codec.decode(rawpacket, addr) + if err != nil { + t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) + return err + } + if fromNode != nil { + // Handshake succeeded, add to table. + t.tab.addSeenNode(wrapNode(fromNode)) + } + if packet.kind() != p_whoareyouV5 { + // WHOAREYOU logged separately to report the sender ID. + t.log.Trace("<< "+packet.name(), "id", fromID, "addr", addr) + } + packet.handle(t, fromID, fromAddr) + return nil +} + +// handleCallResponse dispatches a response packet to the call waiting for it. +func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, reqid []byte, p packetV5) { + ac := t.activeCallByNode[fromID] + if ac == nil || !bytes.Equal(reqid, ac.reqid) { + t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.name()), "id", fromID, "addr", fromAddr) + return + } + if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() { + t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.name()), "id", fromID, "addr", fromAddr) + return + } + if p.kind() != ac.responseType { + t.log.Debug(fmt.Sprintf("Wrong disv5 response type %s", p.name()), "id", fromID, "addr", fromAddr) + return + } + t.startResponseTimeout(ac) + ac.ch <- p +} + +// getNode looks for a node record in table and database. +func (t *UDPv5) getNode(id enode.ID) *enode.Node { + if n := t.tab.getNode(id); n != nil { + return n + } + if n := t.localNode.Database().Node(id); n != nil { + return n + } + return nil +} + +// UNKNOWN + +func (p *unknownV5) name() string { return "UNKNOWN/v5" } +func (p *unknownV5) kind() byte { return p_unknownV5 } +func (p *unknownV5) setreqid(id []byte) {} + +func (p *unknownV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + challenge := &whoareyouV5{AuthTag: p.AuthTag} + crand.Read(challenge.IDNonce[:]) + if n := t.getNode(fromID); n != nil { + challenge.node = n + challenge.RecordSeq = n.Seq() + } + t.sendResponse(fromID, fromAddr, challenge) +} + +// WHOAREYOU + +func (p *whoareyouV5) name() string { return "WHOAREYOU/v5" } +func (p *whoareyouV5) kind() byte { return p_whoareyouV5 } +func (p *whoareyouV5) setreqid(id []byte) {} + +func (p *whoareyouV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + c, err := p.matchWithCall(t, p.AuthTag) + if err != nil { + t.log.Debug("Invalid WHOAREYOU/v5", "addr", fromAddr, "err", err) + return + } + // Resend the call that was answered by WHOAREYOU. + t.log.Trace("<< "+p.name(), "id", c.node.ID(), "addr", fromAddr) + c.handshakeCount++ + c.challenge = p + p.node = c.node + t.sendCall(c) +} + +var ( + errChallengeNoCall = errors.New("no matching call") + errChallengeTwice = errors.New("second handshake") +) + +// matchWithCall checks whether the handshake attempt matches the active call. +func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) { + c := t.activeCallByAuth[string(authTag)] + if c == nil { + return nil, errChallengeNoCall + } + if c.handshakeCount > 0 { + return nil, errChallengeTwice + } + return c, nil +} + +// PING + +func (p *pingV5) name() string { return "PING/v5" } +func (p *pingV5) kind() byte { return p_pingV5 } +func (p *pingV5) setreqid(id []byte) { p.ReqID = id } + +func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, &pongV5{ + ReqID: p.ReqID, + ToIP: fromAddr.IP, + ToPort: uint16(fromAddr.Port), + ENRSeq: t.localNode.Node().Seq(), + }) +} + +// PONG + +func (p *pongV5) name() string { return "PONG/v5" } +func (p *pongV5) kind() byte { return p_pongV5 } +func (p *pongV5) setreqid(id []byte) { p.ReqID = id } + +func (p *pongV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// FINDNODE + +func (p *findnodeV5) name() string { return "FINDNODE/v5" } +func (p *findnodeV5) kind() byte { return p_findnodeV5 } +func (p *findnodeV5) setreqid(id []byte) { p.ReqID = id } + +func (p *findnodeV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + if p.Distance == 0 { + t.sendNodes(fromID, fromAddr, p.ReqID, []*enode.Node{t.Self()}) + return + } + if p.Distance > 256 { + p.Distance = 256 + } + // Get bucket entries. + t.tab.mutex.Lock() + nodes := unwrapNodes(t.tab.bucketAtDistance(int(p.Distance)).entries) + t.tab.mutex.Unlock() + if len(nodes) > findnodeResultLimit { + nodes = nodes[:findnodeResultLimit] + } + t.sendNodes(fromID, fromAddr, p.ReqID, nodes) +} + +// sendNodes sends the given records in one or more NODES packets. +func (t *UDPv5) sendNodes(toID enode.ID, toAddr *net.UDPAddr, reqid []byte, nodes []*enode.Node) { + // TODO livenessChecks > 1 + // TODO CheckRelayIP + total := uint8(math.Ceil(float64(len(nodes)) / 3)) + resp := &nodesV5{ReqID: reqid, Total: total, Nodes: make([]*enr.Record, 3)} + sent := false + for len(nodes) > 0 { + items := min(nodesResponseItemLimit, len(nodes)) + resp.Nodes = resp.Nodes[:items] + for i := 0; i < items; i++ { + resp.Nodes[i] = nodes[i].Record() + } + t.sendResponse(toID, toAddr, resp) + nodes = nodes[items:] + sent = true + } + // Ensure at least one response is sent. + if !sent { + resp.Total = 1 + resp.Nodes = nil + t.sendResponse(toID, toAddr, resp) + } +} + +// NODES + +func (p *nodesV5) name() string { return "NODES/v5" } +func (p *nodesV5) kind() byte { return p_nodesV5 } +func (p *nodesV5) setreqid(id []byte) { p.ReqID = id } + +func (p *nodesV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// REQUESTTICKET + +func (p *requestTicketV5) name() string { return "REQUESTTICKET/v5" } +func (p *requestTicketV5) kind() byte { return p_requestTicketV5 } +func (p *requestTicketV5) setreqid(id []byte) { p.ReqID = id } + +func (p *requestTicketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, &ticketV5{ReqID: p.ReqID}) +} + +// TICKET + +func (p *ticketV5) name() string { return "TICKET/v5" } +func (p *ticketV5) kind() byte { return p_ticketV5 } +func (p *ticketV5) setreqid(id []byte) { p.ReqID = id } + +func (p *ticketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// REGTOPIC + +func (p *regtopicV5) name() string { return "REGTOPIC/v5" } +func (p *regtopicV5) kind() byte { return p_regtopicV5 } +func (p *regtopicV5) setreqid(id []byte) { p.ReqID = id } + +func (p *regtopicV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, ®confirmationV5{ReqID: p.ReqID, Registered: false}) +} + +// REGCONFIRMATION + +func (p *regconfirmationV5) name() string { return "REGCONFIRMATION/v5" } +func (p *regconfirmationV5) kind() byte { return p_regconfirmationV5 } +func (p *regconfirmationV5) setreqid(id []byte) { p.ReqID = id } + +func (p *regconfirmationV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// TOPICQUERY + +func (p *topicqueryV5) name() string { return "TOPICQUERY/v5" } +func (p *topicqueryV5) kind() byte { return p_topicqueryV5 } +func (p *topicqueryV5) setreqid(id []byte) { p.ReqID = id } + +func (p *topicqueryV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { +} diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go new file mode 100644 index 0000000000..24a14e1edd --- /dev/null +++ b/p2p/discover/v5_udp_test.go @@ -0,0 +1,622 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/ecdsa" + "encoding/binary" + "fmt" + "math/rand" + "net" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +// Real sockets, real crypto: this test checks end-to-end connectivity for UDPv5. +func TestEndToEndV5(t *testing.T) { + t.Parallel() + + var nodes []*UDPv5 + for i := 0; i < 5; i++ { + var cfg Config + if len(nodes) > 0 { + bn := nodes[0].Self() + cfg.Bootnodes = []*enode.Node{bn} + } + node := startLocalhostV5(t, cfg) + nodes = append(nodes, node) + defer node.Close() + } + + last := nodes[len(nodes)-1] + target := nodes[rand.Intn(len(nodes)-2)].Self() + results := last.Lookup(target.ID()) + if len(results) == 0 || results[0].ID() != target.ID() { + t.Fatalf("lookup returned wrong results: %v", results) + } +} + +func startLocalhostV5(t *testing.T, cfg Config) *UDPv5 { + cfg.PrivateKey = newkey() + db, _ := enode.OpenDB("") + ln := enode.NewLocalNode(db, cfg.PrivateKey, testNetworkId) + + // Prefix logs with node ID. + lprefix := fmt.Sprintf("(%s)", ln.ID().TerminalString()) + lfmt := log.TerminalFormat(false) + cfg.Log = testlog.Logger(t, log.LvlTrace) + cfg.Log.SetHandler(log.FuncHandler(func(r *log.Record) error { + t.Logf("%s %s", lprefix, lfmt.Format(r)) + return nil + })) + + // Listen. + socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + if err != nil { + t.Fatal(err) + } + realaddr := socket.LocalAddr().(*net.UDPAddr) + ln.SetStaticIP(realaddr.IP) + ln.Set(enr.UDP(realaddr.Port)) + udp, err := ListenV5(socket, ln, cfg) + if err != nil { + t.Fatal(err) + } + return udp +} + +// This test checks that incoming PING calls are handled correctly. +func TestUDPv5_pingHandling(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + test.packetIn(&pingV5{ReqID: []byte("foo")}) + test.waitPacketOut(func(p *pongV5, addr *net.UDPAddr, authTag []byte) { + if !bytes.Equal(p.ReqID, []byte("foo")) { + t.Error("wrong request ID in response:", p.ReqID) + } + if p.ENRSeq != test.table.self().Seq() { + t.Error("wrong ENR sequence number in response:", p.ENRSeq) + } + }) +} + +// This test checks that incoming 'unknown' packets trigger the handshake. +func TestUDPv5_unknownPacket(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + authTag := [12]byte{1, 2, 3} + check := func(p *whoareyouV5, wantSeq uint64) { + t.Helper() + if !bytes.Equal(p.AuthTag, authTag[:]) { + t.Error("wrong token in WHOAREYOU:", p.AuthTag, authTag[:]) + } + if p.IDNonce == ([32]byte{}) { + t.Error("all zero ID nonce") + } + if p.RecordSeq != wantSeq { + t.Errorf("wrong record seq %d in WHOAREYOU, want %d", p.RecordSeq, wantSeq) + } + } + + // Unknown packet from unknown node. + test.packetIn(&unknownV5{AuthTag: authTag[:]}) + test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + check(p, 0) + }) + + // Make node known. + n := test.getNode(test.remotekey, test.remoteaddr).Node() + test.table.addSeenNode(wrapNode(n)) + + test.packetIn(&unknownV5{AuthTag: authTag[:]}) + test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + check(p, n.Seq()) + }) +} + +// This test checks that incoming FINDNODE calls are handled correctly. +func TestUDPv5_findnodeHandling(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Create test nodes and insert them into the table. + nodes := nodesAtDistance(test.table.self().ID(), 253, 10) + fillTable(test.table, wrapNodes(nodes)) + + // Requesting with distance zero should return the node's own record. + test.packetIn(&findnodeV5{ReqID: []byte{0}, Distance: 0}) + test.expectNodes([]byte{0}, 1, []*enode.Node{test.udp.Self()}) + + // Requesting with distance > 256 caps it at 256. + test.packetIn(&findnodeV5{ReqID: []byte{1}, Distance: 4234098}) + test.expectNodes([]byte{1}, 1, nil) + + // This request gets no nodes because the corresponding bucket is empty. + test.packetIn(&findnodeV5{ReqID: []byte{2}, Distance: 254}) + test.expectNodes([]byte{2}, 1, nil) + + // This request gets all test nodes. + test.packetIn(&findnodeV5{ReqID: []byte{3}, Distance: 253}) + test.expectNodes([]byte{3}, 4, nodes) +} + +func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes []*enode.Node) { + nodeSet := make(map[enode.ID]*enr.Record) + for _, n := range wantNodes { + nodeSet[n.ID()] = n.Record() + } + for { + test.waitPacketOut(func(p *nodesV5, addr *net.UDPAddr, authTag []byte) { + if len(p.Nodes) > 3 { + test.t.Fatalf("too many nodes in response") + } + if p.Total != wantTotal { + test.t.Fatalf("wrong total response count %d", p.Total) + } + if !bytes.Equal(p.ReqID, wantReqID) { + test.t.Fatalf("wrong request ID in response: %v", p.ReqID) + } + for _, record := range p.Nodes { + n, _ := enode.New(enode.ValidSchemesForTesting, record) + want := nodeSet[n.ID()] + if want == nil { + test.t.Fatalf("unexpected node in response: %v", n) + } + if !reflect.DeepEqual(record, want) { + test.t.Fatalf("wrong record in response: %v", n) + } + delete(nodeSet, n.ID()) + } + }) + if len(nodeSet) == 0 { + return + } + } +} + +// This test checks that outgoing PING calls work. +func TestUDPv5_pingCall(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 1) + + // This ping times out. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) {}) + if err := <-done; err != errTimeout { + t.Fatalf("want errTimeout, got %q", err) + } + + // This ping works. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetInFrom(test.remotekey, test.remoteaddr, &pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != nil { + t.Fatal(err) + } + + // This ping gets a reply from the wrong endpoint. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} + test.packetInFrom(test.remotekey, wrongAddr, &pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != errTimeout { + t.Fatalf("want errTimeout for reply from wrong IP, got %q", err) + } +} + +// This test checks that outgoing FINDNODE calls work and multiple NODES +// replies are aggregated. +func TestUDPv5_findnodeCall(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Launch the request: + var ( + distance = 230 + remote = test.getNode(test.remotekey, test.remoteaddr).Node() + nodes = nodesAtDistance(remote.ID(), distance, 8) + done = make(chan error, 1) + response []*enode.Node + ) + go func() { + var err error + response, err = test.udp.findnode(remote, distance) + done <- err + }() + + // Serve the responses: + test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { + if p.Distance != uint(distance) { + t.Fatalf("wrong bucket: %d", p.Distance) + } + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[:4]), + }) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[4:]), + }) + }) + + // Check results: + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(response, nodes) { + t.Fatalf("wrong nodes in response") + } + + // TODO: check invalid IPs + // TODO: check invalid/unsigned record +} + +// This test checks that pending calls are re-sent when a handshake happens. +func TestUDPv5_callResend(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 2) + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + + // Ping answered by WHOAREYOU. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + // Ping should be re-sent. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&pongV5{ReqID: p.ReqID}) + }) + // Answer the other ping. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != nil { + t.Fatalf("unexpected ping error: %v", err) + } + if err := <-done; err != nil { + t.Fatalf("unexpected ping error: %v", err) + } +} + +// This test ensures we don't allow multiple rounds of WHOAREYOU for a single call. +func TestUDPv5_multipleHandshakeRounds(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 1) + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + + // Ping answered by WHOAREYOU. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + // Ping answered by WHOAREYOU again. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + if err := <-done; err != errTimeout { + t.Fatalf("unexpected ping error: %q", err) + } +} + +// This test checks that calls with n replies may take up to n * respTimeout. +func TestUDPv5_callTimeoutReset(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Launch the request: + var ( + distance = 230 + remote = test.getNode(test.remotekey, test.remoteaddr).Node() + nodes = nodesAtDistance(remote.ID(), distance, 8) + done = make(chan error, 1) + ) + go func() { + _, err := test.udp.findnode(remote, distance) + done <- err + }() + + // Serve two responses, slowly. + test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { + time.Sleep(respTimeout - 50*time.Millisecond) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[:4]), + }) + + time.Sleep(respTimeout - 50*time.Millisecond) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[4:]), + }) + }) + if err := <-done; err != nil { + t.Fatalf("unexpected error: %q", err) + } +} + +// This test checks that lookup works. +func TestUDPv5_lookup(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + + // Lookup on empty table returns no nodes. + if results := test.udp.Lookup(lookupTestnet.target.id()); len(results) > 0 { + t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) + } + + // Ensure the tester knows all nodes in lookupTestnet by IP. + for d, nn := range lookupTestnet.dists { + for i, key := range nn { + n := lookupTestnet.node(d, i) + test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()}) + } + } + + // Seed table with initial node. + fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}) + + // Start the lookup. + resultC := make(chan []*enode.Node, 1) + go func() { + resultC <- test.udp.Lookup(lookupTestnet.target.id()) + test.close() + }() + + // Answer lookup packets. + for done := false; !done; { + done = test.waitPacketOut(func(p packetV5, to *net.UDPAddr, authTag []byte) { + recipient, key := lookupTestnet.nodeByAddr(to) + switch p := p.(type) { + case *pingV5: + test.packetInFrom(key, to, &pongV5{ReqID: p.ReqID}) + case *findnodeV5: + nodes := lookupTestnet.neighborsAtDistance(recipient, p.Distance, 3) + response := &nodesV5{ReqID: p.ReqID, Total: 1, Nodes: nodesToRecords(nodes)} + test.packetInFrom(key, to, response) + } + }) + } + + // Verify result nodes. + checkLookupResults(t, lookupTestnet, <-resultC) +} + +// udpV5Test is the framework for all tests above. +// It runs the UDPv5 transport on a virtual socket and allows testing outgoing packets. +type udpV5Test struct { + t *testing.T + pipe *dgramPipe + table *Table + db *enode.DB + udp *UDPv5 + localkey, remotekey *ecdsa.PrivateKey + remoteaddr *net.UDPAddr + nodesByID map[enode.ID]*enode.LocalNode + nodesByIP map[string]*enode.LocalNode +} + +type testCodec struct { + test *udpV5Test + id enode.ID + ctr uint64 +} + +type testCodecFrame struct { + NodeID enode.ID + AuthTag []byte + Ptype byte + Packet rlp.RawValue +} + +func (c *testCodec) encode(toID enode.ID, addr string, p packetV5, _ *whoareyouV5) ([]byte, []byte, error) { + c.ctr++ + authTag := make([]byte, 8) + binary.BigEndian.PutUint64(authTag, c.ctr) + penc, _ := rlp.EncodeToBytes(p) + frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.kind(), penc}) + return frame, authTag, err +} + +func (c *testCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { + frame, p, err := c.decodeFrame(input) + if err != nil { + return enode.ID{}, nil, nil, err + } + if p.kind() == p_whoareyouV5 { + frame.NodeID = enode.ID{} // match wireCodec behavior + } + return frame.NodeID, nil, p, nil +} + +func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p packetV5, err error) { + if err = rlp.DecodeBytes(input, &frame); err != nil { + return frame, nil, fmt.Errorf("invalid frame: %v", err) + } + switch frame.Ptype { + case p_unknownV5: + dec := new(unknownV5) + err = rlp.DecodeBytes(frame.Packet, &dec) + p = dec + case p_whoareyouV5: + dec := new(whoareyouV5) + err = rlp.DecodeBytes(frame.Packet, &dec) + p = dec + default: + p, err = decodePacketBodyV5(frame.Ptype, frame.Packet) + } + return frame, p, err +} + +func newUDPV5Test(t *testing.T) *udpV5Test { + test := &udpV5Test{ + t: t, + pipe: newpipe(), + localkey: newkey(), + remotekey: newkey(), + remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + nodesByID: make(map[enode.ID]*enode.LocalNode), + nodesByIP: make(map[string]*enode.LocalNode), + } + test.db, _ = enode.OpenDB("") + ln := enode.NewLocalNode(test.db, test.localkey, testNetworkId) + ln.SetStaticIP(net.IP{10, 0, 0, 1}) + ln.Set(enr.UDP(30303)) + test.udp, _ = ListenV5(test.pipe, ln, Config{ + PrivateKey: test.localkey, + Log: testlog.Logger(t, log.LvlTrace), + ValidSchemes: enode.ValidSchemesForTesting, + }) + test.udp.codec = &testCodec{test: test, id: ln.ID()} + test.table = test.udp.tab + test.nodesByID[ln.ID()] = ln + // Wait for initial refresh so the table doesn't send unexpected findnode. + <-test.table.initDone + return test +} + +// handles a packet as if it had been sent to the transport. +func (test *udpV5Test) packetIn(packet packetV5) { + test.t.Helper() + test.packetInFrom(test.remotekey, test.remoteaddr, packet) +} + +// handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet packetV5) { + test.t.Helper() + + ln := test.getNode(key, addr) + codec := &testCodec{test: test, id: ln.ID()} + enc, _, err := codec.encode(test.udp.Self().ID(), addr.String(), packet, nil) + if err != nil { + test.t.Errorf("%s encode error: %v", packet.name(), err) + } + if test.udp.dispatchReadPacket(addr, enc) { + <-test.udp.readNextCh // unblock UDPv5.dispatch + } +} + +// getNode ensures the test knows about a node at the given endpoint. +func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode { + id := encodePubkey(&key.PublicKey).id() + ln := test.nodesByID[id] + if ln == nil { + db, _ := enode.OpenDB("") + ln = enode.NewLocalNode(db, key, testNetworkId) + ln.SetStaticIP(addr.IP) + ln.Set(enr.UDP(addr.Port)) + test.nodesByID[id] = ln + } + test.nodesByIP[string(addr.IP)] = ln + return ln +} + +func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { + test.t.Helper() + fn := reflect.ValueOf(validate) + exptype := fn.Type().In(0) + + dgram, err := test.pipe.receive() + if err == errClosed { + return true + } + if err == errTimeout { + test.t.Fatalf("timed out waiting for %v", exptype) + return false + } + ln := test.nodesByIP[string(dgram.to.IP)] + if ln == nil { + test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to) + return false + } + codec := &testCodec{test: test, id: ln.ID()} + frame, p, err := codec.decodeFrame(dgram.data) + if err != nil { + test.t.Errorf("sent packet decode error: %v", err) + return false + } + if !reflect.TypeOf(p).AssignableTo(exptype) { + test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return false + } + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)}) + return false +} + +func (test *udpV5Test) close() { + test.t.Helper() + + test.udp.Close() + test.db.Close() + for id, n := range test.nodesByID { + if id != test.udp.Self().ID() { + n.Database().Close() + } + } + if len(test.pipe.queue) != 0 { + test.t.Fatalf("%d unmatched UDP packets in queue", len(test.pipe.queue)) + } +} diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 44332640c7..bd066ce857 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -41,6 +41,7 @@ const ( dbNodePrefix = "n:" // Identifier to prefix node entries with dbLocalPrefix = "local:" dbDiscoverRoot = "v4" + dbDiscv5Root = "v5" // These fields are stored per ID and IP, the full key is "n::v4::findfail". // Use nodeItemKey to create those keys. @@ -172,6 +173,16 @@ func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { return id, ip, field } +func v5Key(id ID, ip net.IP, field string) []byte { + return bytes.Join([][]byte{ + []byte(dbNodePrefix), + id[:], + []byte(dbDiscv5Root), + ip.To16(), + []byte(field), + }, []byte{':'}) +} + // localItemKey returns the key of a local node item. func localItemKey(id ID, field string) []byte { key := append([]byte(dbLocalPrefix), id[:]...) @@ -378,6 +389,16 @@ func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) } +// FindFailsV5 retrieves the discv5 findnode failure counter. +func (db *DB) FindFailsV5(id ID, ip net.IP) int { + return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) +} + +// UpdateFindFailsV5 stores the discv5 findnode failure counter. +func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { + return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) +} + // LocalSeq retrieves the local record sequence counter. func (db *DB) localSeq(id ID) uint64 { return db.fetchUint64(localItemKey(id, dbLocalSeq)) diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index 2adb14145d..d2b187896f 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -462,3 +462,14 @@ func TestDBExpiration(t *testing.T) { } } } + +// This test checks that expiration works when discovery v5 data is present +// in the database. +func TestDBExpireV5(t *testing.T) { + db, _ := OpenDB("") + defer db.Close() + + ip := net.IP{127, 0, 0, 1} + db.UpdateFindFailsV5(ID{}, ip, 4) + db.expireNodes() +}