!38953 [NoRepeatNGram] add cpu op and npu dynamic shape

Merge pull request !38953 from bichaoyang/master
This commit is contained in:
i-robot 2022-07-28 01:31:11 +00:00 committed by Gitee
commit fb990c85d3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 598 additions and 21 deletions

View File

@ -0,0 +1,148 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/no_repeat_ngram_cpu_kernel.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <map>
#include <cmath>
#include <limits>
#include "mindspore/core/ops/no_repeat_ngram.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kNoRepeatNGramInputsNum = 2;
constexpr size_t kNoRepeatNGramOutputsNum = 1;
constexpr size_t kNoRepeatNGramDim = 3;
constexpr int64_t kNoRepeatNGramParamValue = 0;
} // namespace
bool NoRepeatNGramCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::NoRepeatNGram>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "Cast NoRepeatNGram ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "NoRepeatNGram does not support this kernel data type: " << kernel_attr;
}
base_operator_ = base_operator;
kernel_func_ = func_list_[index].second;
return true;
}
int NoRepeatNGramCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
MS_LOG(WARNING) << kernel_name_ << " resize failed.";
return ret;
}
state_seq_shape_ = inputs[kIndex0]->GetShapeVector();
log_probs_shape_ = inputs[kIndex1]->GetShapeVector();
ngram_size_ = GetValue<int64_t>(base_operator->GetAttr("ngram_size"));
return 0;
}
void NoRepeatNGramCpuKernelMod::CheckAndInitParams() {
size_t state_seq_shape_len = state_seq_shape_.size();
size_t log_probs_shape_len = log_probs_shape_.size();
if ((state_seq_shape_len != log_probs_shape_len) || (state_seq_shape_len != kNoRepeatNGramDim)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'state_seq' " << state_seq_shape_len
<< " and 'log_probs' " << log_probs_shape_len << " is illegal.";
}
state_dim_ = state_seq_shape_[state_seq_shape_len - 1];
output_dim_ = log_probs_shape_[log_probs_shape_len - 1];
if ((ngram_size_ <= kNoRepeatNGramParamValue) || (ngram_size_ >= state_dim_)) {
MS_LOG(EXCEPTION) << "Param ngram_size is illegal";
}
num_seq_ = 1;
for (size_t i = 0; i < log_probs_shape_len - 1; i++) {
num_seq_ *= log_probs_shape_[i];
}
output_size_ = num_seq_ * output_dim_;
}
template <typename T>
bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
auto *state_seq = reinterpret_cast<int32_t *>(inputs[kIndex0]->addr);
auto *log_probs = reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
CheckAndInitParams();
for (size_t i = 0; i < LongToSize(output_size_); i++) {
output[i] = log_probs[i];
}
std::vector<int32_t> array_dim(state_dim_);
std::vector<int32_t> array_ngram(ngram_size_ - 1);
for (int64_t i = 0; i < num_seq_; i++) {
int64_t src_index_i = state_dim_ * i;
int64_t output_index_i = output_dim_ * i;
for (int64_t k = 0; k < state_dim_; k++) {
int64_t src_index_k = k + src_index_i;
array_dim[k] = static_cast<int32_t>(state_seq[src_index_k]);
if (k > (state_dim_ - ngram_size_)) {
array_ngram[k + ngram_size_ - state_dim_ - 1] = static_cast<int32_t>(state_seq[src_index_k]);
}
}
for (int64_t j = 0; j < state_dim_ - ngram_size_ + 1; j++) {
if (equal(array_ngram.begin(), array_ngram.end(), array_dim.begin() + j)) {
int64_t output_index_j = static_cast<int64_t>(array_dim[j + ngram_size_ - 1]);
output[output_index_i + output_index_j] = -(std::numeric_limits<T>::max)();
}
}
}
return true;
}
std::vector<std::pair<KernelAttr, NoRepeatNGramCpuKernelMod::NoRepeatNGramFunc>> NoRepeatNGramCpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&NoRepeatNGramCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&NoRepeatNGramCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&NoRepeatNGramCpuKernelMod::LaunchKernel<float16>}};
std::vector<KernelAttr> NoRepeatNGramCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, NoRepeatNGramFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NoRepeatNGram, NoRepeatNGramCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class NoRepeatNGramCpuKernelMod : public NativeCpuKernelMod {
public:
NoRepeatNGramCpuKernelMod() = default;
~NoRepeatNGramCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckAndInitParams();
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using NoRepeatNGramFunc =
std::function<bool(NoRepeatNGramCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, NoRepeatNGramFunc>> func_list_;
NoRepeatNGramFunc kernel_func_{nullptr};
BaseOperatorPtr base_operator_;
std::vector<int64_t> state_seq_shape_;
std::vector<int64_t> log_probs_shape_;
int64_t ngram_size_{1};
int64_t output_size_{1};
int64_t state_dim_{1};
int64_t output_dim_{1};
int64_t num_seq_{1};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_

View File

@ -545,6 +545,7 @@ GVAR_DEF(PrimitivePtr, kPrimMatrixBandPart, std::make_shared<Primitive>(kMatrixB
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));
GVAR_DEF(PrimitivePtr, kPrimNoRepeatNGram, std::make_shared<Primitive>("NoRepeatNGram"));
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
GVAR_DEF(PrimitivePtr, kPrimImag, std::make_shared<Primitive>(kImag));

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/no_repeat_ngram.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr NoRepeatNGramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
if (min_shape.size() != 0 && max_shape.size() != 0) {
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr NoRepeatNGramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> seq_types;
(void)seq_types.emplace("seq_type", input_args[0]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(seq_types, {kInt32}, prim->name());
std::set<TypePtr> valid_params_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> log_types;
(void)log_types.emplace("log_types", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(log_types, valid_params_types, prim->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(NoRepeatNGram, BaseOperator);
AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto type = NoRepeatNGramInferType(primitive, input_args);
auto shape = NoRepeatNGramInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(NoRepeatNGram, prim::kPrimNoRepeatNGram, NoRepeatNGramInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NO_REPEAT_NGRAM_H
#define MINDSPORE_NO_REPEAT_NGRAM_H
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameNoRepeatNGram = "NoRepeatNGram";
/// \brief
class MIND_API NoRepeatNGram : public BaseOperator {
public:
MIND_API_BASE_MEMBER(NoRepeatNGram);
/// \brief Constructor.
NoRepeatNGram() : BaseOperator(kNameNoRepeatNGram) { InitIOName({"state_seq", "log_probs"}, {"out"}); }
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_NO_REPEAT_NGRAM_H

View File

@ -38,6 +38,30 @@ from ..composite import _VmapGeneralRule
from ..operations.array_ops import TensorScatterElements
@vmap_rules_getters.register(P.NoRepeatNGram)
def get_no_repeat_ngram_vmap_rule(prim, axis_size):
"""VmapRule for `NoRepeatNGram` operation."""
def vmap_rule(state_seq_bdim, log_probs_bdim):
is_all_none, result = vmap_general_preprocess(prim, state_seq_bdim, log_probs_bdim)
if is_all_none:
return result
state_seq, state_seq_dim = state_seq_bdim
log_probs, log_probs_dim = log_probs_bdim
state_seq = _bdim_at_front(state_seq, state_seq_dim, axis_size)
log_probs = _bdim_at_front(log_probs, log_probs_dim, axis_size)
s_ori_shape = F.shape(state_seq)
l_ori_shape = F.shape(log_probs)
state_seq = F.reshape(state_seq, (-1,) + s_ori_shape[-2:])
log_probs = F.reshape(log_probs, (-1,) + l_ori_shape[-2:])
out = prim(state_seq, log_probs)
out = F.reshape(out, l_ori_shape)
return (out, 0)
return vmap_rule
@vmap_rules_getters.register("Cast")
def get_cast_vmap_rule(prim, axis_size):
"""VmapRule for `Cast` operation."""

View File

@ -144,7 +144,7 @@ class NoRepeatNGram(PrimitiveWithInfer):
TypeError: If neither `state_seq` nor `log_probs` is a Tensor.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)

View File

@ -29,32 +29,52 @@ class Net(nn.Cell):
super(Net, self).__init__()
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
def construct(self, state_seq, log_probs):
return self.no_repeat_ngram(state_seq, log_probs)
def construct(self, s, l):
return self.no_repeat_ngram(s, l)
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
[9, 3, 9, 5, 4, 1, 5],
[4, 7, 9, 1, 9, 6, 1],
[7, 6, 4, 2, 9, 1, 5],
[7, 5, 8, 9, 9, 3, 9]],
[[7, 7, 2, 7, 9, 9, 4],
[3, 4, 7, 4, 7, 6, 8],
[1, 9, 5, 7, 6, 9, 3],
[4, 8, 6, 4, 5, 6, 4],
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
expect_log_probs = log_probs.asnumpy().copy()
expect_log_probs[0, 0, 1] = -FLT_MAX
expect_log_probs[0, 0, 5] = -FLT_MAX
expect_log_probs[1, 3, 5] = -FLT_MAX
expect_log_probs[1, 4, 8] = -FLT_MAX
def test_net():
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
[9, 3, 9, 5, 4, 1, 5],
[4, 7, 9, 1, 9, 6, 1],
[7, 6, 4, 2, 9, 1, 5],
[7, 5, 8, 9, 9, 3, 9]],
[[7, 7, 2, 7, 9, 9, 4],
[3, 4, 7, 4, 7, 6, 8],
[1, 9, 5, 7, 6, 9, 3],
[4, 8, 6, 4, 5, 6, 4],
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
expect_log_probs = log_probs.asnumpy().copy()
expect_log_probs[0, 0, 1] = -FLT_MAX
expect_log_probs[0, 0, 5] = -FLT_MAX
expect_log_probs[1, 3, 5] = -FLT_MAX
expect_log_probs[1, 4, 8] = -FLT_MAX
"""
Feature: test NoRepeatNGram on Ascend.
Description: inputs with batch.
Expectation: the result match with expect.
"""
net = Net(ngram_size=3)
output = net(state_seq, log_probs)
print(expect_log_probs)
print(output)
assert np.array_equal(expect_log_probs, output.asnumpy())
def test_net_dynamic_shape():
"""
Feature: test NoRepeatNGram dynamic shape on Ascend.
Description: inputs with batch.
Expectation: the result match with expect.
"""
net = Net(ngram_size=3)
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mindspore.int32)
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mindspore.float32)
net.set_inputs(place_holder_x, place_holder_v)
output = net(state_seq, log_probs)
assert np.array_equal(expect_log_probs, output.asnumpy())

View File

@ -0,0 +1,191 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.functional import vmap
from mindspore.common.api import ms_function
from mindspore import dtype as mstype
FLT_MAX = 3.4028235e+38
class CpuNet(nn.Cell):
def __init__(self, ngram_size):
super(CpuNet, self).__init__()
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
def construct(self, state_seq, log_probs):
return self.no_repeat_ngram(state_seq, log_probs)
base_state_seq = np.array([[[1, 2, 1, 2, 5, 1, 2],
[9, 3, 9, 5, 4, 1, 5],
[4, 7, 9, 1, 9, 6, 1],
[7, 6, 4, 2, 9, 1, 5],
[7, 5, 8, 9, 9, 3, 9]],
[[7, 7, 2, 7, 9, 9, 4],
[3, 4, 7, 4, 7, 6, 8],
[1, 9, 5, 7, 6, 9, 3],
[4, 8, 6, 4, 5, 6, 4],
[4, 8, 8, 4, 3, 4, 8]]], dtype=np.int32)
base_log_probs = np.random.random((2, 5, 10)).astype(np.float32)
base_expect_log_probs = base_log_probs.copy()
base_expect_log_probs[0, 0, 1] = -FLT_MAX
base_expect_log_probs[0, 0, 5] = -FLT_MAX
base_expect_log_probs[1, 3, 5] = -FLT_MAX
base_expect_log_probs[1, 4, 8] = -FLT_MAX
def test_net():
"""
Feature: test NoRepeatNGram on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
state_seq = Tensor(base_state_seq)
log_probs = Tensor(base_log_probs)
expect_log_probs = base_expect_log_probs
net = CpuNet(ngram_size=3)
output = net(state_seq, log_probs)
assert np.array_equal(expect_log_probs, output.asnumpy())
def test_net_dynamic_shape():
"""
Feature: test NoRepeatNGram dynamic shape on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
state_seq = Tensor(base_state_seq)
log_probs = Tensor(base_log_probs)
expect_log_probs = base_expect_log_probs
net = CpuNet(ngram_size=3)
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mstype.int32)
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mstype.float32)
net.set_inputs(place_holder_x, place_holder_v)
output = net(state_seq, log_probs)
assert np.array_equal(expect_log_probs, output.asnumpy())
def vmap_case():
class Net(nn.Cell):
def __init__(self, ngram_size):
super(Net, self).__init__()
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
def construct(self, seq, log):
return self.no_repeat_ngram(seq, log)
class VmapNet(nn.Cell):
def __init__(self, net, in_axes, out_axes):
super(VmapNet, self).__init__()
self.net = net
self.in_axes = in_axes
self.out_axes = out_axes
def construct(self, state_seq, log_probs):
return vmap(self.net, self.in_axes, self.out_axes)(state_seq, log_probs)
@ms_function
def for_net(state_seq, log_probs, ngram_size):
# split and concat along dimension 0
output = []
for i in range(state_seq.shape[0]):
out = P.NoRepeatNGram(ngram_size)(state_seq[i], log_probs[i])
output.append(out)
return F.stack(output)
state_seq = Tensor(np.tile(base_state_seq, (2, 1, 1, 1)))
log_probs = Tensor(np.tile(base_log_probs, (2, 1, 1, 1)))
output = VmapNet(Net(3), (0, 0), 0)(state_seq, log_probs)
fornet_output = for_net(state_seq, log_probs, 3)
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
def vmap_nested_case():
class Net(nn.Cell):
def __init__(self, ngram_size):
super(Net, self).__init__()
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
def construct(self, seq, log):
return self.no_repeat_ngram(seq, log)
class WrapNet(nn.Cell):
def __init__(self, net, inin_axes, inout_axes, outin_axes, outout_axes):
super(WrapNet, self).__init__()
self.net = net
self.ii = inin_axes
self.io = inout_axes
self.oi = outin_axes
self.oo = outout_axes
def construct(self, state_seq, log_probs):
return vmap(vmap(self.net, self.ii, self.io), self.oi, self.oo)(state_seq, log_probs)
@ms_function
def for_net(state_seq, log_probs, ngram_size):
# split and concat along dimension 0 and 1
output = []
for i in range(state_seq.shape[0]):
inner_output = []
for j in range(state_seq.shape[1]):
out = P.NoRepeatNGram(ngram_size)(state_seq[i][j], log_probs[i][j])
inner_output.append(out)
output.append(F.stack(inner_output))
return F.stack(output)
state_seq = Tensor(np.tile(base_state_seq, (2, 2, 1, 1, 1)))
log_probs = Tensor(np.tile(base_log_probs, (2, 2, 1, 1, 1)))
output = WrapNet(Net(3), (0, 0), 0, (0, 0), 0)(state_seq, log_probs)
fornet_output = for_net(state_seq, log_probs, 3)
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vmap_cpu():
"""
Feature: test NoRepeatNGram vmap on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
vmap_case()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vmap_cpu_nested():
"""
Feature: test nested NoRepeatNGram vmap on CPU.
Description: inputs with batch.
Expectation: the result match with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
vmap_nested_case()