Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dtensor to all paths #73

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from megablocks import grouped_gemm_util as gg
import stk
import torch
from packaging import version


class SparseGLU(SparseMLP):
Expand Down Expand Up @@ -34,7 +35,15 @@ def forward(self, x, topo):
raise NotImplementedError("Memory optimized implementation not yet supported with GLU with sparse kernels.")

w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2))

if version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed._tensor import DTensor
if isinstance(w1, DTensor):
w1 = w1.to_local()
if isinstance(v1, DTensor):
v1 = v1.to_local()
if isinstance(w2, DTensor):
w2 = w2.to_local()

# Compute the GLU.
x1 = stk.ops.sdd(x, w1.t(), topo)
x2 = stk.ops.sdd(x, v1.t(), topo)
Expand Down Expand Up @@ -179,6 +188,14 @@ def forward(self, x, tokens_per_expert):

# Re-shape the weights for the grouped GEMMs.
ne = mpu.experts_per_rank(self.args)
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed._tensor import DTensor
if isinstance(w1, DTensor):
w1 = w1.to_local()
if isinstance(v1, DTensor):
v1 = v1.to_local()
if isinstance(w2, DTensor):
w2 = w2.to_local()
w1 = w1.view(ne, -1, self.args.hidden_size)
v1 = v1.view(ne, -1, self.args.hidden_size)
w2 = w2.view(ne, -1, self.args.hidden_size)
Expand Down
12 changes: 12 additions & 0 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def scale_grad(self, w):
return scale_gradient(w, self.gradient_scale)

def forward(self, x):
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed._tensor import DTensor
if isinstance(w1, DTensor):
w1 = w1.to_local()
if isinstance(w2, DTensor):
w2 = w2.to_local()
x = torch.bmm(x, self.scale_grad(self.w1))
x = self.args.activation_fn(x)
return torch.bmm(x, self.scale_grad(self.w2))
Expand Down Expand Up @@ -382,6 +388,12 @@ def parallel_forward(self, x, topo):

def forward(self, x, topo):
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed._tensor import DTensor
if isinstance(w1, DTensor):
w1 = w1.to_local()
if isinstance(w2, DTensor):
w2 = w2.to_local()
if self.args.moe_weight_parallelism:
return self.parallel_forward(x, topo)
elif self.args.memory_optimized_mlp:
Expand Down