-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coord refactor #186
base: sycl-develop
Are you sure you want to change the base?
Coord refactor #186
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @t4c1 - a few small things I spotted.
auto | ||
get_pvc_tensor(GShape const& g_shape) const { | ||
static_assert(rank(GShape{}) == 3, "mismatch rank"); | ||
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_tma_tensor
uses g_stride_
for the 2nd arg to make_layout
here. Is there any loss of generality with this simpler approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if this stride works correctly for column major
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, but I think that is used to encode col/row-major information. For PVC, that is instead encoded in the copy atom itself.
include/cute/atom/copy_traits_xe.hpp
Outdated
constexpr int dtype_size = sizeof(dtype); | ||
constexpr int bits_in_byte = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass provides cutlass::sizeof_bits<dtype>
for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
include/cute/atom/copy_traits_xe.hpp
Outdated
static_assert(is_rmem<TS>::value); | ||
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}), | ||
"Src tensor size does not match copy atom size"); | ||
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, use cutlass::sizeof_bits<dtype>
I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -137,12 +137,31 @@ struct CollectiveMma< | |||
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>; | |||
using atom_load_B = Copy_Atom<traits_load_B, ElementB>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the changes from this file need to be copied over to xe_mma_mixed_input.hpp
. I am getting local failure of ninja test_unit_gemm_device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k) | ||
using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used in universal gemm
|
||
// Instantiate the MMA object and get thread slice | ||
TiledMma tiled_mma; | ||
auto thr_mma = tiled_mma.get_slice(thread_idx); | ||
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a TODO(Codeplay):
here to fix this later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Tensor tArA = thr_copy_A2.retile_D(tCrA); | ||
Tensor tBrB = thr_copy_B2.retile_D(tCrB); | ||
|
||
// Retile global tile for copies | ||
Tensor tAgA = thr_copy_A2.retile_S(tCgA); | ||
Tensor tBgB = thr_copy_B2.retile_S(tCgB); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retile_D
and retile_S
do the same thing by the way. Not sure if that affects what's going on here - but I don't think I've seen both used anywhere before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if they always do the same thing. Anyway one is intended for Source and one for Destination and that is how I use them here.
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N) | ||
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering here, if it should be possible to avoid this and have something like
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible if I manually construct the tile for MN dimesnions. I can only use the last argument(projection) if the previous two have the same number of modees.
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this warp is responsible for | ||
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
// Instantiate the MMA object and get thread slice | ||
TiledMma tiled_mma; | ||
auto thr_mma = tiled_mma.get_slice(thread_idx); | ||
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp | |
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Tensor tCrA = make_tensor<ElementA>(tCgA(_,_,_,0).shape()); | ||
Tensor tCrB = make_tensor<ElementB>(tCgB(_,_,_,0).shape(), make_stride(_1{}, shape<0>(tCgB) * shape<2>(tCgB), shape<0>(tCgB))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This too line does not seems to match what you are aiming to do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean with that?
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) | ||
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) | ||
|
||
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{}); | ||
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{}); | ||
|
||
// Slice with m_coord and n_coord | ||
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) | ||
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here I think it should be possible to say:
Tensor gA = local_tile(mA_mkl, blk_shape, make_coord(m_coord,_,l_coord), Step<_1, X, _1>{});
Tensor gB = local_tile(mB_nkl, blk_shape, make_coord(n_coord,_,l_coord), Step< X, _1, _1>{});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done something similar
@@ -243,22 +235,19 @@ class GemmUniversal< | |||
// Get the appropriate blocks for this sub_group -- potential for sub_group locality | |||
int thread_idx = int(ThreadIdxX()); | |||
auto blk_shape = TileShape{}; | |||
#ifdef CUTLASS_SYCL_SWITCH_WG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not using CUTLASS_SYCL_SWITCH_WG anymore, could you remove the definition in the CMakeLists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
using XE_Copy_O = decltype(make_xe_2d_copy(Copy_Atom<Copy_Traits<CopyOpO, StrideO>, ElementO>{}.with( | ||
make_tensor(make_gmem_ptr(static_cast<ElementO const*>(nullptr)), make_layout(make_shape(0, 0, 0), StrideO{}))), | ||
Layout<Shape<_1, Int<SubgroupSize>>>{})); | ||
using XE_Copy_O = decltype(make_tiled_copy(Copy_Atom<Trait_O, ElementO>{} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do reformatting and code changes in the same PR. Makes review unnecessary hard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not reformatting, it is a functional change make_xe_2d_copy
-> make_tiled_copy
.
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>, | ||
TiledMMA<MMAAtom, | ||
Layout<Shape<_8,_4,_1>, Stride<_4,_1,_0>>, | ||
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be easier to read with make_ordered_layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With make_ordered_layout
we would also need decltype. I am a bit on the edge but I think I prefer it to be explicit. Maybe we can make a helper in the future that will not need decltype. But not in this PR.
…into coord_refactor
Refactor coordinates for PVC copies to be consistent with how copies for all CUDA GPUs are called.