Skip to content

Commit

Permalink
Hold rpc ports till server start (#611)
Browse files Browse the repository at this point in the history
* Hold rpc ports till server start

* Remove serverPortHolder
  • Loading branch information
zuston authored Oct 19, 2021
1 parent 7622a6b commit 4f302c4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 21 deletions.
23 changes: 6 additions & 17 deletions tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.linkedin.tony.events.EventType;
import com.linkedin.tony.rpc.ApplicationRpc;
import com.linkedin.tony.rpc.ApplicationRpcServer;
import com.linkedin.tony.rpc.MetricsRpc;
import com.linkedin.tony.rpc.TaskInfo;
import com.linkedin.tony.rpc.impl.MetricsRpcServer;
import com.linkedin.tony.rpc.impl.TaskStatus;
Expand Down Expand Up @@ -55,13 +54,11 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
Expand Down Expand Up @@ -446,17 +443,13 @@ private boolean prepare() throws IOException {

// Setup application RPC server
String amHostname = Utils.getCurrentHostName();
applicationRpcServer = setupRPCService(amHostname);
applicationRpcServer = setupAppRPCService(amHostname);
containerEnv.put(Constants.AM_HOST, amHostname);
containerEnv.put(Constants.AM_PORT, Integer.toString(amPort));

// Setup metrics RPC server.
ServerSocket rpcSocket = new ServerSocket(0);
int metricsRpcPort = rpcSocket.getLocalPort();
rpcSocket.close();
metricsRpcServer = new MetricsRpcServer();
RPC.Builder metricsServerBuilder = new RPC.Builder(yarnConf).setProtocol(MetricsRpc.class)
.setInstance(metricsRpcServer).setPort(metricsRpcPort);
metricsRpcServer = new MetricsRpcServer(yarnConf);
int metricsRpcPort = metricsRpcServer.getMetricRpcPort();
containerEnv.put(Constants.METRICS_RPC_PORT, Integer.toString(metricsRpcPort));

// Init AMRMClient
Expand Down Expand Up @@ -484,7 +477,7 @@ private boolean prepare() throws IOException {
byte[] secret = response.getClientToAMTokenMasterKey().array();
ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(appAttemptID, secret);
applicationRpcServer.setSecretManager(secretManager);
metricsServerBuilder.setSecretManager(secretManager);
metricsRpcServer.setSecretManager(secretManager);

// create token for application RPC server
Token<? extends TokenIdentifier> tensorflowClusterToken = new Token<>(identifier, secretManager);
Expand All @@ -511,11 +504,7 @@ private boolean prepare() throws IOException {
applicationRpcServer.start();

LOG.info("Starting metrics RPC server at: " + amHostname + ":" + metricsRpcPort);
RPC.Server metricsServer = metricsServerBuilder.build();
if (yarnConf.getBoolean(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, false)) {
metricsServer.refreshServiceAclWithLoadedConfiguration(yarnConf, new TonyPolicyProvider());
}
metricsServer.start();
metricsRpcServer.start();

hbMonitor.start();

Expand Down Expand Up @@ -840,7 +829,7 @@ private void printTaskUrls() {
}
}

private ApplicationRpcServer setupRPCService(String hostname) throws IOException {
private ApplicationRpcServer setupAppRPCService(String hostname) throws IOException {
ApplicationRpcServer rpcServer = new ApplicationRpcServer(hostname, new RpcForClient(), yarnConf);
amPort = rpcServer.getRpcPort();
return rpcServer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import com.linkedin.tony.rpc.impl.pb.service.TonyClusterPBServiceImpl;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.Random;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
Expand All @@ -25,7 +25,7 @@

public class ApplicationRpcServer extends Thread implements TonyCluster {
private static final RecordFactory RECORD_FACTORY = RecordFactoryProvider.getRecordFactory(null);
private static final Random RANDOM_NUMBER_GENERATOR = new Random();
private ServerSocket rpcSocket;
private final int rpcPort;
private final String rpcAddress;
private final ApplicationRpc appRpc;
Expand All @@ -35,9 +35,10 @@ public class ApplicationRpcServer extends Thread implements TonyCluster {

public ApplicationRpcServer(String hostname, ApplicationRpc rpc, Configuration conf) throws IOException {
this.rpcAddress = hostname;
ServerSocket rpcSocket = new ServerSocket(0);

this.rpcSocket = new ServerSocket(0);
this.rpcPort = rpcSocket.getLocalPort();
rpcSocket.close();

this.appRpc = rpc;
this.conf = conf;
}
Expand Down Expand Up @@ -127,6 +128,7 @@ public void run() {
translator = new TonyClusterPBServiceImpl(this);
BlockingService service = com.linkedin.tony.rpc.proto.TonyCluster.TonyClusterService
.newReflectiveBlockingService(translator);
rpcSocket.close();
server = new RPC.Builder(conf).setProtocol(TonyClusterPB.class)
.setInstance(service).setBindAddress(rpcAddress)
.setPort(rpcPort) // TODO: let RPC randomly generate it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
*/
package com.linkedin.tony.rpc.impl;

import com.linkedin.tony.TonyPolicyProvider;
import com.linkedin.tony.events.Metric;
import com.linkedin.tony.rpc.MetricsRpc;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager;


/**
Expand All @@ -24,6 +30,18 @@ public class MetricsRpcServer implements MetricsRpc {

private Map<String, Map<Integer, MetricsWritable>> metricsMap = new HashMap<>();

private Configuration conf;
private ServerSocket metricRpcSocket;
private int metricRpcPort;
private ClientToAMTokenSecretManager secretManager;

public MetricsRpcServer(Configuration conf) throws IOException {
this.conf = conf;

this.metricRpcSocket = new ServerSocket(0);
this.metricRpcPort = metricRpcSocket.getLocalPort();
}

public List<Metric> getMetrics(String taskType, int taskIndex) {
if (!metricsMap.containsKey(taskType) || !metricsMap.get(taskType).containsKey(taskIndex)) {
LOG.warn("No metrics for " + taskType + " " + taskIndex + "!");
Expand Down Expand Up @@ -53,4 +71,23 @@ public ProtocolSignature getProtocolSignature(String protocol, long clientVersio
throws IOException {
return ProtocolSignature.getProtocolSignature(this, protocol, clientVersion, clientMethodsHash);
}

public int getMetricRpcPort() {
return metricRpcPort;
}

public void setSecretManager(ClientToAMTokenSecretManager secretManager) {
this.secretManager = secretManager;
}

public void start() throws IOException {
RPC.Builder metricsServerBuilder = new RPC.Builder(conf).setProtocol(MetricsRpc.class)
.setInstance(this).setPort(metricRpcPort).setSecretManager(secretManager);
metricRpcSocket.close();
RPC.Server metricsServer = metricsServerBuilder.build();
if (conf.getBoolean(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, false)) {
metricsServer.refreshServiceAclWithLoadedConfiguration(conf, new TonyPolicyProvider());
}
metricsServer.start();
}
}

0 comments on commit 4f302c4

Please sign in to comment.