!49395 Speedup Pynative Dynamic Shape
Merge pull request !49395 from 王禹程/sp
This commit is contained in:
commit
8254a39e40
|
@ -505,12 +505,11 @@ BackendOpRunInfoPtr SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, cons
|
|||
const auto &shape = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
|
||||
bool is_gradient_out =
|
||||
graph_output_info != nullptr &&
|
||||
std::any_of(graph_output_info->output_indexes.cbegin(), graph_output_info->output_indexes.cend(),
|
||||
[cnode](const std::pair<KernelWithIndex, std::vector<std::vector<size_t>>> &output_index) {
|
||||
return output_index.first.first == cnode;
|
||||
});
|
||||
bool is_gradient_out = false;
|
||||
if (graph_output_info != nullptr) {
|
||||
auto lb_iter = graph_output_info->output_indexes.lower_bound({cnode, 0});
|
||||
is_gradient_out = lb_iter != graph_output_info->output_indexes.end() && lb_iter->first.first == cnode;
|
||||
}
|
||||
pynative::BaseOpRunInfo base_op_run_info;
|
||||
base_op_run_info.is_mixed_precision_cast = false;
|
||||
base_op_run_info.lazy_build = !shape->IsDynamic();
|
||||
|
|
|
@ -213,11 +213,11 @@ void AclKernelMod::UpdateReduceAxisAttr(const AnfNodePtr &node) {
|
|||
opt::NormalizeReduceAttrAxis(cnode);
|
||||
}
|
||||
|
||||
void AclKernelMod::ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr) {
|
||||
void AclKernelMod::ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr,
|
||||
const std::vector<string> &input_names) {
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &attr_to_input_maps = GeOpConvertor::GetNeedAddInput(node, true);
|
||||
const auto &input_names = kernel::AclUtils::GetOpInputAnchorNames(node);
|
||||
UpdateReduceAxisAttr(node);
|
||||
auto attr_list = GeOpConvertor::GetAttrAndValue(node, true);
|
||||
for (auto &[attr_name, value] : attr_list) {
|
||||
|
@ -246,11 +246,24 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
}
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_op_runtime_info = node->user_data<runtime::OpRuntimeInfo>();
|
||||
bool node_acl_runtime_info_legal = node_op_runtime_info != nullptr &&
|
||||
node_op_runtime_info->acl_runtime_info_ != nullptr &&
|
||||
node_op_runtime_info->acl_runtime_info_->use();
|
||||
const auto &input_names =
|
||||
(node_acl_runtime_info_legal && !node_op_runtime_info->acl_runtime_info_->is_dynamic_input_size())
|
||||
? node_op_runtime_info->acl_runtime_info_->input_names()
|
||||
: AclUtils::GetOpInputAnchorNames(node);
|
||||
const auto &output_names =
|
||||
(node_acl_runtime_info_legal && !node_op_runtime_info->acl_runtime_info_->is_dynamic_output_size())
|
||||
? node_op_runtime_info->acl_runtime_info_->output_names()
|
||||
: AclUtils::GetOpOutputAnchorNames(node);
|
||||
|
||||
auto op_desc_ptr = std::make_shared<AclOpDesc>(op_type_, node);
|
||||
MS_EXCEPTION_IF_NULL(op_desc_ptr);
|
||||
op_desc_ptr->AddTensorDesc(input_desc_list_, output_desc_list_);
|
||||
op_desc_ptr->AddDataBuf(inputs, input_size_list_, outputs, output_size_list_);
|
||||
ProcessAttribute(op_desc_ptr);
|
||||
op_desc_ptr->AddDataBuf(inputs, input_size_list_, outputs, output_size_list_, input_names, output_names);
|
||||
ProcessAttribute(op_desc_ptr, input_names);
|
||||
op_desc_ptr->ClearNullTensor();
|
||||
|
||||
// cppcheck-suppress unreadVariable
|
||||
|
|
|
@ -48,7 +48,7 @@ class AclKernelMod : public AscendKernelMod {
|
|||
|
||||
protected:
|
||||
void SyncData() override;
|
||||
void ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr);
|
||||
void ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr, const std::vector<string> &input_names);
|
||||
void UpdateReduceAxisAttr(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
|
|
|
@ -186,10 +186,10 @@ void AclOpDesc::AddTensorDesc(const std::vector<GeTensorDescPtr> &inputs, const
|
|||
}
|
||||
|
||||
void AclOpDesc::AddDataBuf(const std::vector<AddressPtr> &inputs, const std::vector<size_t> &input_size_list,
|
||||
const std::vector<AddressPtr> &outputs, const std::vector<size_t> &output_size_list) {
|
||||
const std::vector<AddressPtr> &outputs, const std::vector<size_t> &output_size_list,
|
||||
const std::vector<std::string> &input_names, const std::vector<std::string> &output_names) {
|
||||
auto node = anf_node_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &input_names = AclUtils::GetOpInputAnchorNames(node);
|
||||
input_tensor_data_.clear();
|
||||
input_tensor_data_.resize(input_names.size(), nullptr);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
|
@ -209,7 +209,7 @@ void AclOpDesc::AddDataBuf(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
}
|
||||
input_tensor_data_[idx] = CreateDataBuf(inputs[i], input_size_list[idx]);
|
||||
}
|
||||
const auto &output_names = AclUtils::GetOpOutputAnchorNames(node);
|
||||
|
||||
output_tensor_data_.clear();
|
||||
output_tensor_data_.resize(output_names.size(), aclCreateDataBuffer(nullptr, 0));
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "transform/graph_ir/convert.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/oplib/super_bar.h"
|
||||
#include "runtime/pynative/op_runtime_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -52,7 +53,8 @@ class AclOpDesc {
|
|||
|
||||
void AddTensorDesc(const std::vector<GeTensorDescPtr> &inputs, const std::vector<GeTensorDescPtr> &outputs);
|
||||
void AddDataBuf(const std::vector<AddressPtr> &inputs, const std::vector<size_t> &input_size_list,
|
||||
const std::vector<AddressPtr> &outputs, const std::vector<size_t> &output_size_list);
|
||||
const std::vector<AddressPtr> &outputs, const std::vector<size_t> &output_size_list,
|
||||
const std::vector<std::string> &input_names, const std::vector<std::string> &output_names);
|
||||
void ProcessAclAttrs(const std::string &attr_name, const ValuePtr &value, const ProcessAttrMode &mode);
|
||||
void ClearNullTensor();
|
||||
|
||||
|
|
|
@ -456,8 +456,8 @@ std::optional<ShapeVector> DeviceShapeTransfer::GetFixedDeviceShape(const ShapeV
|
|||
|
||||
ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type,
|
||||
int64_t groups, const ShapeVector &input_hidden_size) const {
|
||||
using DeviceShapeTransfer = std::function<ShapeVector(const ShapeVector &, const TypeId &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map = {
|
||||
using DeviceShapeTransferFunc = std::function<ShapeVector(const ShapeVector &, const TypeId &)>;
|
||||
static const mindspore::HashMap<std::string, DeviceShapeTransferFunc> device_shape_map = {
|
||||
{kOpFormat_NCHW, NCHWDeviceShape},
|
||||
{kOpFormat_NHWC, NHWCDeviceShape},
|
||||
{kOpFormat_HWCN, HWCNDeviceShape},
|
||||
|
|
|
@ -443,7 +443,7 @@ GraphId GraphCompiler::CompileDynamicGraph(const GraphSegmentPtr &segment, const
|
|||
|
||||
graph->UpdateGraphAquireGilAttr();
|
||||
graph->SetInputNodes();
|
||||
auto manager = MakeManager({graph});
|
||||
auto manager = Manage(graph);
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
|
@ -459,7 +459,7 @@ GraphId GraphCompiler::CompileDynamicGraph(const GraphSegmentPtr &segment, const
|
|||
graph->set_root_graph_id(graph_id);
|
||||
session_->DumpGraphs({graph});
|
||||
|
||||
auto exec_nodes = graph->execution_order();
|
||||
auto &exec_nodes = graph->execution_order();
|
||||
std::for_each(exec_nodes.begin(), exec_nodes.end(),
|
||||
[](const CNodePtr &node) { common::AnfAlgo::SetNodeAttr(kAttrMutableKernel, MakeValue(true), node); });
|
||||
|
||||
|
|
|
@ -21,38 +21,38 @@
|
|||
|
||||
namespace mindspore {
|
||||
bool IsOneOfPosteriorOperator(const std::string &name) {
|
||||
const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
|
||||
static const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
|
||||
|
||||
auto iter = kPosteriorOperatorSet.find(name);
|
||||
return iter != kPosteriorOperatorSet.end();
|
||||
}
|
||||
|
||||
bool IsOneOfCacheBlackList(const std::string &name) {
|
||||
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
|
||||
kGetNextOpName};
|
||||
static const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
|
||||
kGetNextOpName};
|
||||
|
||||
auto iter = kOpCacheBlackList.find(name);
|
||||
return iter != kOpCacheBlackList.end();
|
||||
}
|
||||
|
||||
bool IsOneOf3DFormat(const std::string &format) {
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
static const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
|
||||
auto iter = k3DFormatSet.find(format);
|
||||
return iter != k3DFormatSet.end();
|
||||
}
|
||||
|
||||
bool IsOneOfNoPaddingFormat(const std::string &format) {
|
||||
const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN,
|
||||
kOpFormat_ND_RNN_BIAS, kOpFormat_DEFAULT};
|
||||
static const std::set<std::string> kNoPaddingFormatSet = {
|
||||
kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN, kOpFormat_ND_RNN_BIAS, kOpFormat_DEFAULT};
|
||||
|
||||
auto iter = kNoPaddingFormatSet.find(format);
|
||||
return iter != kNoPaddingFormatSet.end();
|
||||
}
|
||||
|
||||
bool IsOneOfDynamicShapeConstInputToAttrGPU(const std::string &name) {
|
||||
const std::set<std::string> DynamicShapeConstInputToAttrGPU = {
|
||||
static const std::set<std::string> DynamicShapeConstInputToAttrGPU = {
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName,
|
||||
kReduceSumOpName, kReduceMinOpName, kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName,
|
||||
kReduceAnyOpName, kConcatOpName, kScatterNdOpName, kGatherOpName, kAvgPool3DGradOpName};
|
||||
|
@ -69,120 +69,120 @@ bool IsOneOfCustomAkgType(const std::string &name) {
|
|||
}
|
||||
|
||||
bool IsOneOfOperator(const std::string &name) {
|
||||
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
||||
kApplyMomentumOpName,
|
||||
kApplyMomentumDOpName,
|
||||
kApplyAdadeltaOpName,
|
||||
kApplyAdadeltaDOpName,
|
||||
kApplyAdagradOpName,
|
||||
kApplyAdagradDOpName,
|
||||
kApplyAdagradDAOpName,
|
||||
kApplyAdagradDADOpName,
|
||||
kAdamOpName,
|
||||
kApplyAdamDOpName,
|
||||
kApplyAdamOpName,
|
||||
kApplyAdaMaxOpName,
|
||||
kApplyAdaMaxDOpName,
|
||||
kApplyAddSignOpName,
|
||||
kApplyAddSignDOpName,
|
||||
kApplyCenteredRMSPOpName,
|
||||
kApplyFtrlOpName,
|
||||
kApplyFtrlDOpName,
|
||||
kApplyFtrlV2OpName,
|
||||
kApplyFtrlV2DOpName,
|
||||
kApplyGradientDescentOpName,
|
||||
kApplyPowerSignOpName,
|
||||
kApplyPowerSignDOpName,
|
||||
kApplyProximalAdagradOpName,
|
||||
kApplyProximalAdagradDOpName,
|
||||
kApplyProximalGradientDescentOpName,
|
||||
kApplyRMSPropOpName,
|
||||
kApplyRMSPropDOpname,
|
||||
kAdamApplyOneWithDecayOpName,
|
||||
kAdamApplyOneWithDecayAssignOpName,
|
||||
kFusedAdamWeightDecayName,
|
||||
kAdamWeightDecayName,
|
||||
kFusedCastAdamWeightDecayName,
|
||||
kFusedAdamName,
|
||||
kFusedAdaFactorName,
|
||||
kFusedAdaFactorWithGlobalNormName,
|
||||
kFusedSparseAdamName,
|
||||
kFusedMulApplyMomentumOpName,
|
||||
kFusedWeightScaleApplyMomentum,
|
||||
kFusedScaleApplyMomentum,
|
||||
kApplyCenteredRMSPropOpName,
|
||||
kApplyCenteredRMSPropDOpName,
|
||||
kFusedSparseFtrlName,
|
||||
kFusedSparseProximalAdagradName,
|
||||
kFusedSparseLazyAdamName,
|
||||
kSparseApplyFtrlOpName,
|
||||
kSparseApplyFtrlDOpName,
|
||||
kSparseApplyFtrlV2OpName,
|
||||
kSparseApplyFtrlV2DOpName,
|
||||
kSGDName,
|
||||
kLARSUpdateOpName,
|
||||
kLarsV2UpdateOpName,
|
||||
kCombineMomentumWeightOpName,
|
||||
kCombineMomentumOpName,
|
||||
kScatterAddOpName,
|
||||
kScatterUpdateOpName,
|
||||
kSparseApplyProximalAdagradOpName,
|
||||
kSparseApplyProximalAdagradDOpName,
|
||||
kAdaptiveMaxPool2dOpName,
|
||||
kApplyKerasMomentumDOpName};
|
||||
static const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
||||
kApplyMomentumOpName,
|
||||
kApplyMomentumDOpName,
|
||||
kApplyAdadeltaOpName,
|
||||
kApplyAdadeltaDOpName,
|
||||
kApplyAdagradOpName,
|
||||
kApplyAdagradDOpName,
|
||||
kApplyAdagradDAOpName,
|
||||
kApplyAdagradDADOpName,
|
||||
kAdamOpName,
|
||||
kApplyAdamDOpName,
|
||||
kApplyAdamOpName,
|
||||
kApplyAdaMaxOpName,
|
||||
kApplyAdaMaxDOpName,
|
||||
kApplyAddSignOpName,
|
||||
kApplyAddSignDOpName,
|
||||
kApplyCenteredRMSPOpName,
|
||||
kApplyFtrlOpName,
|
||||
kApplyFtrlDOpName,
|
||||
kApplyFtrlV2OpName,
|
||||
kApplyFtrlV2DOpName,
|
||||
kApplyGradientDescentOpName,
|
||||
kApplyPowerSignOpName,
|
||||
kApplyPowerSignDOpName,
|
||||
kApplyProximalAdagradOpName,
|
||||
kApplyProximalAdagradDOpName,
|
||||
kApplyProximalGradientDescentOpName,
|
||||
kApplyRMSPropOpName,
|
||||
kApplyRMSPropDOpname,
|
||||
kAdamApplyOneWithDecayOpName,
|
||||
kAdamApplyOneWithDecayAssignOpName,
|
||||
kFusedAdamWeightDecayName,
|
||||
kAdamWeightDecayName,
|
||||
kFusedCastAdamWeightDecayName,
|
||||
kFusedAdamName,
|
||||
kFusedAdaFactorName,
|
||||
kFusedAdaFactorWithGlobalNormName,
|
||||
kFusedSparseAdamName,
|
||||
kFusedMulApplyMomentumOpName,
|
||||
kFusedWeightScaleApplyMomentum,
|
||||
kFusedScaleApplyMomentum,
|
||||
kApplyCenteredRMSPropOpName,
|
||||
kApplyCenteredRMSPropDOpName,
|
||||
kFusedSparseFtrlName,
|
||||
kFusedSparseProximalAdagradName,
|
||||
kFusedSparseLazyAdamName,
|
||||
kSparseApplyFtrlOpName,
|
||||
kSparseApplyFtrlDOpName,
|
||||
kSparseApplyFtrlV2OpName,
|
||||
kSparseApplyFtrlV2DOpName,
|
||||
kSGDName,
|
||||
kLARSUpdateOpName,
|
||||
kLarsV2UpdateOpName,
|
||||
kCombineMomentumWeightOpName,
|
||||
kCombineMomentumOpName,
|
||||
kScatterAddOpName,
|
||||
kScatterUpdateOpName,
|
||||
kSparseApplyProximalAdagradOpName,
|
||||
kSparseApplyProximalAdagradDOpName,
|
||||
kAdaptiveMaxPool2dOpName,
|
||||
kApplyKerasMomentumDOpName};
|
||||
|
||||
auto iter = kOptOperatorSet.find(name);
|
||||
return iter != kOptOperatorSet.end();
|
||||
}
|
||||
|
||||
bool IsOneOfNotSupportedTransFormat(const std::string &format) {
|
||||
const std::set<std::string> kNotSupportedFormat = {kOpFormat_DHWCN, kOpFormat_NDHWC, kOpFormat_CHWN};
|
||||
static const std::set<std::string> kNotSupportedFormat = {kOpFormat_DHWCN, kOpFormat_NDHWC, kOpFormat_CHWN};
|
||||
return (kNotSupportedFormat.find(format) != kNotSupportedFormat.end());
|
||||
}
|
||||
|
||||
bool IsOneOfComputeDepend(const std::string &name) {
|
||||
const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
||||
kUniqueConsecutiveOpName,
|
||||
kComputeAccidentalHitsOpName,
|
||||
kSubAndFilterOpName,
|
||||
kPadAndShiftOpName,
|
||||
kCTCGreedyDecoderOpName,
|
||||
kMaskedSelectOpName,
|
||||
kDynamicStitchOpName,
|
||||
kGetNextOpName,
|
||||
kListDiffOpName,
|
||||
kNonMaxSuppressionV3OpName,
|
||||
kNonMaxSuppressionWithOverlapsOpName,
|
||||
kCoalesceOpName,
|
||||
kTruncatedNormal,
|
||||
kNonDeterministicInts,
|
||||
kFractionalAvgPoolGradOpName,
|
||||
kDenseToDenseSetOperation,
|
||||
kDenseToSparseSetOperation,
|
||||
kSegmentMaxOpName,
|
||||
kCSRSparseMatrixToSparseTensorOpName,
|
||||
kSegmentMinOpName,
|
||||
kLuUnpackOpName,
|
||||
kSegmentSumOpName,
|
||||
kResizeBicubicOpName,
|
||||
kResizeAreaOpName,
|
||||
kSegmentMeanOpName,
|
||||
kSegmentProdOpName,
|
||||
kSparseSliceOpName,
|
||||
kNonZeroOpName,
|
||||
kSparseSparseMinimumOpName,
|
||||
kSparseSparseMaximumOpName,
|
||||
kRpcRecvOpName,
|
||||
kSparseFillEmptyRows,
|
||||
kSparseCrossOpName,
|
||||
kAdaptiveMaxPool3DGradOpName};
|
||||
static const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
||||
kUniqueConsecutiveOpName,
|
||||
kComputeAccidentalHitsOpName,
|
||||
kSubAndFilterOpName,
|
||||
kPadAndShiftOpName,
|
||||
kCTCGreedyDecoderOpName,
|
||||
kMaskedSelectOpName,
|
||||
kDynamicStitchOpName,
|
||||
kGetNextOpName,
|
||||
kListDiffOpName,
|
||||
kNonMaxSuppressionV3OpName,
|
||||
kNonMaxSuppressionWithOverlapsOpName,
|
||||
kCoalesceOpName,
|
||||
kTruncatedNormal,
|
||||
kNonDeterministicInts,
|
||||
kFractionalAvgPoolGradOpName,
|
||||
kDenseToDenseSetOperation,
|
||||
kDenseToSparseSetOperation,
|
||||
kSegmentMaxOpName,
|
||||
kCSRSparseMatrixToSparseTensorOpName,
|
||||
kSegmentMinOpName,
|
||||
kLuUnpackOpName,
|
||||
kSegmentSumOpName,
|
||||
kResizeBicubicOpName,
|
||||
kResizeAreaOpName,
|
||||
kSegmentMeanOpName,
|
||||
kSegmentProdOpName,
|
||||
kSparseSliceOpName,
|
||||
kNonZeroOpName,
|
||||
kSparseSparseMinimumOpName,
|
||||
kSparseSparseMaximumOpName,
|
||||
kRpcRecvOpName,
|
||||
kSparseFillEmptyRows,
|
||||
kSparseCrossOpName,
|
||||
kAdaptiveMaxPool3DGradOpName};
|
||||
|
||||
auto iter = kComputeDepend.find(name);
|
||||
return iter != kComputeDepend.end();
|
||||
}
|
||||
|
||||
bool IsOneOfHWSpecialFormat(const std::string &format) {
|
||||
const std::set<std::string> kHWSpecialFormatSet = {
|
||||
static const std::set<std::string> kHWSpecialFormatSet = {
|
||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM,
|
||||
kOpFormat_FRACTAL_ZN_RNN, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
|
||||
|
@ -192,7 +192,7 @@ bool IsOneOfHWSpecialFormat(const std::string &format) {
|
|||
}
|
||||
|
||||
bool IsOneOfFormat(const std::string &format) {
|
||||
const std::set<std::string> kOpFormatList = {
|
||||
static const std::set<std::string> kOpFormatList = {
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
|
||||
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
kOpFormat_CHWN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
|
||||
|
@ -207,7 +207,7 @@ bool IsOneOfFormat(const std::string &format) {
|
|||
}
|
||||
|
||||
bool IsOneOfServerFormatC04(const std::string &format) {
|
||||
const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
static const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue