昇腾 paged_attention 算子
【代码】昇腾 paged_attention 算子。
·
/*
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "paged_attention_operation.h"
#include <map>
#include "paged_attention_ops_runner.h"
#include "paged_attention_ops_runner_910a.h"
#include "atb/utils/tensor_check.h"
#include "atb/utils/tensor_util.h"
#include "atb/utils/operation_util.h"
#include "atb/utils/config.h"
#include "atb/utils/param_to_json.h"
#include "atb/core/atb_operation_ir_cfg.h"
#include "atb/utils/singleton.h"
#include "atb/core/op_param_funcs.h"
namespace {
static const uint32_t IN_TENSOR_NUM = 5;
static const uint32_t OUT_TENSOR_NUM = 1;
static const uint32_t DIM_ALIGN_16_NZ = 16;
static const uint32_t MAX_BATCH_SIZE_8192 = 8192;
static const int QUANTOFFSET_BIT = 0x00001;
static const int USEQUANT_BIT = 0x00002;
static const int BATCHSTATUS_BIT = 0x00004;
static const int MASK_BIT = 0x00008;
static const int QLENS_BIT = 0x00010;
static const int RAZOROFFSET_BIT = 0x00020;
static const int LOGN_BIT = 0x00040;
static const int QKVQUANTOFFLINE_BIT = 0x00040;
static const int QKVQUANTONLINE_BIT = 0x00080;
} // namespace
namespace atb {
static bool DeviceParamCheck(const infer::PagedAttentionParam &opParam);
static bool CompressParamCheck(const infer::PagedAttentionParam &opParam);
static bool CalcParamCheck(const infer::PagedAttentionParam &opParam);
static bool QuantParamCheck(const infer::PagedAttentionParam &opParam);
static bool LogNParamCheck(const infer::PagedAttentionParam &opParam);
static bool BNSDParamCheck(const infer::PagedAttentionParam &opParam);
static bool MlaParamCheck(const infer::PagedAttentionParam &opParam);
template <> Status CreateOperation(const infer::PagedAttentionParam &opParam, Operation **operation)
{
if (operation == nullptr) {
return ERROR_INVALID_PARAM;
}
OP_PARAM_RSV_CHECK(opParam);
if (opParam.headNum <= 0) {
ATB_LOG(ERROR) << "headNum should be greater than zero!";
return ERROR_INVALID_PARAM;
}
if (opParam.kvHeadNum < 0) {
ATB_LOG(ERROR) << "kvHeadNum should be no less than zero!";
return ERROR_INVALID_PARAM;
}
if (opParam.kvHeadNum != 0) {
if (opParam.headNum % opParam.kvHeadNum != 0) {
ATB_LOG(ERROR) << "headNum mod kvHeadNum should be zero";
return ERROR_INVALID_PARAM;
}
}
if (!DeviceParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!CompressParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!CalcParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!QuantParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!LogNParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!BNSDParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
if (!MlaParamCheck(opParam)) {
return ERROR_INVALID_PARAM;
}
*operation = new (std::nothrow) PagedAttentionOperation(opParam);
if (*operation == nullptr) {
ATB_LOG(ERROR) << "failed to new operation";
return ERROR_OUT_OF_HOST_MEMORY;
}
return NO_ERROR;
}
bool DeviceParamCheck(const infer::PagedAttentionParam &opParam)
{
if (!GetSingleton<Config>().Is910B()) {
if (opParam.batchRunStatusEnable) {
ATB_LOG(ERROR) << "dynamic batch only support Atlas 800I A2 inference product";
return false;
}
if (opParam.compressType != infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "head compress only support Atlas 800I A2 inference product";
return false;
}
if (opParam.mlaVHeadSize > 0) {
ATB_LOG(ERROR) << "mla mode only support Atlas 800I A2 inference product";
return false;
}
if (opParam.quantType != atb::infer::PagedAttentionParam::QuantType::TYPE_QUANT_UNQUANT) {
ATB_LOG(ERROR) << "quant feature only support Atlas 800I A2 inference product";
return false;
}
}
if (GetSingleton<Config>().Is910A()) {
if (opParam.maskType == atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_SPEC) {
ATB_LOG(ERROR) << "SPEC_MASK does not support Atlas 800 training product";
return false;
}
if (opParam.calcType != atb::infer::PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "SPEC feature does not support Atlas 800 training product";
return false;
}
if (opParam.scaleType != atb::infer::PagedAttentionParam::ScaleType::SCALE_TYPE_TOR) {
ATB_LOG(ERROR) << "logN feature does not support Atlas 800 training product";
return false;
}
if (opParam.inputLayout != atb::infer::InputLayout::TYPE_BSND) {
ATB_LOG(ERROR) << "BNSD feature does not support Atlas 800 training product";
return false;
}
}
return true;
}
bool CompressParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.compressType >= infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_MAX ||
opParam.compressType < infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "compressType should be in the range of its enum value";
return false;
}
if (opParam.compressType == infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE &&
opParam.maskType != infer::PagedAttentionParam::MaskType::UNDEFINED) {
ATB_LOG(ERROR) << "When compressType is set to COMPRESS_TYPE_KVHead_ROPE, maskType must be set to UNDEFINED.";
return false;
}
if (opParam.compressType == infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE &&
opParam.batchRunStatusEnable) {
ATB_LOG(ERROR) << "When compressType is set to COMPRESS_TYPE_KVHEAD_ROPE,"
<< "batchRunStatusEnable must be set to false.";
return false;
}
return true;
}
bool CalcParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.calcType != infer::PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED) {
if (opParam.batchRunStatusEnable) {
ATB_LOG(ERROR) << "SPEC func does not support dynamic batch";
return false;
}
if (opParam.compressType != infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "SPEC func only support when compressType is COMPRESS_TYPE_UNDEFINED";
return false;
}
if (GetSingleton<Config>().Is910B()) {
if (opParam.mlaVHeadSize == 0 &&
opParam.maskType != atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_NORM &&
opParam.maskType != atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_SPEC) {
ATB_LOG(ERROR) << "SPEC func only support norm mask and spec mask";
return false;
}
} else {
if (opParam.maskType != atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_SPEC &&
opParam.maskType != atb::infer::PagedAttentionParam::MaskType::UNDEFINED) {
ATB_LOG(ERROR) << "SPEC func only support no mask and spec mask";
return false;
}
}
}
return true;
}
bool QuantParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.quantType == infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION &&
opParam.calcType != infer::PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "Dequant only support when calcType is CALC_TYPE_UNDEFINED";
return false;
}
if (opParam.quantType == infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE ||
opParam.quantType == infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE) {
if (opParam.outDataType != ACL_FLOAT16 && opParam.outDataType != ACL_BF16) {
ATB_LOG(ERROR) << "outDataType only support ACL_FLOAT16 and ACL_BF16";
return false;
}
if (opParam.hasQuantOffset) {
ATB_LOG(ERROR) << "QKVQuant only support when hasQuantOffset is False";
return false;
}
if (opParam.compressType != infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "QKVQuant only support when compressType is COMPRESS_TYPE_UNDEFINED";
return false;
}
if (opParam.calcType != infer::PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "QKVQuant only support when calcType is CALC_TYPE_UNDEFINED";
return false;
}
}
return true;
}
bool LogNParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.scaleType >= infer::PagedAttentionParam::ScaleType::SCALE_TYPE_MAX ||
opParam.scaleType < infer::PagedAttentionParam::ScaleType::SCALE_TYPE_TOR) {
ATB_LOG(ERROR) << "scaleType should be in the range of its enum value";
return false;
}
if (opParam.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN) {
if (opParam.quantType != infer::PagedAttentionParam::TYPE_QUANT_UNQUANT) {
ATB_LOG(ERROR) << "logN func does not support quant feature";
return false;
}
if (opParam.calcType != infer::PagedAttentionParam::CALC_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "logN func does not support calcType feature";
return false;
}
if (opParam.compressType != atb::infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "logN func does not support compressType feature";
return false;
}
}
return true;
}
bool BNSDParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.inputLayout == infer::InputLayout::TYPE_BNSD) {
if (opParam.calcType == atb::infer::PagedAttentionParam::CALC_TYPE_SPEC) {
ATB_LOG(ERROR) << "BNSD feature and calcType feature cannot coexist";
return false;
}
if (opParam.compressType != atb::infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "BNSD feature and compressType feature cannot coexist";
return false;
}
if (opParam.quantType != atb::infer::PagedAttentionParam::TYPE_QUANT_UNQUANT) {
ATB_LOG(ERROR) << "BNSD feature and quantType feature cannot coexist";
return false;
}
if (opParam.scaleType != atb::infer::PagedAttentionParam::SCALE_TYPE_TOR) {
ATB_LOG(ERROR) << "BNSD feature and scaleType feature cannot coexist";
return false;
}
}
return true;
}
bool MlaParamCheck(const infer::PagedAttentionParam &opParam)
{
if (opParam.mlaVHeadSize > 0) {
if (opParam.maskType == infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI) {
ATB_LOG(ERROR) << "mla mode does not support alibi mask";
return false;
}
if (opParam.calcType == infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC &&
(opParam.maskType == infer::PagedAttentionParam::MaskType::MASK_TYPE_NORM ||
opParam.quantType != infer::PagedAttentionParam::QuantType::TYPE_QUANT_UNQUANT ||
opParam.batchRunStatusEnable)) {
ATB_LOG(ERROR) << "spec does not support norm mask, quant and dynamic batch";
return false;
}
if (opParam.compressType != infer::PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED) {
ATB_LOG(ERROR) << "mla mode does not support compress mask";
return false;
}
if (opParam.quantType == infer::PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION) {
ATB_LOG(ERROR) << "mla mode does not support dequant_fusion";
return false;
}
if (opParam.scaleType != infer::PagedAttentionParam::ScaleType::SCALE_TYPE_TOR) {
ATB_LOG(ERROR) << "mla mode does not support logN scale";
return false;
}
if (opParam.inputLayout != infer::InputLayout::TYPE_BSND) {
ATB_LOG(ERROR) << "mla mode only support BSND InputLayout";
return false;
}
if (opParam.kvHeadNum != 1) {
ATB_LOG(ERROR) << "kvHeadNum should be 1, mla mode only support MQA";
return false;
}
if (opParam.mlaVHeadSize > 576) { // 576: MLA大小限制
ATB_LOG(ERROR) << "mlaVHeadSize should be no greater than 576";
return false;
}
}
return true;
}
PagedAttentionOperation::PagedAttentionOperation(const infer::PagedAttentionParam ¶m)
: OperationBase("PagedAttentionOperation"), param_(param)
{
if (param_.mlaVHeadSize > 0) {
std::stringstream opIrKeySs;
opIrKeySs << "PagedAttentionOperationMla";
if (param_.maskType != infer::PagedAttentionParam::MaskType::UNDEFINED) {
opIrKeySs << "Mask";
}
if (param_.batchRunStatusEnable) {
opIrKeySs << "Batch";
}
if (param_.quantType == infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_OFFLINE) {
opIrKeySs << "QuantOffline";
} else if (param_.quantType == infer::PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE) {
opIrKeySs << "QuantOnline";
}
if (param_.calcType == infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC) {
opIrKeySs << "Qlens";
}
operationIr_ = GetSingleton<AtbOperationIrCfg>().GetOperationIr(opIrKeySs.str());
} else {
InitOpIni();
}
ATB_LOG(INFO) << GetLogPrefix() << "PagedAttentionParam headNum:" << param.headNum << ", qkScale:" << param.qkScale
<< ", kvHeadNum:" << param.kvHeadNum << ", maskType:" << param.maskType
<< ", batchRunStatusEnable:" << param.batchRunStatusEnable << ", quantType:" << param.quantType
<< ", outDataType:" << param.outDataType << ", hasQuantOffset:" << param.hasQuantOffset
<< ", compressType:" << param.compressType << ", calcType:" << param.calcType
<< ", scaleType:" << param.scaleType << ", inputLayout:" << param.inputLayout;
}
PagedAttentionOperation::~PagedAttentionOperation() {}
uint32_t PagedAttentionOperation::GetInputNum() const
{
uint32_t inputNumBase = IN_TENSOR_NUM;
if (param_.maskType != atb::infer::PagedAttentionParam::UNDEFINED) {
inputNumBase += 1; // need to input mask
}
if (param_.batchRunStatusEnable) {
inputNumBase += 1; // need to input batchRunStatus
}
if (param_.quantType == infer::PagedAttentionParam::TYPE_DEQUANT_FUSION) {
inputNumBase += 2; // 2: kDescale, vDescale
if (param_.hasQuantOffset) {
inputNumBase += 2; // 2: kOffset, vOffset
}
}
if (param_.calcType == infer::PagedAttentionParam::CALC_TYPE_SPEC) {
inputNumBase += 1; // 1: qSeqLen
}
if (param_.compressType == infer::PagedAttentionParam::COMPRESS_TYPE_KVHEAD_ROPE) {
inputNumBase += 1; // 1: razorOffset
}
bool needQKVQuant = (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE ||
param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE);
if (needQKVQuant) {
inputNumBase += 2; // 2: kDescale, vDescale
if (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE) {
inputNumBase += 1; // pScale
}
}
if (param_.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN) {
inputNumBase += 1; // 1: logN
}
if (param_.mlaVHeadSize > 0) {
inputNumBase--;
}
return inputNumBase;
}
uint32_t PagedAttentionOperation::GetOutputNum() const
{
return OUT_TENSOR_NUM;
}
Status PagedAttentionOperation::InferShapeImpl(const SVector<TensorDesc> &inTensorDescs,
SVector<TensorDesc> &outTensorDescs) const
{
outTensorDescs.at(0) = inTensorDescs.at(0);
bool needQKVQuant = (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE ||
param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE);
if (needQKVQuant) {
outTensorDescs.at(0).dtype = param_.outDataType;
}
if (GetSingleton<Config>().Is910B()) {
int64_t hiddenSizeValue = 0;
if (param_.mlaVHeadSize > 0) {
hiddenSizeValue = param_.mlaVHeadSize;
} else {
uint32_t hiddenSizeValuePos = inTensorDescs.at(2).shape.dimNum - 1;
hiddenSizeValue = inTensorDescs.at(2).shape.dims[hiddenSizeValuePos]; // 2: valueTensor
}
uint32_t hiddenSizeOut = outTensorDescs.at(0).shape.dimNum - 1;
outTensorDescs.at(0).shape.dims[hiddenSizeOut] = hiddenSizeValue;
}
return NO_ERROR;
}
Status PagedAttentionOperation::InferShapeCheckImpl(const SVector<TensorDesc> &inTensorDescs) const
{
Status st = NO_ERROR;
if (param_.inputLayout == infer::InputLayout::TYPE_BNSD && GetSingleton<Config>().Is910B()) {
st = InferShapeDimCheckBNSD910B(inTensorDescs);
} else {
st = InferShapeDimCheck(inTensorDescs);
}
if (st != NO_ERROR) {
return st;
}
return NO_ERROR;
}
Status PagedAttentionOperation::SetupCheckImpl(const SVector<Tensor> &inTensors,
const SVector<Tensor> &outTensors) const
{
if (inTensors.at(1).desc.shape.dimNum != 4) { // 4: 必须是四维
ATB_LOG(ERROR) << "ErrorCode: " << ERROR_INVALID_TENSOR_DIM
<< ". the keyCache dimNum is:" << inTensors.at(1).desc.shape.dimNum
<< ". keyCache should be 4 dims";
return ERROR_INVALID_TENSOR_DIM_NUM;
}
Status st = NO_ERROR;
if (param_.inputLayout == infer::InputLayout::TYPE_BSND) {
st = SetupDimCheck(inTensors, outTensors);
} else if (param_.inputLayout == infer::InputLayout::TYPE_BNSD && GetSingleton<Config>().Is910B()) {
st = SetupDimCheckBNSD910B(inTensors, outTensors);
}
if (st != NO_ERROR) {
return st;
}
return NO_ERROR;
}
Status PagedAttentionOperation::KVCacheDimCheck310P(const SVector<TensorDesc> &inTensorDescs) const
{
if (inTensorDescs.at(1).shape.dims[3] != DIM_ALIGN_16_NZ) { // 1: keyCache 3: last dim
ATB_LOG(ERROR) << "lastDim of KVCache should be 16";
return ERROR_INVALID_TENSOR_DIM;
}
if (inTensorDescs.at(2).shape.dims[3] != DIM_ALIGN_16_NZ) { // 2: valueCache 3: last dim
ATB_LOG(ERROR) << "lastDim of KVCache should be 16";
return ERROR_INVALID_TENSOR_DIM;
}
int64_t blockSize = inTensorDescs.at(2).shape.dims[2]; // 2: valueCache 2: blockSize
if (blockSize != inTensorDescs.at(1).shape.dims[2]) { // 1: keyCache 2: blockSize
ATB_LOG(ERROR) << "blocksize of KVCache should be same";
return ERROR_INVALID_TENSOR_DIM;
}
int64_t kvHeadNum = param_.kvHeadNum > 0 ? param_.kvHeadNum : param_.headNum;
// kvHeadNum is checked > 0 in CreateOperation
int64_t headSize = inTensorDescs.at(1).shape.dims[1] * DIM_ALIGN_16_NZ / kvHeadNum;
if (headSize % DIM_ALIGN_16_NZ != 0) {
ATB_LOG(ERROR) << "head_size should align 16 when format of keycache is NZ";
return ERROR_INVALID_TENSOR_DIM;
}
if (headSize > 256 || headSize * blockSize > 128 * 128) { // 256: 310p headSize大小限制 // 128: 大小限制
ATB_LOG(ERROR) << "head_size of keyCache should be no greater than 256 and "
<< "block_size * head_size should be no greater than 128 * 128";
return ERROR_INVALID_TENSOR_DIM;
}
if (headSize != inTensorDescs.at(0).shape.dims[2]) { // 2: headSize
ATB_LOG(ERROR) << "headSizes of query and keyCache should be same";
return ERROR_INVALID_TENSOR_DIM;
}
return NO_ERROR;
}
bool PagedAttentionOperation::IsInMLAIncompatible() const
{
bool needQKVQuant = (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE ||
param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE);
if (param_.quantType == infer::PagedAttentionParam::TYPE_DEQUANT_FUSION ||
(param_.calcType == infer::PagedAttentionParam::CALC_TYPE_SPEC && param_.mlaVHeadSize == 0) ||
(needQKVQuant && param_.mlaVHeadSize == 0) || param_.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN ||
param_.compressType != infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED) {
return true;
}
return false;
}
bool PagedAttentionOperation::MlaBatchSizeCheck(const SVector<TensorDesc> &inTensorDescs) const
{
int64_t batchSize = inTensorDescs.at(3).shape.dims[0]; // 3: contextLens
int64_t maxBatchSize = MAX_BATCH_SIZE_8192;
if (batchSize > maxBatchSize) {
ATB_LOG(ERROR) << "batchSize should <= " << maxBatchSize;
return false;
}
return true;
}
Status PagedAttentionOperation::KVCacheDimCheck910B(const SVector<TensorDesc> &inTensorDescs) const
{
int64_t headSize = inTensorDescs.at(1).shape.dims[3]; // 1: keyCache 3: headSize
if (headSize != inTensorDescs.at(0).shape.dims[2]) { // 2: headSize
ATB_LOG(ERROR) << "headSize of keyCache and query should be same";
return ERROR_INVALID_TENSOR_DIM;
}
int64_t blockSize = inTensorDescs.at(1).shape.dims[1]; // 1: keyCache 1: 1st dim
if (IsInMLAIncompatible()) { // 非mla情况
if (headSize != inTensorDescs.at(2).shape.dims[3]) { // 2: valueCache dim 3: headSize
ATB_LOG(ERROR) << "headSize of keyCache and valueCache should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (headSize > 256 || headSize * blockSize > 128 * 128) { // 256: 310p headSize大小限制 // 128: 大小限制
ATB_LOG(ERROR) << "head_size of keyCache should be no greater than 256 and "
<< "block_size * head_size should be no greater than 128 * 128";
return ERROR_INVALID_TENSOR_DIM;
}
} else {
if (param_.mlaVHeadSize > headSize) {
ATB_LOG(ERROR) << "mlaVHeadSize should be no greater than headSizeK";
return ERROR_INVALID_TENSOR_DIM;
}
int64_t headSizeV = param_.mlaVHeadSize > 0 ? param_.mlaVHeadSize :
inTensorDescs.at(2).shape.dims[3]; // 2: valueCache 3: headSize
if (headSize > 576 || headSizeV > 576) { // 576: 910b headSize大小限制 // 576: headSize大小限制
ATB_LOG(ERROR) << "head_size of keyCache and ValueCache should be no greater than 576";
return ERROR_INVALID_TENSOR_DIM;
}
if (param_.mlaVHeadSize > 0 && !MlaBatchSizeCheck(inTensorDescs)) {
return ERROR_INVALID_TENSOR_DIM;
}
// 特殊场景支持blocksize 256
bool blockSize256Check =
param_.mlaVHeadSize > 0 && blockSize == 256 && param_.kvHeadNum == 1 && // 256: blockSize
(param_.headNum == 16 || param_.headNum == 32) && headSize == 576 && // 16 32: headNum 576: headSize
headSizeV == 512 && // 512: headSizeV
param_.calcType != infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC;
if (blockSize256Check) {
return NO_ERROR;
}
if (((headSize > 256 || headSizeV > 256) && blockSize > 128)) { // 128: mla blockSize大小限制 256:headsize阈值
ATB_LOG(ERROR) << "blockSize should be no greater than 128 if headSize > 256";
return ERROR_INVALID_TENSOR_DIM;
}
}
return NO_ERROR;
}
Status PagedAttentionOperation::KVCacheDimCheck(const SVector<TensorDesc> &inTensorDescs) const
{
int64_t numBlocks = inTensorDescs.at(1).shape.dims[0]; // 1: keyCache
int64_t blockSize = inTensorDescs.at(1).shape.dims[1]; // 1: keyCache 1: blockSize
int64_t headNum = inTensorDescs.at(1).shape.dims[2]; // 1: keyCache 2: headNum
if (param_.mlaVHeadSize == 0) {
if (numBlocks != inTensorDescs.at(2).shape.dims[0]) { // 2: valueCache
ATB_LOG(ERROR) << "numBlocks should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (blockSize != inTensorDescs.at(2).shape.dims[1]) { // 2: valueCache 1: 1st dim
ATB_LOG(ERROR) << "2nd dim of KVCache should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (headNum != inTensorDescs.at(2).shape.dims[2]) { // 2: valueCache 1: 2: headNum
ATB_LOG(ERROR) << "3rd dim of KVCache should be same";
return ERROR_INVALID_TENSOR_DIM;
}
}
Status st = NO_ERROR;
if (!GetSingleton<Config>().Is910B()) {
st = KVCacheDimCheck310P(inTensorDescs);
} else {
st = KVCacheDimCheck910B(inTensorDescs);
}
return st;
}
Status PagedAttentionOperation::InferShapeDimCheck(const SVector<TensorDesc> &inTensorDescs) const
{
uint32_t blockTablesPos = 3; // 3: blockTables
uint32_t contextLensPos = 4; // 4: contextLens
if (param_.mlaVHeadSize > 0) {
blockTablesPos--;
contextLensPos--;
} else if (inTensorDescs.at(2).shape.dimNum != 4) { // 2: valueCache 4: 4 dims
ATB_LOG(ERROR) << "invalid intensor2 dimNum";
return ERROR_INVALID_TENSOR_DIM_NUM;
}
if (inTensorDescs.at(0).shape.dimNum != 3 || // 0: query 3: 3 dims
inTensorDescs.at(1).shape.dimNum != 4 || // 1: keyCache 4: 4 dims
inTensorDescs.at(blockTablesPos).shape.dimNum != 2 || // 2: 2 dims
inTensorDescs.at(contextLensPos).shape.dimNum != 1) { // 1: 1 dim
ATB_LOG(ERROR) << "invalid intensor dimNum";
return ERROR_INVALID_TENSOR_DIM_NUM;
}
int64_t numTokens = inTensorDescs.at(0).shape.dims[0];
if (param_.batchRunStatusEnable) {
if (numTokens > inTensorDescs.at(blockTablesPos).shape.dims[0] || // 3: blockTables
numTokens > inTensorDescs.at(contextLensPos).shape.dims[0]) { // 4: contextLens
ATB_LOG(ERROR) << "numTokens in q should be no greater than blockTables and contextLens"
<< " , and should be same as output";
return ERROR_INVALID_TENSOR_DIM;
}
}
Status st = KVCacheDimCheck(inTensorDescs);
if (st != NO_ERROR) {
return st;
}
return NO_ERROR;
}
Status PagedAttentionOperation::InferShapeDimCheckBNSD910B(const SVector<TensorDesc> &inTensorDescs) const
{
if (inTensorDescs.at(0).shape.dimNum != 3 || // 0: query 3: 3rd dims
inTensorDescs.at(1).shape.dimNum != 4 || // 1: keyCache 4: 4th dims
inTensorDescs.at(2).shape.dimNum != 4 || // 2: valueCache 4: 4th dims
inTensorDescs.at(3).shape.dimNum != 2 || // 3: blockTables 2: 2nd dims
inTensorDescs.at(4).shape.dimNum != 1) { // 4: contestLens 1: 1st dim
ATB_LOG(ERROR) << "invalid intensor dimNum";
return ERROR_INVALID_TENSOR_DIM_NUM;
}
int64_t headSize = inTensorDescs.at(0).shape.dims[2]; // 0: query 2: 2nd dim
int64_t numBlocks = inTensorDescs.at(1).shape.dims[0]; // 1: keyCache 0: 0th dim
int64_t blockSize = inTensorDescs.at(1).shape.dims[2]; // 1: keyCache 2: 2nd dim
if (headSize != inTensorDescs.at(1).shape.dims[3] || // 1: keyCache 3: 3rd dim
headSize != inTensorDescs.at(2).shape.dims[3]) { // 2: valueCache 3: 3rd dim
ATB_LOG(ERROR) << "headSize should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (numBlocks != inTensorDescs.at(2).shape.dims[0]) { // 2: valueCache 0: 0th dim
ATB_LOG(ERROR) << "numBlocks should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (blockSize != inTensorDescs.at(2).shape.dims[2]) { // 2: valueCache 2: 2nd dim
ATB_LOG(ERROR) << "blockSizes should be same";
return ERROR_INVALID_TENSOR_DIM;
}
return NO_ERROR;
}
bool PagedAttentionOperation::BlockDimCheck(const SVector<Tensor> &inTensors, const SVector<Tensor> &outTensors) const
{
int64_t numBlocks = inTensors.at(1).desc.shape.dims[0]; // 1: keyCache
int64_t numHeads = inTensors.at(0).desc.shape.dims[1]; // 1: 1st dim
int64_t blockSize = inTensors.at(1).desc.shape.dims[1]; // 1: keyCache 1: 1st dim
if (param_.mlaVHeadSize == 0 && numBlocks != inTensors.at(2).desc.shape.dims[0]) { // 2: valueCache
ATB_LOG(ERROR) << GetLogPrefix() << "numBlocks should be same";
return false;
}
if (numHeads != outTensors.at(0).desc.shape.dims[1]) { // 1: 1st dim
ATB_LOG(ERROR) << GetLogPrefix() << "numHeads should be same";
return false;
}
if (param_.mlaVHeadSize == 0 && blockSize != inTensors.at(2).desc.shape.dims[1]) { // 2: valueCache 1: 1st dim
ATB_LOG(ERROR) << GetLogPrefix() << "blockSizes should be same";
return false;
}
return true;
}
bool PagedAttentionOperation::RazorDimCheck(const SVector<Tensor> &inTensors) const
{
int64_t numBlocks = inTensors.at(1).desc.shape.dims[0]; // 1: keyCache
int64_t blockSize = inTensors.at(1).desc.shape.dims[1]; // 1: keyCache 1: 1st dim
uint64_t dimRazor = 2;
uint64_t indexReverse = (param_.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN) ? 2 : 1;
if (param_.compressType == infer::PagedAttentionParam::COMPRESS_TYPE_KVHEAD_ROPE) {
if (inTensors.at(inTensors.size() - indexReverse).desc.shape.dimNum != dimRazor) {
ATB_LOG(ERROR) << GetLogPrefix() << "invalid intensor dimNum";
return false;
}
if (numBlocks != inTensors.at(inTensors.size() - indexReverse).desc.shape.dims[0]) {
ATB_LOG(ERROR) << GetLogPrefix() << "numBlocks should be same";
return false;
}
if (blockSize != inTensors.at(inTensors.size() - indexReverse).desc.shape.dims[1]) {
ATB_LOG(ERROR) << GetLogPrefix() << "blockSizes should be same";
return false;
}
}
return true;
}
bool PagedAttentionOperation::LogNDimCheck(const SVector<Tensor> &inTensors) const
{
if (param_.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN) {
if (inTensors.at(inTensors.size() - 1).desc.shape.dimNum != 1) {
ATB_LOG(ERROR) << GetLogPrefix() << "invalid logN intensor dimNum";
return false;
}
uint32_t inputNumBase = IN_TENSOR_NUM;
uint32_t batchLogNNum = inTensors.at(inTensors.size() - 1).desc.shape.dims[0];
if (inTensors.at(inputNumBase - 1).desc.shape.dims[0] != batchLogNNum) {
ATB_LOG(ERROR) << "intensor contextLens and intensor logn has different batch size";
return false;
}
if (param_.maskType != atb::infer::PagedAttentionParam::UNDEFINED) {
inputNumBase += 1; // need to input mask
}
if (param_.batchRunStatusEnable) {
if (inTensors.at(inputNumBase).desc.shape.dims[0] != batchLogNNum) {
ATB_LOG(ERROR) << "intensor batchRunStatus and intensor logn has different batch size";
return false;
}
}
}
return true;
}
Status PagedAttentionOperation::SetupDimCheck(const SVector<Tensor> &inTensors, const SVector<Tensor> &outTensors) const
{
SVector<TensorDesc> inTensorDescs = {};
OperationUtil::InTensorsToInTensorDescs(inTensors, inTensorDescs);
Status st = InferShapeCheckImpl(inTensorDescs);
if (st != NO_ERROR) {
return st;
}
int64_t numTokens = inTensors.at(0).desc.shape.dims[0];
if (param_.batchRunStatusEnable) {
if (numTokens != outTensors.at(0).desc.shape.dims[0]) {
ATB_LOG(ERROR) << GetLogPrefix() << "numTokens of outTensor should be the same as q";
return ERROR_INVALID_TENSOR_DIM;
}
}
int64_t targetHeadSize =
param_.mlaVHeadSize > 0 ? param_.mlaVHeadSize : inTensors.at(2).desc.shape.dims[3]; // 2: valueCache 3: 3rd dim
if (!GetSingleton<Config>().Is910B()) {
targetHeadSize = inTensors.at(0).desc.shape.dims[2]; // 2: 2nd dim
}
if (targetHeadSize != outTensors.at(0).desc.shape.dims[2]) { // 2: 2nd dim
ATB_LOG(ERROR) << "headSize of attnOut error! It should equal to " << targetHeadSize;
return ERROR_INVALID_TENSOR_DIM;
}
if (!BlockDimCheck(inTensors, outTensors)) {
return ERROR_INVALID_TENSOR_DIM;
}
if (!RazorDimCheck(inTensors)) {
return ERROR_INVALID_TENSOR_DIM;
}
if (!LogNDimCheck(inTensors)) {
return ERROR_INVALID_TENSOR_DIM;
}
return NO_ERROR;
}
Status PagedAttentionOperation::SetupDimCheckBNSD910B(const SVector<Tensor> &inTensors,
const SVector<Tensor> &outTensors) const
{
if (inTensors.at(0).desc.shape.dimNum != 3 || // 0: query 3: 3 dims
inTensors.at(1).desc.shape.dimNum != 4 || // 1: keyCache 4: 4 dims
inTensors.at(2).desc.shape.dimNum != 4 || // 2: valueCache 4: 4 dims
inTensors.at(3).desc.shape.dimNum != 2 || // 3: blockTables 2: 2 dims
inTensors.at(4).desc.shape.dimNum != 1) { // 4:contestLens 1: 1 dim
ATB_LOG(ERROR) << "invalid intensor dimNum";
return ERROR_INVALID_TENSOR_DIM_NUM;
}
int64_t headSize = inTensors.at(0).desc.shape.dims[2]; // 0: query 2: 2nd dim
int64_t numBlocks = inTensors.at(1).desc.shape.dims[0]; // 1: keyCache 0: 0th dim
int64_t blockSize = inTensors.at(1).desc.shape.dims[2]; // 1: keyCache 2: 2nd dim
if (headSize != inTensors.at(1).desc.shape.dims[3] || // 1: keyCache 3: 3rd dim
headSize != inTensors.at(2).desc.shape.dims[3] || // 2: valueCache 3: 3rd dim
headSize != outTensors.at(0).desc.shape.dims[2]) { // 3: 3rd dim 2: 2nd dim
ATB_LOG(ERROR) << "headSizes should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (numBlocks != inTensors.at(2).desc.shape.dims[0]) { // 2: valueCache 0: 0th dim
ATB_LOG(ERROR) << "numBlocks should be same";
return ERROR_INVALID_TENSOR_DIM;
}
if (blockSize != inTensors.at(2).desc.shape.dims[2]) { // 2: valueCache 2: 2nd dim
ATB_LOG(ERROR) << "blockSize should be same";
return ERROR_INVALID_TENSOR_DIM;
}
return NO_ERROR;
}
uint32_t PagedAttentionOperation::Bools2IntQKVQuant(AttentionFlags inputB) const
{
uint32_t ret = 0;
ret = inputB.useQuantOffset ? (ret | QUANTOFFSET_BIT) : ret;
ret = inputB.useQuant ? (ret | USEQUANT_BIT) : ret;
ret = inputB.useBatchRunStatus ? (ret | BATCHSTATUS_BIT) : ret;
ret = inputB.useMask ? (ret | MASK_BIT) : ret;
ret = inputB.useQLens ? (ret | QLENS_BIT) : ret;
ret = inputB.useRazorOffset ? (ret | RAZOROFFSET_BIT) : ret;
ret = inputB.useQKVQuantOffline ? (ret | QKVQUANTOFFLINE_BIT) : ret;
ret = inputB.useQKVQuantOnline ? (ret | QKVQUANTONLINE_BIT) : ret;
return ret;
}
uint32_t PagedAttentionOperation::Bools2IntLogN(AttentionFlags inputB) const
{
uint32_t ret = 0;
ret = inputB.useQuantOffset ? (ret | QUANTOFFSET_BIT) : ret;
ret = inputB.useQuant ? (ret | USEQUANT_BIT) : ret;
ret = inputB.useBatchRunStatus ? (ret | BATCHSTATUS_BIT) : ret;
ret = inputB.useMask ? (ret | MASK_BIT) : ret;
ret = inputB.useQLens ? (ret | QLENS_BIT) : ret;
ret = inputB.useRazorOffset ? (ret | RAZOROFFSET_BIT) : ret;
ret = inputB.useLogN ? (ret | LOGN_BIT) : ret;
return ret;
}
void PagedAttentionOperation::InitOpIni()
{
AttentionFlags inputB2I;
inputB2I.useQuantOffset = param_.hasQuantOffset;
inputB2I.useQuant = (param_.quantType == infer::PagedAttentionParam::TYPE_DEQUANT_FUSION);
inputB2I.useBatchRunStatus = param_.batchRunStatusEnable;
inputB2I.useMask = (param_.maskType != atb::infer::PagedAttentionParam::UNDEFINED);
inputB2I.useQLens = (param_.calcType == atb::infer::PagedAttentionParam::CALC_TYPE_SPEC);
inputB2I.useRazorOffset = (param_.compressType == infer::PagedAttentionParam::COMPRESS_TYPE_KVHEAD_ROPE);
inputB2I.useLogN = (param_.scaleType == infer::PagedAttentionParam::SCALE_TYPE_LOGN);
inputB2I.useQKVQuantOffline = (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_OFFLINE);
inputB2I.useQKVQuantOnline = (param_.quantType == atb::infer::PagedAttentionParam::TYPE_QUANT_QKV_ONLINE);
uint32_t caseCodeLogN = Bools2IntLogN(inputB2I);
uint32_t caseCodeQKVQuant = Bools2IntQKVQuant(inputB2I);
static std::map<uint32_t, std::string> opIniTableLogN = {
{99, "PagedAttentionOperationLogN1RazorOffset1QLens0Mask0Batch0Quant1Offset1"},
{98, "PagedAttentionOperationLogN1RazorOffset1QLens0Mask0Batch0Quant1Offset0"},
{96, "PagedAttentionOperationLogN1RazorOffset1QLens0Mask0Batch0Quant0Offset0"},
{88, "PagedAttentionOperationLogN1RazorOffset0QLens1Mask1Batch0Quant0Offset0"},
{79, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch1Quant1Offset1"},
{78, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch1Quant1Offset0"},
{76, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch1Quant0Offset0"},
{75, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch0Quant1Offset1"},
{74, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch0Quant1Offset0"},
{72, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask1Batch0Quant0Offset0"},
{71, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch1Quant1Offset1"},
{70, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch1Quant1Offset0"},
{68, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch1Quant0Offset0"},
{67, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch0Quant1Offset1"},
{66, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch0Quant1Offset0"},
{64, "PagedAttentionOperationLogN1RazorOffset0QLens0Mask0Batch0Quant0Offset0"},
};
static std::map<uint32_t, std::string> opIniTableQKVQuant = {
{160, "PagedAttentionOperationQKVQuantOnline1QKVQuantOffline0RazorOffset1QLens0Mask0Batch0Quant0Offset0"},
{140, "PagedAttentionOperationQKVQuantOnline1QKVQuantOffline0RazorOffset0QLens0Mask1Batch1Quant0Offset0"},
{136, "PagedAttentionOperationQKVQuantOnline1QKVQuantOffline0RazorOffset0QLens0Mask1Batch0Quant0Offset0"},
{132, "PagedAttentionOperationQKVQuantOnline1QKVQuantOffline0RazorOffset0QLens0Mask0Batch1Quant0Offset0"},
{128, "PagedAttentionOperationQKVQuantOnline1QKVQuantOffline0RazorOffset0QLens0Mask0Batch0Quant0Offset0"},
{96, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline1RazorOffset1QLens0Mask0Batch0Quant0Offset0"},
{76, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline1RazorOffset0QLens0Mask1Batch1Quant0Offset0"},
{72, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline1RazorOffset0QLens0Mask1Batch0Quant0Offset0"},
{68, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline1RazorOffset0QLens0Mask0Batch1Quant0Offset0"},
{64, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline1RazorOffset0QLens0Mask0Batch0Quant0Offset0"},
{35, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset1QLens0Mask0Batch0Quant1Offset1"},
{34, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset1QLens0Mask0Batch0Quant1Offset0"},
{32, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset1QLens0Mask0Batch0Quant0Offset0"},
{24, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens1Mask1Batch0Quant0Offset0"},
{15, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch1Quant1Offset1"},
{14, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch1Quant1Offset0"},
{12, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch1Quant0Offset0"},
{11, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch0Quant1Offset1"},
{10, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch0Quant1Offset0"},
{8, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask1Batch0Quant0Offset0"},
{7, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch1Quant1Offset1"},
{6, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch1Quant1Offset0"},
{4, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch1Quant0Offset0"},
{3, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch0Quant1Offset1"},
{2, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch0Quant1Offset0"},
{0, "PagedAttentionOperationQKVQuantOnline0QKVQuantOffline0RazorOffset0QLens0Mask0Batch0Quant0Offset0"},
};
std::map<uint32_t, std::string>::const_iterator itLogN = opIniTableLogN.find(caseCodeLogN);
if (itLogN != opIniTableLogN.end()) {
operationIr_ = GetSingleton<AtbOperationIrCfg>().GetOperationIr(itLogN->second);
return;
}
std::map<uint32_t, std::string>::const_iterator itQKVQuant = opIniTableQKVQuant.find(caseCodeQKVQuant);
if (itQKVQuant != opIniTableQKVQuant.end()) {
operationIr_ = GetSingleton<AtbOperationIrCfg>().GetOperationIr(itQKVQuant->second);
return;
}
if (operationIr_ == nullptr) {
ATB_LOG(ERROR) << GetLogPrefix() << "No matched param for op ini";
}
}
std::shared_ptr<Runner> PagedAttentionOperation::CreateRunner(Context &context) const
{
ContextBase *contextBase = dynamic_cast<ContextBase *>(&context);
if (!contextBase) {
ATB_LOG(DEBUG) << "context cast to contextBase failed!";
return nullptr;
}
RunnerPool &pool = contextBase->GetRunnerPool(RUNNER_TYPE_PAGED_ATTENTION);
if (!GetSingleton<Config>().Is910B()) {
Runner *runner = pool.MallocRunner<PagedAttentionOpsRunner910A, infer::PagedAttentionParam>(param_);
return runner ? std::shared_ptr<Runner>(runner, [&pool](Runner *runner) { pool.FreeRunner(runner); }) :
std::make_shared<PagedAttentionOpsRunner910A>(param_);
}
return std::make_shared<PagedAttentionOpsRunner>(param_);
}
nlohmann::json PagedAttentionOperation::GetParamJson() const
{
return OpParamToJson(param_);
}
} // namespace atb
# Reference
https://gitee.com/ascend/ascend-transformer-boost/blob/master/src/ops_infer/paged_attention/paged_attention_operation.cpp
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐


所有评论(0)