forked from mindspore-Ecosystem/mindspore
cluster standardnormal
This commit is contained in:
parent
de19ef2931
commit
32d3ce532b
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 78c10ce87a7edbf28b8ccd2b23028cd2126cba61
|
||||
Subproject commit 545ebff8aff5fb7877337f004652b36fe7ca515e
|
|
@ -234,6 +234,7 @@ class PrimLib:
|
|||
'Gather': Prim(OPAQUE),
|
||||
'GatherNd': Prim(OPAQUE),
|
||||
'UnsortedSegmentSum': Prim(OPAQUE),
|
||||
'StandardNormal': Prim(OPAQUE),
|
||||
'UserDefined': Prim(OPAQUE),
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
class StandardNormal : public OpExpander {
|
||||
public:
|
||||
StandardNormal() {
|
||||
std::initializer_list<std::string> attrs{"seed", "seed2"};
|
||||
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
|
||||
}
|
||||
~StandardNormal() {}
|
||||
NodePtrList Expand() override {
|
||||
const auto &inputs = gb.Get()->inputs();
|
||||
const auto &input_x = inputs[0];
|
||||
auto shape = MakeValue(outputs_info_[0].shape);
|
||||
auto result =
|
||||
gb.Emit("StandardNormal", {input_x}, {{"shape", shape}, {"seed", attrs_["seed"]}, {"seed2", attrs_["seed2"]}});
|
||||
return {result};
|
||||
}
|
||||
};
|
||||
OP_EXPANDER_REGISTER("StandardNormal", StandardNormal);
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -46,6 +46,7 @@ using context::OpLevel_1;
|
|||
constexpr size_t kAssignInputIdx = 1;
|
||||
constexpr size_t kLambOptimizerInputIdx = 12;
|
||||
constexpr size_t kLambWeightInputIdx = 4;
|
||||
constexpr size_t kRandomInputIdx = 1;
|
||||
|
||||
std::vector<PrimitivePtr> GetExpandOps() {
|
||||
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
|
||||
|
@ -93,6 +94,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
{kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
|
||||
};
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
|
||||
|
@ -201,6 +203,7 @@ ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
|
|||
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
|
||||
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambOptimizerInputIdx)},
|
||||
{prim::kLambApplyWeightAssign, std::make_shared<OpUMonadExpander>(kLambWeightInputIdx)},
|
||||
{prim::kPrimStandardNormal, std::make_shared<OpUMonadExpander>(kRandomInputIdx)},
|
||||
};
|
||||
|
||||
for (auto &e : expanders) {
|
||||
|
|
|
@ -543,6 +543,11 @@ void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
|
||||
}
|
||||
}
|
||||
|
||||
DShape StandardNormalOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
CHECK_ATTR(attrs, "shape");
|
||||
return GetListInt(attrs.find("shape")->second);
|
||||
}
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -282,6 +282,17 @@ class ComplexOp : public ElemwiseOp {
|
|||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
|
||||
};
|
||||
|
||||
class StandardNormalOp : public OpaqueOp {
|
||||
public:
|
||||
StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {}
|
||||
~StandardNormalOp() = default;
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }
|
||||
};
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,6 +79,7 @@ class OpRegistry {
|
|||
Register("CImag", OP_CREATOR(CImagOp));
|
||||
Register("Complex", OP_CREATOR(ComplexOp));
|
||||
Register("Opaque", OP_CREATOR(OpaqueOp));
|
||||
Register("StandardNormal", OP_CREATOR(StandardNormalOp));
|
||||
}
|
||||
~OpRegistry() = default;
|
||||
std::unordered_map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
|
||||
|
|
|
@ -702,6 +702,9 @@ inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitiv
|
|||
inline const PrimitivePtr kPrimDynamicBroadcastGradientArgs =
|
||||
std::make_shared<Primitive>(kDynamicBroadcastGradientArgs);
|
||||
|
||||
// Random
|
||||
inline const PrimitivePtr kPrimStandardNormal = std::make_shared<Primitive>("StandardNormal");
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
|
|
Loading…
Reference in New Issue