Move JWT "kid" field from claims to header

This commit is contained in:
Junhyun Shim 2022-07-13 20:21:27 +02:00
parent 24317aa6be
commit 11a9fe9aff
3 changed files with 36 additions and 44 deletions

View File

@ -161,12 +161,7 @@ bool TokenCacheImpl::validateAndAdd(double currentTime,
TEST(true); // Token can't be parsed TEST(true); // Token can't be parsed
return false; return false;
} }
if (!t.keyId.present()) { auto key = FlowTransport::transport().getPublicKeyByName(t.keyId);
TEST(true); // Token with no key id
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "NoKeyID");
return false;
}
auto key = FlowTransport::transport().getPublicKeyByName(t.keyId.get());
if (!key.present()) { if (!key.present()) {
TEST(true); // Token referencing non-existing key TEST(true); // Token referencing non-existing key
TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "UnknownKey"); TraceEvent(SevWarn, "InvalidToken").detail("From", peer).detail("Reason", "UnknownKey");
@ -258,10 +253,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
[](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); }, [](Arena&, IRandom&, authz::jwt::TokenRef&) { FlowTransport::transport().removeAllPublicKeys(); },
"NoKeyWithSuchName", "NoKeyWithSuchName",
}, },
{
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); },
"NoKeyId",
},
{ {
[](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); }, [](Arena&, IRandom&, authz::jwt::TokenRef& token) { token.expiresAtUnixTime.reset(); },
"NoExpirationTime", "NoExpirationTime",
@ -282,10 +273,6 @@ TEST_CASE("/fdbrpc/authz/TokenCache/BadTokens") {
}, },
"TokenNotYetValid", "TokenNotYetValid",
}, },
{
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.keyId.reset(); },
"UnknownKey",
},
{ {
[](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); }, [](Arena& arena, IRandom&, authz::jwt::TokenRef& token) { token.tenants.reset(); },
"NoTenants", "NoTenants",

View File

@ -185,14 +185,13 @@ void appendField(fmt::memory_buffer& b, char const (&name)[NameLen], Optional<Fi
StringRef TokenRef::toStringRef(Arena& arena) { StringRef TokenRef::toStringRef(Arena& arena) {
auto buf = fmt::memory_buffer(); auto buf = fmt::memory_buffer();
fmt::format_to(std::back_inserter(buf), "alg={}", getAlgorithmName(algorithm)); fmt::format_to(std::back_inserter(buf), "alg={} kid={}", getAlgorithmName(algorithm), keyId.toStringView());
appendField(buf, "iss", issuer); appendField(buf, "iss", issuer);
appendField(buf, "sub", subject); appendField(buf, "sub", subject);
appendField(buf, "aud", audience); appendField(buf, "aud", audience);
appendField(buf, "iat", issuedAtUnixTime); appendField(buf, "iat", issuedAtUnixTime);
appendField(buf, "exp", expiresAtUnixTime); appendField(buf, "exp", expiresAtUnixTime);
appendField(buf, "nbf", notBeforeUnixTime); appendField(buf, "nbf", notBeforeUnixTime);
appendField(buf, "kid", keyId);
appendField(buf, "jti", tokenId); appendField(buf, "jti", tokenId);
appendField(buf, "tenants", tenants); appendField(buf, "tenants", tenants);
auto str = new (arena) uint8_t[buf.size()]; auto str = new (arena) uint8_t[buf.size()];
@ -231,9 +230,12 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
header.StartObject(); header.StartObject();
header.Key("typ"); header.Key("typ");
header.String("JWT"); header.String("JWT");
header.Key("alg");
auto algo = getAlgorithmName(tokenSpec.algorithm); auto algo = getAlgorithmName(tokenSpec.algorithm);
header.Key("alg");
header.String(algo.data(), algo.size()); header.String(algo.data(), algo.size());
auto kid = tokenSpec.keyId.toStringView();
header.Key("kid");
header.String(kid.begin(), kid.size());
header.EndObject(); header.EndObject();
payload.StartObject(); payload.StartObject();
putField(tokenSpec.issuer, payload, "iss"); putField(tokenSpec.issuer, payload, "iss");
@ -242,7 +244,6 @@ StringRef makeTokenPart(Arena& arena, TokenRef tokenSpec) {
putField(tokenSpec.issuedAtUnixTime, payload, "iat"); putField(tokenSpec.issuedAtUnixTime, payload, "iat");
putField(tokenSpec.expiresAtUnixTime, payload, "exp"); putField(tokenSpec.expiresAtUnixTime, payload, "exp");
putField(tokenSpec.notBeforeUnixTime, payload, "nbf"); putField(tokenSpec.notBeforeUnixTime, payload, "nbf");
putField(tokenSpec.keyId, payload, "kid");
putField(tokenSpec.tokenId, payload, "jti"); putField(tokenSpec.tokenId, payload, "jti");
putField(tokenSpec.tenants, payload, "tenants"); putField(tokenSpec.tenants, payload, "tenants");
payload.EndObject(); payload.EndObject();
@ -279,7 +280,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey) {
return StringRef(out, totalLen); return StringRef(out, totalLen);
} }
bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) { bool parseHeaderPart(Arena& arena, TokenRef& token, StringRef b64urlHeader) {
auto tmpArena = Arena(); auto tmpArena = Arena();
auto optHeader = base64url::decode(tmpArena, b64urlHeader); auto optHeader = base64url::decode(tmpArena, b64urlHeader);
if (!optHeader.present()) if (!optHeader.present())
@ -295,24 +296,30 @@ bool parseHeaderPart(TokenRef& token, StringRef b64urlHeader) {
.detail("Offset", d.GetErrorOffset()); .detail("Offset", d.GetErrorOffset());
return false; return false;
} }
auto algItr = d.FindMember("alg"); if (!d.IsObject())
return false;
auto typItr = d.FindMember("typ"); auto typItr = d.FindMember("typ");
if (d.IsObject() && algItr != d.MemberEnd() && typItr != d.MemberEnd()) { if (typItr == d.MemberEnd() || !typItr->value.IsString())
auto const& alg = algItr->value; return false;
auto const& typ = typItr->value; auto algItr = d.FindMember("alg");
if (alg.IsString() && typ.IsString()) { if (algItr == d.MemberEnd() || !algItr->value.IsString())
auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength()); return false;
auto algType = algorithmFromString(algValue); auto kidItr = d.FindMember("kid");
if (algType == Algorithm::UNKNOWN) if (kidItr == d.MemberEnd() || !kidItr->value.IsString())
return false; return false;
token.algorithm = algType; auto const& typ = typItr->value;
auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength()); auto const& alg = algItr->value;
if (typValue != "JWT"_sr) auto const& kid = kidItr->value;
return false; auto typValue = StringRef(reinterpret_cast<const uint8_t*>(typ.GetString()), typ.GetStringLength());
return true; if (typValue != "JWT"_sr)
} return false;
} auto algValue = StringRef(reinterpret_cast<const uint8_t*>(alg.GetString()), alg.GetStringLength());
return false; auto algType = algorithmFromString(algValue);
if (algType == Algorithm::UNKNOWN)
return false;
token.algorithm = algType;
token.keyId = StringRef(arena, reinterpret_cast<const uint8_t*>(kid.GetString()), kid.GetStringLength());
return true;
} }
template <class FieldType> template <class FieldType>
@ -382,8 +389,6 @@ bool parsePayloadPart(Arena& arena, TokenRef& token, StringRef b64urlPayload) {
return false; return false;
if (!parseField(arena, token.notBeforeUnixTime, d, "nbf")) if (!parseField(arena, token.notBeforeUnixTime, d, "nbf"))
return false; return false;
if (!parseField(arena, token.keyId, d, "kid"))
return false;
if (!parseField(arena, token.tenants, d, "tenants")) if (!parseField(arena, token.tenants, d, "tenants"))
return false; return false;
return true; return true;
@ -409,7 +414,7 @@ bool parseToken(Arena& arena, TokenRef& token, StringRef signedToken) {
auto b64urlSignature = signedToken; auto b64urlSignature = signedToken;
if (b64urlHeader.empty() || b64urlPayload.empty() || b64urlSignature.empty()) if (b64urlHeader.empty() || b64urlPayload.empty() || b64urlSignature.empty())
return false; return false;
if (!parseHeaderPart(token, b64urlHeader)) if (!parseHeaderPart(arena, token, b64urlHeader))
return false; return false;
if (!parsePayloadPart(arena, token, b64urlPayload)) if (!parsePayloadPart(arena, token, b64urlPayload))
return false; return false;
@ -432,7 +437,7 @@ bool verifyToken(StringRef signedToken, PublicKey publicKey) {
return false; return false;
auto sig = optSig.get(); auto sig = optSig.get();
auto parsedToken = TokenRef(); auto parsedToken = TokenRef();
if (!parseHeaderPart(parsedToken, b64urlHeader)) if (!parseHeaderPart(arena, parsedToken, b64urlHeader))
return false; return false;
auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm); auto [verifyAlgo, digest] = getMethod(parsedToken.algorithm);
if (!checkVerifyAlgorithm(verifyAlgo, publicKey)) if (!checkVerifyAlgorithm(verifyAlgo, publicKey))
@ -446,6 +451,7 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
} }
auto ret = TokenRef{}; auto ret = TokenRef{};
ret.algorithm = alg; ret.algorithm = alg;
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
ret.issuer = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1); ret.issuer = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.subject = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1); ret.subject = genRandomAlphanumStringRef(arena, rng, MaxIssuerNameLenPlus1);
ret.tokenId = genRandomAlphanumStringRef(arena, rng, 31); ret.tokenId = genRandomAlphanumStringRef(arena, rng, 31);
@ -457,7 +463,6 @@ TokenRef makeRandomTokenSpec(Arena& arena, IRandom& rng, Algorithm alg) {
ret.issuedAtUnixTime = timer_int() / 1'000'000'000ul; ret.issuedAtUnixTime = timer_int() / 1'000'000'000ul;
ret.notBeforeUnixTime = ret.issuedAtUnixTime.get(); ret.notBeforeUnixTime = ret.issuedAtUnixTime.get();
ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1); ret.expiresAtUnixTime = ret.issuedAtUnixTime.get() + rng.randomInt(360, 1080 + 1);
ret.keyId = genRandomAlphanumStringRef(arena, rng, MaxKeyNameLenPlus1);
auto numTenants = rng.randomInt(1, 3); auto numTenants = rng.randomInt(1, 3);
auto tenants = new (arena) StringRef[numTenants]; auto tenants = new (arena) StringRef[numTenants];
for (auto i = 0; i < numTenants; i++) for (auto i = 0; i < numTenants; i++)
@ -553,7 +558,7 @@ TEST_CASE("/fdbrpc/TokenSign/JWT/ToStringRef") {
auto arena = Arena(); auto arena = Arena();
auto tokenStr = t.toStringRef(arena); auto tokenStr = t.toStringRef(arena);
auto tokenStrExpected = auto tokenStrExpected =
"alg=ES256 iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 kid=keyId jti=tokenId tenants=[tenant1,tenant2]"_sr; "alg=ES256 kid=keyId iss=issuer sub=subject aud=[aud1,aud2,aud3] iat=123 exp=456 nbf=789 jti=tokenId tenants=[tenant1,tenant2]"_sr;
if (tokenStr != tokenStrExpected) { if (tokenStr != tokenStrExpected) {
fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView()); fmt::print("Expected: {}\nGot : {}\n", tokenStrExpected.toStringView(), tokenStr.toStringView());
ASSERT(false); ASSERT(false);

View File

@ -85,6 +85,7 @@ namespace authz::jwt {
struct TokenRef { struct TokenRef {
// header part ("typ": "JWT" implicitly enforced) // header part ("typ": "JWT" implicitly enforced)
Algorithm algorithm; // alg Algorithm algorithm; // alg
StringRef keyId; // kid
// payload part // payload part
Optional<StringRef> issuer; // iss Optional<StringRef> issuer; // iss
Optional<StringRef> subject; // sub Optional<StringRef> subject; // sub
@ -92,7 +93,6 @@ struct TokenRef {
Optional<uint64_t> issuedAtUnixTime; // iat Optional<uint64_t> issuedAtUnixTime; // iat
Optional<uint64_t> expiresAtUnixTime; // exp Optional<uint64_t> expiresAtUnixTime; // exp
Optional<uint64_t> notBeforeUnixTime; // nbf Optional<uint64_t> notBeforeUnixTime; // nbf
Optional<StringRef> keyId; // kid
Optional<StringRef> tokenId; // jti Optional<StringRef> tokenId; // jti
Optional<VectorRef<StringRef>> tenants; // tenants Optional<VectorRef<StringRef>> tenants; // tenants
// signature part // signature part
@ -113,7 +113,7 @@ StringRef signToken(Arena& arena, TokenRef tokenSpec, PrivateKey privateKey);
// Parse passed b64url-encoded header part and materialize its contents into tokenOut, // Parse passed b64url-encoded header part and materialize its contents into tokenOut,
// using memory allocated from arena // using memory allocated from arena
bool parseHeaderPart(TokenRef& tokenOut, StringRef b64urlHeaderIn); bool parseHeaderPart(Arena& arena, TokenRef& tokenOut, StringRef b64urlHeaderIn);
// Parse passed b64url-encoded payload part and materialize its contents into tokenOut, // Parse passed b64url-encoded payload part and materialize its contents into tokenOut,
// using memory allocated from arena // using memory allocated from arena