!40105 fix code check

Merge pull request !40105 from hezhenhao1/master
This commit is contained in:
i-robot 2022-08-10 05:28:53 +00:00 committed by Gitee
commit 16339e9d16
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 128 additions and 124 deletions

View File

@ -239,6 +239,8 @@ Reduction函数
mindspore.ops.amax
mindspore.ops.amin
mindspore.ops.argmin
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.logsumexp
mindspore.ops.max
mindspore.ops.mean
@ -265,8 +267,6 @@ Reduction函数
:template: classtemplate.rst
mindspore.ops.approximate_equal
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.equal
mindspore.ops.ge
mindspore.ops.gt

View File

@ -241,6 +241,8 @@ Reduction Functions
mindspore.ops.amax
mindspore.ops.amin
mindspore.ops.argmin
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.logsumexp
mindspore.ops.max
mindspore.ops.mean
@ -267,8 +269,6 @@ Comparison Functions
:template: classtemplate.rst
mindspore.ops.approximate_equal
mindspore.ops.cummax
mindspore.ops.cummin
mindspore.ops.equal
mindspore.ops.ge
mindspore.ops.gt

View File

@ -793,7 +793,7 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
}
}
if (args_tuple_num == 0) {
(void)args->emplace_back(value);
args->emplace_back(value);
front_index++;
}
}

View File

@ -64,7 +64,7 @@ inline void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_
}
template <typename T>
inline void Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
inline void Copy(T *input, const T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
size_t k1 = i / dim2 % dim0;

View File

@ -58,7 +58,7 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod {
static std::vector<std::pair<KernelAttr, CumSumLaunchFunc>> func_list_;
CumSumLaunchFunc kernel_func_;
int axis_{0};
ShapeVector shape_{};
std::vector<size_t> shape_{};
size_t stride_{0};
size_t stride2_{0};
size_t dims_[3]{0};

View File

@ -15,7 +15,7 @@
*/
#include "plugin/device/cpu/kernel/is_close_cpu_kernel.h"
#include <cmath>
#include <functional>
#include <algorithm>
#include <memory>
#include "mindspore/core/ops/is_close.h"
@ -26,15 +26,16 @@ namespace kernel {
namespace {
constexpr size_t kIsCloseInputsNum = 2;
constexpr size_t kIsCloseOutputsNum = 1;
constexpr size_t kIsCloseMaxDim = 10;
template <typename T>
inline bool IsClose(T a, T b, float rtol, float atol, bool equal_nan) {
if (a == b) {
if (std::equal_to<T>()(a, b)) {
return true;
}
if (equal_nan && std::isnan(a) && std::isnan(b)) {
return true;
}
if (atol == 0 && rtol == 0) {
if (std::equal_to<float>()(atol, 0) && std::equal_to<float>()(rtol, 0)) {
return false;
}
auto left_side = std::abs(a - b);
@ -42,27 +43,26 @@ inline bool IsClose(T a, T b, float rtol, float atol, bool equal_nan) {
return std::isfinite(left_side) && left_side <= right_side;
}
void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVector &output_shape,
std::vector<int64_t> *index_list) {
void GetBroadCastIndex(const std::vector<size_t> &unaligned_input_shape, const std::vector<size_t> &output_shape,
std::vector<size_t> *index_list) {
// Given unaligned input shape and output shape, this function returns the mapping
// from indices of output (logical) to corespondingly real input indices (physical).
// The return will write to index_list, whose size is equal to total elements of output.
constexpr int MaxDim = 10;
int64_t logical_shape[MaxDim];
int64_t physical_shape[MaxDim];
int64_t size = 0, output_size = 1;
size_t logical_shape[kIsCloseMaxDim];
size_t physical_shape[kIsCloseMaxDim];
size_t size = 0, output_size = 1;
// Align input shape to output shape by filling one into the outermost dimension.
ShapeVector input_shape(output_shape.size());
std::vector<size_t> input_shape(output_shape.size());
for (size_t i = 0, j = output_shape.size() - unaligned_input_shape.size(); i < output_shape.size(); i++) {
input_shape[i] = i < j ? 1 : unaligned_input_shape[i - j];
}
// Get logical shape and physical shape of input. Moreover, we will merge the dimensions with same
// (logical or physical) property.
for (int i = SizeToInt(output_shape.size()) - 1; i >= 0;) {
int64_t stride = 1;
for (size_t i = output_shape.size(); i > 0;) {
size_t stride = 1;
bool change = false, is_valid = false;
while (i >= 0 && input_shape[i] == output_shape[i]) {
stride *= output_shape[i];
while (i > 0 && input_shape[i - 1] == output_shape[i - 1]) {
stride *= output_shape[i - 1];
change = is_valid = true;
--i;
}
@ -73,8 +73,8 @@ void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVect
}
change = false;
stride = 1;
while (i >= 0 && input_shape[i] == 1) {
stride *= output_shape[i];
while (i > 0 && input_shape[i - 1] == 1) {
stride *= output_shape[i - 1];
change = is_valid = true;
--i;
}
@ -90,13 +90,13 @@ void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVect
}
}
// Get the flatten input indices according to "logical_shape" and "physical_shape".
int64_t offset = 1;
int64_t stride = 1;
size_t offset = 1;
size_t stride = 1;
index_list->resize(output_size);
(*index_list)[0] = 0; // First element is set to 0.
for (int64_t i = 0; i < size; ++i) {
int64_t increment = (logical_shape[i] == physical_shape[i] ? stride : 0);
for (int64_t j = 0; j < (physical_shape[i] - 1) * offset; ++j) {
for (size_t i = 0; i < size; ++i) {
size_t increment = (logical_shape[i] == physical_shape[i] ? stride : 0);
for (size_t j = 0; j < (physical_shape[i] - 1) * offset; ++j) {
(*index_list)[offset + j] = (*index_list)[j] + increment;
}
offset *= physical_shape[i];
@ -125,9 +125,9 @@ int IsCloseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
auto other_shape = inputs.at(kIndex1)->GetShapeVector();
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
auto input_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
auto other_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
auto output_shape = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector());
is_need_broadcast_ = input_shape != other_shape;
if (is_need_broadcast_) {
GetBroadCastIndex(input_shape, output_shape, &index_list1_);
@ -137,7 +137,7 @@ int IsCloseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
}
template <typename T>
bool IsCloseCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool IsCloseCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIsCloseInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIsCloseOutputsNum, kernel_name_);

View File

@ -55,8 +55,8 @@ class IsCloseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
bool equal_nan_{true};
// Broadcast related.
std::vector<int64_t> index_list1_{};
std::vector<int64_t> index_list2_{};
std::vector<size_t> index_list1_{};
std::vector<size_t> index_list2_{};
bool is_need_broadcast_{false};
};
} // namespace kernel

View File

@ -173,22 +173,21 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
};
auto element_size = batch_size_ * inner_size_;
constexpr size_t min_block_size = 128;
auto block_size = std::max(min_block_size, element_size / GetActorMgrInnerThreadPool()->GetKernelThreadNum());
ParallelLaunch(task, element_size, block_size, this);
ParallelLaunch(task, element_size, 0, this, pool_);
if (invalid_index_pos != -1) {
std::stringstream indices_ss;
std::stringstream input_shape_ss;
auto pos = LongToSize(invalid_index_pos);
for (size_t i = 0; i < slice_size_; i++) {
if (i > 0) {
indices_ss << ", ";
input_shape_ss << ", ";
}
indices_ss << std::to_string(indices[invalid_index_pos + i]);
indices_ss << std::to_string(indices[pos + i]);
input_shape_ss << std::to_string(input_shape_[i]);
}
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
<< indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "].";
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << pos << "-th value of 'indices'[" << indices_ss.str()
<< "] is out of range[" + input_shape_ss.str() + "].";
return false;
}
return true;

View File

@ -42,21 +42,21 @@ bool SparseSegmentMeanCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
int SparseSegmentMeanCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
auto x_shape = inputs.at(kIndex0)->GetShapeVector();
auto indices_shape = inputs.at(kIndex1)->GetShapeVector();
auto y_shape = outputs.at(kIndex0)->GetShapeVector();
batch_rank_ = LongToSize(base_operator->get_batch_rank());
batch_size_ = std::accumulate(x_shape.begin(), x_shape.begin() + batch_rank_, size_t(1), std::multiplies{});
outer_size_ = LongToSize(x_shape.at(batch_rank_));
inner_size_ = std::accumulate(x_shape.begin() + batch_rank_ + 1, x_shape.end(), size_t(1), std::multiplies{});
auto x_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
auto indices_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
auto y_shape = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector());
auto batch_rank = base_operator->get_batch_rank();
batch_size_ = std::accumulate(x_shape.begin(), x_shape.begin() + batch_rank, size_t(1), std::multiplies{});
outer_size_ = x_shape.at(LongToSize(batch_rank));
inner_size_ = std::accumulate(x_shape.begin() + batch_rank + 1, x_shape.end(), size_t(1), std::multiplies{});
x_size_ = inner_size_ * outer_size_;
indices_size_ = LongToSize(indices_shape.at(batch_rank_));
y_size_ = std::accumulate(y_shape.begin() + batch_rank_, y_shape.end(), size_t(1), std::multiplies{});
indices_size_ = indices_shape.at(LongToSize(batch_rank));
y_size_ = std::accumulate(y_shape.begin() + batch_rank, y_shape.end(), size_t(1), std::multiplies{});
return ret;
}
@ -107,11 +107,11 @@ bool SparseSegmentMeanCpuKernelMod::LaunchKernel(const std::vector<kernel::Addre
IndexType pre_segment_id = -1;
for (size_t k = 0; k < indices_size_; ++k) {
auto i = index_batch_offset + k;
auto x_offset = x_batch_offset + indices_ptr[i] * inner_size_;
auto y_offset = y_batch_offset + segment_ids_ptr[i] * inner_size_;
auto x_offset = x_batch_offset + LongToSize(indices_ptr[i]) * inner_size_;
auto y_offset = y_batch_offset + LongToSize(segment_ids_ptr[i]) * inner_size_;
// Reset the empty segments by setting output[i] = 0.
for (auto sid = pre_segment_id + 1; sid <= segment_ids_ptr[i]; sid++) {
auto reset_y_ptr = y_ptr + (y_batch_offset + sid * inner_size_);
auto reset_y_ptr = y_ptr + (y_batch_offset + LongToSize(sid) * inner_size_);
std::fill(reset_y_ptr + start, reset_y_ptr + end, DataType(0));
}
pre_segment_id = segment_ids_ptr[i];

View File

@ -39,7 +39,7 @@ class SparseSegmentMeanCpuKernelMod : public NativeCpuKernelMod {
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
@ -59,7 +59,6 @@ class SparseSegmentMeanCpuKernelMod : public NativeCpuKernelMod {
size_t indices_size_{1};
size_t x_size_{1};
size_t y_size_{1};
size_t batch_rank_{0};
size_t batch_size_{1};
};
} // namespace kernel

View File

@ -105,6 +105,7 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
const std::vector<KernelTensorPtr> &outputs) override {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Lamb>(base_operator);
kernel_name_ = kernel_ptr->name();
InitResource();
return true;
}
@ -114,7 +115,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
MS_LOG(EXCEPTION) << "For 'Lamb', the number of inputs should be " << INPUT_NUM << ", but got " << inputs.size();
}
InitResource();
InitParamSizeByType();
auto variable_int64_shape = inputs[kVarIndex]->GetShapeVector();

View File

@ -56,8 +56,8 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<Abst
// Input 'x' and 'value' must have the same types.
std::map<std::string, TypePtr> args;
(void)args.insert({"x", x_type});
(void)args.insert({"value", value_type});
(void)args.emplace("x", x_type);
(void)args.emplace("value", x_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_data_types, prim_name);
return x_type;
}
@ -78,7 +78,7 @@ abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std:
// Input 'index' must be a scalar/vector.
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto index_rank = SizeToLong(index_shape.size());
(void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name);
CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name);
// Input 'value' must be a tensor with a value or a scalar.
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {

View File

@ -53,14 +53,16 @@ abstract::ShapePtr IsCloseInferShape(const PrimitivePtr &primitive, const std::v
<< ", 'other' size: " << other_shape[i] << ".";
}
}
if (input_size > MAX)
if (input_size > MAX) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the size of tensor 'input' must be less than [2147483648], but got: "
<< input_size << ".";
if (other_size > MAX)
}
if (other_size > MAX) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the size of tensor 'other' must be less than [2147483648], but got: "
<< other_size << ".";
}
}
return BroadCastInferShape(op_name, input_args);
}

View File

@ -30,7 +30,7 @@
namespace mindspore {
namespace ops {
namespace {
int64_t GetTensorValue(AbstractBasePtr arg, const std::string &prim_name, const std::string &arg_name) {
int64_t GetTensorValue(const AbstractBasePtr &arg, const std::string &prim_name, const std::string &arg_name) {
if (!arg->isa<abstract::AbstractTensor>() || !arg->BuildValue()->isa<tensor::Tensor>()) {
MS_EXCEPTION(TypeError) << "For " << prim_name << ", the input '" << arg_name << "' must be const Tensor.";
}
@ -60,9 +60,9 @@ ShapeVector GetOutputShape(const std::vector<int64_t> &x_shape, int64_t lower_di
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1], but got k[0] is "
<< lower_diag_index << ", k[1] is " << upper_diag_index << ".";
}
CheckAndConvertUtils::CheckInteger("rank of 'x'", x_rank, kGreaterEqual, number_two, prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of 'x'", x_rank, kGreaterEqual, number_two, prim_name);
auto num_diags = upper_diag_index - lower_diag_index + 1;
if (x_shape[x_rank - number_two] != num_diags) {
if (x_shape[LongToSize(x_rank - number_two)] != num_diags) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the input x_shape[-2] doesn't match with k value.";
}
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - number_two);
@ -73,10 +73,12 @@ ShapeVector GetOutputShape(const std::vector<int64_t> &x_shape, int64_t lower_di
int64_t max_diag_len = x_shape.back();
int64_t min_num_rows = max_diag_len - std::min(upper_diag_index, int64_t(0));
int64_t min_num_cols = max_diag_len + std::max(lower_diag_index, int64_t(0));
if (row_val != -1 && row_val < min_num_rows)
if (row_val != -1 && row_val < min_num_rows) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows is too small.";
if (col_val != -1 && col_val < min_num_cols)
}
if (col_val != -1 && col_val < min_num_cols) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of columns is too small.";
}
if (row_val == -1 && col_val == -1) {
row_val = std::max(min_num_rows, min_num_cols);
col_val = row_val;
@ -212,8 +214,8 @@ AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const Pri
auto align_ptr = primitive->GetAttr(kAlign);
MS_EXCEPTION_IF_NULL(align_ptr);
auto align = GetValue<std::string>(align_ptr);
CheckAndConvertUtils::CheckString(kAlign, align, {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"},
primitive->name());
(void)CheckAndConvertUtils::CheckString(kAlign, align, {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"},
primitive->name());
auto infer_type = MatrixDiagV3InferType(primitive, input_args);
auto infer_shape = MatrixDiagV3InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);

View File

@ -60,7 +60,7 @@ abstract::ShapePtr SparseSegmentMeanInferShape(const PrimitivePtr &prim,
if (!input_args[kInputIndex2]->BuildValue()->isa<tensor::Tensor>()) {
// The real output shape relies on the last value of 'segment_ids', we have already added dependency map.
// The framework ensures the `else` branch will be executed, so min/max shape are not necessary to set.
out_shape[batch_rank] = abstract::Shape::SHP_ANY;
out_shape[LongToSize(batch_rank)] = abstract::Shape::SHP_ANY;
return std::make_shared<abstract::Shape>(out_shape);
} else {
auto segment_ids = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
@ -85,7 +85,7 @@ abstract::ShapePtr SparseSegmentMeanInferShape(const PrimitivePtr &prim,
if (segment_num <= 0) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', the input 'segment_ids' must be non-negative.";
}
out_shape[batch_rank] = segment_num;
out_shape[LongToSize(batch_rank)] = segment_num;
return std::make_shared<abstract::Shape>(out_shape);
}
}

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Linear algebra submodule"""
from __future__ import absolute_import
from .ops import Cholesky
from .ops import Eigh
from .ops import LU

View File

@ -72,9 +72,9 @@ class MinimizeBfgs(nn.Cell):
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
# Constant tensors which avoid loop unrolling
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
const_bool_false = _to_tensor(False)
const_int_zero = _to_tensor(0)
const_int_one = _to_tensor(1)
if maxiter is None:
maxiter = mnp.size(x0) * 200
@ -86,18 +86,18 @@ class MinimizeBfgs(nn.Cell):
state = {
"converged": _norm(g_0, ord_=mnp.inf) < gtol,
"failed": _BOOL_FALSE,
"k": _INT_ZERO,
"nfev": _INT_ONE,
"ngev": _INT_ONE,
"nhev": _INT_ZERO,
"failed": const_bool_false,
"k": const_int_zero,
"nfev": const_int_one,
"ngev": const_int_one,
"nhev": const_int_zero,
"x_k": x0,
"f_k": f_0,
"g_k": g_0,
"H_k": identity,
"old_old_fval": f_0 + _norm(g_0) / 2,
"status": _INT_ZERO,
"line_search_status": _INT_ZERO
"status": const_int_zero,
"line_search_status": const_int_zero
}
while state["k"] < maxiter:

View File

@ -86,13 +86,13 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
Tries cubic, quadratic, and bisection methods of zooming.
"""
# Constant tensors which avoid loop unrolling
_FLOAT_ONE = _to_tensor(1., dtype=a_low.dtype)
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
const_float_one = _to_tensor(1., dtype=a_low.dtype)
const_bool_false = _to_tensor(False)
const_int_zero = _to_tensor(0)
state = {
"done": _BOOL_FALSE,
"failed": _BOOL_FALSE,
"j": _INT_ZERO,
"done": const_bool_false,
"failed": const_bool_false,
"j": const_int_zero,
"a_low": a_low,
"phi_low": phi_low,
"dphi_low": dphi_low,
@ -101,12 +101,12 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
"dphi_high": dphi_high,
"a_rec": (a_low + a_high) / 2.,
"phi_rec": (phi_low + phi_high) / 2.,
"a_star": _FLOAT_ONE,
"a_star": const_float_one,
"phi_star": phi_low,
"dphi_star": dphi_low,
"g_star": g_0,
"nfev": _INT_ZERO,
"ngev": _INT_ZERO,
"nfev": const_int_zero,
"ngev": const_int_zero,
}
if mnp.logical_not(is_run):
@ -195,41 +195,41 @@ class LineSearch(nn.Cell):
return fkk, gkk, mnp.dot(gkk, pk)
# Constant tensors which avoid loop unrolling
_FLOAT_ZERO = _to_tensor(0., dtype=xk.dtype)
_FLOAT_ONE = _to_tensor(1., dtype=xk.dtype)
_BOOL_FALSE = _to_tensor(False)
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
const_float_zero = _to_tensor(0., dtype=xk.dtype)
const_float_one = _to_tensor(1., dtype=xk.dtype)
const_bool_false = _to_tensor(False)
const_int_zero = _to_tensor(0)
const_int_one = _to_tensor(1)
if old_fval is None or gfk is None:
nfev, ngev = _INT_ONE, _INT_ONE
phi_0, g_0, dphi_0 = fval_and_grad(_FLOAT_ZERO)
nfev, ngev = const_int_one, const_int_one
phi_0, g_0, dphi_0 = fval_and_grad(const_float_zero)
else:
nfev, ngev = _INT_ZERO, _INT_ZERO
nfev, ngev = const_int_zero, const_int_zero
phi_0, g_0 = old_fval, gfk
dphi_0 = mnp.dot(g_0, pk)
if old_old_fval is None:
start_value = _FLOAT_ONE
start_value = const_float_one
else:
old_phi0 = old_old_fval
candidate_start_value = 1.01 * 2 * (phi_0 - old_phi0) / dphi_0
start_value = mnp.where(
mnp.isfinite(candidate_start_value),
mnp.minimum(candidate_start_value, _FLOAT_ONE),
_FLOAT_ONE
mnp.minimum(candidate_start_value, const_float_one),
const_float_one
)
state = {
"done": _BOOL_FALSE,
"failed": _BOOL_FALSE,
"i": _INT_ONE,
"a_i": _FLOAT_ZERO,
"done": const_bool_false,
"failed": const_bool_false,
"i": const_int_one,
"a_i": const_float_zero,
"phi_i": phi_0,
"dphi_i": dphi_0,
"nfev": nfev,
"ngev": ngev,
"a_star": _FLOAT_ZERO,
"a_star": const_float_zero,
"phi_star": phi_0,
"dphi_star": dphi_0,
"g_star": g_0,

View File

@ -78,19 +78,19 @@ def _batch_gmres(A, b, x0, tol, restart, maxiter, M, atol):
It does not allow for early termination, but has much less overhead on GPUs.
"""
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
const_int_zero = _to_tensor(0)
dtype = b.dtype
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype)
residual = _matvec(M, b - _matvec(A, x0))
unit_residual, residual_norm = _safe_normalize(residual)
k = _INT_ZERO
k = const_int_zero
x = x0
while k < maxiter and residual_norm > atol:
pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),)
V = mnp.pad(unit_residual[..., None], pad_width=pad_width)
H = mnp.eye(restart, restart + 1, dtype=dtype)
k_iter = _INT_ZERO
k_iter = const_int_zero
breakdown = _to_tensor(False)
while k_iter < restart and mnp.logical_not(breakdown):
V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H)
@ -103,7 +103,7 @@ def _batch_gmres(A, b, x0, tol, restart, maxiter, M, atol):
residual = _matvec(M, b - _matvec(A, x))
unit_residual, residual_norm = _safe_normalize(residual)
k += 1
return x, F.select(residual_norm > atol, k, _INT_ZERO)
return x, F.select(residual_norm > atol, k, const_int_zero)
def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol):
@ -112,7 +112,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol):
the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of
the residual norm that allows for early termination within a single "restart".
"""
_INT_ZERO = _to_tensor(0)
const_int_zero = _to_tensor(0)
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, atol)
@ -123,7 +123,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol):
r = _matvec(M, b - _matvec(A, x0))
r, r_norm = _safe_normalize(r)
iters = _INT_ZERO
iters = const_int_zero
while iters < maxiter and r_norm > atol:
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
dtype = mnp.result_type(b)
@ -134,13 +134,13 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol):
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
beta_vec[0] = r_norm
k = _INT_ZERO
k = const_int_zero
err = r_norm
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
V, R, _ = arnoldi_iteration(k, A, M, V, R)
# Givens rotation
row_k = R[k, :]
i = _INT_ZERO
i = const_int_zero
while i < k:
row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1])
i += 1
@ -168,7 +168,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol):
r, r_norm = _safe_normalize(r)
x0 = x
iters += 1
return x0, F.select(r_norm > atol, iters, _INT_ZERO)
return x0, F.select(r_norm > atol, iters, const_int_zero)
class GMRES(nn.Cell):
@ -360,13 +360,13 @@ def _cg(A, b, x0, tol, atol, maxiter, M):
building blocks for iterative methods', 1994, pg. 12-14
"""
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
const_int_zero = _to_tensor(0)
atol_ = mnp.maximum(atol, tol * _norm(b))
r = b - _matvec(A, x0)
z = p = _matvec(M, r)
rho = mnp.dot(r, z)
k = _INT_ZERO
k = const_int_zero
x = x0
while k < maxiter and _norm(r) > atol_:
q = _matvec(A, p)
@ -382,7 +382,7 @@ def _cg(A, b, x0, tol, atol, maxiter, M):
k += 1
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
return x, F.select(_norm(r) > atol_, k, const_int_zero)
class CG(nn.Cell):
@ -535,20 +535,20 @@ class BiCGStab(nn.Cell):
def construct(self, b, x0, tol, atol, maxiter):
# Constant tensors which avoid loop unrolling
_INT_ZERO = _to_tensor(0)
_INT_NEG_ONE = _to_tensor(-1)
const_int_zero = _to_tensor(0)
const_int_neg_one = _to_tensor(-1)
_float_one = _to_tensor(1., dtype=b.dtype)
const_float_one = _to_tensor(1., dtype=b.dtype)
atol_ = mnp.maximum(atol, tol * _norm(b))
r = r_tilde = v = p = b - _matvec(self.A, x0)
rho = alpha = omega = _float_one
k = _INT_ZERO
rho = alpha = omega = const_float_one
k = const_int_zero
x = x0
while k < maxiter:
rho_ = mnp.dot(r_tilde, r)
if rho_ == 0. or omega == 0.:
k = _INT_NEG_ONE
k = const_int_neg_one
break
beta = rho_ / rho * (alpha / omega)
@ -572,7 +572,7 @@ class BiCGStab(nn.Cell):
rho = rho_
k += 1
return x, F.select(k == _INT_NEG_ONE or k >= maxiter, k, _INT_ZERO)
return x, F.select(k == const_int_neg_one or k >= maxiter, k, const_int_zero)
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""internal graph-compatible utility functions"""
from __future__ import absolute_import
from types import FunctionType
from collections.abc import Iterable
from .. import context