Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Set allocation as loop in fusion segmentor #3880

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ TensorView* castIntermediateValueInCompleteFusion(
void SegmentedFusion::finalize() {
impl_.cleanUnused();
castInputOutputToLowerPrecision(edges());
setAllocationAsLoopForShardedTvs();
}

//! Lower FP precision of inputs and outputs specified by the given
Expand Down Expand Up @@ -1435,6 +1436,22 @@ void SegmentedFusion::revertInputOutputPrecisionChanges(
}
}

void SegmentedFusion::setAllocationAsLoopForShardedTvs() {
auto set_allocation_as_loop = [](std::vector<Val*> vals) {
auto tvs = ir_utils::filterByType<TensorView>(vals);
std::for_each(tvs.begin(), tvs.end(), [](TensorView* tv) {
if (isSharded(tv)) {
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
});
};

for (auto group : groups()) {
set_allocation_as_loop(group->inputs());
set_allocation_as_loop(group->outputs());
}
}

//! An utility class to compute and maintain the "producers of"
//! relationship in a segmented graph. Space heavy and should
//! avoid use on very large graphs.
Expand Down
3 changes: 3 additions & 0 deletions csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ class SegmentedFusion {
//! Deserialize SegmentedFusion using flatbuffers
void deserialize(const serde::SegmentedFusion* buffer);

//! Set allocation domain as loop domain for sharded tensors
void setAllocationAsLoopForShardedTvs();

private:
void validateDAG() const;
void validateDisjoint() const;
Expand Down
35 changes: 23 additions & 12 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,31 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
}

bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
continue;
}
// First check allocation domain if available, or the logical domain.
auto num_sharded_axes = std::count_if(
tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end(),
[](IterDomain* id) { return id->isDeviceDim(); });

// Only one axis can be sharded on DIDx.
NVF_ERROR(
!is_sharded,
"Multiple IterDomains parallelized on DIDx in TensorView ",
tv);
is_sharded = true;
if (num_sharded_axes == 1) {
return true;
}
return is_sharded;

// Check if only the loop domain is sharded.
// It is possible if the allocation domain has not been set yet.
if (num_sharded_axes == 0) {
num_sharded_axes = std::count_if(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](IterDomain* id) { return id->isDeviceDim(); });
}

NVF_ERROR(
num_sharded_axes <= 1,
"Multiple IterDomains parallelized on DIDx in TensorView ",
tv);

return num_sharded_axes == 1;
}

namespace {
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/pre_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace nvfuser::preseg_passes {
OptimizationPass<PropagateShardingsPass>::runPass(fusion);
OptimizationPass<InsertReshardingsPass>::runPass(fusion);
OptimizationPass<ReorderShardedAxisPass>::runPass(fusion);
OptimizationPass<MakeReshardingContiguousPass>::runPass(fusion);
// OptimizationPass<MakeReshardingContiguousPass>::runPass(fusion);

// Replace TensorViews with zero extent. Outputs and inputs may still be empty
OptimizationPass<RemoveEmptyPass>::runPass(fusion);
Expand Down
Loading