Skip to content

Commit

Permalink
[CPU] optimize PagedAttention's shape inference (openvinotoolkit#23603)
Browse files Browse the repository at this point in the history
### Details:
 - *Specific shape inference for PagedAttention*
 - *...*

### Tickets:
 - *ticket-id*
  • Loading branch information
luo-cheng2021 authored and bbielawx committed Apr 12, 2024
1 parent 6ae7a10 commit c4af202
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@ class SDPAShapeInfer : public ShapeInferEmptyPads {
ScaledDotProductAttentionWithKVCache::Config m_config;
};

class PAShapeInfer : public ShapeInferEmptyPads {
public:
PAShapeInfer() {}

IShapeInfer::Result infer(const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override {
const auto& query_dims = input_shapes.front().get();

return {{query_dims}, ShapeInferStatus::success};
}

port_mask_t get_port_mask() const override {
return EMPTY_PORT_MASK;
}
};

ShapeInferPtr SDPAShapeInferFactory::makeShapeInfer() const {
if (m_op->get_type_name() == std::string("PagedAttentionExtension")) {
return std::make_shared<PAShapeInfer>();
}
if (auto sdpa = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(m_op)) {
const auto& config = sdpa->get_config();
if (config.output_BLHxS == false)
Expand Down

0 comments on commit c4af202

Please sign in to comment.