cluster standardnormal

This commit is contained in:
Yang Jiao 2021-09-08 21:53:52 +08:00
parent de19ef2931
commit 32d3ce532b
8 changed files with 69 additions and 1 deletions

2
akg

@ -1 +1 @@
Subproject commit 78c10ce87a7edbf28b8ccd2b23028cd2126cba61
Subproject commit 545ebff8aff5fb7877337f004652b36fe7ca515e

View File

@ -234,6 +234,7 @@ class PrimLib:
'Gather': Prim(OPAQUE),
'GatherNd': Prim(OPAQUE),
'UnsortedSegmentSum': Prim(OPAQUE),
'StandardNormal': Prim(OPAQUE),
'UserDefined': Prim(OPAQUE),
}

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
namespace mindspore {
namespace opt {
namespace expanders {
class StandardNormal : public OpExpander {
public:
StandardNormal() {
std::initializer_list<std::string> attrs{"seed", "seed2"};
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
}
~StandardNormal() {}
NodePtrList Expand() override {
const auto &inputs = gb.Get()->inputs();
const auto &input_x = inputs[0];
auto shape = MakeValue(outputs_info_[0].shape);
auto result =
gb.Emit("StandardNormal", {input_x}, {{"shape", shape}, {"seed", attrs_["seed"]}, {"seed2", attrs_["seed2"]}});
return {result};
}
};
OP_EXPANDER_REGISTER("StandardNormal", StandardNormal);
} // namespace expanders
} // namespace opt
} // namespace mindspore

View File

@ -46,6 +46,7 @@ using context::OpLevel_1;
constexpr size_t kAssignInputIdx = 1;
constexpr size_t kLambOptimizerInputIdx = 12;
constexpr size_t kLambWeightInputIdx = 4;
constexpr size_t kRandomInputIdx = 1;
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
@ -93,6 +94,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
{kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
{kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
{kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
{kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
};
const auto &flags = context::GraphKernelFlags::GetInstance();
std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
@ -201,6 +203,7 @@ ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambOptimizerInputIdx)},
{prim::kLambApplyWeightAssign, std::make_shared<OpUMonadExpander>(kLambWeightInputIdx)},
{prim::kPrimStandardNormal, std::make_shared<OpUMonadExpander>(kRandomInputIdx)},
};
for (auto &e : expanders) {

View File

@ -543,6 +543,11 @@ void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
}
}
DShape StandardNormalOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "shape");
return GetListInt(attrs.find("shape")->second);
}
} // namespace graphkernel
} // namespace opt
} // namespace mindspore

View File

@ -282,6 +282,17 @@ class ComplexOp : public ElemwiseOp {
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
};
class StandardNormalOp : public OpaqueOp {
public:
StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {}
~StandardNormalOp() = default;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }
};
} // namespace graphkernel
} // namespace opt
} // namespace mindspore

View File

@ -79,6 +79,7 @@ class OpRegistry {
Register("CImag", OP_CREATOR(CImagOp));
Register("Complex", OP_CREATOR(ComplexOp));
Register("Opaque", OP_CREATOR(OpaqueOp));
Register("StandardNormal", OP_CREATOR(StandardNormalOp));
}
~OpRegistry() = default;
std::unordered_map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;

View File

@ -702,6 +702,9 @@ inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitiv
inline const PrimitivePtr kPrimDynamicBroadcastGradientArgs =
std::make_shared<Primitive>(kDynamicBroadcastGradientArgs);
// Random
inline const PrimitivePtr kPrimStandardNormal = std::make_shared<Primitive>("StandardNormal");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)