Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add network id to peer discovery #601

Merged
merged 11 commits into from
Jun 23, 2018
Merged
22 changes: 16 additions & 6 deletions rskj-core/src/main/java/co/rsk/net/discovery/PeerExplorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class PeerExplorer {
private final Map<String, PeerDiscoveryRequest> pendingFindNodeRequests = new ConcurrentHashMap<>();

private final Map<NodeID, Node> establishedConnections = new ConcurrentHashMap<>();
private final Integer networkId;

private UDPChannel udpChannel;

Expand All @@ -72,12 +73,12 @@ public class PeerExplorer {

private long requestTimeout;

public PeerExplorer(List<String> initialBootNodes, Node localNode, NodeDistanceTable distanceTable, ECKey key, long reqTimeOut, long updatePeriod, long cleanPeriod) {
public PeerExplorer(List<String> initialBootNodes, Node localNode, NodeDistanceTable distanceTable, ECKey key, long reqTimeOut, long updatePeriod, long cleanPeriod, Integer networkId) {
this.localNode = localNode;
this.key = key;
this.distanceTable = distanceTable;
this.updateEntryLock = new ReentrantLock();

this.networkId = networkId;
loadInitialBootNodes(initialBootNodes);

this.cleaner = new PeerExplorerCleaner(this, updatePeriod, cleanPeriod);
Expand Down Expand Up @@ -111,6 +112,12 @@ public void setUDPChannel(UDPChannel udpChannel) {

public void handleMessage(DiscoveryEvent event) {
DiscoveryMessageType type = event.getMessage().getMessageType();
//If this is not from my network ignore it. But if the messages do not
//have a networkId in the message yet, then just let them through, for now.
if (event.getMessage().getNetworkId().isPresent() &&
event.getMessage().getNetworkId().getAsInt() != this.networkId) {
return;
}
if (type == DiscoveryMessageType.PING) {
this.handlePingMessage(event.getAddressIp(), (PingPeerMessage) event.getMessage());
}
Expand Down Expand Up @@ -198,7 +205,10 @@ public PingPeerMessage sendPing(InetSocketAddress nodeAddress, int attempt, Node

InetSocketAddress localAddress = this.localNode.getAddress();
String id = UUID.randomUUID().toString();
nodeMessage = PingPeerMessage.create(localAddress.getAddress().getHostAddress(), localAddress.getPort(), id, this.key);
nodeMessage = PingPeerMessage.create(
localAddress.getAddress().getHostAddress(),
localAddress.getPort(),
id, this.key, this.networkId);
udpChannel.write(new DiscoveryEvent(nodeMessage, nodeAddress));

PeerDiscoveryRequest request = PeerDiscoveryRequestBuilder.builder().messageId(id)
Expand Down Expand Up @@ -231,7 +241,7 @@ private PingPeerMessage checkPendingPeerToAddress(InetSocketAddress address) {

public PongPeerMessage sendPong(String ip, PingPeerMessage message) {
InetSocketAddress localAddress = this.localNode.getAddress();
PongPeerMessage pongPeerMessage = PongPeerMessage.create(localAddress.getHostName(), localAddress.getPort(), message.getMessageId(), this.key);
PongPeerMessage pongPeerMessage = PongPeerMessage.create(localAddress.getHostName(), localAddress.getPort(), message.getMessageId(), this.key, this.networkId);
InetSocketAddress nodeAddress = new InetSocketAddress(ip, message.getPort());
udpChannel.write(new DiscoveryEvent(pongPeerMessage, nodeAddress));

Expand All @@ -241,7 +251,7 @@ public PongPeerMessage sendPong(String ip, PingPeerMessage message) {
public FindNodePeerMessage sendFindNode(Node node) {
InetSocketAddress nodeAddress = node.getAddress();
String id = UUID.randomUUID().toString();
FindNodePeerMessage findNodePeerMessage = FindNodePeerMessage.create(this.key.getNodeId(), id, this.key);
FindNodePeerMessage findNodePeerMessage = FindNodePeerMessage.create(this.key.getNodeId(), id, this.key, this.networkId);
udpChannel.write(new DiscoveryEvent(findNodePeerMessage, nodeAddress));
PeerDiscoveryRequest request = PeerDiscoveryRequestBuilder.builder().messageId(id).relatedNode(node)
.message(findNodePeerMessage).address(nodeAddress).expectedResponse(DiscoveryMessageType.NEIGHBORS)
Expand All @@ -253,7 +263,7 @@ public FindNodePeerMessage sendFindNode(Node node) {

public NeighborsPeerMessage sendNeighbors(InetSocketAddress nodeAddress, List<Node> nodes, String id) {
List<Node> nodesToSend = getRandomizeLimitedList(nodes, MAX_NODES_PER_MSG, 5);
NeighborsPeerMessage sendNodesMessage = NeighborsPeerMessage.create(nodesToSend, id, this.key);
NeighborsPeerMessage sendNodesMessage = NeighborsPeerMessage.create(nodesToSend, id, this.key, networkId);
udpChannel.write(new DiscoveryEvent(sendNodesMessage, nodeAddress));
logger.debug(" [{}] Neighbors Sent to ip:[{}] port:[{}]", nodesToSend.size(), nodeAddress.getAddress().getHostAddress(), nodeAddress.getPort());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

import org.apache.commons.lang3.builder.ToStringBuilder;
import org.ethereum.crypto.ECKey;
import org.ethereum.util.RLP;
import org.ethereum.util.RLPItem;
import org.ethereum.util.RLPList;
import org.ethereum.util.*;
import org.spongycastle.util.encoders.Hex;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.OptionalInt;

import static org.ethereum.util.ByteUtil.intToBytes;
import static org.ethereum.util.ByteUtil.stripLeadingZeroes;

/**
* Created by mario on 16/02/17.
Expand All @@ -44,20 +46,24 @@ public FindNodePeerMessage(byte[] wire, byte[] mdc, byte[] signature, byte[] typ
private FindNodePeerMessage() {
}

public static FindNodePeerMessage create(byte[] nodeId, String check, ECKey privKey) {
public static FindNodePeerMessage create(byte[] nodeId, String check, ECKey privKey, Integer networkId) {

/* RLP Encode data */
byte[] rlpCheck = RLP.encodeElement(check.getBytes(StandardCharsets.UTF_8));
byte[] rlpNodeId = RLP.encodeElement(nodeId);

byte[] type = new byte[]{(byte) DiscoveryMessageType.FIND_NODE.getTypeValue()};
byte[] data = RLP.encodeList(rlpNodeId, rlpCheck);

byte[] data;
byte[] rlpNetworkId = RLP.encodeElement(stripLeadingZeroes(intToBytes(networkId)));
data = RLP.encodeList(rlpNodeId, rlpCheck, rlpNetworkId);

FindNodePeerMessage message = new FindNodePeerMessage();
message.encode(type, data, privKey);

message.messageId = check;
message.nodeId = nodeId;
message.setNetworkId(OptionalInt.of(networkId));

return message;
}
Expand All @@ -72,6 +78,8 @@ public final void parse(byte[] data) {
RLPItem nodeRlp = (RLPItem) dataList.get(0);

this.nodeId = nodeRlp.getRLPData();

this.setNetworkIdWithRLP(dataList.size()>2?dataList.get(2):null);
}


Expand All @@ -88,6 +96,7 @@ public DiscoveryMessageType getMessageType() {
public String toString() {
return new ToStringBuilder(this)
.append(Hex.toHexString(this.nodeId))
.append(this.getNetworkId())
.append(this.messageId).toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalInt;

import static org.ethereum.util.ByteUtil.intToBytes;
import static org.ethereum.util.ByteUtil.stripLeadingZeroes;

/**
* Created by mario on 16/02/17.
Expand Down Expand Up @@ -62,9 +66,11 @@ public final void parse(byte[] data) {
RLPItem chk = (RLPItem) list.get(1);

this.messageId = new String(chk.getRLPData(), Charset.forName("UTF-8"));

this.setNetworkIdWithRLP(list.size()>2?list.get(2):null);
}

public static NeighborsPeerMessage create(List<Node> nodes, String check, ECKey privKey) {
public static NeighborsPeerMessage create(List<Node> nodes, String check, ECKey privKey, Integer networkId) {
byte[][] nodeRLPs = null;

if (nodes != null) {
Expand All @@ -80,10 +86,14 @@ public static NeighborsPeerMessage create(List<Node> nodes, String check, ECKey
byte[] rlpCheck = RLP.encodeElement(check.getBytes(StandardCharsets.UTF_8));

byte[] type = new byte[]{(byte) DiscoveryMessageType.NEIGHBORS.getTypeValue()};
byte[] data = RLP.encodeList(rlpListNodes, rlpCheck);
byte[] data;
byte[] tmpNetworkId = intToBytes(networkId);
byte[] rlpNetworkId = RLP.encodeElement(stripLeadingZeroes(tmpNetworkId));
data = RLP.encodeList(rlpListNodes, rlpCheck, rlpNetworkId);

NeighborsPeerMessage neighborsMessage = new NeighborsPeerMessage();
neighborsMessage.encode(type, data, privKey);
neighborsMessage.setNetworkId(OptionalInt.of(networkId));
neighborsMessage.nodes = nodes;
neighborsMessage.messageId = check;

Expand Down Expand Up @@ -111,7 +121,8 @@ public int countNodes() {
public String toString() {
return new ToStringBuilder(this)
.append(this.nodes)
.append(this.messageId).toString();
.append(this.messageId)
.append(this.getNetworkId()).toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.ethereum.crypto.ECKey;
import org.ethereum.crypto.HashUtil;
import org.ethereum.util.ByteUtil;
import org.ethereum.util.RLPElement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.spongycastle.util.BigIntegers;
import org.spongycastle.util.encoders.Hex;

import java.security.SignatureException;
import java.util.Optional;
import java.util.OptionalInt;

import static org.ethereum.crypto.HashUtil.keccak256;
import static org.ethereum.util.ByteUtil.merge;
Expand All @@ -41,6 +45,7 @@ public abstract class PeerDiscoveryMessage {
private byte[] signature;
private byte[] type;
private byte[] data;
private OptionalInt networkId;

public PeerDiscoveryMessage() {}

Expand All @@ -51,6 +56,7 @@ public PeerDiscoveryMessage(byte[] wire, byte[] mdc, byte[] signature, byte[] ty
this.data = data;
this.wire = wire;
}

public PeerDiscoveryMessage encode(byte[] type, byte[] data, ECKey privKey) {
/* [1] Calc sha3 - prepare for sig */
byte[] payload = new byte[type.length + data.length];
Expand Down Expand Up @@ -110,6 +116,22 @@ public ECKey getKey() {
return outKey;
}

public OptionalInt getNetworkId() {
return this.networkId;
}

protected void setNetworkId(final OptionalInt networkId) {
this.networkId = networkId;
}

protected void setNetworkIdWithRLP(final RLPElement networkId) {
Integer setValue = null;
if (networkId != null) {
setValue = ByteUtil.byteArrayToInt(networkId.getRLPData());
}
this.setNetworkId(Optional.ofNullable(setValue).map(OptionalInt::of).orElseGet(OptionalInt::empty));
}

public NodeID getNodeId() {
byte[] nodeID = new byte[64];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.OptionalInt;

import static org.ethereum.util.ByteUtil.intToBytes;
import static org.ethereum.util.ByteUtil.longToBytes;
import static org.ethereum.util.ByteUtil.stripLeadingZeroes;

Expand All @@ -46,7 +48,7 @@ public PingPeerMessage(byte[] wire, byte[] mdc, byte[] signature, byte[] type, b

private PingPeerMessage() {}

public static PingPeerMessage create(String host, int port, String check, ECKey privKey) {
public static PingPeerMessage create(String host, int port, String check, ECKey privKey, Integer networkId) {
/* RLP Encode data */
byte[] rlpIp = RLP.encodeElement(host.getBytes(StandardCharsets.UTF_8));

Expand All @@ -56,17 +58,19 @@ public static PingPeerMessage create(String host, int port, String check, ECKey
byte[] rlpIpTo = RLP.encodeElement(host.getBytes(StandardCharsets.UTF_8));
byte[] tmpPortTo = longToBytes(port);
byte[] rlpPortTo = RLP.encodeElement(stripLeadingZeroes(tmpPortTo));

byte[] rlpCheck = RLP.encodeElement(check.getBytes(StandardCharsets.UTF_8));

byte[] type = new byte[]{(byte) DiscoveryMessageType.PING.getTypeValue()};
byte[] rlpFromList = RLP.encodeList(rlpIp, rlpPort, rlpPort);
byte[] rlpToList = RLP.encodeList(rlpIpTo, rlpPortTo, rlpPortTo);
byte[] data = RLP.encodeList(rlpFromList, rlpToList, rlpCheck);
byte[] rlpCheck = RLP.encodeElement(check.getBytes(StandardCharsets.UTF_8));
byte[] data;
byte[] tmpNetworkId = intToBytes(networkId);
byte[] rlpNetworkID = RLP.encodeElement(stripLeadingZeroes(tmpNetworkId));
data = RLP.encodeList(rlpFromList, rlpToList, rlpCheck, rlpNetworkID);

PingPeerMessage message = new PingPeerMessage();
message.encode(type, data, privKey);

message.setNetworkId(OptionalInt.of(networkId));
message.messageId = check;
message.host = host;
message.port = port;
Expand All @@ -85,8 +89,10 @@ public final void parse(byte[] data) {
this.host = new String(ipB, Charset.forName("UTF-8"));
this.port = ByteUtil.byteArrayToInt(fromList.get(1).getRLPData());
this.messageId = new String(chk.getRLPData(), Charset.forName("UTF-8"));
}

//Message from nodes that do not have this
this.setNetworkIdWithRLP(dataList.size()>3?dataList.get(3):null);
}

public String getMessageId() {
return this.messageId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.OptionalInt;

import static org.ethereum.util.ByteUtil.intToBytes;
import static org.ethereum.util.ByteUtil.longToBytes;
import static org.ethereum.util.ByteUtil.stripLeadingZeroes;

Expand All @@ -47,7 +49,7 @@ public PongPeerMessage(byte[] wire, byte[] mdc, byte[] signature, byte[] type, b
private PongPeerMessage() {
}

public static PongPeerMessage create(String host, int port, String check, ECKey privKey) {
public static PongPeerMessage create(String host, int port, String check, ECKey privKey, Integer networkId) {
/* RLP Encode data */
byte[] rlpIp = RLP.encodeElement(host.getBytes(StandardCharsets.UTF_8));

Expand All @@ -63,11 +65,16 @@ public static PongPeerMessage create(String host, int port, String check, ECKey
byte[] type = new byte[]{(byte) DiscoveryMessageType.PONG.getTypeValue()};
byte[] rlpFromList = RLP.encodeList(rlpIp, rlpPort, rlpPort);
byte[] rlpToList = RLP.encodeList(rlpIpTo, rlpPortTo, rlpPortTo);
byte[] data = RLP.encodeList(rlpFromList, rlpToList, rlpCheck);
byte[] data;

byte[] tmpNetworkId = intToBytes(networkId);
byte[] rlpNetworkID = RLP.encodeElement(stripLeadingZeroes(tmpNetworkId));
data = RLP.encodeList(rlpFromList, rlpToList, rlpCheck, rlpNetworkID);

PongPeerMessage message = new PongPeerMessage();
message.encode(type, data, privKey);

message.setNetworkId(OptionalInt.of(networkId));
message.messageId = check;
message.host = host;
message.port = port;
Expand Down Expand Up @@ -95,6 +102,9 @@ public final void parse(byte[] data) {
RLPItem chk = (RLPItem) dataList.get(2);

this.messageId = new String(chk.getRLPData(), Charset.forName("UTF-8"));

//Message from nodes that do not have this
this.setNetworkIdWithRLP(dataList.size()>3?dataList.get(3):null);
}

public String getMessageId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ public BlockValidationRule minerServerBlockValidationRule(
@Bean
public PeerExplorer peerExplorer(RskSystemProperties rskConfig) {
ECKey key = rskConfig.getMyKey();
Integer networkId = rskConfig.networkId();
Node localNode = new Node(key.getNodeId(), rskConfig.getPublicIp(), rskConfig.getPeerPort());
NodeDistanceTable distanceTable = new NodeDistanceTable(KademliaOptions.BINS, KademliaOptions.BUCKET_SIZE, localNode);
long msgTimeOut = rskConfig.peerDiscoveryMessageTimeOut();
Expand All @@ -219,7 +220,7 @@ public PeerExplorer peerExplorer(RskSystemProperties rskConfig) {
initialBootNodes.add(address.getHostName() + ":" + address.getPort());
}
}
return new PeerExplorer(initialBootNodes, localNode, distanceTable, key, msgTimeOut, refreshPeriod, cleanPeriod);
return new PeerExplorer(initialBootNodes, localNode, distanceTable, key, msgTimeOut, refreshPeriod, cleanPeriod, networkId);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.spongycastle.util.encoders.Hex;

import java.util.ArrayList;
import java.util.OptionalInt;
import java.util.UUID;

/**
Expand All @@ -38,6 +39,7 @@ public class NodeChallengeManagerTest {
private static final String KEY_1 = "bd1d20e480dfb1c9c07ba0bc8cf9052f89923d38b5128c5dbfc18d4eea38261f";
private static final String HOST_1 = "localhost";
private static final int PORT_1 = 44035;
private static final int NETWORK_ID = 1;

private static final String KEY_2 = "bd2d20e480dfb1c9c07ba0bc8cf9052f89923d38b5128c5dbfc18d4eea38262f";
private static final String HOST_2 = "localhost";
Expand All @@ -63,7 +65,7 @@ public void startChallenge() {
Node node3 = new Node(key3.getNodeId(), HOST_3, PORT_3);

NodeDistanceTable distanceTable = new NodeDistanceTable(KademliaOptions.BINS, KademliaOptions.BUCKET_SIZE, node1);
PeerExplorer peerExplorer = new PeerExplorer(new ArrayList<>(), node1, distanceTable, new ECKey(), TIMEOUT, UPDATE, CLEAN);
PeerExplorer peerExplorer = new PeerExplorer(new ArrayList<>(), node1, distanceTable, new ECKey(), TIMEOUT, UPDATE, CLEAN, NETWORK_ID);
peerExplorer.setUDPChannel(Mockito.mock(UDPChannel.class));

NodeChallengeManager manager = new NodeChallengeManager();
Expand Down
Loading