Merge pull request #7106 from sfc-gh-ahusain/ahusain-fdb-mock-kms
FDB native MockKMS REST server implementation
This commit is contained in:
commit
524365083d
|
@ -0,0 +1,179 @@
|
||||||
|
/*
|
||||||
|
* fault_injection.go
|
||||||
|
*
|
||||||
|
* This source file is part of the FoundationDB open source project
|
||||||
|
*
|
||||||
|
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Interface supports client to inject fault(s)
|
||||||
|
// Module enables a client to update { FaultLocation -> FaultStatus } mapping in a
|
||||||
|
// thread-safe manner, however, client is responsible to synchronize fault status
|
||||||
|
// updates across 'getEncryptionKeys' REST requests to obtain predictable results.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Fault struct {
|
||||||
|
Location int `json:"fault_location"`
|
||||||
|
Enable bool `json:"enable_fault"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FaultInjectionRequest struct {
|
||||||
|
Faults []Fault `json:"faults"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FaultInjectionResponse struct {
|
||||||
|
Faults []Fault `json:"faults"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type faultLocMap struct {
|
||||||
|
locMap map[int]bool
|
||||||
|
rwLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
faultLocMapInstance *faultLocMap // Singleton mapping of { FaultLocation -> FaultStatus }
|
||||||
|
)
|
||||||
|
|
||||||
|
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
|
||||||
|
func NewFaultLocMap() *faultLocMap {
|
||||||
|
if faultLocMapInstance == nil {
|
||||||
|
faultLocMapInstance = &faultLocMap{}
|
||||||
|
|
||||||
|
faultLocMapInstance.rwLock = sync.RWMutex{}
|
||||||
|
faultLocMapInstance.locMap = map[int]bool {
|
||||||
|
READ_HTTP_REQUEST_BODY : false,
|
||||||
|
UNMARSHAL_REQUEST_BODY_JSON : false,
|
||||||
|
UNSUPPORTED_QUERY_MODE : false,
|
||||||
|
PARSE_HTTP_REQUEST : false,
|
||||||
|
MARSHAL_RESPONSE : false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return faultLocMapInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLocFaultStatus(loc int) (val bool, found bool) {
|
||||||
|
if faultLocMapInstance == nil {
|
||||||
|
panic("FaultLocMap not intialized")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
faultLocMapInstance.rwLock.RLock()
|
||||||
|
defer faultLocMapInstance.rwLock.RUnlock()
|
||||||
|
val, found = faultLocMapInstance.locMap[loc]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateLocFaultStatuses(faults []Fault) (updated []Fault, err error) {
|
||||||
|
if faultLocMapInstance == nil {
|
||||||
|
panic("FaultLocMap not intialized")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated = []Fault{}
|
||||||
|
err = nil
|
||||||
|
|
||||||
|
faultLocMapInstance.rwLock.Lock()
|
||||||
|
defer faultLocMapInstance.rwLock.Unlock()
|
||||||
|
for i := 0; i < len(faults); i++ {
|
||||||
|
fault := faults[i]
|
||||||
|
|
||||||
|
oldVal, found := faultLocMapInstance.locMap[fault.Location]
|
||||||
|
if !found {
|
||||||
|
err = fmt.Errorf("Unknown fault_location '%d'", fault.Location)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
faultLocMapInstance.locMap[fault.Location] = fault.Enable
|
||||||
|
log.Printf("Update Location '%d' oldVal '%t' newVal '%t'", fault.Location, oldVal, fault.Enable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the updated faultLocMap
|
||||||
|
for loc, enable := range faultLocMapInstance.locMap {
|
||||||
|
var f Fault
|
||||||
|
f.Location = loc
|
||||||
|
f.Enable = enable
|
||||||
|
updated = append(updated, f)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonifyFaultArr(w http.ResponseWriter, faults []Fault) (jResp string) {
|
||||||
|
resp := FaultInjectionResponse{
|
||||||
|
Faults: faults,
|
||||||
|
}
|
||||||
|
|
||||||
|
mResp, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error marshaling response '%s'", err.Error())
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jResp = string(mResp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateFaultLocMap(w http.ResponseWriter, faults []Fault) {
|
||||||
|
updated , err := updateLocFaultStatuses(faults)
|
||||||
|
if err != nil {
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, jsonifyFaultArr(w, updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldInjectFault(loc int) bool {
|
||||||
|
status, found := getLocFaultStatus(loc)
|
||||||
|
if !found {
|
||||||
|
log.Printf("Unknown fault_location '%d'", loc)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleUpdateFaultInjection(w http.ResponseWriter, r *http.Request) {
|
||||||
|
byteArr, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Http request body read error '%s'", err.Error())
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := FaultInjectionRequest{}
|
||||||
|
err = json.Unmarshal(byteArr, &req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error parsing FaultInjectionRequest '%s'", string(byteArr))
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
}
|
||||||
|
updateFaultLocMap(w, req.Faults)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initFaultLocMap() {
|
||||||
|
faultLocMapInstance = NewFaultLocMap()
|
||||||
|
log.Printf("FaultLocMap int done")
|
||||||
|
}
|
|
@ -0,0 +1,321 @@
|
||||||
|
/*
|
||||||
|
* get_encryption_keys.go
|
||||||
|
*
|
||||||
|
* This source file is part of the FoundationDB open source project
|
||||||
|
*
|
||||||
|
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// GetEncryptionKeys handler
|
||||||
|
// Handler is resposible for the following:
|
||||||
|
// 1. Parse the incoming HttpRequest and validate JSON request structural sanity
|
||||||
|
// 2. Ability to handle getEncryptionKeys by 'KeyId' or 'DomainId' as requested
|
||||||
|
// 3. Ability to inject faults if requested
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CipherDetailRes struct {
|
||||||
|
BaseCipherId uint64 `json:"base_cipher_id"`
|
||||||
|
EncryptDomainId int64 `json:"encrypt_domain_id"`
|
||||||
|
BaseCipher string `json:"base_cipher"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ValidationToken struct {
|
||||||
|
TokenName string `json:"token_name"`
|
||||||
|
TokenValue string `json:"token_value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CipherDetailReq struct {
|
||||||
|
BaseCipherId uint64 `json:"base_cipher_id"`
|
||||||
|
EncryptDomainId int64 `json:"encrypt_domain_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetEncryptKeysResponse struct {
|
||||||
|
CipherDetails []CipherDetailRes `json:"cipher_key_details"`
|
||||||
|
KmsUrls []string `json:"kms_urls"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetEncryptKeysRequest struct {
|
||||||
|
QueryMode string `json:"query_mode"`
|
||||||
|
CipherDetails []CipherDetailReq `json:"cipher_key_details"`
|
||||||
|
ValidationTokens []ValidationToken `json:"validation_tokens"`
|
||||||
|
RefreshKmsUrls bool `json:"refresh_kms_urls"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type cipherMapInstanceSingleton map[uint64][]byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
READ_HTTP_REQUEST_BODY = iota
|
||||||
|
UNMARSHAL_REQUEST_BODY_JSON
|
||||||
|
UNSUPPORTED_QUERY_MODE
|
||||||
|
PARSE_HTTP_REQUEST
|
||||||
|
MARSHAL_RESPONSE
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxCipherKeys = uint64(1024*1024) // Max cipher keys
|
||||||
|
maxCipherSize = 16 // Max cipher buffer size
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cipherMapInstance cipherMapInstanceSingleton // Singleton mapping of { baseCipherId -> baseCipher }
|
||||||
|
)
|
||||||
|
|
||||||
|
// const mapping of { Location -> errorString }
|
||||||
|
func errStrMap() func(int) string {
|
||||||
|
_errStrMap := map[int]string {
|
||||||
|
READ_HTTP_REQUEST_BODY : "Http request body read error",
|
||||||
|
UNMARSHAL_REQUEST_BODY_JSON : "Http request body unmarshal error",
|
||||||
|
UNSUPPORTED_QUERY_MODE : "Unsupported query_mode",
|
||||||
|
PARSE_HTTP_REQUEST : "Error parsing GetEncryptionKeys request",
|
||||||
|
MARSHAL_RESPONSE : "Error marshaling response",
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(key int) string {
|
||||||
|
return _errStrMap[key]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
|
||||||
|
func NewCipherMap(maxKeys uint64, cipherSize int) cipherMapInstanceSingleton {
|
||||||
|
if cipherMapInstance == nil {
|
||||||
|
cipherMapInstance = make(map[uint64][]byte)
|
||||||
|
|
||||||
|
for i := uint64(1); i<= maxKeys; i++ {
|
||||||
|
cipher := make([]byte, cipherSize)
|
||||||
|
rand.Read(cipher)
|
||||||
|
cipherMapInstance[i] = cipher
|
||||||
|
}
|
||||||
|
log.Printf("KMS cipher map populate done, maxCiphers '%d'", maxCipherKeys)
|
||||||
|
}
|
||||||
|
return cipherMapInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKmsUrls() (urls []string) {
|
||||||
|
urlCount := rand.Intn(5) + 1
|
||||||
|
for i := 1; i <= urlCount; i++ {
|
||||||
|
url := fmt.Sprintf("https://KMS/%d:%d:%d:%d", i, i, i, i)
|
||||||
|
urls = append(urls, url)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEncryptDomainIdValid(id int64) bool {
|
||||||
|
if id > 0 || id == -1 || id == -2 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func abs(x int64) int64 {
|
||||||
|
if x < 0 {
|
||||||
|
return -x
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBaseCipherIdFromDomainId(domainId int64) (baseCipherId uint64) {
|
||||||
|
baseCipherId = uint64(1) + uint64(abs(domainId)) % maxCipherKeys
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEncryptionKeysByKeyIds(w http.ResponseWriter, byteArr []byte) {
|
||||||
|
req := GetEncryptKeysRequest{}
|
||||||
|
err := json.Unmarshal(byteArr, &req)
|
||||||
|
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s %s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s %s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var details []CipherDetailRes
|
||||||
|
for i := 0; i < len(req.CipherDetails); i++ {
|
||||||
|
var baseCipherId = uint64(req.CipherDetails[i].BaseCipherId)
|
||||||
|
|
||||||
|
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
|
||||||
|
if !isEncryptDomainIdValid(encryptDomainId) {
|
||||||
|
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cipher, found := cipherMapInstance[baseCipherId]
|
||||||
|
if !found {
|
||||||
|
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var detail = CipherDetailRes {
|
||||||
|
BaseCipherId: baseCipherId,
|
||||||
|
EncryptDomainId: encryptDomainId,
|
||||||
|
BaseCipher: string(cipher),
|
||||||
|
}
|
||||||
|
details = append(details, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
if req.RefreshKmsUrls {
|
||||||
|
urls = getKmsUrls()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := GetEncryptKeysResponse{
|
||||||
|
CipherDetails: details,
|
||||||
|
KmsUrls: urls,
|
||||||
|
}
|
||||||
|
|
||||||
|
mResp, err := json.Marshal(resp)
|
||||||
|
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(MARSHAL_RESPONSE) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, string(mResp))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEncryptionKeysByDomainIds(w http.ResponseWriter, byteArr []byte) {
|
||||||
|
req := GetEncryptKeysRequest{}
|
||||||
|
err := json.Unmarshal(byteArr, &req)
|
||||||
|
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(PARSE_HTTP_REQUEST) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s '%s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var details []CipherDetailRes
|
||||||
|
for i := 0; i < len(req.CipherDetails); i++ {
|
||||||
|
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
|
||||||
|
if !isEncryptDomainIdValid(encryptDomainId) {
|
||||||
|
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var baseCipherId = getBaseCipherIdFromDomainId(encryptDomainId)
|
||||||
|
cipher, found := cipherMapInstance[baseCipherId]
|
||||||
|
if !found {
|
||||||
|
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var detail = CipherDetailRes {
|
||||||
|
BaseCipherId: baseCipherId,
|
||||||
|
EncryptDomainId: encryptDomainId,
|
||||||
|
BaseCipher: string(cipher),
|
||||||
|
}
|
||||||
|
details = append(details, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
if req.RefreshKmsUrls {
|
||||||
|
urls = getKmsUrls()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := GetEncryptKeysResponse{
|
||||||
|
CipherDetails: details,
|
||||||
|
KmsUrls: urls,
|
||||||
|
}
|
||||||
|
|
||||||
|
mResp, err := json.Marshal(resp)
|
||||||
|
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(MARSHAL_RESPONSE) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, string(mResp))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleGetEncryptionKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
|
byteArr, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil || shouldInjectFault(READ_HTTP_REQUEST_BODY) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(READ_HTTP_REQUEST_BODY) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s", errStrMap()(READ_HTTP_REQUEST_BODY))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s err '%v'", errStrMap()(READ_HTTP_REQUEST_BODY), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var arbitrary_json map[string]interface{}
|
||||||
|
err = json.Unmarshal(byteArr, &arbitrary_json)
|
||||||
|
if err != nil || shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
|
||||||
|
var e error
|
||||||
|
if shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
|
||||||
|
e = fmt.Errorf("[FAULT] %s", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON))
|
||||||
|
} else {
|
||||||
|
e = fmt.Errorf("%s err '%v'", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON), err)
|
||||||
|
}
|
||||||
|
log.Println(e.Error())
|
||||||
|
sendErrorResponse(w, e)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInjectFault(UNSUPPORTED_QUERY_MODE) {
|
||||||
|
err = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
return
|
||||||
|
} else if arbitrary_json["query_mode"] == "lookupByKeyId" {
|
||||||
|
getEncryptionKeysByKeyIds(w, byteArr)
|
||||||
|
} else if arbitrary_json["query_mode"] == "lookupByDomainId" {
|
||||||
|
getEncryptionKeysByDomainIds(w, byteArr)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("%s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
|
||||||
|
sendErrorResponse(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initEncryptCipherMap() {
|
||||||
|
cipherMapInstance = NewCipherMap(maxCipherKeys, maxCipherSize)
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*
|
||||||
|
* mock_kms.go
|
||||||
|
*
|
||||||
|
* This source file is part of the FoundationDB open source project
|
||||||
|
*
|
||||||
|
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// FoundationDB Mock KMS (Key Management Solution/Service) interface
|
||||||
|
// Interface runs an HTTP server handling REST calls simulating FDB communications
|
||||||
|
// with an external KMS.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KMS supported endpoints
|
||||||
|
const (
|
||||||
|
getEncryptionKeysEndpoint = "/getEncryptionKeys"
|
||||||
|
updateFaultInjectionEndpoint = "/updateFaultInjection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Routine is responsible to instantiate data-structures necessary for MockKMS functioning
|
||||||
|
func init () {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(2)
|
||||||
|
go func(){
|
||||||
|
initEncryptCipherMap()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func(){
|
||||||
|
initFaultLocMap()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
rand.Seed(time.Now().UTC().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
http.NewServeMux()
|
||||||
|
http.HandleFunc(getEncryptionKeysEndpoint, handleGetEncryptionKeys)
|
||||||
|
http.HandleFunc(updateFaultInjectionEndpoint, handleUpdateFaultInjection)
|
||||||
|
|
||||||
|
log.Fatal(http.ListenAndServe(":5001", nil))
|
||||||
|
}
|
|
@ -0,0 +1,302 @@
|
||||||
|
/*
|
||||||
|
* mockkms_test.go
|
||||||
|
*
|
||||||
|
* This source file is part of the FoundationDB open source project
|
||||||
|
*
|
||||||
|
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// MockKMS unit tests, the coverage includes:
|
||||||
|
// 1. Mock HttpRequest creation and HttpResponse writer.
|
||||||
|
// 2. Construct fake request to validate the following scenarions:
|
||||||
|
// 2.1. Request with "unsupported query mode"
|
||||||
|
// 2.2. Get encryption keys by KeyIds; with and without 'RefreshKmsUrls' flag.
|
||||||
|
// 2.2. Get encryption keys by DomainIds; with and without 'RefreshKmsUrls' flag.
|
||||||
|
// 2.3. Random fault injection and response validation
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ByKeyIdReqWithRefreshUrls = `{
|
||||||
|
"query_mode": "lookupByKeyId",
|
||||||
|
"cipher_key_details": [
|
||||||
|
{
|
||||||
|
"base_cipher_id": 77,
|
||||||
|
"encrypt_domain_id": 76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"base_cipher_id": 2,
|
||||||
|
"encrypt_domain_id": -1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"validation_tokens": [
|
||||||
|
{
|
||||||
|
"token_name": "1",
|
||||||
|
"token_value":"12344"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_name": "2",
|
||||||
|
"token_value":"12334"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"refresh_kms_urls": true
|
||||||
|
}`
|
||||||
|
ByKeyIdReqWithoutRefreshUrls = `{
|
||||||
|
"query_mode": "lookupByKeyId",
|
||||||
|
"cipher_key_details": [
|
||||||
|
{
|
||||||
|
"base_cipher_id": 77,
|
||||||
|
"encrypt_domain_id": 76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"base_cipher_id": 2,
|
||||||
|
"encrypt_domain_id": -1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"validation_tokens": [
|
||||||
|
{
|
||||||
|
"token_name": "1",
|
||||||
|
"token_value":"12344"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_name": "2",
|
||||||
|
"token_value":"12334"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"refresh_kms_urls": false
|
||||||
|
}`
|
||||||
|
ByDomainIdReqWithRefreshUrls = `{
|
||||||
|
"query_mode": "lookupByDomainId",
|
||||||
|
"cipher_key_details": [
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": 76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": -1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"validation_tokens": [
|
||||||
|
{
|
||||||
|
"token_name": "1",
|
||||||
|
"token_value":"12344"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_name": "2",
|
||||||
|
"token_value":"12334"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"refresh_kms_urls": true
|
||||||
|
}`
|
||||||
|
ByDomainIdReqWithoutRefreshUrls = `{
|
||||||
|
"query_mode": "lookupByDomainId",
|
||||||
|
"cipher_key_details": [
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": 76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": -1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"validation_tokens": [
|
||||||
|
{
|
||||||
|
"token_name": "1",
|
||||||
|
"token_value":"12344"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_name": "2",
|
||||||
|
"token_value":"12334"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"refresh_kms_urls": false
|
||||||
|
}`
|
||||||
|
UnsupportedQueryMode= `{
|
||||||
|
"query_mode": "foo_mode",
|
||||||
|
"cipher_key_details": [
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": 76
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"encrypt_domain_id": -1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"validation_tokens": [
|
||||||
|
{
|
||||||
|
"token_name": "1",
|
||||||
|
"token_value":"12344"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"token_name": "2",
|
||||||
|
"token_value":"12334"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"refresh_kms_urls": false
|
||||||
|
}`
|
||||||
|
)
|
||||||
|
|
||||||
|
func unmarshalValidResponse(data []byte, t *testing.T) (resp GetEncryptKeysResponse) {
|
||||||
|
resp = GetEncryptKeysResponse{}
|
||||||
|
err := json.Unmarshal(data, &resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error unmarshaling valid response '%s' error '%v'", string(data), err)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalErrorResponse(data []byte, t *testing.T) (resp ErrorResponse) {
|
||||||
|
resp = ErrorResponse{}
|
||||||
|
err := json.Unmarshal(data, &resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error unmarshaling error response resp '%s' error '%v'", string(data), err)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkGetEncyptKeysResponseValidity(resp GetEncryptKeysResponse, t *testing.T) {
|
||||||
|
if len(resp.CipherDetails) != 2 {
|
||||||
|
t.Errorf("Unexpected CipherDetails count, expected '%d' actual '%d'", 2, len(resp.CipherDetails))
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
baseCipherIds := [...]uint64 {uint64(77), uint64(2)}
|
||||||
|
encryptDomainIds := [...]int64 {int64(76), int64(-1)}
|
||||||
|
|
||||||
|
for i := 0; i < len(resp.CipherDetails); i++ {
|
||||||
|
if resp.CipherDetails[i].BaseCipherId != baseCipherIds[i] {
|
||||||
|
t.Errorf("Mismatch BaseCipherId, expected '%d' actual '%d'", baseCipherIds[i], resp.CipherDetails[i].BaseCipherId)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
if resp.CipherDetails[i].EncryptDomainId != encryptDomainIds[i] {
|
||||||
|
t.Errorf("Mismatch EncryptDomainId, expected '%d' actual '%d'", encryptDomainIds[i], resp.CipherDetails[i].EncryptDomainId)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
if len(resp.CipherDetails[i].BaseCipher) == 0 {
|
||||||
|
t.Error("Empty BaseCipher")
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runQueryExpectingErrorResponse(payload string, url string, errSubStr string, t *testing.T) {
|
||||||
|
body := strings.NewReader(payload)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, url, body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handleGetEncryptionKeys(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
data, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := unmarshalErrorResponse(data, t)
|
||||||
|
if !strings.Contains(resp.Err.Detail, errSubStr) {
|
||||||
|
t.Errorf("Unexpected error response '%s'", resp.Err.Detail)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runQueryExpectingValidResponse(payload string, url string, t *testing.T) {
|
||||||
|
body := strings.NewReader(payload)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, url, body)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handleGetEncryptionKeys(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
data, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := unmarshalValidResponse(data, t)
|
||||||
|
checkGetEncyptKeysResponseValidity(resp, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsupportedQueryMode(t *testing.T) {
|
||||||
|
runQueryExpectingErrorResponse(UnsupportedQueryMode, getEncryptionKeysEndpoint, errStrMap()(UNSUPPORTED_QUERY_MODE), t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEncryptionKeysByKeyIdsWithRefreshUrls(t *testing.T) {
|
||||||
|
runQueryExpectingValidResponse(ByKeyIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEncryptionKeysByKeyIdsWithoutRefreshUrls(t *testing.T) {
|
||||||
|
runQueryExpectingValidResponse(ByKeyIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEncryptionKeysByDomainIdsWithRefreshUrls(t *testing.T) {
|
||||||
|
runQueryExpectingValidResponse(ByDomainIdReqWithRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEncryptionKeysByDomainIdsWithoutRefreshUrls(t *testing.T) {
|
||||||
|
runQueryExpectingValidResponse(ByDomainIdReqWithoutRefreshUrls, getEncryptionKeysEndpoint, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFaultInjection(t *testing.T) {
|
||||||
|
numIterations := rand.Intn(701) + 86
|
||||||
|
|
||||||
|
for i := 0; i < numIterations; i++ {
|
||||||
|
loc := rand.Intn(MARSHAL_RESPONSE + 1)
|
||||||
|
f := Fault{}
|
||||||
|
f.Location = loc
|
||||||
|
f.Enable = true
|
||||||
|
|
||||||
|
var faults []Fault
|
||||||
|
faults = append(faults, f)
|
||||||
|
fW := httptest.NewRecorder()
|
||||||
|
body := strings.NewReader(jsonifyFaultArr(fW, faults))
|
||||||
|
fReq := httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
|
||||||
|
handleUpdateFaultInjection(fW, fReq)
|
||||||
|
if !shouldInjectFault(loc) {
|
||||||
|
t.Errorf("Expected fault enabled for loc '%d'", loc)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload string
|
||||||
|
lottery := rand.Intn(100)
|
||||||
|
if lottery < 25 {
|
||||||
|
payload = ByKeyIdReqWithRefreshUrls
|
||||||
|
} else if lottery >= 25 && lottery < 50 {
|
||||||
|
payload = ByKeyIdReqWithoutRefreshUrls
|
||||||
|
} else if lottery >= 50 && lottery < 75 {
|
||||||
|
payload = ByDomainIdReqWithRefreshUrls
|
||||||
|
} else {
|
||||||
|
payload = ByDomainIdReqWithoutRefreshUrls
|
||||||
|
}
|
||||||
|
runQueryExpectingErrorResponse(payload, getEncryptionKeysEndpoint, errStrMap()(loc), t)
|
||||||
|
|
||||||
|
// reset Fault
|
||||||
|
faults[0].Enable = false
|
||||||
|
fW = httptest.NewRecorder()
|
||||||
|
body = strings.NewReader(jsonifyFaultArr(fW, faults))
|
||||||
|
fReq = httptest.NewRequest(http.MethodPost, updateFaultInjectionEndpoint, body)
|
||||||
|
handleUpdateFaultInjection(fW, fReq)
|
||||||
|
if shouldInjectFault(loc) {
|
||||||
|
t.Errorf("Expected fault disabled for loc '%d'", loc)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* utils.go
|
||||||
|
*
|
||||||
|
* This source file is part of the FoundationDB open source project
|
||||||
|
*
|
||||||
|
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrorDetail struct {
|
||||||
|
Detail string `json:"details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Err ErrorDetail `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendErrorResponse(w http.ResponseWriter, err error) {
|
||||||
|
e := ErrorDetail{}
|
||||||
|
e.Detail = fmt.Sprintf("Error: %s", err.Error())
|
||||||
|
resp := ErrorResponse{
|
||||||
|
Err: e,
|
||||||
|
}
|
||||||
|
|
||||||
|
mResp,err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error marshalling error response %s", err.Error())
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, string(mResp))
|
||||||
|
}
|
Loading…
Reference in New Issue