Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move GEMM tests to Cutlass #220

Open
wants to merge 15 commits into
base: sycl-develop
Choose a base branch
from
10 changes: 5 additions & 5 deletions test/unit/cute/intel_xe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ cutlass_test_unit_add_executable(
copy_block.cpp
copy_scatter.cpp
mma.cpp
gemm_partition_src_dst.cpp
gemm_partition_fragment_abc.cpp
gemm_tiled_copy_abc.cpp
gemm_layout.cpp
gemm_data_type.cpp
# gemm_partition_src_dst.cpp
# gemm_partition_fragment_abc.cpp
# gemm_tiled_copy_abc.cpp
# gemm_layout.cpp
# gemm_data_type.cpp
)
else()
cutlass_test_unit_add_executable(
Expand Down
7 changes: 7 additions & 0 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@

if(CUTLASS_ENABLE_SYCL)
if(SYCL_INTEL_TARGET)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_xe
xe_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
)

cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_xe
xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Expand All @@ -41,13 +46,15 @@ if(CUTLASS_ENABLE_SYCL)
add_custom_target(
cutlass_test_unit_gemm_device
DEPENDS
cutlass_test_unit_gemm_device_tensorop_xe
cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_xe
cutlass_test_unit_gemm_device_mixed_input_tensorop_xe
)

add_custom_target(
test_unit_gemm_device
DEPENDS
cutlass_test_unit_gemm_device_tensorop_xe
test_unit_gemm_device_tensorop_epilogue_fusion_xe
test_unit_gemm_device_mixed_input_tensorop_xe
)
Expand Down
99 changes: 99 additions & 0 deletions test/unit/gemm/device/default_gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,105 @@ struct DefaultGemmConfigurationToCutlass3Types<

///////////////////////////////////////////////////////////////////////////////

#if defined(CUTLASS_ENABLE_SYCL)
namespace detail {

template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpXe_OperandA;

template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpXe_OperandB;

//
// Bfloat16
//

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpXe_OperandA<bfloat16_t, layout::RowMajor, 32, 32>
{
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpXe_OperandA<bfloat16_t, layout::ColumnMajor, 32, SizeK>
{
// Gmem
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

/// Operand B - Row-major (N-Major)
template <>
struct DefaultGemm_TensorOpXe_OperandB<bfloat16_t, layout::RowMajor, 32, 32>
{
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

/// Operand B - Column-major (K-major)
template <int SizeK>
struct DefaultGemm_TensorOpXe_OperandB<bfloat16_t, layout::ColumnMajor, 32, SizeK>
{
// Gmem
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

}

///////////////////////////////////////////////////////////////////////////////

// Ampere MMA F32F16
template <typename LayoutA, typename LayoutB, typename LayoutC>
struct DefaultGemmConfigurationToCutlass3Types<
arch::OpClassTensorOp, arch::IntelPVC,
bfloat16_t, LayoutA,
bfloat16_t, LayoutB,
float, LayoutC,
float>
{
using TileShape = Shape<_256, _256, _32>;

using DispatchPolicy = MainloopIntelPVC<3>;
using TiledMma =
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;

// A
static constexpr int kAlignmentA = 32;
using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA<
bfloat16_t, LayoutA, kAlignmentA, 32>;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;

// B
static constexpr int kAlignmentB = 32;
using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB<
bfloat16_t, LayoutB, kAlignmentB, 32>;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;

// Mainloop
using CollectiveMainloop = collective::CollectiveMma<
DispatchPolicy, TileShape,
bfloat16_t, TagToStrideA_t<LayoutA>,
bfloat16_t, TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA, void, void, cute::identity, // A
GmemTiledCopyB, void, void, cute::identity // B
>;

// Epilogue
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
float,
TagToStrideC_t<LayoutC>,
TagToStrideC_t<LayoutC>,
epilogue::thread::LinearCombination<float, 1, float, float>,
cutlass::gemm::EpilogueDefault>;
};

#endif
///////////////////////////////////////////////////////////////////////////////

} // namespace device
} // namespace gemm
} // namespace cutlass
64 changes: 64 additions & 0 deletions test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

/*! \file
\brief Tests for Xe bf16t_bf16t_f32
*/

#include <iostream>

#include "cutlass/cutlass.h"

#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "default_gemm_configuration.hpp"

#include "gemm_testbed_3x.hpp"

using namespace cute;

TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_f32, 256x256x32) {
using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types<
cutlass::arch::OpClassTensorOp, cutlass::arch::IntelPVC,
bfloat16_t, cutlass::layout::RowMajor,
bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float>;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
Config::CollectiveMainloop,
Config::CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
}