feat: 支持授权角色和用户访问主机

This commit is contained in:
zze326 2023-09-27 14:24:17 +08:00
parent cc2deb3305
commit 2d24a3ef85
21 changed files with 195 additions and 23 deletions

1
.gitignore vendored
View File

@ -18,3 +18,4 @@ temp.yaml
bin
**/config/config-prod.yaml
dist
host-sessions

View File

@ -12,6 +12,7 @@ import (
type IHostGroupV1 interface {
GetLst(ctx context.Context, req *v1.GetLstReq) (res *v1.GetLstRes, err error)
GetPartialList(ctx context.Context, req *v1.GetPartialListReq) (res *v1.GetPartialListRes, err error)
Add(ctx context.Context, req *v1.AddReq) (res *v1.AddRes, err error)
Upt(ctx context.Context, req *v1.UptReq) (res *v1.UptRes, err error)
Del(ctx context.Context, req *v1.DelReq) (res *v1.DelRes, err error)

View File

@ -15,6 +15,14 @@ type GetLstRes struct {
List []*entity.HostGroup `json:"list"`
}
type GetPartialListReq struct {
g.Meta `method:"get" path:"/host-group/partial-list" summary:"获取主机组列表(部分字段)" tags:"主机组"`
}
type GetPartialListRes struct {
List []*mid.HostGroupPartial `json:"list"`
}
type AddReq struct {
g.Meta `method:"post" path:"/host-group" summary:"添加主机组" tags:"主机组"`
*mid.HostGroup

View File

@ -11,6 +11,7 @@ import (
)
type IUserV1 interface {
GetLst(ctx context.Context, req *v1.GetLstReq) (res *v1.GetLstRes, err error)
GetPageLst(ctx context.Context, req *v1.GetPageLstReq) (res *v1.GetPageLstRes, err error)
Add(ctx context.Context, req *v1.AddReq) (res *v1.AddRes, err error)
Upt(ctx context.Context, req *v1.UptReq) (res *v1.UptRes, err error)
@ -18,5 +19,3 @@ type IUserV1 interface {
UptEnabled(ctx context.Context, req *v1.UptEnabledReq) (res *v1.UptEnabledRes, err error)
Del(ctx context.Context, req *v1.DelReq) (res *v1.DelRes, err error)
}

View File

@ -7,6 +7,14 @@ import (
"github.com/gogf/gf/v2/frame/g"
)
type GetLstReq struct {
g.Meta `method:"get" path:"/user/list" summary:"获取用户列表" tags:"用户"`
}
type GetLstRes struct {
List []*entity.User `json:"list"`
}
type GetPageLstReq struct {
g.Meta `method:"get" path:"/user/page-list" summary:"分页获取用户列表" tags:"用户"`
*api.PageLstReq

View File

@ -2,12 +2,30 @@ package host
import (
"context"
"devops-super/internal/model/do"
"devops-super/internal/service"
"github.com/gogf/gf/v2/errors/gerror"
"devops-super/api/host/v1"
)
func (c *ControllerV1) TestSsh(ctx context.Context, req *v1.TestSshReq) (res *v1.TestSshRes, err error) {
err = service.Host().TestSSH(ctx, req.Id)
eHost, err := service.Host().Get(ctx, &do.Host{Id: req.Id})
if err != nil {
return nil, err
}
if eHost == nil {
return nil, gerror.New("主机不存在")
}
can, err := service.Host().CanAccess(ctx, eHost)
if err != nil {
return nil, err
}
if !can {
return nil, gerror.New("未授权")
}
err = service.Host().TestSSH(ctx, eHost)
return
}

View File

@ -17,6 +17,16 @@ func (c *ControllerV1) WsSftpFileManager(ctx context.Context, req *v1.WsSftpFile
if eHost == nil {
return nil, gerror.New("主机不存在")
}
can, err := service.Host().CanAccess(ctx, eHost)
if err != nil {
return nil, err
}
if !can {
return nil, gerror.New("未授权")
}
err = service.Host().WsSftpFileManager(ctx, eHost)
return
}

View File

@ -17,6 +17,16 @@ func (c *ControllerV1) WsTerminal(ctx context.Context, req *v1.WsTerminalReq) (r
if eHost == nil {
return nil, gerror.New("主机不存在")
}
can, err := service.Host().CanAccess(ctx, eHost)
if err != nil {
return nil, err
}
if !can {
return nil, gerror.New("未授权")
}
err = service.Host().WsTerminal(ctx, eHost)
return
}

View File

@ -0,0 +1,35 @@
package host_group
import (
"context"
"devops-super/internal/model/mid"
"devops-super/internal/service"
"devops-super/api/host_group/v1"
)
func (c *ControllerV1) GetPartialList(ctx context.Context, req *v1.GetPartialListReq) (res *v1.GetPartialListRes, err error) {
var resList []*mid.HostGroupPartial
eHostGroupList, err := service.HostGroup().GetLst(ctx, "")
if err != nil {
return nil, err
}
for _, eHostGorup := range eHostGroupList {
hostCount, err := service.Host().GetCountByHostGroupId(ctx, eHostGorup.Id)
if err != nil {
return nil, err
}
resList = append(resList, &mid.HostGroupPartial{
Id: eHostGorup.Id,
Name: eHostGorup.Name,
ParentId: eHostGorup.ParentId,
HostCount: hostCount,
})
}
res = &v1.GetPartialListRes{
List: resList,
}
return
}

View File

@ -0,0 +1,14 @@
package user
import (
"context"
"devops-super/internal/service"
"devops-super/api/user/v1"
)
func (c *ControllerV1) GetLst(ctx context.Context, req *v1.GetLstReq) (res *v1.GetLstRes, err error) {
res = new(v1.GetLstRes)
res.List, err = service.User().GetLst(ctx)
return
}

View File

@ -25,6 +25,8 @@ type HostGroupColumns struct {
Rank string // 排序
ParentId string // 上级主机组 id
UpdatedAt string // 更新时间
RoleIds string // 可访问的角色 id 列表
UserIds string // 可访问的用户 id 列表
}
// hostGroupColumns holds the columns for table host_group.
@ -34,6 +36,8 @@ var hostGroupColumns = HostGroupColumns{
Rank: "rank",
ParentId: "parent_id",
UpdatedAt: "updated_at",
RoleIds: "role_ids",
UserIds: "user_ids",
}
// NewHostGroupDao creates and returns a new DAO object for table data access.

View File

@ -12,6 +12,7 @@ import (
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
"os"
"path/filepath"
@ -56,6 +57,10 @@ func (*sHost) Upt(ctx context.Context, in *do.Host) (err error) {
return
}
func (*sHost) GetCountByHostGroupId(ctx context.Context, hostGroupId int) (int, error) {
return dao.Host.Ctx(ctx).Where(cols.HostGroupId, hostGroupId).Count()
}
func (*sHost) GetPageLst(ctx context.Context, in *api.PageLstReq) (out *api.PageLstRes[*entity.Host], err error) {
out = &api.PageLstRes[*entity.Host]{}
m := dao.Host.Ctx(ctx).Safe(true)
@ -82,13 +87,8 @@ func (*sHost) Del(ctx context.Context, in *do.Host) (err error) {
return
}
func (s *sHost) TestSSH(ctx context.Context, id int) (err error) {
eHost, err := s.Get(ctx, &do.Host{Id: id})
if err != nil {
return err
}
client, err := s.SshClient(eHost)
func (s *sHost) TestSSH(ctx context.Context, in *entity.Host) (err error) {
client, err := s.SshClient(in)
if err != nil {
return err
}
@ -135,3 +135,37 @@ func (s *sHost) DownloadFile(ctx context.Context, in *mid.DownloadFileIn) error
return nil
}
func (s *sHost) CanAccess(ctx context.Context, in *entity.Host) (bool, error) {
if service.CurrentUser(ctx).IsAdmin() {
return true, nil
}
// 1. 获取机器所属主机组 in
eHostGroup, err := service.HostGroup().Get(ctx, &do.HostGroup{Id: in.HostGroupId})
if err != nil {
return false, err
}
// 2. 获取主机组授权的角色和用户 eHostGroup.RoleIds eHostGroup.UserIds
// 3. 获取当前用户的角色
eUser, err := service.User().Get(ctx, &do.User{Id: service.CurrentUser(ctx).UserId})
if err != nil {
return false, err
}
// 4. 如果当前用户存在于主机组授权的用户列表,则有权限
for _, hostGroupUserId := range eHostGroup.UserIds.Array() {
if eUser.Id == gconv.Int(hostGroupUserId) {
return true, nil
}
}
// 5. 如果当前用户拥有的角色存在于主机组授权的角色,则有权限
for _, userRoleId := range eUser.RoleIds.Array() {
for _, hostGroupRoleId := range eHostGroup.RoleIds.Array() {
if userRoleId == hostGroupRoleId {
return true, nil
}
}
}
return false, nil
}

View File

@ -26,6 +26,7 @@ func (s *sHost) WsSftpFileManager(ctx context.Context, in *entity.Host) (err err
if wsCtx.ws, err = wsCtx.request.WebSocket(); err != nil {
return err
}
defer wsCtx.ws.Close()
sftpClient, err := s.SftpClient(in)
if err != nil {

View File

@ -44,6 +44,11 @@ func (*sUser) Upt(ctx context.Context, in *do.User) (err error) {
return
}
func (*sUser) GetLst(ctx context.Context) (out []*entity.User, err error) {
err = dao.User.Ctx(ctx).FieldsEx(cols.RoleIds, cols.Phone, cols.Password, cols.DeptId).OrderDesc(cols.Id).Scan(&out)
return
}
func (*sUser) GetPageLst(ctx context.Context, in *api.PageLstReq) (out *api.PageLstRes[*entity.User], err error) {
out = &api.PageLstRes[*entity.User]{}
m := dao.User.Ctx(ctx).Safe(true)

View File

@ -11,3 +11,7 @@ type RequestUser struct {
RealName string `json:"realName"`
Username string `json:"username"`
}
func (u *RequestUser) IsAdmin() bool {
return u.Username == "admin"
}

View File

@ -5,6 +5,7 @@
package do
import (
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
@ -17,4 +18,6 @@ type HostGroup struct {
Rank interface{} // 排序
ParentId interface{} // 上级主机组 id
UpdatedAt *gtime.Time // 更新时间
RoleIds *gjson.Json // 可访问的角色 id 列表
UserIds *gjson.Json // 可访问的用户 id 列表
}

View File

@ -5,6 +5,7 @@
package entity
import (
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/os/gtime"
)
@ -15,4 +16,6 @@ type HostGroup struct {
Rank int `json:"rank" description:"排序"` // 排序
ParentId int `json:"parentId" description:"上级主机组 id"` // 上级主机组 id
UpdatedAt *gtime.Time `json:"updatedAt" description:"更新时间"` // 更新时间
RoleIds *gjson.Json `json:"roleIds" description:"可访问的角色 id 列表"` // 可访问的角色 id 列表
UserIds *gjson.Json `json:"userIds" description:"可访问的用户 id 列表"` // 可访问的用户 id 列表
}

View File

@ -11,7 +11,7 @@ import (
// User is the golang structure for table user.
type User struct {
Id uint `json:"id" description:""` //
Id int `json:"id" description:""` //
Username string `json:"username" description:"用户名"` // 用户名
Password string `json:"password" description:"密码"` // 密码
Phone string `json:"phone" description:"手机号码"` // 手机号码

View File

@ -1,7 +1,18 @@
package mid
import "github.com/gogf/gf/v2/encoding/gjson"
type HostGroup struct {
Name string `v:"required" json:"name"`
Rank int `v:"required" json:"rank"`
ParentId int `json:"parentId"`
RoleIds *gjson.Json `json:"roleIds"`
UserIds *gjson.Json `json:"userIds"`
}
type HostGroupPartial struct {
Id int `json:"id"`
Name string `json:"name"`
ParentId int `json:"parentId"`
HostCount int `json:"hostCount"`
}

View File

@ -20,11 +20,13 @@ type (
IHost interface {
Add(ctx context.Context, in *entity.Host) (err error)
Upt(ctx context.Context, in *do.Host) (err error)
GetCountByHostGroupId(ctx context.Context, hostGroupId int) (int, error)
GetPageLst(ctx context.Context, in *api.PageLstReq) (out *api.PageLstRes[*entity.Host], err error)
Get(ctx context.Context, in *do.Host) (out *entity.Host, err error)
Del(ctx context.Context, in *do.Host) (err error)
TestSSH(ctx context.Context, id int) (err error)
TestSSH(ctx context.Context, in *entity.Host) (err error)
DownloadFile(ctx context.Context, in *mid.DownloadFileIn) error
CanAccess(ctx context.Context, in *entity.Host) (bool, error)
WsSftpFileManager(ctx context.Context, in *entity.Host) (err error)
SftpClient(in *entity.Host) (*sftp.Client, error)
WsTerminal(ctx context.Context, in *entity.Host) error

View File

@ -17,6 +17,7 @@ type (
IUser interface {
Add(ctx context.Context, in *entity.User) (err error)
Upt(ctx context.Context, in *do.User) (err error)
GetLst(ctx context.Context) (out []*entity.User, err error)
GetPageLst(ctx context.Context, in *api.PageLstReq) (out *api.PageLstRes[*entity.User], err error)
Get(ctx context.Context, userDo *do.User) (out *entity.User, err error)
GetComb(ctx context.Context, userDo *do.User) (out *comb.User, err error)