!16494 remove monad abs when using cpp infer

From: @lianliguang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-24 09:33:38 +08:00 committed by Gitee
commit efcc4c7df7
10 changed files with 70 additions and 36 deletions

View File

@ -613,7 +613,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value =
!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
std::all_of(args.begin(), args.end(),
[](const AbstractBasePtr &abs) -> bool { return (abs->BuildValue() != nullptr); });
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();

View File

@ -25,6 +25,23 @@
namespace mindspore {
namespace abstract {
namespace {
std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
std::ostringstream buffer;
bool f_begin = true;
buffer << "(";
for (auto &x : shp) {
if (!f_begin) {
buffer << ", ";
} else {
f_begin = false;
}
buffer << x;
}
buffer << ")";
return buffer.str();
}
} // namespace
// used for print BaseShape content
std::ostream &operator<<(std::ostream &os, const BaseShape &bs) {
os << bs.ToString();
@ -48,17 +65,18 @@ bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == oth
std::string Shape::ToString() const {
std::ostringstream buffer;
bool f_begin = true;
buffer << "(";
for (auto &x : shape_) {
if (!f_begin) {
buffer << ", ";
} else {
f_begin = false;
}
buffer << x;
bool has_dyn_shape = IsDynamic();
if (has_dyn_shape) {
buffer << "{ shape : ";
}
buffer << ShapeVectorToStr(shape_);
if (has_dyn_shape) {
buffer << " | min shape: ";
buffer << ShapeVectorToStr(min_shape_);
buffer << " | max shape: ";
buffer << ShapeVectorToStr(max_shape_);
buffer << " }";
}
buffer << ")";
return buffer.str();
}

View File

@ -286,8 +286,6 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>

View File

@ -562,19 +562,5 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
}
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref, value, [universal]
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
return args_spec_list[0];
}
}
} // namespace abstract
} // namespace mindspore

View File

@ -136,7 +136,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
// Others
{prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
{prim::kPrimLoad, {InferImplLoad, nullptr, true}},
{prim::kPrimAssign, {InferImplAssign, nullptr, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, nullptr, true}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},

View File

@ -26,6 +26,24 @@
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameAssign, Assign);
AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("Assign infer", (CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)),
kEqual, 2, prim_name);
auto check_types = common_valid_types;
check_types.emplace(kBool);
auto variable_type = args_spec_list[0]->BuildType();
auto value_type = args_spec_list[1]->BuildType();
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(std::map<std::string, TypePtr>{{"value", value_type}}, check_types,
prim_name);
if (variable_type->isa<RefKeyType>()) {
return args_spec_list[1]->Broaden();
}
CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -643,4 +643,19 @@ int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
}
return data_format;
}
int64_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
int64_t remove_monad_count = abs_list.size();
for (const auto &item : abs_list) {
if (item->isa<abstract::AbstractMonad>()) {
--remove_monad_count;
}
}
for (int64_t i = 0; i < remove_monad_count; ++i) {
if (abs_list[i]->isa<abstract::AbstractMonad>()) {
MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
}
}
return remove_monad_count;
}
} // namespace mindspore

View File

@ -323,6 +323,7 @@ class CheckAndConvertUtils {
const std::string &arg_name);
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
static int64_t GetAndCheckFormat(const ValuePtr &value);
static int64_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
private:
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);

View File

@ -23,7 +23,7 @@ from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
class Assign(PrimitiveWithCheck):
class Assign(Primitive):
"""
Assigns `Parameter` with a value.
@ -73,12 +73,6 @@ class Assign(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
def check_dtype(self, variable, value):
types = mstype.number_type + (mstype.bool_,)
if variable != mstype.type_refkey:
validator.check_tensor_dtype_valid("variable", variable, types, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, types, self.name)
class InplaceAssign(PrimitiveWithInfer):
"""

View File

@ -29,6 +29,7 @@ class TestDShape : public UT::Common {
Shape shp_2;
Shape shp_3;
Shape shp_4;
Shape shp_5;
NoShape shp_noshp_1;
NoShape shp_noshp_2;
@ -42,6 +43,7 @@ class TestDShape : public UT::Common {
shp_2({1, 1}),
shp_3({1, 2}),
shp_4({1}),
shp_5({-1, 2}, {1, 2}, {3, 3}),
shp_noshp_1(),
shp_noshp_2(),
@ -67,6 +69,7 @@ TEST_F(TestDShape, ToString) {
ASSERT_EQ(shp_3.ToString(), "(1, 2)");
ASSERT_EQ(shp_noshp_1.ToString(), "NoShape");
ASSERT_EQ(shp_tuple_2.ToString(), "TupleShape(NoShape, (1, 1, 1))");
ASSERT_EQ(shp_5.ToString(), "{ shape : (-1, 2) | min shape: (1, 2) | max shape: (3, 3) }");
}
TEST_F(TestDShape, Clone) {