!28889 optimizes the error message of BroadcastTo, ResizeBilinear, Conv3D and supports zero dims of input for Squeeze.
Merge pull request !28889 from wangshuide/wsd_r1.6
This commit is contained in:
commit
bd74e1510f
|
@ -62,11 +62,8 @@ class SqueezeGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
int64_t dims = SizeToLong(input_shape.size());
|
||||
if (dims == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input cannot be 0, but got " << dims;
|
||||
}
|
||||
for (const auto i : axis) {
|
||||
if (i < -dims || i >= dims) {
|
||||
if (dims != 0 && (i < -dims || i >= dims)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the 'axis' should be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << i;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ const std::vector<size_t> &DatasetIteratorKernel::GetWorkspaceSizeList() const {
|
|||
bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<TypePtr> types;
|
||||
|
@ -145,7 +146,8 @@ bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::v
|
|||
return false;
|
||||
}
|
||||
if (total_bytes_ != len) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len << ", expect: " << total_bytes_;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len
|
||||
<< " Bytes, expect: " << total_bytes_ << " Bytes.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -5499,7 +5499,7 @@ class BroadcastTo(Primitive):
|
|||
def __init__(self, shape):
|
||||
"""Initialize BroadcastTo"""
|
||||
validator.check_value_type("shape", shape, (tuple), self.name)
|
||||
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
||||
validator.check("dimension of input_x", len(shape), "", 0, Rel.GT, self.name)
|
||||
for ix, i in enumerate(shape):
|
||||
validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
|
||||
validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
|
||||
|
|
|
@ -3310,7 +3310,7 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|||
validator.check_positive_int(value, f'{i}th value of size', self.name)
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name)
|
||||
validator.check("dimension of input", len(input_shape), "", 4, Rel.EQ, self.name)
|
||||
input_shape = list(input_shape)
|
||||
batch, channel, _, _ = input_shape
|
||||
out_shape = [batch, channel]
|
||||
|
@ -7845,7 +7845,7 @@ class Conv3D(PrimitiveWithInfer):
|
|||
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.add_prim_attr('mode', self.mode)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||
self.group = validator.check_equal_int(group, 1, 'group', self.name)
|
||||
|
|
Loading…
Reference in New Issue