move GetUnsortedSegmentOpScalarArg to inner

This commit is contained in:
i-robot 2021-11-23 09:51:02 +08:00 committed by lianliguang
parent 511441a27e
commit 6e886038bb
4 changed files with 57 additions and 50 deletions

View File

@ -26,6 +26,36 @@
namespace mindspore {
namespace abstract {
namespace {
// Get 3rd argument for UnsortedSegmentOps' inferImpl function
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
int64_t num_segments_value = 0;
constexpr size_t scalar_index = 2;
if (args_spec_list[scalar_index]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[scalar_index]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, scalar_index);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
}
return num_segments_value;
}
} // namespace
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a scalar.

View File

@ -313,34 +313,6 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
*max_shape = (*max_shape).empty() ? shape : *max_shape;
}
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
int64_t num_segments_value = 0;
constexpr size_t scalar_index = 2;
if (args_spec_list[scalar_index]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
auto num_segments_value_ptr = num_segments->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_tensor);
if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
} else {
num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
}
} else if (args_spec_list[scalar_index]->isa<AbstractScalar>()) { // num_segments is Scalar
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, scalar_index);
if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
} else {
num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
}
} else {
MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
}
return num_segments_value;
}
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(type);

View File

@ -60,8 +60,6 @@ ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tenso
// Check dynamic shape routine
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
// Get 3rd argument for UnsortedSegmentOps' inferImpl function
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name);
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);

View File

@ -76,11 +76,12 @@ TypePtr SelectInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
(void)CheckAndConvertUtils::CheckSubClass("y_type", y_type, {kTensorType}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("cond", cond_type, {kBool}, prim_name);
if (*x_type != *y_type) {
MS_EXCEPTION(TypeError) << prim_name << "the x_type " << x_type->ToString() << " must be the same as y_type "
MS_EXCEPTION(TypeError) << prim_name << "'s the x_type " << x_type->ToString() << " must be the same as y_type "
<< y_type->ToString();
}
return x_type;
}
AbstractBasePtr SelectInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t input_num = 3;
@ -90,27 +91,11 @@ AbstractBasePtr SelectInfer(const abstract::AnalysisEnginePtr &, const Primitive
return abstract::MakeAbstract(shape, type);
}
ValuePtr SelectInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto result_type = SelectInferType(prim, input_args);
auto result_shape = SelectInferShape(prim, input_args)->cast<abstract::ShapePtr>();
auto cond_value = input_args[kCondIndex]->BuildValue();
auto x = input_args[kXIndex]->BuildValue();
auto y = input_args[kYIndex]->BuildValue();
if (x == nullptr || y == nullptr || cond_value == nullptr || result_shape->IsDynamic()) {
return nullptr;
}
auto x_tensor = x->cast<tensor::TensorPtr>();
auto y_tensor = y->cast<tensor::TensorPtr>();
auto cond_tensor = cond_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x_tensor);
MS_EXCEPTION_IF_NULL(y_tensor);
MS_EXCEPTION_IF_NULL(cond_tensor);
auto conds = cond_tensor->data_c();
MS_EXCEPTION_IF_NULL(conds);
void SelectInnerInferValue(const tensor::TensorPtr &cond_tensor, const tensor::TensorPtr &x_tensor,
const tensor::TensorPtr &y_tensor, const tensor::TensorPtr &result_tensor) {
bool *cond_data = reinterpret_cast<bool *>(cond_tensor->data_c());
auto data_size = cond_tensor->DataSize();
auto type_id = x_tensor->data_type();
auto result_tensor = std::make_shared<tensor::Tensor>(type_id, result_shape->shape());
switch (type_id) {
case kNumberTypeBool: {
SelectImpl<bool>(cond_data, x_tensor->data_c(), y_tensor->data_c(), result_tensor->data_c(), data_size);
@ -173,9 +158,31 @@ ValuePtr SelectInferValue(const PrimitivePtr &prim, const std::vector<AbstractBa
break;
}
default: {
MS_EXCEPTION(TypeError) << "Select not supported type " << result_type->ToString();
MS_EXCEPTION(TypeError) << "Select not supported type " << result_tensor->type()->ToString();
}
}
}
ValuePtr SelectInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
(void)SelectInferType(prim, input_args);
auto result_shape = SelectInferShape(prim, input_args)->cast<abstract::ShapePtr>();
auto cond_value = input_args[kCondIndex]->BuildValue();
auto x = input_args[kXIndex]->BuildValue();
auto y = input_args[kYIndex]->BuildValue();
if (x == nullptr || y == nullptr || cond_value == nullptr || result_shape->IsDynamic()) {
return nullptr;
}
auto x_tensor = x->cast<tensor::TensorPtr>();
auto y_tensor = y->cast<tensor::TensorPtr>();
auto cond_tensor = cond_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x_tensor);
MS_EXCEPTION_IF_NULL(y_tensor);
MS_EXCEPTION_IF_NULL(cond_tensor);
auto conds = cond_tensor->data_c();
MS_EXCEPTION_IF_NULL(conds);
auto type_id = x_tensor->data_type();
auto result_tensor = std::make_shared<tensor::Tensor>(type_id, result_shape->shape());
SelectInnerInferValue(cond_tensor, x_tensor, y_tensor, result_tensor);
return result_tensor;
}
} // namespace