forked from mindspore-Ecosystem/mindspore
add adamweightdecay to opset and update lock input tensors
This commit is contained in:
parent
2ae65ae387
commit
8fa85cac34
|
@ -1339,6 +1339,9 @@ void KernelGraph::SetOptimizerFlag() {
|
||||||
for (const auto &cnode : execution_order_) {
|
for (const auto &cnode : execution_order_) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
|
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
|
||||||
has_optimizer_ = true;
|
has_optimizer_ = true;
|
||||||
} else if (node_name.find("Assign") == string::npos) {
|
} else if (node_name.find("Assign") == string::npos) {
|
||||||
|
|
|
@ -1602,8 +1602,17 @@ std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const Graph
|
||||||
if (!graph->has_optimizer()) {
|
if (!graph->has_optimizer()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
auto input_nodes = graph->inputs();
|
||||||
|
bool check_monad = false;
|
||||||
|
if (input_nodes.size() == inputs.size()) {
|
||||||
|
check_monad = true;
|
||||||
|
}
|
||||||
std::vector<tensor::TensorPtr> result;
|
std::vector<tensor::TensorPtr> result;
|
||||||
for (auto &tensor : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (check_monad && HasAbstractMonad(input_nodes[i])) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto &tensor = inputs[i];
|
||||||
if (!tensor->IsGraphOutput()) {
|
if (!tensor->IsGraphOutput()) {
|
||||||
result.emplace_back(tensor);
|
result.emplace_back(tensor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -202,6 +202,7 @@ constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt";
|
||||||
constexpr auto kStridedReadOpName = "StridedRead";
|
constexpr auto kStridedReadOpName = "StridedRead";
|
||||||
constexpr auto kStridedWriteOpName = "StridedWrite";
|
constexpr auto kStridedWriteOpName = "StridedWrite";
|
||||||
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
||||||
|
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
|
||||||
constexpr auto kFusedAdamName = "FusedAdam";
|
constexpr auto kFusedAdamName = "FusedAdam";
|
||||||
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
|
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
|
||||||
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
|
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
|
||||||
|
@ -320,6 +321,7 @@ constexpr auto kAttrInputNames = "input_names";
|
||||||
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
||||||
constexpr auto kIsBackendCast = "is_backed_cast";
|
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||||
constexpr auto kAttrOutputNames = "output_names";
|
constexpr auto kAttrOutputNames = "output_names";
|
||||||
|
constexpr auto kAttrAsync = "async";
|
||||||
constexpr auto kAttrVisited = "visited";
|
constexpr auto kAttrVisited = "visited";
|
||||||
constexpr auto kAttrShape = "shape";
|
constexpr auto kAttrShape = "shape";
|
||||||
constexpr auto kAttrMomentum = "momentum";
|
constexpr auto kAttrMomentum = "momentum";
|
||||||
|
@ -581,6 +583,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
||||||
kAdamApplyOneWithDecayOpName,
|
kAdamApplyOneWithDecayOpName,
|
||||||
kAdamApplyOneWithDecayAssignOpName,
|
kAdamApplyOneWithDecayAssignOpName,
|
||||||
kFusedAdamWeightDecayName,
|
kFusedAdamWeightDecayName,
|
||||||
|
kAdamWeightDecayName,
|
||||||
kFusedAdamName,
|
kFusedAdamName,
|
||||||
kFusedSparseAdamName,
|
kFusedSparseAdamName,
|
||||||
kFusedMulApplyMomentumOpName,
|
kFusedMulApplyMomentumOpName,
|
||||||
|
|
Loading…
Reference in New Issue