forked from mindspore-Ecosystem/mindspore
add adamweightdecay to opset and update lock input tensors
This commit is contained in:
parent
2ae65ae387
commit
8fa85cac34
|
@ -1339,6 +1339,9 @@ void KernelGraph::SetOptimizerFlag() {
|
|||
for (const auto &cnode : execution_order_) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
|
||||
continue;
|
||||
}
|
||||
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
|
||||
has_optimizer_ = true;
|
||||
} else if (node_name.find("Assign") == string::npos) {
|
||||
|
|
|
@ -1602,8 +1602,17 @@ std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const Graph
|
|||
if (!graph->has_optimizer()) {
|
||||
return {};
|
||||
}
|
||||
auto input_nodes = graph->inputs();
|
||||
bool check_monad = false;
|
||||
if (input_nodes.size() == inputs.size()) {
|
||||
check_monad = true;
|
||||
}
|
||||
std::vector<tensor::TensorPtr> result;
|
||||
for (auto &tensor : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (check_monad && HasAbstractMonad(input_nodes[i])) {
|
||||
continue;
|
||||
}
|
||||
auto &tensor = inputs[i];
|
||||
if (!tensor->IsGraphOutput()) {
|
||||
result.emplace_back(tensor);
|
||||
}
|
||||
|
|
|
@ -202,6 +202,7 @@ constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt";
|
|||
constexpr auto kStridedReadOpName = "StridedRead";
|
||||
constexpr auto kStridedWriteOpName = "StridedWrite";
|
||||
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
||||
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
|
||||
constexpr auto kFusedAdamName = "FusedAdam";
|
||||
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
|
||||
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
|
||||
|
@ -320,6 +321,7 @@ constexpr auto kAttrInputNames = "input_names";
|
|||
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
||||
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||
constexpr auto kAttrOutputNames = "output_names";
|
||||
constexpr auto kAttrAsync = "async";
|
||||
constexpr auto kAttrVisited = "visited";
|
||||
constexpr auto kAttrShape = "shape";
|
||||
constexpr auto kAttrMomentum = "momentum";
|
||||
|
@ -581,6 +583,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
|||
kAdamApplyOneWithDecayOpName,
|
||||
kAdamApplyOneWithDecayAssignOpName,
|
||||
kFusedAdamWeightDecayName,
|
||||
kAdamWeightDecayName,
|
||||
kFusedAdamName,
|
||||
kFusedSparseAdamName,
|
||||
kFusedMulApplyMomentumOpName,
|
||||
|
|
Loading…
Reference in New Issue