Merge pull request #7106 from sfc-gh-ahusain/ahusain-fdb-mock-kms
FDB native MockKMS REST server implementation
This commit is contained in:
commit
524365083d
|
@ -0,0 +1,179 @@
|
|||
/*
|
||||
* fault_injection.go
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Interface supports client to inject fault(s)
|
||||
// Module enables a client to update { FaultLocation -> FaultStatus } mapping in a
|
||||
// thread-safe manner, however, client is responsible to synchronize fault status
|
||||
// updates across 'getEncryptionKeys' REST requests to obtain predictable results.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Fault struct {
|
||||
Location int `json:"fault_location"`
|
||||
Enable bool `json:"enable_fault"`
|
||||
}
|
||||
|
||||
type FaultInjectionRequest struct {
|
||||
Faults []Fault `json:"faults"`
|
||||
}
|
||||
|
||||
type FaultInjectionResponse struct {
|
||||
Faults []Fault `json:"faults"`
|
||||
}
|
||||
|
||||
type faultLocMap struct {
|
||||
locMap map[int]bool
|
||||
rwLock sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
faultLocMapInstance *faultLocMap // Singleton mapping of { FaultLocation -> FaultStatus }
|
||||
)
|
||||
|
||||
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
|
||||
func NewFaultLocMap() *faultLocMap {
|
||||
if faultLocMapInstance == nil {
|
||||
faultLocMapInstance = &faultLocMap{}
|
||||
|
||||
faultLocMapInstance.rwLock = sync.RWMutex{}
|
||||
faultLocMapInstance.locMap = map[int]bool {
|
||||
READ_HTTP_REQUEST_BODY : false,
|
||||
UNMARSHAL_REQUEST_BODY_JSON : false,
|
||||
UNSUPPORTED_QUERY_MODE : false,
|
||||
PARSE_HTTP_REQUEST : false,
|
||||
MARSHAL_RESPONSE : false,
|
||||
}
|
||||
}
|
||||
return faultLocMapInstance
|
||||
}
|
||||
|
||||
func getLocFaultStatus(loc int) (val bool, found bool) {
|
||||
if faultLocMapInstance == nil {
|
||||
panic("FaultLocMap not intialized")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
faultLocMapInstance.rwLock.RLock()
|
||||
defer faultLocMapInstance.rwLock.RUnlock()
|
||||
val, found = faultLocMapInstance.locMap[loc]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func updateLocFaultStatuses(faults []Fault) (updated []Fault, err error) {
|
||||
if faultLocMapInstance == nil {
|
||||
panic("FaultLocMap not intialized")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
updated = []Fault{}
|
||||
err = nil
|
||||
|
||||
faultLocMapInstance.rwLock.Lock()
|
||||
defer faultLocMapInstance.rwLock.Unlock()
|
||||
for i := 0; i < len(faults); i++ {
|
||||
fault := faults[i]
|
||||
|
||||
oldVal, found := faultLocMapInstance.locMap[fault.Location]
|
||||
if !found {
|
||||
err = fmt.Errorf("Unknown fault_location '%d'", fault.Location)
|
||||
return
|
||||
}
|
||||
faultLocMapInstance.locMap[fault.Location] = fault.Enable
|
||||
log.Printf("Update Location '%d' oldVal '%t' newVal '%t'", fault.Location, oldVal, fault.Enable)
|
||||
}
|
||||
|
||||
// return the updated faultLocMap
|
||||
for loc, enable := range faultLocMapInstance.locMap {
|
||||
var f Fault
|
||||
f.Location = loc
|
||||
f.Enable = enable
|
||||
updated = append(updated, f)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func jsonifyFaultArr(w http.ResponseWriter, faults []Fault) (jResp string) {
|
||||
resp := FaultInjectionResponse{
|
||||
Faults: faults,
|
||||
}
|
||||
|
||||
mResp, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Printf("Error marshaling response '%s'", err.Error())
|
||||
sendErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
jResp = string(mResp)
|
||||
return
|
||||
}
|
||||
|
||||
func updateFaultLocMap(w http.ResponseWriter, faults []Fault) {
|
||||
updated , err := updateLocFaultStatuses(faults)
|
||||
if err != nil {
|
||||
sendErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, jsonifyFaultArr(w, updated))
|
||||
}
|
||||
|
||||
func shouldInjectFault(loc int) bool {
|
||||
status, found := getLocFaultStatus(loc)
|
||||
if !found {
|
||||
log.Printf("Unknown fault_location '%d'", loc)
|
||||
return false
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func handleUpdateFaultInjection(w http.ResponseWriter, r *http.Request) {
|
||||
byteArr, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
log.Printf("Http request body read error '%s'", err.Error())
|
||||
sendErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := FaultInjectionRequest{}
|
||||
err = json.Unmarshal(byteArr, &req)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing FaultInjectionRequest '%s'", string(byteArr))
|
||||
sendErrorResponse(w, err)
|
||||
}
|
||||
updateFaultLocMap(w, req.Faults)
|
||||
}
|
||||
|
||||
func initFaultLocMap() {
|
||||
faultLocMapInstance = NewFaultLocMap()
|
||||
log.Printf("FaultLocMap int done")
|
||||
}
|
|
@ -0,0 +1,321 @@
|
|||
/*
|
||||
* get_encryption_keys.go
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// GetEncryptionKeys handler
|
||||
// Handler is resposible for the following:
|
||||
// 1. Parse the incoming HttpRequest and validate JSON request structural sanity
|
||||
// 2. Ability to handle getEncryptionKeys by 'KeyId' or 'DomainId' as requested
|
||||
// 3. Ability to inject faults if requested
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type CipherDetailRes struct {
|
||||
BaseCipherId uint64 `json:"base_cipher_id"`
|
||||
EncryptDomainId int64 `json:"encrypt_domain_id"`
|
||||
BaseCipher string `json:"base_cipher"`
|
||||
}
|
||||
|
||||
type ValidationToken struct {
|
||||
TokenName string `json:"token_name"`
|
||||
TokenValue string `json:"token_value"`
|
||||
}
|
||||
|
||||
type CipherDetailReq struct {
|
||||
BaseCipherId uint64 `json:"base_cipher_id"`
|
||||
EncryptDomainId int64 `json:"encrypt_domain_id"`
|
||||
}
|
||||
|
||||
type GetEncryptKeysResponse struct {
|
||||
CipherDetails []CipherDetailRes `json:"cipher_key_details"`
|
||||
KmsUrls []string `json:"kms_urls"`
|
||||
}
|
||||
|
||||
type GetEncryptKeysRequest struct {
|
||||
QueryMode string `json:"query_mode"`
|
||||
CipherDetails []CipherDetailReq `json:"cipher_key_details"`
|
||||
ValidationTokens []ValidationToken `json:"validation_tokens"`
|
||||
RefreshKmsUrls bool `json:"refresh_kms_urls"`
|
||||
}
|
||||
|
||||
type cipherMapInstanceSingleton map[uint64][]byte
|
||||
|
||||
const (
|
||||
READ_HTTP_REQUEST_BODY = iota
|
||||
UNMARSHAL_REQUEST_BODY_JSON
|
||||
UNSUPPORTED_QUERY_MODE
|
||||
PARSE_HTTP_REQUEST
|
||||
MARSHAL_RESPONSE
|
||||
)
|
||||
|
||||
const (
|
||||
maxCipherKeys = uint64(1024*1024) // Max cipher keys
|
||||
maxCipherSize = 16 // Max cipher buffer size
|
||||
)
|
||||
|
||||
var (
|
||||
cipherMapInstance cipherMapInstanceSingleton // Singleton mapping of { baseCipherId -> baseCipher }
|
||||
)
|
||||
|
||||
// const mapping of { Location -> errorString }
|
||||
func errStrMap() func(int) string {
|
||||
_errStrMap := map[int]string {
|
||||
READ_HTTP_REQUEST_BODY : "Http request body read error",
|
||||
UNMARSHAL_REQUEST_BODY_JSON : "Http request body unmarshal error",
|
||||
UNSUPPORTED_QUERY_MODE : "Unsupported query_mode",
|
||||
PARSE_HTTP_REQUEST : "Error parsing GetEncryptionKeys request",
|
||||
MARSHAL_RESPONSE : "Error marshaling response",
|
||||
}
|
||||
|
||||
return func(key int) string {
|
||||
return _errStrMap[key]
|
||||
}
|
||||
}
|
||||
|
||||
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
|
||||
func NewCipherMap(maxKeys uint64, cipherSize int) cipherMapInstanceSingleton {
|
||||
if cipherMapInstance == nil {
|
||||
cipherMapInstance = make(map[uint64][]byte)
|
||||
|
||||
for i := uint64(1); i<= maxKeys; i++ {
|
||||
cipher := make([]byte, cipherSize)
|
||||
rand.Read(cipher)
|
||||
cipherMapInstance[i] = cipher
|
||||
}
|
||||
log.Printf("KMS cipher map populate done, maxCiphers '%d'", maxCipherKeys)
|
||||
}
|
||||
return cipherMapInstance
|
||||
}
|
||||
|
||||
func getKmsUrls() (urls []string) {
|
||||
urlCount := rand.Intn(5) + 1
|
||||
for i := 1; i <= urlCount; i++ {
|
||||
url := fmt.Sprintf("https://KMS/%d:%d:%d:%d", i, i, i, i)
|
||||
urls = append(urls, url)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func isEncryptDomainIdValid(id int64) bool {
|
||||
if id > 0 || id == -1 || id == -2 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func abs(x int64) int64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func getBaseCipherIdFromDomainId(domainId int64) (baseCipherId uint64) {
|
||||
baseCipherId = uint64(1) + uint64(abs(domainId)) % maxCipherKeys
|
||||
return
|
||||
}
|
||||
|
||||
func getEncryptionKeysByKeyIds(w http.ResponseWriter, byteArr []byte) {
|
||||
req := GetEncryptKeysRequest{}
|
||||
err := json.Unmarshal(byteArr, &req)
|
||||
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||
var e error
|
||||
if shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||
e = fmt.Errorf("[FAULT] %s %s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
|
||||
} else {
|
||||
e = fmt.Errorf("%s %s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var details []CipherDetailRes
|
||||
for i := 0; i < len(req.CipherDetails); i++ {
|
||||
var baseCipherId = uint64(req.CipherDetails[i].BaseCipherId)
|
||||
|
||||
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
|
||||
if !isEncryptDomainIdValid(encryptDomainId) {
|
||||
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
cipher, found := cipherMapInstance[baseCipherId]
|
||||
if !found {
|
||||
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var detail = CipherDetailRes {
|
||||
BaseCipherId: baseCipherId,
|
||||
EncryptDomainId: encryptDomainId,
|
||||
BaseCipher: string(cipher),
|
||||
}
|
||||
details = append(details, detail)
|
||||
}
|
||||
|
||||
var urls []string
|
||||
if req.RefreshKmsUrls {
|
||||
urls = getKmsUrls()
|
||||
}
|
||||
|
||||
resp := GetEncryptKeysResponse{
|
||||
CipherDetails: details,
|
||||
KmsUrls: urls,
|
||||
}
|
||||
|
||||
mResp, err := json.Marshal(resp)
|
||||
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
|
||||
var e error
|
||||
if shouldInjectFault(MARSHAL_RESPONSE) {
|
||||
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
|
||||
} else {
|
||||
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, string(mResp))
|
||||
}
|
||||
|
||||
func getEncryptionKeysByDomainIds(w http.ResponseWriter, byteArr []byte) {
|
||||
req := GetEncryptKeysRequest{}
|
||||
err := json.Unmarshal(byteArr, &req)
|
||||
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||
var e error
|
||||
if shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||
e = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
|
||||
} else {
|
||||
e = fmt.Errorf("%s '%s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var details []CipherDetailRes
|
||||
for i := 0; i < len(req.CipherDetails); i++ {
|
||||
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
|
||||
if !isEncryptDomainIdValid(encryptDomainId) {
|
||||
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var baseCipherId = getBaseCipherIdFromDomainId(encryptDomainId)
|
||||
cipher, found := cipherMapInstance[baseCipherId]
|
||||
if !found {
|
||||
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var detail = CipherDetailRes {
|
||||
BaseCipherId: baseCipherId,
|
||||
EncryptDomainId: encryptDomainId,
|
||||
BaseCipher: string(cipher),
|
||||
}
|
||||
details = append(details, detail)
|
||||
}
|
||||
|
||||
var urls []string
|
||||
if req.RefreshKmsUrls {
|
||||
urls = getKmsUrls()
|
||||
}
|
||||
|
||||
resp := GetEncryptKeysResponse{
|
||||
CipherDetails: details,
|
||||
KmsUrls: urls,
|
||||
}
|
||||
|
||||
mResp, err := json.Marshal(resp)
|
||||
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
|
||||
var e error
|
||||
if shouldInjectFault(MARSHAL_RESPONSE) {
|
||||
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
|
||||
} else {
|
||||
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, string(mResp))
|
||||
}
|
||||
|
||||
func handleGetEncryptionKeys(w http.ResponseWriter, r *http.Request) {
|
||||
byteArr, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil || shouldInjectFault(READ_HTTP_REQUEST_BODY) {
|
||||
var e error
|
||||
if shouldInjectFault(READ_HTTP_REQUEST_BODY) {
|
||||
e = fmt.Errorf("[FAULT] %s", errStrMap()(READ_HTTP_REQUEST_BODY))
|
||||
} else {
|
||||
e = fmt.Errorf("%s err '%v'", errStrMap()(READ_HTTP_REQUEST_BODY), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
var arbitrary_json map[string]interface{}
|
||||
err = json.Unmarshal(byteArr, &arbitrary_json)
|
||||
if err != nil || shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
|
||||
var e error
|
||||
if shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
|
||||
e = fmt.Errorf("[FAULT] %s", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON))
|
||||
} else {
|
||||
e = fmt.Errorf("%s err '%v'", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON), err)
|
||||
}
|
||||
log.Println(e.Error())
|
||||
sendErrorResponse(w, e)
|
||||
return
|
||||
}
|
||||
|
||||
if shouldInjectFault(UNSUPPORTED_QUERY_MODE) {
|
||||
err = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
|
||||
sendErrorResponse(w, err)
|
||||
return
|
||||
} else if arbitrary_json["query_mode"] == "lookupByKeyId" {
|
||||
getEncryptionKeysByKeyIds(w, byteArr)
|
||||
} else if arbitrary_json["query_mode"] == "lookupByDomainId" {
|
||||
getEncryptionKeysByDomainIds(w, byteArr)
|
||||
} else {
|
||||
err = fmt.Errorf("%s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
|
||||
sendErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func initEncryptCipherMap() {
|
||||
cipherMapInstance = NewCipherMap(maxCipherKeys, maxCipherSize)
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* mock_kms.go
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// FoundationDB Mock KMS (Key Management Solution/Service) interface
|
||||
// Interface runs an HTTP server handling REST calls simulating FDB communications
|
||||
// with an external KMS.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KMS supported endpoints
|
||||
const (
|
||||
getEncryptionKeysEndpoint = "/getEncryptionKeys"
|
||||
updateFaultInjectionEndpoint = "/updateFaultInjection"
|
||||
)
|
||||
|
||||
// Routine is responsible to instantiate data-structures necessary for MockKMS functioning
|
||||
func init () {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(2)
|
||||
go func(){
|
||||
initEncryptCipherMap()
|
||||
wg.Done()
|
||||
}()
|
||||
go func(){
|
||||
initFaultLocMap()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
}
|
||||
|
||||
func main() {
|
||||
http.NewServeMux()
|
||||
http.HandleFunc(getEncryptionKeysEndpoint, handleGetEncryptionKeys)
|
||||
http.HandleFunc(updateFaultInjectionEndpoint, handleUpdateFaultInjection)
|
||||
|
||||
log.Fatal(http.ListenAndServe(":5001", nil))
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* mockkms_test.go
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// MockKMS unit tests, the coverage includes:
|
||||
// 1. Mock HttpRequest creation and HttpResponse writer.
|
||||
// 2. Construct fake request to validate the following scenarions:
|
||||
// 2.1. Request with "unsupported query mode"
|
||||
// 2.2. Get encryption keys by KeyIds; with and without 'RefreshKmsUrls' flag.
|
||||
// 2.2. Get encryption keys by DomainIds; with and without 'RefreshKmsUrls' flag.
|
||||
// 2.3. Random fault injection and response validation
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
ByKeyIdReqWithRefreshUrls = `{
|
||||
"query_mode": "lookupByKeyId",
|
||||
"cipher_key_details": [
|
||||
{
|
||||
"base_cipher_id": 77,
|
||||
"encrypt_domain_id": 76
|
||||
},
|
||||
{
|
||||
"base_cipher_id": 2,
|
||||
"encrypt_domain_id": -1
|
||||
}
|
||||
],
|
||||
"validation_tokens": [
|
||||
{
|
||||
"token_name": "1",
|
||||
"token_value":"12344"
|
||||
},
|
||||
{
|
||||
"token_name": "2",
|
||||
"token_value":"12334"
|
||||
}
|
||||
],
|
||||
"refresh_kms_urls": true
|
||||
}`
|
||||
ByKeyIdReqWithoutRefreshUrls = `{
|
||||
"query_mode": "lookupByKeyId",
|
||||
"cipher_key_details": [
|
||||
{
|
||||
"base_cipher_id": 77,
|
||||
"encrypt_domain_id": 76
|
||||
},
|
||||
{
|
||||
"base_cipher_id": 2,
|
||||
"encrypt_domain_id": -1
|
||||
}
|
||||
],
|
||||
"validation_tokens": [
|
||||
{
|
||||
"token_name": "1",
|
||||
"token_value":"12344"
|
||||
},
|
||||
{
|
||||
"token_name": "2",
|
||||
"token_value":"12334"
|
||||
}
|
||||
],
|
||||
"refresh_kms_urls": false
|
||||
}`
|
||||
ByDomainIdReqWithRefreshUrls = `{
|
||||
"query_mode": "lookupByDomainId",
|
||||
"cipher_key_details": [
|
||||
{
|
||||
"encrypt_domain_id": 76
|
||||
},
|
||||
{
|
||||
"encrypt_domain_id": -1
|
||||
}
|
||||
],
|
||||
"validation_tokens": [
|
||||
{
|
||||
"token_name": "1",
|
||||
"token_value":"12344"
|
||||
},
|
||||
{
|
||||
"token_name": "2",
|
||||
"token_value":"12334"
|
||||
}
|
||||
],
|
||||
"refresh_kms_urls": true
|
||||
}`
|
||||
ByDomainIdReqWithoutRefreshUrls = `{
|
||||
"query_mode": "lookupByDomainId",
|
||||
"cipher_key_details": [
|
||||
{
|
||||
"encrypt_domain_id": 76
|
||||
},
|
||||
{
|
||||
"encrypt_domain_id": -1
|
||||
}
|
||||
],
|
||||
"validation_tokens": [
|
||||
{
|
||||
"token_name": "1",
|
||||
"token_value":"12344"
|
||||
},
|
||||
{
|
||||
"token_name": "2",
|
||||
"token_value":"12334"
|
||||
}
|
||||
],
|
||||
"refresh_kms_urls": false
|
||||
}`
|
||||
UnsupportedQueryMode= `{
|
||||
"query_mode": "foo_mode",
|
||||
"cipher_key_details": [
|
||||
{
|
||||
"encrypt_domain_id": 76
|
||||
},
|
||||
{
|
||||
"encrypt_domain_id": -1
|
||||
}
|
||||
],
|
||||
"validation_tokens": [
|
||||
{
|
||||
"token_name": "1",
|
||||
"token_value":"12344"
|
||||
},
|
||||
{
|
||||
"token_name": "2",
|
||||
"token_value":"12334"
|
||||
}
|
||||
],
|
||||
"refresh_kms_urls": false
|
||||
}`
|
||||
)
|
||||
|
||||
func unmarshalValidResponse(data []byte, t *testing.T) (resp GetEncryptKeysResponse) {
|
||||
resp = GetEncryptKeysResponse{}
|
||||
err := json.Unmarshal(data, &resp)
|
||||
if err != nil {
|
||||
t.Errorf("Error unmarshaling valid response '%s' error '%v'", string(data), err)
|
||||
t.Fail()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func unmarshalErrorResponse(data []byte, t *testing.T) (resp ErrorResponse) {
|
||||
resp = ErrorResponse{}
|
||||
err := json.Unmarshal(data, &resp)
|
||||
if err != nil {
|
||||
t.Errorf("Error unmarshaling error response resp '%s' error '%v'", string(data), err)
|
||||
t.Fail()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func checkGetEncyptKeysResponseValidity(resp GetEncryptKeysResponse, t *testing.T) {
|
||||
if len(resp.CipherDetails) != 2 {
|
||||
t.Errorf("Unexpected CipherDetails count, expected '%d' actual '%d'", 2, len(resp.CipherDetails))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
baseCipherIds := [...]uint64 {uint64(77), uint64(2)}
|
||||
encryptDomainIds := [...]int64 {int64(76), int64(-1)}
|
||||
|
||||
for i := 0; i < len(resp.CipherDetails); i++ {
|
||||
if resp.CipherDetails[i].BaseCipherId != baseCipherIds[i] {
|
||||
t.Errorf("Mismatch BaseCipherId, expected '%d' actual '%d'", baseCipherIds[i], resp.CipherDetails[i].BaseCipherId)
|
||||
t.Fail()
|
||||
}
|
||||
if resp.CipherDetails[i].EncryptDomainId != encryptDomainIds[i] {
|
||||
t.Errorf("Mismatch EncryptDomainId, expected '%d' actual '%d'", encryptDomainIds[i], resp.CipherDetails[i].EncryptDomainId)
|
||||
t.Fail()
|
||||
}
|
||||
if len(resp.CipherDetails[i].BaseCipher) == 0 {
|
||||
t.Error("Empty BaseCipher")
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runQueryExpectingErrorResponse(payload string, url string, errSubStr string, t *testing.T) {
|
||||
body := strings.NewReader(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, url, body)
|
||||
w := httptest.NewRecorder()
|
||||
handleGetEncryptionKeys(w, req)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error %v", err)
|
||||
}
|
||||
|
||||
resp := unmarshalErrorResponse(data, t)
|
||||
if !strings.Contains(resp.Err.Detail, errSubStr) {
|
||||
t.Errorf("Unexpected error response '%s'", resp.Err.Detail)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
func runQueryExpectingValidResponse(payload string, url string, t *testing.T) {
|
||||
body := strings.NewReader(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, url, body)
|
||||
w := httptest.NewRecorder()
|
||||
handleGetEncryptionKeys(w, req)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Error %v", err)
|
||||
}
|
||||
|
||||
resp := unmarshalValidResponse(data, t)
|
||||
checkGetEncyptKeysResponseValidity(resp, t)
|
||||
}
|
||||
|
||||
func TestUnsupportedQueryMode(t *testing.T) {
|
||||
runQueryExpectingErrorResponse(UnsupportedQueryMode, getEncryptionKeysEndpoint, errStrMap()(UNSUPPORTED_QUERY_MODE), t)
|
||||
}
|
||||
|
||||
func TestGetEncryptionKeysByKeyIdsWithRefreshUrls(t *testing.T) {
|
||||
runQueryExpectingValidResponse(ByKeyIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||
}
|
||||
|
||||
func TestGetEncryptionKeysByKeyIdsWithoutRefreshUrls(t *testing.T) {
|
||||
runQueryExpectingValidResponse(ByKeyIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||
}
|
||||
|
||||
func TestGetEncryptionKeysByDomainIdsWithRefreshUrls(t *testing.T) {
|
||||
runQueryExpectingValidResponse(ByDomainIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||
}
|
||||
|
||||
func TestGetEncryptionKeysByDomainIdsWithoutRefreshUrls(t *testing.T) {
|
||||
runQueryExpectingValidResponse(ByDomainIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||
}
|
||||
|
||||
func TestFaultInjection(t *testing.T) {
|
||||
numIterations := rand.Intn(701) + 86
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
loc := rand.Intn(MARSHAL_RESPONSE + 1)
|
||||
f := Fault{}
|
||||
f.Location = loc
|
||||
f.Enable = true
|
||||
|
||||
var faults []Fault
|
||||
faults = append(faults, f)
|
||||
fW := httptest.NewRecorder()
|
||||
body := strings.NewReader(jsonifyFaultArr(fW, faults))
|
||||
fReq := httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
|
||||
handleUpdateFaultInjection(fW, fReq)
|
||||
if !shouldInjectFault(loc) {
|
||||
t.Errorf("Expected fault enabled for loc '%d'", loc)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
var payload string
|
||||
lottery := rand.Intn(100)
|
||||
if lottery < 25 {
|
||||
payload = ByKeyIdReqWithRefreshUrls
|
||||
} else if lottery >= 25 && lottery < 50 {
|
||||
payload = ByKeyIdReqWithoutRefreshUrls
|
||||
} else if lottery >= 50 && lottery < 75 {
|
||||
payload = ByDomainIdReqWithRefreshUrls
|
||||
} else {
|
||||
payload = ByDomainIdReqWithoutRefreshUrls
|
||||
}
|
||||
runQueryExpectingErrorResponse(payload, getEncryptionKeysEndpoint, errStrMap()(loc), t)
|
||||
|
||||
// reset Fault
|
||||
faults[0].Enable = false
|
||||
fW = httptest.NewRecorder()
|
||||
body = strings.NewReader(jsonifyFaultArr(fW, faults))
|
||||
fReq = httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
|
||||
handleUpdateFaultInjection(fW, fReq)
|
||||
if shouldInjectFault(loc) {
|
||||
t.Errorf("Expected fault disabled for loc '%d'", loc)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* utils.go
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ErrorDetail struct {
|
||||
Detail string `json:"details"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Err ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
func sendErrorResponse(w http.ResponseWriter, err error) {
|
||||
e := ErrorDetail{}
|
||||
e.Detail = fmt.Sprintf("Error: %s", err.Error())
|
||||
resp := ErrorResponse{
|
||||
Err: e,
|
||||
}
|
||||
|
||||
mResp,err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Printf("Error marshalling error response %s", err.Error())
|
||||
panic(err)
|
||||
}
|
||||
fmt.Fprintf(w, string(mResp))
|
||||
}
|
Loading…
Reference in New Issue