update auth check and use listv2 for github (#744)

This commit is contained in:
Vistaar Juneja 2023-10-30 18:07:38 +00:00 committed by Harness
parent bd31faee07
commit 081d79717a
1 changed files with 63 additions and 41 deletions

View File

@ -106,51 +106,43 @@ func hash(s string) string {
return base32.StdEncoding.EncodeToString(h.Sum(nil)[:10])
}
func oauthTransport(token string) (http.RoundTripper, error) {
func oauthTransport(token string, scheme string) http.RoundTripper {
if token == "" {
return nil, errors.New("no token provided")
return nil
}
return &oauth2.Transport{
Scheme: scheme,
Source: oauth2.StaticTokenSource(&scm.Token{Token: token}),
}, nil
}
func gogsTransport(token string) (http.RoundTripper, error) {
if token == "" {
return nil, errors.New("no token provided")
}
return &oauth2.Transport{
Scheme: oauth2.SchemeToken,
Source: oauth2.StaticTokenSource(&scm.Token{Token: token}),
}, nil
}
func authHeaderTransport(token string) (http.RoundTripper, error) {
func authHeaderTransport(token string) http.RoundTripper {
if token == "" {
return nil, errors.New("no token provided")
return nil
}
return &transport.Authorization{
Scheme: "token",
Credentials: token,
}, nil
}
}
func basicAuthTransport(username, password string) (http.RoundTripper, error) {
if username == "" || password == "" {
return nil, errors.New("username or password not provided")
}
func basicAuthTransport(username, password string) http.RoundTripper {
return &transport.BasicAuth{
Username: username,
Password: password,
}, nil
}
}
// getScmClientWithTransport creates an SCM client along with the necessary transport
// layer depending on the provider. For example, for bitbucket we support app passwords
// so the auth transport is BasicAuth whereas it's Oauth for other providers.
func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolint:gocognit
// It validates that auth credentials are provided if authReq is true.
func getScmClientWithTransport(provider Provider, authReq bool) (*scm.Client, error) { //nolint:gocognit
if authReq && (provider.Username == "" || provider.Password == "") {
return nil, usererror.BadRequest("scm provider authentication credentials missing")
}
var c *scm.Client
var err, transportErr error
var err error
var transport http.RoundTripper
switch provider.Type {
case "":
@ -165,7 +157,7 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
} else {
c = github.NewDefault()
}
transport, transportErr = oauthTransport(provider.Password)
transport = oauthTransport(provider.Password, oauth2.SchemeBearer)
case ProviderTypeGitLab:
if provider.Host != "" {
@ -176,7 +168,7 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
} else {
c = gitlab.NewDefault()
}
transport, transportErr = oauthTransport(provider.Password)
transport = oauthTransport(provider.Password, oauth2.SchemeBearer)
case ProviderTypeBitbucket:
if provider.Host != "" {
@ -187,7 +179,7 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
} else {
c = bitbucket.NewDefault()
}
transport, transportErr = basicAuthTransport(provider.Username, provider.Password)
transport = basicAuthTransport(provider.Username, provider.Password)
case ProviderTypeStash:
if provider.Host != "" {
@ -198,7 +190,7 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
} else {
c = stash.NewDefault()
}
transport, transportErr = oauthTransport(provider.Password)
transport = oauthTransport(provider.Password, oauth2.SchemeBearer)
case ProviderTypeGitea:
if provider.Host == "" {
@ -208,7 +200,7 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
if err != nil {
return nil, fmt.Errorf("scm provider Host invalid: %w", err)
}
transport, transportErr = authHeaderTransport(provider.Password)
transport = authHeaderTransport(provider.Password)
case ProviderTypeGogs:
if provider.Host == "" {
@ -218,23 +210,22 @@ func getScmClientWithTransport(provider Provider) (*scm.Client, error) { //nolin
if err != nil {
return nil, fmt.Errorf("scm provider Host invalid: %w", err)
}
transport, transportErr = gogsTransport(provider.Password)
transport = oauthTransport(provider.Password, oauth2.SchemeToken)
default:
return nil, fmt.Errorf("unsupported scm provider: %s", provider)
}
if transportErr != nil {
return nil, fmt.Errorf("could not create transport: %w", transportErr)
// override default transport if available
if transport != nil {
c.Client = &http.Client{Transport: transport}
}
c.Client = &http.Client{Transport: transport}
return c, nil
}
func LoadRepositoryFromProvider(ctx context.Context, provider Provider, repoSlug string) (RepositoryInfo, error) {
scmClient, err := getScmClientWithTransport(provider)
scmClient, err := getScmClientWithTransport(provider, false)
if err != nil {
return RepositoryInfo{}, usererror.BadRequestf("could not create client: %s", err)
}
@ -257,12 +248,14 @@ func LoadRepositoryFromProvider(ctx context.Context, provider Provider, repoSlug
}, nil
}
//nolint:gocognit
func LoadRepositoriesFromProviderSpace(
ctx context.Context,
provider Provider,
spaceSlug string,
) ([]RepositoryInfo, error) {
scmClient, err := getScmClientWithTransport(provider)
var err error
scmClient, err := getScmClientWithTransport(provider, true)
if err != nil {
return nil, usererror.BadRequestf("could not create client: %s", err)
}
@ -271,11 +264,37 @@ func LoadRepositoriesFromProviderSpace(
return nil, usererror.BadRequest("provider space identifier is missing")
}
opts := scm.ListOptions{
Size: 100,
}
var optsv2 scm.RepoListOptions
listv2 := false
if provider.Type == ProviderTypeGitHub {
listv2 = true
optsv2 = scm.RepoListOptions{
ListOptions: opts,
RepoSearchTerm: scm.RepoSearchTerm{
User: spaceSlug,
},
}
}
repos := make([]RepositoryInfo, 0)
opts := scm.ListOptions{Size: 100}
var scmRepos []*scm.Repository
var scmResp *scm.Response
for {
scmRepos, scmResp, err := scmClient.Repositories.List(ctx, opts)
if listv2 {
scmRepos, scmResp, err = scmClient.Repositories.ListV2(ctx, optsv2)
optsv2.Page = scmResp.Page.Next
optsv2.URL = scmResp.Page.NextURL
} else {
scmRepos, scmResp, err = scmClient.Repositories.List(ctx, opts)
opts.Page = scmResp.Page.Next
opts.URL = scmResp.Page.NextURL
}
if err = convertSCMError(provider, spaceSlug, scmResp, err); err != nil {
return nil, err
}
@ -299,11 +318,14 @@ func LoadRepositoriesFromProviderSpace(
})
}
opts.Page = scmResp.Page.Next
opts.URL = scmResp.Page.NextURL
if opts.Page == 0 && opts.URL == "" {
break
if listv2 {
if optsv2.Page == 0 && optsv2.URL == "" {
break
}
} else {
if opts.Page == 0 && opts.URL == "" {
break
}
}
}