c++ infer: dynamic_rnn
This commit is contained in:
parent
bf776e8807
commit
4777eaa06e
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "ops/dynamic_rnn.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t kDynRnnIdx0 = 0;
|
||||
constexpr int64_t kDynRnnIdx1 = 1;
|
||||
constexpr int64_t kDynRnnIdx2 = 2;
|
||||
constexpr int64_t kDynRnnIdx3 = 3;
|
||||
constexpr int64_t kDynRnnIdx4 = 4;
|
||||
constexpr int64_t kDynRnnIdx5 = 5;
|
||||
constexpr int64_t kDynamicRnnShapeX = 3;
|
||||
constexpr int64_t kDynamicRnnShapeW = 2;
|
||||
constexpr int64_t kDynamicRnnShapeB = 1;
|
||||
constexpr int64_t kDynamicRnnShapeH = 3;
|
||||
constexpr int64_t kDynamicRnnShapeC = 3;
|
||||
constexpr int64_t kDynRnnNum4 = 4;
|
||||
|
||||
abstract::TupleShapePtr DynamicRNNInferDynamicShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t y_shape_num = 3;
|
||||
ShapeVector y_shape_dyn;
|
||||
for (size_t i = 0; i < y_shape_num; ++i) {
|
||||
y_shape_dyn.push_back(abstract::Shape::kShapeDimAny);
|
||||
}
|
||||
abstract::ShapePtr y_shape_dyn_ptr = std::make_shared<abstract::Shape>(y_shape_dyn);
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr,
|
||||
y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr, y_shape_dyn_ptr});
|
||||
}
|
||||
|
||||
void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx2]->BuildShape())[kShape];
|
||||
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
|
||||
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
|
||||
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
|
||||
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kEqual, kDynamicRnnShapeX, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("w_shape", SizeToLong(w_shape.size()), kEqual, kDynamicRnnShapeW, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("b_shape", SizeToLong(b_shape.size()), kEqual, kDynamicRnnShapeB, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC, op_name);
|
||||
int64_t batch_size = x_shape[kDynRnnIdx1];
|
||||
int64_t input_size = x_shape[kDynRnnIdx2];
|
||||
if (seq_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
|
||||
if (w_shape[w_shape.size() - 1] % kDynRnnNum4 != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', w_shape[-1] should multiple of 4, now is "
|
||||
<< w_shape[w_shape.size() - 1] << ".";
|
||||
}
|
||||
if (w_shape[0] != input_size + hidden_size) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name
|
||||
<< "', w_shape[0] should equal to input_size + hidden_size, but gets " << w_shape[0]
|
||||
<< ".";
|
||||
}
|
||||
if (b_shape[0] != w_shape[1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', b_shape[0] should equal to w_shape[1], but gets "
|
||||
<< b_shape[0] << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
|
||||
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
|
||||
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
|
||||
std::vector<ValuePtr> placeholder_index = {MakeValue((int64_t)3)};
|
||||
primitive->set_attr("placeholder_index", MakeValue(placeholder_index));
|
||||
if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) {
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny}),
|
||||
std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny})});
|
||||
}
|
||||
if (IsDynamic(x_shape) || IsDynamic(w_shape)) {
|
||||
return DynamicRNNInferDynamicShape(input_args);
|
||||
}
|
||||
DynamicRNNShapeCheck(primitive, input_args);
|
||||
int64_t num_step = x_shape[kDynRnnIdx0];
|
||||
int64_t batch_size = x_shape[kDynRnnIdx1];
|
||||
int64_t input_size = x_shape[kDynRnnIdx2];
|
||||
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
|
||||
primitive->set_attr("input_size", MakeValue(input_size));
|
||||
primitive->set_attr("hidden_size", MakeValue(hidden_size));
|
||||
std::vector<int64_t> y_shape{num_step, batch_size, hidden_size};
|
||||
abstract::ShapePtr y_shape_ptr = std::make_shared<abstract::Shape>(y_shape);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr, y_shape_ptr});
|
||||
}
|
||||
|
||||
TuplePtr DynamicRNNInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto x_dtype = input_args[kDynRnnIdx0]->BuildType();
|
||||
auto w_dtype = input_args[kDynRnnIdx1]->BuildType();
|
||||
auto b_dtype = input_args[kDynRnnIdx2]->BuildType();
|
||||
auto h_dtype = input_args[kDynRnnIdx4]->BuildType();
|
||||
auto c_dtype = input_args[kDynRnnIdx5]->BuildType();
|
||||
auto seq_type = input_args[kDynRnnIdx3]->BuildType();
|
||||
if (seq_type->type_id() != kMetaTypeNone) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "' seq is not None, please check seq's type";
|
||||
}
|
||||
std::set<TypePtr> float16_set = {kFloat16};
|
||||
MS_EXCEPTION_IF_NULL(x_dtype);
|
||||
MS_EXCEPTION_IF_NULL(w_dtype);
|
||||
MS_EXCEPTION_IF_NULL(h_dtype);
|
||||
MS_EXCEPTION_IF_NULL(c_dtype);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", x_dtype);
|
||||
types.emplace("w", w_dtype);
|
||||
types.emplace("h", h_dtype);
|
||||
types.emplace("c", c_dtype);
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(types, float16_set, op_name, true);
|
||||
const std::set<TypePtr> valid_b_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("b", b_dtype, valid_b_types, op_name);
|
||||
return std::make_shared<Tuple>(
|
||||
std::vector<TypePtr>{x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(DynamicRNN, BaseOperator);
|
||||
|
||||
AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 6;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto type = DynamicRNNInferType(primitive, input_args);
|
||||
auto shape = DynamicRNNInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicRNN, prim::kPrimDynamicRNN, DynamicRNNInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_
|
||||
#define MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameDynamicRNN = "DynamicRNN";
|
||||
class MIND_API DynamicRNN : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(DynamicRNN);
|
||||
|
||||
DynamicRNN() : BaseOperator(kNameDynamicRNN) {
|
||||
InitIOName({"x", "w", "b", "seq_length", "init_h", "init_c", "wci", "wcf", "wco", "mask"},
|
||||
{"y", "output_h", "output_c", "i", "j", "f", "o", "tanhc"});
|
||||
}
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr DynamicRNNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DYNAMIC_RNN_H_
|
|
@ -7283,7 +7283,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
return c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype
|
||||
|
||||
|
||||
class DynamicRNN(PrimitiveWithInfer):
|
||||
class DynamicRNN(Primitive):
|
||||
r"""
|
||||
Applies a recurrent neural network to the input.
|
||||
Only long short-term memory (LSTM) is supported currently.
|
||||
|
@ -7408,43 +7408,6 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
validator.check_value_type("activation", activation, [str], self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
|
||||
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
|
||||
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
|
||||
validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
|
||||
validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For '{self.name}', the 'seq_length' must be None.")
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = w_shape[-1] // 4
|
||||
|
||||
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
|
||||
if w_shape[-1] % 4 != 0:
|
||||
raise ValueError(f"For '{self.name}', the last dimension of 'w' must be a multiple of 4, "
|
||||
f"but got {w_shape[-1]}.")
|
||||
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
|
||||
input_size + hidden_size, Rel.EQ, self.name)
|
||||
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name)
|
||||
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
|
||||
self.placeholder_index = [3]
|
||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||
self.add_prim_attr("input_size", input_size)
|
||||
self.add_prim_attr("hidden_size", hidden_size)
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name),
|
||||
("x", "w", "h", "c"),
|
||||
(x_dtype, w_dtype, h_dtype, c_dtype)))
|
||||
validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
class DynamicGRUV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue