Skip to content

Commit

Permalink
Merge branch 'branch-0.3.20' into 'branch-0.3.20'
Browse files Browse the repository at this point in the history
Backport: Hold rpc ports till server start tony-framework#611

See merge request !65
  • Loading branch information
zhangjunfan committed Oct 20, 2021
2 parents 18295d9 + 032b5bf commit 6e009c7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
21 changes: 5 additions & 16 deletions tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
Expand Down Expand Up @@ -92,7 +90,6 @@
import com.linkedin.tony.plugin.TFTrainMetricCollector;
import com.linkedin.tony.rpc.ApplicationRpc;
import com.linkedin.tony.rpc.ApplicationRpcServer;
import com.linkedin.tony.rpc.MetricsRpc;
import com.linkedin.tony.rpc.TaskInfo;
import com.linkedin.tony.rpc.impl.MetricsRpcServer;
import com.linkedin.tony.rpc.impl.TaskStatus;
Expand Down Expand Up @@ -497,12 +494,8 @@ private boolean prepare() throws IOException {
setOpalEnv(containerEnv);

// Setup metrics RPC server.
ServerSocket rpcSocket = new ServerSocket(0);
int metricsRpcPort = rpcSocket.getLocalPort();
rpcSocket.close();
metricsRpcServer = new MetricsRpcServer();
RPC.Builder metricsServerBuilder = new RPC.Builder(yarnConf).setProtocol(MetricsRpc.class)
.setInstance(metricsRpcServer).setPort(metricsRpcPort);
metricsRpcServer = new MetricsRpcServer(yarnConf);
int metricsRpcPort = metricsRpcServer.getMetricRpcPort();
containerEnv.put(Constants.METRICS_RPC_PORT, Integer.toString(metricsRpcPort));

// Init AMRMClient
Expand Down Expand Up @@ -531,7 +524,7 @@ private boolean prepare() throws IOException {
byte[] secret = response.getClientToAMTokenMasterKey().array();
ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(appAttemptID, secret);
applicationRpcServer.setSecretManager(secretManager);
metricsServerBuilder.setSecretManager(secretManager);
metricsRpcServer.setSecretManager(secretManager);

// create token for application RPC server
Token<? extends TokenIdentifier> tensorflowClusterToken = new Token<>(identifier, secretManager);
Expand All @@ -558,11 +551,7 @@ private boolean prepare() throws IOException {
applicationRpcServer.start();

LOG.info("Starting metrics RPC server at: " + amHostname + ":" + metricsRpcPort);
RPC.Server metricsServer = metricsServerBuilder.build();
if (yarnConf.getBoolean(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, false)) {
metricsServer.refreshServiceAclWithLoadedConfiguration(yarnConf, new TonyPolicyProvider());
}
metricsServer.start();
metricsRpcServer.start();

hbMonitor.start();

Expand Down Expand Up @@ -1000,7 +989,7 @@ private void printTaskUrls() {
}
}

private ApplicationRpcServer setupRPCService(String hostname) {
private ApplicationRpcServer setupRPCService(String hostname) throws IOException {
ApplicationRpcServer rpcServer = new ApplicationRpcServer(hostname, new RpcForClient(), yarnConf);
amPort = rpcServer.getRpcPort();
return rpcServer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.linkedin.tony.TonyPolicyProvider;
import com.linkedin.tony.rpc.impl.pb.service.TensorFlowClusterPBServiceImpl;
import java.io.IOException;
import java.util.Random;
import java.net.ServerSocket;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
Expand All @@ -24,19 +24,21 @@

public class ApplicationRpcServer extends Thread implements TensorFlowCluster {
private static final RecordFactory RECORD_FACTORY = RecordFactoryProvider.getRecordFactory(null);
private static final Random RANDOM_NUMBER_GENERATOR = new Random();
private final int rpcPort;
private final String rpcAddress;
private final ApplicationRpc appRpc;
private ClientToAMTokenSecretManager secretManager;
private Server server;
private Configuration conf;
private ServerSocket rpcSocket;

public ApplicationRpcServer(String hostname, ApplicationRpc rpc, Configuration conf) {
public ApplicationRpcServer(String hostname, ApplicationRpc rpc, Configuration conf) throws IOException {
this.rpcAddress = hostname;
this.rpcPort = 10000 + RANDOM_NUMBER_GENERATOR.nextInt(5000) + 1;
this.appRpc = rpc;
this.conf = conf;

this.rpcSocket = new ServerSocket(0);
this.rpcPort = rpcSocket.getLocalPort();
}

@Override
Expand Down Expand Up @@ -124,6 +126,7 @@ public void run() {
translator = new TensorFlowClusterPBServiceImpl(this);
BlockingService service = com.linkedin.tony.rpc.proto.TensorFlowCluster.TensorFlowClusterService
.newReflectiveBlockingService(translator);
rpcSocket.close();
server = new RPC.Builder(conf).setProtocol(TensorFlowClusterPB.class)
.setInstance(service).setBindAddress(rpcAddress)
.setPort(rpcPort) // TODO: let RPC randomly generate it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
*/
package com.linkedin.tony.rpc.impl;

import com.linkedin.tony.TonyPolicyProvider;
import com.linkedin.tony.events.Metric;
import com.linkedin.tony.rpc.MetricsRpc;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager;


/**
Expand All @@ -24,6 +30,22 @@ public class MetricsRpcServer implements MetricsRpc {

private Map<String, Map<Integer, MetricsWritable>> metricsMap = new HashMap<>();

private Configuration conf;
private ServerSocket metricRpcSocket;
private int metricRpcPort;
private ClientToAMTokenSecretManager secretManager;

public MetricsRpcServer() {
// ignore
}

public MetricsRpcServer(Configuration conf) throws IOException {
this.conf = conf;

this.metricRpcSocket = new ServerSocket(0);
this.metricRpcPort = metricRpcSocket.getLocalPort();
}

public List<Metric> getMetrics(String taskType, int taskIndex) {
if (!metricsMap.containsKey(taskType) || !metricsMap.get(taskType).containsKey(taskIndex)) {
LOG.warn("No metrics for " + taskType + " " + taskIndex + "!");
Expand Down Expand Up @@ -53,4 +75,23 @@ public ProtocolSignature getProtocolSignature(String protocol, long clientVersio
throws IOException {
return ProtocolSignature.getProtocolSignature(this, protocol, clientVersion, clientMethodsHash);
}

public int getMetricRpcPort() {
return metricRpcPort;
}

public void setSecretManager(ClientToAMTokenSecretManager secretManager) {
this.secretManager = secretManager;
}

public void start() throws IOException {
RPC.Builder metricsServerBuilder = new RPC.Builder(conf).setProtocol(MetricsRpc.class)
.setInstance(this).setPort(metricRpcPort).setSecretManager(secretManager);
metricRpcSocket.close();
RPC.Server metricsServer = metricsServerBuilder.build();
if (conf.getBoolean(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION, false)) {
metricsServer.refreshServiceAclWithLoadedConfiguration(conf, new TonyPolicyProvider());
}
metricsServer.start();
}
}

0 comments on commit 6e009c7

Please sign in to comment.