Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add memory parity for PL vs Vanilla #5170

Merged
merged 22 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmarks/generate_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import matplotlib.pylab as plt
import pandas as pd

from benchmarks.test_basic_parity import lightning_loop, vanilla_loop
from benchmarks.test_basic_parity import measure_loops
from tests.base.models import ParityModuleMNIST, ParityModuleRNN

NUM_EPOCHS = 20
Expand All @@ -34,8 +34,9 @@ def _main():
if os.path.isfile(path_csv):
df_time = pd.read_csv(path_csv, index_col=0)
else:
vanilla = vanilla_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
lightning = lightning_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
# todo: kind="Vanilla PT" -> use_lightning=False
vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)

df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]})
df_time /= NUM_RUNS
Expand Down
190 changes: 114 additions & 76 deletions benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,126 +11,164 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import time

import numpy as np
import pytest
import torch
from tqdm import tqdm

from pytorch_lightning import seed_everything, Trainer
import tests.base.develop_utils as tutils
from pytorch_lightning import LightningModule, seed_everything, Trainer
from tests.base.models import ParityModuleMNIST, ParityModuleRNN


def assert_parity_relative(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.1):
# assert speeds
diffs = np.asarray(pl_values) - np.mean(pt_values)
# norm by vanilla time
diffs = diffs / norm_by
# relative to mean reference value
diffs = diffs / np.mean(pt_values)
assert np.mean(diffs) < max_diff, f"Lightning diff {diffs} was worse than vanilla PT (threshold {max_diff})"


def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.55):
# assert speeds
diffs = np.asarray(pl_values) - np.mean(pt_values)
# norm by event count
diffs = diffs / norm_by
assert np.mean(diffs) < max_diff, f"Lightning {diffs} was worse than vanilla PT (threshold {max_diff})"


# ParityModuleMNIST runs with num_workers=1
@pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.25), # todo: lower this thr
@pytest.mark.parametrize('cls_model,max_diff_speed,max_diff_memory', [
(ParityModuleRNN, 0.05, 0.0),
(ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr
Copy link
Contributor

@SeanNaren SeanNaren Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really 0.0 still? I know you investigated a bit was curious. It doesn't seem correct but maybe that's because of how small the memory difference is (memory usage is tiny) Maybe move to a significant figure 1e-5?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, the model is super small and we run just 4 epochs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure adding memory check for such small models make sense.

Copy link
Contributor

@awaelchli awaelchli Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what difference does it make how big the models are?
max_diff_memory there is the difference between the pytorch run and the lightning run with the SAME model. It's perfectly fine if lightning uses the same amount of memory as pytorch. in fact, how would you even explain any other numbers?
There is no logging, no fancy Lightning features, nothing that should occupy extra memory on the gpu.

])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3):
def test_pytorch_parity(
tmpdir,
cls_model: LightningModule,
max_diff_speed: float,
max_diff_memory: float,
num_epochs: int = 4,
num_runs: int = 3,
):
"""
Verify that the same pytorch and lightning models achieve the same results
"""
lightning = lightning_loop(cls_model, num_runs, num_epochs)
vanilla = vanilla_loop(cls_model, num_runs, num_epochs)
lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs)
vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs)

# make sure the losses match exactly to 5 decimal places
print(f"Losses are for... \n vanilla: {vanilla['losses']} \n lightning: {lightning['losses']}")
for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
np.testing.assert_almost_equal(pl_out, pt_out, 5)

# the fist run initialize dataset (download & filter)
tutils.assert_speed_parity_absolute(
lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff
# drop the first run for initialize dataset (download & filter)
assert_parity_absolute(
lightning['durations'][1:], vanilla['durations'][1:], norm_by=num_epochs, max_diff=max_diff_speed
)

assert_parity_relative(lightning['memory'], vanilla['memory'], max_diff=max_diff_memory)

def vanilla_loop(cls_model, num_runs=10, num_epochs=10):

def _hook_memory():
if torch.cuda.is_available():
torch.cuda.synchronize()
used_memory = torch.cuda.max_memory_allocated()
else:
used_memory = np.nan
return used_memory


def measure_loops(cls_model, kind, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
"""
hist_losses = []
hist_durations = []
hist_memory = []

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.deterministic = True
for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'):
time_start = time.perf_counter()
for i in tqdm(range(num_runs), desc=f'{kind} with {cls_model.__name__}'):
gc.collect()
if device_type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
time.sleep(1)

# set seed
seed = i
seed_everything(seed)

# init model parts
model = cls_model()
dl = model.train_dataloader()
optimizer = model.configure_optimizers()

# model to GPU
model = model.to(device)

epoch_losses = []
# as the first run is skipped, no need to run it long
for epoch in range(num_epochs if i > 0 else 1):

# run through full training set
for j, batch in enumerate(dl):
batch = [x.to(device) for x in batch]
loss_dict = model.training_step(batch, j)
loss = loss_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()
time_start = time.perf_counter()

# track last epoch loss
epoch_losses.append(loss.item())
_loop = lightning_loop if kind == "PT Lightning" else vanilla_loop
final_loss, used_memory = _loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)

time_end = time.perf_counter()
hist_durations.append(time_end - time_start)

hist_losses.append(epoch_losses[-1])
hist_losses.append(final_loss)
hist_durations.append(time_end - time_start)
hist_memory.append(used_memory)

return {
'losses': hist_losses,
'durations': hist_durations,
'memory': hist_memory,
}


def lightning_loop(cls_model, num_runs=10, num_epochs=10):
hist_losses = []
hist_durations = []
def vanilla_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
device = torch.device(device_type)
# set seed
seed_everything(idx)

for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'):
time_start = time.perf_counter()
# init model parts
model = cls_model()
dl = model.train_dataloader()
optimizer = model.configure_optimizers()

# set seed
seed = i
seed_everything(seed)

model = cls_model()
# init model parts
trainer = Trainer(
# as the first run is skipped, no need to run it long
max_epochs=num_epochs if i > 0 else 1,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1,
checkpoint_callback=False,
deterministic=True,
logger=False,
replace_sampler_ddp=False,
)
trainer.fit(model)

final_loss = trainer.train_loop.running_loss.last().item()
hist_losses.append(final_loss)
# model to GPU
model = model.to(device)

time_end = time.perf_counter()
hist_durations.append(time_end - time_start)
epoch_losses = []
# as the first run is skipped, no need to run it long
for epoch in range(num_epochs if idx > 0 else 1):

return {
'losses': hist_losses,
'durations': hist_durations,
}
# run through full training set
for j, batch in enumerate(dl):
batch = [x.to(device) for x in batch]
loss_dict = model.training_step(batch, j)
loss = loss_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()

# track last epoch loss
epoch_losses.append(loss.item())

return epoch_losses[-1], _hook_memory()


def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
seed_everything(idx)

model = cls_model()
# init model parts
trainer = Trainer(
# as the first run is skipped, no need to run it long
max_epochs=num_epochs if idx > 0 else 1,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1 if device_type == 'cuda' else 0,
checkpoint_callback=False,
deterministic=True,
logger=False,
replace_sampler_ddp=False,
)
trainer.fit(model)

return trainer.train_loop.running_loss.last().item(), _hook_memory()
42 changes: 17 additions & 25 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,64 +28,59 @@
from tests.base.boring_model import BoringModel, RandomDataset


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
plugin_parity_test(
accelerator='ddp_cpu',
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.15, # todo: slower speed due to one CPU doing additional sequential memory saving calls
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
model_cls=SeedTrainLoaderModel,
)


@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -95,13 +90,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
Expand All @@ -111,7 +105,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
accelerator='ddp_spawn',
plugin='ddp_sharded',
model_cls=SeedTrainLoaderModel,
max_percent_speed_diff=0.25
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down Expand Up @@ -147,8 +141,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):

@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Expand All @@ -159,14 +152,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Expand All @@ -177,7 +169,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
)


Expand Down
Loading