Skip to content

Commit

Permalink
[CPU] FullyConnected acceleration with 8bit weights decompression on SPR
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Aug 10, 2023
1 parent f5f221a commit 64bdaa8
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 36 deletions.
29 changes: 19 additions & 10 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,32 +251,41 @@ void DnnlPostOpsComposer::appendClip(const std::vector<float>& low, const std::v
}
}

void DnnlPostOpsComposer::appendDecompressionScales(const std::vector<float>& scales) {
MemoryPtr DnnlPostOpsComposer::prepackDecompressionParams(const std::vector<float>& params, size_t icBlock) {
// Prepacking params from [oc] to [oc, icBlock] layout, where for each icBlock corresponding parameter is duplicated
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({icBlock * params.size()}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
size_t dstIdx = 0;
auto decomp_scales_buf = static_cast<float*>(mem->getData());
for (int oc = 0; oc < params.size(); oc++) {
for (int intIdx = 0; intIdx < icBlock; intIdx++) {
decomp_scales_buf[dstIdx] = params[oc];
dstIdx++;
}
}
return mem;
}

void DnnlPostOpsComposer::appendDecompressionScales(const std::vector<float>& scales, size_t icBlock) {
if (scales.empty())
return;

int mask = scales.size() > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask);

DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({scales.size()}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
memcpy(mem->getData(), scales.data(), scales.size() * sizeof(float));
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = mem;
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(scales, icBlock);
}

void DnnlPostOpsComposer::appendDecompressionZeroPoints(const std::vector<float>& zero_points) {
void DnnlPostOpsComposer::appendDecompressionZeroPoints(const std::vector<float>& zero_points, size_t icBlock) {
if (zero_points.empty())
return;

int mask = zero_points.size() > 1 ? weightScaleMaskPerChannel : 0;
DEBUG_LOG("Set weights zero points mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", mask);
attr.set_zero_points_mask(DNNL_ARG_WEIGHTS, mask);

DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({zero_points.size()}));
auto mem = std::make_shared<Memory>(engine, memoryDesc);
memcpy(mem->getData(), zero_points.data(), zero_points.size() * sizeof(float));
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = mem;
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = prepackDecompressionParams(zero_points, icBlock);
}

} // namespace intel_cpu
Expand Down
5 changes: 3 additions & 2 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class DnnlPostOpsComposer {
bool appendLinear(const std::vector<float>& scale, const std::vector<float>& shift, bool isLastPostOp, bool allowBinary = true);
void appendClip(const std::vector<float>& low, const std::vector<float>& high);

void appendDecompressionScales(const std::vector<float>& scales);
void appendDecompressionZeroPoints(const std::vector<float>& zero_points);
void appendDecompressionScales(const std::vector<float>& scales, size_t icBlock);
void appendDecompressionZeroPoints(const std::vector<float>& zero_points, size_t icBlock);

const VectorDims& getOutputDims() {
return outputDims;
Expand All @@ -69,6 +69,7 @@ class DnnlPostOpsComposer {

void updateWeiScales();
void updateDestScales();
MemoryPtr prepackDecompressionParams(const std::vector<float>& params, size_t icBlock);
};

} // namespace intel_cpu
Expand Down
25 changes: 21 additions & 4 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {

void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8};
const std::set<InferenceEngine::Precision> supportedDataPrecisions{InferenceEngine::Precision::FP32, InferenceEngine::Precision::BF16};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
};
Expand Down Expand Up @@ -334,14 +335,14 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
continue;

// Precision limitations
if (fcNode->getOriginalInputPrecisionAtPort(0) != Precision::FP32)
continue;
if (multiplyConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())
continue;
if (withSubtract && subtractConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
if (supportedDataPrecisions.find(fcNode->getOriginalInputPrecisionAtPort(0)) == supportedDataPrecisions.end())
continue;
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())
continue;

// Shape limitations
const auto weightsShape = weightsNode->getOutputShapeAtPort(0);
Expand All @@ -356,6 +357,22 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
if (withSubtract && subtractConstNode->getOutputShapeAtPort(0).getDims() != expectedDims)
continue;

// HW specific shape limitations
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx)) {
// OneDNN AMX IP implementation has limited shapes support due to performance considerations. As a current solution conditions below are copied
// from OneDNN to make sure correct IP impl will be used since fallback one doesn't support weights decompression feature.
size_t OC = withTranspose ? weightsShape.getDims()[1] : weightsShape.getDims()[0];
size_t IC = withTranspose ? weightsShape.getDims()[0] : weightsShape.getDims()[1];
size_t simdWidth = 16;
size_t vnniFactor = 2;
size_t maxSize = 512;
auto amxRow = vnniFactor * simdWidth;

if ((IC <= amxRow && OC <= amxRow) || (IC <= maxSize && OC <= maxSize && IC % amxRow != 0))
continue;
}

// Fusion processing
fcNode->fuseDecompressionMultiply(multiplyConstNode);
if (withSubtract)
fcNode->fuseDecompressionSubtract(subtractConstNode);
Expand Down
15 changes: 12 additions & 3 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ void FullyConnected::getSupportedDescriptors() {

useSparseWeights = useSparseWeightsDecompression();
useWeightsDecompressionImpl = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
inputDataType == memory::data_type::f32 && weightsDataType == memory::data_type::u8;
one_of(inputDataType, memory::data_type::f32, memory::data_type::bf16) &&
weightsDataType == memory::data_type::u8;

// revert back outputDataType on special cases
if (inputDataType == memory::data_type::f32) {
Expand Down Expand Up @@ -594,10 +595,18 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, canBeExecutedInInt8(),
1 << 0, getDQScales(), withBiases);

NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";
// OneDNN API doesn't provide an abilitiy to query optimal layout for runtime attributes
// As workaround we assume that all AMX IP implementations use equal internal IC block size for weights layout
// and prepack runtime attributes accordingly for better performance
bool withAMX = selected_pd->getImplementationType() & impl_desc_type::amx;
int icBlock = withAMX ? 2 : 1;
if (!decompressionMultiply.empty())
dnnlpoc.appendDecompressionScales(decompressionMultiply);
dnnlpoc.appendDecompressionScales(decompressionMultiply, icBlock);
if (!decompressionSubtract.empty())
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtract);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtract, icBlock);

for (size_t i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ using MatmulWeightsDecompressionParams = std::tuple<std::vector<InputShape>, //
bool, // decompression subtract
bool, // reshape on decompression constants
std::map<std::string, std::string>, // additional config
fusingSpecificParams>;
fusingSpecificParams,
bool>; // should use decompression implementation

class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeightsDecompressionParams>,
virtual public SubgraphBaseTest,
Expand All @@ -51,14 +52,16 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool reshape_on_decompression;
std::map<std::string, std::string> additional_config;
fusingSpecificParams fusing_params;
bool should_fuse;

std::tie(inputShapes,
weights_precision,
transpose,
decompression_sub,
reshape_on_decompression,
additional_config,
fusing_params) = obj.param;
fusing_params,
should_fuse) = obj.param;

std::ostringstream result;
for (const auto& shape : inputShapes) {
Expand Down Expand Up @@ -158,22 +161,22 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool reshape_on_decompression;
std::map<std::string, std::string> additional_config;
fusingSpecificParams fusing_params;
bool should_fuse;

std::tie(inputShapes,
weights_precision,
transpose_weights,
decompression_sub,
reshape_on_decompression,
additional_config,
fusing_params) = GetParam();
fusing_params,
should_fuse) = GetParam();

configuration.insert(additional_config.begin(), additional_config.end());
std::tie(postOpMgrPtr, fusedOps) = fusing_params;
init_input_shapes(inputShapes);

ElementType netType = element::f32;
if (additional_config[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
netType = ElementType::bf16;
inType = outType = netType;

function = initSubgraph(inputDynamicShapes, netType, weights_precision, transpose_weights, decompression_sub, reshape_on_decompression);
Expand All @@ -182,16 +185,15 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
void checkResults() {
const auto& test_param = GetParam();
ov::test::ElementType weights_precision = std::get<1>(test_param);
bool should_fuse = std::get<7>(test_param);
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "Compressed_weights") {
ASSERT_EQ(n->get_output_element_type(0), weights_precision);
}
}

std::map<std::string, std::string> additional_config = std::get<5>(test_param);
const size_t expected_count =
InferenceEngine::with_cpu_x86_avx2() &&
additional_config[PluginConfigParams::KEY_ENFORCE_BF16] != PluginConfigParams::YES ? 0 : 1;
const size_t expected_count = should_fuse ? 0 : 1;
CheckNumberOfNodesWithType(compiledModel, "Convert", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Eltwise", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Subgraph", 0);
Expand All @@ -207,24 +209,44 @@ TEST_P(MatmulWeightsDecompression, CompareWithRefs) {
namespace {

std::vector<std::map<std::string, std::string>> filterAdditionalConfig() {
std::vector<std::map<std::string, std::string>> additional_config{CPUTestUtils::cpuEmptyPluginConfig};
std::vector<std::map<std::string, std::string>> additional_config;//{CPUTestUtils::cpuEmptyPluginConfig};
if (with_cpu_x86_avx512_core())
additional_config.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
return additional_config;
}

bool shouldUseDecompressionKernel() {
// No decompression support on non-avx systems
if (!with_cpu_x86_avx2())
return false;

return true;
}

bool shouldUseDecompressionKernelAMX() {
// AMX decompression support has shape limitations
if (with_cpu_x86_avx512_core_amx())
return false;

return shouldUseDecompressionKernel();
}

const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8};
const std::vector<std::vector<InputShape>> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {{}, {{16, 32}}}},
{{{}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{}, {{10, 40, 496}}}, {{}, {{1, 496, 240}}}},
{{{}, {{1, 4, 32}}}, {{}, {{32, 256}}}},
{{{}, {{1, 4, 48}}}, {{}, {{48, 256}}}},
{{{}, {{11, 339, 377}}}, {{}, {{377, 335}}}},
};
const std::vector<std::vector<InputShape>> input_shapes_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {{}, {{1, 480, 256}}}},
{{{}, {{1, 4, 32}}}, {{}, {{32, 256}}}},
{{{}, {{1, 4, 512}}}, {{}, {{512, 256}}}},
{{{}, {{1, 16, 32}}}, {{}, {{32, 64}}}},
{{{}, {{2, 4, 32}}}, {{}, {{32, 65}}}},
{{{}, {{11, 339, 377}}}, {{}, {{377, 335}}}},
{{{}, {{3, 12, 768}}}, {{}, {{768, 1024}}}},
{{{}, {{11, 339, 577}}}, {{}, {{577, 335}}}},
};
const std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
Expand All @@ -239,26 +261,56 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
::testing::Values(true),
::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfig()),
::testing::ValuesIn(fusingParamsSet)),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(shouldUseDecompressionKernelAMX())),
MatmulWeightsDecompression::getTestCaseName);

const std::vector<std::vector<InputShape>> input_shapes_corner_cases = {
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_big,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_big),
::testing::ValuesIn(weights_precisions),
::testing::Values(true),
::testing::Values(true),
::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfig()),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(shouldUseDecompressionKernel())),
MatmulWeightsDecompression::getTestCaseName);

const std::vector<std::vector<InputShape>> input_shapes_corner_cases_basic = {
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{16, 32}}}},
};
const std::vector<std::vector<InputShape>> input_shapes_corner_cases_big = {
{{{-1, -1, -1}, {{10, 40, 480}, {11, 40, 480}}}, {{}, {{1, 480, 256}}}},
};

const std::vector<bool> transpose_weights = {true, false};
const std::vector<bool> add_decompression_sub = {true, false};
const std::vector<bool> reshape_on_decompression = {true, false};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases,
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_basic),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig),
::testing::Values(emptyFusingSpec),
::testing::Values(shouldUseDecompressionKernelAMX())),
MatmulWeightsDecompression::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_big,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases),
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases_big),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig),
::testing::Values(emptyFusingSpec)),
::testing::Values(emptyFusingSpec),
::testing::Values(shouldUseDecompressionKernel())),
MatmulWeightsDecompression::getTestCaseName);
} // namespace

Expand Down

0 comments on commit 64bdaa8

Please sign in to comment.