From d493b0e9b40ec3acdd232f0632e56d9302b6a1e8 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Fri, 10 Feb 2023 11:20:43 +0800 Subject: [PATCH] Add aicpu for mask --- .../ops/mindspore.ops.GenerateEodMask.rst | 22 +++ mindspore/ccsrc/include/common/utils/utils.h | 1 + .../ms_kernel/generate_eod_mask_kernels.cc | 123 +++++++++++++++ .../ms_kernel/generate_eod_mask_kernels.h | 37 +++++ .../device/ascend/kernel/aicpu/aicpu_util.h | 1 + .../optimizer/mindir/aicpu_lib_select.cc | 1 + mindspore/core/ops/core_ops.h | 1 + mindspore/core/ops/generate_eod_mask.cc | 98 ++++++++++++ mindspore/core/ops/generate_eod_mask.h | 54 +++++++ mindspore/core/ops/op_name.h | 1 + .../mindspore/ops/_op_impl/aicpu/__init__.py | 1 + .../ops/_op_impl/aicpu/generate_eod_mask.py | 36 +++++ .../mindspore/ops/operations/__init__.py | 3 +- .../mindspore/ops/operations/inner_ops.py | 46 ++++++ tests/st/parallel/test_generate_eod_mask.py | 143 ++++++++++++++++++ .../python/parallel/test_generate_eod_mask.py | 83 ++++++++++ 16 files changed, 650 insertions(+), 1 deletion(-) create mode 100644 docs/api/api_python/ops/mindspore.ops.GenerateEodMask.rst create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.cc create mode 100644 mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.h create mode 100644 mindspore/core/ops/generate_eod_mask.cc create mode 100644 mindspore/core/ops/generate_eod_mask.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/generate_eod_mask.py create mode 100644 tests/st/parallel/test_generate_eod_mask.py create mode 100644 tests/ut/python/parallel/test_generate_eod_mask.py diff --git a/docs/api/api_python/ops/mindspore.ops.GenerateEodMask.rst b/docs/api/api_python/ops/mindspore.ops.GenerateEodMask.rst new file mode 100644 index 00000000000..1dbb0f2a4b6 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.GenerateEodMask.rst @@ -0,0 +1,22 @@ +mindspore.ops.GenerateEodMask +============================= + +.. py:class:: mindspore.ops.GenerateEodMask(eod_token_id) + + 根据输入的 `inputs_ids`, 遇到 `eod_token_id` 时,会将输出的位置编码和注意力编码全部重置。 + 即`position_id`从0开始重新计数,同时对应的掩码矩阵也会填充为0。 + + 参数: + - **eod_token_id** (int) - `eod_token_id` 的数值。在NLP场景中,这个值对应词表中的 `EodOfDocument` 的符号编码。 + + 输入: + - **inputs_ids** (Tensor) - 词序列。是一个二维Tensor,其shape为 :math:`(batch\_size, seq\_length)` 。 + + 输出: + - **position_id** (Tensor) - 位置编码矩阵。数据类型和shape与输入 `inputs_ids` 相同。 + - **attention_mask** (Tensor) - 注意力掩码矩阵。类型为float16,其shape为: :math:`(batch\_size, seq\_length)` 。 + + 异常: + - **TypeError** - 如果 `eod_token_id` 的数据类型不是int。 + - **TypeError** - 如果 `inputs_ids` 的数据类型不是整数类型。 + - **ValueError** - 如果 `inputs_ids` 不是二维的Tensor。 diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 816d014f148..73afe7b5d0b 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -321,6 +321,7 @@ constexpr auto kEnvironGetOpName = "EnvironGet"; constexpr auto kEnvironSetOpName = "EnvironSet"; constexpr auto kEqualOpName = "Equal"; constexpr auto kErfOpName = "Erf"; +constexpr auto kGenerateEodMaskOpName = "GenerateEodMask"; constexpr auto kEuclideanNormOpName = "EuclideanNorm"; constexpr auto kEuclideanNormDOpName = "EuclideanNormD"; constexpr auto kExpandOpName = "Expand"; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.cc new file mode 100644 index 00000000000..48f02d41dbb --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "./generate_eod_mask_kernels.h" +#include +#include +#include +#include +#include +#include +#include "utils/kernel_util.h" +#include "cpu_kernel_utils.h" + +namespace { +const char *kGenerateEodMask = "GenerateEodMask"; +constexpr auto kInputSize = 1; +constexpr auto kOutputSize = 2; +constexpr auto kInputIdsShape = 2; +constexpr auto kAddressSize = 3; +constexpr auto kDim0 = 0; +constexpr auto kDim1 = 1; +constexpr auto kDim2 = 2; +constexpr auto kDim3 = 3; +} // namespace +namespace aicpu { +uint32_t GenerateEodMaskCpuKernel::Compute(CpuKernelContext &ctx) { + KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputSize, kOutputSize), "GenerateEodMaskCpu check input and output failed."); + Tensor *input = ctx.Input(0); + auto data_type_in = input->GetDataType(); + AttrValue *eod_token_value = ctx.GetAttr("eod_token_id"); + int64_t eod_token_id = (eod_token_value == nullptr) ? 0 : eod_token_value->GetInt(); + switch (data_type_in) { + case DT_UINT16: + return ComputeKernel(ctx, eod_token_id); + case DT_UINT32: + return ComputeKernel(ctx, eod_token_id); + case DT_UINT64: + return ComputeKernel(ctx, eod_token_id); + case DT_INT32: + return ComputeKernel(ctx, eod_token_id); + case DT_INT64: + return ComputeKernel(ctx, eod_token_id); + default: + KERNEL_LOG_ERROR("GenerateEodMask kernel data type [%s] not support.", DTypeStr(data_type_in).c_str()); + return KERNEL_STATUS_PARAM_INVALID; + } +} + +template +uint32_t GenerateEodMaskCpuKernel::ComputeKernel(CpuKernelContext &ctx, const T &eod_token_id) { + auto input_idsptr = reinterpret_cast(ctx.Input(0)->GetData()); + auto input_positionptr = reinterpret_cast(ctx.Output(0)->GetData()); + auto outputptr = reinterpret_cast(ctx.Output(1)->GetData()); + auto output_shape = ctx.Output(1)->GetTensorShape()->GetDimSizes(); + size_t output_size = + std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()) * sizeof(Eigen::half); + if (memset_s(outputptr, output_size, 0x0, output_size) != EOK) { + KERNEL_LOG_ERROR("memset_s failed!"); + return KERNEL_STATUS_INNER_ERROR; + } + size_t batch_size = ctx.Input(0)->GetTensorShape()->GetDimSize(0); + size_t seq_length = ctx.Input(0)->GetTensorShape()->GetDimSize(1); + + auto shard_generate_tril = [&](size_t start, size_t end) { + size_t x = seq_length * seq_length; + for (size_t i = start; i < end; ++i) { + for (size_t j = 0; j < seq_length; ++j) { + for (size_t k = 0; k < j + 1; ++k) { + outputptr[i * x + j * seq_length + k] = (Eigen::half)1.0; + } + } + } + }; + + auto shard_generate_eod_mask = [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + T sub = 0; + T pre_sub = 0; + for (size_t index = 0; index < seq_length; ++index) { + size_t sub_index = i * seq_length + index; + if (input_idsptr[sub_index] == eod_token_id) { + pre_sub = sub; + sub = index + 1; + size_t seq_n2 = seq_length * seq_length; + for (size_t k = index + 1; k < seq_length; ++k) { + for (size_t m = 0; m < index + 1; ++m) { + outputptr[i * seq_n2 + k * seq_length + m] = (Eigen::half)0.0; + } + } + input_positionptr[sub_index] = index - pre_sub; + } else { + input_positionptr[sub_index] = index - sub; + } + } + } + }; + + auto get_per_unit_size = [&](int64_t data_size) -> int64_t { + const int64_t max_core_num = + std::max(static_cast(1), static_cast(aicpu::CpuKernelUtils::GetCPUNum(ctx) - 2)); + return data_size / std::min(max_core_num, data_size); + }; + KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, batch_size, get_per_unit_size(batch_size), shard_generate_tril), + "GenerateEodMask kernel compute failed."); + KERNEL_HANDLE_ERROR( + CpuKernelUtils::ParallelFor(ctx, batch_size, get_per_unit_size(batch_size), shard_generate_eod_mask), + "GenerateEodMask kernel compute failed."); + return KERNEL_STATUS_OK; +} +REGISTER_CPU_KERNEL(kGenerateEodMask, GenerateEodMaskCpuKernel); +} // namespace aicpu diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.h new file mode 100644 index 00000000000..b3678c9accf --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/generate_eod_mask_kernels.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_KERNELS_GENERATEEODMASK_H_ +#define AICPU_KERNELS_GENERATEEODMASK_H_ + +#include +#include "common/kernel_base.h" +#include "cpu_ops_kernel.h" + +namespace aicpu { +class GenerateEodMaskCpuKernel : public CpuKernel { + public: + GenerateEodMaskCpuKernel() = default; + ~GenerateEodMaskCpuKernel() override = default; + + protected: + uint32_t Compute(CpuKernelContext &ctx) override; + + private: + template + uint32_t ComputeKernel(CpuKernelContext &ctx, const T &eod_token_id); +}; +} // namespace aicpu +#endif // AICPU_KERNELS_GENERATEEODMASK_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h index 3b4a26b3cb9..5e1fa20a2c6 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h @@ -72,6 +72,7 @@ constexpr auto kConcatOffset = "ConcatOffset"; constexpr auto kConcatOffsetV1 = "ConcatOffsetV1"; constexpr auto kRandomChoiceWithMask = "RandomChoiceWithMask"; constexpr auto kGatherDGradV2 = "GatherDGradV2"; +constexpr auto kGenerateEodMask = "GenerateEodMask"; constexpr auto kResizeNearestNeighborV2 = "ResizeNearestNeighborV2"; constexpr auto kResizeNearestNeighborV2Grad = "ResizeNearestNeighborV2Grad"; constexpr auto kUpdateCache = "UpdateCache"; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index 28bbcee4341..f865843277a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -64,6 +64,7 @@ bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const { mindspore::kComplexAbsOpName, mindspore::kConcatOpName, mindspore::kCosOpName, + mindspore::kGenerateEodMaskOpName, mindspore::kCountNonZeroOpName, mindspore::kCumulativeLogsumexpOpName, mindspore::kCumProdOpName, diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 3f7392ac765..c1da88a30e8 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -765,6 +765,7 @@ GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared("NonZero")); GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared("NonZeroWithValue")); GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared("NonZeroWithValueShape")); GVAR_DEF(PrimitivePtr, kPrimNoRepeatNGram, std::make_shared("NoRepeatNGram")); +GVAR_DEF(PrimitivePtr, kPrimGenerateEodMask, std::make_shared("GenerateEodMask")); GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared(kRealInner)); GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared(kReal)); GVAR_DEF(PrimitivePtr, kPrimImag, std::make_shared(kImag)); diff --git a/mindspore/core/ops/generate_eod_mask.cc b/mindspore/core/ops/generate_eod_mask.cc new file mode 100644 index 00000000000..346ccb7a673 --- /dev/null +++ b/mindspore/core/ops/generate_eod_mask.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/generate_eod_mask.h" +#include +#include +#include +#include +#include + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +void GenerateEodMask::set_eod_token_id(const int64_t eod_token_id) { + (void)this->AddAttr(kEodTokenId, api::MakeValue(eod_token_id)); +} +/// \brief Get EodTokenId. +/// +/// \return EodTokenId. +int64_t GenerateEodMask::get_eod_token_id() const { return GetValue(GetAttr(kEodTokenId)); } + +MIND_API_OPERATOR_IMPL(GenerateEodMask, BaseOperator); + +// AG means auto generated +class MIND_API AGGenerateEodMaskInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + const int64_t no_repeat_kShapeSize = 2; + auto inputs_ids_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto inputs_ids_shape = inputs_ids_shape_map[kShape]; + + if (IsDynamicRank(inputs_ids_shape)) { + auto any_shape = std::make_shared(std::vector{abstract::Shape::kShapeRankAny}); + std::vector shapes_list = {any_shape, any_shape}; + return std::make_shared(shapes_list); + } + + ShapeVector attention_mask_shape{inputs_ids_shape.begin(), inputs_ids_shape.end()}; + + attention_mask_shape.push_back(inputs_ids_shape.back()); + (void)CheckAndConvertUtils::CheckInteger("rank of inputs_ids", SizeToLong(inputs_ids_shape.size()), kEqual, + no_repeat_kShapeSize, prim_name); + + std::vector shapes_list = {}; + (void)shapes_list.emplace_back(std::make_shared(inputs_ids_shape)); + (void)shapes_list.emplace_back(std::make_shared(attention_mask_shape)); + return std::make_shared(shapes_list); + } + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + std::map input_types; + auto input_ids_type = input_args[0]->BuildType(); + (void)input_types.emplace("inputs_ids", input_ids_type); + std::set valid_input_types = {kInt16, kInt32, kInt64, kUInt16, kUInt32, kUInt64}; + (void)CheckAndConvertUtils::CheckTensorTypeSame(input_types, valid_input_types, primitive->name()); + return std::make_shared(std::vector{input_ids_type, kFloat16}); + } + AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputsNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); + auto type = InferType(primitive, input_args); + auto shape = InferShape(primitive, input_args); + return abstract::MakeAbstract(shape, type); + } +}; + +REGISTER_PRIMITIVE_OP_INFER_IMPL(GenerateEodMask, prim::kPrimGenerateEodMask, AGGenerateEodMaskInfer, false); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/generate_eod_mask.h b/mindspore/core/ops/generate_eod_mask.h new file mode 100644 index 00000000000..86a93e88983 --- /dev/null +++ b/mindspore/core/ops/generate_eod_mask.h @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GENERATE_EOD_MASK_H +#define MINDSPORE_GENERATE_EOD_MASK_H +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameGenerateEodMask = "GenerateEodMask"; +/// \brief +class MIND_API GenerateEodMask : public BaseOperator { + public: + MIND_API_BASE_MEMBER(GenerateEodMask); + /// \brief Constructor. + GenerateEodMask() : BaseOperator(kNameGenerateEodMask) { + InitIOName({"inputs_ids"}, {"position_ids", "attention_mask"}); + } + /// \brief Init. + void Init() const {} + /// \brief Set axis. + void set_eod_token_id(const int64_t eod_token_id); + /// \brief Get axis. + /// + /// \return axis. + int64_t get_eod_token_id() const; +}; + +MIND_API abstract::AbstractBasePtr GenerateEodMaskInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_GENERATE_EOD_MASK_H diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index f58db75bff6..d9d5e8cef17 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -149,6 +149,7 @@ constexpr auto kNeginf = "neginf"; constexpr auto kNesterov = "nesterov"; constexpr auto kNewAxisMask = "new_axis_mask"; constexpr auto kNgramSize = "ngram_size"; +constexpr auto kEodTokenId = "eod_token_id"; constexpr auto kNGram = "ngram"; constexpr auto kNmsThresh = "nms_thresh"; constexpr auto kNormRegion = "norm_region"; diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index ebe8df0987f..5f0b9d96cfa 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -60,6 +60,7 @@ from .init_data_set_queue import _init_data_set_queue_aicpu from .embedding_lookup import _embedding_lookup_aicpu from .padding import _padding_aicpu from .gather import _gather_aicpu +from .generate_eod_mask import _generate_eod_mask_aicpu from .gather_grad import _gather_grad_aicpu from .gather_d_grad_v2 import _gather_d_grad_v2_aicpu from .gather_d import _gather_d_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/generate_eod_mask.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/generate_eod_mask.py new file mode 100644 index 00000000000..4885970ff45 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/generate_eod_mask.py @@ -0,0 +1,36 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GenerateEodMask op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +generate_eod_mask_op_info = AiCPURegOp("GenerateEodMask") \ + .fusion_type("OPAQUE") \ + .attr("eod_token_id", "int") \ + .input(0, "inputs_ids", "required") \ + .output(0, "position_ids", "required") \ + .output(1, "attention_mask", "required") \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.F16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.F16_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(generate_eod_mask_op_info) +def _generate_eod_mask_aicpu(): + """GenerateEodMask AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index f737df6ba94..98fb90f7da8 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -64,7 +64,7 @@ from .image_ops import (CropAndResize, NonMaxSuppressionV3, HSVToRGB, AdjustHue, CombinedNonMaxSuppression, RGBToHSV, ScaleAndTranslate, ResizeLinear1D, ResizeBicubic) from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor, - FusedAdaFactorWithGlobalNorm) + FusedAdaFactorWithGlobalNorm, GenerateEodMask) from .linalg_ops import (Svd, Geqrf) from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, Ger, @@ -170,6 +170,7 @@ __all__ = [ 'Conv3D', 'Conv2DTranspose', 'Conv3DTranspose', + "GenerateEodMask", 'FillV2', 'Flatten', 'MaxPoolWithArgmax', diff --git a/mindspore/python/mindspore/ops/operations/inner_ops.py b/mindspore/python/mindspore/ops/operations/inner_ops.py index 893d9af56d4..2cd7613fc02 100755 --- a/mindspore/python/mindspore/ops/operations/inner_ops.py +++ b/mindspore/python/mindspore/ops/operations/inner_ops.py @@ -655,6 +655,52 @@ class FusedAdaFactorWithGlobalNorm(FusedAdaFactor): return param_type +class GenerateEodMask(Primitive): + r""" + Given the input `inputs_ids`, if found eod_token_id, the output position and attention mask matrix will be reset. + This means the `position_id` will start counting from 0, and the corresponding mask matrix will be filled with 0. + + Args: + eod_token_id (int) - In the NLP scenario, this value corresponds to the id of + the symbol of 'EodOfDocument' in the vocabulary. + + Inputs: + - **inputs_ids** (Tensor) - token id, a 2-D Tensor with shape :math:`(batch\_size, seq\_length)`. + + Outputs: + - **position_id** (Tensor) - position id matrix with same shape and type as original `inputs_ids`. + - **attention_mask** (Tensor) - attention mask matrix with type + float16 and shape :math:`(batch\_size, seq\_length)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> op = ops.GenerateEodMask(eod_token_id=0) + >>> position, mask = op(Tensor([[1, 0, 3], [1, 0, 0]], dtype=mindspore.int32)) + >>> print(position) + [[0 1 0] [0 0 1]] + >>> print(mask) + [[[ 1. 0. 0.] + [1. 1. 0.] + [0. 0. 1.]] + [[1. 0. 0.] + [0. 1. 0.] + [0. 1. 1.]]] + + Raises: + - **TypeError** - If `eod_token_id` is not int. + - **TypeError** - If `inputs_ids` is not int. + - **ValueError** - If `inputs_ids` is not a 2-D Tensor. + """ + @prim_attr_register + def __init__(self, eod_token_id): + """Initialize GenerateEodMask""" + validator.check_value_type("eod_token_id", eod_token_id, [int], self.name) + self.init_prim_io_names(inputs=['inputs_ids'], + outputs=['position_ids', 'attention_mask']) + + class ScaleGrad(PrimitiveWithInfer): """ Scale the input grad according to the loss scale. diff --git a/tests/st/parallel/test_generate_eod_mask.py b/tests/st/parallel/test_generate_eod_mask.py new file mode 100644 index 00000000000..13a0abec545 --- /dev/null +++ b/tests/st/parallel/test_generate_eod_mask.py @@ -0,0 +1,143 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, eod_token_id): + super(Net, self).__init__() + self.mask = P.GenerateEodMask(eod_token_id=eod_token_id) + + def construct(self, tensor): + return self.mask(tensor) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_generate_eod_mask(): + """ + Feature: Test GenerateEodMask. + Description: Test eod_token_id 0 + Expectation: raise TypeError. + """ + x = np.array([[1, 0, 3, 4, 0, 6, 7, 8], [1, 0, 3, 0, 0, 6, 7, 0]]) + net = Net(0) + position, mask = net(Tensor(x, dtype=mindspore.int32)) + assert position.shape == (2, 8) + assert mask.shape == (2, 8, 8) + assert np.all(mask.asnumpy() == np.array([[[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]], + [[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]]])) + + assert np.all(position.asnumpy() == np.array([[0, 1, 0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 0, 0, 1, 2]])) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_generate_eod_mask_negative_value(): + """ + Feature: Test GenerateEodMask. + Description: Test EodTokenId Negative + Expectation: no errors + """ + x = np.array([[1, -1, 3, 4, -1, 6, 7, 8], [1, -1, 3, -1, -1, 6, 7, -1]]) + net = Net(eod_token_id=-1) + position, mask = net(Tensor(x, dtype=mindspore.int32)) + assert position.shape == (2, 8) + assert mask.shape == (2, 8, 8) + assert np.all(mask.asnumpy() == np.array([[[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]], + [[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]]])) + + assert np.all(position.asnumpy() == np.array([[0, 1, 0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 0, 0, 1, 2]])) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_generate_eod_mask_dynamic_inputs(): + """ + Feature: Test GenerateEodMask. + Description: Test dynamic inputs + Expectation: No error. + """ + x = np.array([[1, -1, 3, 4, -1, 6, 7, 8], [1, -1, 3, -1, -1, 6, 7, -1]]) + net = Net(eod_token_id=-1) + dyn_x = Tensor(shape=(None, None), dtype=mindspore.int32) + net.set_inputs(dyn_x) + position, mask = net(Tensor(x, dtype=mindspore.int32)) + assert position.shape == (2, 8) + assert mask.shape == (2, 8, 8) + assert np.all(mask.asnumpy() == np.array([[[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]], + [[1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1]]])) + + assert np.all(position.asnumpy() == np.array([[0, 1, 0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 0, 0, 1, 2]])) diff --git a/tests/ut/python/parallel/test_generate_eod_mask.py b/tests/ut/python/parallel/test_generate_eod_mask.py new file mode 100644 index 00000000000..f83482bebdf --- /dev/null +++ b/tests/ut/python/parallel/test_generate_eod_mask.py @@ -0,0 +1,83 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, eod_token_id): + super(Net, self).__init__() + self.mask = P.GenerateEodMask(eod_token_id=eod_token_id) + + def construct(self, tensor): + return self.mask(tensor) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_generate_eod_mask_wrong_type(): + """ + Feature: Test GenerateEodMask. + Description: Test float_token_id + Expectation: raise TypeError. + """ + x = np.array([[1, 0, 3, 4, 0, 6, 7, 8], [1, 0, 3, 0, 0, 6, 7, 0]]) + net = Net(0) + with pytest.raises(TypeError): + net(Tensor(x, dtype=mindspore.float32)) + + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_generate_eod_mask_float_token_id(): + """ + Feature: Test GenerateEodMask. + Description: Test float_token_id + Expectation: raise TypeError. + """ + x = np.array([[1, 0, 3, 4, 0, 6, 7, 8], [1, 0, 3, 0, 0, 6, 7, 0]]) + with pytest.raises(TypeError): + net = Net(-1.0) + net(Tensor(x, dtype=mindspore.float32)) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [mindspore.int32, mindspore.int64, mindspore.uint16, + mindspore.uint32, mindspore.uint64]) +def test_generate_eod_mask_support_dtype(dtype): + """ + Feature: Test GenerateEodMask. + Description: Test multi dtype inputs + Expectation: Successful graph compilation. + """ + x = np.array([[1, 0, 3, 4, 0, 6, 7, 8], [1, 0, 3, 0, 0, 6, 7, 0]]) + net = Net(0) + net(Tensor(x, dtype=dtype))