scale_and_translate / scale_and_translate_grad support dynamic shape

This commit is contained in:
mengyuanli 2022-11-07 09:42:48 +08:00
parent 413dba1523
commit 09472bb296
4 changed files with 23 additions and 57 deletions

View File

@ -38,32 +38,9 @@ bool ScaleAndTranslateCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
return false;
}
kernel_name_ = kernel_ptr->name();
input0_shape_ = inputs[kIndex0]->GetShapeVector();
input1_shape_ = inputs[kIndex1]->GetShapeVector();
input2_shape_ = inputs[kIndex2]->GetShapeVector();
input3_shape_ = inputs[kIndex3]->GetShapeVector();
input0_dtype_ = inputs[kIndex0]->GetDtype();
kernel_type_ = kernel_ptr->get_kernel_type();
antialias_ = kernel_ptr->get_antialias();
size_t input0_dim = 4;
std::vector<int64_t> valid_shape = {2};
// dims check
if (input0_shape_.size() != input0_dim) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[images]'s rank must be 4, but got "
<< input0_shape_.size() << ".";
}
if (input1_shape_ != valid_shape) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[size]'s shape must be (2,), but got " << input1_shape_
<< ".";
}
if (input2_shape_ != valid_shape) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[scale]'s shape must be (2,), but got " << input1_shape_
<< ".";
}
if (input3_shape_ != valid_shape) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[translation]'s shape must be (2,), but got "
<< input1_shape_ << ".";
}
switch (input0_dtype_) {
case kNumberTypeInt8:
kernel_func_ = &ScaleAndTranslateCpuKernelMod::LaunchKernel<int8_t>;
@ -102,27 +79,8 @@ bool ScaleAndTranslateGradCpuKernelMod::Init(const BaseOperatorPtr &base_operato
return false;
}
kernel_name_ = kernel_ptr->name();
input1_shape_ = inputs[kIndex1]->GetShapeVector();
input2_shape_ = inputs[kIndex2]->GetShapeVector();
input3_shape_ = inputs[kIndex3]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
kernel_type_ = kernel_ptr->get_kernel_type();
antialias_ = kernel_ptr->get_antialias();
size_t dim = 4;
std::vector<int64_t> valid_shape = {2};
// dims check
if (input1_shape_.size() != dim) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[original_image]'s rank must be 4, but got "
<< input1_shape_.size() << ".";
}
if (input2_shape_ != valid_shape) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[scale]'s shape must be (2,), but got " << input1_shape_
<< ".";
}
if (input3_shape_ != valid_shape) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the input[translation]'s shape must be (2,), but got "
<< input1_shape_ << ".";
}
kernel_func_ = &ScaleAndTranslateGradCpuKernelMod::LaunchKernel<float>;
return true;
}
@ -527,10 +485,13 @@ int ScaleAndTranslateCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, others); ret != KRET_OK) {
return ret;
}
input0_shape_ = inputs[kIndex0]->GetShapeVector();
input1_shape_ = inputs[kIndex1]->GetShapeVector();
input2_shape_ = inputs[kIndex2]->GetShapeVector();
input3_shape_ = inputs[kIndex3]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
return 0;
}
@ -539,11 +500,14 @@ int ScaleAndTranslateGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operat
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, others); ret != KRET_OK) {
return ret;
}
input0_shape_ = inputs[kIndex0]->GetShapeVector();
input1_shape_ = inputs[kIndex1]->GetShapeVector();
input2_shape_ = inputs[kIndex2]->GetShapeVector();
input3_shape_ = inputs[kIndex3]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
return 0;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_AND_TRANSLATE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_AND_TRANSLATE_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SCALE_AND_TRANSLATE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SCALE_AND_TRANSLATE_CPU_KERNEL_H_
#include <map>
#include <memory>
@ -60,7 +60,6 @@ class ScaleAndTranslateCpuKernelMod : public NativeCpuKernelMod {
~ScaleAndTranslateCpuKernelMod() override = default;
// void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -119,7 +118,6 @@ class ScaleAndTranslateGradCpuKernelMod : public NativeCpuKernelMod {
~ScaleAndTranslateGradCpuKernelMod() override = default;
// void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -157,4 +155,4 @@ class ScaleAndTranslateGradCpuKernelMod : public NativeCpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_AND_TRANSLATE_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SCALE_AND_TRANSLATE_CPU_KERNEL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_
#include <vector>
#include <string>
@ -60,4 +60,4 @@ class ScaleAndTranslateGpuKernelMod : public NativeGpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SCALE_AND_TRANSLATE_GPU_KERNEL_H_

View File

@ -36,6 +36,12 @@ abstract::ShapePtr ScaleAndTranslateInferShape(const PrimitivePtr &primitive,
auto scale_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto translation_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
// support dynamic rank
if (IsDynamicRank(images_shape) || IsDynamicRank(size_shape) || IsDynamicRank(scale_shape) ||
IsDynamicRank(translation_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
const int64_t kShapeSize = 1;
const int64_t kElementsNumber = 2;
const int64_t kImagesShapeSize = 4;
@ -102,9 +108,7 @@ abstract::ShapePtr ScaleAndTranslateInferShape(const PrimitivePtr &primitive,
(void)out_shape.emplace_back(-1);
(void)out_shape.emplace_back(-1);
(void)out_shape.emplace_back(images_shape[kInputIndex3]);
ShapeVector shape_min = {images_shape[kInputIndex0], 1, 1, images_shape[kInputIndex3]};
ShapeVector shape_max = {images_shape[kInputIndex0], 1, 1, images_shape[kInputIndex3]};
return std::make_shared<abstract::Shape>(out_shape, shape_min, shape_max);
return std::make_shared<abstract::Shape>(out_shape);
}
}