/* * Copyright (C) 2010-2014 Mathias Unberath * CONRAD is developed as an Open Source project under the GNU General Public License (GPL). */ package edu.stanford.rsl.apps.activeshapemodel; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.StringTokenizer; import edu.stanford.rsl.conrad.geometry.shapes.activeshapemodels.GPA; import edu.stanford.rsl.conrad.geometry.shapes.activeshapemodels.KPCA; import edu.stanford.rsl.conrad.geometry.shapes.activeshapemodels.PCA; import edu.stanford.rsl.conrad.geometry.shapes.activeshapemodels.kernels.GaussianKernel; import edu.stanford.rsl.conrad.geometry.shapes.mesh.DataMatrix; import edu.stanford.rsl.conrad.geometry.shapes.mesh.Mesh; import edu.stanford.rsl.conrad.io.ParseInfoFile; import edu.stanford.rsl.conrad.numerics.SimpleMatrix; public class CreateARFF { /** * Lists all the heart components available for modeling. * @author Mathias Unberath * */ public enum heartComponents{ //MYOCARDIUM("myocardium", 4660, 9316), LEFT_VENTRICLE("leftVentricle", 4038, 8072); private String filename; private int numVertices; private int numTriangles; private heartComponents(String fn, int numVertices, int numTriangles){ this.filename = fn; this.numVertices = numVertices; this.numTriangles = numTriangles; } public String getName(){ return this.filename; } public int getNumVertices(){ return this.numVertices; } public int getNumTriangles(){ return this.numTriangles; } } /** * Path to model data, i.e. the meshes stored in folders corresponding to their heart phase. * The naming convention needs to follow: * .../study_id/.../phase_#/meshname.vtk */ public static final File DATA_PATH = new File("E:\\_uni_\\Masterthesis\\Data\\"); /** * Path to the folder, where the heart model PCA files are stored. */ public static final String HEART_MODEL_BASE = "C:\\research\\data\\Test\\WEKA\\"; /** * Number of phases obtained in the dynamic CT scan. */ public static final int numPhases = 10; /** * Number of components in the whole heart model. */ private static final int numModelComponents = 2; /** * The vertex dimension of the model's vertices. */ public static final int vertexDimension = 3; /** * Variation threshold. */ static double variationTh = 0.9; /** * Keyword indicating folders containing phases. */ private static final String PHASE_KEY = "phase_"; /** * Keyword indicating folders containing phase-folders that have already been segmented and contain meshes. */ private static final String ANALYSIS_KEY = "analysis"; private static int[] vertexOffs = new int[heartComponents.values().length]; private static int totalVertices; private static int[] triangleOffs = new int[heartComponents.values().length]; private static int totalTriangles; private static int[] principalComponents = new int[numPhases]; @SuppressWarnings("unused") private static double radialKernel = 5; private static int numComponents = 2; //========================================================================================== // METHODS //========================================================================================== public static void main(String[] args) throws Exception{ /* radialKernel = UserUtil.queryDouble("Specify sigma for Gaussian kernel:", 5.0); numComponents = UserUtil.queryInt("Specify number of principal components for projetion:", 3); variationTh = UserUtil.queryDouble("Select variance threshold for principal component dimensionality reduction:", 0.9); */ // initialize offset array for easier use later on int c = 0; totalVertices = 0; totalTriangles = 0; for(heartComponents hc : heartComponents.values()){ vertexOffs[c] = totalVertices; triangleOffs[c] = totalTriangles; totalVertices += hc.getNumVertices(); totalTriangles += hc.getNumTriangles(); c++; } // get all valid folders ArrayList<String> folders = getValidFolders(); // run GPA and PCA ArrayList< double[] > scores = performPCA(folders); String pFn = HEART_MODEL_BASE + "CCmExamples.ccm"; writeScores(pFn, folders, scores); writeARFF(folders,scores); } /** * This method performs GPA and PCA on all heart model components. It uses data from all folders passed to it via the folders list. * @param folders The folders being used for input. */ private static ArrayList< double[] > performPCA(ArrayList<String> folders){ // output folder will be named: CardiacModel //String outFolder = HEART_MODEL_BASE + "\\" + "CardiacModel\\"; new File(HEART_MODEL_BASE).mkdirs(); // create ONE mesh object at each phase for each training set and perform GPA and PCA on ALL heart components at the same time for(int i = 0; i < numPhases; i++){ GPA populationGPA = new GPA(folders.size()); for(int j = 0; j < folders.size(); j++){ int componentCount = 0; SimpleMatrix currentVert = new SimpleMatrix(totalVertices,3); if(j == 0){ SimpleMatrix currentTriangles = new SimpleMatrix(totalTriangles,3); for(heartComponents hc : heartComponents.values()){ String currentComponent = "wrp_" + hc.getName() + ".vtk"; String currentFile = folders.get(j) + "\\" + PHASE_KEY + i + "\\" + currentComponent; Mesh currentMesh = new Mesh(currentFile); currentVert.setSubMatrixValue(vertexOffs[componentCount], 0, currentMesh.getPoints()); currentTriangles.setSubMatrixValue(triangleOffs[componentCount], 0, currentMesh.getConnectivity()); componentCount++; } populationGPA.setConnectivity(currentTriangles); }else{ for(heartComponents hc : heartComponents.values()){ String currentComponent = "wrp_" + hc.getName() + ".vtk"; String currentFile = folders.get(j) + "\\" + PHASE_KEY + i + "\\" + currentComponent; Mesh currentMesh = new Mesh(currentFile); currentVert.setSubMatrixValue(vertexOffs[componentCount], 0, currentMesh.getPoints()); componentCount++; } } populationGPA.addElement(j, currentVert); } populationGPA.runGPA(); // create PCA file for each phase KPCA kPCA = new KPCA(new DataMatrix(populationGPA)); //kPCA.setKernel( new RadialKernel(900) ); //kPCA.setKernel( new PolynomialKernel(1,1,0) ); kPCA.setKernel( new GaussianKernel(25) ); kPCA.numProjections = numComponents; kPCA.run(); principalComponents[i] = kPCA.numProjections; // project each shape System.out.println("Projecting shapes."); kPCA.projectTrainingSets(); ArrayList< double[] > scores = new ArrayList< double[] >(); for(int j = 0; j < folders.size(); j++){ scores.add(kPCA.scores.getCol(j).copyAsDoubleArray()); } String scoresFilename = HEART_MODEL_BASE + "\\" + "phase_" + Integer.valueOf(i) + ".ccs"; writeScores(scoresFilename, folders, scores); System.out.println("______________________________________"); System.out.println("Finished work on phase: " + i); } System.out.println("\n\n"); System.out.println("Parameter PCA starting."); int totPC = 0; for(int i = 0; i < numPhases; i++){ totPC += principalComponents[i]; } SimpleMatrix scores = new SimpleMatrix(totPC, folders.size()); int offs = 0; for(int i = 0; i < numPhases; i++){ offs += (i == 0) ? 0 : principalComponents[i-1]; String cnfgFile = HEART_MODEL_BASE + "\\" + "phase_" + Integer.valueOf(i) + ".ccs"; scores.setSubMatrixValue(offs, 0, parseScores(cnfgFile)); } PCA scorePCA = new PCA(scores, 1); scorePCA.variationThreshold = variationTh; //scorePCA.setKernel( new GaussianKernel(Math.sqrt(2)*2)); //scorePCA.numProjections = 6; scorePCA.run(); //scorePCA.projectTrainingSets(); //SimpleMatrix s2 = scorePCA.scores; ArrayList< double[] > proj = new ArrayList< double[] >(); for(int i = 0; i < folders.size(); i++){ //proj.add(s2.getCol(i).copyAsDoubleArray()); proj.add(scorePCA.projectTrainingShape(i)); } System.out.println("Done."); return proj; } /** * Reads the scores for each data set from a .ccs file. * @param filename * @return */ private static SimpleMatrix parseScores(String filename){ try { FileReader fr = new FileReader(filename); BufferedReader br = new BufferedReader(fr); String line = br.readLine(); StringTokenizer tok = new StringTokenizer(line); tok.nextToken(); // skip "NUM_SAMPLES:" int numSamples = Integer.parseInt(tok.nextToken()); line = br.readLine(); tok = new StringTokenizer(line); tok.nextToken(); // skip "NUM_PRINCIPAL_COMPONENTS:" int numPC = Integer.parseInt(tok.nextToken()); SimpleMatrix m = new SimpleMatrix(numPC, numSamples); for(int i = 0; i < numSamples; i++){ line = br.readLine(); tok = new StringTokenizer(line); tok.nextToken(); // skip "<STUDY_NAME>:" for(int j = 0; j < numPC; j++){ m.setElementValue(j, i, Double.parseDouble(tok.nextToken())); } } br.close(); fr.close(); return m; } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return new SimpleMatrix(); } private static void writeARFF(ArrayList<String> folders, ArrayList<double[]> scores ){ System.out.println("\n\n"); System.out.println("Writing file."); // output folder will be named: CardiacModel String outFolder = HEART_MODEL_BASE + "\\"; new File(outFolder).mkdirs(); String att = "@ATTRIBUTE "; try { PrintWriter writer = new PrintWriter(outFolder + "DilatedCardiomyopathy.arff","UTF-8"); writer.println("@RELATION " + "DCM"); writer.println(); //writer.println(att + "GENDER"+ "\t" +"{M,F}"); //writer.println(att + "AGE"+ "\t" +"NUMERIC"); for(int i = 0; i < scores.get(0).length; i++){ String line = att + "PC-" + Integer.valueOf(i) + "\t" + "NUMERIC"; writer.println(line); } // NAD nothing abnormal detected, HT hypertrophic cardiomyopathy, DC dilated cardiomyopathy writer.println(att + "DILATED"+ "\t" +"{NAD,DC}"); writer.println("@DATA"); for(int i = 0; i < folders.size(); i++){ String infoFile = folders.get(i).substring(0, folders.get(0).lastIndexOf("\\")+1) + "info.txt"; ParseInfoFile info = new ParseInfoFile(infoFile); String line = "";//info.gender +"," + info.age; for(int k = 0; k < scores.get(0).length; k++){ if(k==0){ line += scores.get(i)[k]; }else{ line += "," + scores.get(i)[k]; } } line += "," + info.attribute; writer.println(line); } writer.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } System.out.println("Done."); } /** * Writes the scores to file. * @param filename * @param names * @param scores */ private static void writeScores(String filename, ArrayList<String> names, ArrayList<double[]> scores){ try { PrintWriter writer = new PrintWriter(filename,"UTF-8"); writer.println("NUM_SAMPLES: " + names.size()); writer.println("NUM_PRINCIPAL_COMPONENTS: " + scores.get(0).length); for(int i = 0; i < names.size(); i++){ String line = names.get(i); for( int j = 0; j < scores.get(0).length; j++){ line += " " + Double.valueOf(scores.get(i)[j]); } writer.println(line); } writer.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } } /** * Searches all sub-directories for the keyword ANALYSIS_KEY. This keyword indicates, that registration has been performed and meshes exist and * are assumed to be stored in this folder. * @return An ArrayList containing the folders. */ private static ArrayList<String> getValidFolders(){ String[] fl = DATA_PATH.list(); ArrayList<String> validFolders = new ArrayList<String>(); for(int i = 0; i < fl.length; i++){ File fi = new File(DATA_PATH + "\\" + fl[i]); if(fi.isDirectory()){ String[] list = fi.list(); for(int j = 0; j < list.length; j++){ if(list[j].contains(ANALYSIS_KEY)){ validFolders.add(DATA_PATH + "\\" + fl[i] + "\\" + list[j]); } } } } return checkIfAllFoldersValid(validFolders); } /** * This method checks if all directories in the ArrayList contain the necessary amount of phases. * @param f The ArrayList of directories to be checked. * @return An ArrayList containing only the valid directories. */ private static ArrayList<String> checkIfAllFoldersValid(ArrayList<String> f){ ArrayList<String> fl = new ArrayList<String>(); for(int i = 0; i < f.size(); i++){ File file = new File(f.get(i)); String[] list = file.list(); int[] count = new int[numPhases]; int cnt = 0; for(int j = 0; j < list.length; j++){ File check = new File(f.get(i) + "\\" + list[j]); if(check.isDirectory() && check.getName().contains(PHASE_KEY)){ int strPos = check.getName().indexOf(PHASE_KEY) + PHASE_KEY.length(); int idx = Integer.valueOf(check.getName().substring(strPos)); count[idx] = 1; } } for(int k = 0; k < numPhases; k++){ cnt += count[k]; } if(cnt == numPhases){ fl.add(f.get(i)); }else{ System.out.println("Missing files in dataset: " + f.get(i)); } } return fl; } }