fix slice bug of infershape

This commit is contained in:
wangyanling10 2022-06-30 22:37:18 +08:00
parent ff60ea8eaa
commit 15cbfba761
2 changed files with 35 additions and 40 deletions

View File

@ -23,7 +23,6 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTwo = 2;
constexpr size_t kAdamInputsNum = 10;
constexpr size_t kAdamOutputsNum = 3;
constexpr size_t kScalarIndex = 0;
@ -131,13 +130,7 @@ bool AdamCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con
<< "', the shape and dtype of 'v' and 'var' must be the same, but got the memory size of 'v': "
<< inputs[kIndexV]->size << " and 'var': " << inputs[kIndexVar]->size;
}
if ((dtype_ == kNumberTypeFloat32 && inputs[kIndexVar]->size != inputs[kIndexGrad]->size) ||
(dtype_ == kNumberTypeFloat16 && inputs[kIndexVar]->size != inputs[kIndexGrad]->size / kTwo)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the shape and dtype of 'gradient' and 'var' must be the same, but got "
"the memory size of 'gradient': "
<< inputs[kIndexGrad]->size << " and 'var': " << inputs[kIndexVar]->size;
}
size_t f_size = sizeof(float);
if (inputs[kIndexBeta1Power]->size != f_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_

View File

@ -28,12 +28,9 @@
namespace mindspore {
namespace ops {
namespace {
std::vector<std::vector<int64_t>> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr int64_t kDynamicOutValue = -2;
std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive, const ValuePtr &input_value) {
std::vector<int64_t> tmp_input;
std::vector<std::vector<int64_t>> input_values;
for (size_t i = 1; i <= kInputIndex2; ++i) {
auto input_value = input_args[i]->BuildValue();
MS_EXCEPTION_IF_NULL(input_value);
if (input_value->isa<tensor::Tensor>()) {
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, primitive->name());
@ -44,9 +41,8 @@ std::vector<std::vector<int64_t>> InferImplSliceFuncCalInputValue(const Primitiv
} else {
MS_EXCEPTION(TypeError) << "For Slice, the begin and size must be Tuple or List.";
}
input_values.emplace_back(tmp_input);
}
return input_values;
return tmp_input;
}
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
@ -75,24 +71,30 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
out_shape_min = input_x_shape;
out_shape_max = input_x_shape;
}
if (input_begin_value_ptr->isa<AnyValue>() || input_size_value_ptr->isa<AnyValue>()) {
if (input_size_value_ptr->isa<AnyValue>()) {
if (input_size_shape[0] < 0) {
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
}
for (size_t i = 0; i < LongToSize(input_size_shape[0]); i++) {
if (input_begin_value_ptr->isa<AnyValue>() && !input_size_value_ptr->isa<AnyValue>()) {
auto input_value = input_args[kInputIndex2]->BuildValue();
auto tmp_input = InferImplSliceFuncCalInputValue(primitive, input_value);
for (size_t i = 0; i < tmp_input.size(); i++) {
out_shape.push_back(-1);
}
} else {
for (size_t i = 0; i < input_size_shape.size(); i++) {
out_shape.push_back(-1);
}
}
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
}
auto input_values = InferImplSliceFuncCalInputValue(primitive, input_args);
auto input_begin_value = input_values[0];
auto input_size_value = input_values[1];
if (input_size_value_ptr->isa<AnyValue>()) {
if (input_begin_value_ptr->isa<AnyValue>() || input_size_shape.size() == 0) {
out_shape.push_back(kDynamicOutValue);
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
}
if (input_size_shape[0] < 0) {
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
}
for (int64_t i = 0; i < input_size_shape[0]; i++) {
out_shape.push_back(-1);
}
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
}
auto input_begin_value = InferImplSliceFuncCalInputValue(primitive, input_args[kInputIndex1]->BuildValue());
auto input_size_value = InferImplSliceFuncCalInputValue(primitive, input_args[kInputIndex2]->BuildValue());
auto rank = input_x_shape.size();
if (input_begin_value.size() != rank || input_size_value.size() != rank) {
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";