forked from mindspore-Ecosystem/mindspore
optimize scalar to tensor function
This commit is contained in:
parent
d936779f42
commit
0647b8b7db
|
@ -75,7 +75,7 @@ AnfNodePtr CreateInt32Tensor(int64_t value) {
|
||||||
if (it != int_tensor_map.end()) {
|
if (it != int_tensor_map.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32);
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt32);
|
||||||
ValuePtr value_ptr = MakeValue(tensor_ptr);
|
ValuePtr value_ptr = MakeValue(tensor_ptr);
|
||||||
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
|
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
|
||||||
int_tensor_map[value] = anf_node_ptr;
|
int_tensor_map[value] = anf_node_ptr;
|
||||||
|
|
|
@ -382,7 +382,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
|
||||||
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
||||||
*tensor_mask = kValueNodeTensorMask;
|
*tensor_mask = kValueNodeTensorMask;
|
||||||
} else if (py::isinstance<py::int_>(input_object)) {
|
} else if (py::isinstance<py::int_>(input_object)) {
|
||||||
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt64);
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
|
||||||
*tensor_mask = kValueNodeTensorMask;
|
*tensor_mask = kValueNodeTensorMask;
|
||||||
} else if (py::isinstance<py::array>(input_object)) {
|
} else if (py::isinstance<py::array>(input_object)) {
|
||||||
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
|
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
|
||||||
|
|
|
@ -20,16 +20,13 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "utils/shape_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
bool ValueToBool(const ValuePtr &v, bool *value) {
|
bool ValueToBool(const ValuePtr &v, bool *value) {
|
||||||
|
@ -37,13 +34,13 @@ bool ValueToBool(const ValuePtr &v, bool *value) {
|
||||||
if (v->isa<BoolImm>()) {
|
if (v->isa<BoolImm>()) {
|
||||||
*value = v->cast<BoolImmPtr>()->value();
|
*value = v->cast<BoolImmPtr>()->value();
|
||||||
} else if (v->isa<Int32Imm>()) {
|
} else if (v->isa<Int32Imm>()) {
|
||||||
*value = v->cast<Int32ImmPtr>()->value() == 0 ? false : true;
|
*value = v->cast<Int32ImmPtr>()->value() != 0;
|
||||||
} else if (v->isa<UInt32Imm>()) {
|
} else if (v->isa<UInt32Imm>()) {
|
||||||
*value = v->cast<UInt32ImmPtr>()->value() == 0 ? false : true;
|
*value = v->cast<UInt32ImmPtr>()->value() != 0;
|
||||||
} else if (v->isa<FP32Imm>()) {
|
} else if (v->isa<FP32Imm>()) {
|
||||||
*value = v->cast<FP32ImmPtr>()->value() == 0 ? false : true;
|
*value = v->cast<FP32ImmPtr>()->value() != 0;
|
||||||
} else if (v->isa<FP64Imm>()) {
|
} else if (v->isa<FP64Imm>()) {
|
||||||
*value = v->cast<FP64ImmPtr>()->value() == 0 ? false : true;
|
*value = v->cast<FP64ImmPtr>()->value() != 0;
|
||||||
} else if (v->isa<tensor::Tensor>()) {
|
} else if (v->isa<tensor::Tensor>()) {
|
||||||
auto tensor = v->cast<tensor::TensorPtr>();
|
auto tensor = v->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
@ -65,11 +62,11 @@ bool BaseRefToInt(const ValuePtr &v, int64_t *value) {
|
||||||
auto tensor = v->cast<tensor::TensorPtr>();
|
auto tensor = v->cast<tensor::TensorPtr>();
|
||||||
(void)tensor->data_sync();
|
(void)tensor->data_sync();
|
||||||
if (tensor->Dtype()->ToString() == "Int32") {
|
if (tensor->Dtype()->ToString() == "Int32") {
|
||||||
int32_t *tensor_data = static_cast<int32_t *>(tensor->data_c());
|
auto *tensor_data = static_cast<int32_t *>(tensor->data_c());
|
||||||
auto vb = tensor_data[0];
|
auto vb = tensor_data[0];
|
||||||
*value = static_cast<int64_t>(vb);
|
*value = static_cast<int64_t>(vb);
|
||||||
} else if (tensor->Dtype()->ToString() == "Int64") {
|
} else if (tensor->Dtype()->ToString() == "Int64") {
|
||||||
int64_t *tensor_data = static_cast<int64_t *>(tensor->data_c());
|
auto *tensor_data = static_cast<int64_t *>(tensor->data_c());
|
||||||
auto vb = tensor_data[0];
|
auto vb = tensor_data[0];
|
||||||
*value = vb;
|
*value = vb;
|
||||||
} else {
|
} else {
|
||||||
|
@ -86,39 +83,19 @@ bool BaseRefToBool(const BaseRef &v, bool *value) {
|
||||||
return ValueToBool(utils::cast<ValuePtr>(v), value);
|
return ValueToBool(utils::cast<ValuePtr>(v), value);
|
||||||
} else if (utils::isa<bool>(v)) {
|
} else if (utils::isa<bool>(v)) {
|
||||||
auto vb = utils::cast<bool>(v);
|
auto vb = utils::cast<bool>(v);
|
||||||
if (vb == true) {
|
*value = vb;
|
||||||
*value = true;
|
|
||||||
} else {
|
|
||||||
*value = false;
|
|
||||||
}
|
|
||||||
} else if (utils::isa<int>(v)) {
|
} else if (utils::isa<int>(v)) {
|
||||||
auto vb = utils::cast<int>(v);
|
auto vb = utils::cast<int>(v);
|
||||||
if (vb == 0) {
|
*value = vb != 0;
|
||||||
*value = false;
|
|
||||||
} else {
|
|
||||||
*value = true;
|
|
||||||
}
|
|
||||||
} else if (utils::isa<unsigned int>(v)) {
|
} else if (utils::isa<unsigned int>(v)) {
|
||||||
auto vb = utils::cast<unsigned int>(v);
|
auto vb = utils::cast<unsigned int>(v);
|
||||||
if (vb == 0) {
|
*value = vb != 0;
|
||||||
*value = false;
|
|
||||||
} else {
|
|
||||||
*value = true;
|
|
||||||
}
|
|
||||||
} else if (utils::isa<float>(v)) {
|
} else if (utils::isa<float>(v)) {
|
||||||
auto vb = utils::cast<float>(v);
|
auto vb = utils::cast<float>(v);
|
||||||
if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) {
|
*value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON);
|
||||||
*value = false;
|
|
||||||
} else {
|
|
||||||
*value = true;
|
|
||||||
}
|
|
||||||
} else if (utils::isa<double>(v)) {
|
} else if (utils::isa<double>(v)) {
|
||||||
auto vb = utils::cast<double>(v);
|
auto vb = utils::cast<double>(v);
|
||||||
if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) {
|
*value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON);
|
||||||
*value = false;
|
|
||||||
} else {
|
|
||||||
*value = true;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "value is not supported to cast to be bool";
|
MS_LOG(DEBUG) << "value is not supported to cast to be bool";
|
||||||
return false;
|
return false;
|
||||||
|
@ -187,13 +164,13 @@ bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMap
|
||||||
return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
|
return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph,
|
bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||||
NodeMapEquiv *const equiv_node) {
|
NodeMapEquiv *const equiv_node) {
|
||||||
std::unordered_set<AnfNodePtr> done;
|
std::unordered_set<AnfNodePtr> done;
|
||||||
std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
|
std::stack<std::pair<AnfNodePtr, AnfNodePtr>> todo;
|
||||||
|
|
||||||
todo.push(std::make_pair(root1, root2));
|
todo.push(std::make_pair(root1, root2));
|
||||||
while (todo.size() > 0) {
|
while (!todo.empty()) {
|
||||||
AnfNodePtr node1 = todo.top().first;
|
AnfNodePtr node1 = todo.top().first;
|
||||||
if (done.count(node1) > 0) {
|
if (done.count(node1) > 0) {
|
||||||
todo.pop();
|
todo.pop();
|
||||||
|
@ -231,7 +208,7 @@ bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equ
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph,
|
bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||||
NodeMapEquiv *const equiv_node) {
|
NodeMapEquiv *const equiv_node) {
|
||||||
auto fg1_fg2 = std::make_pair(fg1, fg2);
|
auto fg1_fg2 = std::make_pair(fg1, fg2);
|
||||||
if (equiv_func_graph == nullptr) {
|
if (equiv_func_graph == nullptr) {
|
||||||
|
@ -267,23 +244,35 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
|
||||||
if (scalar == nullptr) {
|
if (scalar == nullptr) {
|
||||||
MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
|
MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
|
||||||
}
|
}
|
||||||
tensor::TensorPtr tensor = nullptr;
|
TypePtr data_type = scalar->type();
|
||||||
if (scalar->isa<FloatImm>()) {
|
MS_EXCEPTION_IF_NULL(data_type);
|
||||||
tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
|
TypeId type_id = data_type->type_id();
|
||||||
} else if (scalar->isa<Int32Imm>()) {
|
switch (type_id) {
|
||||||
tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
|
case kNumberTypeBool:
|
||||||
} else if (scalar->isa<Int64Imm>()) {
|
return std::make_shared<tensor::Tensor>(GetValue<bool>(scalar), data_type);
|
||||||
tensor = std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), kInt64);
|
case kNumberTypeInt8:
|
||||||
} else if (scalar->isa<BoolImm>()) {
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
|
||||||
const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
|
case kNumberTypeInt16:
|
||||||
tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
|
||||||
} else {
|
case kNumberTypeInt32:
|
||||||
auto type = scalar->type();
|
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
|
||||||
auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
|
case kNumberTypeInt64:
|
||||||
MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str;
|
return std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), data_type);
|
||||||
|
case kNumberTypeUInt8:
|
||||||
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
|
||||||
|
case kNumberTypeUInt16:
|
||||||
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
|
||||||
|
case kNumberTypeUInt32:
|
||||||
|
return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
|
||||||
|
case kNumberTypeUInt64:
|
||||||
|
return std::make_shared<tensor::Tensor>(GetValue<uint64_t>(scalar), data_type);
|
||||||
|
case kNumberTypeFloat32:
|
||||||
|
return std::make_shared<tensor::Tensor>(GetValue<float>(scalar), data_type);
|
||||||
|
case kNumberTypeFloat64:
|
||||||
|
return std::make_shared<tensor::Tensor>(GetValue<double>(scalar), data_type);
|
||||||
|
default:
|
||||||
|
MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid.";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
|
||||||
return tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
|
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
|
||||||
|
@ -301,7 +290,7 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (value->isa<tensor::Tensor>()) {
|
} else if (value->isa<tensor::Tensor>()) {
|
||||||
tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>();
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
tensors->push_back(tensor);
|
tensors->push_back(tensor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,8 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 };
|
||||||
using FuncGraphPairMapEquiv = std::unordered_map<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>;
|
using FuncGraphPairMapEquiv = std::unordered_map<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>;
|
||||||
using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>;
|
using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>;
|
||||||
|
|
||||||
bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node);
|
bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||||
|
NodeMapEquiv *equiv_node);
|
||||||
|
|
||||||
tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
|
tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
|
||||||
|
|
||||||
|
|
|
@ -491,6 +491,16 @@ Tensor::Tensor(double input, const TypePtr &data_type)
|
||||||
data_(MakeTensorData(data_type_, {}, input)),
|
data_(MakeTensorData(data_type_, {}, input)),
|
||||||
id_(MakeId()) {}
|
id_(MakeId()) {}
|
||||||
|
|
||||||
|
Tensor::Tensor(uint64_t input, const TypePtr &data_type)
|
||||||
|
: MetaTensor(TypeIdOf(data_type, kNumberTypeUInt64), {}),
|
||||||
|
data_(MakeTensorData(data_type_, {}, input)),
|
||||||
|
id_(MakeId()) {}
|
||||||
|
|
||||||
|
Tensor::Tensor(bool input, const TypePtr &data_type)
|
||||||
|
: MetaTensor(TypeIdOf(data_type, kNumberTypeBool), {}),
|
||||||
|
data_(MakeTensorData(data_type_, {}, input)),
|
||||||
|
id_(MakeId()) {}
|
||||||
|
|
||||||
bool Tensor::operator==(const Tensor &tensor) const {
|
bool Tensor::operator==(const Tensor &tensor) const {
|
||||||
return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
|
return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
|
||||||
}
|
}
|
||||||
|
|
|
@ -172,6 +172,18 @@ class Tensor : public MetaTensor {
|
||||||
// param data_type [TypeId] data type
|
// param data_type [TypeId] data type
|
||||||
explicit Tensor(double input, const TypePtr &data_type = nullptr);
|
explicit Tensor(double input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
|
// brief Create 0 dimension tensor from a uint scalar.
|
||||||
|
//
|
||||||
|
// param input [uint] the data for tensor
|
||||||
|
// param data_type [TypeId] data type
|
||||||
|
explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
|
// brief Create 0 dimension tensor from a bool scalar.
|
||||||
|
//
|
||||||
|
// param input [bool] the data for tensor
|
||||||
|
// param data_type [TypeId] data type
|
||||||
|
explicit Tensor(bool input, const TypePtr &data_type = nullptr);
|
||||||
|
|
||||||
~Tensor() override = default;
|
~Tensor() override = default;
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Tensor, MetaTensor);
|
MS_DECLARE_PARENT(Tensor, MetaTensor);
|
||||||
|
|
|
@ -88,6 +88,7 @@ class L1Regularizer(Cell):
|
||||||
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
||||||
return l1_regularization
|
return l1_regularization
|
||||||
|
|
||||||
|
|
||||||
class Dropout(Cell):
|
class Dropout(Cell):
|
||||||
r"""
|
r"""
|
||||||
Dropout layer for the input.
|
Dropout layer for the input.
|
||||||
|
@ -210,6 +211,7 @@ class Flatten(Cell):
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return F.reshape(x, (F.shape(x)[0], -1))
|
return F.reshape(x, (F.shape(x)[0], -1))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel):
|
def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel):
|
||||||
"""get broadcast_weight_bias shape"""
|
"""get broadcast_weight_bias shape"""
|
||||||
|
@ -217,6 +219,7 @@ def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel):
|
||||||
broad_bias_shape = x_shape[:-1] + (out_channel,)
|
broad_bias_shape = x_shape[:-1] + (out_channel,)
|
||||||
return broad_weight_shape, broad_bias_shape
|
return broad_weight_shape, broad_bias_shape
|
||||||
|
|
||||||
|
|
||||||
class Dense(Cell):
|
class Dense(Cell):
|
||||||
r"""
|
r"""
|
||||||
The dense connected layer.
|
The dense connected layer.
|
||||||
|
@ -262,6 +265,7 @@ class Dense(Cell):
|
||||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels'])
|
@cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels'])
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -323,7 +327,6 @@ class Dense(Cell):
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
@ -339,11 +342,13 @@ def _is_equal_one(x):
|
||||||
return False
|
return False
|
||||||
return bool(x.asnumpy().mean() == 1.0)
|
return bool(x.asnumpy().mean() == 1.0)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _dtype_check(x_dtype):
|
def _dtype_check(x_dtype):
|
||||||
if x_dtype not in [mstype.float32, mstype.float16]:
|
if x_dtype not in [mstype.float32, mstype.float16]:
|
||||||
raise TypeError("The input type must be float32 or float16.")
|
raise TypeError("The input type must be float32 or float16.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _is_float_dtype(dtype):
|
def _is_float_dtype(dtype):
|
||||||
if dtype in [mstype.float32, mstype.float16]:
|
if dtype in [mstype.float32, mstype.float16]:
|
||||||
|
@ -539,7 +544,6 @@ class OneHot(Cell):
|
||||||
return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
|
return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Pad(Cell):
|
class Pad(Cell):
|
||||||
"""
|
"""
|
||||||
Pads the input tensor according to the paddings and mode.
|
Pads the input tensor according to the paddings and mode.
|
||||||
|
@ -672,6 +676,7 @@ class Interpolate(Cell):
|
||||||
>>> print(result.shape)
|
>>> print(result.shape)
|
||||||
(1, 1, 5, 5)
|
(1, 1, 5, 5)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Interpolate, self).__init__()
|
super(Interpolate, self).__init__()
|
||||||
|
|
||||||
|
@ -767,6 +772,7 @@ class Tril(Cell):
|
||||||
[[1 0]
|
[[1 0]
|
||||||
[3 4]]
|
[3 4]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Tril, self).__init__()
|
super(Tril, self).__init__()
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
|
@ -809,6 +815,7 @@ class Triu(Cell):
|
||||||
[[1 2]
|
[[1 2]
|
||||||
[0 4]]
|
[0 4]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Triu, self).__init__()
|
super(Triu, self).__init__()
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
|
@ -859,6 +866,7 @@ class MatrixDiag(Cell):
|
||||||
[[ 1. 0.]
|
[[ 1. 0.]
|
||||||
[ 0. -1.]]
|
[ 0. -1.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MatrixDiag, self).__init__()
|
super(MatrixDiag, self).__init__()
|
||||||
self.matrix_diag = inner.MatrixDiag()
|
self.matrix_diag = inner.MatrixDiag()
|
||||||
|
@ -895,6 +903,7 @@ class MatrixDiagPart(Cell):
|
||||||
[-1. 1.]
|
[-1. 1.]
|
||||||
[-1. 1.]]
|
[-1. 1.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MatrixDiagPart, self).__init__()
|
super(MatrixDiagPart, self).__init__()
|
||||||
self.matrix_diag_part = inner.MatrixDiagPart()
|
self.matrix_diag_part = inner.MatrixDiagPart()
|
||||||
|
@ -936,6 +945,7 @@ class MatrixSetDiag(Cell):
|
||||||
[[-1. 0.]
|
[[-1. 0.]
|
||||||
[ 0. 1.]]]
|
[ 0. 1.]]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MatrixSetDiag, self).__init__()
|
super(MatrixSetDiag, self).__init__()
|
||||||
self.matrix_set_diag = inner.MatrixSetDiag()
|
self.matrix_set_diag = inner.MatrixSetDiag()
|
||||||
|
|
|
@ -407,7 +407,7 @@ class ParameterUpdate(Cell):
|
||||||
>>> param = network.parameters_dict()['weight']
|
>>> param = network.parameters_dict()['weight']
|
||||||
>>> update = nn.ParameterUpdate(param)
|
>>> update = nn.ParameterUpdate(param)
|
||||||
>>> update.phase = "update_param"
|
>>> update.phase = "update_param"
|
||||||
>>> weight = Tensor(0.001, mindspore.float32)
|
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
|
||||||
>>> update(weight)
|
>>> update(weight)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue