fix incorrect input shape check of TruncatedNormal operator.
This commit is contained in:
parent
813187cf04
commit
1a4d660356
|
@ -64,6 +64,14 @@ int TruncatedNormalCPUKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto shape_input = inputs[kIndex0]->GetShapeVector();
|
||||
if (shape_input.size() != kInputDims) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
if (shape_input[kIndex0] < kInputSizes) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the input tensor shape must >= 2, but got "
|
||||
<< shape_input[kIndex0];
|
||||
}
|
||||
input_type_ = inputs[kIndex0]->GetDtype();
|
||||
output_type_ = outputs[kIndex0]->GetDtype();
|
||||
return KRET_OK;
|
||||
|
@ -126,7 +134,7 @@ std::vector<std::pair<KernelAttr, TruncatedNormalCPUKernelMod::TruncatedNormalFu
|
|||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, double, double>},
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, double, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float16, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -58,8 +58,8 @@ class ReduceScatterInfer : public abstract::OpInferBase {
|
|||
MS_ERROR_IF_NULL_W_RET_VAL(value_ptr, std::make_shared<abstract::Shape>());
|
||||
auto rank_size = static_cast<int>(GetValue<int64_t>(value_ptr));
|
||||
if (rank_size == 0) {
|
||||
MS_LOG(ERROR) << "For '" << primitive->name() << "', the 'rank_size' can not be zero, but got " << rank_size;
|
||||
return std::make_shared<abstract::Shape>();
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 'rank_size' can not be zero, but got "
|
||||
<< rank_size;
|
||||
}
|
||||
auto abstract_shape = input_args[kIndex0]->BuildShape();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(abstract_shape, std::make_shared<abstract::Shape>());
|
||||
|
@ -68,9 +68,9 @@ class ReduceScatterInfer : public abstract::OpInferBase {
|
|||
}
|
||||
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abstract_shape)[kShape];
|
||||
if (shape.empty() || shape[0] % rank_size != 0) {
|
||||
MS_LOG(ERROR) << "the first dimension for 'input_shape' must be divided by 'rank_size', but got input_shape[0]: "
|
||||
<< shape[0] << ", rank_size: " << rank_size;
|
||||
return std::make_shared<abstract::Shape>();
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "the first dimension for 'input_shape' must be divided by 'rank_size', but got input_shape[0]: " << shape[0]
|
||||
<< ", rank_size: " << rank_size;
|
||||
}
|
||||
auto out_shape = shape;
|
||||
out_shape[0] = static_cast<int64_t>(shape[0] / rank_size);
|
||||
|
|
|
@ -39,13 +39,6 @@ abstract::ShapePtr TruncatedNormalInferShape(const PrimitivePtr &primitive,
|
|||
if (IsDynamicRank(shape_input)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
if (shape_input.size() != kInputDims) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
if (shape_input[kInputIndex0] < kInputSizes) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input tensor shape must >= 2, but got "
|
||||
<< shape_input[kInputIndex0];
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInpuDims = 1;
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
|
|
|
@ -147,7 +147,10 @@ def get_bprop_pad_v3(self):
|
|||
|
||||
def bprop(x, paddings, constant_values, out, dout):
|
||||
if mode == 'constant':
|
||||
neg_paddings = tuple(-x for x in paddings)
|
||||
if isinstance(paddings, (list, tuple)):
|
||||
neg_paddings = tuple(-x for x in paddings)
|
||||
else:
|
||||
neg_paddings = -paddings
|
||||
dx = pad_v3_grad(dout, neg_paddings, zeros_like(constant_values))
|
||||
else:
|
||||
dx = pad_v3_grad(dout, paddings)
|
||||
|
|
Loading…
Reference in New Issue