forked from mindspore-Ecosystem/mindspore
!19000 update LayerNormGrad split pass to V2
Merge pull request !19000 from yuchaojie/ir_fusion2
This commit is contained in:
commit
32281f84e7
|
@ -59,10 +59,16 @@ int TypeStrToDstType(const std::string &type_str) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
std::unordered_set<std::string> TbeAdapter::input_order_adjusted_ops_ = {
|
||||
kConv2DBackpropInputOpName, kConv2DBackpropFilterOpName, kLogSoftmaxGradOpName,
|
||||
kLayerNormGradOpName, kLayerNormXBackpropOpName, kLayerNormBetaGammaBackpropOpName,
|
||||
kMinimumGradOpName, kMaximumGradOpName, kApplyCenteredRMSPropOpName};
|
||||
std::unordered_set<std::string> TbeAdapter::input_order_adjusted_ops_ = {kConv2DBackpropInputOpName,
|
||||
kConv2DBackpropFilterOpName,
|
||||
kLogSoftmaxGradOpName,
|
||||
kLayerNormGradOpName,
|
||||
kLayerNormXBackpropOpName,
|
||||
kLayerNormXBackpropV2OpName,
|
||||
kLayerNormBetaGammaBackpropOpName,
|
||||
kMinimumGradOpName,
|
||||
kMaximumGradOpName,
|
||||
kApplyCenteredRMSPropOpName};
|
||||
|
||||
std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
|
||||
// TODO(xxx): tbeadapter max and min
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,16 +26,19 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kLayerNormGradOutputGammaIndex = 1;
|
||||
constexpr size_t kLayerNormGradOutputBetaIndex = 2;
|
||||
constexpr size_t kLayerNormGradInputGammaIndex = 4;
|
||||
void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
|
||||
} // namespace
|
||||
|
||||
void LayerNormGradSplit::CreateOutputsOfLayerNormXBackpropV2(
|
||||
const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_grad);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop_outputs);
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName);
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormXBackpropV2OpName);
|
||||
std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) {
|
||||
layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i));
|
||||
|
@ -44,23 +47,23 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
|
|||
MS_EXCEPTION_IF_NULL(layer_norm_x_backprop);
|
||||
layer_norm_x_backprop->set_scope(layer_norm_grad->scope());
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0)};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0), kNumberTypeFloat32};
|
||||
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 0),
|
||||
AnfAlgo::GetPrevNodeOutputDetailShape(layer_norm_grad, 1)};
|
||||
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_x_backprop.get());
|
||||
|
||||
(*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop);
|
||||
CreateMultipleOutputsOfAnfNode(graph, layer_norm_x_backprop, kLayerNormXBackpropV2OutputNum,
|
||||
layer_norm_x_backprop_outputs);
|
||||
}
|
||||
|
||||
void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
|
||||
const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackpropV2(
|
||||
const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, const AnfNodePtr &res_for_gamma,
|
||||
std::vector<AnfNodePtr> *layer_norm_beta_gamma_backprop_outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_grad);
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropOpName);
|
||||
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) {
|
||||
layer_norm_beta_gamma_backprop_inputs.push_back(layer_norm_grad->input(i));
|
||||
}
|
||||
auto prim = std::make_shared<Primitive>(kLayerNormBetaGammaBackpropV2OpName);
|
||||
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim), layer_norm_grad->input(kIndex2),
|
||||
res_for_gamma};
|
||||
auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs);
|
||||
MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
|
@ -99,15 +102,16 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An
|
|||
|
||||
// create layer_norm_x_backprop
|
||||
std::vector<AnfNodePtr> layer_norm_x_backprop_outputs;
|
||||
CreateOutputsOfLayerNormXBackprop(graph, cnode, &layer_norm_x_backprop_outputs);
|
||||
if (layer_norm_x_backprop_outputs.size() != kSingleOutputNum) {
|
||||
CreateOutputsOfLayerNormXBackpropV2(graph, cnode, &layer_norm_x_backprop_outputs);
|
||||
if (layer_norm_x_backprop_outputs.size() != kLayerNormXBackpropV2OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "layer_norm_grad_outputs has wrong size"
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
// create layer_norm_beta_gamma_backprop
|
||||
std::vector<AnfNodePtr> layer_norm_beta_gamma_backprop_outputs;
|
||||
CreateOutputsOfLayerNormBetaGammaBackprop(graph, cnode, &layer_norm_beta_gamma_backprop_outputs);
|
||||
CreateOutputsOfLayerNormBetaGammaBackpropV2(graph, cnode, layer_norm_x_backprop_outputs[1],
|
||||
&layer_norm_beta_gamma_backprop_outputs);
|
||||
if (layer_norm_beta_gamma_backprop_outputs.size() != kLayerNormBetaGammaBackpropOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "layer_norm_beta_gamma_outputs has wrong size"
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
|
|
|
@ -32,10 +32,11 @@ class LayerNormGradSplit : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
void CreateOutputsOfLayerNormXBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_grad_outputs) const;
|
||||
void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const;
|
||||
void CreateOutputsOfLayerNormXBackpropV2(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_grad_outputs) const;
|
||||
void CreateOutputsOfLayerNormBetaGammaBackpropV2(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
const AnfNodePtr &res_for_gamma,
|
||||
std::vector<AnfNodePtr> *layer_norm_beta_gamma_outputs) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -91,6 +91,9 @@ constexpr size_t kBackendTransposeInputTensorNum = 1;
|
|||
constexpr size_t kAdamApplyOneWithDecayOutputNum = 3;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropV2InputTensorNum = 2;
|
||||
constexpr size_t kLayerNormXBackpropOutputNum = 4;
|
||||
constexpr size_t kLayerNormXBackpropV2OutputNum = 2;
|
||||
constexpr size_t kLayerNormGradInputTensorNum = 5;
|
||||
constexpr size_t kAdamApplyOneOutputNum = 3;
|
||||
constexpr size_t kApplyMomentumInputTensorNum = 5;
|
||||
|
|
|
@ -1157,6 +1157,11 @@ abstract::BaseShapePtr AnfRuntimeAlgorithm::GetOutputDetailShape(const AnfNodePt
|
|||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr AnfRuntimeAlgorithm::GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputDetailShape(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
// set infer shapes and types of anf node
|
||||
void AnfRuntimeAlgorithm::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
||||
const std::vector<abstract::BaseShapePtr> &shapes,
|
||||
|
@ -1539,6 +1544,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
|
|||
{prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
|
||||
{prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
|
||||
{prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
|
||||
{prim::kPrimLayerNormXBackpropV2->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
|
||||
{prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
|
||||
{prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
|
||||
{prim::kPrimApplyCenteredRMSProp->name(),
|
||||
|
|
|
@ -199,6 +199,7 @@ class AnfRuntimeAlgorithm {
|
|||
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
|
||||
// get and set output shape ptr
|
||||
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
|
||||
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
|
||||
static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
||||
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
|
||||
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
|
||||
|
|
|
@ -144,7 +144,9 @@ constexpr auto kLambNextRightOpName = "LambNextRight";
|
|||
constexpr auto kConfusionSoftmaxGradOpName = "ConfusionSoftmaxGrad";
|
||||
constexpr auto kLambUpdateWithLrV2OpName = "LambUpdateWithLrV2";
|
||||
constexpr auto kLayerNormXBackpropOpName = "LayerNormXBackprop";
|
||||
constexpr auto kLayerNormXBackpropV2OpName = "LayerNormXBackpropV2";
|
||||
constexpr auto kLayerNormBetaGammaBackpropOpName = "LayerNormBetaGammaBackprop";
|
||||
constexpr auto kLayerNormBetaGammaBackpropV2OpName = "LayerNormBetaGammaBackpropV2";
|
||||
constexpr auto kLambNextMVOpName = "LambNextMV";
|
||||
constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD";
|
||||
constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay";
|
||||
|
|
|
@ -336,7 +336,10 @@ inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("LRN");
|
|||
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>(kLayerNorm);
|
||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>(kLayerNormGrad);
|
||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormXBackpropV2 = std::make_shared<Primitive>("LayerNormXBackpropV2");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackpropV2 =
|
||||
std::make_shared<Primitive>("LayerNormBetaGammaBackpropV2");
|
||||
inline const PrimitivePtr kPrimLog1p = std::make_shared<Primitive>("Log1p");
|
||||
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>(kDropoutGenMask);
|
||||
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>(kDropoutDoMask);
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/layer_norm_beta_gamma_backprop_v2.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr LayerNormBetaGammaBackpropV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr gamma_value_ptr = primitive->GetAttr(kShapeGamma);
|
||||
MS_EXCEPTION_IF_NULL(gamma_value_ptr);
|
||||
auto gamma_shape = GetValue<ShapeVector>(gamma_value_ptr);
|
||||
auto gamma_shape_ptr = std::make_shared<abstract::Shape>(gamma_shape);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{gamma_shape_ptr, gamma_shape_ptr});
|
||||
}
|
||||
|
||||
TypePtr LayerNormBetaGammaBackpropV2InferType(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("gamma", input_args[0]->BuildType());
|
||||
types.emplace("beta", input_args[0]->BuildType());
|
||||
auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{output_type, output_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void LayerNormBetaGammaBackpropV2::Init(const std::vector<int64_t> &shape_gamma) { set_shape_gamma(shape_gamma); }
|
||||
|
||||
void LayerNormBetaGammaBackpropV2::set_shape_gamma(const std::vector<int64_t> &shape_gamma) {
|
||||
(void)AddAttr(kShapeGamma, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kShapeGamma, shape_gamma, name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> LayerNormBetaGammaBackpropV2::get_shape_gamma() const {
|
||||
auto value_ptr = GetAttr(kShapeGamma);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr LayerNormBetaGammaBackpropV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInteger("LayerNormBetaGammaBackpropV2 infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
return abstract::MakeAbstract(LayerNormBetaGammaBackpropV2InferShape(primitive, input_args),
|
||||
LayerNormBetaGammaBackpropV2InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormBetaGammaBackpropV2, prim::kPrimLayerNormBetaGammaBackpropV2,
|
||||
LayerNormBetaGammaBackpropV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LAYERNORMBETAGAMMABACKPROPV2_H_
|
||||
#define MINDSPORE_CORE_OPS_LAYERNORMBETAGAMMABACKPROPV2_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class LayerNormBetaGammaBackpropV2 : public PrimitiveC {
|
||||
public:
|
||||
LayerNormBetaGammaBackpropV2() : PrimitiveC(prim::kPrimLayerNormBetaGammaBackpropV2->name()) {}
|
||||
~LayerNormBetaGammaBackpropV2() = default;
|
||||
MS_DECLARE_PARENT(LayerNormBetaGammaBackpropV2, PrimitiveC);
|
||||
void Init(const std::vector<int64_t> &shape_gamma);
|
||||
void set_shape_gamma(const std::vector<int64_t> &shape_gamma);
|
||||
std::vector<int64_t> get_shape_gamma() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr LayerNormBetaGammaBackpropV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LAYERNORMBETAGAMMABACKPROPV2_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/layer_norm_x_backprop_v2.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr LayerNormXBackpropV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(primitive->name(), input_args, 0);
|
||||
auto res_for_gamma_shape = CheckAndConvertUtils::GetTensorInputShape(primitive->name(), input_args, 1);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape, res_for_gamma_shape});
|
||||
}
|
||||
|
||||
TypePtr LayerNormXBackpropV2InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, kFloat32});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr LayerNormXBackpropV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInteger("LayerNormXBackpropV2 infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
return abstract::MakeAbstract(LayerNormXBackpropV2InferShape(primitive, input_args),
|
||||
LayerNormXBackpropV2InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormXBackpropV2, prim::kPrimLayerNormXBackpropV2, LayerNormXBackpropV2Infer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LAYERNORMXBACKPROPV2_H_
|
||||
#define MINDSPORE_CORE_OPS_LAYERNORMXBACKPROPV2_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class LayerNormXBackpropV2 : public PrimitiveC {
|
||||
public:
|
||||
LayerNormXBackpropV2() : PrimitiveC(prim::kPrimLayerNormXBackpropV2->name()) {}
|
||||
~LayerNormXBackpropV2() = default;
|
||||
MS_DECLARE_PARENT(LayerNormXBackpropV2, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr LayerNormXBackpropV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LAYERNORMXBACKPROPV2_H_
|
|
@ -177,6 +177,7 @@ constexpr auto kSeed2 = "seed2";
|
|||
constexpr auto kSeqDim = "seq_dim";
|
||||
constexpr auto kSetattrFlag = "setattr_flag";
|
||||
constexpr auto kShape = "shape";
|
||||
constexpr auto kShapeGamma = "shape_gamma";
|
||||
constexpr auto kShapeSize = "shape_size";
|
||||
constexpr auto kShift = "shift";
|
||||
constexpr auto kShrinkAxisMask = "shrink_axis_mask";
|
||||
|
|
|
@ -196,10 +196,14 @@ from .clip_by_norm_no_div_sum import _clip_by_norm_no_div_sum_tbe
|
|||
from .clip_by_value import _clip_by_value_tbe
|
||||
from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe
|
||||
from .layer_norm_beta_gamma_backprop_ds import _layer_norm_beta_gamma_backprop_ds_tbe
|
||||
from .layer_norm_beta_gamma_backprop_v2 import _layer_norm_beta_gamma_backprop_v2_tbe
|
||||
from .layer_norm_beta_gamma_backprop_v2_ds import _layer_norm_beta_gamma_backprop_v2_ds_tbe
|
||||
from .layer_norm import _layer_norm_tbe
|
||||
from .layer_norm_ds import _layer_norm_ds_tbe
|
||||
from .layer_norm_grad import _layer_norm_grad_tbe
|
||||
from .layer_norm_x_backprop_ds import _layer_norm_x_backprop_ds_tbe
|
||||
from .layer_norm_x_backprop_v2 import _layer_norm_x_backprop_v2_tbe
|
||||
from .layer_norm_x_backprop_v2_ds import _layer_norm_x_backprop_v2_ds_tbe
|
||||
from .l2_loss import _l2_loss_tbe
|
||||
from .l2_normalize import _l2_normalize_tbe
|
||||
from .l2_normalize_grad import _l2_normalize_grad_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNormBetaGammaBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_beta_gamma_backprop_v2_op_info = TBERegOp("LayerNormBetaGammaBackpropV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_beta_gamma_backprop_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_beta_gamma_backprop_v2") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape_gamma", "required", "listInt", "all") \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "res_for_gamma", False, "required", "all") \
|
||||
.output(0, "pd_gamma", False, "required", "all") \
|
||||
.output(1, "pd_beta", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_beta_gamma_backprop_v2_op_info)
|
||||
def _layer_norm_beta_gamma_backprop_v2_tbe():
|
||||
"""LayerNormBetaGammaBackpropV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNormBetaGammaBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_beta_gamma_backprop_v2_op_info = TBERegOp("LayerNormBetaGammaBackpropV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_beta_gamma_backprop_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_beta_gamma_backprop_v2") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("shape_gamma", "required", "listInt", "all") \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "res_for_gamma", False, "required", "all") \
|
||||
.output(0, "pd_gamma", False, "required", "all") \
|
||||
.output(1, "pd_beta", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_beta_gamma_backprop_v2_op_info)
|
||||
def _layer_norm_beta_gamma_backprop_v2_ds_tbe():
|
||||
"""LayerNormBetaGammaBackpropV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNormXBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_x_backprop_v2_op_info = TBERegOp("LayerNormXBackpropV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_x_backprop_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_x_backprop_v2") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "variance", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "gamma", False, "required", "all") \
|
||||
.output(0, "pd_x", False, "required", "all") \
|
||||
.output(1, "res_for_gamma", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||
DataType.F16_None, DataType.F16_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_x_backprop_v2_op_info)
|
||||
def _layer_norm_x_backprop_v2_tbe():
|
||||
"""LayerNormXBackpropV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNormXBackprop op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
layer_norm_x_backprop_v2_op_info = TBERegOp("LayerNormXBackpropV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("layer_norm_x_backprop_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("layer_norm_x_backprop_v2") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "dy", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "variance", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "gamma", False, "required", "all") \
|
||||
.output(0, "pd_x", False, "required", "all") \
|
||||
.output(1, "res_for_gamma", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||
DataType.F16_None, DataType.F16_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_x_backprop_v2_op_info)
|
||||
def _layer_norm_x_backprop_v2_ds_tbe():
|
||||
"""LayerNormXBackpropV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "ops/layer_norm_beta_gamma_backprop_v2.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/value.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class TestLayerNormBetaGammaBackpropV2 : public UT::Common {
|
||||
public:
|
||||
TestLayerNormBetaGammaBackpropV2() {}
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
TEST_F(TestLayerNormBetaGammaBackpropV2, test_ops_layer_norm_beta_gamma_backprop_v2_1) {
|
||||
auto ln_back = std::make_shared<LayerNormBetaGammaBackpropV2>();
|
||||
ln_back->Init(std::vector<int64_t>{1024});
|
||||
auto shape_gamma = ln_back->get_shape_gamma();
|
||||
EXPECT_EQ(shape_gamma.size(), 1);
|
||||
EXPECT_EQ(shape_gamma[0], 1024);
|
||||
auto input_dy = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1, 128, 1});
|
||||
auto input_res_gamma = TensorConstructUtils::CreateOnesTensor(kFloat32, std::vector<int64_t>{1, 128, 1});
|
||||
MS_EXCEPTION_IF_NULL(input_dy);
|
||||
MS_EXCEPTION_IF_NULL(input_res_gamma);
|
||||
AbstractBasePtrList inputs = {input_dy->ToAbstract(), input_res_gamma->ToAbstract()};
|
||||
auto abstract = ln_back->Infer(inputs);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
EXPECT_EQ(abstract->isa<abstract::AbstractTuple>(), true);
|
||||
auto shape_ptr = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
EXPECT_EQ(shape_ptr->isa<abstract::TupleShape>(), true);
|
||||
auto shape = shape_ptr->cast<abstract::TupleShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_vec = shape->shape();
|
||||
EXPECT_EQ(shape_vec.size(), 2);
|
||||
auto shape1 = shape_vec[0]->cast<abstract::ShapePtr>()->shape();
|
||||
EXPECT_EQ(shape1.size(), 1);
|
||||
EXPECT_EQ(shape1[0], 1024);
|
||||
auto shape2 = shape_vec[1]->cast<abstract::ShapePtr>()->shape();
|
||||
EXPECT_EQ(shape2.size(), 1);
|
||||
EXPECT_EQ(shape2[0], 1024);
|
||||
auto type_ptr = abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type = type_ptr->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto type_vec = type->elements();
|
||||
MS_EXCEPTION_IF_NULL(type_vec[0]);
|
||||
auto data_type = type_vec[0]->cast<TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat16);
|
||||
MS_EXCEPTION_IF_NULL(type_vec[1]);
|
||||
auto data1_type = type_vec[1]->cast<TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(data1_type);
|
||||
EXPECT_EQ(data1_type->type_id(), kNumberTypeFloat16);
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "ops/layer_norm_x_backprop_v2.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/value.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class TestLayerNormXBackpropV2 : public UT::Common {
|
||||
public:
|
||||
TestLayerNormXBackpropV2() {}
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
TEST_F(TestLayerNormXBackpropV2, test_ops_layer_norm_x_backprop_v2_1) {
|
||||
auto ln_back = std::make_shared<LayerNormXBackpropV2>();
|
||||
auto input_x = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1, 128, 1024});
|
||||
auto input_dy = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1, 128, 1});
|
||||
auto input_var = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1, 128, 1});
|
||||
auto input_mean = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1, 128, 1});
|
||||
auto input_gamma = TensorConstructUtils::CreateOnesTensor(kFloat16, std::vector<int64_t>{1024});
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_dy);
|
||||
MS_EXCEPTION_IF_NULL(input_var);
|
||||
MS_EXCEPTION_IF_NULL(input_mean);
|
||||
MS_EXCEPTION_IF_NULL(input_gamma);
|
||||
AbstractBasePtrList inputs = {input_x->ToAbstract(), input_dy->ToAbstract(), input_var->ToAbstract(),
|
||||
input_mean->ToAbstract(), input_gamma->ToAbstract()};
|
||||
auto abstract = ln_back->Infer(inputs);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
EXPECT_EQ(abstract->isa<abstract::AbstractTuple>(), true);
|
||||
auto shape_ptr = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
EXPECT_EQ(shape_ptr->isa<abstract::TupleShape>(), true);
|
||||
auto shape = shape_ptr->cast<abstract::TupleShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_vec = shape->shape();
|
||||
EXPECT_EQ(shape_vec.size(), 2);
|
||||
auto shape1 = shape_vec[0]->cast<abstract::ShapePtr>()->shape();
|
||||
EXPECT_EQ(shape1.size(), 3);
|
||||
EXPECT_EQ(shape1[0], 1);
|
||||
EXPECT_EQ(shape1[1], 128);
|
||||
EXPECT_EQ(shape1[2], 1024);
|
||||
auto shape2 = shape_vec[1]->cast<abstract::ShapePtr>()->shape();
|
||||
EXPECT_EQ(shape2.size(), 3);
|
||||
EXPECT_EQ(shape2[0], 1);
|
||||
EXPECT_EQ(shape2[1], 128);
|
||||
EXPECT_EQ(shape2[2], 1);
|
||||
auto type_ptr = abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type = type_ptr->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto type_vec = type->elements();
|
||||
MS_EXCEPTION_IF_NULL(type_vec[0]);
|
||||
auto data_type = type_vec[0]->cast<TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
EXPECT_EQ(data_type->type_id(), kNumberTypeFloat16);
|
||||
MS_EXCEPTION_IF_NULL(type_vec[1]);
|
||||
auto data1_type = type_vec[1]->cast<TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(data1_type);
|
||||
EXPECT_EQ(data1_type->type_id(), kNumberTypeFloat32);
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -20,8 +20,8 @@ from mindspore.ops import _constants as Constants
|
|||
make_tuple = Primitive('MakeTuple')
|
||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||
layer_norm_grad = G.LayerNormGrad()
|
||||
layer_norm_x_backprop = Primitive('LayerNormXBackprop')
|
||||
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackprop')
|
||||
layer_norm_x_backprop = Primitive('LayerNormXBackpropV2')
|
||||
layer_norm_beta_gamma_backprop = Primitive('LayerNormBetaGammaBackpropV2')
|
||||
|
||||
|
||||
class FnDict:
|
||||
|
@ -51,10 +51,12 @@ def test_layer_norm_grad_split(tag):
|
|||
@fns
|
||||
def after(i0, i1, i2, i3, i4):
|
||||
layer_norm_x_output = layer_norm_x_backprop(i0, i1, i2, i3, i4)
|
||||
layer_norm_beta_output = layer_norm_beta_gamma_backprop(i0, i1, i2, i3)
|
||||
x_item0 = tuple_getitem(layer_norm_x_output, 0)
|
||||
x_item1 = tuple_getitem(layer_norm_x_output, 1)
|
||||
layer_norm_beta_output = layer_norm_beta_gamma_backprop(i1, x_item1)
|
||||
beta_item0 = tuple_getitem(layer_norm_beta_output, 0)
|
||||
beta_item1 = tuple_getitem(layer_norm_beta_output, 1)
|
||||
mt = make_tuple(layer_norm_x_output, beta_item0, beta_item1)
|
||||
mt = make_tuple(x_item0, beta_item0, beta_item1)
|
||||
item0 = tuple_getitem(mt, 0)
|
||||
item1 = tuple_getitem(mt, 1)
|
||||
item2 = tuple_getitem(mt, 2)
|
||||
|
|
Loading…
Reference in New Issue