c++ infer: dynamic_rnn

This commit is contained in:
yiyanzhi_akane 2022-11-02 14:08:07 +08:00
parent bf776e8807
commit 4777eaa06e
3 changed files with 213 additions and 38 deletions

View File

@ -0,0 +1,168 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <set>
#include "ops/dynamic_rnn.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kDynRnnIdx0 = 0;
constexpr int64_t kDynRnnIdx1 = 1;
constexpr int64_t kDynRnnIdx2 = 2;
constexpr int64_t kDynRnnIdx3 = 3;
constexpr int64_t kDynRnnIdx4 = 4;
constexpr int64_t kDynRnnIdx5 = 5;
constexpr int64_t kDynamicRnnShapeX = 3;
constexpr int64_t kDynamicRnnShapeW = 2;
constexpr int64_t kDynamicRnnShapeB = 1;
constexpr int64_t kDynamicRnnShapeH = 3;
constexpr int64_t kDynamicRnnShapeC = 3;
constexpr int64_t kDynRnnNum4 = 4;
abstract::TupleShapePtr DynamicRNNInferDynamicShape(const std::vector<AbstractBasePtr> &input_args) {
const int64_t y_shape_num = 3;
ShapeVector y_shape_dyn;
for (size_t i = 0; i < y_shape_num; ++i) {
y_shape_dyn.push_back(abstract::Shape::kShapeDimAny);
}
abstract::ShapePtr y_shape_dyn_ptr = std::make_shared<abstract::Shape>(y_shape_dyn);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr,
y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr});
}
void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx2]->BuildShape())[kShape];
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kEqual, kDynamicRnnShapeX, op_name);
(void)CheckAndConvertUtils::CheckInteger("w_shape", SizeToLong(w_shape.size()), kEqual, kDynamicRnnShapeW, op_name);
(void)CheckAndConvertUtils::CheckInteger("b_shape", SizeToLong(b_shape.size()), kEqual, kDynamicRnnShapeB, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC, op_name);
int64_t batch_size = x_shape[kDynRnnIdx1];
int64_t input_size = x_shape[kDynRnnIdx2];
if (seq_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
<< ".";
}
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
if (w_shape[w_shape.size() - 1] % kDynRnnNum4 != 0) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', w_shape[-1] should multiple of 4, now is "
<< w_shape[w_shape.size() - 1] << ".";
}
if (w_shape[0] != input_size + hidden_size) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', w_shape[0] should equal to input_size + hidden_size, but gets " << w_shape[0]
<< ".";
}
if (b_shape[0] != w_shape[1]) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', b_shape[0] should equal to w_shape[1], but gets "
<< b_shape[0] << ".";
}
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
}
abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
std::vector<ValuePtr> placeholder_index = {MakeValue((int64_t)3)};
primitive->set_attr("placeholder_index", MakeValue(placeholder_index));
if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny})});
}
if (IsDynamic(x_shape) || IsDynamic(w_shape)) {
return DynamicRNNInferDynamicShape(input_args);
}
DynamicRNNShapeCheck(primitive, input_args);
int64_t num_step = x_shape[kDynRnnIdx0];
int64_t batch_size = x_shape[kDynRnnIdx1];
int64_t input_size = x_shape[kDynRnnIdx2];
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
primitive->set_attr("input_size", MakeValue(input_size));
primitive->set_attr("hidden_size", MakeValue(hidden_size));
std::vector<int64_t> y_shape{num_step, batch_size, hidden_size};
abstract::ShapePtr y_shape_ptr = std::make_shared<abstract::Shape>(y_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr});
}
TuplePtr DynamicRNNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto x_dtype = input_args[kDynRnnIdx0]->BuildType();
auto w_dtype = input_args[kDynRnnIdx1]->BuildType();
auto b_dtype = input_args[kDynRnnIdx2]->BuildType();
auto h_dtype = input_args[kDynRnnIdx4]->BuildType();
auto c_dtype = input_args[kDynRnnIdx5]->BuildType();
auto seq_type = input_args[kDynRnnIdx3]->BuildType();
if (seq_type->type_id() != kMetaTypeNone) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "' seq is not None, please check seq's type";
}
std::set<TypePtr> float16_set = {kFloat16};
MS_EXCEPTION_IF_NULL(x_dtype);
MS_EXCEPTION_IF_NULL(w_dtype);
MS_EXCEPTION_IF_NULL(h_dtype);
MS_EXCEPTION_IF_NULL(c_dtype);
std::map<std::string, TypePtr> types;
types.emplace("x", x_dtype);
types.emplace("w", w_dtype);
types.emplace("h", h_dtype);
types.emplace("c", c_dtype);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(types, float16_set, op_name, true);
const std::set<TypePtr> valid_b_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("b", b_dtype, valid_b_types, op_name);
return std::make_shared<Tuple>(
std::vector<TypePtr>{x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype});
}
} // namespace
MIND_API_OPERATOR_IMPL(DynamicRNN, BaseOperator);
AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 6;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto type = DynamicRNNInferType(primitive, input_args);
auto shape = DynamicRNNInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicRNN, prim::kPrimDynamicRNN, DynamicRNNInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_
#define MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_
#include <string>
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameDynamicRNN = "DynamicRNN";
class MIND_API DynamicRNN : public BaseOperator {
public:
MIND_API_BASE_MEMBER(DynamicRNN);
DynamicRNN() : BaseOperator(kNameDynamicRNN) {
InitIOName({"x", "w", "b", "seq_length", "init_h", "init_c", "wci", "wcf", "wco", "mask"},
{"y", "output_h", "output_c", "i", "j", "f", "o", "tanhc"});
}
void Init() {}
};
abstract::AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_

View File

@ -7283,7 +7283,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
return c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype
class DynamicRNN(PrimitiveWithInfer):
class DynamicRNN(Primitive):
r"""
Applies a recurrent neural network to the input.
Only long short-term memory (LSTM) is supported currently.
@ -7408,43 +7408,6 @@ class DynamicRNN(PrimitiveWithInfer):
validator.check_value_type("activation", activation, [str], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
if seq_shape is not None:
raise ValueError(f"For '{self.name}', the 'seq_length' must be None.")
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
if w_shape[-1] % 4 != 0:
raise ValueError(f"For '{self.name}', the last dimension of 'w' must be a multiple of 4, "
f"but got {w_shape[-1]}.")
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
self.placeholder_index = [3]
self.add_prim_attr("placeholder_index", self.placeholder_index)
self.add_prim_attr("input_size", input_size)
self.add_prim_attr("hidden_size", hidden_size)
y_shape = (num_step, batch_size, hidden_size)
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name),
("x", "w", "h", "c"),
(x_dtype, w_dtype, h_dtype, c_dtype)))
validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
class DynamicGRUV2(PrimitiveWithInfer):
r"""