/*
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
package com.facebook.common.streams;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
/**
* Reads the wrapped InputStream only until a specified number of bytes, the 'limit' is reached.
*/
public class LimitedInputStream extends FilterInputStream {
private int mBytesToRead;
private int mBytesToReadWhenMarked;
public LimitedInputStream(InputStream inputStream, int limit) {
super(inputStream);
if (inputStream == null) {
throw new NullPointerException();
}
if (limit < 0) {
throw new IllegalArgumentException("limit must be >= 0");
}
mBytesToRead = limit;
mBytesToReadWhenMarked = -1;
}
@Override
public int read() throws IOException {
if (mBytesToRead == 0) {
return -1;
}
final int readByte = in.read();
if (readByte != -1) {
mBytesToRead--;
}
return readByte;
}
@Override
public int read(byte[] buffer, int byteOffset, int byteCount) throws IOException {
if (mBytesToRead == 0) {
return -1;
}
final int maxBytesToRead = Math.min(byteCount, mBytesToRead);
final int bytesRead = in.read(buffer, byteOffset, maxBytesToRead);
if (bytesRead > 0) {
mBytesToRead -= bytesRead;
}
return bytesRead;
}
@Override
public long skip(long byteCount) throws IOException {
final long maxBytesToSkip = Math.min(byteCount, mBytesToRead);
final long bytesSkipped = in.skip(maxBytesToSkip);
mBytesToRead -= bytesSkipped;
return bytesSkipped;
}
@Override
public int available() throws IOException {
return Math.min(in.available(), mBytesToRead);
}
@Override
public void mark(int readLimit) {
if (in.markSupported()) {
in.mark(readLimit);
mBytesToReadWhenMarked = mBytesToRead;
}
}
@Override
public void reset() throws IOException {
if (!in.markSupported()) {
throw new IOException("mark is not supported");
}
if (mBytesToReadWhenMarked == -1) {
throw new IOException("mark not set");
}
in.reset();
mBytesToRead = mBytesToReadWhenMarked;
}
}