From b2a96ab803787e8d0bf443c1392be88495faa701 Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 14 Aug 2024 16:26:24 +0300 Subject: [PATCH 1/8] mnist with summaries updaetd to TF v2 Signed-off-by: yelias --- .../mnist_with_summaries/Dockerfile | 2 +- .../mnist_with_summaries.py | 302 ++++++++---------- 2 files changed, 140 insertions(+), 164 deletions(-) diff --git a/examples/tensorflow/mnist_with_summaries/Dockerfile b/examples/tensorflow/mnist_with_summaries/Dockerfile index a2c5b77abb..fb0ff6b99c 100644 --- a/examples/tensorflow/mnist_with_summaries/Dockerfile +++ b/examples/tensorflow/mnist_with_summaries/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM tensorflow/tensorflow:1.11.0 +FROM tensorflow/tensorflow:2.17.0 ADD examples/tensorflow/mnist_with_summaries/ /var/tf_mnist ENTRYPOINT ["python", "/var/tf_mnist/mnist_with_summaries.py"] diff --git a/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py b/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py index b5c47c65b8..0f0542590d 100644 --- a/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py +++ b/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py @@ -18,165 +18,142 @@ naming summary tags so that they are grouped meaningfully in TensorBoard. It demonstrates the functionality of every TensorBoard dashboard. """ -from __future__ import absolute_import, division, print_function - import argparse import os -import sys +import numpy as np import tensorflow as tf -from tensorflow.examples.tutorials.mnist import input_data - -FLAGS = None - - -def train(): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, fake_data=FLAGS.fake_data) - - sess = tf.InteractiveSession() - # Create a multilayer model. - - # Input placeholders - with tf.name_scope("input"): - x = tf.placeholder(tf.float32, [None, 784], name="x-input") - y_ = tf.placeholder(tf.int64, [None], name="y-input") - - with tf.name_scope("input_reshape"): - image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) - tf.summary.image("input", image_shaped_input, 10) - - # We can't initialize these variables to 0 - the network will get stuck. - def weight_variable(shape): - """Create a weight variable with appropriate initialization.""" - initial = tf.truncated_normal(shape, stddev=0.1) - return tf.Variable(initial) - - def bias_variable(shape): - """Create a bias variable with appropriate initialization.""" - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) - - def variable_summaries(var): - """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" - with tf.name_scope("summaries"): - mean = tf.reduce_mean(var) - tf.summary.scalar("mean", mean) - with tf.name_scope("stddev"): - stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) - tf.summary.scalar("stddev", stddev) - tf.summary.scalar("max", tf.reduce_max(var)) - tf.summary.scalar("min", tf.reduce_min(var)) - tf.summary.histogram("histogram", var) - - def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu): - """Reusable code for making a simple neural net layer. - It does a matrix multiply, bias add, and then uses ReLU to nonlinearize. - It also sets up name scoping so that the resultant graph is easy to read, - and adds a number of summary ops. +from tensorflow.keras.datasets import mnist + + +def load_data(fake_data=False): + """ + Loads the MNIST dataset and converts it into TensorFlow datasets. + + Args: + fake_data (bool): If `True`, loads a fake dataset for testing purposes. + If `False`, loads the real MNIST dataset. + + Returns: + train_ds (tf.data.Dataset): Dataset containing the training data (images and labels). + test_ds (tf.data.Dataset): Dataset containing the test data (images and labels). + """ + if fake_data: + (x_train, y_train), (x_test, y_test) = load_fake_data() + else: + (x_train, y_train), (x_test, y_test) = mnist.load_data(path=FLAGS.data_path) + # Create TensorFlow datasets from the NumPy arrays + train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_ds, test_ds + + +def load_fake_data(): + x_train = np.random.randint(0, 256, (60000, 28, 28)).astype(np.uint8) + y_train = np.random.randint(0, 10, (60000,)).astype(np.uint8) + x_test = np.random.randint(0, 256, (10000, 28, 28)).astype(np.uint8) + y_test = np.random.randint(0, 10, (10000,)).astype(np.uint8) + + return (x_train, y_train), (x_test, y_test) + + +def preprocess(ds): + """ + Preprocesses the dataset by normalizing the images, shuffling, batching, and prefetching. + + Args: + ds (tf.data.Dataset): The dataset to preprocess (either training or testing data). + + Returns: + ds (tf.data.Dataset): The preprocessed dataset. + """ + + def normalize_img(image, label): """ - # Adding a name scope ensures logical grouping of the layers in the graph. - with tf.name_scope(layer_name): - # This Variable will hold the state of the weights for the layer - with tf.name_scope("weights"): - weights = weight_variable([input_dim, output_dim]) - variable_summaries(weights) - with tf.name_scope("biases"): - biases = bias_variable([output_dim]) - variable_summaries(biases) - with tf.name_scope("Wx_plus_b"): - preactivate = tf.matmul(input_tensor, weights) + biases - tf.summary.histogram("pre_activations", preactivate) - activations = act(preactivate, name="activation") - tf.summary.histogram("activations", activations) - return activations - - hidden1 = nn_layer(x, 784, 500, "layer1") - - with tf.name_scope("dropout"): - keep_prob = tf.placeholder(tf.float32) - tf.summary.scalar("dropout_keep_probability", keep_prob) - dropped = tf.nn.dropout(hidden1, keep_prob) - - # Do not apply softmax activation yet, see below. - y = nn_layer(dropped, 500, 10, "layer2", act=tf.identity) - - with tf.name_scope("cross_entropy"): - # The raw formulation of cross-entropy, - # - # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)), - # reduction_indices=[1])) - # - # can be numerically unstable. - # - # So here we use tf.losses.sparse_softmax_cross_entropy on the - # raw logit outputs of the nn_layer above, and then average across - # the batch. - with tf.name_scope("total"): - cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y) - tf.summary.scalar("cross_entropy", cross_entropy) - - with tf.name_scope("train"): - train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cross_entropy) - - with tf.name_scope("accuracy"): - with tf.name_scope("correct_prediction"): - correct_prediction = tf.equal(tf.argmax(y, 1), y_) - with tf.name_scope("accuracy"): - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - tf.summary.scalar("accuracy", accuracy) - - # Merge all the summaries and write them out to - # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default) - merged = tf.summary.merge_all() - train_writer = tf.summary.FileWriter(FLAGS.log_dir + "/train", sess.graph) - test_writer = tf.summary.FileWriter(FLAGS.log_dir + "/test") - tf.global_variables_initializer().run() - - # Train the model, and also write summaries. - # Every 10th step, measure test-set accuracy, and write test summaries - # All other steps, run train_step on training data, & add training summaries - - def feed_dict(train): # pylint: disable=redefined-outer-name - """Make a TensorFlow feed_dict: maps data onto Tensor placeholders.""" - if train or FLAGS.fake_data: - xs, ys = mnist.train.next_batch(FLAGS.batch_size, fake_data=FLAGS.fake_data) - k = FLAGS.dropout - else: - xs, ys = mnist.test.images, mnist.test.labels - k = 1.0 - return {x: xs, y_: ys, keep_prob: k} - - for i in range(FLAGS.max_steps): - if i % 10 == 0: # Record summaries and test-set accuracy - summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) - test_writer.add_summary(summary, i) - print("Accuracy at step %s: %s" % (i, acc)) - else: # Record train set summaries, and train - if i % 100 == 99: # Record execution stats - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - run_metadata = tf.RunMetadata() - summary, _ = sess.run( - [merged, train_step], - feed_dict=feed_dict(True), - options=run_options, - run_metadata=run_metadata, - ) - train_writer.add_run_metadata(run_metadata, "step%03d" % i) - train_writer.add_summary(summary, i) - print("Adding run metadata for", i) - else: # Record a summary - summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) - train_writer.add_summary(summary, i) - train_writer.close() - test_writer.close() - - -def main(_): - if tf.gfile.Exists(FLAGS.log_dir): - tf.gfile.DeleteRecursively(FLAGS.log_dir) - tf.gfile.MakeDirs(FLAGS.log_dir) - train() + Normalizes images by scaling pixel values from the range [0, 255] to [0, 1]. + + Args: + image (tf.Tensor): The image tensor. + label (tf.Tensor): The corresponding label tensor. + + Returns: + tuple: The normalized image and the corresponding label. + """ + image = tf.cast(image, tf.float32) / 255.0 + return image, label + + # Map the normalization function across the dataset + ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.cache() # Cache the dataset to improve performance + ds = ds.shuffle( + buffer_size=10000 + ) # Shuffle the dataset with a buffer size of 10,000 + ds = ds.batch(FLAGS.batch_size) # Batch the dataset + ds = ds.prefetch( + buffer_size=tf.data.experimental.AUTOTUNE + ) # Prefetch to improve performance + return ds + + +def build_model(): + """ + Builds a simple neural network model using Keras Sequential API. + + Returns: + model (tf.keras.Model): The compiled Keras model. + """ + model = tf.keras.Sequential( + [ + tf.keras.layers.Input( + shape=(28, 28, 1) + ), # Input layer with the shape of MNIST images + tf.keras.layers.Flatten(), + tf.keras.layers.Dense( + 128, activation="relu" + ), # Dense layer with 128 neurons and ReLU activation + tf.keras.layers.Dropout( + 1 - FLAGS.dropout + ), # Dropout layer to prevent overfitting + tf.keras.layers.Dense( + 10, activation="softmax" + ), # Output layer with 10 neurons (one for each class) + ] + ) + # Define an optimizer with a specific learning rate + optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) + # Compile the model with Adam optimizer and sparse categorical crossentropy loss + model.compile( + optimizer=optimizer, + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + return model + + +def main(): + """ + The main function to load data, preprocess it, build the model, and train it. + """ + # Load and preprocess data + train_ds, test_ds = load_data(fake_data=FLAGS.fake_data) + train_ds = preprocess(train_ds) + test_ds = preprocess(test_ds) + + # Build model + model = build_model() + + # Setup TensorBoard + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=FLAGS.log_dir, histogram_freq=1 + ) + + # Train the model + model.fit( + train_ds, + epochs=FLAGS.epochs, + validation_data=test_ds, + callbacks=[tensorboard_callback], + ) if __name__ == "__main__": @@ -190,13 +167,13 @@ def main(_): help="If true, uses fake data for unit testing.", ) parser.add_argument( - "--max_steps", type=int, default=1000, help="Number of steps to run trainer." + "--epochs", type=int, default=5, help="Number of epochs for training." ) parser.add_argument( "--learning_rate", type=float, default=0.001, help="Initial learning rate" ) parser.add_argument( - "--batch_size", type=int, default=100, help="Training batch size" + "--batch_size", type=int, default=64, help="Training batch size" ) parser.add_argument( "--dropout", @@ -205,12 +182,10 @@ def main(_): help="Keep probability for training dropout.", ) parser.add_argument( - "--data_dir", + "--data_path", type=str, - default=os.path.join( - os.getenv("TEST_TMPDIR", "/tmp"), "tensorflow/mnist/input_data" - ), - help="Directory for storing input data", + default="mnist.npz", + help="Path where to cache the dataset locally (relative to ~/.keras/datasets).", ) parser.add_argument( "--log_dir", @@ -221,5 +196,6 @@ def main(_): ), help="Summaries log directory", ) - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + FLAGS, _ = parser.parse_known_args() + print(f"Run script with {FLAGS=}") + main() From abd968e89baa22b83efa6d2f5450714ac18f83d0 Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 21 Aug 2024 16:29:27 +0300 Subject: [PATCH 2/8] tf_sample updaetd to TF v2 Signed-off-by: yelias --- examples/tensorflow/tf_sample/Dockerfile | 2 +- examples/tensorflow/tf_sample/tf_smoke.py | 157 +++++++--------------- 2 files changed, 47 insertions(+), 112 deletions(-) diff --git a/examples/tensorflow/tf_sample/Dockerfile b/examples/tensorflow/tf_sample/Dockerfile index 690fb6ec20..5ce15848f5 100644 --- a/examples/tensorflow/tf_sample/Dockerfile +++ b/examples/tensorflow/tf_sample/Dockerfile @@ -1,4 +1,4 @@ -FROM tensorflow/tensorflow:1.8.0 +FROM tensorflow/tensorflow:2.17.0 RUN pip install retrying RUN mkdir -p /opt/kubeflow COPY examples/tensorflow/tf_sample/tf_smoke.py /opt/kubeflow/ diff --git a/examples/tensorflow/tf_sample/tf_smoke.py b/examples/tensorflow/tf_sample/tf_smoke.py index bbde61167c..2fc5c3899d 100644 --- a/examples/tensorflow/tf_sample/tf_smoke.py +++ b/examples/tensorflow/tf_sample/tf_smoke.py @@ -1,28 +1,33 @@ -"""Train a simple TF program to verify we can execute ops. +""" +Run a distributed TensorFlow program using +MultiWorkerMirroredStrategy to verify we can execute ops. The program does a simple matrix multiplication. -Only the master assigns ops to devices/workers. +With MultiWorkerMirroredStrategy, the operations are distributed across multiple workers, +and each worker performs the matrix multiplication. The strategy handles the distribution +of operations and aggregation of results. -The master will assign ops to every task in the cluster. This way we can verify -that distributed training is working by executing ops on all devices. +This way we can verify that distributed training is working by executing ops on all devices. """ import argparse -import json -import logging -import os +import time +import numpy as np import retrying import tensorflow as tf +# Set up the MultiWorkerMirroredStrategy to distribute computation across multiple workers. +strategy = tf.distribute.MultiWorkerMirroredStrategy() + def parse_args(): """Parse the command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( - "--sleep_secs", default=0, type=int, help=("Amount of time to sleep at the end") + "--sleep_secs", default=0, type=int, help="Amount of time to sleep at the end" ) # TODO(jlewi): We ignore unknown arguments because the backend is currently @@ -38,116 +43,46 @@ def parse_args(): wait_exponential_max=10000, stop_max_delay=60 * 3 * 1000, ) -def run(server, cluster_spec): # pylint: disable=too-many-statements, too-many-locals - """Build the graph and run the example. - - Args: - server: The TensorFlow server to use. - - Raises: - RuntimeError: If the expected log entries aren't found. +def matrix_multiplication_fn(): """ + Perform matrix multiplication on two example matrices using TensorFlow. - # construct the graph and create a saver object - with tf.Graph().as_default(): # pylint: disable=not-context-manager - # The initial value should be such that type is correctly inferred as - # float. - width = 10 - height = 10 - results = [] - - # The master assigns ops to every TFProcess in the cluster. - for job_name in cluster_spec.keys(): - for i in range(len(cluster_spec[job_name])): - d = "/job:{0}/task:{1}".format(job_name, i) - with tf.device(d): - a = tf.constant(range(width * height), shape=[height, width]) - b = tf.constant(range(width * height), shape=[height, width]) - c = tf.multiply(a, b) - results.append(c) - - init_op = tf.global_variables_initializer() - - if server: - target = server.target - else: - # Create a direct session. - target = "" - - logging.info("Server target: %s", target) - with tf.Session( - target, config=tf.ConfigProto(log_device_placement=True) - ) as sess: - sess.run(init_op) - for r in results: - result = sess.run(r) - logging.info("Result: %s", result) - - -def main(): - """Run training. - - Raises: - ValueError: If the arguments are invalid. + Returns: + tf.Tensor: The result of the matrix multiplication. """ - logging.info("Tensorflow version: %s", tf.__version__) - logging.info("Tensorflow git version: %s", tf.__git_version__) - - tf_config_json = os.environ.get("TF_CONFIG", "{}") - tf_config = json.loads(tf_config_json) - logging.info("tf_config: %s", tf_config) - - task = tf_config.get("task", {}) - logging.info("task: %s", task) - - cluster_spec = tf_config.get("cluster", {}) - logging.info("cluster_spec: %s", cluster_spec) - - server = None - device_func = None - if cluster_spec: - cluster_spec_object = tf.train.ClusterSpec(cluster_spec) - server_def = tf.train.ServerDef( - cluster=cluster_spec_object.as_cluster_def(), - protocol="grpc", - job_name=task["type"], - task_index=task["index"], - ) + width = 10 + height = 10 + a = np.arange(width * height).reshape(height, width).astype(np.float32) + b = np.arange(width * height).reshape(height, width).astype(np.float32) + + # Perform matrix multiplication + c = tf.matmul(a, b) + tf.print(f"Result for this device: {c}") - logging.info("server_def: %s", server_def) + return c - logging.info("Building server.") - # Create and start a server for the local task. - server = tf.train.Server(server_def) - logging.info("Finished building server.") - # Assigns ops to the local worker by default. - device_func = tf.train.replica_device_setter( - worker_device="/job:worker/task:%d" % server_def.task_index, - cluster=server_def.cluster, +def run(): + """ + Run the distributed matrix multiplication operation across multiple devices. + """ + with strategy.scope(): + tf.print(f"Number of devices: {strategy.num_replicas_in_sync}") + + result = strategy.run(matrix_multiplication_fn) + + # Reduce results across devices to get a single result + reduced_result = strategy.reduce(tf.distribute.ReduceOp.SUM, result, axis=None) + tf.print( + "Summed result of matrix multiplication across all devices:", reduced_result ) - else: - # This should return a null op device setter since we are using - # all the defaults. - logging.error("Using default device function.") - device_func = tf.train.replica_device_setter() - - job_type = task.get("type", "").lower() - if job_type == "ps": - logging.info("Running PS code.") - server.join() - elif job_type == "worker": - logging.info("Running Worker code.") - # The worker just blocks because we let the master assign all ops. - server.join() - elif job_type in ["master", "chief"] or not job_type: - logging.info("Running master/chief.") - with tf.device(device_func): - run(server=server, cluster_spec=cluster_spec) - else: - raise ValueError("invalid job_type %s" % (job_type,)) if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - main() + args = parse_args() + + # Execute the distributed matrix multiplication. + run() + if args.sleep_secs: + print(f"Sleeping for {args.sleep_secs} seconds") + time.sleep(args.sleep_secs) From 6baa62b496b6c4fbbffeac6ddd559f86ce79a19d Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 25 Sep 2024 20:45:34 +0300 Subject: [PATCH 3/8] Add mnist_utils and update dist-mnist Signed-off-by: yelias --- examples/tensorflow/dist-mnist/Dockerfile | 7 +- examples/tensorflow/dist-mnist/dist_mnist.py | 430 ++++++------------ .../tensorflow/dist-mnist/tf_job_mnist.yaml | 14 +- examples/tensorflow/mnist_utils.py | 140 ++++++ .../mnist_with_summaries/Dockerfile | 1 + .../mnist_with_summaries.py | 172 ++----- 6 files changed, 341 insertions(+), 423 deletions(-) create mode 100644 examples/tensorflow/mnist_utils.py diff --git a/examples/tensorflow/dist-mnist/Dockerfile b/examples/tensorflow/dist-mnist/Dockerfile index cd03949b04..b0d8fc7d86 100644 --- a/examples/tensorflow/dist-mnist/Dockerfile +++ b/examples/tensorflow/dist-mnist/Dockerfile @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM tensorflow/tensorflow:1.5.0 +FROM tensorflow/tensorflow:2.17.0 + +# Using keras-2.17 because of bug on keras-3.4.1 which used by default by TF-2.17 +ENV TF_USE_LEGACY_KERAS 1 +RUN pip install tf_keras ADD examples/tensorflow/dist-mnist/ /var/tf_dist_mnist +ADD examples/tensorflow/mnist_utils.py /var/tf_dist_mnist ENTRYPOINT ["python", "/var/tf_dist_mnist/dist_mnist.py"] diff --git a/examples/tensorflow/dist-mnist/dist_mnist.py b/examples/tensorflow/dist-mnist/dist_mnist.py index 32ad877a2b..0f0c34e1c1 100755 --- a/examples/tensorflow/dist-mnist/dist_mnist.py +++ b/examples/tensorflow/dist-mnist/dist_mnist.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Distributed MNIST training and validation, with model replicas. +"""Distributed MNIST training and validation, with model replicas, using Parameter Server Strategy. -A simple softmax model with one hidden layer is defined. The parameters -(weights and biases) are located on one parameter server (ps), while the ops +A Sequential model with a Flatten layer, a Dense layer (128 ReLU units), +Dropout for regularization, and a final Dense layer with 10 softmax units for classification. +The parameters (weights and biases) are located on one parameter server (ps), while the ops are executed on two worker nodes by default. The TF sessions also run on the worker node. -Multiple invocations of this script can be done in parallel, with different -values for --task_index. There should be exactly one invocation with ---task_index, which will create a master session that carries out variable -initialization. The other, non-master, sessions will wait for the master -session to finish the initialization before proceeding to the training stage. +This script can be run with multiple workers and parameter servers, with at least +one chief, one worker, and one parameter server. The coordination between the multiple worker invocations occurs due to the definition of the parameters on the same ps devices. The parameter updates @@ -31,307 +29,167 @@ should lead to increased training speed for the simple model. """ -from __future__ import absolute_import, division, print_function - -import json -import math +import argparse import os -import sys -import tempfile import time +import mnist_utils as helper import tensorflow as tf -from tensorflow.examples.tutorials.mnist import input_data - -flags = tf.app.flags -flags.DEFINE_string("data_dir", "/tmp/mnist-data", "Directory for storing mnist data") -flags.DEFINE_boolean( - "download_only", - False, - "Only perform downloading of data; Do not proceed to " - "session preparation, model definition or training", -) -flags.DEFINE_integer( - "task_index", - None, - "Worker task index, should be >= 0. task_index=0 is " - "the master worker task the performs the variable " - "initialization ", -) -flags.DEFINE_integer( - "num_gpus", - 1, - "Total number of gpus for each machine." - "If you don't use GPU, please set it to '0'", -) -flags.DEFINE_integer( - "replicas_to_aggregate", - None, - "Number of replicas to aggregate before parameter update" - "is applied (For sync_replicas mode only; default: " - "num_workers)", -) -flags.DEFINE_integer( - "hidden_units", 100, "Number of units in the hidden layer of the NN" -) -flags.DEFINE_integer( - "train_steps", 20000, "Number of (global) training steps to perform" -) -flags.DEFINE_integer("batch_size", 100, "Training batch size") -flags.DEFINE_float("learning_rate", 0.01, "Learning rate") -flags.DEFINE_boolean( - "sync_replicas", - False, - "Use the sync_replicas (synchronized replicas) mode, " - "wherein the parameter updates from workers are aggregated " - "before applied to avoid stale gradients", -) -flags.DEFINE_boolean( - "existing_servers", - False, - "Whether servers already exists. If True, " - "will use the worker hosts via their GRPC URLs (one client process " - "per worker host). Otherwise, will create an in-process TensorFlow " - "server.", -) -flags.DEFINE_string( - "ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs" -) -flags.DEFINE_string( - "worker_hosts", - "localhost:2223,localhost:2224", - "Comma-separated list of hostname:port pairs", -) -flags.DEFINE_string("job_name", None, "job name: worker or ps") - -FLAGS = flags.FLAGS - -IMAGE_PIXELS = 28 - -# Example: -# cluster = {'ps': ['host1:2222', 'host2:2222'], -# 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} -# os.environ['TF_CONFIG'] = json.dumps( -# {'cluster': cluster, -# 'task': {'type': 'worker', 'index': 1}}) - - -def main(unused_argv): - # Parse environment variable TF_CONFIG to get job_name and task_index - - # If not explicitly specified in the constructor and the TF_CONFIG - # environment variable is present, load cluster_spec from TF_CONFIG. - tf_config = json.loads(os.environ.get("TF_CONFIG") or "{}") - task_config = tf_config.get("task", {}) - task_type = task_config.get("type") - task_index = task_config.get("index") - - FLAGS.job_name = task_type - FLAGS.task_index = task_index - - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) - if FLAGS.download_only: - sys.exit(0) - - if FLAGS.job_name is None or FLAGS.job_name == "": - raise ValueError("Must specify an explicit `job_name`") - if FLAGS.task_index is None or FLAGS.task_index == "": - raise ValueError("Must specify an explicit `task_index`") - - print("job name = %s" % FLAGS.job_name) - print("task index = %d" % FLAGS.task_index) - - cluster_config = tf_config.get("cluster", {}) - ps_hosts = cluster_config.get("ps") - worker_hosts = cluster_config.get("worker") - - ps_hosts_str = ",".join(ps_hosts) - worker_hosts_str = ",".join(worker_hosts) - - FLAGS.ps_hosts = ps_hosts_str - FLAGS.worker_hosts = worker_hosts_str - - # Construct the cluster and start the server - ps_spec = FLAGS.ps_hosts.split(",") - worker_spec = FLAGS.worker_hosts.split(",") - - # Get the number of workers. - num_workers = len(worker_spec) - - cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) - - if not FLAGS.existing_servers: - # Not using existing servers. Create an in-process server. - server = tf.train.Server( - cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index - ) - if FLAGS.job_name == "ps": - server.join() - - is_chief = FLAGS.task_index == 0 - if FLAGS.num_gpus > 0: - # Avoid gpu allocation conflict: now allocate task_num -> #gpu - # for each worker in the corresponding machine - gpu = FLAGS.task_index % FLAGS.num_gpus - worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) - elif FLAGS.num_gpus == 0: - # Just allocate the CPU to worker server - cpu = 0 - worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) - # The device setter will automatically place Variables ops on separate - # parameter servers (ps). The non-Variable ops will be placed on the workers. - # The ps use CPU and workers use corresponding GPU - with tf.device( - tf.train.replica_device_setter( - worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster - ) - ): - global_step = tf.Variable(0, name="global_step", trainable=False) - - # Variables of the hidden layer - hid_w = tf.Variable( - tf.truncated_normal( - [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], - stddev=1.0 / IMAGE_PIXELS, - ), - name="hid_w", + +args = None + + +def init_parser(): + global args + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_path", + type=str, + default="mnist.npz", + help="Path where to cache the dataset locally (relative to ~/.keras/datasets).", + ) + parser.add_argument( + "--dropout", + type=float, + default=0.9, + help="Keep probability for training dropout", + ) + parser.add_argument( + "--batch_size", type=int, default=100, help="Training batch size" + ) + parser.add_argument( + "--learning_rate", type=float, default=0.001, help="Learning rate" + ) + parser.add_argument( + "--epochs", type=int, default=5, help="Number of epochs for training" + ) + parser.add_argument( + "--fake_data", + nargs="?", + const=True, + type=bool, + default=False, + help="If true, uses fake data for unit testing.", + ) + args = parser.parse_args() + print(f"Run script with {args=}") + + +def main(): + # Set the environment variable to allow reporting worker and ps failure to the + # coordinator. This is a workaround and won't be necessary in the future. + os.environ["GRPC_FAIL_FAST"] = "use_caller" + + cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() + print(f"{cluster_resolver=}") + + # Get the cluster specification + cluster_spec = cluster_resolver.cluster_spec() + + # Get the number of PS replicas (parameter servers) + if "ps" in cluster_spec.jobs: + num_ps = cluster_spec.num_tasks("ps") + print(f"Number of PS replicas: {num_ps}") + else: + raise Exception("No PS replicas found in the cluster configuration.") + + if cluster_resolver.task_type in ("worker", "ps"): + # Start a TensorFlow server and wait. + server = tf.distribute.Server( + cluster_spec, + job_name=cluster_resolver.task_type, + task_index=cluster_resolver.task_id, + protocol=cluster_resolver.rpc_layer or "grpc", + start=True, ) - hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") - - # Variables of the softmax layer - sm_w = tf.Variable( - tf.truncated_normal( - [FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units) - ), - name="sm_w", + server.join() + else: + # Run the coordinator. + + # Configure ParameterServerStrategy + variable_partitioner = ( + tf.distribute.experimental.partitioners.MinSizePartitioner( + min_shard_bytes=(256 << 10), max_shards=num_ps + ) ) - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") - - # Ops: located on the worker specified with FLAGS.task_index - x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) - y_ = tf.placeholder(tf.float32, [None, 10]) - - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) - hid = tf.nn.relu(hid_lin) - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) - cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) - - opt = tf.train.AdamOptimizer(FLAGS.learning_rate) - - if FLAGS.sync_replicas: - if FLAGS.replicas_to_aggregate is None: - replicas_to_aggregate = num_workers - else: - replicas_to_aggregate = FLAGS.replicas_to_aggregate + strategy = tf.distribute.ParameterServerStrategy( + cluster_resolver, variable_partitioner=variable_partitioner + ) - opt = tf.train.SyncReplicasOptimizer( - opt, - replicas_to_aggregate=replicas_to_aggregate, - total_num_replicas=num_workers, - name="mnist_sync_replicas", + # Load and preprocess data + train_ds, test_ds = helper.load_data( + fake_data=args.fake_data, data_path=args.data_path, repeat=True + ) + train_ds = helper.preprocess(ds=train_ds, batch_size=args.batch_size) + test_ds = helper.preprocess(ds=test_ds, batch_size=args.batch_size) + + # Distribute training across workers + with strategy.scope(): + model = helper.build_model( + dropout=args.dropout, + learning_rate=args.learning_rate, ) - train_step = opt.minimize(cross_entropy, global_step=global_step) - - if FLAGS.sync_replicas: - local_init_op = opt.local_step_init_op - if is_chief: - local_init_op = opt.chief_init_op - - ready_for_local_init_op = opt.ready_for_local_init_op - - # Initial token and chief queue runners required by the sync_replicas mode - chief_queue_runner = opt.get_chief_queue_runner() - sync_init_op = opt.get_init_tokens_op() - - init_op = tf.global_variables_initializer() - train_dir = tempfile.mkdtemp() - - if FLAGS.sync_replicas: - sv = tf.train.Supervisor( - is_chief=is_chief, - logdir=train_dir, - init_op=init_op, - local_init_op=local_init_op, - ready_for_local_init_op=ready_for_local_init_op, - recovery_wait_secs=1, - global_step=global_step, - ) - else: - sv = tf.train.Supervisor( - is_chief=is_chief, - logdir=train_dir, - init_op=init_op, - recovery_wait_secs=1, - global_step=global_step, - ) + # Start training + time_begin = time.time() + print(f"Training begins @ {time.ctime(time_begin)}") - sess_config = tf.ConfigProto( - allow_soft_placement=True, - log_device_placement=False, - device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index], + model.fit( + train_ds, + batch_size=args.batch_size, + epochs=args.epochs, + steps_per_epoch=6000 // args.batch_size * 2, ) - # The chief worker (task_index==0) session will prepare the session, - # while the remaining workers will wait for the preparation to complete. - if is_chief: - print("Worker %d: Initializing session..." % FLAGS.task_index) - else: - print( - "Worker %d: Waiting for session to be initialized..." % FLAGS.task_index - ) + time_end = time.time() + print(f"Training ends @ {time.ctime(time_end)}") + training_time = time_end - time_begin + print(f"Training elapsed time: {training_time} s") - if FLAGS.existing_servers: - server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] - print("Using existing server at: %s" % server_grpc_url) + # Validation + coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy) + with strategy.scope(): + eval_accuracy = tf.keras.metrics.Accuracy() - sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) - else: - sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) + @tf.function + def eval_step(iterator): + """ + Perform an evaluation step across replicas. - print("Worker %d: Session initialization complete." % FLAGS.task_index) + Args: + iterator: An iterator for the evaluation dataset. + """ - if FLAGS.sync_replicas and is_chief: - # Chief worker will start the chief queue runner and call the init op. - sess.run(sync_init_op) - sv.start_queue_runners(sess, [chief_queue_runner]) + def replica_fn(batch_data, labels): + # Generates output predictions + pred = model(batch_data, training=False) + # Get the predicted class by taking the argmax over the class probabilities (axis=1) + predicted_class = tf.argmax(pred, axis=1, output_type=tf.int64) + eval_accuracy.update_state(labels, predicted_class) - # Perform training - time_begin = time.time() - print("Training begins @ %f" % time_begin) + batch_data, labels = next(iterator) + # Run the function on all workers using strategy.run + strategy.run(replica_fn, args=(batch_data, labels)) - local_step = 0 - while True: - # Training feed - batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) - train_feed = {x: batch_xs, y_: batch_ys} + # Prepare the per-worker evaluation dataset and iterator + per_worker_eval_dataset = coordinator.create_per_worker_dataset(test_ds) + per_worker_eval_iterator = iter(per_worker_eval_dataset) - _, step = sess.run([train_step, global_step], feed_dict=train_feed) - local_step += 1 + # Calculate evaluation steps per epoch (e.g., based on dataset size and batch size) + eval_steps_per_epoch = 10000 // args.batch_size * 2 - now = time.time() - print( - "%f: Worker %d: training step %d done (global step: %d)" - % (now, FLAGS.task_index, local_step, step) - ) + # Loop through the evaluation steps, scheduling them across the workers + for _ in range(eval_steps_per_epoch): + coordinator.schedule(eval_step, args=(per_worker_eval_iterator,)) - if step >= FLAGS.train_steps: - break + # Wait for all scheduled evaluation steps to complete + coordinator.join() - time_end = time.time() - print("Training ends @ %f" % time_end) - training_time = time_end - time_begin - print("Training elapsed time: %f s" % training_time) - - # Validation feed - val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} - val_xent = sess.run(cross_entropy, feed_dict=val_feed) - print( - "After %d training step(s), validation cross entropy = %g" - % (FLAGS.train_steps, val_xent) - ) + # Print the evaluation result (accuracy) + print("Evaluation accuracy: %f" % eval_accuracy.result()) if __name__ == "__main__": - tf.app.run() + init_parser() + main() diff --git a/examples/tensorflow/dist-mnist/tf_job_mnist.yaml b/examples/tensorflow/dist-mnist/tf_job_mnist.yaml index cb6fad1495..c97d03b700 100644 --- a/examples/tensorflow/dist-mnist/tf_job_mnist.yaml +++ b/examples/tensorflow/dist-mnist/tf_job_mnist.yaml @@ -4,16 +4,26 @@ metadata: name: "dist-mnist-for-e2e-test" spec: tfReplicaSpecs: + Chief: + replicas: 1 + restartPolicy: Never + template: + spec: + containers: + - name: tensorflow + image: kubeflow/tf-dist-mnist-test:latest + PS: - replicas: 2 + replicas: 1 restartPolicy: Never template: spec: containers: - name: tensorflow image: kubeflow/tf-dist-mnist-test:latest + Worker: - replicas: 4 + replicas: 2 restartPolicy: Never template: spec: diff --git a/examples/tensorflow/mnist_utils.py b/examples/tensorflow/mnist_utils.py new file mode 100644 index 0000000000..8a698eeb82 --- /dev/null +++ b/examples/tensorflow/mnist_utils.py @@ -0,0 +1,140 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Utility functions for loading, preprocessing, and building models for MNIST data. + +This module provides functions to load the MNIST dataset, preprocess it into TensorFlow datasets, +and build a simple neural network model using TensorFlow's Keras API. +""" + +import numpy as np +import tensorflow as tf +from tensorflow.keras.datasets import mnist + + +def load_data(fake_data=False, data_path=None, repeat=False): + """ + Loads the MNIST dataset and converts it into TensorFlow datasets. + + Args: + fake_data (bool): If `True`, loads a fake dataset for testing purposes. + If `False`, loads the real MNIST dataset. + data_path (str, optional): Path where to cache the dataset locally. + If `None`, the dataset is loaded to the default location. + repeat (bool, optional): If `True`, makes the dataset repeat indefinitely. + + Returns: + train_ds (tf.data.Dataset): Dataset containing the training data (images and labels). + test_ds (tf.data.Dataset): Dataset containing the test data (images and labels). + """ + if fake_data: + (x_train, y_train), (x_test, y_test) = load_fake_data() + else: + (x_train, y_train), (x_test, y_test) = ( + mnist.load_data(path=data_path) if data_path else mnist.load_data() + ) + # Create TensorFlow datasets from the NumPy arrays + train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + if repeat: + return train_ds.repeat(), test_ds.repeat() + return train_ds, test_ds + + +def load_fake_data(): + x_train = np.random.randint(0, 256, (60000, 28, 28)).astype(np.uint8) + y_train = np.random.randint(0, 10, (60000,)).astype(np.uint8) + x_test = np.random.randint(0, 256, (10000, 28, 28)).astype(np.uint8) + y_test = np.random.randint(0, 10, (10000,)).astype(np.uint8) + + return (x_train, y_train), (x_test, y_test) + + +def build_model(dropout=0.9, learning_rate=0.001): + """ + Builds a simple neural network model using Keras Sequential API. + + Args: + dropout (float, optional): Keep probability for training dropout. + learning_rate (float, optional): The learning rate for the Adam optimizer. + + Returns: + model (tf.keras.Model): The compiled Keras model. + """ + model = tf.keras.Sequential( + [ + tf.keras.layers.Input( + shape=(28, 28, 1) + ), # Input layer with the shape of MNIST images + tf.keras.layers.Flatten(), + tf.keras.layers.Dense( + 128, activation="relu" + ), # Dense layer with 128 neurons and ReLU activation + tf.keras.layers.Dropout( + 1 - dropout + ), # Dropout layer to prevent overfitting + tf.keras.layers.Dense( + 10, activation="softmax" + ), # Output layer with 10 neurons (one for each class) + ] + ) + # Define an optimizer with a specific learning rate + optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) + # Compile the model with Adam optimizer and sparse categorical crossentropy loss + model.compile( + optimizer=optimizer, + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + return model + + +def preprocess(ds, batch_size): + """ + Preprocesses the dataset by normalizing the images, shuffling, batching, and prefetching. + + Args: + ds (tf.data.Dataset): The dataset to preprocess (either training or testing data). + batch_size (int): The number of samples per batch of data. + + + Returns: + ds (tf.data.Dataset): The preprocessed dataset. + """ + + def normalize_img(image, label): + """ + Normalizes images by scaling pixel values from the range [0, 255] to [0, 1]. + + Args: + image (tf.Tensor): The image tensor. + label (tf.Tensor): The corresponding label tensor. + + Returns: + tuple: The normalized image and the corresponding label. + """ + image = tf.cast(image, tf.float32) / 255.0 + return image, label + + # Map the normalization function across the dataset + ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.shuffle( + buffer_size=10000 + ) # Shuffle the dataset with a buffer size of 10,000 + ds = ds.batch(batch_size) # Batch the dataset + ds = ds.prefetch( + buffer_size=tf.data.experimental.AUTOTUNE + ) # Prefetch to improve performance. + return ds diff --git a/examples/tensorflow/mnist_with_summaries/Dockerfile b/examples/tensorflow/mnist_with_summaries/Dockerfile index fb0ff6b99c..77a6232a36 100644 --- a/examples/tensorflow/mnist_with_summaries/Dockerfile +++ b/examples/tensorflow/mnist_with_summaries/Dockerfile @@ -15,4 +15,5 @@ FROM tensorflow/tensorflow:2.17.0 ADD examples/tensorflow/mnist_with_summaries/ /var/tf_mnist +ADD examples/tensorflow/mnist_utils.py /var/tf_mnist ENTRYPOINT ["python", "/var/tf_mnist/mnist_with_summaries.py"] diff --git a/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py b/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py index 0f0542590d..b2971538cb 100644 --- a/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py +++ b/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py @@ -21,142 +21,14 @@ import argparse import os -import numpy as np +import mnist_utils as helper import tensorflow as tf -from tensorflow.keras.datasets import mnist +args = None -def load_data(fake_data=False): - """ - Loads the MNIST dataset and converts it into TensorFlow datasets. - - Args: - fake_data (bool): If `True`, loads a fake dataset for testing purposes. - If `False`, loads the real MNIST dataset. - - Returns: - train_ds (tf.data.Dataset): Dataset containing the training data (images and labels). - test_ds (tf.data.Dataset): Dataset containing the test data (images and labels). - """ - if fake_data: - (x_train, y_train), (x_test, y_test) = load_fake_data() - else: - (x_train, y_train), (x_test, y_test) = mnist.load_data(path=FLAGS.data_path) - # Create TensorFlow datasets from the NumPy arrays - train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) - test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) - return train_ds, test_ds - - -def load_fake_data(): - x_train = np.random.randint(0, 256, (60000, 28, 28)).astype(np.uint8) - y_train = np.random.randint(0, 10, (60000,)).astype(np.uint8) - x_test = np.random.randint(0, 256, (10000, 28, 28)).astype(np.uint8) - y_test = np.random.randint(0, 10, (10000,)).astype(np.uint8) - - return (x_train, y_train), (x_test, y_test) - - -def preprocess(ds): - """ - Preprocesses the dataset by normalizing the images, shuffling, batching, and prefetching. - - Args: - ds (tf.data.Dataset): The dataset to preprocess (either training or testing data). - - Returns: - ds (tf.data.Dataset): The preprocessed dataset. - """ - - def normalize_img(image, label): - """ - Normalizes images by scaling pixel values from the range [0, 255] to [0, 1]. - - Args: - image (tf.Tensor): The image tensor. - label (tf.Tensor): The corresponding label tensor. - - Returns: - tuple: The normalized image and the corresponding label. - """ - image = tf.cast(image, tf.float32) / 255.0 - return image, label - - # Map the normalization function across the dataset - ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) - ds = ds.cache() # Cache the dataset to improve performance - ds = ds.shuffle( - buffer_size=10000 - ) # Shuffle the dataset with a buffer size of 10,000 - ds = ds.batch(FLAGS.batch_size) # Batch the dataset - ds = ds.prefetch( - buffer_size=tf.data.experimental.AUTOTUNE - ) # Prefetch to improve performance - return ds - - -def build_model(): - """ - Builds a simple neural network model using Keras Sequential API. - - Returns: - model (tf.keras.Model): The compiled Keras model. - """ - model = tf.keras.Sequential( - [ - tf.keras.layers.Input( - shape=(28, 28, 1) - ), # Input layer with the shape of MNIST images - tf.keras.layers.Flatten(), - tf.keras.layers.Dense( - 128, activation="relu" - ), # Dense layer with 128 neurons and ReLU activation - tf.keras.layers.Dropout( - 1 - FLAGS.dropout - ), # Dropout layer to prevent overfitting - tf.keras.layers.Dense( - 10, activation="softmax" - ), # Output layer with 10 neurons (one for each class) - ] - ) - # Define an optimizer with a specific learning rate - optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) - # Compile the model with Adam optimizer and sparse categorical crossentropy loss - model.compile( - optimizer=optimizer, - loss="sparse_categorical_crossentropy", - metrics=["accuracy"], - ) - return model - - -def main(): - """ - The main function to load data, preprocess it, build the model, and train it. - """ - # Load and preprocess data - train_ds, test_ds = load_data(fake_data=FLAGS.fake_data) - train_ds = preprocess(train_ds) - test_ds = preprocess(test_ds) - - # Build model - model = build_model() - - # Setup TensorBoard - tensorboard_callback = tf.keras.callbacks.TensorBoard( - log_dir=FLAGS.log_dir, histogram_freq=1 - ) - - # Train the model - model.fit( - train_ds, - epochs=FLAGS.epochs, - validation_data=test_ds, - callbacks=[tensorboard_callback], - ) - -if __name__ == "__main__": +def init_parser(): + global args parser = argparse.ArgumentParser() parser.add_argument( "--fake_data", @@ -196,6 +68,38 @@ def main(): ), help="Summaries log directory", ) - FLAGS, _ = parser.parse_known_args() - print(f"Run script with {FLAGS=}") + args = parser.parse_args() + print(f"Run script with {args=}") + + +def main(): + """ + The main function to load data, preprocess it, build the model, and train it. + """ + # Load and preprocess data + train_ds, test_ds = helper.load_data( + data_path=args.data_path, fake_data=args.fake_data + ) + train_ds = helper.preprocess(ds=train_ds, batch_size=args.batch_size) + test_ds = helper.preprocess(ds=test_ds, batch_size=args.batch_size) + + # Build model + model = helper.build_model(dropout=args.dropout, learning_rate=args.learning_rate) + + # Setup TensorBoard + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=args.log_dir, histogram_freq=1 + ) + + # Train the model + model.fit( + train_ds, + epochs=args.epochs, + validation_data=test_ds, + callbacks=[tensorboard_callback], + ) + + +if __name__ == "__main__": + init_parser() main() From 51a4c7743e075587702fa2804c28c15178006f14 Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 25 Sep 2024 20:52:10 +0300 Subject: [PATCH 4/8] Add mnist_utils and update dist-mnist Signed-off-by: yelias --- examples/tensorflow/dist-mnist/dist_mnist.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/tensorflow/dist-mnist/dist_mnist.py b/examples/tensorflow/dist-mnist/dist_mnist.py index 0f0c34e1c1..a443a61dc3 100755 --- a/examples/tensorflow/dist-mnist/dist_mnist.py +++ b/examples/tensorflow/dist-mnist/dist_mnist.py @@ -81,7 +81,6 @@ def main(): os.environ["GRPC_FAIL_FAST"] = "use_caller" cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() - print(f"{cluster_resolver=}") # Get the cluster specification cluster_spec = cluster_resolver.cluster_spec() From 053c4b39f9545ce182f473b6d02f4002d799e359 Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 25 Sep 2024 21:33:21 +0300 Subject: [PATCH 5/8] Remove old example - estimator-API, this example has been replaced by distribution_strategy Signed-off-by: yelias --- .github/workflows/publish-example-images.yaml | 5 +- .../{keras-API => }/Dockerfile | 2 +- .../{keras-API => }/README.md | 0 .../estimator-API/Dockerfile | 4 - .../estimator-API/Makefile | 38 --------- .../estimator-API/README.md | 22 ----- .../estimator-API/distributed_tfjob.yaml | 19 ----- .../estimator-API/keras_model_to_estimator.py | 84 ------------------- .../multi_worker_strategy-with-keras.py | 0 .../{keras-API => }/multi_worker_tfjob.yaml | 0 .../{keras-API => }/pvc.yaml | 0 11 files changed, 2 insertions(+), 172 deletions(-) rename examples/tensorflow/distribution_strategy/{keras-API => }/Dockerfile (78%) rename examples/tensorflow/distribution_strategy/{keras-API => }/README.md (100%) delete mode 100644 examples/tensorflow/distribution_strategy/estimator-API/Dockerfile delete mode 100644 examples/tensorflow/distribution_strategy/estimator-API/Makefile delete mode 100644 examples/tensorflow/distribution_strategy/estimator-API/README.md delete mode 100644 examples/tensorflow/distribution_strategy/estimator-API/distributed_tfjob.yaml delete mode 100644 examples/tensorflow/distribution_strategy/estimator-API/keras_model_to_estimator.py rename examples/tensorflow/distribution_strategy/{keras-API => }/multi_worker_strategy-with-keras.py (100%) rename examples/tensorflow/distribution_strategy/{keras-API => }/multi_worker_tfjob.yaml (100%) rename examples/tensorflow/distribution_strategy/{keras-API => }/pvc.yaml (100%) diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index c2aff06784..74dc242551 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -24,12 +24,9 @@ jobs: - component-name: tf-dist-mnist-test platforms: linux/amd64,linux/arm64 dockerfile: examples/tensorflow/dist-mnist/Dockerfile - - component-name: tf-distributed-worker - platforms: linux/amd64,linux/arm64 - dockerfile: examples/tensorflow/distribution_strategy/estimator-API/Dockerfile - component-name: tf-multi-worker-strategy platforms: linux/amd64,linux/arm64 - dockerfile: examples/tensorflow/distribution_strategy/keras-API/Dockerfile + dockerfile: examples/tensorflow/distribution_strategy/Dockerfile - component-name: tf-mnist-with-summaries platforms: linux/amd64,linux/arm64 dockerfile: examples/tensorflow/mnist_with_summaries/Dockerfile diff --git a/examples/tensorflow/distribution_strategy/keras-API/Dockerfile b/examples/tensorflow/distribution_strategy/Dockerfile similarity index 78% rename from examples/tensorflow/distribution_strategy/keras-API/Dockerfile rename to examples/tensorflow/distribution_strategy/Dockerfile index 4b57046398..8f7cc4e93b 100644 --- a/examples/tensorflow/distribution_strategy/keras-API/Dockerfile +++ b/examples/tensorflow/distribution_strategy/Dockerfile @@ -7,5 +7,5 @@ RUN apt-get update && \ RUN pip install tensorflow==2.11.0 tensorflow_datasets==4.7.0 -COPY examples/tensorflow/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py / +COPY examples/tensorflow/distribution_strategy/multi_worker_strategy-with-keras.py / ENTRYPOINT ["python", "/multi_worker_strategy-with-keras.py", "--saved_model_dir", "/train/saved_model/", "--checkpoint_dir", "/train/checkpoint"] diff --git a/examples/tensorflow/distribution_strategy/keras-API/README.md b/examples/tensorflow/distribution_strategy/README.md similarity index 100% rename from examples/tensorflow/distribution_strategy/keras-API/README.md rename to examples/tensorflow/distribution_strategy/README.md diff --git a/examples/tensorflow/distribution_strategy/estimator-API/Dockerfile b/examples/tensorflow/distribution_strategy/estimator-API/Dockerfile deleted file mode 100644 index 88d6399059..0000000000 --- a/examples/tensorflow/distribution_strategy/estimator-API/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -FROM tensorflow/tensorflow:1.11.0 - -COPY examples/tensorflow/distribution_strategy/estimator-API/keras_model_to_estimator.py / -ENTRYPOINT ["python", "/keras_model_to_estimator.py", "/tmp/tfkeras_example/"] diff --git a/examples/tensorflow/distribution_strategy/estimator-API/Makefile b/examples/tensorflow/distribution_strategy/estimator-API/Makefile deleted file mode 100644 index f0ce957ec1..0000000000 --- a/examples/tensorflow/distribution_strategy/estimator-API/Makefile +++ /dev/null @@ -1,38 +0,0 @@ -IMG = gcr.io/kubeflow-examples/distributed_worker - -# List any changed files. We only include files in the notebooks directory. -# because that is the code in the docker image. -# In particular we exclude changes to the ksonnet configs. -CHANGED_FILES := $(shell git diff-files --relative=tensorflow/tf_sample) - -ifeq ($(strip $(CHANGED_FILES)),) -# Changed files is empty; not dirty -# Don't include --dirty because it could be dirty if files outside the ones we care -# about changed. -GIT_VERSION := $(shell git describe --always) -else -GIT_VERSION := $(shell git describe --always)-dirty-$(shell git diff | shasum -a256 | cut -c -6) -endif - -TAG := $(shell date +v%Y%m%d)-$(GIT_VERSION) -all: build - -# To build without the cache set the environment variable -# export DOCKER_BUILD_OPTS=--no-cache -build: - docker build ${DOCKER_BUILD_OPTS} -t $(IMG):$(TAG) . \ - --label=git-verions=$(GIT_VERSION) - docker tag $(IMG):$(TAG) $(IMG):latest - @echo Built $(IMG):latest - @echo Built $(IMG):$(TAG) - - -# Build but don't attach the latest tag. This allows manual testing/inspection of the image -# first. -push: build - gcloud docker -- push $(IMG):$(TAG) - @echo Pushed $(IMG) with :$(TAG) tags - -push-latest: push - gcloud container images add-tag --quiet $(IMG):$(TAG) $(IMG):latest --verbosity=info - echo created $(IMG):latest diff --git a/examples/tensorflow/distribution_strategy/estimator-API/README.md b/examples/tensorflow/distribution_strategy/estimator-API/README.md deleted file mode 100644 index a4a009b2da..0000000000 --- a/examples/tensorflow/distribution_strategy/estimator-API/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Distributed Training on Kubeflow - -This is an example of running distributed training on Kubeflow. The source code is taken from -TensorFlow team's example [here](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy). - -The directory contains the following files: -* Dockerfile: Builds the independent worker image. -* Makefile: For building the above image. -* keras_model_to_estimator.py: This is the model code to run multi-worker training. Identical to the TensorFlow example. -* distributed_tfjob.yaml: The TFJob spec. - -To run the example, edit `distributed_tfjob.yaml` for your cluster's namespace. Then run -``` -kubectl apply -f distributed_tfjob.yaml -``` -to create the job. - -Then use -``` -kubectl -n ${NAMESPACE} describe tfjob distributed-training -``` -to see the status. diff --git a/examples/tensorflow/distribution_strategy/estimator-API/distributed_tfjob.yaml b/examples/tensorflow/distribution_strategy/estimator-API/distributed_tfjob.yaml deleted file mode 100644 index b7a0bc36b5..0000000000 --- a/examples/tensorflow/distribution_strategy/estimator-API/distributed_tfjob.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: "kubeflow.org/v1" -kind: "TFJob" -metadata: - name: "distributed-training" -spec: - runPolicy: - cleanPodPolicy: None - tfReplicaSpecs: - Worker: - replicas: 3 - restartPolicy: Never - template: - metadata: - annotations: - scheduling.k8s.io/group-name: "distributed-training" - spec: - containers: - - name: tensorflow - image: kubeflow/tf-distributed-worker:latest diff --git a/examples/tensorflow/distribution_strategy/estimator-API/keras_model_to_estimator.py b/examples/tensorflow/distribution_strategy/estimator-API/keras_model_to_estimator.py deleted file mode 100644 index 8b345cbd99..0000000000 --- a/examples/tensorflow/distribution_strategy/estimator-API/keras_model_to_estimator.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2018 The Kubeflow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An example of training Keras model with multi-worker strategies.""" -from __future__ import absolute_import, division, print_function - -import sys - -import numpy as np -import tensorflow as tf - - -def input_fn(): - x = np.random.random((1024, 10)) - y = np.random.randint(2, size=(1024, 1)) - x = tf.cast(x, tf.float32) - dataset = tf.data.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(100) - dataset = dataset.batch(32) - return dataset - - -def main(args): - if len(args) < 2: - print( - "You must specify model_dir for checkpoints such as" - " /tmp/tfkeras_example/." - ) - return - - model_dir = args[1] - print("Using %s to store checkpoints." % model_dir) - - # Define a Keras Model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(16, activation="relu", input_shape=(10,))) - model.add(tf.keras.layers.Dense(1, activation="sigmoid")) - - # Compile the model. - optimizer = tf.train.GradientDescentOptimizer(0.2) - model.compile(loss="binary_crossentropy", optimizer=optimizer) - model.summary() - tf.keras.backend.set_learning_phase(True) - - # Define DistributionStrategies and convert the Keras Model to an - # Estimator that utilizes these DistributionStrateges. - # Evaluator is a single worker, so using MirroredStrategy. - config = tf.estimator.RunConfig( - experimental_distribute=tf.contrib.distribute.DistributeConfig( - train_distribute=tf.contrib.distribute.CollectiveAllReduceStrategy( - num_gpus_per_worker=0 - ), - eval_distribute=tf.contrib.distribute.MirroredStrategy( - num_gpus_per_worker=0 - ), - ) - ) - keras_estimator = tf.keras.estimator.model_to_estimator( - keras_model=model, config=config, model_dir=model_dir - ) - - # Train and evaluate the model. Evaluation will be skipped if there is not an - # "evaluator" job in the cluster. - tf.estimator.train_and_evaluate( - keras_estimator, - train_spec=tf.estimator.TrainSpec(input_fn=input_fn), - eval_spec=tf.estimator.EvalSpec(input_fn=input_fn), - ) - - -if __name__ == "__main__": - tf.logging.set_verbosity(tf.logging.INFO) - tf.app.run(argv=sys.argv) diff --git a/examples/tensorflow/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py b/examples/tensorflow/distribution_strategy/multi_worker_strategy-with-keras.py similarity index 100% rename from examples/tensorflow/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py rename to examples/tensorflow/distribution_strategy/multi_worker_strategy-with-keras.py diff --git a/examples/tensorflow/distribution_strategy/keras-API/multi_worker_tfjob.yaml b/examples/tensorflow/distribution_strategy/multi_worker_tfjob.yaml similarity index 100% rename from examples/tensorflow/distribution_strategy/keras-API/multi_worker_tfjob.yaml rename to examples/tensorflow/distribution_strategy/multi_worker_tfjob.yaml diff --git a/examples/tensorflow/distribution_strategy/keras-API/pvc.yaml b/examples/tensorflow/distribution_strategy/pvc.yaml similarity index 100% rename from examples/tensorflow/distribution_strategy/keras-API/pvc.yaml rename to examples/tensorflow/distribution_strategy/pvc.yaml From 6812156e3dc19758e2728cd55670b7ac2840005c Mon Sep 17 00:00:00 2001 From: yelias Date: Wed, 25 Sep 2024 22:24:48 +0300 Subject: [PATCH 6/8] Small fix Signed-off-by: yelias --- examples/tensorflow/dist-mnist/Dockerfile | 3 ++- examples/tensorflow/mnist_utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/tensorflow/dist-mnist/Dockerfile b/examples/tensorflow/dist-mnist/Dockerfile index b0d8fc7d86..e7fd5b0563 100644 --- a/examples/tensorflow/dist-mnist/Dockerfile +++ b/examples/tensorflow/dist-mnist/Dockerfile @@ -14,7 +14,8 @@ FROM tensorflow/tensorflow:2.17.0 -# Using keras-2.17 because of bug on keras-3.4.1 which used by default by TF-2.17 +# Using keras-2.17 because of bug on keras-3.4.1 +# which used by default by TF-2.17 (https://github.com/tensorflow/tensorflow/issues/72388) ENV TF_USE_LEGACY_KERAS 1 RUN pip install tf_keras diff --git a/examples/tensorflow/mnist_utils.py b/examples/tensorflow/mnist_utils.py index 8a698eeb82..8cecd3f462 100644 --- a/examples/tensorflow/mnist_utils.py +++ b/examples/tensorflow/mnist_utils.py @@ -29,7 +29,7 @@ def load_data(fake_data=False, data_path=None, repeat=False): Loads the MNIST dataset and converts it into TensorFlow datasets. Args: - fake_data (bool): If `True`, loads a fake dataset for testing purposes. + fake_data (bool, optional): If `True`, loads a fake dataset for testing purposes. If `False`, loads the real MNIST dataset. data_path (str, optional): Path where to cache the dataset locally. If `None`, the dataset is loaded to the default location. From 28d48b43a1c6c6f9dab51fca4e8195527abdbfb3 Mon Sep 17 00:00:00 2001 From: yelias Date: Mon, 28 Oct 2024 13:28:01 +0200 Subject: [PATCH 7/8] Remove unsupported powerPC dockerfiles Signed-off-by: yelias --- .../tensorflow/dist-mnist/Dockerfile.ppc64le | 18 ------------------ examples/tensorflow/dist-mnist/README.md | 5 ----- .../mnist_with_summaries/Dockerfile.ppc64le | 18 ------------------ .../tensorflow/mnist_with_summaries/README.md | 5 ----- 4 files changed, 46 deletions(-) delete mode 100644 examples/tensorflow/dist-mnist/Dockerfile.ppc64le delete mode 100644 examples/tensorflow/mnist_with_summaries/Dockerfile.ppc64le diff --git a/examples/tensorflow/dist-mnist/Dockerfile.ppc64le b/examples/tensorflow/dist-mnist/Dockerfile.ppc64le deleted file mode 100644 index 8b9bd79de1..0000000000 --- a/examples/tensorflow/dist-mnist/Dockerfile.ppc64le +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2019 The Kubeflow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -FROM ibmcom/tensorflow-ppc64le:1.13.1 - -ADD . /var/tf_dist_mnist -ENTRYPOINT ["python", "/var/tf_dist_mnist/dist_mnist.py"] diff --git a/examples/tensorflow/dist-mnist/README.md b/examples/tensorflow/dist-mnist/README.md index 4d3f842850..306df6c9b1 100644 --- a/examples/tensorflow/dist-mnist/README.md +++ b/examples/tensorflow/dist-mnist/README.md @@ -10,14 +10,9 @@ To build this image on x86_64: ```shell docker build -f Dockerfile -t kubeflow/tf-dist-mnist-test:1.0 ./ ``` -To build this image on ppc64le: -```shell -docker build -f Dockerfile.ppc64le -t kubeflow123/tf-dist-mnist-test:1.0 ./ -``` **Create TFJob YAML** ``` kubectl create -f ./tf_job_mnist.yaml ``` - * If on ppc64le, please update tf_job_mnist.yaml to use the image of ppc64le firstly. diff --git a/examples/tensorflow/mnist_with_summaries/Dockerfile.ppc64le b/examples/tensorflow/mnist_with_summaries/Dockerfile.ppc64le deleted file mode 100644 index 68587dd875..0000000000 --- a/examples/tensorflow/mnist_with_summaries/Dockerfile.ppc64le +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -FROM ibmcom/tensorflow-ppc64le:1.13.1 - -ADD examples/tensorflow/tf_sample/ /var/tf_mnist -ENTRYPOINT ["python", "/var/tf_mnist/mnist_with_summaries.py"] diff --git a/examples/tensorflow/mnist_with_summaries/README.md b/examples/tensorflow/mnist_with_summaries/README.md index ddef953fc1..f6b76dd631 100644 --- a/examples/tensorflow/mnist_with_summaries/README.md +++ b/examples/tensorflow/mnist_with_summaries/README.md @@ -10,12 +10,7 @@ To build this image on x86_64: ```shell docker build -f Dockerfile -t kubeflow/tf-mnist-with-summaries:1.0 ./ ``` -On ppc64le, run as: -```shell -docker build -f Dockerfile.ppc64le -t kubeflow123/tf-mnist-with-summaries:1.0 ./ -``` Usage: 1. Add the persistent volume and claim: `kubectl apply -f tfevent-volume/.` 1. Deploy the TFJob: `kubectl apply -f tf_job_mnist.yaml` - * If on ppc64le, please update tf_job_mnist.yaml to use the image of ppc64le firstly. From 6200bc8a9ca8c2b411c74a8c3a90e7972d7f535d Mon Sep 17 00:00:00 2001 From: yelias Date: Mon, 4 Nov 2024 11:40:04 +0200 Subject: [PATCH 8/8] Fix typo in copyright Signed-off-by: yelias --- examples/tensorflow/mnist_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tensorflow/mnist_utils.py b/examples/tensorflow/mnist_utils.py index 8cecd3f462..5cd436e376 100644 --- a/examples/tensorflow/mnist_utils.py +++ b/examples/tensorflow/mnist_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# Copyright 2024 The Kubeflow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.