forked from mindspore-Ecosystem/mindspore
Fix issues:
1.Reuse realtuplegetitem op infer. 2.Get ms_role by env. 3.Fix acc error for optimizer_split case.
This commit is contained in:
parent
b834aa3a45
commit
d30c4cac6a
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
@ -385,34 +386,6 @@ void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<Kern
|
|||
<< ". Output object types: " << *output_obj_type;
|
||||
}
|
||||
|
||||
void UpdateAbsForTupleGetItem(const CNodePtr &tuple_get_item_node) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_node);
|
||||
if (!IsPrimitiveCNode(tuple_get_item_node, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(EXCEPTION) << "Node should be TupleGetItem, but got " << tuple_get_item_node->fullname_with_scope() << ", "
|
||||
<< tuple_get_item_node->DebugString();
|
||||
}
|
||||
auto tuple_input = common::AnfAlgo::GetInputNode(tuple_get_item_node, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
auto input_abs = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
if (!input_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "TupleGetItem's first input abstract should be Sequence, but got " << input_abs->ToString();
|
||||
}
|
||||
|
||||
auto seq_abs = input_abs->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
AbstractBasePtrList seq_element = seq_abs->elements();
|
||||
// This method is used for TupleGetItem to RealTupleGetItem converting, the tuple elements must be scalar for now.
|
||||
for (const auto &ele : seq_element) {
|
||||
if (!ele->isa<abstract::AbstractScalar>() && !ele->isa<abstract::AbstractTensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Element of the tuple should be scalar or tensor, but got " << ele->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
int64_t item_index = GetGetitemIndex(tuple_get_item_node);
|
||||
tuple_get_item_node->set_abstract(seq_element[item_index]);
|
||||
}
|
||||
|
||||
// A map of kernel object type pairs to processing functions.
|
||||
static std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
|
||||
|
||||
|
@ -680,8 +653,10 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphP
|
|||
kg->AddValueNodeToGraph(index_input->cast<ValueNodePtr>());
|
||||
}
|
||||
|
||||
// Need to update TupleGetItem abstract.
|
||||
UpdateAbsForTupleGetItem(node);
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimRealTupleGetItem, {input, index_input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for RealTupleGetItem op is " << abs->ToString();
|
||||
node->set_abstract(abs);
|
||||
|
||||
// The primitive of user is changed.
|
||||
*new_prim = true;
|
||||
|
|
|
@ -109,8 +109,6 @@ std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode);
|
|||
void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type,
|
||||
std::vector<KernelObjectType> *output_obj_type);
|
||||
|
||||
void UpdateAbsForTupleGetItem(const CNodePtr &tuple_get_item_node);
|
||||
|
||||
// After kernel selection phase, one kernel's acquired input type may not be the same as the actual input type(the input
|
||||
// node's output type). We need this pass to transform these types to valid types.
|
||||
class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
|
||||
|
|
|
@ -40,7 +40,7 @@ def _need_reset_device_target_for_ps(target):
|
|||
For Ascend backend, the card can't be occupied by multiple processes in distributed traning,
|
||||
so we need to reset the device target for some roles.
|
||||
'''
|
||||
is_server = (_get_ps_context("ms_role") in ["MS_PSERVER", "MS_SERVER", "MS_SCHED"])
|
||||
is_server = (os.getenv('MS_ROLE') in ["MS_PSERVER", "MS_SERVER", "MS_SCHED"])
|
||||
return is_server and target == "Ascend"
|
||||
|
||||
|
||||
|
|
|
@ -30,9 +30,10 @@ def test_split_ref_without_optim():
|
|||
)
|
||||
if return_code != 0:
|
||||
os.system(f"echo '\n**************** Worker Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' -C 15 ./worker*/worker*.log")
|
||||
os.system(f"grep -E 'ERROR|Error' -C 15 ./worker*/worker*.log")
|
||||
os.system(f"cat ./worker*/worker*.log | grep acc")
|
||||
os.system(f"echo '\n**************** Scheduler Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' -C 15 ./sched/sched.log")
|
||||
os.system(f"grep -E 'ERROR|Error' -C 15 ./sched/sched.log")
|
||||
assert return_code == 0
|
||||
|
||||
|
||||
|
@ -50,7 +51,8 @@ def test_split_optim():
|
|||
)
|
||||
if return_code != 0:
|
||||
os.system(f"echo '\n**************** Worker Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' -C 15 ./worker*/worker*.log")
|
||||
os.system(f"grep -E 'ERROR|Error' -C 15 ./worker*/worker*.log")
|
||||
os.system(f"cat ./worker*/worker*.log | grep acc")
|
||||
os.system(f"echo '\n**************** Scheduler Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' -C 15 ./sched/sched.log")
|
||||
os.system(f"grep -E 'ERROR|Error' -C 15 ./sched/sched.log")
|
||||
assert return_code == 0
|
||||
|
|
|
@ -54,7 +54,7 @@ net_need_split_opt_map = {
|
|||
def run_dist():
|
||||
init()
|
||||
net = net_name_map.get(net_name)(dist=True)
|
||||
opt = get_optimizer(net, net_need_split_opt_map.get(net_name))
|
||||
opt = get_optimizer(net, dist=net_need_split_opt_map.get(net_name))
|
||||
criterion = get_loss()
|
||||
model = Model(net, criterion, opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
|
|
Loading…
Reference in New Issue