forked from mindspore-Ecosystem/mindspore
!21301 fix kernel size bug in conv2d
Merge pull request !21301 from Simson/opinfer
This commit is contained in:
commit
8610b7e38d
|
@ -713,12 +713,12 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
|||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) {
|
||||
return input_abstract;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (AnfAlgo::HasDynamicShapeFlag(primitive)) {
|
||||
return input_abstract;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device == kGPUDevice) {
|
||||
if (DynamicShapeConstInputToAttrGPU.find(primitive->name()) != DynamicShapeConstInputToAttrGPU.end()) {
|
||||
return input_abstract;
|
||||
|
|
|
@ -2076,6 +2076,7 @@ class Conv2DBackpropInput(Primitive):
|
|||
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.add_prim_attr('kernel_size', self.kernel_size)
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError("NHWC format only support in GPU target.")
|
||||
|
|
Loading…
Reference in New Issue