package org.zalando.stups.fullstop.jobs.rds; import com.amazonaws.regions.Region; import com.amazonaws.services.rds.AmazonRDSClient; import com.amazonaws.services.rds.model.DBInstance; import com.amazonaws.services.rds.model.DescribeDBInstancesRequest; import com.amazonaws.services.rds.model.DescribeDBInstancesResult; import com.amazonaws.services.rds.model.Endpoint; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.zalando.stups.fullstop.aws.ClientProvider; import org.zalando.stups.fullstop.jobs.common.AccountIdSupplier; import org.zalando.stups.fullstop.jobs.config.JobsProperties; import org.zalando.stups.fullstop.violation.Violation; import org.zalando.stups.fullstop.violation.ViolationSink; import java.util.Map; import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Sets.newHashSet; import static org.junit.Assert.assertArrayEquals; import static org.mockito.Mockito.*; public class FetchRdsJobTest { private ClientProvider clientProviderMock; private JobsProperties jobsPropertiesMock; private ViolationSink violationSinkMock; private AmazonRDSClient amazonRDSClientMock; private DescribeDBInstancesResult describeDBInstancesResultMock; private AccountIdSupplier accountIdSupplierMock; @Before public void setUp() throws Exception { this.clientProviderMock = mock(ClientProvider.class); this.jobsPropertiesMock = mock(JobsProperties.class); this.violationSinkMock = mock(ViolationSink.class); this.amazonRDSClientMock = mock(AmazonRDSClient.class); this.accountIdSupplierMock = mock(AccountIdSupplier.class); when(accountIdSupplierMock.get()).thenReturn(newHashSet("54321")); // Jobsproperties when(jobsPropertiesMock.getWhitelistedRegions()).thenReturn(newArrayList("eu-west-1")); // Dbinstances final Endpoint endpoint = new Endpoint(); endpoint.setAddress("aws.db.cn"); final Endpoint endpoint2 = new Endpoint(); endpoint2.setAddress("aws.db2.cn"); final DBInstance dbInstance1 = new DBInstance(); dbInstance1.setPubliclyAccessible(true); dbInstance1.setEndpoint(endpoint); final DBInstance dbInstance2 = new DBInstance(); dbInstance2.setPubliclyAccessible(false); dbInstance2.setEndpoint(endpoint); final DBInstance dbInstance3 = new DBInstance(); dbInstance3.setPubliclyAccessible(true); dbInstance3.setEndpoint(endpoint2); describeDBInstancesResultMock = new DescribeDBInstancesResult(); describeDBInstancesResultMock.setDBInstances(newArrayList(dbInstance1, dbInstance2, dbInstance3)); // clientprovider when(clientProviderMock.getClient(any(), any(String.class), any(Region.class))).thenReturn(amazonRDSClientMock); } @After public void tearDown() throws Exception { verifyNoMoreInteractions(accountIdSupplierMock, clientProviderMock, jobsPropertiesMock, violationSinkMock, amazonRDSClientMock); } @Test public void testCheck() throws Exception { final FetchRdsJob fetchRdsJob = new FetchRdsJob(accountIdSupplierMock, clientProviderMock, jobsPropertiesMock, violationSinkMock); when(amazonRDSClientMock.describeDBInstances(any(DescribeDBInstancesRequest.class))).thenReturn(describeDBInstancesResultMock); fetchRdsJob.run(); ArgumentCaptor<Violation> violations = ArgumentCaptor.forClass(Violation.class); verify(violationSinkMock, times(2)).put(violations.capture()); verify(accountIdSupplierMock, times(1)).get(); verify(amazonRDSClientMock, times(1)).describeDBInstances(any(DescribeDBInstancesRequest.class)); verify(jobsPropertiesMock, times(1)).getWhitelistedRegions(); verify(clientProviderMock, times(1)).getClient(any(), any(String.class), any(Region.class)); // Regression test for #479: Make sure that the metadata lists the correct endpoints. assertArrayEquals(new String[] {"aws.db.cn", "aws.db2.cn"}, violations.getAllValues().stream() .map(v -> ((Map<String, Object>) v.getMetaInfo()).get("unsecuredDatabase")) .toArray()); } }