[fix][assistant][I47V0O] fix bug about the wrong output in the operator ComplexNorm

This commit is contained in:
robert_luo_yibo 2021-09-01 17:55:38 +08:00
parent fc48f21bc2
commit 3d58da9023
7 changed files with 26 additions and 67 deletions

View File

@ -420,66 +420,28 @@ Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
template <typename T>
Status Norm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float power) {
// calcutate total complex num
int32_t dim = input->shape().Size();
int32_t total_num = 1;
for (int32_t i = 0; i < (dim - 1); i++) {
total_num *= (input->shape()[i]);
}
// calculate the output dimension
auto input_size = input->shape().AsVector();
int32_t dim_back = input_size.back();
CHECK_FAIL_RETURN_UNEXPECTED(
dim_back == 2, "ComplexNorm: expect complex input of shape <..., 2>, but got: " + std::to_string(dim_back));
input_size.pop_back();
int32_t complex_num = input_size.back();
int32_t iter_num = total_num / complex_num;
// TensorShape out_put_shape{}
input_size.pop_back();
input_size.emplace_back(2);
TensorShape out_shape = TensorShape(input_size);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(out_shape, input->type(), output));
// slice input into real tensor and imaginary tensor
std::shared_ptr<Tensor> re_tensor;
std::shared_ptr<Tensor> im_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({total_num, 1}), input->type(), &re_tensor));
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({total_num, 1}), input->type(), &im_tensor));
std::vector<SliceOption> slice_re = {};
std::vector<SliceOption> slice_im = {};
for (int32_t i = 0; i < (dim - 1); i++) {
slice_re.emplace_back(SliceOption(true));
slice_im.emplace_back(SliceOption(true));
}
slice_re.emplace_back(SliceOption(std::vector<dsize_t>{0}));
slice_im.emplace_back(SliceOption(std::vector<dsize_t>{1}));
RETURN_IF_NOT_OK(input->Slice(&re_tensor, slice_re));
RETURN_IF_NOT_OK(input->Slice(&im_tensor, slice_im));
// calculate norm, using: .pow(2.).sum(-1).pow(0.5 * power)
auto itr_out = (*output)->begin<T>();
auto itr_re = re_tensor->begin<T>();
auto itr_im = im_tensor->begin<T>();
for (int32_t i = 0; i < iter_num; i++) {
double re = 0.0;
double im = 0.0;
for (int32_t j = complex_num * i; j < complex_num * (i + 1); j++) {
double a = static_cast<double>(*itr_re);
double b = static_cast<double>(*itr_im);
re = re + (pow(a, 2) - pow(b, 2));
im = im + (2 * a * b);
++itr_re;
++itr_im;
}
std::complex<double> comp(re, im);
comp = std::pow(comp, (0.5 * power));
*itr_out = static_cast<T>(comp.real());
++itr_out;
*itr_out = static_cast<T>(comp.imag());
++itr_out;
auto itr_in = input->begin<T>();
for (; itr_out != (*output)->end<T>(); ++itr_out) {
auto a = static_cast<T>(*itr_in);
++itr_in;
auto b = static_cast<T>(*itr_in);
++itr_in;
auto res = pow(a, 2) + pow(b, 2);
*itr_out = static_cast<T>(pow(res, (0.5 * power)));
}
RETURN_IF_NOT_OK((*output)->Reshape(out_shape));
return Status::OK();
}
@ -488,7 +450,6 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
if (input->type().value() >= DataType::DE_INT8 && input->type().value() <= DataType::DE_FLOAT16) {
// convert the data type to float
std::shared_ptr<Tensor> input_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_FLOAT32), &input_tensor));
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
Norm<float>(input_tensor, output, power);

View File

@ -249,7 +249,7 @@ Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
/// \brief Compute the norm of complex tensor input.
/// \param power Power of the norm description (optional).
/// \param input Tensor shape of <..., complex=2>.
/// \param output Tensor shape of <..., complex=2>.
/// \param output Tensor shape of <..., >.
/// \return Status code.
Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float power);

View File

@ -36,8 +36,6 @@ Status ComplexNormOp::OutputShape(const std::vector<TensorShape> &inputs, std::v
outputs.clear();
auto input_size = inputs[0].AsVector();
input_size.pop_back();
input_size.pop_back();
input_size.emplace_back(2);
TensorShape out = TensorShape(input_size);
outputs.emplace_back(out);
if (!outputs.empty()) return Status::OK();

View File

@ -235,6 +235,7 @@ def check_complex_norm(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[power], _ = parse_user_args(method, *args, **kwargs)
type_check(power, (int, float), "power")
check_non_negative_float32(power, "power")
return method(self, *args, **kwargs)

View File

@ -629,7 +629,7 @@ TEST_F(MindDataTestPipeline, TestComplexNormBasic) {
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {3, 2, 2};
std::vector<int64_t> expected = {3, 2, 4};
int i = 0;
while (row.size() != 0) {

View File

@ -651,7 +651,7 @@ TEST_F(MindDataTestExecute, TestLowpassBiquadEager) {
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {23.5, 13.2, 62.5, 27.1, 15.5, 30.3, 44.9, 25.0,
11.3, 37.4, 67.1, 33.8, 73.4, 53.3, 93.7, 31.1};
Tensor::CreateFromVector(test_vector, TensorShape({4,4}), &test);
Tensor::CreateFromVector(test_vector, TensorShape({4, 4}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
std::shared_ptr<TensorTransform> lowpass_biquad(new audio::LowpassBiquad({sample_rate, cutoff_freq, Q}));
auto transform = Execute({lowpass_biquad});
@ -664,9 +664,9 @@ TEST_F(MindDataTestExecute, TestLowpassBiuqadParamCheckQ) {
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482,
0.3007, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5,3}), &test);
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482, 0.3007, 0.9054,
0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5, 3}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check Q
std::shared_ptr<TensorTransform> lowpass_biquad_op = std::make_shared<audio::LowpassBiquad>(44100, 3000.5, 0);
@ -680,9 +680,8 @@ TEST_F(MindDataTestExecute, TestLowpassBiuqadParamCheckSampleRate) {
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.5, 4.6, 2.2, 0.6, 1.9, 4.7,
2.3, 4.9, 4.7, 0.5, 0.8, 0.9};
Tensor::CreateFromVector(test_vector, TensorShape({6,2}), &test);
std::vector<double> test_vector = {0.5, 4.6, 2.2, 0.6, 1.9, 4.7, 2.3, 4.9, 4.7, 0.5, 0.8, 0.9};
Tensor::CreateFromVector(test_vector, TensorShape({6, 2}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check sample_rate
std::shared_ptr<TensorTransform> lowpass_biquad_op = std::make_shared<audio::LowpassBiquad>(0, 2000.5, 0.7);

View File

@ -35,11 +35,11 @@ def test_complex_norm():
dataset = ds.GeneratorDataset(source=gen, column_names=["multi_dim_data"])
dataset = dataset.map(operations=audio.ComplexNorm(2.), input_columns=["multi_dim_data"])
dataset = dataset.map(operations=audio.ComplexNorm(2), input_columns=["multi_dim_data"])
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
assert i["multi_dim_data"].shape == (2,)
expected = np.array([-5., 46.])
assert i["multi_dim_data"].shape == (3,)
expected = np.array([2., 13., 32.])
assert np.array_equal(i["multi_dim_data"], expected)
logger.info("Finish testing ComplexNorm.")
@ -52,9 +52,9 @@ def test_complex_norm_eager():
logger.info("Test ComplexNorm callable.")
input_t = np.array([[1.0, 1.0], [2.0, 3.0], [4.0, 4.0]])
output_t = audio.ComplexNorm(3)(input_t)
assert output_t.shape == (2,)
expected = np.array([-255.6179621231501, 183.64515392460598])
output_t = audio.ComplexNorm()(input_t)
assert output_t.shape == (3,)
expected = np.array([1.4142135623730951, 3.605551275463989, 5.656854249492381])
assert np.array_equal(output_t, expected)
logger.info("Finish testing ComplexNorm.")
@ -69,7 +69,7 @@ def test_complex_norm_uncallable():
try:
input_t = random.rand(2, 4, 3, 2)
output_t = audio.ComplexNorm(-3.)(input_t)
assert output_t.shape == (2, 4, 2)
assert output_t.shape == (2, 4, 3)
except ValueError as e:
assert 'Input power is not within the required interval of [0, 16777216].' in str(e)