Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add get private key per algorithm support for aws s3 keystore #1756

Merged
merged 1 commit into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import com.yahoo.athenz.auth.PrivateKeyStore;
import com.yahoo.athenz.auth.ServerPrivateKey;
import com.yahoo.athenz.auth.util.Crypto;
import org.eclipse.jetty.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -79,33 +81,57 @@ public AwsPrivateKeyStore() {
}

private static AWSKMS initAWSKMS() {
String s3Region = System.getProperty(ATHENZ_PROP_AWS_KMS_REGION);
///CLOVER:OFF
if (null != s3Region && !s3Region.isEmpty()) {
return AWSKMSClientBuilder.standard().withRegion(s3Region).build();
}
return AWSKMSClientBuilder.defaultClient();
///CLOVER:ON
final String kmsRegion = System.getProperty(ATHENZ_PROP_AWS_KMS_REGION);
return StringUtil.isEmpty(kmsRegion) ? AWSKMSClientBuilder.defaultClient() : AWSKMSClientBuilder.standard().withRegion(kmsRegion).build();
}

private static AmazonS3 initAmazonS3() {
String s3Region = System.getProperty(ATHENZ_PROP_AWS_S3_REGION);
///CLOVER:OFF
if (null != s3Region && !s3Region.isEmpty()) {
return AmazonS3ClientBuilder.standard().withRegion(s3Region).build();
}
return AmazonS3ClientBuilder.defaultClient();
///CLOVER:ON
final String s3Region = System.getProperty(ATHENZ_PROP_AWS_S3_REGION);
return StringUtil.isEmpty(s3Region) ? AmazonS3ClientBuilder.defaultClient() : AmazonS3ClientBuilder.standard().withRegion(s3Region).build();
}

public AwsPrivateKeyStore(final AmazonS3 s3, final AWSKMS kms) {
this.s3 = s3;
this.kms = kms;
}


@Override
public ServerPrivateKey getPrivateKey(String service, String serverHostName, String serverRegion, String algorithm) {

final String bucketName;
String keyName;
String keyIdName;

final String objectSuffix = "." + algorithm.toLowerCase();
if (ZMS_SERVICE.equals(service)) {
bucketName = System.getProperty(ATHENZ_PROP_ZMS_BUCKET_NAME);
keyName = System.getProperty(ATHENZ_PROP_ZMS_KEY_NAME, ATHENZ_DEFAULT_KEY_NAME) + objectSuffix;
keyIdName = System.getProperty(ATHENZ_PROP_ZMS_KEY_ID_NAME, ATHENZ_DEFAULT_KEY_ID_NAME) + objectSuffix;
} else if (ZTS_SERVICE.equals(service)) {
bucketName = System.getProperty(ATHENZ_PROP_ZTS_BUCKET_NAME);
keyName = System.getProperty(ATHENZ_PROP_ZTS_KEY_NAME, ATHENZ_DEFAULT_KEY_NAME) + objectSuffix;
keyIdName = System.getProperty(ATHENZ_PROP_ZTS_KEY_ID_NAME, ATHENZ_DEFAULT_KEY_ID_NAME) + objectSuffix;
} else {
LOG.error("Unknown service specified: {}", service);
return null;
}

if (bucketName == null) {
LOG.error("No bucket name specified with system property");
return null;
}

PrivateKey pkey = null;
try {
pkey = Crypto.loadPrivateKey(getDecryptedData(bucketName, keyName));
} catch (Exception ex) {
LOG.error("unable to load private key", ex);
}
return pkey == null ? null : new ServerPrivateKey(pkey, getDecryptedData(bucketName, keyIdName));
}

@Override
public PrivateKey getPrivateKey(String service, String serverHostName,
StringBuilder privateKeyId) {
public PrivateKey getPrivateKey(String service, String serverHostName, StringBuilder privateKeyId) {

String bucketName;
String keyName;
Expand Down Expand Up @@ -158,11 +184,9 @@ private String getDecryptedData(final String bucketName, final String keyName) {

byte[] buffer = new byte[1024];
int length;
///CLOVER:OFF
while ((length = s3InputStream.read(buffer)) != -1) {
result.write(buffer, 0, length);
}
///CLOVER:ON
// if key should be decrypted, do so with KMS

if (kmsDecrypt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import com.yahoo.athenz.auth.ServerPrivateKey;
import org.mockito.Mockito;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;

import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.*;

public class AwsPrivateKeyStoreTest {

private static final String ATHENZ_PROP_ZTS_BUCKET_NAME = "athenz.aws.zts.bucket_name";
private static final String ATHENZ_AWS_KMS_REGION = "athenz.aws.store_kms_region";

@Test
public void testAwsPrivateKeyStore() {
System.setProperty("athenz.aws.s3.region", "us-east-1");
Expand Down Expand Up @@ -67,7 +71,7 @@ public void testAwsPrivateKeyStore() {
String actual = awsPrivateKeyStore.getApplicationSecret(bucketName, keyName);
StringBuilder privateKeyId = new StringBuilder(keyName);
awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", privateKeyId);
Assert.assertEquals(actual, expected);
assertEquals(actual, expected);
Mockito.when(s3Object.getObjectContent()).thenAnswer(invocation -> { throw new IOException("test IOException"); });
awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", privateKeyId);

Expand All @@ -82,9 +86,7 @@ public void testGetPrivateKey() {
AwsPrivateKeyStoreFactory awsPrivateKeyStoreFactory = new AwsPrivateKeyStoreFactory();
assertTrue(awsPrivateKeyStoreFactory.create() instanceof AwsPrivateKeyStore);


AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore();
awsPrivateKeyStore = new AwsPrivateKeyStore();
StringBuilder privateKeyId = new StringBuilder("testPrivateKeyId");
awsPrivateKeyStore.getPrivateKey("zms", "testServerHostName", privateKeyId);
awsPrivateKeyStore.getPrivateKey("testService", "testserverHostname", privateKeyId);
Expand Down Expand Up @@ -119,7 +121,7 @@ public void testGetApplicationSecret() {
doReturn(s3).when(spyAWS).getS3();
doReturn(kms).when(spyAWS).getKMS();
String actual = spyAWS.getApplicationSecret(bucketName, keyName);
Assert.assertEquals(actual, expected);
assertEquals(actual, expected);
System.clearProperty("athenz.aws.s3.region");
System.clearProperty(ATHENZ_AWS_KMS_REGION);
}
Expand All @@ -136,7 +138,6 @@ public void testGetEncryptedDataException() {
AWSKMS kms = mock(AWSKMS.class);
S3Object s3Object = mock(S3Object.class);
Mockito.when(s3.getObject(bucketName, keyName)).thenReturn(s3Object);
InputStream is = new ByteArrayInputStream( expected.getBytes() );
given(s3Object.getObjectContent()).willAnswer(invocation -> { throw new IOException();});

ByteBuffer buffer = ByteBuffer.wrap(expected.getBytes());
Expand Down Expand Up @@ -164,4 +165,161 @@ public void testGetKMS() {

assertEquals(privateKeyStore.getKMS(), kms);
}

@Test
public void testGetPrivateKeyAlgorithm() {

// first valid zms/zts services

try {
testGetPrivateKeyAlgorithm("zms");
} catch (IOException ex) {
fail(ex.getMessage());
}

try {
testGetPrivateKeyAlgorithm("zts");
} catch (IOException ex) {
fail(ex.getMessage());
}
}

@Test
public void testGetPrivateKeyAlgorithmFailures() {

// with unknown service we should get a null object

AmazonS3 s3 = mock(AmazonS3.class);
AWSKMS kms = mock(AWSKMS.class);
AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms);
assertNull(awsPrivateKeyStore.getPrivateKey("msd", "testServerHostName", "us-east-1", "rsa"));

// with no bucket with should get a null object

System.clearProperty("athenz.aws.zts.bucket_name");
System.setProperty("athenz.aws.zts.key_name", "key");
System.setProperty("athenz.aws.zts.key_id_name", "keyid");
assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa"));

System.clearProperty("athenz.aws.zts.bucket_name");
System.clearProperty("athenz.aws.zts.key_name");
System.clearProperty("athenz.aws.zts.key_id_name");
}

private void testGetPrivateKeyAlgorithm(final String service) throws IOException {

final String bucketName = "my_bucket";
final String keyName = "my_key";
final String algKeyName = "my_key.rsa";
final String keyId = "my_key_id";
final String algKeyId = "my_key_id.rsa";
final String expectedKeyId = "1";

System.setProperty("athenz.aws.s3.region", "us-east-1");

System.setProperty("athenz.aws." + service + ".bucket_name", bucketName);
System.setProperty("athenz.aws." + service + ".key_name", keyName);
System.setProperty("athenz.aws." + service + ".key_id_name", keyId);

AmazonS3 s3 = mock(AmazonS3.class);
AWSKMS kms = mock(AWSKMS.class);

S3Object s3ObjectKey = mock(S3Object.class);
Mockito.when(s3.getObject(bucketName, algKeyName)).thenReturn(s3ObjectKey);
File privKeyFile = new File("src/test/resources/unit_test_zts_private.pem");
final String privKey = new String(Files.readAllBytes(privKeyFile.toPath()), StandardCharsets.UTF_8);
InputStream isKey = new ByteArrayInputStream( privKey.getBytes() );
S3ObjectInputStream s3ObjectKeyInputStream = new S3ObjectInputStream(isKey, null);
Mockito.when(s3ObjectKey.getObjectContent()).thenReturn(s3ObjectKeyInputStream);

S3Object s3ObjectKeyId = mock(S3Object.class);
Mockito.when(s3.getObject(bucketName, algKeyId)).thenReturn(s3ObjectKeyId);
InputStream isKeyId = new ByteArrayInputStream( expectedKeyId.getBytes() );
S3ObjectInputStream s3ObjectKeyIdInputStream = new S3ObjectInputStream(isKeyId, null);
Mockito.when(s3ObjectKeyId.getObjectContent()).thenReturn(s3ObjectKeyIdInputStream);

AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms);
ServerPrivateKey serverPrivateKey = awsPrivateKeyStore.getPrivateKey(service, "testServerHostName", "us-east-1", "rsa");
assertNotNull(serverPrivateKey);
assertNotNull(serverPrivateKey.getKey());
assertEquals(serverPrivateKey.getAlgorithm().toString(), "RS256");
assertEquals(serverPrivateKey.getId(), "1");

System.clearProperty("athenz.aws.s3.region");

System.clearProperty("athenz.aws." + service + ".bucket_name");
System.clearProperty("athenz.aws." + service + ".key_name");
System.clearProperty("athenz.aws." + service + ".key_id_name");
}

@Test
public void testGetPrivateKeyAlgorithmInvalidKey() {

final String bucketName = "my_bucket";
final String keyName = "my_key";
final String algKeyName = "my_key.rsa";
final String keyId = "my_key_id";
final String algKeyId = "my_key_id.rsa";
final String expectedKeyId = "1";
final String privKey = "invalid-key";

System.setProperty("athenz.aws.s3.region", "us-east-1");

System.setProperty("athenz.aws.zts.bucket_name", bucketName);
System.setProperty("athenz.aws.zts.key_name", keyName);
System.setProperty("athenz.aws.zts.key_id_name", keyId);

AmazonS3 s3 = mock(AmazonS3.class);
AWSKMS kms = mock(AWSKMS.class);

S3Object s3ObjectKey = mock(S3Object.class);
Mockito.when(s3.getObject(bucketName, algKeyName)).thenReturn(s3ObjectKey);
InputStream isKey = new ByteArrayInputStream( privKey.getBytes() );
S3ObjectInputStream s3ObjectKeyInputStream = new S3ObjectInputStream(isKey, null);
Mockito.when(s3ObjectKey.getObjectContent()).thenReturn(s3ObjectKeyInputStream);

S3Object s3ObjectKeyId = mock(S3Object.class);
Mockito.when(s3.getObject(bucketName, algKeyId)).thenReturn(s3ObjectKeyId);
InputStream isKeyId = new ByteArrayInputStream( expectedKeyId.getBytes() );
S3ObjectInputStream s3ObjectKeyIdInputStream = new S3ObjectInputStream(isKeyId, null);
Mockito.when(s3ObjectKeyId.getObjectContent()).thenReturn(s3ObjectKeyIdInputStream);

AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms);
assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa"));

System.clearProperty("athenz.aws.s3.region");

System.clearProperty("athenz.aws.zts.bucket_name");
System.clearProperty("athenz.aws.zts.key_name");
System.clearProperty("athenz.aws.zts.key_id_name");
}

@Test
public void testGetPrivateKeyAlgorithmException() {

final String bucketName = "my_bucket";
final String keyName = "my_key";
final String algKeyName = "my_key.rsa";
final String keyId = "my_key_id";

System.setProperty("athenz.aws.s3.region", "us-east-1");

System.setProperty("athenz.aws.zts.bucket_name", bucketName);
System.setProperty("athenz.aws.zts.key_name", keyName);
System.setProperty("athenz.aws.zts.key_id_name", keyId);

AmazonS3 s3 = mock(AmazonS3.class);
AWSKMS kms = mock(AWSKMS.class);

Mockito.when(s3.getObject(bucketName, algKeyName)).thenThrow(new IndexOutOfBoundsException());

AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms);
assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa"));

System.clearProperty("athenz.aws.s3.region");

System.clearProperty("athenz.aws.zts.bucket_name");
System.clearProperty("athenz.aws.zts.key_name");
System.clearProperty("athenz.aws.zts.key_id_name");
}
}