Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coord refactor #186

Open
wants to merge 48 commits into
base: sycl-develop
Choose a base branch
from
Open

Conversation

t4c1
Copy link
Collaborator

@t4c1 t4c1 commented Jan 17, 2025

Refactor coordinates for PVC copies to be consistent with how copies for all CUDA GPUs are called.

@t4c1 t4c1 marked this pull request as ready for review January 29, 2025 09:38
Copy link
Collaborator

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @t4c1 - a few small things I spotted.

auto
get_pvc_tensor(GShape const& g_shape) const {
static_assert(rank(GShape{}) == 3, "mismatch rank");
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>())));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_tma_tensor uses g_stride_ for the 2nd arg to make_layout here. Is there any loss of generality with this simpler approach?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if this stride works correctly for column major

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, but I think that is used to encode col/row-major information. For PVC, that is instead encoded in the copy atom itself.

Comment on lines 170 to 171
constexpr int dtype_size = sizeof(dtype);
constexpr int bits_in_byte = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutlass provides cutlass::sizeof_bits<dtype> for this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 320 to 323
static_assert(is_rmem<TS>::value);
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}),
"Src tensor size does not match copy atom size");
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, use cutlass::sizeof_bits<dtype> I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -137,12 +137,31 @@ struct CollectiveMma<
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes from this file need to be copied over to xe_mma_mixed_input.hpp. I am getting local failure of ninja test_unit_gemm_device

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 151 to 152
using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k)
using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used in universal gemm


// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(thread_idx);
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a TODO(Codeplay): here to fix this later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 256 to 261
Tensor tArA = thr_copy_A2.retile_D(tCrA);
Tensor tBrB = thr_copy_B2.retile_D(tCrB);

// Retile global tile for copies
Tensor tAgA = thr_copy_A2.retile_S(tCgA);
Tensor tBgB = thr_copy_B2.retile_S(tCgB);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retile_D and retile_S do the same thing by the way. Not sure if that affects what's going on here - but I don't think I've seen both used anywhere before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if they always do the same thing. Anyway one is intended for Source and one for Destination and that is how I use them here.

Comment on lines 325 to 328
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering here, if it should be possible to avoid this and have something like

 Tensor g_cta_D_mnl  = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{}); 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible if I manually construct the tile for MN dimesnions. I can only use the last argument(projection) if the previous two have the same number of modees.

Comment on lines 331 to 334
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(thread_idx);
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
Copy link
Collaborator

@mehdi-goli mehdi-goli Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 252 to 253
Tensor tCrA = make_tensor<ElementA>(tCgA(_,_,_,0).shape());
Tensor tCrB = make_tensor<ElementB>(tCgB(_,_,_,0).shape(), make_stride(_1{}, shape<0>(tCgB) * shape<2>(tCgB), shape<0>(tCgB)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too line does not seems to match what you are aiming to do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean with that?

Comment on lines 249 to 257
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)

auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{});
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{});

// Slice with m_coord and n_coord
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
Copy link
Collaborator

@mehdi-goli mehdi-goli Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here I think it should be possible to say:

Tensor gA = local_tile(mA_mkl, blk_shape, make_coord(m_coord,_,l_coord), Step<_1,  X, _1>{});                                          
Tensor gB = local_tile(mB_nkl, blk_shape, make_coord(n_coord,_,l_coord), Step< X, _1, _1>{});

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done something similar

@@ -243,22 +235,19 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto blk_shape = TileShape{};
#ifdef CUTLASS_SYCL_SWITCH_WG
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not using CUTLASS_SYCL_SWITCH_WG anymore, could you remove the definition in the CMakeLists?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@FMarno FMarno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

using XE_Copy_O = decltype(make_xe_2d_copy(Copy_Atom<Copy_Traits<CopyOpO, StrideO>, ElementO>{}.with(
make_tensor(make_gmem_ptr(static_cast<ElementO const*>(nullptr)), make_layout(make_shape(0, 0, 0), StrideO{}))),
Layout<Shape<_1, Int<SubgroupSize>>>{}));
using XE_Copy_O = decltype(make_tiled_copy(Copy_Atom<Trait_O, ElementO>{}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do reformatting and code changes in the same PR. Makes review unnecessary hard.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not reformatting, it is a functional change make_xe_2d_copy -> make_tiled_copy.

TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
TiledMMA<MMAAtom,
Layout<Shape<_8,_4,_1>, Stride<_4,_1,_0>>,
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be easier to read with make_ordered_layout

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With make_ordered_layout we would also need decltype. I am a bit on the edge but I think I prefer it to be explicit. Maybe we can make a helper in the future that will not need decltype. But not in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants