forked from mindspore-Ecosystem/mindspore
fix incorrect in_channel size error messages
This commit is contained in:
parent
09db51a797
commit
d3330a1087
|
@ -333,9 +333,10 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
w_axis = 2;
|
||||
}
|
||||
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis]
|
||||
<< " (channels) must be divisible by group = " << group;
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
|
||||
((x_shape[c_axis] / group) != w_shape[c_axis])) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got "
|
||||
<< (x_shape[c_axis] / group);
|
||||
}
|
||||
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
|
||||
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
|
||||
|
|
|
@ -53,5 +53,6 @@ def test_lenet5_exception():
|
|||
predict = Tensor(in1)
|
||||
label = Tensor(in2)
|
||||
net = train_step_with_loss_warp(LeNet5())
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_executor.compile(net, predict, label)
|
||||
assert "x_shape[C_in] / group must equal to w_shape[C_in] = " in str(info.value)
|
||||
|
|
Loading…
Reference in New Issue