!28730 optimizes the error message of BroadcastTo, ResizeBilinear, Conv3D and supports zero dims of input for Squeeze.

Merge pull request !28730 from wangshuide/wsd_master
This commit is contained in:
i-robot 2022-01-13 01:31:27 +00:00 committed by Gitee
commit 3dc4999b68
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 7 additions and 8 deletions

View File

@ -62,11 +62,8 @@ class SqueezeGpuKernel : public GpuKernel {
return true; return true;
} }
int64_t dims = SizeToLong(input_shape.size()); int64_t dims = SizeToLong(input_shape.size());
if (dims == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input cannot be 0, but got " << dims;
}
for (const auto i : axis) { for (const auto i : axis) {
if (i < -dims || i >= dims) { if (dims != 0 && (i < -dims || i >= dims)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' should be in the range [-" << dims << "," << dims MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' should be in the range [-" << dims << "," << dims
<< "), but got " << i; << "), but got " << i;
} }

View File

@ -49,6 +49,7 @@ const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const {
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name"); queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
std::vector<std::vector<int>> shapes; std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types; std::vector<TypePtr> types;
@ -145,7 +146,8 @@ bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::v
return false; return false;
} }
if (total_bytes_ != len) { if (total_bytes_ != len) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len << ", expect: " << total_bytes_; MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len
<< " Bytes, expect: " << total_bytes_ << " Bytes.";
return false; return false;
} }

View File

@ -5500,7 +5500,7 @@ class BroadcastTo(Primitive):
def __init__(self, shape): def __init__(self, shape):
"""Initialize BroadcastTo""" """Initialize BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name) validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) validator.check("dimension of input_x", len(shape), "", 0, Rel.GT, self.name)
for ix, i in enumerate(shape): for ix, i in enumerate(shape):
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name) validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)

View File

@ -3278,7 +3278,7 @@ class ResizeBilinear(PrimitiveWithInfer):
validator.check_positive_int(value, f'{i}th value of size', self.name) validator.check_positive_int(value, f'{i}th value of size', self.name)
def infer_shape(self, input_shape): def infer_shape(self, input_shape):
validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name) validator.check("dimension of input", len(input_shape), "", 4, Rel.EQ, self.name)
input_shape = list(input_shape) input_shape = list(input_shape)
batch, channel, _, _ = input_shape batch, channel, _, _ = input_shape
out_shape = [batch, channel] out_shape = [batch, channel]
@ -7789,7 +7789,7 @@ class Conv3D(PrimitiveWithInfer):
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode) self.add_prim_attr('mode', self.mode)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) self.format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
self.add_prim_attr('data_format', self.format) self.add_prim_attr('data_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_equal_int(group, 1, 'group', self.name) self.group = validator.check_equal_int(group, 1, 'group', self.name)