forked from mindspore-Ecosystem/mindspore
!21869 add select infer func
Merge pull request !21869 from lianliguang/add-selected_infer_func
This commit is contained in:
commit
24d378bab4
|
@ -642,12 +642,16 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value =
|
||||
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
|
||||
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return (abs->BuildValue() != nullptr);
|
||||
});
|
||||
bool need_infer_value = !eval_impl_.in_white_list_;
|
||||
if (need_infer_value == false) {
|
||||
need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
|
||||
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto value = abs->BuildValue();
|
||||
return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
|
||||
!value->isa<Monad>() && !value->isa<FuncGraph>());
|
||||
});
|
||||
}
|
||||
AbstractBasePtr abs_base = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
prim_->BeginRecordAddAttr();
|
||||
|
|
|
@ -55,7 +55,7 @@ constexpr auto kMul = "Mul";
|
|||
constexpr auto kRealDiv = "RealDiv";
|
||||
constexpr auto kReciprocal = "Reciprocal";
|
||||
constexpr auto kLog = "Log";
|
||||
|
||||
constexpr auto kSelect = "Select";
|
||||
constexpr auto kAdd = "Add";
|
||||
constexpr auto kBiasAdd = "BiasAdd";
|
||||
constexpr auto kTile = "Tile";
|
||||
|
@ -514,7 +514,7 @@ inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch
|
|||
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
||||
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
|
||||
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
|
||||
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>(kSelect);
|
||||
inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
|
||||
|
||||
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple");
|
||||
|
|
|
@ -15,9 +15,169 @@
|
|||
*/
|
||||
|
||||
#include "ops/select.h"
|
||||
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameSelect, Select);
|
||||
namespace {
|
||||
constexpr auto kCondIndex = 0;
|
||||
constexpr auto kXIndex = 1;
|
||||
constexpr auto kYIndex = 2;
|
||||
template <typename T>
|
||||
void SelectImpl(const bool *conds, void *x, void *y, void *result, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(y);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(conds);
|
||||
T *x_data = reinterpret_cast<T *>(x);
|
||||
T *y_data = reinterpret_cast<T *>(y);
|
||||
auto result_data = reinterpret_cast<T *>(result);
|
||||
MS_EXCEPTION_IF_NULL(x_data);
|
||||
MS_EXCEPTION_IF_NULL(y_data);
|
||||
MS_EXCEPTION_IF_NULL(result_data);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto cond = conds[i];
|
||||
result_data[i] = cond ? x_data[i] : y_data[i];
|
||||
}
|
||||
}
|
||||
abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto cond_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kCondIndex]->BuildShape());
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kXIndex]->BuildShape());
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kYIndex]->BuildShape());
|
||||
bool error_flag = false;
|
||||
if (x_shape[kShape] != cond_shape[kShape] || x_shape[kShape] != y_shape[kShape]) {
|
||||
error_flag = true;
|
||||
}
|
||||
if (CheckAndConvertUtils::HasDynamicShapeInput(input_args)) {
|
||||
if (x_shape[kMaxShape] != cond_shape[kMaxShape] || x_shape[kMaxShape] != y_shape[kMaxShape]) {
|
||||
error_flag = true;
|
||||
}
|
||||
if (x_shape[kMinShape] != cond_shape[kMinShape] || x_shape[kMinShape] != y_shape[kMinShape]) {
|
||||
error_flag = true;
|
||||
}
|
||||
}
|
||||
if (error_flag) {
|
||||
MS_LOG(ERROR) << " cond shape :" << input_args[kCondIndex]->BuildShape()->ToString();
|
||||
MS_LOG(ERROR) << " x shape :" << input_args[kXIndex]->BuildShape()->ToString();
|
||||
MS_LOG(ERROR) << " y shape :" << input_args[kYIndex]->BuildShape()->ToString();
|
||||
MS_EXCEPTION(ValueError) << "The x_shape is not same as y_shape and cond_shape";
|
||||
}
|
||||
return input_args[1]->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr SelectInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
auto x_type = input_args[kXIndex]->BuildType();
|
||||
auto y_type = input_args[kYIndex]->BuildType();
|
||||
auto cond_type = input_args[kCondIndex]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, {kTensorType}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("y_type", y_type, {kTensorType}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("cond", cond_type, {kBool}, prim_name);
|
||||
if (*x_type != *y_type) {
|
||||
MS_EXCEPTION(TypeError) << prim_name << "the x_type " << x_type->ToString() << " must be the same as y_type "
|
||||
<< y_type->ToString();
|
||||
}
|
||||
return x_type;
|
||||
}
|
||||
AbstractBasePtr SelectInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 3, "ops [select]");
|
||||
auto type = SelectInferType(primitive, input_args);
|
||||
auto shape = SelectInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
ValuePtr SelectInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto result_type = SelectInferType(prim, input_args);
|
||||
auto result_shape = SelectInferShape(prim, input_args)->cast<abstract::ShapePtr>();
|
||||
auto cond_value = input_args[kCondIndex]->BuildValue();
|
||||
auto x = input_args[kXIndex]->BuildValue();
|
||||
auto y = input_args[kYIndex]->BuildValue();
|
||||
if (x == nullptr || y == nullptr || cond_value == nullptr || result_shape->IsDynamic()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto x_tensor = x->cast<tensor::TensorPtr>();
|
||||
auto y_tensor = y->cast<tensor::TensorPtr>();
|
||||
auto cond_tensor = cond_value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x_tensor);
|
||||
MS_EXCEPTION_IF_NULL(y_tensor);
|
||||
MS_EXCEPTION_IF_NULL(cond_tensor);
|
||||
auto conds = cond_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(conds);
|
||||
bool *cond_data = reinterpret_cast<bool *>(cond_tensor->data_c());
|
||||
auto data_size = cond_tensor->DataSize();
|
||||
auto type_id = x_tensor->data_type();
|
||||
auto result_tensor = std::make_shared<tensor::Tensor>(type_id, result_shape->shape());
|
||||
switch (type_id) {
|
||||
case kNumberTypeBool: {
|
||||
SelectImpl<bool>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt: {
|
||||
SelectImpl<int>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt8: {
|
||||
SelectImpl<int8_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt16: {
|
||||
SelectImpl<int16_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt32: {
|
||||
SelectImpl<int32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt64: {
|
||||
SelectImpl<int64_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt: {
|
||||
SelectImpl<uint32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt8: {
|
||||
SelectImpl<uint8_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt16: {
|
||||
SelectImpl<uint16_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt32: {
|
||||
SelectImpl<uint32_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt64: {
|
||||
SelectImpl<uint64_t>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat: {
|
||||
SelectImpl<float>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat16: {
|
||||
SelectImpl<float16>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat32: {
|
||||
SelectImpl<float>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat64: {
|
||||
SelectImpl<double>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_EXCEPTION(TypeError) << "Select not supported type " << result_type->ToString();
|
||||
}
|
||||
}
|
||||
return result_tensor;
|
||||
}
|
||||
} // namespace
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Select, prim::kPrimSelect, SelectInfer, SelectInferValue, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -721,4 +721,15 @@ size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs
|
|||
}
|
||||
return remove_monad_count;
|
||||
}
|
||||
|
||||
bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
|
||||
for (const auto &item : abs_list) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto shape = item->BuildShape();
|
||||
if (shape->IsDynamic()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -313,6 +313,7 @@ class CheckAndConvertUtils {
|
|||
const int64_t match_value, const std::string &prim_name);
|
||||
static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
|
||||
const std::string &prim_name);
|
||||
static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
|
||||
|
||||
private:
|
||||
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);
|
||||
|
|
|
@ -2817,7 +2817,7 @@ class Rint(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class Select(PrimitiveWithInfer):
|
||||
class Select(Primitive):
|
||||
r"""
|
||||
|
||||
Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the `condition`.
|
||||
|
@ -2879,28 +2879,6 @@ class Select(PrimitiveWithInfer):
|
|||
"""Initialize Select."""
|
||||
self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, cond_shape, x_shape, y_shape):
|
||||
if cond_shape != x_shape or x_shape != y_shape:
|
||||
raise ValueError('The x_shape and y_shape must be the same as cond_shape.')
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, cond_type, x_type, y_type):
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
|
||||
if x_type != y_type:
|
||||
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
|
||||
return x_type
|
||||
|
||||
def infer_value(self, cond, x, y):
|
||||
if cond is not None and x is not None and y is not None:
|
||||
cond = cond.asnumpy()
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = np.where(cond, x, y)
|
||||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
||||
def _compute_slicing_length(begin, end, stride, x_shape, i):
|
||||
"""Computes the length of the slicing."""
|
||||
|
|
Loading…
Reference in New Issue