/*
* TFTPServer.java
* Copyright (C) 2011,2012 Wannes De Smet
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.xenmaster.connectivity.tftp;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.TimerTask;
import java.util.concurrent.TimeUnit;
import net.wgr.utility.GlobalExecutorService;
import org.apache.commons.net.tftp.ExtendedTFTP;
import org.apache.commons.net.tftp.TFTPAckPacket;
import org.apache.commons.net.tftp.TFTPErrorPacket;
import org.apache.commons.net.tftp.TFTPOptionAckPacket;
import org.apache.commons.net.tftp.TFTPOptionReadRequestPacket;
import org.apache.commons.net.tftp.TFTPPXEDataPacket;
import org.apache.commons.net.tftp.TFTPPacket;
import org.apache.commons.net.tftp.TFTPPacketException;
import org.apache.log4j.Logger;
/**
*
* @created Oct 27, 2011
* @author double-u
*/
public class TFTPServer implements Runnable {
protected final Thread thread;
protected boolean run;
protected ExtendedTFTP tftp;
protected InetAddress clientAddress;
protected InputStream dataInputStream;
protected int blockNumber, blockSize = 512;
protected LinkedList<ActivityListener> listeners;
protected int count;
protected ResendTask resendTask;
public TFTPServer() {
tftp = new ExtendedTFTP();
tftp.beginBufferedOps();
tftp.setDefaultTimeout(0);
listeners = new LinkedList<>();
resendTask = new ResendTask();
thread = new Thread(this, "TFTP server");
GlobalExecutorService.get().scheduleAtFixedRate(resendTask, 500, 500, TimeUnit.MILLISECONDS);
}
public void addListener(ActivityListener al) {
listeners.push(al);
}
public void boot() {
run = true;
thread.start();
}
@Override
public void run() {
while (run) {
if (!tftp.isOpen()) {
try {
tftp.open(69, InetAddress.getByName("0.0.0.0"));
tftp.beginBufferedOps();
} catch (SocketException | UnknownHostException ex) {
Logger.getLogger(getClass()).error("TFTP listening failed", ex);
run = false;
return;
}
}
try {
TFTPPacket packet = tftp.bufferedReceive();
count++;
handlePacket(packet);
} catch (TFTPPacketException | IOException ex) {
Logger.getLogger(getClass()).error("TFTP receive failed", ex);
}
}
tftp.endBufferedOps();
tftp.close();
// Send a goodbye to all threads waiting for the tftp server to finish
synchronized (this.thread) {
this.thread.notifyAll();
}
}
protected void handlePacket(final TFTPPacket packet) throws IOException {
switch (packet.getType()) {
case TFTPPacket.READ_REQUEST:
if (clientAddress == null) {
TFTPOptionReadRequestPacket request = (TFTPOptionReadRequestPacket) packet;
dataInputStream = null;
try {
Logger.getLogger(getClass()).debug("Request for : " + request.getFilename() + " received from " + packet.getAddress().getCanonicalHostName()
+ ":" + packet.getPort());
for (ActivityListener al : listeners) {
InputStream is = al.pathRequest(request);
if (is != null) {
dataInputStream = is;
}
break;
}
if (dataInputStream == null) {
Logger.getLogger(getClass()).debug("No ActivityListener provided valid InputStream for TFTP request");
tftp.bufferedSend(new TFTPErrorPacket(packet.getAddress(), packet.getPort(), TFTPErrorPacket.FILE_NOT_FOUND, request.getFilename()));
return;
}
if (request.getOptions().size() > 0) {
HashMap<String, Integer> acks = new HashMap<>();
for (Map.Entry<String, Integer> entry : request.getOptions().entrySet()) {
switch (entry.getKey()) {
case "blksize":
blockSize = entry.getValue();
acks.put(entry.getKey(), blockSize);
tftp.restartBufferedOps(blockSize + 4);
break;
case "tsize":
// Client wants to know transfer size
acks.put(entry.getKey(), dataInputStream.available());
break;
}
}
blockNumber = 0;
clientAddress = packet.getAddress();
tftp.bufferedSend(new TFTPOptionAckPacket(packet.getAddress(), packet.getPort(), acks));
return;
}
blockNumber = 1;
byte[] data = new byte[blockSize];
final int bytesRead = dataInputStream.read(data);
tftp.bufferedSend(new TFTPPXEDataPacket(packet.getAddress(), packet.getPort(), blockNumber, data, 0, bytesRead));
// if more blocks to send
if (bytesRead == blockSize) {
clientAddress = packet.getAddress();
} else {
dataInputStream.close();
clientAddress = null;
}
} catch (FileNotFoundException ex) {
tftp.bufferedSend(new TFTPErrorPacket(packet.getAddress(), packet.getPort(), TFTPErrorPacket.FILE_NOT_FOUND, ex.getMessage()));
clientAddress = null;
}
}
break;
case TFTPPacket.ACKNOWLEDGEMENT:
if (packet.getAddress().equals(clientAddress)) {
TFTPAckPacket ackPacket = (TFTPAckPacket) packet;
Logger.getLogger(getClass()).debug("ACK : " + ackPacket.getBlockNumber() + " ~ " + blockNumber * blockSize + "/" + dataInputStream.available() + " " + ackPacket.getAddress().getCanonicalHostName());
// Check if client ACKd correctly
if (ackPacket.getBlockNumber() == blockNumber) {
// send next block
final byte[] data = new byte[blockSize];
final int bytesRead = dataInputStream.read(data);
blockNumber++;
TFTPPXEDataPacket dataPacket = new TFTPPXEDataPacket(packet.getAddress(), packet.getPort(), blockNumber, data, 0, bytesRead);
Logger.getLogger(getClass()).debug("Sending " + blockNumber + " to " + packet.getAddress().getCanonicalHostName() + " with " + bytesRead + " bytes");
tftp.bufferedSend(dataPacket);
resendTask.set(dataPacket, ackPacket.getBlockNumber());
// It is done
if (bytesRead < blockSize) {
clientAddress = null;
dataInputStream.close();
resendTask.chill();
}
}
}
break;
case TFTPPacket.ERROR:
TFTPErrorPacket tep = (TFTPErrorPacket) packet;
Logger.getLogger(getClass()).warn("TFTP error : " + tep.getMessage());
clientAddress = null;
blockNumber = 0;
break;
}
}
public void waitTillQuit() throws InterruptedException {
synchronized (this.thread) {
this.thread.wait();
}
}
protected class ResendTask extends TimerTask {
protected TFTPPXEDataPacket packet;
protected int blockCount;
public void set(TFTPPXEDataPacket packet, int blockCount) {
this.packet = packet;
this.blockCount = blockCount;
}
public void chill() {
this.packet = null;
}
@Override
public void run() {
if (packet == null) {
return;
}
if (blockNumber > blockCount) {
try {
Logger.getLogger(getClass()).debug("Resending " + blockNumber);
tftp.bufferedSend(packet);
} catch (IOException ex) {
chill();
Logger.getLogger(getClass()).warn("TFTP resend failed", ex);
}
}
}
}
public static interface ActivityListener {
public InputStream pathRequest(TFTPOptionReadRequestPacket packet);
}
}