forked from mindspore-Ecosystem/mindspore
!28730 optimizes the error message of BroadcastTo, ResizeBilinear, Conv3D and supports zero dims of input for Squeeze.
Merge pull request !28730 from wangshuide/wsd_master
This commit is contained in:
commit
3dc4999b68
|
@ -62,11 +62,8 @@ class SqueezeGpuKernel : public GpuKernel {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
int64_t dims = SizeToLong(input_shape.size());
|
int64_t dims = SizeToLong(input_shape.size());
|
||||||
if (dims == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input cannot be 0, but got " << dims;
|
|
||||||
}
|
|
||||||
for (const auto i : axis) {
|
for (const auto i : axis) {
|
||||||
if (i < -dims || i >= dims) {
|
if (dims != 0 && (i < -dims || i >= dims)) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' should be in the range [-" << dims << "," << dims
|
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' should be in the range [-" << dims << "," << dims
|
||||||
<< "), but got " << i;
|
<< "), but got " << i;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,7 @@ const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const {
|
||||||
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
|
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<std::vector<int>> shapes;
|
||||||
std::vector<TypePtr> types;
|
std::vector<TypePtr> types;
|
||||||
|
@ -145,7 +146,8 @@ bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::v
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (total_bytes_ != len) {
|
if (total_bytes_ != len) {
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len << ", expect: " << total_bytes_;
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len
|
||||||
|
<< " Bytes, expect: " << total_bytes_ << " Bytes.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5500,7 +5500,7 @@ class BroadcastTo(Primitive):
|
||||||
def __init__(self, shape):
|
def __init__(self, shape):
|
||||||
"""Initialize BroadcastTo"""
|
"""Initialize BroadcastTo"""
|
||||||
validator.check_value_type("shape", shape, (tuple), self.name)
|
validator.check_value_type("shape", shape, (tuple), self.name)
|
||||||
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
validator.check("dimension of input_x", len(shape), "", 0, Rel.GT, self.name)
|
||||||
for ix, i in enumerate(shape):
|
for ix, i in enumerate(shape):
|
||||||
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
||||||
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
||||||
|
|
|
@ -3278,7 +3278,7 @@ class ResizeBilinear(PrimitiveWithInfer):
|
||||||
validator.check_positive_int(value, f'{i}th value of size', self.name)
|
validator.check_positive_int(value, f'{i}th value of size', self.name)
|
||||||
|
|
||||||
def infer_shape(self, input_shape):
|
def infer_shape(self, input_shape):
|
||||||
validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name)
|
validator.check("dimension of input", len(input_shape), "", 4, Rel.EQ, self.name)
|
||||||
input_shape = list(input_shape)
|
input_shape = list(input_shape)
|
||||||
batch, channel, _, _ = input_shape
|
batch, channel, _, _ = input_shape
|
||||||
out_shape = [batch, channel]
|
out_shape = [batch, channel]
|
||||||
|
@ -7789,7 +7789,7 @@ class Conv3D(PrimitiveWithInfer):
|
||||||
|
|
||||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||||
self.add_prim_attr('mode', self.mode)
|
self.add_prim_attr('mode', self.mode)
|
||||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
self.format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
|
||||||
self.add_prim_attr('data_format', self.format)
|
self.add_prim_attr('data_format', self.format)
|
||||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||||
self.group = validator.check_equal_int(group, 1, 'group', self.name)
|
self.group = validator.check_equal_int(group, 1, 'group', self.name)
|
||||||
|
|
Loading…
Reference in New Issue