Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow native vectors for LLVM operations #7155

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/dxc/DXIL/DxilInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,42 @@ struct LlvmInst_VAArg {
bool isAllowed() const { return false; }
};

/// This instruction extracts from vector
struct LlvmInst_ExtractElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ExtractElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ExtractElement;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction inserts into vector
struct LlvmInst_InsertElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_InsertElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::InsertElement;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction Shuffle two vectors
struct LlvmInst_ShuffleVector {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ShuffleVector(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ShuffleVector;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction extracts from aggregate
struct LlvmInst_ExtractValue {
llvm::Instruction *Instr;
Expand Down
2 changes: 2 additions & 0 deletions lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,8 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx,
return true;

if (Ty->isVectorTy()) {
if (ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
return true;
ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector);
return false;
}
Expand Down
6 changes: 6 additions & 0 deletions lib/HLSL/DxilLinker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,12 @@ void DxilLinkJob::RunPreparePass(Module &M) {
// For static global handle.
PM.add(createLowerStaticGlobalIntoAlloca());

// Change dynamic indexing vector to array where vectors aren't
// supported, but might be there from the initial compile.
if (!pSM->IsSM69Plus())
PM.add(
createDynamicIndexingVectorToArrayPass(false /* ReplaceAllVector */));

// Remove MultiDimArray from function call arg.
PM.add(createMultiDimArrayToOneDimArrayPass());

Expand Down
44 changes: 28 additions & 16 deletions lib/HLSL/HLMatrixBitcastLowerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ class MatrixBitcastLowerPass : public FunctionPass {

// Lower matrix first.
for (BitCastInst *BCI : matCastSet) {
lowerMatrix(BCI, BCI->getOperand(0));
lowerMatrix(DM, BCI, BCI->getOperand(0));
}
return bUpdated;
}

private:
void lowerMatrix(Instruction *M, Value *A);
void lowerMatrix(DxilModule &DM, Instruction *M, Value *A);
bool hasCallUser(Instruction *M);
};

Expand Down Expand Up @@ -180,7 +180,8 @@ Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
}
} // namespace

void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
void MatrixBitcastLowerPass::lowerMatrix(DxilModule &DM, Instruction *M,
Value *A) {
for (auto it = M->user_begin(); it != M->user_end();) {
User *U = *(it++);
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
Expand All @@ -193,31 +194,42 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
DXASSERT(idxList.size() == 2,
"else not one dim matrix array index to matrix");

HLMatrixType MatTy = HLMatrixType::cast(EltTy);
Value *matSize = Builder.getInt32(MatTy.getNumElements());
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
if (!DM.GetShaderModel()->IsSM69Plus()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we base the low-level decisions on a descriptively named boolean field in the MatrixBitcastLowerPass class, initialized based on the shader model, instead of passing the DxilModule down and checking the shader model each time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even something about scalarizing operations, rather than anything to do with the native DXIL vector support?

HLMatrixType MatTy = HLMatrixType::cast(EltTy);
Value *matSize = Builder.getInt32(MatTy.getNumElements());
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
}
Value *NewGEP = Builder.CreateGEP(A, idxList);
lowerMatrix(GEP, NewGEP);
lowerMatrix(DM, GEP, NewGEP);
DXASSERT(GEP->user_empty(), "else lower matrix fail");
GEP->eraseFromParent();
} else {
DXASSERT(0, "invalid GEP for matrix");
}
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
lowerMatrix(BCI, A);
lowerMatrix(DM, BCI, A);
DXASSERT(BCI->user_empty(), "else lower matrix fail");
BCI->eraseFromParent();
} else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
IRBuilder<> Builder(LI);
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
Value *NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
Value *NewVec = nullptr;
if (DM.GetShaderModel()->IsSM69Plus()) {
// Just create a replacement load using the vector pointer.
Instruction *NewLI = LI->clone();
unsigned VecIdx = NewLI->getNumOperands() - 1;
NewLI->setOperand(VecIdx, A);
Builder.Insert(NewLI);
NewVec = NewLI;
} else {
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that below this point, there's:

    } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {

where it still scalarizes the store for the vector.

Did you mean to leave that scalarization in?

}
LI->replaceAllUsesWith(NewVec);
LI->eraseFromParent();
Expand Down
6 changes: 6 additions & 0 deletions lib/Transforms/Scalar/DxilEliminateVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// //
///////////////////////////////////////////////////////////////////////////////

#include "dxc/DXIL/DxilModule.h"

#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Pass.h"
Expand Down Expand Up @@ -151,6 +153,10 @@ bool DxilEliminateVector::TryRewriteDebugInfoForVector(InsertElementInst *IE) {

bool DxilEliminateVector::runOnFunction(Function &F) {

if (F.getParent()->HasDxilModule())
if (F.getParent()->GetDxilModule().GetShaderModel()->IsSM69Plus())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is where we presumably will still want to do something different for vec1, right?


auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
DxilValueCache *DVC = &getAnalysis<DxilValueCache>();

Expand Down
18 changes: 15 additions & 3 deletions lib/Transforms/Scalar/LowerTypePasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/HLSL/HLModule.h"
Expand Down Expand Up @@ -180,10 +181,12 @@ bool LowerTypePass::runOnModule(Module &M) {
namespace {
class DynamicIndexingVectorToArray : public LowerTypePass {
bool ReplaceAllVectors;
bool SupportsVectors;

public:
explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll),
SupportsVectors(false) {}
static char ID; // Pass identification, replacement for typeid
void applyOptions(PassOptions O) override;
void dumpConfig(raw_ostream &OS) override;
Expand All @@ -194,6 +197,7 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
Type *lowerType(Type *Ty) override;
Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
StringRef getGlobalPrefix() override { return ".v"; }
void initialize(Module &M) override;

private:
bool HasVectorDynamicIndexing(Value *V);
Expand All @@ -207,6 +211,11 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
void ReplaceAddrSpaceCast(ConstantExpr *CE, Value *A, IRBuilder<> &Builder);
};

void DynamicIndexingVectorToArray::initialize(Module &M) {
if (M.HasHLModule())
SupportsVectors = M.GetHLModule().GetShaderModel()->IsSM69Plus();
}

void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
ReplaceAllVectors);
Expand Down Expand Up @@ -286,7 +295,7 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
StoreInst *stInst = cast<StoreInst>(GEPUser);
Value *val = stInst->getValueOperand();
Value *ldVal = Builder.CreateLoad(V);
ldVal = Builder.CreateInsertElement(ldVal, val, constIdx);
ldVal = Builder.CreateInsertElement(ldVal, val, constIdx); // UGH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you want to elaborate on the "UGH" comment? Is it on the general operation being performed, or something specific about this line?

Builder.CreateStore(ldVal, V);
stInst->eraseFromParent();
}
Expand All @@ -306,8 +315,11 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
}

bool DynamicIndexingVectorToArray::needToLower(Value *V) {
// Only needed where vectors aren't supported.
if (SupportsVectors)
return false;
Type *Ty = V->getType()->getPointerElementType();
if (dyn_cast<VectorType>(Ty)) {
if (isa<VectorType>(Ty)) {
if (isa<GlobalVariable>(V) || ReplaceAllVectors) {
return true;
}
Expand Down
18 changes: 11 additions & 7 deletions lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,8 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
// if
// all its users can be transformed, then split up the aggregate into its
// separate elements.
if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
if (!HLM.GetShaderModel()->IsSM69Plus() && ShouldAttemptScalarRepl(AI) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is so complicated as it is, can we capture a bool for DXIL vector support earlier and use that instead of the individual SM 6.9 checks here as well?

isSafeAllocaToScalarRepl(AI)) {
std::vector<Value *> Elts;
IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(AI));
bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
Expand Down Expand Up @@ -1945,8 +1946,9 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
continue;
}

// Flat Global vector if no dynamic vector indexing.
bool bFlatVector = !hasDynamicVectorIndexing(GV);
// Flat Global vector if no dynamic vector indexing and pre-6.9.
bool bFlatVector =
!hasDynamicVectorIndexing(GV) && !HLM.GetShaderModel()->IsSM69Plus();

if (bFlatVector) {
GVDbgOffset &dbgOffset = GVDbgOffsetMap[GV];
Expand Down Expand Up @@ -1980,10 +1982,12 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
} else {
// SROA_Parameter_HLSL has no access to a domtree, if one is needed,
// it'll be generated
SROAed = SROA_Helper::DoScalarReplacement(
GV, Elts, Builder, bFlatVector,
// TODO: set precise.
/*hasPrecise*/ false, typeSys, DL, DeadInsts, /*DT*/ nullptr);
if (!HLM.GetShaderModel()->IsSM69Plus()) {
SROAed = SROA_Helper::DoScalarReplacement(
GV, Elts, Builder, bFlatVector,
// TODO: set precise.
/*hasPrecise*/ false, typeSys, DL, DeadInsts, /*DT*/ nullptr);
}
}

if (SROAed) {
Expand Down
6 changes: 6 additions & 0 deletions lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
//
//===----------------------------------------------------------------------===//

#include "dxc/DXIL/DxilModule.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
Expand Down Expand Up @@ -290,6 +292,10 @@ bool Scalarizer::doInitialization(Module &M) {
}

bool Scalarizer::runOnFunction(Function &F) {
if (F.getParent()->HasDxilModule())
if (F.getParent()->GetDxilModule().GetShaderModel()->IsSM69Plus())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually want to turn off scalarization entirely? Maybe this is one place where we need to preserve it for vec1 only?


for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
BasicBlock *BB = BBI;
for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
Expand Down
6 changes: 4 additions & 2 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,14 @@ void AddStdIsEqualImplementation(clang::ASTContext &context, clang::Sema &sema);
clang::CXXRecordDecl *DeclareTemplateTypeWithHandle(
clang::ASTContext &context, llvm::StringRef name,
uint8_t templateArgCount = 1,
clang::TypeSourceInfo *defaultTypeArgValue = nullptr);
clang::TypeSourceInfo *defaultTypeArgValue = nullptr,
clang::InheritableAttr *Attr = nullptr);

clang::CXXRecordDecl *DeclareTemplateTypeWithHandleInDeclContext(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef name, uint8_t templateArgCount,
clang::TypeSourceInfo *defaultTypeArgValue);
clang::TypeSourceInfo *defaultTypeArgValue,
clang::InheritableAttr *Attr = nullptr);

clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandle(
clang::ASTContext &context, llvm::StringRef typeName,
Expand Down
12 changes: 12 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,18 @@ def HLSLNodeTrackRWInputSharing : InheritableAttr {
let Documentation = [Undocumented];
}

def HLSLCBuffer : InheritableAttr {
let Spellings = []; // No spellings!
let Subjects = SubjectList<[CXXRecord]>;
let Documentation = [Undocumented];
}

def HLSLTessPatch : InheritableAttr {
let Spellings = []; // No spellings!
let Subjects = SubjectList<[CXXRecord]>;
let Documentation = [Undocumented];
}

def HLSLNodeObject : InheritableAttr {
let Spellings = []; // No spellings!
let Subjects = SubjectList<[CXXRecord]>;
Expand Down
4 changes: 2 additions & 2 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -7691,8 +7691,6 @@ def err_hlsl_control_flow_cond_not_scalar : Error<
"%0 statement conditional expressions must evaluate to a scalar">;
def err_hlsl_unsupportedvectortype : Error<
"%0 is declared with type %1, but only primitive scalar values are supported">;
def err_hlsl_unsupportedvectorsize : Error<
"%0 is declared with size %1, but only values 1 through 4 are supported">;
def err_hlsl_unsupportedmatrixsize : Error<
"%0 is declared with size %1x%2, but only values 1 through 4 are supported">;
def err_hlsl_norm_float_only : Error<
Expand Down Expand Up @@ -7843,6 +7841,8 @@ def err_hlsl_load_from_mesh_out_arrays: Error<
"output arrays of a mesh shader can not be read from">;
def err_hlsl_out_indices_array_incorrect_access: Error<
"a vector in out indices array must be accessed as a whole">;
def err_hlsl_unsupported_long_vector: Error<
"Vectors of over 4 elements in %0 are not supported">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Vectors of over 4 elements in %0 are not supported">;
"vectors of over 4 elements in %0 are not supported">;

def err_hlsl_logical_binop_scalar : Error<
"operands for short-circuiting logical binary operator must be scalar, for non-scalar types use '%select{and|or}0'">;
def err_hlsl_ternary_scalar : Error<
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema,
const clang::InitListExpr *InitList,
const clang::QualType EltTy);

bool HasLongVecs(const clang::QualType &qt);

bool IsConversionToLessOrEqualElements(clang::Sema *self,
const clang::ExprResult &sourceExpr,
const clang::QualType &targetType,
Expand Down
19 changes: 13 additions & 6 deletions tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,18 +903,19 @@ void hlsl::AddStdIsEqualImplementation(clang::ASTContext &context,
/// <parm name="templateArgCount">Number of template arguments (one or
/// two).</param> <parm name="defaultTypeArgValue">If assigned, the default
/// argument for the element template.</param>
CXXRecordDecl *
hlsl::DeclareTemplateTypeWithHandle(ASTContext &context, StringRef name,
uint8_t templateArgCount,
TypeSourceInfo *defaultTypeArgValue) {
CXXRecordDecl *hlsl::DeclareTemplateTypeWithHandle(
ASTContext &context, StringRef name, uint8_t templateArgCount,
TypeSourceInfo *defaultTypeArgValue, InheritableAttr *Attr) {
return DeclareTemplateTypeWithHandleInDeclContext(
context, context.getTranslationUnitDecl(), name, templateArgCount,
defaultTypeArgValue);
defaultTypeArgValue, Attr);
}

CXXRecordDecl *hlsl::DeclareTemplateTypeWithHandleInDeclContext(
ASTContext &context, DeclContext *declContext, StringRef name,
uint8_t templateArgCount, TypeSourceInfo *defaultTypeArgValue) {
uint8_t templateArgCount, TypeSourceInfo *defaultTypeArgValue,
InheritableAttr *Attr) {

DXASSERT(templateArgCount != 0,
"otherwise caller should be creating a class or struct");
DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated "
Expand Down Expand Up @@ -968,6 +969,9 @@ CXXRecordDecl *hlsl::DeclareTemplateTypeWithHandleInDeclContext(

typeDeclBuilder.addField("h", elementType);

if (Attr)
typeDeclBuilder.getRecordDecl()->addAttr(Attr);

return typeDeclBuilder.getRecordDecl();
}

Expand Down Expand Up @@ -1131,6 +1135,9 @@ hlsl::DeclareConstantBufferViewType(clang::ASTContext &context, bool bTBuf) {
typeDeclBuilder.addField(
"h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.

typeDeclBuilder.getRecordDecl()->addAttr(
HLSLCBufferAttr::CreateImplicit(context));

typeDeclBuilder.getRecordDecl();

return templateRecordDecl;
Expand Down
Loading