diff --git a/libs/java/server_common/src/main/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStore.java b/libs/java/server_common/src/main/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStore.java index 69ee50bf7c7..51912dd5d6c 100644 --- a/libs/java/server_common/src/main/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStore.java +++ b/libs/java/server_common/src/main/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStore.java @@ -24,7 +24,9 @@ import com.amazonaws.services.s3.model.S3Object; import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.yahoo.athenz.auth.PrivateKeyStore; +import com.yahoo.athenz.auth.ServerPrivateKey; import com.yahoo.athenz.auth.util.Crypto; +import org.eclipse.jetty.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -79,33 +81,57 @@ public AwsPrivateKeyStore() { } private static AWSKMS initAWSKMS() { - String s3Region = System.getProperty(ATHENZ_PROP_AWS_KMS_REGION); - ///CLOVER:OFF - if (null != s3Region && !s3Region.isEmpty()) { - return AWSKMSClientBuilder.standard().withRegion(s3Region).build(); - } - return AWSKMSClientBuilder.defaultClient(); - ///CLOVER:ON + final String kmsRegion = System.getProperty(ATHENZ_PROP_AWS_KMS_REGION); + return StringUtil.isEmpty(kmsRegion) ? AWSKMSClientBuilder.defaultClient() : AWSKMSClientBuilder.standard().withRegion(kmsRegion).build(); } private static AmazonS3 initAmazonS3() { - String s3Region = System.getProperty(ATHENZ_PROP_AWS_S3_REGION); - ///CLOVER:OFF - if (null != s3Region && !s3Region.isEmpty()) { - return AmazonS3ClientBuilder.standard().withRegion(s3Region).build(); - } - return AmazonS3ClientBuilder.defaultClient(); - ///CLOVER:ON + final String s3Region = System.getProperty(ATHENZ_PROP_AWS_S3_REGION); + return StringUtil.isEmpty(s3Region) ? AmazonS3ClientBuilder.defaultClient() : AmazonS3ClientBuilder.standard().withRegion(s3Region).build(); } public AwsPrivateKeyStore(final AmazonS3 s3, final AWSKMS kms) { this.s3 = s3; this.kms = kms; } - + + @Override + public ServerPrivateKey getPrivateKey(String service, String serverHostName, String serverRegion, String algorithm) { + + final String bucketName; + String keyName; + String keyIdName; + + final String objectSuffix = "." + algorithm.toLowerCase(); + if (ZMS_SERVICE.equals(service)) { + bucketName = System.getProperty(ATHENZ_PROP_ZMS_BUCKET_NAME); + keyName = System.getProperty(ATHENZ_PROP_ZMS_KEY_NAME, ATHENZ_DEFAULT_KEY_NAME) + objectSuffix; + keyIdName = System.getProperty(ATHENZ_PROP_ZMS_KEY_ID_NAME, ATHENZ_DEFAULT_KEY_ID_NAME) + objectSuffix; + } else if (ZTS_SERVICE.equals(service)) { + bucketName = System.getProperty(ATHENZ_PROP_ZTS_BUCKET_NAME); + keyName = System.getProperty(ATHENZ_PROP_ZTS_KEY_NAME, ATHENZ_DEFAULT_KEY_NAME) + objectSuffix; + keyIdName = System.getProperty(ATHENZ_PROP_ZTS_KEY_ID_NAME, ATHENZ_DEFAULT_KEY_ID_NAME) + objectSuffix; + } else { + LOG.error("Unknown service specified: {}", service); + return null; + } + + if (bucketName == null) { + LOG.error("No bucket name specified with system property"); + return null; + } + + PrivateKey pkey = null; + try { + pkey = Crypto.loadPrivateKey(getDecryptedData(bucketName, keyName)); + } catch (Exception ex) { + LOG.error("unable to load private key", ex); + } + return pkey == null ? null : new ServerPrivateKey(pkey, getDecryptedData(bucketName, keyIdName)); + } + @Override - public PrivateKey getPrivateKey(String service, String serverHostName, - StringBuilder privateKeyId) { + public PrivateKey getPrivateKey(String service, String serverHostName, StringBuilder privateKeyId) { String bucketName; String keyName; @@ -158,11 +184,9 @@ private String getDecryptedData(final String bucketName, final String keyName) { byte[] buffer = new byte[1024]; int length; - ///CLOVER:OFF while ((length = s3InputStream.read(buffer)) != -1) { result.write(buffer, 0, length); } - ///CLOVER:ON // if key should be decrypted, do so with KMS if (kmsDecrypt) { diff --git a/libs/java/server_common/src/test/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStoreTest.java b/libs/java/server_common/src/test/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStoreTest.java index 77d44b67ae8..3b35da520bf 100644 --- a/libs/java/server_common/src/test/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStoreTest.java +++ b/libs/java/server_common/src/test/java/com/yahoo/athenz/auth/impl/aws/AwsPrivateKeyStoreTest.java @@ -21,24 +21,28 @@ import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.S3Object; import com.amazonaws.services.s3.model.S3ObjectInputStream; +import com.yahoo.athenz.auth.ServerPrivateKey; import org.mockito.Mockito; -import org.testng.Assert; import org.testng.annotations.Test; import java.io.ByteArrayInputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertTrue; +import static org.testng.Assert.*; public class AwsPrivateKeyStoreTest { + private static final String ATHENZ_PROP_ZTS_BUCKET_NAME = "athenz.aws.zts.bucket_name"; private static final String ATHENZ_AWS_KMS_REGION = "athenz.aws.store_kms_region"; + @Test public void testAwsPrivateKeyStore() { System.setProperty("athenz.aws.s3.region", "us-east-1"); @@ -67,7 +71,7 @@ public void testAwsPrivateKeyStore() { String actual = awsPrivateKeyStore.getApplicationSecret(bucketName, keyName); StringBuilder privateKeyId = new StringBuilder(keyName); awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", privateKeyId); - Assert.assertEquals(actual, expected); + assertEquals(actual, expected); Mockito.when(s3Object.getObjectContent()).thenAnswer(invocation -> { throw new IOException("test IOException"); }); awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", privateKeyId); @@ -82,9 +86,7 @@ public void testGetPrivateKey() { AwsPrivateKeyStoreFactory awsPrivateKeyStoreFactory = new AwsPrivateKeyStoreFactory(); assertTrue(awsPrivateKeyStoreFactory.create() instanceof AwsPrivateKeyStore); - AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(); - awsPrivateKeyStore = new AwsPrivateKeyStore(); StringBuilder privateKeyId = new StringBuilder("testPrivateKeyId"); awsPrivateKeyStore.getPrivateKey("zms", "testServerHostName", privateKeyId); awsPrivateKeyStore.getPrivateKey("testService", "testserverHostname", privateKeyId); @@ -119,7 +121,7 @@ public void testGetApplicationSecret() { doReturn(s3).when(spyAWS).getS3(); doReturn(kms).when(spyAWS).getKMS(); String actual = spyAWS.getApplicationSecret(bucketName, keyName); - Assert.assertEquals(actual, expected); + assertEquals(actual, expected); System.clearProperty("athenz.aws.s3.region"); System.clearProperty(ATHENZ_AWS_KMS_REGION); } @@ -136,7 +138,6 @@ public void testGetEncryptedDataException() { AWSKMS kms = mock(AWSKMS.class); S3Object s3Object = mock(S3Object.class); Mockito.when(s3.getObject(bucketName, keyName)).thenReturn(s3Object); - InputStream is = new ByteArrayInputStream( expected.getBytes() ); given(s3Object.getObjectContent()).willAnswer(invocation -> { throw new IOException();}); ByteBuffer buffer = ByteBuffer.wrap(expected.getBytes()); @@ -164,4 +165,161 @@ public void testGetKMS() { assertEquals(privateKeyStore.getKMS(), kms); } + + @Test + public void testGetPrivateKeyAlgorithm() { + + // first valid zms/zts services + + try { + testGetPrivateKeyAlgorithm("zms"); + } catch (IOException ex) { + fail(ex.getMessage()); + } + + try { + testGetPrivateKeyAlgorithm("zts"); + } catch (IOException ex) { + fail(ex.getMessage()); + } + } + + @Test + public void testGetPrivateKeyAlgorithmFailures() { + + // with unknown service we should get a null object + + AmazonS3 s3 = mock(AmazonS3.class); + AWSKMS kms = mock(AWSKMS.class); + AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms); + assertNull(awsPrivateKeyStore.getPrivateKey("msd", "testServerHostName", "us-east-1", "rsa")); + + // with no bucket with should get a null object + + System.clearProperty("athenz.aws.zts.bucket_name"); + System.setProperty("athenz.aws.zts.key_name", "key"); + System.setProperty("athenz.aws.zts.key_id_name", "keyid"); + assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa")); + + System.clearProperty("athenz.aws.zts.bucket_name"); + System.clearProperty("athenz.aws.zts.key_name"); + System.clearProperty("athenz.aws.zts.key_id_name"); + } + + private void testGetPrivateKeyAlgorithm(final String service) throws IOException { + + final String bucketName = "my_bucket"; + final String keyName = "my_key"; + final String algKeyName = "my_key.rsa"; + final String keyId = "my_key_id"; + final String algKeyId = "my_key_id.rsa"; + final String expectedKeyId = "1"; + + System.setProperty("athenz.aws.s3.region", "us-east-1"); + + System.setProperty("athenz.aws." + service + ".bucket_name", bucketName); + System.setProperty("athenz.aws." + service + ".key_name", keyName); + System.setProperty("athenz.aws." + service + ".key_id_name", keyId); + + AmazonS3 s3 = mock(AmazonS3.class); + AWSKMS kms = mock(AWSKMS.class); + + S3Object s3ObjectKey = mock(S3Object.class); + Mockito.when(s3.getObject(bucketName, algKeyName)).thenReturn(s3ObjectKey); + File privKeyFile = new File("src/test/resources/unit_test_zts_private.pem"); + final String privKey = new String(Files.readAllBytes(privKeyFile.toPath()), StandardCharsets.UTF_8); + InputStream isKey = new ByteArrayInputStream( privKey.getBytes() ); + S3ObjectInputStream s3ObjectKeyInputStream = new S3ObjectInputStream(isKey, null); + Mockito.when(s3ObjectKey.getObjectContent()).thenReturn(s3ObjectKeyInputStream); + + S3Object s3ObjectKeyId = mock(S3Object.class); + Mockito.when(s3.getObject(bucketName, algKeyId)).thenReturn(s3ObjectKeyId); + InputStream isKeyId = new ByteArrayInputStream( expectedKeyId.getBytes() ); + S3ObjectInputStream s3ObjectKeyIdInputStream = new S3ObjectInputStream(isKeyId, null); + Mockito.when(s3ObjectKeyId.getObjectContent()).thenReturn(s3ObjectKeyIdInputStream); + + AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms); + ServerPrivateKey serverPrivateKey = awsPrivateKeyStore.getPrivateKey(service, "testServerHostName", "us-east-1", "rsa"); + assertNotNull(serverPrivateKey); + assertNotNull(serverPrivateKey.getKey()); + assertEquals(serverPrivateKey.getAlgorithm().toString(), "RS256"); + assertEquals(serverPrivateKey.getId(), "1"); + + System.clearProperty("athenz.aws.s3.region"); + + System.clearProperty("athenz.aws." + service + ".bucket_name"); + System.clearProperty("athenz.aws." + service + ".key_name"); + System.clearProperty("athenz.aws." + service + ".key_id_name"); + } + + @Test + public void testGetPrivateKeyAlgorithmInvalidKey() { + + final String bucketName = "my_bucket"; + final String keyName = "my_key"; + final String algKeyName = "my_key.rsa"; + final String keyId = "my_key_id"; + final String algKeyId = "my_key_id.rsa"; + final String expectedKeyId = "1"; + final String privKey = "invalid-key"; + + System.setProperty("athenz.aws.s3.region", "us-east-1"); + + System.setProperty("athenz.aws.zts.bucket_name", bucketName); + System.setProperty("athenz.aws.zts.key_name", keyName); + System.setProperty("athenz.aws.zts.key_id_name", keyId); + + AmazonS3 s3 = mock(AmazonS3.class); + AWSKMS kms = mock(AWSKMS.class); + + S3Object s3ObjectKey = mock(S3Object.class); + Mockito.when(s3.getObject(bucketName, algKeyName)).thenReturn(s3ObjectKey); + InputStream isKey = new ByteArrayInputStream( privKey.getBytes() ); + S3ObjectInputStream s3ObjectKeyInputStream = new S3ObjectInputStream(isKey, null); + Mockito.when(s3ObjectKey.getObjectContent()).thenReturn(s3ObjectKeyInputStream); + + S3Object s3ObjectKeyId = mock(S3Object.class); + Mockito.when(s3.getObject(bucketName, algKeyId)).thenReturn(s3ObjectKeyId); + InputStream isKeyId = new ByteArrayInputStream( expectedKeyId.getBytes() ); + S3ObjectInputStream s3ObjectKeyIdInputStream = new S3ObjectInputStream(isKeyId, null); + Mockito.when(s3ObjectKeyId.getObjectContent()).thenReturn(s3ObjectKeyIdInputStream); + + AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms); + assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa")); + + System.clearProperty("athenz.aws.s3.region"); + + System.clearProperty("athenz.aws.zts.bucket_name"); + System.clearProperty("athenz.aws.zts.key_name"); + System.clearProperty("athenz.aws.zts.key_id_name"); + } + + @Test + public void testGetPrivateKeyAlgorithmException() { + + final String bucketName = "my_bucket"; + final String keyName = "my_key"; + final String algKeyName = "my_key.rsa"; + final String keyId = "my_key_id"; + + System.setProperty("athenz.aws.s3.region", "us-east-1"); + + System.setProperty("athenz.aws.zts.bucket_name", bucketName); + System.setProperty("athenz.aws.zts.key_name", keyName); + System.setProperty("athenz.aws.zts.key_id_name", keyId); + + AmazonS3 s3 = mock(AmazonS3.class); + AWSKMS kms = mock(AWSKMS.class); + + Mockito.when(s3.getObject(bucketName, algKeyName)).thenThrow(new IndexOutOfBoundsException()); + + AwsPrivateKeyStore awsPrivateKeyStore = new AwsPrivateKeyStore(s3, kms); + assertNull(awsPrivateKeyStore.getPrivateKey("zts", "testServerHostName", "us-east-1", "rsa")); + + System.clearProperty("athenz.aws.s3.region"); + + System.clearProperty("athenz.aws.zts.bucket_name"); + System.clearProperty("athenz.aws.zts.key_name"); + System.clearProperty("athenz.aws.zts.key_id_name"); + } }