forked from mindspore-Ecosystem/mindspore
!18950 adapt for multi-frontend cj
Merge pull request !18950 from qingshanxiaozi/adapt_for_cj
This commit is contained in:
commit
2ffb014dc9
|
@ -419,8 +419,9 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern(
|
|||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
VarPtr x3 = std::make_shared<Var>();
|
||||
VarPtr x4 = std::make_shared<Var>();
|
||||
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x4});
|
||||
VectorRef depend(
|
||||
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
|
||||
return VectorRef({prim::kPrimMul, depend, x3});
|
||||
|
@ -518,19 +519,21 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||
return reduce_node;
|
||||
// Both of the forward loss function and the backward loss function from cangjie will match this pattern,
|
||||
// the true branch is for the backward loss function, and the false branch is for the other one.
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return softmax_node_outputs[1];
|
||||
} else {
|
||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||
return reduce_node;
|
||||
}
|
||||
}
|
||||
|
||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||
|
|
|
@ -294,8 +294,10 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
|||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
} else {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
// Add PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR pass first to avoid the backward loss function
|
||||
// from the python frontend matching the pattern defined in PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR.
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||
}
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
|
|
|
@ -300,9 +300,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::shared_ptr<Debugger> debugger_;
|
||||
#endif
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool initialized_ps_cache_{false};
|
||||
#endif
|
||||
};
|
||||
|
||||
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
||||
|
|
|
@ -26,4 +26,9 @@ if(ENABLE_GE OR ENABLE_D)
|
|||
list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES})
|
||||
endif()
|
||||
|
||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||
string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
add_library(_mindspore_pipeline_jit_obj OBJECT ${_PIPELINE_SRC_FILES})
|
||||
|
|
|
@ -51,6 +51,11 @@ void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id, const std::vect
|
|||
}
|
||||
}
|
||||
|
||||
KernelRuntimeManager &KernelRuntimeManager::Instance() {
|
||||
static KernelRuntimeManager instance{};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
|
||||
if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
|
||||
(void)runtime_creators_.emplace(device_name, runtime_creator);
|
||||
|
|
|
@ -32,10 +32,7 @@ using KernelRuntimeCreator = std::function<std::shared_ptr<KernelRuntime>()>;
|
|||
|
||||
class KernelRuntimeManager {
|
||||
public:
|
||||
static KernelRuntimeManager &Instance() {
|
||||
static KernelRuntimeManager instance;
|
||||
return instance;
|
||||
}
|
||||
static KernelRuntimeManager &Instance();
|
||||
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
|
||||
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
|
||||
KernelRuntime *GetCurrentKernelRuntime();
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace common {
|
||||
namespace {
|
||||
const int CACHED_STR_NUM = 1 << 8;
|
||||
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
|
||||
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
|
||||
} // namespace
|
||||
const char *SafeCStr(const std::string &&str) {
|
||||
static std::atomic<uint32_t> index{0};
|
||||
uint32_t cur_index = index++;
|
||||
|
|
Loading…
Reference in New Issue