Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optimization of parallel dimension predicate with IdModel #3938

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 21, 2025

No description provided.

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 21, 2025

!test --diff

Copy link

github-actions bot commented Feb 21, 2025

Review updated until commit e342f14

Description

  • Introduced IdModel-based unswitch analysis in predicate_compute.cpp

  • Added new tests for parallel dimension predicate with unswitch in test_indexing.cpp

  • Enhanced ValGraphVisitor.h with utility functions for expression group inputs and outputs


Changes walkthrough 📝

Relevant files
Enhancement
predicate_compute.cpp
Update unswitch analysis with IdModel                                       

csrc/predicate_compute.cpp

  • Added new functions getFullyUnswitchedLoopIds and isFullyUnswitched
  • Updated getPredicateMap to use getFullyUnswitchedLoopIds
  • Removed old getNonUnswitchedRootDomains and isFullyUnswitched
    functions
  • +99/-37 
    val_graph_visitor.h
    Add utility functions for expression groups                           

    csrc/val_graph_visitor.h

  • Added utility functions getInputsOfExprGroup and getOutputsOfExprGroup
  • +16/-0   
    Tests
    test_indexing.cpp
    Add tests for parallel dimension predicate with unswitch 

    tests/cpp/test_indexing.cpp

  • Added new test ParallelDimensionPredicateWithUnswitch
  • Added new test ParallelDimensionPredicateWithUnswitchAndSetLoopDomain
  • Updated PredicateIndexValidator to handle inline predicates more
    gracefully
  • +189/-8 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Code Complexity

    The new function getFullyUnswitchedLoopIds is quite complex and could benefit from additional comments to explain the logic and flow. Consider adding more detailed comments to help future maintainers understand the code.

    std::vector<IterDomain*> getFullyUnswitchedLoopIds(
        const Expr* expr,
        const std::vector<ForLoop*>& loops,
        ForLoop* unswitched_loop) {
      if (unswitched_loop == nullptr) {
        return {};
      }
    
      const auto& id_model = GpuLower::current()->idModel();
      const auto& indexing_graph =
          id_model.idGraph(TensorIndexer::traversalGraphType());
    
      auto out_tv = ir_utils::getTvOutput(expr);
      NVF_ERROR(out_tv != nullptr);
    
      std::vector<IterDomain*> loop_ids;
      loop_ids.reserve(loops.size());
      std::transform(
          loops.begin(),
          loops.end(),
          std::back_inserter(loop_ids),
          [&](ForLoop* loop) {
            const auto& loop_group =
                id_model.idGraph(IdMappingMode::LOOP).toGroup(loop->iter_domain());
            auto promotion_it = id_model.loopPromotionMap().find(loop_group);
            NVF_ERROR(
                promotion_it != id_model.loopPromotionMap().end(),
                "Loop promotion not found for ",
                loop->iter_domain()->toString());
            return promotion_it->second;
          });
    
      const auto predicate_ids = getPredicateDomains(out_tv, expr);
    
      const IndexingTraversal::ExprPath predicate_path =
          IndexingTraversal::getExprsBetween(
              expr, indexing_graph, loop_ids, predicate_ids);
    
      ValGroups non_unswitch_dep_ids;
      std::vector<IterDomain*> unswitched_loop_ids;
      bool unswitch_found = false;
      for (const auto loop : loops) {
        if (loop == unswitched_loop) {
          unswitch_found = true;
        }
        if (unswitch_found) {
          unswitched_loop_ids.push_back(loop->iter_domain());
        } else {
          non_unswitch_dep_ids.pushBack(
              indexing_graph.toGroup(loop->iter_domain()));
        }
      }
    
      for (const auto& [expr_g, dir] : predicate_path) {
        const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir);
        const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir);
        if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) {
              return non_unswitch_dep_ids.has(input);
            })) {
          // Depends on non-unswitched ids
          non_unswitch_dep_ids.pushBack(outputs);
        }
      }
    
      // If none of unswitched_loop_ids is used with the non-unswitched
      // loop ids,
    
      std::vector<IterDomain*> fully_unswitched_loop_ids;
      for (auto unswitched_loop_id : unswitched_loop_ids) {
        if (!isParallelTypeThread(unswitched_loop_id->getParallelType())) {
          continue;
        }
    
        ValGroups unswitch_dep_ids;
        unswitch_dep_ids.pushBack(indexing_graph.toGroup(unswitched_loop_id));
    
        bool conflict_found = false;
        for (const auto& [expr_g, dir] : predicate_path) {
          const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir);
          const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir);
          if (std::none_of(
                  inputs.begin(), inputs.end(), [&](const ValGroup& input) {
                    return unswitch_dep_ids.has(input);
                  })) {
            continue;
          }
    
          if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) {
                return non_unswitch_dep_ids.has(input);
              })) {
            conflict_found = true;
            break;
          }
        }
    
        if (!conflict_found) {
          fully_unswitched_loop_ids.push_back(unswitched_loop_id);
        }
      }
    
      return fully_unswitched_loop_ids;
    }
    Error Handling

    The function getFullyUnswitchedLoopIds uses NVF_ERROR for error handling. Ensure that all possible error conditions are covered and that the error messages are clear and informative.

    const auto& id_model = GpuLower::current()->idModel();
    const auto& indexing_graph =
        id_model.idGraph(TensorIndexer::traversalGraphType());
    
    auto out_tv = ir_utils::getTvOutput(expr);
    NVF_ERROR(out_tv != nullptr);
    
    std::vector<IterDomain*> loop_ids;
    loop_ids.reserve(loops.size());
    std::transform(
        loops.begin(),
        loops.end(),
        std::back_inserter(loop_ids),
        [&](ForLoop* loop) {
          const auto& loop_group =
              id_model.idGraph(IdMappingMode::LOOP).toGroup(loop->iter_domain());
          auto promotion_it = id_model.loopPromotionMap().find(loop_group);
          NVF_ERROR(
              promotion_it != id_model.loopPromotionMap().end(),
              "Loop promotion not found for ",
              loop->iter_domain()->toString());
          return promotion_it->second;
    Test Coverage

    The new tests added (ParallelDimensionPredicateWithUnswitch and ParallelDimensionPredicateWithUnswitchAndSetLoopDomain) are comprehensive, but consider adding more edge cases to ensure robustness.

    TEST_F(PredicateIndexingTest, ParallelDimensionPredicateWithUnswitch) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigTensor(1);
      fusion.addInput(tv0);
      auto tv1 = makeContigTensor(1);
      fusion.addInput(tv1);
      auto tv2 = makeContigTensor(1);
      fusion.addInput(tv2);
    
      // Just to make TIDx non unique so the parallel dimension predicate
      // is required for TIDx
      auto tv3 = set(tv0);
      fusion.addOutput(tv3);
      tv3->axis(0)->parallelize(ParallelType::TIDx);
    
      auto tv4 = set(tv1);
      fusion.addOutput(tv4);
    
      auto tv5 = set(tv2);
      fusion.addOutput(tv5);
    
      // TIDx-parallelized ID is fully unswitched
      tv4->split(0, 128);
      tv4->split(0, 1, false);
      tv4->axis(0)->parallelize(ParallelType::Unswitch);
      tv4->axis(-1)->parallelize(ParallelType::TIDx);
    
      // TIDx-parallelized ID is not fully unswitched. The unswitch
      // predicate should have (threadIdx.x < 128)
      tv5->split(0, 128);
      tv5->split(0, 1);
      tv5->axis(-2)->parallelize(ParallelType::Unswitch);
      tv5->axis(-1)->parallelize(ParallelType::TIDx);
    
      struct GetReference : AbstractGetReference {
        GetReference(const TensorIndexer& indexer, const IdModel& id_model)
            : AbstractGetReference(indexer, id_model) {}
    
        // The unswitch predicate should look like:
        //
        // tv4:
        // ((nvfuser_index_t)threadIdx.x) >= 0LL &&
        // (ceilDiv(T4.logical_size[0LL], 128LL)) - 1LL) * 128LL) +
        // ((nvfuser_index_t)threadIdx.x)) < T4.logical_size[0LL]
        //
        // tv5:
        //  ((nvfuser_index_t)threadIdx.x) < 128LL &&
        // (i1 * 128LL) + ((nvfuser_index_t)threadIdx.x) >= 0LL &&
        // (i1 * 128LL) + ((nvfuser_index_t)threadIdx.x) < T5.logical_size[0LL]
    
        Val* getOuterPredicate(TensorView* tv) const override {
          std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);
          Val* zero = tv->fusion()->zeroVal();
          Val* one = tv->fusion()->oneVal();
    
          if (tv->name() == 4) {
            auto min_idx = addExpr(
                IrBuilder::mulExpr(zero, createInt(128)), loop_indices.back());
            auto min_pred = geExpr(min_idx, zero);
            auto max_idx = addExpr(
                IrBuilder::mulExpr(
                    subExpr(
                        ceilDivExpr(
                            tv->getLogicalDomain().at(0)->extent(), createInt(128)),
                        one),
                    createInt(128)),
                loop_indices.back());
            auto max_pred = ltExpr(max_idx, tv->getLogicalDomain().at(0)->extent());
            return andExpr(min_pred, max_pred);
          } else if (tv->name() == 5) {
            auto tidx_pred = ltExpr(loop_indices.back(), createInt(128));
            auto idx = addExpr(
                IrBuilder::mulExpr(loop_indices.at(0), createInt(128)),
                loop_indices.back());
            auto min_pred = geExpr(idx, zero);
            auto max_pred = ltExpr(idx, tv->getLogicalDomain().at(0)->extent());
            return andExpr(andExpr(tidx_pred, min_pred), max_pred);
          } else {
            return nullptr;
          }
        }
      };
    
      PredicateIndexValidator<GetReference>::validate(&fusion, false);
    }
    
    // Check if a parallel dimension predicate is propertly used with a
    // loop domain set as the producer of a logical domain. Because of the
    // reversed depency, BFS traversal is required. This test resulted in
    // a validation failure before PR #3938.
    TEST_F(
        PredicateIndexingTest,
        ParallelDimensionPredicateWithUnswitchAndSetLoopDomain) {
      // EnableOptionsGuard enable_options_guard;
      EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigTensor(1);
      fusion.addInput(tv0);
    
      auto tv1 = makeContigConcreteTensor({4, 8});
      fusion.addInput(tv1);
    
      // Just to make TIDx non unique so the parallel dimension predicate
      // is required for TIDx
      auto tv2 = set(tv0);
      fusion.addOutput(tv2);
      tv2->axis(0)->parallelize(ParallelType::TIDx);
    
      auto tv3 = reshape(tv1, {4, 8}, {32});
      auto tv4 = sum(tv3, {0});
      fusion.addOutput(tv4);
    
      // Cancel the reshape in the loop domain [4, 8]
      tv3->setLoopDomain(tv3->getRootDomain());
    
      // Make the loop domain of tv4 look like that of tv3.
      // TODO: use scheduler_tools::scheduleLoopDomainsLike, which doesn't
      // seem to propertly set the IterType of the new IDs.
      auto tv4_loop_id0 = IterDomainBuilder(tv3->getLoopDomain().at(0))
                              .iter_type(IterType::Reduction)
                              .build();
      auto tv4_loop_id1 = IterDomainBuilder(tv3->getLoopDomain().at(1))
                              .iter_type(IterType::Reduction)
                              .build();
      IrBuilder::create<Merge>(
          tv4->getLogicalDomain().at(0), tv4_loop_id0, tv4_loop_id1);
      tv4->setLoopDomain({tv4_loop_id0, tv4_loop_id1});
    
      // Schedule tv3 and tv4 as:
      // [Serial(4), Unswitch(1), TIDx(8)]
      for (auto tv : {tv3, tv4}) {
        tv->split(1, 1, false);
      }
    
      tv3->inlineAt(-1);
    
      tv4->axis(-2)->parallelize(ParallelType::Unswitch);
      tv4->axis(-1)->parallelize(ParallelType::TIDx);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({128}, options);
      at::Tensor t1 = at::randn({4, 8}, options);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1});
      auto outputs = ke.run({t0, t1});
    
      testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
    
      struct GetReference : AbstractGetReference {
        GetReference(const TensorIndexer& indexer, const IdModel& id_model)
            : AbstractGetReference(indexer, id_model) {}
    
        // The unswitch predicate should look like:
        //
        // tv3:
        // ((nvfuser_index_t)threadIdx.x) < 8LL &&
        // ((i0 * 8LL) + ((nvfuser_index_t)threadIdx.x)) >= 0LL
        // ((i0 * 8LL) + ((nvfuser_index_t)threadIdx.x)) < 32LL
    
        Val* getOuterPredicate(TensorView* tv) const override {
          std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);
          Val* zero = tv->fusion()->zeroVal();
    
          if (tv->name() == 3) {
            auto tidx = loop_indices.back();
            auto tid_pred = ltExpr(tidx, createInt(8));
            auto idx = addExpr(mulExpr(loop_indices.front(), createInt(8)), tidx);
            auto min_pred = geExpr(idx, zero);
            auto max_pred = ltExpr(idx, createInt(32));
            return andExpr(andExpr(tid_pred, min_pred), max_pred);
          } else {
            return nullptr;
          }
        }
      };
    
      PredicateIndexValidator<GetReference>::validate(&fusion, false);
    }

    @naoyam naoyam changed the title Optimization of parallel dimension predicate with IdModel Fix optimization of parallel dimension predicate with IdModel Feb 21, 2025
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 21, 2025

    !test --diff

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant