!21869 add select infer func

Merge pull request !21869 from lianliguang/add-selected_infer_func
This commit is contained in:
i-robot 2021-08-23 14:12:54 +00:00 committed by Gitee
commit 24d378bab4
6 changed files with 187 additions and 33 deletions

View File

@ -642,12 +642,16 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value =
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
return (abs->BuildValue() != nullptr);
});
bool need_infer_value = !eval_impl_.in_white_list_;
if (need_infer_value == false) {
need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
auto value = abs->BuildValue();
return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
!value->isa<Monad>() && !value->isa<FuncGraph>());
});
}
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();

View File

@ -55,7 +55,7 @@ constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kReciprocal = "Reciprocal";
constexpr auto kLog = "Log";
constexpr auto kSelect = "Select";
constexpr auto kAdd = "Add";
constexpr auto kBiasAdd = "BiasAdd";
constexpr auto kTile = "Tile";
@ -514,7 +514,7 @@ inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>(kSelect);
inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple");

View File

@ -15,9 +15,169 @@
*/
#include "ops/select.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameSelect, Select);
namespace {
constexpr auto kCondIndex = 0;
constexpr auto kXIndex = 1;
constexpr auto kYIndex = 2;
template <typename T>
void SelectImpl(const bool *conds, void *x, void *y, void *result, size_t size) {
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(conds);
T *x_data = reinterpret_cast<T *>(x);
T *y_data = reinterpret_cast<T *>(y);
auto result_data = reinterpret_cast<T *>(result);
MS_EXCEPTION_IF_NULL(x_data);
MS_EXCEPTION_IF_NULL(y_data);
MS_EXCEPTION_IF_NULL(result_data);
for (size_t i = 0; i < size; ++i) {
auto cond = conds[i];
result_data[i] = cond ? x_data[i] : y_data[i];
}
}
abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto cond_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kCondIndex]->BuildShape());
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kXIndex]->BuildShape());
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kYIndex]->BuildShape());
bool error_flag = false;
if (x_shape[kShape] != cond_shape[kShape] || x_shape[kShape] != y_shape[kShape]) {
error_flag = true;
}
if (CheckAndConvertUtils::HasDynamicShapeInput(input_args)) {
if (x_shape[kMaxShape] != cond_shape[kMaxShape] || x_shape[kMaxShape] != y_shape[kMaxShape]) {
error_flag = true;
}
if (x_shape[kMinShape] != cond_shape[kMinShape] || x_shape[kMinShape] != y_shape[kMinShape]) {
error_flag = true;
}
}
if (error_flag) {
MS_LOG(ERROR) << " cond shape :" << input_args[kCondIndex]->BuildShape()->ToString();
MS_LOG(ERROR) << " x shape :" << input_args[kXIndex]->BuildShape()->ToString();
MS_LOG(ERROR) << " y shape :" << input_args[kYIndex]->BuildShape()->ToString();
MS_EXCEPTION(ValueError) << "The x_shape is not same as y_shape and cond_shape";
}
return input_args[1]->BuildShape();
}
TypePtr SelectInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
auto x_type = input_args[kXIndex]->BuildType();
auto y_type = input_args[kYIndex]->BuildType();
auto cond_type = input_args[kCondIndex]->BuildType();
(void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, {kTensorType}, prim_name);
(void)CheckAndConvertUtils::CheckSubClass("y_type", y_type, {kTensorType}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("cond", cond_type, {kBool}, prim_name);
if (*x_type != *y_type) {
MS_EXCEPTION(TypeError) << prim_name << "the x_type " << x_type->ToString() << " must be the same as y_type "
<< y_type->ToString();
}
return x_type;
}
AbstractBasePtr SelectInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 3, "ops [select]");
auto type = SelectInferType(primitive, input_args);
auto shape = SelectInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
ValuePtr SelectInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto result_type = SelectInferType(prim, input_args);
auto result_shape = SelectInferShape(prim, input_args)->cast<abstract::ShapePtr>();
auto cond_value = input_args[kCondIndex]->BuildValue();
auto x = input_args[kXIndex]->BuildValue();
auto y = input_args[kYIndex]->BuildValue();
if (x == nullptr || y == nullptr || cond_value == nullptr || result_shape->IsDynamic()) {
return nullptr;
}
auto x_tensor = x->cast<tensor::TensorPtr>();
auto y_tensor = y->cast<tensor::TensorPtr>();
auto cond_tensor = cond_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x_tensor);
MS_EXCEPTION_IF_NULL(y_tensor);
MS_EXCEPTION_IF_NULL(cond_tensor);
auto conds = cond_tensor->data_c();
MS_EXCEPTION_IF_NULL(conds);
bool *cond_data = reinterpret_cast<bool *>(cond_tensor->data_c());
auto data_size = cond_tensor->DataSize();
auto type_id = x_tensor->data_type();
auto result_tensor = std::make_shared<tensor::Tensor>(type_id, result_shape->shape());
switch (type_id) {
case kNumberTypeBool: {
SelectImpl<bool>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeInt: {
SelectImpl<int>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeInt8: {
SelectImpl<int8_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeInt16: {
SelectImpl<int16_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeInt32: {
SelectImpl<int32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeInt64: {
SelectImpl<int64_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeUInt: {
SelectImpl<uint32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeUInt8: {
SelectImpl<uint8_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeUInt16: {
SelectImpl<uint16_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeUInt32: {
SelectImpl<uint32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeUInt64: {
SelectImpl<uint64_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeFloat: {
SelectImpl<float>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeFloat16: {
SelectImpl<float16>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeFloat32: {
SelectImpl<float>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
case kNumberTypeFloat64: {
SelectImpl<double>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
break;
}
default: {
MS_EXCEPTION(TypeError) << "Select not supported type " << result_type->ToString();
}
}
return result_tensor;
}
} // namespace
REGISTER_PRIMITIVE_EVAL_IMPL(Select, prim::kPrimSelect, SelectInfer, SelectInferValue, true);
} // namespace ops
} // namespace mindspore

View File

@ -721,4 +721,15 @@ size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs
}
return remove_monad_count;
}
bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
for (const auto &item : abs_list) {
MS_EXCEPTION_IF_NULL(item);
auto shape = item->BuildShape();
if (shape->IsDynamic()) {
return true;
}
}
return false;
}
} // namespace mindspore

View File

@ -313,6 +313,7 @@ class CheckAndConvertUtils {
const int64_t match_value, const std::string &prim_name);
static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
const std::string &prim_name);
static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
private:
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);

View File

@ -2817,7 +2817,7 @@ class Rint(PrimitiveWithInfer):
return x_dtype
class Select(PrimitiveWithInfer):
class Select(Primitive):
r"""
Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the `condition`.
@ -2879,28 +2879,6 @@ class Select(PrimitiveWithInfer):
"""Initialize Select."""
self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
def infer_shape(self, cond_shape, x_shape, y_shape):
if cond_shape != x_shape or x_shape != y_shape:
raise ValueError('The x_shape and y_shape must be the same as cond_shape.')
return x_shape
def infer_dtype(self, cond_type, x_type, y_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
if x_type != y_type:
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type
def infer_value(self, cond, x, y):
if cond is not None and x is not None and y is not None:
cond = cond.asnumpy()
x = x.asnumpy()
y = y.asnumpy()
out = np.where(cond, x, y)
return Tensor(out)
return None
def _compute_slicing_length(begin, end, stride, x_shape, i):
"""Computes the length of the slicing."""