add adamweightdecay to opset and update lock input tensors

This commit is contained in:
kswang 2021-07-30 17:37:32 +08:00
parent 2ae65ae387
commit 8fa85cac34
3 changed files with 16 additions and 1 deletions

View File

@ -1339,6 +1339,9 @@ void KernelGraph::SetOptimizerFlag() {
for (const auto &cnode : execution_order_) {
MS_EXCEPTION_IF_NULL(cnode);
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
continue;
}
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
has_optimizer_ = true;
} else if (node_name.find("Assign") == string::npos) {

View File

@ -1602,8 +1602,17 @@ std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const Graph
if (!graph->has_optimizer()) {
return {};
}
auto input_nodes = graph->inputs();
bool check_monad = false;
if (input_nodes.size() == inputs.size()) {
check_monad = true;
}
std::vector<tensor::TensorPtr> result;
for (auto &tensor : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (check_monad && HasAbstractMonad(input_nodes[i])) {
continue;
}
auto &tensor = inputs[i];
if (!tensor->IsGraphOutput()) {
result.emplace_back(tensor);
}

View File

@ -202,6 +202,7 @@ constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt";
constexpr auto kStridedReadOpName = "StridedRead";
constexpr auto kStridedWriteOpName = "StridedWrite";
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
@ -320,6 +321,7 @@ constexpr auto kAttrInputNames = "input_names";
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrAsync = "async";
constexpr auto kAttrVisited = "visited";
constexpr auto kAttrShape = "shape";
constexpr auto kAttrMomentum = "momentum";
@ -581,6 +583,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kAdamApplyOneWithDecayOpName,
kAdamApplyOneWithDecayAssignOpName,
kFusedAdamWeightDecayName,
kAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedMulApplyMomentumOpName,