foundationdb/contrib/mockkms/get_encryption_keys.go

321 lines
10 KiB
Go

/*
* get_encryption_keys.go
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// GetEncryptionKeys handler
// Handler is resposible for the following:
// 1. Parse the incoming HttpRequest and validate JSON request structural sanity
// 2. Ability to handle getEncryptionKeys by 'KeyId' or 'DomainId' as requested
// 3. Ability to inject faults if requested
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net/http"
)
type CipherDetailRes struct {
BaseCipherId uint64 `json:"base_cipher_id"`
EncryptDomainId int64 `json:"encrypt_domain_id"`
BaseCipher string `json:"base_cipher"`
}
type ValidationToken struct {
TokenName string `json:"token_name"`
TokenValue string `json:"token_value"`
}
type CipherDetailReq struct {
BaseCipherId uint64 `json:"base_cipher_id"`
EncryptDomainId int64 `json:"encrypt_domain_id"`
}
type GetEncryptKeysResponse struct {
CipherDetails []CipherDetailRes `json:"cipher_key_details"`
KmsUrls []string `json:"kms_urls"`
}
type GetEncryptKeysRequest struct {
QueryMode string `json:"query_mode"`
CipherDetails []CipherDetailReq `json:"cipher_key_details"`
ValidationTokens []ValidationToken `json:"validation_tokens"`
RefreshKmsUrls bool `json:"refresh_kms_urls"`
}
type cipherMapInstanceSingleton map[uint64][]byte
const (
READ_HTTP_REQUEST_BODY = iota
UNMARSHAL_REQUEST_BODY_JSON
UNSUPPORTED_QUERY_MODE
PARSE_HTTP_REQUEST
MARSHAL_RESPONSE
)
const (
maxCipherKeys = uint64(1024*1024) // Max cipher keys
maxCipherSize = 16 // Max cipher buffer size
)
var (
cipherMapInstance cipherMapInstanceSingleton // Singleton mapping of { baseCipherId -> baseCipher }
)
// const mapping of { Location -> errorString }
func errStrMap() func(int) string {
_errStrMap := map[int]string {
READ_HTTP_REQUEST_BODY : "Http request body read error",
UNMARSHAL_REQUEST_BODY_JSON : "Http request body unmarshal error",
UNSUPPORTED_QUERY_MODE : "Unsupported query_mode",
PARSE_HTTP_REQUEST : "Error parsing GetEncryptionKeys request",
MARSHAL_RESPONSE : "Error marshaling response",
}
return func(key int) string {
return _errStrMap[key]
}
}
// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
func NewCipherMap(maxKeys uint64, cipherSize int) cipherMapInstanceSingleton {
if cipherMapInstance == nil {
cipherMapInstance = make(map[uint64][]byte)
for i := uint64(1); i<= maxKeys; i++ {
cipher := make([]byte, cipherSize)
rand.Read(cipher)
cipherMapInstance[i] = cipher
}
log.Printf("KMS cipher map populate done, maxCiphers '%d'", maxCipherKeys)
}
return cipherMapInstance
}
func getKmsUrls() (urls []string) {
urlCount := rand.Intn(5) + 1
for i := 1; i <= urlCount; i++ {
url := fmt.Sprintf("https://KMS/%d:%d:%d:%d", i, i, i, i)
urls = append(urls, url)
}
return
}
func isEncryptDomainIdValid(id int64) bool {
if id > 0 || id == -1 || id == -2 {
return true
}
return false
}
func abs(x int64) int64 {
if x < 0 {
return -x
}
return x
}
func getBaseCipherIdFromDomainId(domainId int64) (baseCipherId uint64) {
baseCipherId = uint64(1) + uint64(abs(domainId)) % maxCipherKeys
return
}
func getEncryptionKeysByKeyIds(w http.ResponseWriter, byteArr []byte) {
req := GetEncryptKeysRequest{}
err := json.Unmarshal(byteArr, &req)
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
var e error
if shouldInjectFault(PARSE_HTTP_REQUEST) {
e = fmt.Errorf("[FAULT] %s %s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
} else {
e = fmt.Errorf("%s %s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var details []CipherDetailRes
for i := 0; i < len(req.CipherDetails); i++ {
var baseCipherId = uint64(req.CipherDetails[i].BaseCipherId)
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
if !isEncryptDomainIdValid(encryptDomainId) {
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
sendErrorResponse(w, e)
return
}
cipher, found := cipherMapInstance[baseCipherId]
if !found {
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
sendErrorResponse(w, e)
return
}
var detail = CipherDetailRes {
BaseCipherId: baseCipherId,
EncryptDomainId: encryptDomainId,
BaseCipher: string(cipher),
}
details = append(details, detail)
}
var urls []string
if req.RefreshKmsUrls {
urls = getKmsUrls()
}
resp := GetEncryptKeysResponse{
CipherDetails: details,
KmsUrls: urls,
}
mResp, err := json.Marshal(resp)
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
var e error
if shouldInjectFault(MARSHAL_RESPONSE) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
fmt.Fprintf(w, string(mResp))
}
func getEncryptionKeysByDomainIds(w http.ResponseWriter, byteArr []byte) {
req := GetEncryptKeysRequest{}
err := json.Unmarshal(byteArr, &req)
if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
var e error
if shouldInjectFault(PARSE_HTTP_REQUEST) {
e = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
} else {
e = fmt.Errorf("%s '%s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var details []CipherDetailRes
for i := 0; i < len(req.CipherDetails); i++ {
var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
if !isEncryptDomainIdValid(encryptDomainId) {
e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
sendErrorResponse(w, e)
return
}
var baseCipherId = getBaseCipherIdFromDomainId(encryptDomainId)
cipher, found := cipherMapInstance[baseCipherId]
if !found {
e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
sendErrorResponse(w, e)
return
}
var detail = CipherDetailRes {
BaseCipherId: baseCipherId,
EncryptDomainId: encryptDomainId,
BaseCipher: string(cipher),
}
details = append(details, detail)
}
var urls []string
if req.RefreshKmsUrls {
urls = getKmsUrls()
}
resp := GetEncryptKeysResponse{
CipherDetails: details,
KmsUrls: urls,
}
mResp, err := json.Marshal(resp)
if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
var e error
if shouldInjectFault(MARSHAL_RESPONSE) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
fmt.Fprintf(w, string(mResp))
}
func handleGetEncryptionKeys(w http.ResponseWriter, r *http.Request) {
byteArr, err := ioutil.ReadAll(r.Body)
if err != nil || shouldInjectFault(READ_HTTP_REQUEST_BODY) {
var e error
if shouldInjectFault(READ_HTTP_REQUEST_BODY) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(READ_HTTP_REQUEST_BODY))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(READ_HTTP_REQUEST_BODY), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
var arbitrary_json map[string]interface{}
err = json.Unmarshal(byteArr, &arbitrary_json)
if err != nil || shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
var e error
if shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
e = fmt.Errorf("[FAULT] %s", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON))
} else {
e = fmt.Errorf("%s err '%v'", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON), err)
}
log.Println(e.Error())
sendErrorResponse(w, e)
return
}
if shouldInjectFault(UNSUPPORTED_QUERY_MODE) {
err = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
sendErrorResponse(w, err)
return
} else if arbitrary_json["query_mode"] == "lookupByKeyId" {
getEncryptionKeysByKeyIds(w, byteArr)
} else if arbitrary_json["query_mode"] == "lookupByDomainId" {
getEncryptionKeysByDomainIds(w, byteArr)
} else {
err = fmt.Errorf("%s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
sendErrorResponse(w, err)
return
}
}
func initEncryptCipherMap() {
cipherMapInstance = NewCipherMap(maxCipherKeys, maxCipherSize)
}