forked from mindspore-Ecosystem/mindspore
!18950 adapt for multi-frontend cj
Merge pull request !18950 from qingshanxiaozi/adapt_for_cj
This commit is contained in:
commit
2ffb014dc9
|
@ -419,8 +419,9 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern(
|
||||||
VarPtr x1 = std::make_shared<Var>();
|
VarPtr x1 = std::make_shared<Var>();
|
||||||
VarPtr x2 = std::make_shared<Var>();
|
VarPtr x2 = std::make_shared<Var>();
|
||||||
VarPtr x3 = std::make_shared<Var>();
|
VarPtr x3 = std::make_shared<Var>();
|
||||||
|
VarPtr x4 = std::make_shared<Var>();
|
||||||
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
||||||
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x4});
|
||||||
VectorRef depend(
|
VectorRef depend(
|
||||||
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
|
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
|
||||||
return VectorRef({prim::kPrimMul, depend, x3});
|
return VectorRef({prim::kPrimMul, depend, x3});
|
||||||
|
@ -518,19 +519,21 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
||||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||||
|
|
||||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
|
||||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
CNodePtr softmax_node;
|
CNodePtr softmax_node;
|
||||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true);
|
||||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node);
|
||||||
|
|
||||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||||
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
|
||||||
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
// Both of the forward loss function and the backward loss function from cangjie will match this pattern,
|
||||||
return reduce_node;
|
// the true branch is for the backward loss function, and the false branch is for the other one.
|
||||||
|
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||||
|
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||||
|
return softmax_node_outputs[1];
|
||||||
|
} else {
|
||||||
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true);
|
||||||
|
return reduce_node;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||||
|
|
|
@ -294,8 +294,10 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
} else {
|
} else {
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
// Add PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR pass first to avoid the backward loss function
|
||||||
|
// from the python frontend matching the pattern defined in PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR.
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
|
||||||
}
|
}
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR1>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||||
|
|
|
@ -300,9 +300,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
#if !defined(_WIN32) && !defined(_WIN64)
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
std::shared_ptr<Debugger> debugger_;
|
std::shared_ptr<Debugger> debugger_;
|
||||||
#endif
|
#endif
|
||||||
#if (ENABLE_CPU && !_WIN32)
|
|
||||||
bool initialized_ps_cache_{false};
|
|
||||||
#endif
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
||||||
|
|
|
@ -26,4 +26,9 @@ if(ENABLE_GE OR ENABLE_D)
|
||||||
list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES})
|
list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||||
|
string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
|
string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
add_library(_mindspore_pipeline_jit_obj OBJECT ${_PIPELINE_SRC_FILES})
|
add_library(_mindspore_pipeline_jit_obj OBJECT ${_PIPELINE_SRC_FILES})
|
||||||
|
|
|
@ -51,6 +51,11 @@ void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id, const std::vect
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelRuntimeManager &KernelRuntimeManager::Instance() {
|
||||||
|
static KernelRuntimeManager instance{};
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
|
void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) {
|
||||||
if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
|
if (runtime_creators_.find(device_name) == runtime_creators_.end()) {
|
||||||
(void)runtime_creators_.emplace(device_name, runtime_creator);
|
(void)runtime_creators_.emplace(device_name, runtime_creator);
|
||||||
|
|
|
@ -32,10 +32,7 @@ using KernelRuntimeCreator = std::function<std::shared_ptr<KernelRuntime>()>;
|
||||||
|
|
||||||
class KernelRuntimeManager {
|
class KernelRuntimeManager {
|
||||||
public:
|
public:
|
||||||
static KernelRuntimeManager &Instance() {
|
static KernelRuntimeManager &Instance();
|
||||||
static KernelRuntimeManager instance;
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
|
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
|
||||||
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
|
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
|
||||||
KernelRuntime *GetCurrentKernelRuntime();
|
KernelRuntime *GetCurrentKernelRuntime();
|
||||||
|
|
|
@ -17,9 +17,11 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
namespace {
|
||||||
const int CACHED_STR_NUM = 1 << 8;
|
const int CACHED_STR_NUM = 1 << 8;
|
||||||
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
|
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
|
||||||
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
|
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
|
||||||
|
} // namespace
|
||||||
const char *SafeCStr(const std::string &&str) {
|
const char *SafeCStr(const std::string &&str) {
|
||||||
static std::atomic<uint32_t> index{0};
|
static std::atomic<uint32_t> index{0};
|
||||||
uint32_t cur_index = index++;
|
uint32_t cur_index = index++;
|
||||||
|
|
Loading…
Reference in New Issue