Move JWT "kid" field from claims to header

This commit is contained in:
Junhyun Shim 2022-07-13 20:21:27 +02:00
parent 24317aa6be
commit 11a9fe9aff
3 changed files with 36 additions and 44 deletions

View File

@ -161,12 +161,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime,
TEST(true); // Token can't be parsed
return false;
}
if (!t.keyId.present()) {
TEST(true); // Token with no key id
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "NoKeyID");
return false;
}
auto key = FlowTransport::transport().getPublicKeyByName(t.keyId.get());
auto key = FlowTransport::transport().getPublicKeyByName(t.keyId);
if (!key.present()) {
TEST(true); // Token referencing non-existing key
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "UnknownKey");
@ -258,10 +253,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
[](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); },
"NoKeyWithSuchName",
},
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); },
"NoKeyId",
},
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); },
"NoExpirationTime",
@ -282,10 +273,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
},
"TokenNotYetValid",
},
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); },
"UnknownKey",
},
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
"NoTenants",

View File

@ -185,14 +185,13 @@ void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optional<Fi
StringRef TokenRef::toStringRef(Arena& arena) {
auto buf = fmt::memory_buffer();
fmt::format_to(std::back_inserter(buf), "alg={}", getAlgorithmName(algorithm));
fmt::format_to(std::back_inserter(buf), "alg={} kid={}", getAlgorithmName(algorithm), keyId.toStringView());
appendField(buf, "iss", issuer);
appendField(buf, "sub", subject);
appendField(buf, "aud", audience);
appendField(buf, "iat", issuedAtUnixTime);
appendField(buf, "exp", expiresAtUnixTime);
appendField(buf, "nbf", notBeforeUnixTime);
appendField(buf, "kid", keyId);
appendField(buf, "jti", tokenId);
appendField(buf, "tenants", tenants);
auto str = new (arena) uint8_t[buf.size()];
@ -231,9 +230,12 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
header.StartObject();
header.Key("typ");
header.String("JWT");
header.Key("alg");
auto algo = getAlgorithmName(tokenSpec.algorithm);
header.Key("alg");
header.String(algo.data(), algo.size());
auto kid = tokenSpec.keyId.toStringView();
header.Key("kid");
header.String(kid.begin(), kid.size());
header.EndObject();
payload.StartObject();
putField(tokenSpec.issuer, payload, "iss");
@ -242,7 +244,6 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
putField(tokenSpec.issuedAtUnixTime, payload, "iat");
putField(tokenSpec.expiresAtUnixTime, payload, "exp");
putField(tokenSpec.notBeforeUnixTime, payload, "nbf");
putField(tokenSpec.keyId, payload, "kid");
putField(tokenSpec.tokenId, payload, "jti");
putField(tokenSpec.tenants, payload, "tenants");
payload.EndObject();
@ -279,7 +280,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey) {
return StringRef(out, totalLen);
}
bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) {
bool parseHeaderPart(Arena& arena, TokenRef& token, StringRef b64urlHeader) {
auto tmpArena = Arena();
auto optHeader = base64url::decode(tmpArena, b64urlHeader);
if (!optHeader.present())
@ -295,24 +296,30 @@ bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) {
.detail("Offset", d.GetErrorOffset());
return false;
}
auto algItr = d.FindMember("alg");
if (!d.IsObject())
return false;
auto typItr = d.FindMember("typ");
if (d.IsObject() && algItr != d.MemberEnd() && typItr != d.MemberEnd()) {
auto const& alg = algItr->value;
auto const& typ = typItr->value;
if (alg.IsString() && typ.IsString()) {
auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength());
auto algType = algorithmFromString(algValue);
if (algType == Algorithm::UNKNOWN)
return false;
token.algorithm = algType;
auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength());
if (typValue != "JWT"_sr)
return false;
return true;
}
}
return false;
if (typItr == d.MemberEnd() || !typItr->value.IsString())
return false;
auto algItr = d.FindMember("alg");
if (algItr == d.MemberEnd() || !algItr->value.IsString())
return false;
auto kidItr = d.FindMember("kid");
if (kidItr == d.MemberEnd() || !kidItr->value.IsString())
return false;
auto const& typ = typItr->value;
auto const& alg = algItr->value;
auto const& kid = kidItr->value;
auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength());
if (typValue != "JWT"_sr)
return false;
auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength());
auto algType = algorithmFromString(algValue);
if (algType == Algorithm::UNKNOWN)
return false;
token.algorithm = algType;
token.keyId = StringRef(arena, reinterpret_cast<const uint8_t*>(kid.GetString()), kid.GetStringLength());
return true;
}
template <class FieldType>
@ -382,8 +389,6 @@ bool parsePayloadPart(Arena& arena, TokenRef& token, StringRef b64urlPayload) {
return false;
if (!parseField(arena, token.notBeforeUnixTime, d, "nbf"))
return false;
if (!parseField(arena, token.keyId, d, "kid"))
return false;
if (!parseField(arena, token.tenants, d, "tenants"))
return false;
return true;
@ -409,7 +414,7 @@ bool parseToken(Arena& arena, TokenRef& token, StringRef signedToken) {
auto b64urlSignature = signedToken;
if (b64urlHeader.empty() || b64urlPayload.empty() || b64urlSignature.empty())
return false;
if (!parseHeaderPart(token, b64urlHeader))
if (!parseHeaderPart(arena, token, b64urlHeader))
return false;
if (!parsePayloadPart(arena, token, b64urlPayload))
return false;
@ -432,7 +437,7 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) {
return false;
auto sig = optSig.get();
auto parsedToken = TokenRef();
if (!parseHeaderPart(parsedToken, b64urlHeader))
if (!parseHeaderPart(arena, parsedToken, b64urlHeader))
return false;
auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm);
if (!checkVerifyAlgorithm(verifyAlgo, publicKey))
@ -446,6 +451,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
}
auto ret = TokenRef{};
ret.algorithm = alg;
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
ret.issuer = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.subject = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.tokenId = genRandomAlphanumStringRef(arena, rng, 31);
@ -457,7 +463,6 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
ret.issuedAtUnixTime = timer_int() / 1'000'000'000ul;
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
auto numTenants = rng.randomInt(1, 3);
auto tenants = new (arena) StringRef[numTenants];
for (auto i = 0; i < numTenants; i++)
@ -553,7 +558,7 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
auto arena = Arena();
auto tokenStr = t.toStringRef(arena);
auto tokenStrExpected =
"alg=ES256 iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 kid=keyId jti=tokenId tenants=[tenant1,tenant2]"_sr;
"alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr;
if (tokenStr != tokenStrExpected) {
fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView());
ASSERT(false);

View File

@ -85,6 +85,7 @@ namespace authz::jwt {
struct TokenRef {
// header part ("typ": "JWT" implicitly enforced)
Algorithm algorithm; // alg
StringRef keyId; // kid
// payload part
Optional<StringRef> issuer; // iss
Optional<StringRef> subject; // sub
@ -92,7 +93,6 @@ struct TokenRef {
Optional<uint64_t> issuedAtUnixTime; // iat
Optional<uint64_t> expiresAtUnixTime; // exp
Optional<uint64_t> notBeforeUnixTime; // nbf
Optional<StringRef> keyId; // kid
Optional<StringRef> tokenId; // jti
Optional<VectorRef<StringRef>> tenants; // tenants
// signature part
@ -113,7 +113,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey);
// Parse passed b64url-encoded header part and materialize its contents into tokenOut,
// using memory allocated from arena
bool parseHeaderPart(TokenRef& tokenOut, StringRef b64urlHeaderIn);
bool parseHeaderPart(Arena& arena, TokenRef& tokenOut, StringRef b64urlHeaderIn);
// Parse passed b64url-encoded payload part and materialize its contents into tokenOut,
// using memory allocated from arena