fix slice bug of infershape
This commit is contained in:
parent
ff60ea8eaa
commit
15cbfba761
|
@ -23,7 +23,6 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTwo = 2;
|
||||
constexpr size_t kAdamInputsNum = 10;
|
||||
constexpr size_t kAdamOutputsNum = 3;
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
|
@ -131,13 +130,7 @@ bool AdamCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con
|
|||
<< "', the shape and dtype of 'v' and 'var' must be the same, but got the memory size of 'v': "
|
||||
<< inputs[kIndexV]->size << " and 'var': " << inputs[kIndexVar]->size;
|
||||
}
|
||||
if ((dtype_ == kNumberTypeFloat32 && inputs[kIndexVar]->size != inputs[kIndexGrad]->size) ||
|
||||
(dtype_ == kNumberTypeFloat16 && inputs[kIndexVar]->size != inputs[kIndexGrad]->size / kTwo)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the shape and dtype of 'gradient' and 'var' must be the same, but got "
|
||||
"the memory size of 'gradient': "
|
||||
<< inputs[kIndexGrad]->size << " and 'var': " << inputs[kIndexVar]->size;
|
||||
}
|
||||
|
||||
size_t f_size = sizeof(float);
|
||||
if (inputs[kIndexBeta1Power]->size != f_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
|
|
|
@ -28,25 +28,21 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
std::vector<std::vector<int64_t>> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr int64_t kDynamicOutValue = -2;
|
||||
std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive, const ValuePtr &input_value) {
|
||||
std::vector<int64_t> tmp_input;
|
||||
std::vector<std::vector<int64_t>> input_values;
|
||||
for (size_t i = 1; i <= kInputIndex2; ++i) {
|
||||
auto input_value = input_args[i]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueTuple>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueList>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckListInt("slice args value", input_value, primitive->name());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For Slice, the begin and size must be Tuple or List.";
|
||||
}
|
||||
input_values.emplace_back(tmp_input);
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTensorIntValue("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueTuple>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckTupleInt("slice args value", input_value, primitive->name());
|
||||
} else if (input_value->isa<ValueList>()) {
|
||||
tmp_input = CheckAndConvertUtils::CheckListInt("slice args value", input_value, primitive->name());
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For Slice, the begin and size must be Tuple or List.";
|
||||
}
|
||||
return input_values;
|
||||
|
||||
return tmp_input;
|
||||
}
|
||||
|
||||
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -75,24 +71,30 @@ abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
out_shape_min = input_x_shape;
|
||||
out_shape_max = input_x_shape;
|
||||
}
|
||||
if (input_begin_value_ptr->isa<AnyValue>() || input_size_value_ptr->isa<AnyValue>()) {
|
||||
if (input_size_value_ptr->isa<AnyValue>()) {
|
||||
if (input_size_shape[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(input_size_shape[0]); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < input_size_shape.size(); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
if (input_begin_value_ptr->isa<AnyValue>() && !input_size_value_ptr->isa<AnyValue>()) {
|
||||
auto input_value = input_args[kInputIndex2]->BuildValue();
|
||||
auto tmp_input = InferImplSliceFuncCalInputValue(primitive, input_value);
|
||||
for (size_t i = 0; i < tmp_input.size(); i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
|
||||
}
|
||||
auto input_values = InferImplSliceFuncCalInputValue(primitive, input_args);
|
||||
auto input_begin_value = input_values[0];
|
||||
auto input_size_value = input_values[1];
|
||||
if (input_size_value_ptr->isa<AnyValue>()) {
|
||||
if (input_begin_value_ptr->isa<AnyValue>() || input_size_shape.size() == 0) {
|
||||
out_shape.push_back(kDynamicOutValue);
|
||||
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
|
||||
}
|
||||
if (input_size_shape[0] < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the size shape haven't support dynamic yet.";
|
||||
}
|
||||
for (int64_t i = 0; i < input_size_shape[0]; i++) {
|
||||
out_shape.push_back(-1);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape, out_shape_min, out_shape_max);
|
||||
}
|
||||
|
||||
auto input_begin_value = InferImplSliceFuncCalInputValue(primitive, input_args[kInputIndex1]->BuildValue());
|
||||
auto input_size_value = InferImplSliceFuncCalInputValue(primitive, input_args[kInputIndex2]->BuildValue());
|
||||
auto rank = input_x_shape.size();
|
||||
if (input_begin_value.size() != rank || input_size_value.size() != rank) {
|
||||
MS_EXCEPTION(ValueError) << "For Slice, the shape of input|begin|size must be equal.";
|
||||
|
|
Loading…
Reference in New Issue