/** * Copyright 2015 StreamSets Inc. * * Licensed under the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.streamsets.pipeline.lib.http; import com.google.common.collect.ImmutableList; import com.streamsets.pipeline.api.OnRecordError; import com.streamsets.pipeline.api.Stage; import com.streamsets.pipeline.lib.tls.TlsConfigBean; import com.streamsets.pipeline.lib.tls.TlsConnectionType; import com.streamsets.pipeline.sdk.ContextInfoCreator; import com.streamsets.pipeline.stage.util.tls.TLSTestUtils; import com.streamsets.testing.NetworkUtils; import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.File; import java.io.FileInputStream; import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; import java.security.KeyPair; import java.security.KeyStore; import java.security.cert.Certificate; import java.util.UUID; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; public class TestReceiverServer { @Test public void testLifecycleHttp() throws Exception { final int port = NetworkUtils.getRandomPort(); HttpConfigs configs = new HttpConfigs("g", "p") { @Override public int getPort() { return port; } @Override public int getMaxConcurrentRequests() { return 10; } @Override public String getAppId() { return "id"; } @Override public int getMaxHttpRequestSizeKB() { return 1; } @Override public boolean isTlsEnabled() { return false; } @Override public boolean isAppIdViaQueryParamAllowed() { return false; } @Override public TlsConfigBean getTlsConfigBean() { return null; } }; HttpReceiver receiver = Mockito.mock(HttpReceiverWithFragmenterWriter.class); Mockito.when(receiver.getAppId()).thenReturn("id"); Mockito.when(receiver.getUriPath()).thenReturn("/path"); BlockingQueue<Exception> exQueue = new ArrayBlockingQueue<>(10); HttpReceiverServer server = new HttpReceiverServer(configs, receiver, exQueue); Assert.assertTrue(server.getJettyServerThreads(1) > 1); Assert.assertTrue(server.getJettyServerMaxThreads() > 10); Assert.assertTrue(server.getJettyServerMinThreads() >= server.getJettyServerThreads(1)); Stage.Context context = ContextInfoCreator.createSourceContext("i", false, OnRecordError.TO_ERROR, ImmutableList.of("a")); try { Assert.assertTrue(server.init(context).isEmpty()); // valid ping HttpURLConnection conn = (HttpURLConnection) new URL("http://localhost:" + port + "/path").openConnection(); conn.setRequestProperty(HttpConstants.X_SDC_APPLICATION_ID_HEADER, "id"); Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); Assert.assertEquals(HttpConstants.X_SDC_PING_VALUE, conn.getHeaderField(HttpConstants.X_SDC_PING_HEADER)); // invalid ping conn = (HttpURLConnection) new URL("http://localhost:" + port + "/path").openConnection(); conn.setRequestProperty(HttpConstants.X_SDC_APPLICATION_ID_HEADER, "invalid"); Assert.assertEquals(HttpURLConnection.HTTP_FORBIDDEN, conn.getResponseCode()); // valid post Mockito.reset(receiver); Mockito.when(receiver.getAppId()).thenReturn("id"); Mockito.when(receiver.validate(Mockito.any(HttpServletRequest.class), Mockito.any(HttpServletResponse.class))) .thenReturn(true); conn = (HttpURLConnection) new URL("http://localhost:" + port + "/path").openConnection(); conn.setRequestProperty(HttpConstants.X_SDC_APPLICATION_ID_HEADER, "id"); conn.setDoOutput(true); conn.setRequestMethod("POST"); conn.getOutputStream().write("abc".getBytes()); Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode()); Mockito.verify(receiver, Mockito.times(1)).validate(Mockito.any(HttpServletRequest.class), Mockito.any (HttpServletResponse.class)); Mockito.verify(receiver, Mockito.times(1)).process(Mockito.any(HttpServletRequest.class), Mockito.any (InputStream.class)); // invalid post Mockito.reset(receiver); Mockito.when(receiver.getAppId()).thenReturn("id"); Mockito.when(receiver.validate(Mockito.any(HttpServletRequest.class), Mockito.any(HttpServletResponse.class))) .thenReturn(false); conn = (HttpURLConnection) new URL("http://localhost:" + port + "/path").openConnection(); conn.setRequestProperty(HttpConstants.X_SDC_APPLICATION_ID_HEADER, "id"); conn.setDoOutput(true); conn.setRequestMethod("POST"); conn.getOutputStream().write("abc".getBytes()); conn.getResponseCode(); Mockito.verify(receiver, Mockito.times(1)).validate(Mockito.any(HttpServletRequest.class), Mockito.any (HttpServletResponse.class)); Mockito.verify(receiver, Mockito.times(0)).process(Mockito.any(HttpServletRequest.class), Mockito.any (InputStream.class)); } finally { server.destroy(); } } @Test public void testLifecycleHttps() throws Exception { // setup TLS String hostname = TLSTestUtils.getHostname(); File testDir = new File("target", UUID.randomUUID().toString()).getAbsoluteFile(); final File keyStore = new File(testDir, "keystore.jks"); Assert.assertTrue(testDir.mkdirs()); final String keyStorePassword = "keystore"; final File trustStore = new File(testDir, "truststore.jks"); KeyPair keyPair = TLSTestUtils.generateKeyPair(); Certificate cert = TLSTestUtils.generateCertificate("CN=" + hostname, keyPair, 30); TLSTestUtils.createKeyStore(keyStore.toString(), keyStorePassword, "web", keyPair.getPrivate(), cert); TLSTestUtils.createTrustStore(trustStore.toString(), "truststore", "web", cert); final int port = NetworkUtils.getRandomPort(); final HttpConfigs configs = new HttpConfigs("g", "p") { @Override public int getPort() { return port; } @Override public int getMaxConcurrentRequests() { return 10; } @Override public String getAppId() { return "id"; } @Override public int getMaxHttpRequestSizeKB() { return 1; } @Override public boolean isTlsEnabled() { return true; } @Override public boolean isAppIdViaQueryParamAllowed() { return false; } @Override public TlsConfigBean getTlsConfigBean() { final TlsConfigBean tlsConfigBean = new TlsConfigBean(TlsConnectionType.SERVER); tlsConfigBean.keyStoreFilePath = keyStore.getAbsolutePath(); tlsConfigBean.keyStorePassword = keyStorePassword; return tlsConfigBean; } }; HttpReceiver receiver = Mockito.mock(HttpReceiverWithFragmenterWriter.class); Mockito.when(receiver.getAppId()).thenReturn("id"); Mockito.when(receiver.getUriPath()).thenReturn("/path"); BlockingQueue<Exception> exQueue = new ArrayBlockingQueue<>(10); HttpReceiverServer server = new HttpReceiverServer(configs, receiver, exQueue); Assert.assertTrue(server.getJettyServerThreads(1) > 1); Assert.assertTrue(server.getJettyServerMaxThreads() > 10); Assert.assertTrue(server.getJettyServerMinThreads() >= server.getJettyServerThreads(1)); ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); final Stage.Context context = ContextInfoCreator.createSourceContext("i", false, OnRecordError.TO_ERROR, ImmutableList.of("a")); try { Assert.assertTrue(configs.init(context).isEmpty()); Assert.assertTrue(server.init(context).isEmpty()); Future<Boolean> future = executor.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { HttpURLConnection conn = getConnection("/path", configs.getAppId(), context, TLSTestUtils.getHostname() + ":" + configs.getPort(), trustStore.toString(), "truststore" ); return conn.getResponseCode() == HttpURLConnection.HTTP_OK && HttpConstants.X_SDC_PING_VALUE.equals(conn.getHeaderField(HttpConstants.X_SDC_PING_HEADER)); } }); Assert.assertTrue(future.get(5, TimeUnit.SECONDS)); } finally { server.destroy(); executor.shutdownNow(); } } static final HostnameVerifier ACCEPT_ALL_HOSTNAME_VERIFIER = new HostnameVerifier() { @Override public boolean verify(String s, SSLSession sslSession) { return true; } }; private HttpURLConnection getConnection( String path, String appId, Stage.Context context, String hostPort, String trustStoreFile, String trustStorePassword ) throws Exception { URL url = new URL("https://" + hostPort.trim() + path); HttpURLConnection conn = (HttpURLConnection) url.openConnection(); conn.setConnectTimeout(1000); conn.setReadTimeout(1000); HttpsURLConnection sslConn = (HttpsURLConnection) conn; sslConn.setSSLSocketFactory(createSSLSocketFactory(context, trustStoreFile, trustStorePassword)); sslConn.setHostnameVerifier(ACCEPT_ALL_HOSTNAME_VERIFIER); conn.setRequestProperty(HttpConstants.X_SDC_APPLICATION_ID_HEADER, appId); return conn; } private String SSL_CERTIFICATE = "SunX509"; private String[] SSL_ENABLED_PROTOCOLS = {"TLSv1.2"}; private SSLSocketFactory createSSLSocketFactory( Stage.Context context, String trustStoreFile, String trustStorePassword ) throws Exception { SSLSocketFactory sslSocketFactory; if (trustStoreFile.isEmpty()) { sslSocketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault(); } else { KeyStore keystore = KeyStore.getInstance("jks"); try (InputStream is = new FileInputStream(new File(context.getResourcesDirectory(), trustStoreFile))) { keystore.load(is, trustStorePassword.toCharArray()); } KeyManagerFactory keyMgrFactory = KeyManagerFactory.getInstance(SSL_CERTIFICATE); keyMgrFactory.init(keystore, trustStorePassword.toCharArray()); KeyManager[] keyManagers = keyMgrFactory.getKeyManagers(); TrustManager[] trustManagers = new TrustManager[1]; TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(SSL_CERTIFICATE); trustManagerFactory.init(keystore); for (TrustManager trustManager1 : trustManagerFactory.getTrustManagers()) { if (trustManager1 instanceof X509TrustManager) { trustManagers[0] = trustManager1; break; } } SSLContext sslContext = SSLContext.getInstance("TLS"); sslContext.init(keyManagers, trustManagers, null); sslContext.getDefaultSSLParameters().setProtocols(SSL_ENABLED_PROTOCOLS); sslSocketFactory = sslContext.getSocketFactory(); } return sslSocketFactory; } }