Merge pull request #7106 from sfc-gh-ahusain/ahusain-fdb-mock-kms

FDB native MockKMS REST server implementation
This commit is contained in:
Markus Pilman 2022-05-10 09:10:54 -07:00 committed by GitHub
commit 524365083d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 919 additions and 0 deletions

View File

@ -0,0 +1,179 @@
/*
* fault_injection.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Interface supports client to inject fault(s)
// Module enables a client to update { FaultLocation -> FaultStatus } mapping in a
// thread-safe manner, however, client is responsible to synchronize fault status
// updates across 'getEncryptionKeys' REST requests to obtain predictable results.
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"sync"
)
type Fault struct {
Location int `json:"fault_location"`
Enable bool `json:"enable_fault"`
}
type FaultInjectionRequest struct {
Faults []Fault `json:"faults"`
}
type FaultInjectionResponse struct {
Faults []Fault `json:"faults"`
}
type faultLocMap struct {
locMap map[int]bool
rwLock sync.RWMutex
}
var (
faultLocMapInstance *faultLocMap // Singleton mapping of { FaultLocation -> FaultStatus }
)
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
func NewFaultLocMap() *faultLocMap {
if faultLocMapInstance == nil {
faultLocMapInstance = &faultLocMap{}
faultLocMapInstance.rwLock = sync.RWMutex{}
faultLocMapInstance.locMap = map[int]bool {
READ_HTTP_REQUEST_BODY : false,
UNMARSHAL_REQUEST_BODY_JSON : false,
UNSUPPORTED_QUERY_MODE : false,
PARSE_HTTP_REQUEST : false,
MARSHAL_RESPONSE : false,
}
}
return faultLocMapInstance
}
func getLocFaultStatus(loc int) (val bool, found bool) {
if faultLocMapInstance == nil {
panic("FaultLocMap not intialized")
os.Exit(1)
}
faultLocMapInstance.rwLock.RLock()
defer faultLocMapInstance.rwLock.RUnlock()
val, found = faultLocMapInstance.locMap[loc]
if !found {
return
}
return
}
func updateLocFaultStatuses(faults []Fault) (updated []Fault, err error) {
if faultLocMapInstance == nil {
panic("FaultLocMap not intialized")
os.Exit(1)
}
updated = []Fault{}
err = nil
faultLocMapInstance.rwLock.Lock()
defer faultLocMapInstance.rwLock.Unlock()
for i := 0; i < len(faults); i++ {
fault := faults[i]
oldVal, found := faultLocMapInstance.locMap[fault.Location]
if !found {
err = fmt.Errorf("Unknown fault_location '%d'", fault.Location)
return
}
faultLocMapInstance.locMap[fault.Location] = fault.Enable
log.Printf("Update Location '%d' oldVal '%t' newVal '%t'", fault.Location, oldVal, fault.Enable)
}
// return the updated faultLocMap
for loc, enable := range faultLocMapInstance.locMap {
var f Fault
f.Location = loc
f.Enable = enable
updated = append(updated, f)
}
return
}
func jsonifyFaultArr(w http.ResponseWriter, faults []Fault) (jResp string) {
resp := FaultInjectionResponse{
Faults: faults,
}
mResp, err := json.Marshal(resp)
if err != nil {
log.Printf("Error marshaling response '%s'", err.Error())
sendErrorResponse(w, err)
return
}
jResp = string(mResp)
return
}
func updateFaultLocMap(w http.ResponseWriter, faults []Fault) {
updated , err := updateLocFaultStatuses(faults)
if err != nil {
sendErrorResponse(w, err)
return
}
fmt.Fprintf(w, jsonifyFaultArr(w, updated))
}
func shouldInjectFault(loc int) bool {
status, found := getLocFaultStatus(loc)
if !found {
log.Printf("Unknown fault_location '%d'", loc)
return false
}
return status
}
func handleUpdateFaultInjection(w http.ResponseWriter, r *http.Request) {
byteArr, err := ioutil.ReadAll(r.Body)
if err != nil {
log.Printf("Http request body read error '%s'", err.Error())
sendErrorResponse(w, err)
return
}
req := FaultInjectionRequest{}
err = json.Unmarshal(byteArr, &req)
if err != nil {
log.Printf("Error parsing FaultInjectionRequest '%s'", string(byteArr))
sendErrorResponse(w, err)
}
updateFaultLocMap(w, req.Faults)
}
func initFaultLocMap() {
faultLocMapInstance = NewFaultLocMap()
log.Printf("FaultLocMap int done")
}

View File

@ -0,0 +1,321 @@
/*
* get_encryption_keys.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// GetEncryptionKeys handler
// Handler is resposible for the following:
// 1. Parse the incoming HttpRequest and validate JSON request structural sanity
// 2. Ability to handle getEncryptionKeys by 'KeyId' or 'DomainId' as requested
// 3. Ability to inject faults if requested
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net/http"
)
type CipherDetailRes struct {
BaseCipherId uint64 `json:"base_cipher_id"`
EncryptDomainId int64 `json:"encrypt_domain_id"`
BaseCipher string `json:"base_cipher"`
}
type ValidationToken struct {
TokenName string `json:"token_name"`
TokenValue string `json:"token_value"`
}
type CipherDetailReq struct {
BaseCipherId uint64 `json:"base_cipher_id"`
EncryptDomainId int64 `json:"encrypt_domain_id"`
}
type GetEncryptKeysResponse struct {
CipherDetails []CipherDetailRes `json:"cipher_key_details"`
KmsUrls []string `json:"kms_urls"`
}
type GetEncryptKeysRequest struct {
QueryMode string `json:"query_mode"`
CipherDetails []CipherDetailReq `json:"cipher_key_details"`
ValidationTokens []ValidationToken `json:"validation_tokens"`
RefreshKmsUrls bool `json:"refresh_kms_urls"`
}
type cipherMapInstanceSingleton map[uint64][]byte
const (
READ_HTTP_REQUEST_BODY = iota
UNMARSHAL_REQUEST_BODY_JSON
UNSUPPORTED_QUERY_MODE
PARSE_HTTP_REQUEST
MARSHAL_RESPONSE
)
const (
maxCipherKeys = uint64(1024*1024) // Max cipher keys
maxCipherSize = 16 // Max cipher buffer size
)
var (
cipherMapInstance cipherMapInstanceSingleton // Singleton mapping of { baseCipherId -> baseCipher }
)
// const mapping of { Location -> errorString }
func errStrMap() func(int) string {
_errStrMap := map[int]string {
READ_HTTP_REQUEST_BODY : "Http request body read error",
UNMARSHAL_REQUEST_BODY_JSON : "Http request body unmarshal error",
UNSUPPORTED_QUERY_MODE : "Unsupported query_mode",
PARSE_HTTP_REQUEST : "Error parsing GetEncryptionKeys request",
MARSHAL_RESPONSE : "Error marshaling response",
}
return func(key int) string {
return _errStrMap[key]
}
}
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
func NewCipherMap(maxKeys uint64, cipherSize int) cipherMapInstanceSingleton {
if cipherMapInstance == nil {
cipherMapInstance = make(map[uint64][]byte)
for i := uint64(1); i<= maxKeys; i++ {
cipher := make([]byte, cipherSize)
rand.Read(cipher)
cipherMapInstance[i] = cipher
}
log.Printf("KMS cipher map populate done, maxCiphers '%d'", maxCipherKeys)
}
return cipherMapInstance
}
func getKmsUrls() (urls []string) {
urlCount := rand.Intn(5) + 1
for i := 1; i <= urlCount; i++ {
url := fmt.Sprintf("https://KMS/%d:%d:%d:%d", i, i, i, i)
urls = append(urls, url)
}
return
}
func isEncryptDomainIdValid(id int64) bool {
if id > 0 || id == -1 || id == -2 {
return true
}
return false
}
func abs(x int64) int64 {
if x < 0 {
return -x
}
return x
}
func getBaseCipherIdFromDomainId(domainId int64) (baseCipherId uint64) {
baseCipherId = uint64(1) + uint64(abs(domainId)) % maxCipherKeys
return
}
func getEncryptionKeysByKeyIds(w http.ResponseWriter, byteArr []byte) {
req := GetEncryptKeysRequest{}
err := json.Unmarshal(byteArr, &req)
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
var e error
if shouldInjectFault(PARSE_HTTP_REQUEST) {
e = fmt.Errorf("[FAULT] %s %s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
} else {
e = fmt.Errorf("%s %s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var details []CipherDetailRes
for i := 0; i < len(req.CipherDetails); i++ {
var baseCipherId = uint64(req.CipherDetails[i].BaseCipherId)
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
if !isEncryptDomainIdValid(encryptDomainId) {
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
sendErrorResponse(w, e)
return
}
cipher, found := cipherMapInstance[baseCipherId]
if !found {
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
sendErrorResponse(w, e)
return
}
var detail = CipherDetailRes {
BaseCipherId: baseCipherId,
EncryptDomainId: encryptDomainId,
BaseCipher: string(cipher),
}
details = append(details, detail)
}
var urls []string
if req.RefreshKmsUrls {
urls = getKmsUrls()
}
resp := GetEncryptKeysResponse{
CipherDetails: details,
KmsUrls: urls,
}
mResp, err := json.Marshal(resp)
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
var e error
if shouldInjectFault(MARSHAL_RESPONSE) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
fmt.Fprintf(w, string(mResp))
}
func getEncryptionKeysByDomainIds(w http.ResponseWriter, byteArr []byte) {
req := GetEncryptKeysRequest{}
err := json.Unmarshal(byteArr, &req)
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
var e error
if shouldInjectFault(PARSE_HTTP_REQUEST) {
e = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
} else {
e = fmt.Errorf("%s '%s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var details []CipherDetailRes
for i := 0; i < len(req.CipherDetails); i++ {
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
if !isEncryptDomainIdValid(encryptDomainId) {
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
sendErrorResponse(w, e)
return
}
var baseCipherId = getBaseCipherIdFromDomainId(encryptDomainId)
cipher, found := cipherMapInstance[baseCipherId]
if !found {
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
sendErrorResponse(w, e)
return
}
var detail = CipherDetailRes {
BaseCipherId: baseCipherId,
EncryptDomainId: encryptDomainId,
BaseCipher: string(cipher),
}
details = append(details, detail)
}
var urls []string
if req.RefreshKmsUrls {
urls = getKmsUrls()
}
resp := GetEncryptKeysResponse{
CipherDetails: details,
KmsUrls: urls,
}
mResp, err := json.Marshal(resp)
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
var e error
if shouldInjectFault(MARSHAL_RESPONSE) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
fmt.Fprintf(w, string(mResp))
}
func handleGetEncryptionKeys(w http.ResponseWriter, r *http.Request) {
byteArr, err := ioutil.ReadAll(r.Body)
if err != nil || shouldInjectFault(READ_HTTP_REQUEST_BODY) {
var e error
if shouldInjectFault(READ_HTTP_REQUEST_BODY) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(READ_HTTP_REQUEST_BODY))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(READ_HTTP_REQUEST_BODY), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var arbitrary_json map[string]interface{}
err = json.Unmarshal(byteArr, &arbitrary_json)
if err != nil || shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
var e error
if shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
if shouldInjectFault(UNSUPPORTED_QUERY_MODE) {
err = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
sendErrorResponse(w, err)
return
} else if arbitrary_json["query_mode"] == "lookupByKeyId" {
getEncryptionKeysByKeyIds(w, byteArr)
} else if arbitrary_json["query_mode"] == "lookupByDomainId" {
getEncryptionKeysByDomainIds(w, byteArr)
} else {
err = fmt.Errorf("%s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
sendErrorResponse(w, err)
return
}
}
func initEncryptCipherMap() {
cipherMapInstance = NewCipherMap(maxCipherKeys, maxCipherSize)
}

View File

@ -0,0 +1,66 @@
/*
* mock_kms.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// FoundationDB Mock KMS (Key Management Solution/Service) interface
// Interface runs an HTTP server handling REST calls simulating FDB communications
// with an external KMS.
package main
import (
"log"
"math/rand"
"net/http"
"sync"
"time"
)
// KMS supported endpoints
const (
getEncryptionKeysEndpoint = "/getEncryptionKeys"
updateFaultInjectionEndpoint = "/updateFaultInjection"
)
// Routine is responsible to instantiate data-structures necessary for MockKMS functioning
func init () {
var wg sync.WaitGroup
wg.Add(2)
go func(){
initEncryptCipherMap()
wg.Done()
}()
go func(){
initFaultLocMap()
wg.Done()
}()
wg.Wait()
rand.Seed(time.Now().UTC().UnixNano())
}
func main() {
http.NewServeMux()
http.HandleFunc(getEncryptionKeysEndpoint, handleGetEncryptionKeys)
http.HandleFunc(updateFaultInjectionEndpoint, handleUpdateFaultInjection)
log.Fatal(http.ListenAndServe(":5001", nil))
}

View File

@ -0,0 +1,302 @@
/*
* mockkms_test.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// MockKMS unit tests, the coverage includes:
// 1. Mock HttpRequest creation and HttpResponse writer.
// 2. Construct fake request to validate the following scenarions:
// 2.1. Request with "unsupported query mode"
// 2.2. Get encryption keys by KeyIds; with and without 'RefreshKmsUrls' flag.
// 2.2. Get encryption keys by DomainIds; with and without 'RefreshKmsUrls' flag.
// 2.3. Random fault injection and response validation
package main
import (
"encoding/json"
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
const (
ByKeyIdReqWithRefreshUrls = `{
"query_mode": "lookupByKeyId",
"cipher_key_details": [
{
"base_cipher_id": 77,
"encrypt_domain_id": 76
},
{
"base_cipher_id": 2,
"encrypt_domain_id": -1
}
],
"validation_tokens": [
{
"token_name": "1",
"token_value":"12344"
},
{
"token_name": "2",
"token_value":"12334"
}
],
"refresh_kms_urls": true
}`
ByKeyIdReqWithoutRefreshUrls = `{
"query_mode": "lookupByKeyId",
"cipher_key_details": [
{
"base_cipher_id": 77,
"encrypt_domain_id": 76
},
{
"base_cipher_id": 2,
"encrypt_domain_id": -1
}
],
"validation_tokens": [
{
"token_name": "1",
"token_value":"12344"
},
{
"token_name": "2",
"token_value":"12334"
}
],
"refresh_kms_urls": false
}`
ByDomainIdReqWithRefreshUrls = `{
"query_mode": "lookupByDomainId",
"cipher_key_details": [
{
"encrypt_domain_id": 76
},
{
"encrypt_domain_id": -1
}
],
"validation_tokens": [
{
"token_name": "1",
"token_value":"12344"
},
{
"token_name": "2",
"token_value":"12334"
}
],
"refresh_kms_urls": true
}`
ByDomainIdReqWithoutRefreshUrls = `{
"query_mode": "lookupByDomainId",
"cipher_key_details": [
{
"encrypt_domain_id": 76
},
{
"encrypt_domain_id": -1
}
],
"validation_tokens": [
{
"token_name": "1",
"token_value":"12344"
},
{
"token_name": "2",
"token_value":"12334"
}
],
"refresh_kms_urls": false
}`
UnsupportedQueryMode= `{
"query_mode": "foo_mode",
"cipher_key_details": [
{
"encrypt_domain_id": 76
},
{
"encrypt_domain_id": -1
}
],
"validation_tokens": [
{
"token_name": "1",
"token_value":"12344"
},
{
"token_name": "2",
"token_value":"12334"
}
],
"refresh_kms_urls": false
}`
)
func unmarshalValidResponse(data []byte, t *testing.T) (resp GetEncryptKeysResponse) {
resp = GetEncryptKeysResponse{}
err := json.Unmarshal(data, &resp)
if err != nil {
t.Errorf("Error unmarshaling valid response '%s' error '%v'", string(data), err)
t.Fail()
}
return
}
func unmarshalErrorResponse(data []byte, t *testing.T) (resp ErrorResponse) {
resp = ErrorResponse{}
err := json.Unmarshal(data, &resp)
if err != nil {
t.Errorf("Error unmarshaling error response resp '%s' error '%v'", string(data), err)
t.Fail()
}
return
}
func checkGetEncyptKeysResponseValidity(resp GetEncryptKeysResponse, t *testing.T) {
if len(resp.CipherDetails) != 2 {
t.Errorf("Unexpected CipherDetails count, expected '%d' actual '%d'", 2, len(resp.CipherDetails))
t.Fail()
}
baseCipherIds := [...]uint64 {uint64(77), uint64(2)}
encryptDomainIds := [...]int64 {int64(76), int64(-1)}
for i := 0; i < len(resp.CipherDetails); i++ {
if resp.CipherDetails[i].BaseCipherId != baseCipherIds[i] {
t.Errorf("Mismatch BaseCipherId, expected '%d' actual '%d'", baseCipherIds[i], resp.CipherDetails[i].BaseCipherId)
t.Fail()
}
if resp.CipherDetails[i].EncryptDomainId != encryptDomainIds[i] {
t.Errorf("Mismatch EncryptDomainId, expected '%d' actual '%d'", encryptDomainIds[i], resp.CipherDetails[i].EncryptDomainId)
t.Fail()
}
if len(resp.CipherDetails[i].BaseCipher) == 0 {
t.Error("Empty BaseCipher")
t.Fail()
}
}
}
func runQueryExpectingErrorResponse(payload string, url string, errSubStr string, t *testing.T) {
body := strings.NewReader(payload)
req := httptest.NewRequest(http.MethodPost, url, body)
w := httptest.NewRecorder()
handleGetEncryptionKeys(w, req)
res := w.Result()
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Error %v", err)
}
resp := unmarshalErrorResponse(data, t)
if !strings.Contains(resp.Err.Detail, errSubStr) {
t.Errorf("Unexpected error response '%s'", resp.Err.Detail)
t.Fail()
}
}
func runQueryExpectingValidResponse(payload string, url string, t *testing.T) {
body := strings.NewReader(payload)
req := httptest.NewRequest(http.MethodPost, url, body)
w := httptest.NewRecorder()
handleGetEncryptionKeys(w, req)
res := w.Result()
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Error %v", err)
}
resp := unmarshalValidResponse(data, t)
checkGetEncyptKeysResponseValidity(resp, t)
}
func TestUnsupportedQueryMode(t *testing.T) {
runQueryExpectingErrorResponse(UnsupportedQueryMode, getEncryptionKeysEndpoint, errStrMap()(UNSUPPORTED_QUERY_MODE), t)
}
func TestGetEncryptionKeysByKeyIdsWithRefreshUrls(t *testing.T) {
runQueryExpectingValidResponse(ByKeyIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
}
func TestGetEncryptionKeysByKeyIdsWithoutRefreshUrls(t *testing.T) {
runQueryExpectingValidResponse(ByKeyIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
}
func TestGetEncryptionKeysByDomainIdsWithRefreshUrls(t *testing.T) {
runQueryExpectingValidResponse(ByDomainIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
}
func TestGetEncryptionKeysByDomainIdsWithoutRefreshUrls(t *testing.T) {
runQueryExpectingValidResponse(ByDomainIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
}
func TestFaultInjection(t *testing.T) {
numIterations := rand.Intn(701) + 86
for i := 0; i < numIterations; i++ {
loc := rand.Intn(MARSHAL_RESPONSE + 1)
f := Fault{}
f.Location = loc
f.Enable = true
var faults []Fault
faults = append(faults, f)
fW := httptest.NewRecorder()
body := strings.NewReader(jsonifyFaultArr(fW, faults))
fReq := httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
handleUpdateFaultInjection(fW, fReq)
if !shouldInjectFault(loc) {
t.Errorf("Expected fault enabled for loc '%d'", loc)
t.Fail()
}
var payload string
lottery := rand.Intn(100)
if lottery < 25 {
payload = ByKeyIdReqWithRefreshUrls
} else if lottery >= 25 && lottery < 50 {
payload = ByKeyIdReqWithoutRefreshUrls
} else if lottery >= 50 && lottery < 75 {
payload = ByDomainIdReqWithRefreshUrls
} else {
payload = ByDomainIdReqWithoutRefreshUrls
}
runQueryExpectingErrorResponse(payload, getEncryptionKeysEndpoint, errStrMap()(loc), t)
// reset Fault
faults[0].Enable = false
fW = httptest.NewRecorder()
body = strings.NewReader(jsonifyFaultArr(fW, faults))
fReq = httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
handleUpdateFaultInjection(fW, fReq)
if shouldInjectFault(loc) {
t.Errorf("Expected fault disabled for loc '%d'", loc)
t.Fail()
}
}
}

View File

@ -0,0 +1,51 @@
/*
* utils.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
)
type ErrorDetail struct {
Detail string `json:"details"`
}
type ErrorResponse struct {
Err ErrorDetail `json:"error"`
}
func sendErrorResponse(w http.ResponseWriter, err error) {
e := ErrorDetail{}
e.Detail = fmt.Sprintf("Error: %s", err.Error())
resp := ErrorResponse{
Err: e,
}
mResp,err := json.Marshal(resp)
if err != nil {
log.Printf("Error marshalling error response %s", err.Error())
panic(err)
}
fmt.Fprintf(w, string(mResp))
}