Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for latest Horovod and added support for Process Sets #27

Open
wants to merge 44 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a4c2834
HOROVOD new version compatible
Jul 20, 2024
8c1f704
HOROVOD new version compatible for pytorch
Jul 20, 2024
50d1cdf
Update README.md
GarvitBanga Jul 20, 2024
7df7e51
Update README.md
GarvitBanga Jul 20, 2024
ca343d3
Update README.md
GarvitBanga Jul 20, 2024
90c3bf2
HOROVOD new version compatible for pytorch
Jul 22, 2024
2143780
HOROVOD new version compatible for pytorch
Jul 22, 2024
95c8fb0
HOROVOD new version compatible for pytorch
Jul 22, 2024
eceab88
HOROVOD new version compatible for pytorch
Jul 22, 2024
53462cc
HOROVOD new version compatible for pytorch
Jul 22, 2024
9cd3139
HOROVOD new version compatible for pytorch
Jul 22, 2024
d1f3b62
HOROVOD new version compatible for pytorch
Jul 22, 2024
31280b9
Delete .DS_Store
GarvitBanga Jul 22, 2024
5a2691f
Delete patch_files/.DS_Store
GarvitBanga Jul 22, 2024
9f7c3e6
Delete patch_files/horovod/.DS_Store
GarvitBanga Jul 22, 2024
0d15751
Delete patch_files/horovod/torch/.DS_Store
GarvitBanga Jul 22, 2024
8d574b4
Delete grace_dl/.DS_Store
GarvitBanga Jul 22, 2024
83dfee3
Delete grace_dl/torch/.DS_Store
GarvitBanga Jul 22, 2024
d2d7db6
Delete grace_dl/torch/communicator/.DS_Store
GarvitBanga Jul 22, 2024
fdc904c
Delete grace_dl/torch/compressor/.DS_Store
GarvitBanga Jul 22, 2024
1aa5bad
Delete grace_dl/torch/memory/.DS_Store
GarvitBanga Jul 22, 2024
f211386
Update allgather.py
GarvitBanga Jul 22, 2024
d42bb5a
Update topk.py
GarvitBanga Jul 22, 2024
e60a421
Update residual.py
GarvitBanga Jul 22, 2024
0146a96
Update __init__.py
GarvitBanga Jul 22, 2024
3944638
Update helper.py
GarvitBanga Jul 22, 2024
55718c6
Update allgather.py
GarvitBanga Jul 22, 2024
c53fb02
Update topk.py
GarvitBanga Jul 22, 2024
349b466
Update residual.py
GarvitBanga Jul 22, 2024
e847b2f
Update helper.py
GarvitBanga Jul 22, 2024
ef702e8
Update allgather.py
GarvitBanga Jul 22, 2024
f74095c
Delete grace_dl/dist/.DS_Store
GarvitBanga Jul 22, 2024
9123421
Delete grace_dl/tensorflow/.DS_Store
GarvitBanga Jul 22, 2024
2610d62
Delete grace_dl/tensorflow/memory/.DS_Store
GarvitBanga Jul 22, 2024
f4c77ad
Delete grace_dl/tensorflow/compressor/.DS_Store
GarvitBanga Jul 22, 2024
ca8d0f3
Delete grace_dl/tensorflow/communicator/.DS_Store
GarvitBanga Jul 22, 2024
04a1a6e
Delete examples/.DS_Store
GarvitBanga Jul 22, 2024
e8b7144
Create pytorch_mnist_process_set.py
GarvitBanga Jul 22, 2024
8974e5e
Update pytorch_mnist_process_set.py
GarvitBanga Jul 22, 2024
4213c39
Update pytorch_mnist_process_set.py
GarvitBanga Jul 22, 2024
fe7867e
Update topk.py
GarvitBanga Jul 22, 2024
eb82162
backward pass reset count
Jul 23, 2024
5d693b4
Delete .DS_Store
GarvitBanga Jul 23, 2024
7029f8a
Delete patch_files/horovod/.DS_Store
GarvitBanga Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions examples/torch/pytorch_mnist_process_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import print_function

import argparse

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torchvision import datasets, transforms

import horovod.torch as hvd
from grace_dl.torch.communicator.allgather import Allgather
from grace_dl.torch.compressor.topk import TopKCompressor
from grace_dl.torch.memory.residual import ResidualMemory

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

# Horovod: initialize library.
hvd.init()
even_set = hvd.ProcessSet([0,2])
odd_set = hvd.ProcessSet([1,3])
hvd.init(process_sets=[even_set, odd_set])
torch.manual_seed(args.seed)

if args.cuda:
# Horovod: pin GPU to local rank.
torch.cuda.set_device(hvd.local_rank())
torch.cuda.manual_seed(args.seed)

# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(1)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_dataset = \
datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
# Horovod: use DistributedSampler to partition the training data.
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)

test_dataset = \
datasets.MNIST('data-%d' % hvd.rank(), train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
# Horovod: use DistributedSampler to partition the test data.
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset, num_replicas=hvd.size(), rank=hvd.rank())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size,
sampler=test_sampler, **kwargs)


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)


model = Net()

if args.cuda:
# Move model to GPU.
model.cuda()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(),
momentum=args.momentum)

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

# GRACE: compression algorithm.
# grc = Allgather(TopKCompressor(0.3), ResidualMemory(), hvd.size())

from grace_dl.torch.helper import grace_from_params
params = {'compressor': 'topk', 'memory': 'residual', 'communicator': 'allgather'}


# Horovod: wrap optimizer with DistributedOptimizer.
if(hvd.ProcessSet.included(even_set)):
grc = grace_from_params(params,even_set)
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(),grace=grc,process_set=even_set)
else:
grc = grace_from_params(params,odd_set)
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(),grace=grc,process_set=odd_set)


def train(epoch):
model.train()
# Horovod: set epoch to sampler for shuffling.
train_sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
# Horovod: use train_sampler to determine the number of examples in
# this worker's partition.
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_sampler), 100. * batch_idx / len(train_loader), loss.item()))


def metric_average(val, name):
tensor = torch.tensor(val)
if(hvd.ProcessSet.included(even_set)):
avg_tensor = hvd.allreduce(tensor, name=name,process_set=even_set)
else:
avg_tensor = hvd.allreduce(tensor, name=name,process_set=odd_set)
return avg_tensor.item()


def test():
model.eval()
test_loss = 0.
test_accuracy = 0.
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum()

# Horovod: use test_sampler to determine the number of examples in
# this worker's partition.
test_loss /= len(test_sampler)
test_accuracy /= len(test_sampler)

# Horovod: average metric values across workers.
test_loss = metric_average(test_loss, 'avg_loss')
test_accuracy = metric_average(test_accuracy, 'avg_accuracy')

# Horovod: print output only on first rank.
if(hvd.ProcessSet.included(even_set) and hvd.ProcessSet.rank(even_set)==0):
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * test_accuracy))
elif(hvd.ProcessSet.included(odd_set) and hvd.ProcessSet.rank(odd_set)==0):
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * test_accuracy))


for epoch in range(1, args.epochs + 1):
train(epoch)
test()
12 changes: 6 additions & 6 deletions grace_dl/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ def aggregate(self, tensors):

class Communicator(ABC):
@abstractmethod
def async_send(self, tensors, name):
def async_send(self, tensors, name,process_set):
raise NotImplemented("async_send was not implemented.")

@abstractmethod
def wait_receive(self, handles, ctx):
def wait_receive(self, handles, ctx,process_set):
raise NotImplemented("wait_receive was not implemented.")

def __init__(self, compressor, memory):
self.compressor = compressor
self.memory = memory

def send_step(self, tensor, name):
def send_step(self, tensor, name,process_set):
tensor = self.memory.compensate(tensor, name)
tensors_compressed, ctx = self.compressor.compress(tensor, name)
self.memory.update(tensor, name, self.compressor, tensors_compressed, ctx)
handles = self.async_send(tensors_compressed, name)
handles = self.async_send(tensors_compressed, name,process_set)
return handles, ctx

def receive_step(self, handles, ctx):
return self.wait_receive(handles, ctx)
def receive_step(self, handles, ctx,process_set):
return self.wait_receive(handles, ctx,process_set)
8 changes: 4 additions & 4 deletions grace_dl/torch/communicator/allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, compressor, memory, world_size):
super().__init__(compressor, memory)
self.world_size = world_size

def async_send(self, tensors_compressed, name):
def async_send(self, tensors_compressed, name,process_set):
"""
:param tensors_compressed: list of flat tensors to communicate
:param name: for the all_gather operation
Expand All @@ -26,17 +26,17 @@ def async_send(self, tensors_compressed, name):
tensor_sizes = zip(*tensors_size_ag) # transpose
else:
tensors_size = torch.tensor(tensors_size) # TODO: set device
gathered = allgather(tensors_size) # tensor of tensor sizes per rank
gathered = allgather(tensor=tensors_size,process_set=process_set) # tensor of tensor sizes per rank
tensor_sizes = gathered.view([self.world_size, -1]).t().tolist() # transpose, to list

handles = []
for tensor_compressed in tensors_compressed:
handle = allgather_async(tensor_compressed)
handle = allgather_async(tensor=tensor_compressed,process_set=process_set)
handles.append(handle)

return handles, tensor_sizes

def wait_receive(self, result, ctx):
def wait_receive(self, result, ctx,process_set):
handles, tensor_sizes = result
tensors_ag = []
for handle, sizes in zip(handles, tensor_sizes):
Expand Down
6 changes: 3 additions & 3 deletions grace_dl/torch/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def grace_from_params(params):
import horovod.torch as hvd
world_size = hvd.size()
from horovod.torch.mpi_ops import global_process_set
def grace_from_params(params,processset=global_process_set):
world_size = processset.size()
comp = params.get('compressor', 'none')
mem = params.get('memory', 'none')
comm = params.get('communicator', 'allreduce')
Expand Down
57 changes: 34 additions & 23 deletions patch_files/horovod/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,44 @@

from horovod.common.util import check_extension

_MPI_LIB_AVAILABLE = True
try:
check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH',
__file__, 'mpi_lib_v2')
except:
check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH',
__file__, 'mpi_lib', '_mpi_lib')

from horovod.torch import elastic
from horovod.torch.compression import Compression
from horovod.torch.functions import allgather_object, broadcast_object, broadcast_optimizer_state, broadcast_parameters
from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_
from horovod.torch.mpi_ops import grouped_allreduce, grouped_allreduce_async, grouped_allreduce_, grouped_allreduce_async_
from horovod.torch.mpi_ops import allgather, allgather_async
from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_
from horovod.torch.mpi_ops import alltoall, alltoall_async
from horovod.torch.mpi_ops import join
from horovod.torch.mpi_ops import poll, synchronize
from horovod.torch.mpi_ops import init, shutdown
from horovod.torch.mpi_ops import is_initialized, start_timeline, stop_timeline
from horovod.torch.mpi_ops import size, local_size, rank, local_rank
from horovod.torch.mpi_ops import mpi_threads_supported, mpi_enabled, mpi_built
from horovod.torch.mpi_ops import gloo_enabled, gloo_built
from horovod.torch.mpi_ops import nccl_built, ddl_built, ccl_built, cuda_built, rocm_built
from horovod.torch.mpi_ops import Average, Sum, Adasum
from horovod.torch.optimizer import DistributedOptimizer
from horovod.torch.sync_batch_norm import SyncBatchNorm
except Exception as e:
# MPI libs are missing, but python applications are still available.
print(e)
print("Warning! MPI libs are missing, but python applications are still available.")
_MPI_LIB_AVAILABLE = False

# only import following function when mpi is available.
if _MPI_LIB_AVAILABLE:
from horovod.torch import elastic
from horovod.torch.compression import Compression
from horovod.torch.functions import allgather_object, broadcast_object, broadcast_optimizer_state, broadcast_parameters
from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_
from horovod.torch.mpi_ops import grouped_allreduce, grouped_allreduce_async, grouped_allreduce_, grouped_allreduce_async_
from horovod.torch.mpi_ops import sparse_allreduce_async
from horovod.torch.mpi_ops import allgather, allgather_async
from horovod.torch.mpi_ops import grouped_allgather, grouped_allgather_async
from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_
from horovod.torch.mpi_ops import alltoall, alltoall_async
from horovod.torch.mpi_ops import reducescatter, reducescatter_async
from horovod.torch.mpi_ops import grouped_reducescatter, grouped_reducescatter_async
from horovod.torch.mpi_ops import join
from horovod.torch.mpi_ops import barrier
from horovod.torch.mpi_ops import poll, synchronize
from horovod.torch.mpi_ops import init, shutdown
from horovod.torch.mpi_ops import is_initialized, start_timeline, stop_timeline
from horovod.torch.mpi_ops import size, local_size, cross_size, rank, local_rank, cross_rank
from horovod.torch.mpi_ops import mpi_threads_supported, mpi_enabled, mpi_built
from horovod.torch.mpi_ops import gloo_enabled, gloo_built
from horovod.torch.mpi_ops import nccl_built, ddl_built, ccl_built, cuda_built, rocm_built
from horovod.torch.mpi_ops import ProcessSet, global_process_set, add_process_set, remove_process_set
from horovod.torch.mpi_ops import Average, Sum, Adasum, Min, Max, Product
from horovod.torch.mpi_ops import HorovodInternalError
from horovod.torch.optimizer import DistributedOptimizer
from horovod.torch.sync_batch_norm import SyncBatchNorm

# Please run this function in a subprocess
def _check_has_gpu():
Expand Down
Loading