forked from mindspore-Ecosystem/mindspore
!38953 [NoRepeatNGram] add cpu op and npu dynamic shape
Merge pull request !38953 from bichaoyang/master
This commit is contained in:
commit
fb990c85d3
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/no_repeat_ngram_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
#include "mindspore/core/ops/no_repeat_ngram.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kNoRepeatNGramInputsNum = 2;
|
||||
constexpr size_t kNoRepeatNGramOutputsNum = 1;
|
||||
constexpr size_t kNoRepeatNGramDim = 3;
|
||||
constexpr int64_t kNoRepeatNGramParamValue = 0;
|
||||
} // namespace
|
||||
|
||||
bool NoRepeatNGramCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::NoRepeatNGram>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "Cast NoRepeatNGram ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "NoRepeatNGram does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
base_operator_ = base_operator;
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int NoRepeatNGramCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
int ret = 0;
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " resize failed.";
|
||||
return ret;
|
||||
}
|
||||
state_seq_shape_ = inputs[kIndex0]->GetShapeVector();
|
||||
log_probs_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
ngram_size_ = GetValue<int64_t>(base_operator->GetAttr("ngram_size"));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void NoRepeatNGramCpuKernelMod::CheckAndInitParams() {
|
||||
size_t state_seq_shape_len = state_seq_shape_.size();
|
||||
size_t log_probs_shape_len = log_probs_shape_.size();
|
||||
if ((state_seq_shape_len != log_probs_shape_len) || (state_seq_shape_len != kNoRepeatNGramDim)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'state_seq' " << state_seq_shape_len
|
||||
<< " and 'log_probs' " << log_probs_shape_len << " is illegal.";
|
||||
}
|
||||
state_dim_ = state_seq_shape_[state_seq_shape_len - 1];
|
||||
output_dim_ = log_probs_shape_[log_probs_shape_len - 1];
|
||||
if ((ngram_size_ <= kNoRepeatNGramParamValue) || (ngram_size_ >= state_dim_)) {
|
||||
MS_LOG(EXCEPTION) << "Param ngram_size is illegal";
|
||||
}
|
||||
num_seq_ = 1;
|
||||
for (size_t i = 0; i < log_probs_shape_len - 1; i++) {
|
||||
num_seq_ *= log_probs_shape_[i];
|
||||
}
|
||||
output_size_ = num_seq_ * output_dim_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
|
||||
auto *state_seq = reinterpret_cast<int32_t *>(inputs[kIndex0]->addr);
|
||||
auto *log_probs = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
CheckAndInitParams();
|
||||
|
||||
for (size_t i = 0; i < LongToSize(output_size_); i++) {
|
||||
output[i] = log_probs[i];
|
||||
}
|
||||
|
||||
std::vector<int32_t> array_dim(state_dim_);
|
||||
std::vector<int32_t> array_ngram(ngram_size_ - 1);
|
||||
for (int64_t i = 0; i < num_seq_; i++) {
|
||||
int64_t src_index_i = state_dim_ * i;
|
||||
int64_t output_index_i = output_dim_ * i;
|
||||
for (int64_t k = 0; k < state_dim_; k++) {
|
||||
int64_t src_index_k = k + src_index_i;
|
||||
array_dim[k] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||
if (k > (state_dim_ - ngram_size_)) {
|
||||
array_ngram[k + ngram_size_ - state_dim_ - 1] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||
}
|
||||
}
|
||||
for (int64_t j = 0; j < state_dim_ - ngram_size_ + 1; j++) {
|
||||
if (equal(array_ngram.begin(), array_ngram.end(), array_dim.begin() + j)) {
|
||||
int64_t output_index_j = static_cast<int64_t>(array_dim[j + ngram_size_ - 1]);
|
||||
output[output_index_i + output_index_j] = -(std::numeric_limits<T>::max)();
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, NoRepeatNGramCpuKernelMod::NoRepeatNGramFunc>> NoRepeatNGramCpuKernelMod::func_list_ =
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&NoRepeatNGramCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&NoRepeatNGramCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&NoRepeatNGramCpuKernelMod::LaunchKernel<float16>}};
|
||||
|
||||
std::vector<KernelAttr> NoRepeatNGramCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, NoRepeatNGramFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NoRepeatNGram, NoRepeatNGramCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class NoRepeatNGramCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
NoRepeatNGramCpuKernelMod() = default;
|
||||
~NoRepeatNGramCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckAndInitParams();
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using NoRepeatNGramFunc =
|
||||
std::function<bool(NoRepeatNGramCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
static std::vector<std::pair<KernelAttr, NoRepeatNGramFunc>> func_list_;
|
||||
NoRepeatNGramFunc kernel_func_{nullptr};
|
||||
BaseOperatorPtr base_operator_;
|
||||
std::vector<int64_t> state_seq_shape_;
|
||||
std::vector<int64_t> log_probs_shape_;
|
||||
int64_t ngram_size_{1};
|
||||
int64_t output_size_{1};
|
||||
int64_t state_dim_{1};
|
||||
int64_t output_dim_{1};
|
||||
int64_t num_seq_{1};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
|
@ -545,6 +545,7 @@ GVAR_DEF(PrimitivePtr, kPrimMatrixBandPart, std::make_shared<Primitive>(kMatrixB
|
|||
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNoRepeatNGram, std::make_shared<Primitive>("NoRepeatNGram"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
|
||||
GVAR_DEF(PrimitivePtr, kPrimImag, std::make_shared<Primitive>(kImag));
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/no_repeat_ngram.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr NoRepeatNGramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
TypePtr NoRepeatNGramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> seq_types;
|
||||
(void)seq_types.emplace("seq_type", input_args[0]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(seq_types, {kInt32}, prim->name());
|
||||
std::set<TypePtr> valid_params_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> log_types;
|
||||
(void)log_types.emplace("log_types", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(log_types, valid_params_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(NoRepeatNGram, BaseOperator);
|
||||
AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto type = NoRepeatNGramInferType(primitive, input_args);
|
||||
auto shape = NoRepeatNGramInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NoRepeatNGram, prim::kPrimNoRepeatNGram, NoRepeatNGramInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NO_REPEAT_NGRAM_H
|
||||
#define MINDSPORE_NO_REPEAT_NGRAM_H
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameNoRepeatNGram = "NoRepeatNGram";
|
||||
/// \brief
|
||||
class MIND_API NoRepeatNGram : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(NoRepeatNGram);
|
||||
/// \brief Constructor.
|
||||
NoRepeatNGram() : BaseOperator(kNameNoRepeatNGram) { InitIOName({"state_seq", "log_probs"}, {"out"}); }
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_NO_REPEAT_NGRAM_H
|
|
@ -38,6 +38,30 @@ from ..composite import _VmapGeneralRule
|
|||
from ..operations.array_ops import TensorScatterElements
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.NoRepeatNGram)
|
||||
def get_no_repeat_ngram_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `NoRepeatNGram` operation."""
|
||||
|
||||
def vmap_rule(state_seq_bdim, log_probs_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, state_seq_bdim, log_probs_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
state_seq, state_seq_dim = state_seq_bdim
|
||||
log_probs, log_probs_dim = log_probs_bdim
|
||||
state_seq = _bdim_at_front(state_seq, state_seq_dim, axis_size)
|
||||
log_probs = _bdim_at_front(log_probs, log_probs_dim, axis_size)
|
||||
s_ori_shape = F.shape(state_seq)
|
||||
l_ori_shape = F.shape(log_probs)
|
||||
state_seq = F.reshape(state_seq, (-1,) + s_ori_shape[-2:])
|
||||
log_probs = F.reshape(log_probs, (-1,) + l_ori_shape[-2:])
|
||||
out = prim(state_seq, log_probs)
|
||||
out = F.reshape(out, l_ori_shape)
|
||||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register("Cast")
|
||||
def get_cast_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `Cast` operation."""
|
||||
|
|
|
@ -144,7 +144,7 @@ class NoRepeatNGram(PrimitiveWithInfer):
|
|||
TypeError: If neither `state_seq` nor `log_probs` is a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
|
||||
|
|
|
@ -29,32 +29,52 @@ class Net(nn.Cell):
|
|||
super(Net, self).__init__()
|
||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||
|
||||
def construct(self, state_seq, log_probs):
|
||||
return self.no_repeat_ngram(state_seq, log_probs)
|
||||
def construct(self, s, l):
|
||||
return self.no_repeat_ngram(s, l)
|
||||
|
||||
|
||||
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
||||
[9, 3, 9, 5, 4, 1, 5],
|
||||
[4, 7, 9, 1, 9, 6, 1],
|
||||
[7, 6, 4, 2, 9, 1, 5],
|
||||
[7, 5, 8, 9, 9, 3, 9]],
|
||||
[[7, 7, 2, 7, 9, 9, 4],
|
||||
[3, 4, 7, 4, 7, 6, 8],
|
||||
[1, 9, 5, 7, 6, 9, 3],
|
||||
[4, 8, 6, 4, 5, 6, 4],
|
||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
||||
|
||||
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
|
||||
expect_log_probs = log_probs.asnumpy().copy()
|
||||
expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||
expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||
|
||||
|
||||
def test_net():
|
||||
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
||||
[9, 3, 9, 5, 4, 1, 5],
|
||||
[4, 7, 9, 1, 9, 6, 1],
|
||||
[7, 6, 4, 2, 9, 1, 5],
|
||||
[7, 5, 8, 9, 9, 3, 9]],
|
||||
[[7, 7, 2, 7, 9, 9, 4],
|
||||
[3, 4, 7, 4, 7, 6, 8],
|
||||
[1, 9, 5, 7, 6, 9, 3],
|
||||
[4, 8, 6, 4, 5, 6, 4],
|
||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
||||
|
||||
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
|
||||
expect_log_probs = log_probs.asnumpy().copy()
|
||||
expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||
expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||
expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||
|
||||
"""
|
||||
Feature: test NoRepeatNGram on Ascend.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
net = Net(ngram_size=3)
|
||||
output = net(state_seq, log_probs)
|
||||
|
||||
print(expect_log_probs)
|
||||
print(output)
|
||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||
|
||||
|
||||
def test_net_dynamic_shape():
|
||||
"""
|
||||
Feature: test NoRepeatNGram dynamic shape on Ascend.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
net = Net(ngram_size=3)
|
||||
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mindspore.int32)
|
||||
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mindspore.float32)
|
||||
net.set_inputs(place_holder_x, place_holder_v)
|
||||
output = net(state_seq, log_probs)
|
||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.functional import vmap
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
FLT_MAX = 3.4028235e+38
|
||||
|
||||
|
||||
class CpuNet(nn.Cell):
|
||||
def __init__(self, ngram_size):
|
||||
super(CpuNet, self).__init__()
|
||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||
|
||||
def construct(self, state_seq, log_probs):
|
||||
return self.no_repeat_ngram(state_seq, log_probs)
|
||||
|
||||
|
||||
base_state_seq = np.array([[[1, 2, 1, 2, 5, 1, 2],
|
||||
[9, 3, 9, 5, 4, 1, 5],
|
||||
[4, 7, 9, 1, 9, 6, 1],
|
||||
[7, 6, 4, 2, 9, 1, 5],
|
||||
[7, 5, 8, 9, 9, 3, 9]],
|
||||
[[7, 7, 2, 7, 9, 9, 4],
|
||||
[3, 4, 7, 4, 7, 6, 8],
|
||||
[1, 9, 5, 7, 6, 9, 3],
|
||||
[4, 8, 6, 4, 5, 6, 4],
|
||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=np.int32)
|
||||
base_log_probs = np.random.random((2, 5, 10)).astype(np.float32)
|
||||
base_expect_log_probs = base_log_probs.copy()
|
||||
base_expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||
base_expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||
base_expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||
base_expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||
|
||||
|
||||
def test_net():
|
||||
"""
|
||||
Feature: test NoRepeatNGram on CPU.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
state_seq = Tensor(base_state_seq)
|
||||
log_probs = Tensor(base_log_probs)
|
||||
expect_log_probs = base_expect_log_probs
|
||||
|
||||
net = CpuNet(ngram_size=3)
|
||||
output = net(state_seq, log_probs)
|
||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||
|
||||
|
||||
def test_net_dynamic_shape():
|
||||
"""
|
||||
Feature: test NoRepeatNGram dynamic shape on CPU.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
state_seq = Tensor(base_state_seq)
|
||||
log_probs = Tensor(base_log_probs)
|
||||
expect_log_probs = base_expect_log_probs
|
||||
|
||||
net = CpuNet(ngram_size=3)
|
||||
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mstype.int32)
|
||||
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mstype.float32)
|
||||
net.set_inputs(place_holder_x, place_holder_v)
|
||||
output = net(state_seq, log_probs)
|
||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||
|
||||
|
||||
def vmap_case():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, ngram_size):
|
||||
super(Net, self).__init__()
|
||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||
|
||||
def construct(self, seq, log):
|
||||
return self.no_repeat_ngram(seq, log)
|
||||
|
||||
class VmapNet(nn.Cell):
|
||||
def __init__(self, net, in_axes, out_axes):
|
||||
super(VmapNet, self).__init__()
|
||||
self.net = net
|
||||
self.in_axes = in_axes
|
||||
self.out_axes = out_axes
|
||||
|
||||
def construct(self, state_seq, log_probs):
|
||||
return vmap(self.net, self.in_axes, self.out_axes)(state_seq, log_probs)
|
||||
|
||||
@ms_function
|
||||
def for_net(state_seq, log_probs, ngram_size):
|
||||
# split and concat along dimension 0
|
||||
output = []
|
||||
for i in range(state_seq.shape[0]):
|
||||
out = P.NoRepeatNGram(ngram_size)(state_seq[i], log_probs[i])
|
||||
output.append(out)
|
||||
return F.stack(output)
|
||||
|
||||
state_seq = Tensor(np.tile(base_state_seq, (2, 1, 1, 1)))
|
||||
log_probs = Tensor(np.tile(base_log_probs, (2, 1, 1, 1)))
|
||||
|
||||
output = VmapNet(Net(3), (0, 0), 0)(state_seq, log_probs)
|
||||
fornet_output = for_net(state_seq, log_probs, 3)
|
||||
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
|
||||
|
||||
|
||||
def vmap_nested_case():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, ngram_size):
|
||||
super(Net, self).__init__()
|
||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||
|
||||
def construct(self, seq, log):
|
||||
return self.no_repeat_ngram(seq, log)
|
||||
|
||||
class WrapNet(nn.Cell):
|
||||
def __init__(self, net, inin_axes, inout_axes, outin_axes, outout_axes):
|
||||
super(WrapNet, self).__init__()
|
||||
self.net = net
|
||||
self.ii = inin_axes
|
||||
self.io = inout_axes
|
||||
self.oi = outin_axes
|
||||
self.oo = outout_axes
|
||||
|
||||
def construct(self, state_seq, log_probs):
|
||||
return vmap(vmap(self.net, self.ii, self.io), self.oi, self.oo)(state_seq, log_probs)
|
||||
|
||||
@ms_function
|
||||
def for_net(state_seq, log_probs, ngram_size):
|
||||
# split and concat along dimension 0 and 1
|
||||
output = []
|
||||
for i in range(state_seq.shape[0]):
|
||||
inner_output = []
|
||||
for j in range(state_seq.shape[1]):
|
||||
out = P.NoRepeatNGram(ngram_size)(state_seq[i][j], log_probs[i][j])
|
||||
inner_output.append(out)
|
||||
output.append(F.stack(inner_output))
|
||||
return F.stack(output)
|
||||
|
||||
state_seq = Tensor(np.tile(base_state_seq, (2, 2, 1, 1, 1)))
|
||||
log_probs = Tensor(np.tile(base_log_probs, (2, 2, 1, 1, 1)))
|
||||
|
||||
output = WrapNet(Net(3), (0, 0), 0, (0, 0), 0)(state_seq, log_probs)
|
||||
fornet_output = for_net(state_seq, log_probs, 3)
|
||||
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_cpu():
|
||||
"""
|
||||
Feature: test NoRepeatNGram vmap on CPU.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
vmap_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_cpu_nested():
|
||||
"""
|
||||
Feature: test nested NoRepeatNGram vmap on CPU.
|
||||
Description: inputs with batch.
|
||||
Expectation: the result match with expect.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
vmap_nested_case()
|
Loading…
Reference in New Issue