package org.zalando.riptide;
import org.junit.Test;
import org.springframework.core.io.InputStreamResource;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.test.web.client.MockRestServiceServer;
import org.springframework.web.client.AsyncRestTemplate;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.singletonList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.springframework.http.MediaType.APPLICATION_OCTET_STREAM;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;
import static org.zalando.riptide.Bindings.on;
import static org.zalando.riptide.Navigators.contentType;
public final class InputStreamTest {
static class InputStreamHttpMessageConverter implements HttpMessageConverter<InputStream> {
@Override
public boolean canRead(final Class<?> clazz, final MediaType mediaType) {
return InputStream.class.isAssignableFrom(clazz);
}
@Override
public boolean canWrite(final Class<?> clazz, final MediaType mediaType) {
return false;
}
@Override
public List<MediaType> getSupportedMediaTypes() {
return singletonList(MediaType.APPLICATION_OCTET_STREAM);
}
@Override
public InputStream read(final Class<? extends InputStream> clazz, final HttpInputMessage inputMessage) throws IOException,
HttpMessageNotReadableException {
return inputMessage.getBody();
}
@Override
public void write(final InputStream t, final MediaType contentType, final HttpOutputMessage outputMessage) throws IOException,
HttpMessageNotWritableException {
throw new IllegalStateException();
}
}
static class CloseOnceInputStream extends InputStream {
private final InputStream inputStream;
private boolean isClosed;
public CloseOnceInputStream(final InputStream inputStream) {
this.inputStream = inputStream;
}
public CloseOnceInputStream(final byte[] buf) {
this(new ByteArrayInputStream(buf));
}
private void checkClosed() throws IOException {
if (isClosed) {
throw new IOException("Stream is already closed");
}
}
@Override
public void close() throws IOException {
checkClosed();
isClosed = true;
inputStream.close();
}
@Override
public synchronized void mark(final int readlimit) {
inputStream.mark(readlimit);
}
@Override
public synchronized void reset() throws IOException {
checkClosed();
inputStream.reset();
}
@Override
public boolean markSupported() {
return inputStream.markSupported();
}
@Override
public synchronized int read() throws IOException {
checkClosed();
return inputStream.read();
}
@Override
public synchronized int read(final byte[] b, final int off, final int len) throws IOException {
checkClosed();
return inputStream.read(b, off, len);
}
@Override
public synchronized long skip(final long n) throws IOException {
checkClosed();
return inputStream.skip(n);
}
@Override
public synchronized int available() throws IOException {
checkClosed();
return inputStream.available();
}
}
private final URI url = URI.create("https://api.example.com/blobs/123");
private final Rest unit;
private final MockRestServiceServer server;
public InputStreamTest() {
final AsyncRestTemplate template = new AsyncRestTemplate();
this.server = MockRestServiceServer.createServer(template);
this.unit = Rest.builder()
.requestFactory(template.getAsyncRequestFactory())
.converter(new InputStreamHttpMessageConverter())
.baseUrl("https://api.example.com")
.build();
}
@Test
public void shouldAllowCloseOnce() throws IOException {
final InputStream content = new CloseOnceInputStream(new byte[]{'b', 'l', 'o', 'b'});
content.close();
try {
content.close();
fail("Should prevent multiple close calls");
} catch (final IOException e) {
assertEquals("Stream is already closed", e.getMessage());
}
}
@Test
public void shouldNotAllowReadAfterClose() throws IOException {
final InputStream content = new CloseOnceInputStream(new byte[]{'b', 'l', 'o', 'b'});
content.close();
try {
//noinspection ResultOfMethodCallIgnored
content.read();
fail("Should prevent read calls after close");
} catch (final IOException e) {
assertEquals("Stream is already closed", e.getMessage());
}
}
@Test
public void shouldExtractOriginalBody() throws IOException {
final InputStream content = new CloseOnceInputStream(new byte[]{'b', 'l', 'o', 'b'});
server.expect(requestTo(url)).andRespond(
withSuccess()
.body(new InputStreamResource(content))
.contentType(APPLICATION_OCTET_STREAM));
final AtomicReference<InputStream> capture = new AtomicReference<>();
unit.get(url)
.dispatch(contentType(),
on(APPLICATION_OCTET_STREAM).call(InputStream.class, capture::set))
.join();
final InputStream inputStream = capture.get();
assertEquals(content, inputStream);
final int ch1 = inputStream.read();
assertEquals('b', ch1);
}
}