forked from mindspore-Ecosystem/mindspore
!16494 remove monad abs when using cpp infer
From: @lianliguang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
efcc4c7df7
|
@ -613,7 +613,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value =
|
||||
!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
|
||||
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
|
||||
std::all_of(args.begin(), args.end(),
|
||||
[](const AbstractBasePtr &abs) -> bool { return (abs->BuildValue() != nullptr); });
|
||||
AbstractBasePtr abs_base = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
prim_->BeginRecordAddAttr();
|
||||
|
|
|
@ -25,6 +25,23 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
namespace {
|
||||
std::string ShapeVectorToStr(const std::vector<int64_t> &shp) {
|
||||
std::ostringstream buffer;
|
||||
bool f_begin = true;
|
||||
buffer << "(";
|
||||
for (auto &x : shp) {
|
||||
if (!f_begin) {
|
||||
buffer << ", ";
|
||||
} else {
|
||||
f_begin = false;
|
||||
}
|
||||
buffer << x;
|
||||
}
|
||||
buffer << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
} // namespace
|
||||
// used for print BaseShape content
|
||||
std::ostream &operator<<(std::ostream &os, const BaseShape &bs) {
|
||||
os << bs.ToString();
|
||||
|
@ -48,17 +65,18 @@ bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == oth
|
|||
|
||||
std::string Shape::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
bool f_begin = true;
|
||||
buffer << "(";
|
||||
for (auto &x : shape_) {
|
||||
if (!f_begin) {
|
||||
buffer << ", ";
|
||||
} else {
|
||||
f_begin = false;
|
||||
}
|
||||
buffer << x;
|
||||
bool has_dyn_shape = IsDynamic();
|
||||
if (has_dyn_shape) {
|
||||
buffer << "{ shape : ";
|
||||
}
|
||||
buffer << ShapeVectorToStr(shape_);
|
||||
if (has_dyn_shape) {
|
||||
buffer << " | min shape: ";
|
||||
buffer << ShapeVectorToStr(min_shape_);
|
||||
buffer << " | max shape: ";
|
||||
buffer << ShapeVectorToStr(max_shape_);
|
||||
buffer << " }";
|
||||
}
|
||||
buffer << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -286,8 +286,6 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
|
|
|
@ -562,19 +562,5 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
}
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref, value, [universal]
|
||||
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
|
||||
|
||||
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
|
||||
auto type = args_spec_list[0]->BuildType();
|
||||
if (type->type_id() == kObjectTypeRefKey) {
|
||||
return args_spec_list[1]->Broaden();
|
||||
} else {
|
||||
return args_spec_list[0];
|
||||
}
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -136,7 +136,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
// Others
|
||||
{prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
|
||||
{prim::kPrimLoad, {InferImplLoad, nullptr, true}},
|
||||
{prim::kPrimAssign, {InferImplAssign, nullptr, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, {nullptr, nullptr, true}},
|
||||
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},
|
||||
|
|
|
@ -26,6 +26,24 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameAssign, Assign);
|
||||
AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Assign infer", (CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)),
|
||||
kEqual, 2, prim_name);
|
||||
auto check_types = common_valid_types;
|
||||
check_types.emplace(kBool);
|
||||
auto variable_type = args_spec_list[0]->BuildType();
|
||||
auto value_type = args_spec_list[1]->BuildType();
|
||||
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(std::map<std::string, TypePtr>{{"value", value_type}}, check_types,
|
||||
prim_name);
|
||||
if (variable_type->isa<RefKeyType>()) {
|
||||
return args_spec_list[1]->Broaden();
|
||||
}
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -643,4 +643,19 @@ int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
|
|||
}
|
||||
return data_format;
|
||||
}
|
||||
int64_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
|
||||
int64_t remove_monad_count = abs_list.size();
|
||||
for (const auto &item : abs_list) {
|
||||
if (item->isa<abstract::AbstractMonad>()) {
|
||||
--remove_monad_count;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < remove_monad_count; ++i) {
|
||||
if (abs_list[i]->isa<abstract::AbstractMonad>()) {
|
||||
MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
|
||||
}
|
||||
}
|
||||
return remove_monad_count;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -323,6 +323,7 @@ class CheckAndConvertUtils {
|
|||
const std::string &arg_name);
|
||||
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||
static int64_t GetAndCheckFormat(const ValuePtr &value);
|
||||
static int64_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
|
||||
|
||||
private:
|
||||
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);
|
||||
|
|
|
@ -23,7 +23,7 @@ from ...common import dtype as mstype
|
|||
from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
class Assign(PrimitiveWithCheck):
|
||||
class Assign(Primitive):
|
||||
"""
|
||||
Assigns `Parameter` with a value.
|
||||
|
||||
|
@ -73,12 +73,6 @@ class Assign(PrimitiveWithCheck):
|
|||
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def check_dtype(self, variable, value):
|
||||
types = mstype.number_type + (mstype.bool_,)
|
||||
if variable != mstype.type_refkey:
|
||||
validator.check_tensor_dtype_valid("variable", variable, types, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"value": value}, types, self.name)
|
||||
|
||||
|
||||
class InplaceAssign(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -29,6 +29,7 @@ class TestDShape : public UT::Common {
|
|||
Shape shp_2;
|
||||
Shape shp_3;
|
||||
Shape shp_4;
|
||||
Shape shp_5;
|
||||
|
||||
NoShape shp_noshp_1;
|
||||
NoShape shp_noshp_2;
|
||||
|
@ -42,6 +43,7 @@ class TestDShape : public UT::Common {
|
|||
shp_2({1, 1}),
|
||||
shp_3({1, 2}),
|
||||
shp_4({1}),
|
||||
shp_5({-1, 2}, {1, 2}, {3, 3}),
|
||||
|
||||
shp_noshp_1(),
|
||||
shp_noshp_2(),
|
||||
|
@ -67,6 +69,7 @@ TEST_F(TestDShape, ToString) {
|
|||
ASSERT_EQ(shp_3.ToString(), "(1, 2)");
|
||||
ASSERT_EQ(shp_noshp_1.ToString(), "NoShape");
|
||||
ASSERT_EQ(shp_tuple_2.ToString(), "TupleShape(NoShape, (1, 1, 1))");
|
||||
ASSERT_EQ(shp_5.ToString(), "{ shape : (-1, 2) | min shape: (1, 2) | max shape: (3, 3) }");
|
||||
}
|
||||
|
||||
TEST_F(TestDShape, Clone) {
|
||||
|
|
Loading…
Reference in New Issue