forked from mindspore-Ecosystem/mindspore
!14287 fix loss scale for cache embedding
From: @fangzehua Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
d346a861bc
|
@ -714,6 +714,11 @@ void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
|
|||
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
|
||||
return;
|
||||
}
|
||||
for (auto &node : cnodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
|
||||
MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
|
||||
}
|
||||
}
|
||||
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
|
||||
if (unique_cache_enable.empty()) {
|
||||
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
|
||||
|
|
|
@ -465,6 +465,7 @@ inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox"
|
|||
inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast");
|
||||
inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
|
||||
inline const PrimitivePtr kPrimPull = std::make_shared<Primitive>("Pull");
|
||||
inline const PrimitivePtr kPrimNPUAllocFloatStatus = std::make_shared<Primitive>("NPUAllocFloatStatus");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
|
|
Loading…
Reference in New Issue