From d695554ae5578848ce6ace0834dc8a49e6e7f7c6 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Tue, 15 Jun 2021 15:38:38 +0800 Subject: [PATCH] modify the error message of abstract join --- mindspore/ccsrc/debug/trace.cc | 16 ++ .../jit/static_analysis/static_analysis.cc | 77 +++++++--- .../jit/static_analysis/static_analysis.h | 2 +- mindspore/core/abstract/abstract_value.cc | 145 +++++++++++++----- mindspore/core/abstract/abstract_value.h | 9 ++ mindspore/core/abstract/utils.cc | 3 +- tests/ut/python/ops/test_control_ops.py | 2 +- 7 files changed, 189 insertions(+), 65 deletions(-) diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index e9cd715818b..0c5598d2f1e 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -35,6 +35,7 @@ #include "debug/common.h" #include "pipeline/jit/static_analysis/evaluator.h" #include "utils/log_adapter.h" +#include "abstract/abstract_value.h" namespace mindspore { // namespace to support debug trace information @@ -605,5 +606,20 @@ struct TraceProviderRegister { } ~TraceProviderRegister() = default; } trace_provider_regsiter; + +// Register trace cnode provider to AbstractBase. +struct TraceNodeProviderRegister { + TraceNodeProviderRegister() { + abstract::AbstractBase::set_trace_node_provider([](AnfNodePtr *node) { + auto stack = GetCNodeDebugStack(); + if (!stack.empty()) { + auto conf = stack.back(); + MS_EXCEPTION_IF_NULL(conf); + *node = conf->node(); + } + }); + } + ~TraceNodeProviderRegister() = default; +} trace_node_provider_regsiter; } // namespace trace } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index daaee7d83ff..f207e3b9646 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -49,9 +49,7 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { if (dyn_cast(arg1) && dyn_cast(arg2)) { - auto abstract = arg1->Join(arg2); - MS_EXCEPTION_IF_NULL(abstract); - return abstract; + return arg1->Join(arg2); } return nullptr; } @@ -644,7 +642,39 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vectorToString() + << ", and that of the previous branch is " << last_spec->ToString() << ". Please check the node " + << node->DebugString(); + if (node->isa()) { + auto cnode = node->cast()->input(0); + if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { + // {prim::kPrimSwitch, cond, true_branch, false_branch} + constexpr int true_index = 2; + constexpr int false_index = 3; + auto inputs = cnode->cast()->inputs(); + buffer << ", true branch: " << inputs.at(true_index)->ToString() + << ", false branch: " << inputs.at(false_index)->ToString(); + } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { + // {prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, branch1, branch2, ...}} + constexpr int branch_index = 2; + auto tuple_node = cnode->cast()->input(branch_index); + if (IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) { + auto tuple_inputs = tuple_node->cast()->inputs(); + for (size_t i = 1; i < tuple_inputs.size(); i++) { + buffer << ", branch" << i << ": " << tuple_inputs.at(i); + } + } + } + } + buffer << ". trace: " << trace::DumpSourceLines(node); + return buffer.str(); +} + +EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) { if (out_specs.size() == 0) { MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; } @@ -654,8 +684,28 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_ // If only one result derived, then broaden it to avoid wrong constant propagation. return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); } - auto joined_spec = AbstractJoin(out_specs); - MS_EXCEPTION_IF_NULL(joined_spec); + MS_EXCEPTION_IF_NULL(node); + + AbstractBasePtr last_spec = out_specs[0]; + AbstractBasePtr joined_spec = out_specs[0]; + for (const auto &spec : out_specs) { + MS_EXCEPTION_IF_NULL(spec); + try { + joined_spec = joined_spec->Join(spec); + } catch (const py::type_error &ex) { + auto error_info = ExtractLoggingInfo(ex.what()); + MS_EXCEPTION(TypeError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info); + } catch (const py::value_error &ex) { + auto error_info = ExtractLoggingInfo(ex.what()); + MS_EXCEPTION(ValueError) << JoinBranchesFailedInfo(spec, last_spec, node, error_info); + } catch (const std::exception &ex) { + auto error_info = ExtractLoggingInfo(ex.what()); + MS_LOG(EXCEPTION) << JoinBranchesFailedInfo(spec, last_spec, node, error_info); + } + MS_EXCEPTION_IF_NULL(joined_spec); + last_spec = spec; + } + MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); return std::make_shared(joined_spec, std::make_shared()); } @@ -664,8 +714,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorabstract(); MS_EXCEPTION_IF_NULL(eval_abstract); - if (last_abstract != nullptr && eval_abstract->Join(last_abstract) == nullptr) { - auto node = out_conf->node(); - MS_LOG(EXCEPTION) << "Abstracts cannot be joined! Please check the data type of node : " << node->DebugString() - << ".\nThe current evaluator is " << eval->ToString() << " with abstract " - << eval_abstract->ToString() << ", and the previous evaluator is " << last_eval->ToString() - << " with abstract " << last_abstract->ToString() << trace::DumpSourceLines(node); - } else { - last_abstract = eval_abstract; - last_eval = eval; - } - out_specs.push_back(eval_abstract); eval_trace_.pop_back(); if (eval_trace_.empty()) { @@ -729,7 +766,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectornode()); } EvalResultPtr AnfNodeConfig::ObtainEvalResult() { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 67db99ff5e9..9d00eebb2ad 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -316,7 +316,7 @@ class AnalysisEngine : public std::enable_shared_from_this { EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, bool *continue_flag); - EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); + EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node); const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index dc7572dfcea..8d2bb730ca3 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -18,14 +18,73 @@ #include "abstract/abstract_value.h" +#include #include #include "utils/symbolic.h" #include "abstract/utils.h" #include "utils/ms_context.h" +#include "utils/trace_base.h" namespace mindspore { namespace abstract { +AnfNodePtr GetTraceNode(const AbstractBasePtr &abs) { + AnfNodePtr node = nullptr; + if (abs->trace_node_provider_ != nullptr) { + abs->trace_node_provider_(&node); + } + return node; +} + +inline void AbstractTypeJoinLogging(const AbstractBasePtr &abstract1, const AbstractBasePtr &abstract2) { + std::ostringstream oss; + oss << "Type Join Failed: abstract type " << abstract1->type_name() << " cannot not join with " + << abstract2->type_name() << ". For more details, please refer to the FAQ at https://www.mindspore.cn. " + << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString(); + auto node = GetTraceNode(abstract1); + if (node != nullptr) { + oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node); + } + MS_EXCEPTION(TypeError) << oss.str(); +} + +inline void TypeJoinLogging(const TypePtr &type1, const TypePtr &type2, const AbstractBasePtr &abstract1, + const AbstractBasePtr &abstract2) { + std::ostringstream oss; + oss << "Type Join Failed: dtype1 = " << type1->ToString() << ", dtype2 = " << type2->ToString() + << ". For more details, please refer to the FAQ at https://www.mindspore.cn. " + << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString(); + auto node = GetTraceNode(abstract1); + if (node != nullptr) { + oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node); + } + MS_EXCEPTION(TypeError) << oss.str(); +} + +inline void ShapeJoinLogging(const BaseShapePtr &shape1, const BaseShapePtr &shape2, const AbstractBasePtr &abstract1, + const AbstractBasePtr &abstract2) { + std::ostringstream oss; + oss << "Shape Join Failed: shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString() + << ". For more details, please refer to the FAQ at https://www.mindspore.cn. " + << "this: " << abstract1->ToString() << ", other: " << abstract2->ToString(); + auto node = GetTraceNode(abstract1); + if (node != nullptr) { + oss << ". Please check the node " << node->DebugString() << ". trace: " << trace::DumpSourceLines(node); + } + MS_EXCEPTION(ValueError) << oss.str(); +} + +std::string ExtractLoggingInfo(const std::string &info) { + // Extract log information based on the keyword "Type Join Failed" or "Shape Join Failed" + std::regex e("(Type Join Failed|Shape Join Failed).*?\\."); + std::smatch result; + bool found = std::regex_search(info, result, e); + if (found) { + return result.str(); + } + return ""; +} + bool AbstractBase::operator==(const AbstractBase &other) const { if (tid() != other.tid()) { return false; @@ -112,15 +171,14 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { if (*this == *other) { return shared_from_base(); } - auto value_self = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_self); - TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); - auto value_other = other->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(value_other); + auto type_self = GetTypeTrack(); + auto type_other = other->GetTypeTrack(); + TypePtr res_type = TypeJoin(type_self, type_other); if (res_type == kAnyType) { - MS_LOG(ERROR) << "Type join failed, type1 = " << value_self->ToString() << ", type2 = " << value_other->ToString(); - return nullptr; + TypeJoinLogging(type_self, type_other, shared_from_base(), other); } + auto value_self = GetValueTrack(); + auto value_other = other->GetValueTrack(); ValuePtr res_value = ValueJoin(value_self, value_other); if (res_value == value_self) { return shared_from_base(); @@ -187,7 +245,7 @@ AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); auto other_func = dyn_cast(other); if (other_func == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + AbstractTypeJoinLogging(shared_from_base(), other); } return Join(other_func); } @@ -281,15 +339,10 @@ AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); auto other_sequeue = dyn_cast(other); if (other_sequeue == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + AbstractTypeJoinLogging(shared_from_base(), other); } auto joined_list = AbstractJoin(elements_, other_sequeue->elements_); bool changes = false; - if (elements_.size() > joined_list.size()) { - MS_EXCEPTION(IndexError) << "Abstract " << ToString() << "'s size is " << elements_.size() - << " but element joined abstract " << other->ToString() << "with the the size " - << joined_list.size(); - } for (std::size_t i = 0; i < elements_.size(); i++) { if (elements_[i] != joined_list[i]) { changes = true; @@ -469,29 +522,41 @@ BaseShapePtr AbstractTensor::BuildShape() const { AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); auto type = other->BuildType(); - MS_EXCEPTION_IF_NULL(element_); MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(element_); + + // AbstractTensor join with AbstractUndetermined if (type->type_id() == kObjectTypeUndeterminedType) { auto other_undetermined_tensor = dyn_cast(other); MS_EXCEPTION_IF_NULL(other_undetermined_tensor); - auto element = element_->Join(other_undetermined_tensor->element()); - if (element == nullptr) { - return nullptr; + // check shape + auto res_shape = ShapeJoin(shape(), other_undetermined_tensor->shape()); + if (res_shape == nullptr) { + ShapeJoinLogging(shape(), other_undetermined_tensor->shape(), shared_from_base(), other); } - return std::make_shared(element, ShapeJoin(shape(), other_undetermined_tensor->shape())); + // check element + auto element = element_->Join(other_undetermined_tensor->element()); + MS_EXCEPTION_IF_NULL(element); + return std::make_shared(element, res_shape); } + + // AbstractTensor join with AbstractTensor auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + AbstractTypeJoinLogging(shared_from_base(), other); } if (*this == *other) { return shared_from_base(); } - auto element = element_->Join(other_tensor->element_); - if (element == nullptr) { - return nullptr; + // check shape + auto res_shape = ShapeJoin(this->shape(), other_tensor->shape()); + if (res_shape == nullptr) { + ShapeJoinLogging(shape(), other_tensor->shape(), shared_from_base(), other); } - return std::make_shared(element, ShapeJoin(this->shape(), other_tensor->shape())); + // check element + auto element = element_->Join(other_tensor->element_); + MS_EXCEPTION_IF_NULL(element); + return std::make_shared(element, res_shape); } bool AbstractTensor::equal_to(const AbstractTensor &other) const { @@ -826,7 +891,7 @@ AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); auto other_jtagged = dyn_cast(other); if (other_jtagged == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + AbstractTypeJoinLogging(shared_from_base(), other); } auto joined_elem = element_->Join(other_jtagged->element_); return std::make_shared(joined_elem); @@ -1328,15 +1393,14 @@ std::string AbstractSparseTensor::ToString() const { AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); - if (other->isa()) { - return shared_from_base(); + if (!other->isa()) { + auto this_type = GetTypeTrack(); + auto other_type = other->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(this_type); + MS_EXCEPTION_IF_NULL(other); + TypeJoinLogging(this_type, other_type, shared_from_base(), other); } - auto this_type = GetTypeTrack(); - auto other_type = other->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(this_type); - MS_EXCEPTION_IF_NULL(other); - MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << this_type->ToString() - << ", type2 = " << other_type->ToString(); + return shared_from_base(); } bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; } @@ -1350,15 +1414,14 @@ bool AbstractUMonad::operator==(const AbstractBase &other) const { AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); - if (other->isa()) { - return shared_from_base(); + if (!other->isa()) { + auto this_type = GetTypeTrack(); + auto other_type = other->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(this_type); + MS_EXCEPTION_IF_NULL(other); + TypeJoinLogging(this_type, other_type, shared_from_base(), other); } - auto this_type = GetTypeTrack(); - auto other_type = other->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(this_type); - MS_EXCEPTION_IF_NULL(other); - MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << this_type->ToString() - << ", type2 = " << other_type->ToString(); + return shared_from_base(); } bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; } diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 34ac231e09e..f59afcc3c0d 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -45,6 +45,8 @@ using AbstractBasePtrList = std::vector; // to express the type, shape, and value of the real value. class AbstractBase : public Base { public: + using TraceNodeProvider = std::function; + explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, const BaseShapePtr &shape = kNoShape) : value_(value), type_(type), shape_(shape) {} @@ -72,6 +74,11 @@ class AbstractBase : public Base { virtual BaseShapePtr BuildShape() const { return kNoShape; } virtual AbstractBasePtr Clone() const = 0; + static void set_trace_node_provider(TraceNodeProvider trace_node_provider) { + trace_node_provider_ = trace_node_provider; + } + + inline static TraceNodeProvider trace_node_provider_ = nullptr; // mask for Broaden config inline static const uint8_t kBroadenTensorOnly = 1; inline static const uint8_t kBroadenParameterOnly = 2; @@ -760,6 +767,8 @@ class AbstractIOMonad : public AbstractMonad { }; using AbstractIOMonadPtr = std::shared_ptr; +AnfNodePtr GetTraceNode(const AbstractBasePtr &abs); +std::string ExtractLoggingInfo(const std::string &info); } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 5ff1d5b85cc..5cb9f381a83 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -115,8 +115,7 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) { return shape2; } - MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString() - << ", shape2 = " << shape2->ToString(); + return nullptr; } ShapeVector dims; bool has_dynamic_shape = false; diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index d786ce5eace..4144547f0e6 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -889,7 +889,7 @@ def test_switch_layer_dtype_join_failed(): inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) i = Tensor(0, mstype.int32) - with pytest.raises(Exception) as err: + with pytest.raises(TypeError) as err: net(i, inp)