add DeepLM model

This commit is contained in:
mindmender_duan 2021-07-09 09:53:30 +08:00
parent 4b3c490eff
commit 232215ffe0
15 changed files with 1263 additions and 0 deletions

View File

@ -0,0 +1,27 @@
# DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks (CVPR 2021)
MindSpore-based version for [**DeepLM**](https://openaccess.thecvf.com/content/CVPR2021/papers/Huang_DeepLM_Large-Scale_Nonlinear_Least_Squares_on_Deep_Learning_Frameworks_Using_CVPR_2021_paper.pdf), a BA(bundle adjustment) demo is been shown here.
## Run
Make sure the following (v1.3+ is required) is installed
1. [**MindSpore-gpu**](https://www.mindspore.cn/install).
Download the data by running
```shell
cd data
sh download.sh
```
Then, run the example via
```shell
cd ..
sh example.sh
```
## Data Description
A full set of test data can be downloaded from the [**BAL**](http://grail.cs.washington.edu/projects/bal/) page (Software & Data -> Available Dataset).

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DeepLM io."""
import numpy as np
def load_bal_from_file(filename, feature_dim, camera_dim, point_dim, double=True):
"""load ba data"""
dtype = np.float64 if double else np.float32
with open(filename, 'r') as f:
num_cameras, num_points, num_observations = [int(i) for i in f.readline().strip().split()]
point_indices = []
cam_indices = []
t_camera = np.zeros((num_cameras, camera_dim)).astype(dtype)
t_point = np.zeros((num_points, point_dim)).astype(dtype)
t_feat = np.zeros((num_observations, feature_dim)).astype(dtype)
for i in range(num_observations):
features2d = []
if i % 1000 == 0:
print("\r Load observation {} of {}".format(i, num_observations), end="", flush=True)
cam_idx, point_idx, x, y = f.readline().strip().split()
point_indices.append(int(point_idx))
cam_indices.append(int(cam_idx))
features2d.append(float(x))
features2d.append(float(y))
t_feat[i] = (features2d)
t_point_indices = point_indices
t_cam_indices = cam_indices
for i in range(num_cameras):
camera_paras = []
for _ in range(camera_dim):
camera_para = f.readline().strip().split()[0]
camera_paras.append(float(camera_para))
t_camera[i] = camera_paras
for i in range(num_points):
points3d = []
for _ in range(point_dim):
point = f.readline().strip().split()[0]
points3d.append(float(point))
t_point[i] = points3d
return t_point, t_camera, t_feat, t_point_indices, t_cam_indices

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DeepLM loss."""
from .rotation import AngleAxisRotatePoint
from .utils import ComputeBase
class Distort(ComputeBase):
"""distort func"""
def construct(self, xp, yp, cam):
"""construct"""
l1 = cam[:, 7]
l2 = cam[:, 8]
r2 = self.mul(xp, xp) + self.mul(yp, yp)
distortion = 1.0 + self.mul(r2, (l1 + self.mul(l2, r2)))
focal = cam[:, 6]
predicted_x = self.mul(self.mul(-focal, xp), distortion)
predicted_y = self.mul(self.mul(-focal, yp), distortion)
return predicted_x, predicted_y
class SnavelyReprojectionError(ComputeBase):
"""projection func"""
def __init__(self):
super(SnavelyReprojectionError, self).__init__()
self.angleAxisRotatePoint = AngleAxisRotatePoint()
self.distort = Distort()
def construct(self, points_ob, cameras_ob, features):
"""construct"""
if len(points_ob.shape) == 3:
points_ob = points_ob[:, 0, :]
cameras_ob = cameras_ob[:, 0, :]
p = self.angleAxisRotatePoint(cameras_ob[:, :3], points_ob)
p = p + cameras_ob[:, 3:6]
xp = self.div(p[:, 0], p[:, 2])
yp = self.div(p[:, 1], p[:, 2])
predicted_x, predicted_y = self.distort(xp, yp, cameras_ob)
residual_0 = predicted_x - features[:, 0]
residual_1 = predicted_y - features[:, 1]
residual_0 = self.reshape(residual_0, (residual_0.shape[0], 1))
residual_1 = self.reshape(residual_1, (residual_1.shape[0], 1))
return self.concat((residual_0, residual_1))

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""rotation func."""
from .utils import ComputeBase
class AngleAxisRotatePoint(ComputeBase):
"""rotate by a specified angle"""
def construct(self, angle_axis, points):
"""construct"""
theta2 = self.sum(self.mul(angle_axis, angle_axis), axis=1)
mask = (theta2 > 0)
theta = self.sqrt(theta2 + (1 - mask))
mask = self.reshape(mask, (mask.shape[0], 1))
mask = self.concat((mask, mask, mask))
costheta = self.cos(theta)
sintheta = self.sin(theta)
theta_inverse = 1.0 / theta
w_0 = self.mul(angle_axis[:, 0], theta_inverse)
w_1 = self.mul(angle_axis[:, 1], theta_inverse)
w_2 = self.mul(angle_axis[:, 2], theta_inverse)
w_cross_point_0 = self.mul(w_1, points[:, 2]) - self.mul(w_2, points[:, 1])
w_cross_point_1 = self.mul(w_2, points[:, 0]) - self.mul(w_0, points[:, 2])
w_cross_point_2 = self.mul(w_0, points[:, 1]) - self.mul(w_1, points[:, 0])
tmp = self.mul(self.mul(w_0, points[:, 0]) + self.mul(w_1, points[:, 1]) + self.mul(w_2, points[:, 2]),
(1.0 - costheta))
r_0 = self.mul(points[:, 0], costheta) + self.mul(w_cross_point_0, sintheta) + self.mul(w_0, tmp)
r_1 = self.mul(points[:, 1], costheta) + self.mul(w_cross_point_1, sintheta) + self.mul(w_1, tmp)
r_2 = self.mul(points[:, 2], costheta) + self.mul(w_cross_point_2, sintheta) + self.mul(w_2, tmp)
r_0 = self.reshape(r_0, (r_0.shape[0], 1))
r_1 = self.reshape(r_1, (r_1.shape[0], 1))
r_2 = self.reshape(r_2, (r_2.shape[0], 1))
res1 = self.concat((r_0, r_1, r_2))
w_cross_point_0 = self.mul(angle_axis[:, 1], points[:, 2]) - self.mul(angle_axis[:, 2], points[:, 1])
w_cross_point_1 = self.mul(angle_axis[:, 2], points[:, 0]) - self.mul(angle_axis[:, 0], points[:, 2])
w_cross_point_2 = self.mul(angle_axis[:, 0], points[:, 1]) - self.mul(angle_axis[:, 1], points[:, 0])
r00 = points[:, 0] + w_cross_point_0
r01 = points[:, 1] + w_cross_point_1
r02 = points[:, 2] + w_cross_point_2
r00 = self.reshape(r00, (r00.shape[0], 1))
r01 = self.reshape(r01, (r01.shape[0], 1))
r02 = self.reshape(r02, (r02.shape[0], 1))
res2 = self.concat((r00, r01, r02))
return self.mul(res1, mask) + self.mul(res2, 1 - mask)

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""utils."""
import mindspore.ops.operations as P
from mindspore.nn import Cell
from mindspore.ops import functional as F
class ComputeBase(Cell):
"""define some basic OPs"""
def __init__(self):
super(ComputeBase, self).__init__()
self.mul = F.tensor_mul
self.sum = F.reduce_sum
self.cos = F.cos
self.sin = F.sin
self.sqrt = F.sqrt
self.div = P.Div()
self.reshape = F.reshape
self.concat = P.Concat(axis=1)

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
wget http://grail.cs.washington.edu/projects/bal/data/dubrovnik/problem-308-195089-pre.txt.bz2 --no-check-certificate
bzip2 -d problem-308-195089-pre.txt.bz2

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
export PYTHONPATH=$PYTHONPATH:$(pwd)
#global lm solver
python3 examples/BundleAdjuster/bundle_adjuster.py --balFile ./data/problem-49-7776-pre.txt --device cuda

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DeepLM BA demo."""
import argparse
from time import time
import mindspore as ms
from mindspore.ops.functional import tensor_mul
from mindspore import Tensor
from mindspore import context
from lm_solver.solver import solve
from ba_core.loss import SnavelyReprojectionError
from ba_core.io import load_bal_from_file
DOUBLE_ENABLE = False
FLOAT_DTYPE = ms.float64 if DOUBLE_ENABLE else ms.float32
INT_DYTPE = ms.int32
parser = argparse.ArgumentParser(description='Bundle adjuster')
parser.add_argument('--balFile', default='data/problem-1723-156502-pre.txt')
parser.add_argument('--device', default='gpu')
args = parser.parse_args()
filename = args.balFile
device = args.device
if device == "gpu" or "cuda":
device_target = "GPU"
elif device == "cpu":
device_target = "CPU"
context.set_context(device_target=device_target)
context.set_context(mode=context.PYNATIVE_MODE)
# Load BA data
points, cameras, features, pt_idx, cam_idx = load_bal_from_file(filename=filename,
feature_dim=2,
camera_dim=9,
point_dim=3,
double=DOUBLE_ENABLE)
points = Tensor(points, FLOAT_DTYPE)
cameras = Tensor(cameras, FLOAT_DTYPE)
features = Tensor(features, FLOAT_DTYPE)
pt_idx = Tensor(pt_idx, INT_DYTPE)
cam_idx = Tensor(cam_idx, INT_DYTPE)
# Optionally use CUDA, move data to device
points = tensor_mul(points, Tensor(1.0, FLOAT_DTYPE))
cameras = tensor_mul(cameras, Tensor(1.0, FLOAT_DTYPE))
features = tensor_mul(features, Tensor(1.0, FLOAT_DTYPE))
pt_idx = tensor_mul(pt_idx, Tensor(1, INT_DYTPE))
cam_idx = tensor_mul(cam_idx, Tensor(1, INT_DYTPE))
t1 = time()
solve(variables=[points, cameras], constants=[features], indices=[pt_idx, cam_idx],
fn=SnavelyReprojectionError(), num_iterations=15, num_success_iterations=15)
t2 = time()
print("Time used %f secs." % (t2 - t1))

View File

@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Jacobin operations."""
import numpy as np
import mindspore as ms
from mindspore import Parameter
from mindspore.ops import functional as F
import mindspore.ops.operations as P
from mindspore import Tensor
from .listvec import list_zero
squeeze_1 = P.Squeeze(1)
indexadd_0 = P.IndexAdd(axis=0)
batchmatmul = P.BatchMatMul()
def jacobi_column_square(indices, jacobians, jacobian_scale):
"""column square"""
for _, i in enumerate(range(len(jacobians))):
index = indices[i]
jacobi = jacobians[i]
jacobi_square = F.reduce_sum(F.mul(jacobi, jacobi), axis=0)
# TODO: Remove Parameter init.
jacobian_scale[i] = Parameter(jacobian_scale[i])
indexadd_0(jacobian_scale[i], squeeze_1(index), squeeze_1(jacobi_square))
return jacobian_scale
def column_inverse_square(jacobian_scale):
"""column inverse"""
for _, i in enumerate(range(len(jacobian_scale))):
column_norm_shape = jacobian_scale[i].shape
numDimV, *_ = column_norm_shape
jacobian_scale[i][:numDimV] = 1.0 / (1.0 + F.sqrt(jacobian_scale[i][:numDimV]))
def jacobi_squared_column_norm(jacobians, indices, variables):
"""column square with column_norm"""
column_norm = list_zero(variables)
return jacobi_column_square(indices, jacobians, column_norm)
def jacobi_normalize(jacobi, indices, variables):
"""normalize"""
jacobian_scale = jacobi_squared_column_norm(jacobi, indices, variables)
for _, i in enumerate(range(len(jacobian_scale))):
jacobian_scale[i] = 1.0 / (1.0 + F.sqrt(jacobian_scale[i]))
column = jacobian_scale[i][indices[i]]
shape_dim = [i for i in column.shape]
shape_dim.insert(0, jacobi[i].shape[0])
column = P.BroadcastTo(tuple(shape_dim))(F.expand_dims(column, 0))
jacobi[i] = F.mul(jacobi[i], column)
return jacobian_scale
def jacobi_block_jt(jacobians, lmDiagonal, indices, res):
"""block JtJ"""
# t0 = time()
for r in res:
r = F.zeros_like(r)
for _, varid in enumerate(range(len(jacobians))):
jacobian = jacobians[varid]
j_plain = F.reshape(jacobian, (jacobian.shape[0], jacobian.shape[1], -1))
jt_js = batchmatmul(F.transpose(j_plain, (1, 2, 0)), F.transpose(j_plain, (1, 0, 2)))
# TODO: Remove Parameter init.
res[varid] = Parameter(res[varid])
indexadd_0(res[varid], squeeze_1(indices[varid]), jt_js)
for _, varid in enumerate(range(len(res))):
diagonal = lmDiagonal[varid]
diagonal = F.reshape(diagonal, (diagonal.shape[0], -1))
diagonal_pow2 = diagonal * diagonal
diagonal_pow2 = F.reshape(diagonal_pow2, (*diagonal_pow2.shape, 1))
res[varid] = diagonal_pow2 * Tensor(np.eye(diagonal.shape[1]), ms.float32) + res[varid]
def jacobi_left_multiply(jacobians, residuals, variables, indices, res):
"""left multiply"""
# t0 = time()
for _, i in enumerate(range(len(res))):
res[i] = F.zeros_like(res[i])
for _, varid in enumerate(range(len(variables))):
jacobian = jacobians[varid]
j = F.reshape(jacobian, (jacobian.shape[0], jacobian.shape[1], -1))
j = F.transpose(j, (1, 0, 2))
r = F.reshape(residuals, (residuals.shape[0], residuals.shape[1], 1))
r = P.BroadcastTo(j.shape)(r)
jr = F.reduce_sum(F.mul(j, r), 1)
# TODO: Remove Parameter init.
res[varid] = Parameter(res[varid])
indexadd_0(res[varid], squeeze_1(indices[varid]), jr)
# t1 = time()
return res
def jacobi_right_multiply(jacobians, residuals, variables, indices, res):
"""right multiply"""
# t0 = time()
res = F.zeros_like(res)
for _, varid in enumerate(range(len(variables))):
jacobian = jacobians[varid]
residual = residuals[varid][indices[varid]]
targetShape = [i for i in jacobian.shape]
currentShape = [i for i in jacobian.shape]
currentShape[0] = 1
residual = P.BroadcastTo(tuple(targetShape))(
F.reshape(residual, tuple(currentShape)))
res += F.transpose(F.reduce_sum(F.reshape(F.tensor_mul(residual, jacobian),
(jacobian.shape[0], jacobian.shape[1], -1)),
axis=2), (1, 0))
# t1 = time()
return res
def jacobi_jt_jd(jacobians, diagonal, p, variables, indices, z, res):
"""JtJD"""
for _, i in enumerate(range(len(z))):
z[i] = jacobi_right_multiply(jacobians, p, variables, indices, z[i])
res = jacobi_left_multiply(jacobians, z[i], variables, indices, res)
for _, j in enumerate(range(len(res))):
# res[i] += p[i] * (diagonal[i] ** 2)
res[j] += F.mul(p[j], F.mul(diagonal[j], diagonal[j]))

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""list-vector operations."""
import mindspore.ops.operations as P
import mindspore.ops as ops
from mindspore.ops import functional as F
f_arg_max_with_value = P.ArgMaxWithValue()
f_zeros = P.Zeros()
f_batchmatmul = P.BatchMatMul()
f_matrix_inverse = P.MatrixInverse()
def list_norm(listvec):
n = 0
for vec in listvec:
n += F.reduce_sum(F.mul(vec, vec))
return F.sqrt(n)
def list_max_norm(listvec):
"""max norm in the list"""
n = 0
for vec in listvec:
try:
if vec.shape:
m = f_arg_max_with_value(f_arg_max_with_value(F.absolute(vec))[1])[1]
if m > n:
n = m
except AttributeError:
continue
return n
def list_clamp(listvec, minVal, maxVal):
for _, i in enumerate(range(len(listvec))):
listvec[i] = ops.clip_by_value(listvec[i], minVal, maxVal)
def list_invert(listvec):
for _, i in enumerate(range(len(listvec))):
listvec[i] = f_matrix_inverse(listvec[i])
def list_zero(variables):
zeros = []
for v in variables:
vplain = F.reshape(v, (v.shape[0], -1))
zeros.append(f_zeros(vplain.shape, v.dtype))
return zeros
def list_right_multiply(p, r, res):
for _, varid in enumerate(range(len(r))):
res[varid] = F.reshape(f_batchmatmul(p[varid], F.reshape(r[varid], (r[varid].shape[0], -1, 1))),
(r[varid].shape))
def list_dot(a, b):
res = 0
for _, i in enumerate(range(len(a))):
res += F.reduce_sum(F.mul(a[i], b[i]))
return res

View File

@ -0,0 +1,681 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""DeepLM solver."""
from time import time
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.composite import GradOperation
from mindspore.nn import Cell
from mindspore.common.api import _pynative_exec
from . import jacobian as jb
from . import listvec as lv
_zeros = P.Zeros()
class Strategy:
"""LM solver strategy"""
def __init__(self):
self.decrease_factor = 2
self.radius = 1e4
def reject(self):
self.radius /= self.decrease_factor
self.decrease_factor *= 2
def accept(self, quality):
self.radius /= max(1.0 / 3.0, 1.0 - pow(2.0 * quality - 1.0, 3))
self.decrease_factor = 2.0
class StepEvaluator:
"""step evaluator"""
def __init__(self, reference_cost):
self.reference_cost = reference_cost
self.minimum_cost = reference_cost
self.current_cost = reference_cost
self.candidate_cost = reference_cost
self.accumulated_reference_model_cost_change = 0
self.accumulated_candidate_model_cost_change = 0
self.num_consecutive_nonmonotonic_steps = 0
self.max_consecutive_nonmonotonic_steps = 0
def accept(self, cost, model_cost_change):
"""how step evaluator to accept"""
self.current_cost = cost
self.accumulated_candidate_model_cost_change += model_cost_change
self.accumulated_reference_model_cost_change += model_cost_change
if self.current_cost < self.minimum_cost:
self.minimum_cost = self.current_cost
self.num_consecutive_nonmonotonic_steps = 0
self.candidate_cost = self.current_cost
self.accumulated_candidate_model_cost_change = 0
else:
self.num_consecutive_nonmonotonic_steps += 1
if self.current_cost > self.candidate_cost:
self.candidate_cost = self.current_cost
self.accumulated_candidate_model_cost_change = 0
if self.num_consecutive_nonmonotonic_steps == \
self.max_consecutive_nonmonotonic_steps:
self.reference_cost = self.candidate_cost
self.accumulated_reference_model_cost_change = \
self.accumulated_candidate_model_cost_change
def step_quality(self, cost, model_cost_change):
relative_decrease = (self.current_cost - cost) / model_cost_change
historical_relative_decrease = (self.reference_cost - cost) / (
self.accumulated_reference_model_cost_change + model_cost_change)
return max(relative_decrease, historical_relative_decrease)
class Summary:
"""summary"""
def __init__(self):
self.cost = 0
self.gradient_norm = 0
self.gradient_max_norm = 0
self.cost = None
self.lm_iteration = 0
self.num_consecutive_invalid_steps = 0
self.step_is_valid = True
self.relative_decrease = 0
# 0 = no-converge, 1 = success, 2 = fail
self.linear_termination_type = 0
self.lm_termination_type = 0
class LMOption:
"""LM solver options"""
def __init__(self):
self.min_diagonal = 1e-6
self.max_diagonal = 1e32
self.eta = 1e-3
self.residual_reset_period = 10
self.max_linear_iterations = 150
self.max_num_iterations = 16
self.max_success_iterations = 16
self.max_invalid_step = 5
self.min_relative_decrease = 1e-3
self.parameter_tolerance = 1e-8
self.function_tolerance = 1e-6
self.radius = 1e4
self.span = 4000000
class FunctionBlock:
def __init__(self, variables, constants, indices, fn=None):
self.variables = variables
self.constants = constants
self.indices = indices
self.fn = fn
class ResidualFunc(Cell):
"""residual func"""
def __init__(self, fn, constant_para, res_index=0):
super(ResidualFunc, self).__init__()
self.fn = fn
self.constant_para = constant_para
self.reduce_sum = F.reduce_sum
self.mul = F.mul
self.reshape = F.reshape
self.res_index = res_index
def construct(self, *x):
residuals = self.fn(*x, *self.constant_para)
residuals = self.reshape(residuals, (residuals.shape[0], -1))
return self.reduce_sum(residuals[:, self.res_index])
class LMSolver:
"""LM solver"""
def __init__(self, functions, verbose=True, option=LMOption()):
self.start_time = time()
self.verbose = verbose
self.functions = functions
self.variable_dict = {}
self.variables = []
self.vrange_func = []
for func in functions:
for v in func.variables:
if v in self.variable_dict:
continue
self.variable_dict[v] = len(self.variables)
self.variables.append(v)
self.vrange_func.append(range(len(func.variables)))
self.vranges = range(len(self.variables))
self.option = option
self.vids_func = []
self.res_funcs = []
for func in functions:
self.vids_func.append([self.variable_dict[v] for v in func.variables])
for _, dim in enumerate(range(0, func.indices[0].shape[0], self.option.span)):
start = dim
end = start + self.option.span
if end > func.indices[0].shape[0]:
end = func.indices[0].shape[0]
constants_para = [func.constants[i][start:end] for i in range(len(func.constants))]
self.res_funcs.append(
[ResidualFunc(func.fn, constants_para, index) for index in self.vranges])
# some mindspore OPs which may be used below
self.zeros_like = F.zeros_like
self.reduce_sum = F.reduce_sum
self.grad_op = GradOperation(get_all=True)
self.summary = Summary()
self.strategy = Strategy()
self.strategy.radius = option.radius
self.evaluator = StepEvaluator(0)
self.q_tolerance = 0.0
self.r_tolerance = 0.0
self.x_norm = -1
self.model_cost_change = 0.0
self.delta = None
self.q = None
self.z = None
self.p = None
self.r = None
self.xref = None
self.bref = None
self.model_residuals = None
self.gradients = None
self.jacobians = None
self.preconditioner = None
def memory_tensor(self, b):
k = 8
for i in range(len(b.shape)):
k *= b.shape[i]
return k
def memory_list(self, b):
mem = 0
for l in b:
if isinstance(l, list):
mem += self.memory_list(l)
elif isinstance(l, Tensor):
mem += self.memory_tensor(l)
return mem
def memory(self):
mem = 0
for _, b in self.__dict__.items():
if isinstance(b, list):
mem += self.memory_list(b)
if isinstance(b, Tensor):
mem += self.memory_tensor(b)
return mem
def timing(self):
_pynative_exec.sync()
return time()
def initialize_variables(self):
"""initialize variables to zeros"""
self.bref = lv.list_zero(self.variables)
self.xref = lv.list_zero(self.variables)
self.r = lv.list_zero(self.variables)
self.z = lv.list_zero(self.variables)
self.p = lv.list_zero(self.variables)
self.q = lv.list_zero(self.variables)
self.gradients = lv.list_zero(self.variables)
self.jacobian_scale = lv.list_zero(self.variables)
self.diagonal = lv.list_zero(self.variables)
self.preconditioner = []
for v in self.variables:
l = 1
for j in range(1, len(v.shape)):
l *= v.shape[j]
self.preconditioner.append(_zeros((v.shape[0], l, l), v.dtype))
def evaluate_cost(self, candidate):
"""evaluate cost"""
span = self.option.span
cost = 0
for func_id in range(len(self.functions)):
func = self.functions[func_id]
indices = func.indices
variables = [candidate[j] for j in self.vids_func[func_id]]
constants = func.constants
residual_num = indices[0].shape[0]
for dim in range(0, residual_num, span):
start = dim
end = start + span
if end > residual_num:
end = residual_num
varIndexed = [variables[i][indices[i][start:end]] \
for i in range(len(variables))]
constants_para = [constants[i][start:end] \
for i in range(len(constants))]
residuals = func.fn(*varIndexed, *constants_para)
cost += self.reduce_sum(0.5 * residuals * residuals)
return cost
def gradient(self, res_funcs_index, residual_index, varIndexed):
"""grad net"""
res_func = self.res_funcs[res_funcs_index][residual_index]
grad_net = self.grad_op(res_func)
return grad_net(*varIndexed)
def evaluate(self, is_first_time=False):
"""evaluate process"""
for i in self.vranges:
self.variables[i].grad = None
self.gradients[i] = F.zeros_like(self.gradients[i])
self.jacobian_scale[i] = F.zeros_like(self.jacobian_scale[i])
self.cost = 0
for func_id in range(len(self.functions)):
indices = self.functions[func_id].indices
variables = self.variables
constants = self.functions[func_id].constants
fn = self.functions[func_id].fn
vrange = self.vrange_func[func_id]
residual_num = indices[0].shape[0]
if is_first_time:
var_temp = [variables[i][indices[i][:1]] for i in vrange]
constant_temp = [constants[i][:1] \
for i in range(len(constants))]
# t0 = self.timing()
residuals_temp = fn(*var_temp, *constant_temp)
# t1 = self.timing()
residuals_temp = F.reshape(residuals_temp, (residuals_temp.shape[0], -1))
residual_dim = residuals_temp.shape[1]
self.functions[func_id].jacobians = []
max_dim = 0
for i in vrange:
v = variables[i]
v = F.reshape(v, (v.shape[0], -1))
jacobian = _zeros((residual_dim, *indices[i].shape, *v.shape[1:]), v.dtype)
if v.shape[1] > max_dim:
max_dim = v.shape[1]
self.functions[func_id].jacobians.append(jacobian)
self.functions[func_id].buffer = _zeros(
(indices[0].shape[0], max_dim), variables[0].dtype)
self.functions[func_id].residuals = _zeros(
(residual_num, residual_dim), variables[0].dtype)
span = self.option.span
for _index, dim in enumerate(range(0, residual_num, span)):
start = dim
end = start + span
if end > residual_num:
end = residual_num
varIndexed = [variables[i][indices[i][start:end]] for i in vrange]
constants_para = [constants[i][start:end] for i in range(len(constants))]
residuals = fn(*varIndexed, *constants_para)
residuals = F.reshape(residuals, (residuals.shape[0], -1))
cost = self.reduce_sum(0.5 * residuals * residuals)
# collect gradients and jacobians
for i in vrange:
grads = self.gradient(_index, i, varIndexed)
for j in vrange:
grad = grads[j]
if grad is None:
self.functions[func_id].jacobians[j][i, start:end] = self.zeros_like(
self.functions[func_id].jacobians[j][i, start:end]
)
continue
grad = F.reshape(grad, (grad.shape[0], grad.shape[1], -1))
self.functions[func_id].jacobians[j][i, start:end] = grad
self.functions[func_id].residuals[start:end] = F.tensor_mul(residuals, 1.0)
self.cost += cost
residuals = self.functions[func_id].residuals
gradients = [self.gradients[i] for i in self.vids_func[func_id]]
jacobians = self.functions[func_id].jacobians
jb.jacobi_left_multiply(jacobians, residuals, self.variables, indices, gradients)
self.jacobian_scale = [self.jacobian_scale[i] \
for i in self.vids_func[func_id]]
self.jacobian_scale = jb.jacobi_column_square(indices, jacobians, self.jacobian_scale)
jb.column_inverse_square(self.jacobian_scale)
for func_id in range(len(self.functions)):
indices = self.functions[func_id].indices
jacobians = self.functions[func_id].jacobians
self.jacobian_scale = jb.jacobi_normalize(jacobians, indices, self.variables)
self.summary.gradient_norm = lv.list_norm(gradients)
self.summary.gradient_max_norm = lv.list_max_norm(gradients)
def preconditioner_initialize(self):
for i in range(len(self.preconditioner)):
self.preconditioner[i] = self.zeros_like(self.preconditioner[i])
def linear_solve(self, lmDiagonal):
"""linear solver"""
xref = self.xref
bref = self.bref
r = self.r
z = self.z
p = self.p
for i in range(len(self.functions)):
jacobians = self.functions[i].jacobians
residuals = self.functions[i].residuals
indices = self.functions[i].indices
bref = jb.jacobi_left_multiply(jacobians, residuals, self.variables, indices, bref)
self.preconditioner_initialize()
for i in range(len(self.functions)):
jacobians = self.functions[i].jacobians
indices = self.functions[i].indices
lmDiagonalTemp = [lmDiagonal[j] for j in self.vids_func[i]]
jb.jacobi_block_jt(jacobians, lmDiagonalTemp, indices, self.preconditioner)
lv.list_invert(self.preconditioner)
self.summary.linear_termination_type = 0
self.summary.lm_iteration = 0
norm_b = lv.list_norm(bref)
for i in self.vranges:
xref[i] = self.zeros_like(xref[i])
r[i] = F.mul(bref[i], 1)
if norm_b == 0:
self.summary.linear_termination_type = 1
return xref
tol_r = self.r_tolerance * norm_b
rho = 1.0
q_0 = -0.0
self.summary.num_iterations = 0
while True:
# t1 = self.timing()
self.summary.num_iterations += 1
lv.list_right_multiply(self.preconditioner, r, self.z)
last_rho = rho
rho = lv.list_dot(r, z)
if self.summary.num_iterations == 1:
for i in self.vranges:
p[i] = F.mul(z[i], 1)
# p[i] = Tensor(z[i].asnumpy())
else:
beta = rho / last_rho
for i in self.vranges:
p[i] *= beta
p[i] += z[i]
jb.jacobi_jt_jd(jacobians, lmDiagonal, p, self.variables, indices, self.model_residuals,
self.q)
pq = lv.list_dot(p, self.q)
if pq < 0:
self.summary.linear_termination_type = 2
break
# return xref
alpha = rho / pq
for i in self.vranges:
xref[i] += alpha * p[i]
# this is to avoid numercial issue: recompute r every reset steps
if self.summary.num_iterations % \
self.option.residual_reset_period == 0:
jb.jacobi_jt_jd(
jacobians, lmDiagonal, xref, self.variables, indices, self.model_residuals,
self.q)
for i in self.vranges:
r[i] = F.mul((bref[i] - self.q[i]), 1)
# r[i] = Tensor((bref[i] - self.q[i]).asnumpy())
r = [bref[i] - self.q[i] for i in range(len(r))]
else:
for i in self.vranges:
r[i] -= F.mul(alpha, self.q[i])
q_1 = -1.0 * lv.list_dot(xref, [bref[i] + r[i] for i in range(len(r))])
zeta = self.summary.num_iterations * (q_1 - q_0) / q_1
if zeta < self.q_tolerance:
self.summary.linear_termination_type = 1
break
q_0 = q_1
norm_r = lv.list_norm(r)
if norm_r < tol_r:
self.summary.linear_termination_type = 1
break
if self.summary.num_iterations > self.option.max_linear_iterations:
break
return xref
def compute_trust_region_step(self):
"""step of computing the trust region"""
for i in range(len(self.diagonal)):
self.diagonal[i] = self.zeros_like(self.diagonal[i])
for i in range(len(self.functions)):
indices = self.functions[i].indices
jacobians = self.functions[i].jacobians
diagonal = [self.diagonal[j] for j in self.vids_func[i]]
jb.jacobi_column_square(indices, jacobians, diagonal)
diagonal = self.diagonal
lv.list_clamp(diagonal, self.option.min_diagonal, self.option.max_diagonal)
lmDiagonal = []
for v in diagonal:
lmDiagonal.append(F.sqrt(v / self.strategy.radius))
self.q_tolerance = self.option.eta
self.r_tolerance = -1.0
step = self.linear_solve(lmDiagonal)
for i in self.vranges:
step[i] = -step[i]
for i in range(len(self.functions)):
indices = self.functions[i].indices
jacobians = self.functions[i].jacobians
self.model_residuals[i] = jb.jacobi_right_multiply(jacobians, step, self.variables,
indices, self.model_residuals[i])
self.model_cost_change = 0
for i in range(len(self.functions)):
self.model_cost_change += -self.reduce_sum(
self.model_residuals[i] * (
self.functions[i].residuals + self.model_residuals[i] * 0.5))
self.summary.step_is_valid = (self.model_cost_change > 0.0)
if self.summary.step_is_valid:
self.delta = [step[i] * self.jacobian_scale[i] for i in range(len(step))]
self.summary.num_consecutive_invalid_steps = 0
def solve(self):
"""LM solver main func"""
# t00 = self.timing()
self.initialize_variables()
# t0 = self.timing()
self.evaluate(True)
if self.option.max_success_iterations == 0:
return
if self.verbose:
print('\nInitial cost = {}, Memory = {} G'.format(
self.cost, self.memory() / 1024.0 / 1024.0 / 1024.0))
self.model_residuals = [F.mul(
self.functions[func_id].residuals, 1.0) for func_id in range(len(self.functions))]
if self.summary.lm_iteration == 0:
self.evaluator = StepEvaluator(self.cost)
outerIterations = 0
success_iterations = 0
self.debug = False
while True:
# t2 = self.timing()
outerIterations += 1
if outerIterations == self.option.max_num_iterations:
break
self.compute_trust_region_step()
# t3 = self.timing()
if not self.summary.step_is_valid:
self.summary.num_consecutive_invalid_steps += 1
if self.summary.num_consecutive_invalid_steps \
> self.option.max_invalid_step:
self.summary.lm_termination_type = 2
return
self.strategy.reject()
continue
candidate_x = [self.variables[i] + F.reshape(
self.delta[i], self.variables[i].shape) for i in self.vranges]
cost = self.evaluate_cost(candidate_x)
# parameter tolerance check
step_size_tolerance = self.option.parameter_tolerance \
* (self.x_norm + self.option.parameter_tolerance)
step_norm = lv.list_norm(self.delta)
if step_norm < step_size_tolerance:
self.summary.lm_termination_type = 1
return
# function tolerance check
cost_change = self.cost - cost
absolute_function_tolerance = \
self.option.function_tolerance * self.cost
if abs(cost_change.asnumpy()) < absolute_function_tolerance:
self.summary.lm_termination_type = 1
return
# evaluate relative decrease
self.summary.relative_decrease = self.evaluator.step_quality(
cost, self.model_cost_change)
if self.summary.relative_decrease \
> self.option.min_relative_decrease:
for i in self.vranges:
self.variables[i] += F.reshape(self.delta[i], self.variables[i].shape)
self.x_norm = lv.list_norm(self.variables)
self.strategy.accept(self.summary.relative_decrease)
self.evaluator.accept(cost, self.model_cost_change)
if self.verbose:
current_time = self.timing()
print('iter = {}, cost = {}, radius = {}, CGstep = {}, time = {}'.format(
outerIterations,
cost,
self.strategy.radius,
self.summary.num_iterations,
current_time - self.start_time))
success_iterations += 1
if success_iterations >= self.option.max_success_iterations:
self.cost = cost
break
self.evaluate(True)
else:
self.strategy.reject()
if self.verbose:
print('iter = %d (rejected)' % outerIterations)
if self.strategy.radius < 1e-32 or \
self.summary.gradient_max_norm < 1e-10:
self.summary.lm_termination_type = 1
return
def solve(variables, constants, indices, fn, num_iterations=15,
num_success_iterations=15, max_linear_iterations=150, verbose=True):
"""main"""
if not indices:
return None
for _, i in enumerate(range(len(indices))):
if indices[i].shape[0] == 0:
return None
func = FunctionBlock(variables=variables, constants=constants, indices=indices, fn=fn)
option = LMOption()
option.max_linear_iterations = max_linear_iterations
solver = LMSolver(functions=[func], verbose=verbose, option=option)
solver.option.max_num_iterations = num_iterations
solver.option.max_success_iterations = num_success_iterations
for _, i in enumerate(range(len(indices))):
index = indices[i]
if len(index.shape) == 1:
index = F.reshape(index, (-1, 1))
indices[i] = index
solver.solve()
return solver.cost