!14287 fix loss scale for cache embedding

From: @fangzehua
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2021-03-31 09:51:48 +08:00 committed by Gitee
commit d346a861bc
2 changed files with 6 additions and 0 deletions

View File

@ -714,6 +714,11 @@ void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
return;
}
for (auto &node : cnodes) {
if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
}
}
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";

View File

@ -465,6 +465,7 @@ inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox"
inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast");
inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
inline const PrimitivePtr kPrimPull = std::make_shared<Primitive>("Pull");
inline const PrimitivePtr kPrimNPUAllocFloatStatus = std::make_shared<Primitive>("NPUAllocFloatStatus");
// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");