forked from mindspore-Ecosystem/mindspore
!38953 [NoRepeatNGram] add cpu op and npu dynamic shape
Merge pull request !38953 from bichaoyang/master
This commit is contained in:
commit
fb990c85d3
|
@ -0,0 +1,148 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/no_repeat_ngram_cpu_kernel.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "mindspore/core/ops/no_repeat_ngram.h"
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
#include "include/common/thread_pool.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kNoRepeatNGramInputsNum = 2;
|
||||||
|
constexpr size_t kNoRepeatNGramOutputsNum = 1;
|
||||||
|
constexpr size_t kNoRepeatNGramDim = 3;
|
||||||
|
constexpr int64_t kNoRepeatNGramParamValue = 0;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool NoRepeatNGramCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
auto kernel_ptr = std::dynamic_pointer_cast<ops::NoRepeatNGram>(base_operator);
|
||||||
|
if (!kernel_ptr) {
|
||||||
|
MS_LOG(ERROR) << "Cast NoRepeatNGram ops failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_name_ = kernel_ptr->name();
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "NoRepeatNGram does not support this kernel data type: " << kernel_attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
base_operator_ = base_operator;
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NoRepeatNGramCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||||
|
int ret = 0;
|
||||||
|
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
|
||||||
|
MS_LOG(WARNING) << kernel_name_ << " resize failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
state_seq_shape_ = inputs[kIndex0]->GetShapeVector();
|
||||||
|
log_probs_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||||
|
ngram_size_ = GetValue<int64_t>(base_operator->GetAttr("ngram_size"));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NoRepeatNGramCpuKernelMod::CheckAndInitParams() {
|
||||||
|
size_t state_seq_shape_len = state_seq_shape_.size();
|
||||||
|
size_t log_probs_shape_len = log_probs_shape_.size();
|
||||||
|
if ((state_seq_shape_len != log_probs_shape_len) || (state_seq_shape_len != kNoRepeatNGramDim)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'state_seq' " << state_seq_shape_len
|
||||||
|
<< " and 'log_probs' " << log_probs_shape_len << " is illegal.";
|
||||||
|
}
|
||||||
|
state_dim_ = state_seq_shape_[state_seq_shape_len - 1];
|
||||||
|
output_dim_ = log_probs_shape_[log_probs_shape_len - 1];
|
||||||
|
if ((ngram_size_ <= kNoRepeatNGramParamValue) || (ngram_size_ >= state_dim_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Param ngram_size is illegal";
|
||||||
|
}
|
||||||
|
num_seq_ = 1;
|
||||||
|
for (size_t i = 0; i < log_probs_shape_len - 1; i++) {
|
||||||
|
num_seq_ *= log_probs_shape_[i];
|
||||||
|
}
|
||||||
|
output_size_ = num_seq_ * output_dim_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool NoRepeatNGramCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &workspace,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNoRepeatNGramInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNoRepeatNGramOutputsNum, kernel_name_);
|
||||||
|
auto *state_seq = reinterpret_cast<int32_t *>(inputs[kIndex0]->addr);
|
||||||
|
auto *log_probs = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||||
|
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||||
|
CheckAndInitParams();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < LongToSize(output_size_); i++) {
|
||||||
|
output[i] = log_probs[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> array_dim(state_dim_);
|
||||||
|
std::vector<int32_t> array_ngram(ngram_size_ - 1);
|
||||||
|
for (int64_t i = 0; i < num_seq_; i++) {
|
||||||
|
int64_t src_index_i = state_dim_ * i;
|
||||||
|
int64_t output_index_i = output_dim_ * i;
|
||||||
|
for (int64_t k = 0; k < state_dim_; k++) {
|
||||||
|
int64_t src_index_k = k + src_index_i;
|
||||||
|
array_dim[k] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||||
|
if (k > (state_dim_ - ngram_size_)) {
|
||||||
|
array_ngram[k + ngram_size_ - state_dim_ - 1] = static_cast<int32_t>(state_seq[src_index_k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int64_t j = 0; j < state_dim_ - ngram_size_ + 1; j++) {
|
||||||
|
if (equal(array_ngram.begin(), array_ngram.end(), array_dim.begin() + j)) {
|
||||||
|
int64_t output_index_j = static_cast<int64_t>(array_dim[j + ngram_size_ - 1]);
|
||||||
|
output[output_index_i + output_index_j] = -(std::numeric_limits<T>::max)();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, NoRepeatNGramCpuKernelMod::NoRepeatNGramFunc>> NoRepeatNGramCpuKernelMod::func_list_ =
|
||||||
|
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&NoRepeatNGramCpuKernelMod::LaunchKernel<double>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&NoRepeatNGramCpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&NoRepeatNGramCpuKernelMod::LaunchKernel<float16>}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> NoRepeatNGramCpuKernelMod::GetOpSupport() {
|
||||||
|
static std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, NoRepeatNGramFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NoRepeatNGram, NoRepeatNGramCpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class NoRepeatNGramCpuKernelMod : public NativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
NoRepeatNGramCpuKernelMod() = default;
|
||||||
|
~NoRepeatNGramCpuKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||||
|
return kernel_func_(this, inputs, workspace, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Resize(
|
||||||
|
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CheckAndInitParams();
|
||||||
|
template <typename T>
|
||||||
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs);
|
||||||
|
using NoRepeatNGramFunc =
|
||||||
|
std::function<bool(NoRepeatNGramCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||||
|
|
||||||
|
static std::vector<std::pair<KernelAttr, NoRepeatNGramFunc>> func_list_;
|
||||||
|
NoRepeatNGramFunc kernel_func_{nullptr};
|
||||||
|
BaseOperatorPtr base_operator_;
|
||||||
|
std::vector<int64_t> state_seq_shape_;
|
||||||
|
std::vector<int64_t> log_probs_shape_;
|
||||||
|
int64_t ngram_size_{1};
|
||||||
|
int64_t output_size_{1};
|
||||||
|
int64_t state_dim_{1};
|
||||||
|
int64_t output_dim_{1};
|
||||||
|
int64_t num_seq_{1};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NO_REPEAT_NGRAM_CPU_KERNEL_H_
|
|
@ -545,6 +545,7 @@ GVAR_DEF(PrimitivePtr, kPrimMatrixBandPart, std::make_shared<Primitive>(kMatrixB
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
|
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));
|
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimNoRepeatNGram, std::make_shared<Primitive>("NoRepeatNGram"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
|
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
|
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimImag, std::make_shared<Primitive>(kImag));
|
GVAR_DEF(PrimitivePtr, kPrimImag, std::make_shared<Primitive>(kImag));
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/no_repeat_ngram.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr NoRepeatNGramInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||||
|
auto in_shape = shape_map[kShape];
|
||||||
|
auto min_shape = shape_map[kMinShape];
|
||||||
|
auto max_shape = shape_map[kMaxShape];
|
||||||
|
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||||
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(in_shape);
|
||||||
|
}
|
||||||
|
TypePtr NoRepeatNGramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
std::map<std::string, TypePtr> seq_types;
|
||||||
|
(void)seq_types.emplace("seq_type", input_args[0]->BuildType());
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(seq_types, {kInt32}, prim->name());
|
||||||
|
std::set<TypePtr> valid_params_types = {kFloat16, kFloat32, kFloat64};
|
||||||
|
std::map<std::string, TypePtr> log_types;
|
||||||
|
(void)log_types.emplace("log_types", input_args[1]->BuildType());
|
||||||
|
return CheckAndConvertUtils::CheckTensorTypeSame(log_types, valid_params_types, prim->name());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(NoRepeatNGram, BaseOperator);
|
||||||
|
AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t kInputsNum = 2;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||||
|
auto type = NoRepeatNGramInferType(primitive, input_args);
|
||||||
|
auto shape = NoRepeatNGramInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(shape, type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(NoRepeatNGram, prim::kPrimNoRepeatNGram, NoRepeatNGramInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_NO_REPEAT_NGRAM_H
|
||||||
|
#define MINDSPORE_NO_REPEAT_NGRAM_H
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameNoRepeatNGram = "NoRepeatNGram";
|
||||||
|
/// \brief
|
||||||
|
class MIND_API NoRepeatNGram : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(NoRepeatNGram);
|
||||||
|
/// \brief Constructor.
|
||||||
|
NoRepeatNGram() : BaseOperator(kNameNoRepeatNGram) { InitIOName({"state_seq", "log_probs"}, {"out"}); }
|
||||||
|
/// \brief Init.
|
||||||
|
void Init() const {}
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr NoRepeatNGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_NO_REPEAT_NGRAM_H
|
|
@ -38,6 +38,30 @@ from ..composite import _VmapGeneralRule
|
||||||
from ..operations.array_ops import TensorScatterElements
|
from ..operations.array_ops import TensorScatterElements
|
||||||
|
|
||||||
|
|
||||||
|
@vmap_rules_getters.register(P.NoRepeatNGram)
|
||||||
|
def get_no_repeat_ngram_vmap_rule(prim, axis_size):
|
||||||
|
"""VmapRule for `NoRepeatNGram` operation."""
|
||||||
|
|
||||||
|
def vmap_rule(state_seq_bdim, log_probs_bdim):
|
||||||
|
is_all_none, result = vmap_general_preprocess(prim, state_seq_bdim, log_probs_bdim)
|
||||||
|
if is_all_none:
|
||||||
|
return result
|
||||||
|
|
||||||
|
state_seq, state_seq_dim = state_seq_bdim
|
||||||
|
log_probs, log_probs_dim = log_probs_bdim
|
||||||
|
state_seq = _bdim_at_front(state_seq, state_seq_dim, axis_size)
|
||||||
|
log_probs = _bdim_at_front(log_probs, log_probs_dim, axis_size)
|
||||||
|
s_ori_shape = F.shape(state_seq)
|
||||||
|
l_ori_shape = F.shape(log_probs)
|
||||||
|
state_seq = F.reshape(state_seq, (-1,) + s_ori_shape[-2:])
|
||||||
|
log_probs = F.reshape(log_probs, (-1,) + l_ori_shape[-2:])
|
||||||
|
out = prim(state_seq, log_probs)
|
||||||
|
out = F.reshape(out, l_ori_shape)
|
||||||
|
return (out, 0)
|
||||||
|
|
||||||
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
@vmap_rules_getters.register("Cast")
|
@vmap_rules_getters.register("Cast")
|
||||||
def get_cast_vmap_rule(prim, axis_size):
|
def get_cast_vmap_rule(prim, axis_size):
|
||||||
"""VmapRule for `Cast` operation."""
|
"""VmapRule for `Cast` operation."""
|
||||||
|
|
|
@ -144,7 +144,7 @@ class NoRepeatNGram(PrimitiveWithInfer):
|
||||||
TypeError: If neither `state_seq` nor `log_probs` is a Tensor.
|
TypeError: If neither `state_seq` nor `log_probs` is a Tensor.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend``
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
|
>>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
|
||||||
|
|
|
@ -29,32 +29,52 @@ class Net(nn.Cell):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||||
|
|
||||||
def construct(self, state_seq, log_probs):
|
def construct(self, s, l):
|
||||||
return self.no_repeat_ngram(state_seq, log_probs)
|
return self.no_repeat_ngram(s, l)
|
||||||
|
|
||||||
|
|
||||||
|
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
||||||
|
[9, 3, 9, 5, 4, 1, 5],
|
||||||
|
[4, 7, 9, 1, 9, 6, 1],
|
||||||
|
[7, 6, 4, 2, 9, 1, 5],
|
||||||
|
[7, 5, 8, 9, 9, 3, 9]],
|
||||||
|
[[7, 7, 2, 7, 9, 9, 4],
|
||||||
|
[3, 4, 7, 4, 7, 6, 8],
|
||||||
|
[1, 9, 5, 7, 6, 9, 3],
|
||||||
|
[4, 8, 6, 4, 5, 6, 4],
|
||||||
|
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
||||||
|
|
||||||
|
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
|
||||||
|
expect_log_probs = log_probs.asnumpy().copy()
|
||||||
|
expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||||
|
expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||||
|
expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||||
|
expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def test_net():
|
||||||
state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
|
"""
|
||||||
[9, 3, 9, 5, 4, 1, 5],
|
Feature: test NoRepeatNGram on Ascend.
|
||||||
[4, 7, 9, 1, 9, 6, 1],
|
Description: inputs with batch.
|
||||||
[7, 6, 4, 2, 9, 1, 5],
|
Expectation: the result match with expect.
|
||||||
[7, 5, 8, 9, 9, 3, 9]],
|
"""
|
||||||
[[7, 7, 2, 7, 9, 9, 4],
|
|
||||||
[3, 4, 7, 4, 7, 6, 8],
|
|
||||||
[1, 9, 5, 7, 6, 9, 3],
|
|
||||||
[4, 8, 6, 4, 5, 6, 4],
|
|
||||||
[4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
|
|
||||||
|
|
||||||
log_probs = Tensor(np.random.random((2, 5, 10)).astype(np.float32))
|
|
||||||
expect_log_probs = log_probs.asnumpy().copy()
|
|
||||||
expect_log_probs[0, 0, 1] = -FLT_MAX
|
|
||||||
expect_log_probs[0, 0, 5] = -FLT_MAX
|
|
||||||
expect_log_probs[1, 3, 5] = -FLT_MAX
|
|
||||||
expect_log_probs[1, 4, 8] = -FLT_MAX
|
|
||||||
|
|
||||||
net = Net(ngram_size=3)
|
net = Net(ngram_size=3)
|
||||||
output = net(state_seq, log_probs)
|
output = net(state_seq, log_probs)
|
||||||
|
|
||||||
print(expect_log_probs)
|
print(expect_log_probs)
|
||||||
print(output)
|
print(output)
|
||||||
assert np.array_equal(expect_log_probs, output.asnumpy())
|
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test NoRepeatNGram dynamic shape on Ascend.
|
||||||
|
Description: inputs with batch.
|
||||||
|
Expectation: the result match with expect.
|
||||||
|
"""
|
||||||
|
net = Net(ngram_size=3)
|
||||||
|
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mindspore.int32)
|
||||||
|
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mindspore.float32)
|
||||||
|
net.set_inputs(place_holder_x, place_holder_v)
|
||||||
|
output = net(state_seq, log_probs)
|
||||||
|
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops.functional import vmap
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
FLT_MAX = 3.4028235e+38
|
||||||
|
|
||||||
|
|
||||||
|
class CpuNet(nn.Cell):
|
||||||
|
def __init__(self, ngram_size):
|
||||||
|
super(CpuNet, self).__init__()
|
||||||
|
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||||
|
|
||||||
|
def construct(self, state_seq, log_probs):
|
||||||
|
return self.no_repeat_ngram(state_seq, log_probs)
|
||||||
|
|
||||||
|
|
||||||
|
base_state_seq = np.array([[[1, 2, 1, 2, 5, 1, 2],
|
||||||
|
[9, 3, 9, 5, 4, 1, 5],
|
||||||
|
[4, 7, 9, 1, 9, 6, 1],
|
||||||
|
[7, 6, 4, 2, 9, 1, 5],
|
||||||
|
[7, 5, 8, 9, 9, 3, 9]],
|
||||||
|
[[7, 7, 2, 7, 9, 9, 4],
|
||||||
|
[3, 4, 7, 4, 7, 6, 8],
|
||||||
|
[1, 9, 5, 7, 6, 9, 3],
|
||||||
|
[4, 8, 6, 4, 5, 6, 4],
|
||||||
|
[4, 8, 8, 4, 3, 4, 8]]], dtype=np.int32)
|
||||||
|
base_log_probs = np.random.random((2, 5, 10)).astype(np.float32)
|
||||||
|
base_expect_log_probs = base_log_probs.copy()
|
||||||
|
base_expect_log_probs[0, 0, 1] = -FLT_MAX
|
||||||
|
base_expect_log_probs[0, 0, 5] = -FLT_MAX
|
||||||
|
base_expect_log_probs[1, 3, 5] = -FLT_MAX
|
||||||
|
base_expect_log_probs[1, 4, 8] = -FLT_MAX
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
"""
|
||||||
|
Feature: test NoRepeatNGram on CPU.
|
||||||
|
Description: inputs with batch.
|
||||||
|
Expectation: the result match with expect.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
state_seq = Tensor(base_state_seq)
|
||||||
|
log_probs = Tensor(base_log_probs)
|
||||||
|
expect_log_probs = base_expect_log_probs
|
||||||
|
|
||||||
|
net = CpuNet(ngram_size=3)
|
||||||
|
output = net(state_seq, log_probs)
|
||||||
|
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test NoRepeatNGram dynamic shape on CPU.
|
||||||
|
Description: inputs with batch.
|
||||||
|
Expectation: the result match with expect.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
state_seq = Tensor(base_state_seq)
|
||||||
|
log_probs = Tensor(base_log_probs)
|
||||||
|
expect_log_probs = base_expect_log_probs
|
||||||
|
|
||||||
|
net = CpuNet(ngram_size=3)
|
||||||
|
place_holder_x = Tensor(shape=[None, 5, 7], dtype=mstype.int32)
|
||||||
|
place_holder_v = Tensor(shape=[None, 5, 10], dtype=mstype.float32)
|
||||||
|
net.set_inputs(place_holder_x, place_holder_v)
|
||||||
|
output = net(state_seq, log_probs)
|
||||||
|
assert np.array_equal(expect_log_probs, output.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
def vmap_case():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, ngram_size):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||||
|
|
||||||
|
def construct(self, seq, log):
|
||||||
|
return self.no_repeat_ngram(seq, log)
|
||||||
|
|
||||||
|
class VmapNet(nn.Cell):
|
||||||
|
def __init__(self, net, in_axes, out_axes):
|
||||||
|
super(VmapNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.in_axes = in_axes
|
||||||
|
self.out_axes = out_axes
|
||||||
|
|
||||||
|
def construct(self, state_seq, log_probs):
|
||||||
|
return vmap(self.net, self.in_axes, self.out_axes)(state_seq, log_probs)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def for_net(state_seq, log_probs, ngram_size):
|
||||||
|
# split and concat along dimension 0
|
||||||
|
output = []
|
||||||
|
for i in range(state_seq.shape[0]):
|
||||||
|
out = P.NoRepeatNGram(ngram_size)(state_seq[i], log_probs[i])
|
||||||
|
output.append(out)
|
||||||
|
return F.stack(output)
|
||||||
|
|
||||||
|
state_seq = Tensor(np.tile(base_state_seq, (2, 1, 1, 1)))
|
||||||
|
log_probs = Tensor(np.tile(base_log_probs, (2, 1, 1, 1)))
|
||||||
|
|
||||||
|
output = VmapNet(Net(3), (0, 0), 0)(state_seq, log_probs)
|
||||||
|
fornet_output = for_net(state_seq, log_probs, 3)
|
||||||
|
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def vmap_nested_case():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, ngram_size):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.no_repeat_ngram = P.NoRepeatNGram(ngram_size)
|
||||||
|
|
||||||
|
def construct(self, seq, log):
|
||||||
|
return self.no_repeat_ngram(seq, log)
|
||||||
|
|
||||||
|
class WrapNet(nn.Cell):
|
||||||
|
def __init__(self, net, inin_axes, inout_axes, outin_axes, outout_axes):
|
||||||
|
super(WrapNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.ii = inin_axes
|
||||||
|
self.io = inout_axes
|
||||||
|
self.oi = outin_axes
|
||||||
|
self.oo = outout_axes
|
||||||
|
|
||||||
|
def construct(self, state_seq, log_probs):
|
||||||
|
return vmap(vmap(self.net, self.ii, self.io), self.oi, self.oo)(state_seq, log_probs)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def for_net(state_seq, log_probs, ngram_size):
|
||||||
|
# split and concat along dimension 0 and 1
|
||||||
|
output = []
|
||||||
|
for i in range(state_seq.shape[0]):
|
||||||
|
inner_output = []
|
||||||
|
for j in range(state_seq.shape[1]):
|
||||||
|
out = P.NoRepeatNGram(ngram_size)(state_seq[i][j], log_probs[i][j])
|
||||||
|
inner_output.append(out)
|
||||||
|
output.append(F.stack(inner_output))
|
||||||
|
return F.stack(output)
|
||||||
|
|
||||||
|
state_seq = Tensor(np.tile(base_state_seq, (2, 2, 1, 1, 1)))
|
||||||
|
log_probs = Tensor(np.tile(base_log_probs, (2, 2, 1, 1, 1)))
|
||||||
|
|
||||||
|
output = WrapNet(Net(3), (0, 0), 0, (0, 0), 0)(state_seq, log_probs)
|
||||||
|
fornet_output = for_net(state_seq, log_probs, 3)
|
||||||
|
np.testing.assert_allclose(output.asnumpy(), fornet_output.asnumpy(), rtol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_vmap_cpu():
|
||||||
|
"""
|
||||||
|
Feature: test NoRepeatNGram vmap on CPU.
|
||||||
|
Description: inputs with batch.
|
||||||
|
Expectation: the result match with expect.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
vmap_case()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_vmap_cpu_nested():
|
||||||
|
"""
|
||||||
|
Feature: test nested NoRepeatNGram vmap on CPU.
|
||||||
|
Description: inputs with batch.
|
||||||
|
Expectation: the result match with expect.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
vmap_nested_case()
|
Loading…
Reference in New Issue