Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option to return id token in json output instead of redirect uri #2166

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checkstyle-suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@
<suppress checks="MethodName"
files="OAuthConfig.java"
lines="1-9999"/>
<suppress checks="MemberName"
files="OIDCResponse.java"
lines="1-9999"/>
<suppress checks="MethodName"
files="OIDCResponse.java"
lines="1-9999"/>
</suppressions>
13 changes: 9 additions & 4 deletions clients/go/zts/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1156,17 +1156,22 @@ func (client ZTSClient) PostAccessTokenRequest(request AccessTokenRequest) (*Acc
}
}

func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32) (*OIDCResponse, string, error) {
func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32, output SimpleName) (*OIDCResponse, string, error) {
var data *OIDCResponse
url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime))
url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime), encodeStringParam("output", string(output), ""))
resp, err := client.httpGet(url, nil)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
switch resp.StatusCode {
case 302:
data = nil
case 200, 302:
if 302 != resp.StatusCode {
err = json.NewDecoder(resp.Body).Decode(&data)
if err != nil {
return nil, "", err
}
}
location := resp.Header.Get(rdl.FoldHttpHeaderName("Location"))
return data, location, nil
default:
Expand Down
42 changes: 37 additions & 5 deletions clients/go/zts/model.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions clients/go/zts/zts_schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion clients/java/zts/examples/tls-support/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@
</plugins>
</build>

</project>
</project>
124 changes: 49 additions & 75 deletions clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateParsingException;
import java.security.cert.X509Certificate;
Expand All @@ -46,7 +45,6 @@
import com.oath.auth.KeyRefresherException;
import com.oath.auth.KeyRefresherListener;
import com.oath.auth.Utils;
import com.yahoo.athenz.auth.token.IdToken;
import com.yahoo.athenz.auth.token.jwts.JwtsSigningKeyResolver;
import org.apache.http.HttpHost;
import org.apache.http.client.config.RequestConfig;
Expand Down Expand Up @@ -169,7 +167,7 @@ public class ZTSClient implements Closeable {
final static ConcurrentHashMap<String, RoleToken> ROLE_TOKEN_CACHE = new ConcurrentHashMap<>();
final static ConcurrentHashMap<String, AccessTokenResponseCacheEntry> ACCESS_TOKEN_CACHE = new ConcurrentHashMap<>();
final static ConcurrentHashMap<String, AWSTemporaryCredentials> AWS_CREDS_CACHE = new ConcurrentHashMap<>();
final static ConcurrentHashMap<String, IdTokenCacheEntry> ID_TOKEN_CACHE = new ConcurrentHashMap<>();
final static ConcurrentHashMap<String, OIDCResponse> ID_TOKEN_CACHE = new ConcurrentHashMap<>();

private static final Queue<PrefetchTokenScheduledItem> PREFETCH_SCHEDULED_ITEMS = new ConcurrentLinkedQueue<>();
private static Timer FETCH_TIMER;
Expand Down Expand Up @@ -1752,13 +1750,13 @@ static void processPrefetchTask(PrefetchTokenScheduledItem item, ZTSClient itemZ

case ID:

String idToken = itemZtsClient.getIDToken(item.responseType, item.idTokenServiceName,
OIDCResponse oidcResponse = itemZtsClient.getIDToken(item.responseType, item.idTokenServiceName,
item.redirectUri, item.scope, item.state, item.keyType, item.fullArn,
item.maxDuration, true);

if (!idToken.isEmpty()) {
item.setExpiresAtUTC(extractIdTokenExpiry(idToken));
}
// update the expiry time

item.setExpiresAtUTC(oidcResponse.getExpiration_time());
}

// update the fetch/fail times
Expand Down Expand Up @@ -1926,28 +1924,16 @@ public boolean prefetchAccessToken(String domainName, List<String> roleNames,
public boolean prefetchIdToken(String responseType, String clientId, String redirectUri, String scope,
String state, String keyType, Boolean fullArn, Integer expiryTime) {

String idToken = getIDToken(responseType, clientId, redirectUri, scope, state,
OIDCResponse oidcResponse = getIDToken(responseType, clientId, redirectUri, scope, state,
keyType, fullArn, expiryTime, true);
if (isEmpty(idToken)) {
if (oidcResponse == null) {
LOG.error("PrefetchToken: No id token fetchable for client id={} and scope={}", clientId, scope);
return false;
}

return prefetchToken(null, null, null, null, expiryTime, null, null, clientId, null,
responseType, redirectUri, scope, state, keyType, fullArn,
extractIdTokenExpiry(idToken), TokenType.ID);
}

protected static long extractIdTokenExpiry(final String idToken) {

// strip out the signature part
int idx = idToken.lastIndexOf('.');
if (idx == -1) {
return 0;
}

IdToken token = new IdToken(idToken.substring(0, idx + 1), (PublicKey) null);
return token.getExpiryTime();
oidcResponse.getExpiration_time(), TokenType.ID);
}

boolean prefetchToken(String domainName, String roleName, List<String> roleNames,
Expand Down Expand Up @@ -3094,7 +3080,7 @@ public Info getInfo() {
* server default timeout.
* @return ZTS generated ID Token String. ZTSClientException will be thrown in case of failure
*/
public String getIDToken(String domainName, String roleName, String clientId, String redirectUriSuffix,
public OIDCResponse getIDToken(String domainName, String roleName, String clientId, String redirectUriSuffix,
boolean fullArn, Integer expiryTime) {
return getIDToken(domainName, Collections.singletonList(roleName), clientId,
redirectUriSuffix, fullArn, expiryTime);
Expand All @@ -3116,7 +3102,7 @@ public String getIDToken(String domainName, String roleName, String clientId, St
* server default timeout.
* @return ZTS generated ID Token String. ZTSClientException will be thrown in case of failure
*/
public String getIDToken(String domainName, List<String> roleNames, String clientId, String redirectUriSuffix,
public OIDCResponse getIDToken(String domainName, List<String> roleNames, String clientId, String redirectUriSuffix,
boolean fullArn, Integer expiryTime) {
final String redirectUri = generateRedirectUri(clientId, redirectUriSuffix);
final String scope = generateIdTokenScope(domainName, roleNames);
Expand All @@ -3140,7 +3126,7 @@ public String getIDToken(String domainName, List<String> roleNames, String clien
* server default timeout.
* @return ZTS generated ID Token String. ZTSClientException will be thrown in case of failure
*/
public String getIDToken(String responseType, String clientId, String redirectUri, String scope, String state,
public OIDCResponse getIDToken(String responseType, String clientId, String redirectUri, String scope, String state,
String keyType, Boolean fullArn, Integer expiryTime, boolean ignoreCache) {

// check for required attributes
Expand All @@ -3149,7 +3135,7 @@ public String getIDToken(String responseType, String clientId, String redirectUr
throw new ZTSClientException(ResourceException.BAD_REQUEST, "missing required attribute(s)");
}

String idToken;
OIDCResponse oidcResponse = null;

// first lookup in our cache to see if it can be satisfied
// only if we're not asked to ignore the cache
Expand All @@ -3158,18 +3144,18 @@ public String getIDToken(String responseType, String clientId, String redirectUr
if (!cacheDisabled) {
cacheKey = getIdTokenCacheKey(responseType, clientId, redirectUri, scope, state, keyType, fullArn);
if (cacheKey != null && !ignoreCache) {
idToken = lookupIdTokenResponseInCache(cacheKey, expiryTime);
if (idToken != null) {
return idToken;
oidcResponse = lookupIdTokenResponseInCache(cacheKey, expiryTime);
if (oidcResponse != null) {
return oidcResponse;
}
// start prefetch for this token if prefetch is enabled
if (enablePrefetch && prefetchAutoEnable) {
if (prefetchIdToken(responseType, clientId, redirectUri, scope,
state, keyType, fullArn, expiryTime)) {
idToken = lookupIdTokenResponseInCache(cacheKey, expiryTime);
oidcResponse = lookupIdTokenResponseInCache(cacheKey, expiryTime);
}
if (idToken != null) {
return idToken;
if (oidcResponse != null) {
return oidcResponse;
}
LOG.error("GetIdToken: cache prefetch and lookup error");
}
Expand All @@ -3181,10 +3167,8 @@ public String getIDToken(String responseType, String clientId, String redirectUr
updateServicePrincipal();
try {
Map<String, List<String>> responseHeaders = new HashMap<>();
ztsClient.getOIDCResponse(responseType, clientId, redirectUri, scope,
state, Crypto.randomSalt(), keyType, fullArn, expiryTime, responseHeaders);

idToken = extractIdTokenFromLocation(responseHeaders, redirectUri, state);
oidcResponse = ztsClient.getOIDCResponse(responseType, clientId, redirectUri, scope,
state, Crypto.randomSalt(), keyType, fullArn, expiryTime, "json", responseHeaders);

} catch (ResourceException ex) {

Expand All @@ -3193,9 +3177,9 @@ public String getIDToken(String responseType, String clientId, String redirectUr
// if we have an entry in our cache then we'll return that
// instead of returning failure

idToken = lookupIdTokenResponseInCache(cacheKey, -1);
if (idToken != null) {
return idToken;
oidcResponse = lookupIdTokenResponseInCache(cacheKey, -1);
if (oidcResponse != null) {
return oidcResponse;
}
}
throw new ZTSClientException(ex.getCode(), ex.getData());
Expand All @@ -3206,9 +3190,9 @@ public String getIDToken(String responseType, String clientId, String redirectUr
// instead of returning failure

if (cacheKey != null && !ignoreCache) {
idToken = lookupIdTokenResponseInCache(cacheKey, -1);
if (idToken != null) {
return idToken;
oidcResponse = lookupIdTokenResponseInCache(cacheKey, -1);
if (oidcResponse != null) {
return oidcResponse;
}
}
throw new ZTSClientException(ResourceException.BAD_REQUEST, ex.getMessage());
Expand All @@ -3223,65 +3207,55 @@ public String getIDToken(String responseType, String clientId, String redirectUr
keyType, fullArn);
}
if (cacheKey != null) {
ID_TOKEN_CACHE.put(cacheKey, new IdTokenCacheEntry(idToken, extractIdTokenExpiry(idToken)));
ID_TOKEN_CACHE.put(cacheKey, oidcResponse);
}
}

return idToken;
return oidcResponse;
}

String lookupIdTokenResponseInCache(String cacheKey, Integer expiryTime) {
OIDCResponse lookupIdTokenResponseInCache(String cacheKey, Integer expirySeconds) {

IdTokenCacheEntry idTokenCacheEntry = ID_TOKEN_CACHE.get(cacheKey);
if (idTokenCacheEntry == null) {
OIDCResponse oidcResponse = ID_TOKEN_CACHE.get(cacheKey);
if (oidcResponse == null) {
if (LOG.isInfoEnabled()) {
LOG.info("LookupIdTokenResponseInCache: cache-lookup key: {} result: not found", cacheKey);
}
return null;
}

long now = System.currentTimeMillis() / 1000;
long tokenExpiryTime = oidcResponse.getExpiration_time();

// default timeout for id tokens is 1 hour

if (expiryTime == null) {
expiryTime = 60 * 60;
if (expirySeconds == null) {
expirySeconds = 60 * 60;
}

// if our expiry seconds is -1 then we should return
// our cached object as long as it's not expired

if (expirySeconds == -1 && tokenExpiryTime > now) {
return oidcResponse;
}

// before returning our cache hit we need to make sure it
// was at least 1/4th time left before the token expires
// if the expiryTime is -1 then we return the token as
// long as it's not expired

if (idTokenCacheEntry.isExpired(expiryTime)) {
if (idTokenCacheEntry.isExpired(-1)) {
if (tokenExpiryTime < now + expirySeconds / 4) {

// if the token is completely expired then we'll remove it from the cache

if (tokenExpiryTime <= now) {
ID_TOKEN_CACHE.remove(cacheKey);
}
return null;
}

return idTokenCacheEntry.getIdToken();
}

String extractIdTokenFromLocation(Map<String, List<String>> responseHeaders, final String redirectUri,
final String state) {

//the format of the location header is <redirect-uri>#id_token=<token>&state=<state>
List<String> locationValues = responseHeaders.get("location");
if (isEmpty(locationValues)) {
return "";
}
final String location = locationValues.get(0);
final String prefix = redirectUri + "#id_token=";
if (!location.startsWith(prefix)) {
return "";
}
if (isEmpty(state)) {
return location.substring(prefix.length());
}
final String suffix = "&state=" + state;
if (!location.endsWith(suffix)) {
return "";
}
return location.substring(prefix.length(), prefix.length() + location.length() - prefix.length() - suffix.length());
return oidcResponse;
}

String generateIdTokenScope(final String domainName, List<String> roleNames) {
Expand Down
Loading