|
| 1 | +package edu.illinois.cs.cogcomp.sl.util; |
| 2 | + |
| 3 | +import org.slf4j.Logger; |
| 4 | +import org.slf4j.LoggerFactory; |
| 5 | + |
| 6 | +import java.io.*; |
| 7 | +import java.net.URISyntaxException; |
| 8 | +import java.net.URL; |
| 9 | +import java.net.URLDecoder; |
| 10 | +import java.util.ArrayList; |
| 11 | +import java.util.Enumeration; |
| 12 | +import java.util.List; |
| 13 | +import java.util.jar.JarEntry; |
| 14 | +import java.util.jar.JarFile; |
| 15 | +import java.util.zip.GZIPInputStream; |
| 16 | +import java.util.zip.GZIPOutputStream; |
| 17 | + |
| 18 | +public class WeightVectorUtils { |
| 19 | + |
| 20 | + private final static Logger log = LoggerFactory.getLogger(WeightVectorUtils.class); |
| 21 | + |
| 22 | + public static void save(String fileName, WeightVector wv) throws IOException { |
| 23 | + BufferedOutputStream stream = |
| 24 | + new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(fileName))); |
| 25 | + |
| 26 | + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(stream)); |
| 27 | + |
| 28 | + float[] w = wv.getWeightArray(); |
| 29 | + |
| 30 | + writer.write("WeightVector"); |
| 31 | + writer.newLine(); |
| 32 | + |
| 33 | + writer.write(w.length + ""); |
| 34 | + writer.newLine(); |
| 35 | + |
| 36 | + int numNonZero = 0; |
| 37 | + for (int index = 0; index < w.length; index++) { |
| 38 | + if (w[index] != 0) { |
| 39 | + writer.write(index + ":" + w[index]); |
| 40 | + writer.newLine(); |
| 41 | + numNonZero++; |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + writer.close(); |
| 46 | + |
| 47 | + log.info("Number of non zero weights: " + numNonZero); |
| 48 | + } |
| 49 | + |
| 50 | + public static WeightVector load(String fileName) { |
| 51 | + try { |
| 52 | + GZIPInputStream zipin = new GZIPInputStream(new FileInputStream(fileName)); |
| 53 | + |
| 54 | + BufferedReader reader = new BufferedReader(new InputStreamReader(zipin)); |
| 55 | + |
| 56 | + String line; |
| 57 | + |
| 58 | + line = reader.readLine().trim(); |
| 59 | + if (!line.equals("WeightVector")) { |
| 60 | + reader.close(); |
| 61 | + throw new IOException("Invalid model file."); |
| 62 | + } |
| 63 | + |
| 64 | + line = reader.readLine().trim(); |
| 65 | + int size = Integer.parseInt(line); |
| 66 | + |
| 67 | + WeightVector w = new WeightVector(size); |
| 68 | + |
| 69 | + while ((line = reader.readLine()) != null) { |
| 70 | + line = line.trim(); |
| 71 | + String[] parts = line.split(":"); |
| 72 | + int index = Integer.parseInt(parts[0]); |
| 73 | + float value = Float.parseFloat(parts[1]); |
| 74 | + w.setElement(index, value); |
| 75 | + } |
| 76 | + |
| 77 | + zipin.close(); |
| 78 | + |
| 79 | + return w; |
| 80 | + } catch (Exception e) { |
| 81 | + log.error("Error loading model file {}", fileName); |
| 82 | + System.exit(-1); |
| 83 | + } |
| 84 | + return null; |
| 85 | + } |
| 86 | + |
| 87 | + public static WeightVector loadWeightVectorFromClassPath(String fileName) { |
| 88 | + try { |
| 89 | + Class<WeightVectorUtils> clazz = WeightVectorUtils.class; |
| 90 | + List<URL> list = lsResources(clazz, fileName); |
| 91 | + |
| 92 | + if (list.size() == 0) { |
| 93 | + log.error("File {} not found on the classpath", fileName); |
| 94 | + throw new Exception("File not found on classpath"); |
| 95 | + } |
| 96 | + InputStream stream = list.get(0).openStream(); |
| 97 | + |
| 98 | + GZIPInputStream zipin = new GZIPInputStream(stream); |
| 99 | + |
| 100 | + BufferedReader reader = new BufferedReader(new InputStreamReader(zipin)); |
| 101 | + |
| 102 | + String line; |
| 103 | + |
| 104 | + line = reader.readLine().trim(); |
| 105 | + if (!line.equals("WeightVector")) { |
| 106 | + reader.close(); |
| 107 | + throw new IOException("Invalid model file."); |
| 108 | + } |
| 109 | + |
| 110 | + line = reader.readLine().trim(); |
| 111 | + int size = Integer.parseInt(line); |
| 112 | + |
| 113 | + WeightVector w = new WeightVector(size); |
| 114 | + |
| 115 | + while ((line = reader.readLine()) != null) { |
| 116 | + line = line.trim(); |
| 117 | + String[] parts = line.split(":"); |
| 118 | + int index = Integer.parseInt(parts[0]); |
| 119 | + float value = Float.parseFloat(parts[1]); |
| 120 | + w.setElement(index, value); |
| 121 | + } |
| 122 | + |
| 123 | + zipin.close(); |
| 124 | + return w; |
| 125 | + } catch (Exception e) { |
| 126 | + log.error("Error loading model file {}", fileName); |
| 127 | + System.exit(-1); |
| 128 | + } |
| 129 | + return null; |
| 130 | + } |
| 131 | + |
| 132 | + |
| 133 | + /** |
| 134 | + * Lists resources that are contained within a path. This works for any resource on the |
| 135 | + * classpath, either in the file system or in a jar file. The function returns a list of URLs, |
| 136 | + * connections to which can be opened for reading. |
| 137 | + * <p> |
| 138 | + * <b>NB</b>: This method works only for full file names. If you need to list the files of a |
| 139 | + * directory contained in the classpath use lsResourcesDir(Class, String) in illinois-core-utilities |
| 140 | + * |
| 141 | + * @param clazz The class whose path is scanned |
| 142 | + * @param path The name of the resource(s) to be returned |
| 143 | + * @return A list of URLs |
| 144 | + */ |
| 145 | + public static List<URL> lsResources(Class clazz, String path) throws URISyntaxException, |
| 146 | + IOException { |
| 147 | + URL dirURL = clazz.getResource(path); |
| 148 | + |
| 149 | + if (dirURL == null) { |
| 150 | + ClassLoader loader = Thread.currentThread().getContextClassLoader(); |
| 151 | + dirURL = loader.getResource(path); |
| 152 | + } |
| 153 | + |
| 154 | + if (dirURL == null) { |
| 155 | + return new ArrayList<>(); |
| 156 | + } |
| 157 | + |
| 158 | + String dirPath = dirURL.getPath(); |
| 159 | + if (dirURL.getProtocol().equals("file")) { |
| 160 | + String[] list = new File(dirURL.toURI()).list(); |
| 161 | + List<URL> urls = new ArrayList<>(); |
| 162 | + |
| 163 | + if (list == null) { |
| 164 | + // if the list is null, but the dirURL is not, then dirURL is |
| 165 | + // actually a file! |
| 166 | + urls.add(dirURL); |
| 167 | + } else { |
| 168 | + for (String l : list) { |
| 169 | + URL url = (new File(dirPath + File.separator + l)).toURI().toURL(); |
| 170 | + urls.add(url); |
| 171 | + } |
| 172 | + } |
| 173 | + return urls; |
| 174 | + } |
| 175 | + |
| 176 | + if (dirURL.getProtocol().equals("jar")) { |
| 177 | + int exclamation = dirPath.indexOf("!"); |
| 178 | + String jarPath = dirPath.substring(5, exclamation); |
| 179 | + String jarRoot = dirPath.substring(0, exclamation + 1); |
| 180 | + |
| 181 | + JarFile jar = new JarFile(URLDecoder.decode(jarPath, "UTF-8")); |
| 182 | + Enumeration<JarEntry> entries = jar.entries(); |
| 183 | + |
| 184 | + List<URL> urls = new ArrayList<>(); |
| 185 | + while (entries.hasMoreElements()) { |
| 186 | + JarEntry element = entries.nextElement(); |
| 187 | + |
| 188 | + String name = element.getName(); |
| 189 | + |
| 190 | + // Because the path string comes from JarEntry, We SHOULD use |
| 191 | + // '/'' instead of File.SEPERATOR. |
| 192 | + // And it seems that the only way to figure out if a JarEntry |
| 193 | + // path is a folder or file is to check the last character. |
| 194 | + if (name.startsWith(path) && !name.equals(path + "/")) { |
| 195 | + URL url = new URL("jar:" + jarRoot + "/" + name); |
| 196 | + urls.add(url); |
| 197 | + } |
| 198 | + } |
| 199 | + return urls; |
| 200 | + } |
| 201 | + throw new UnsupportedOperationException("Cannot list files for URL " + dirURL); |
| 202 | + } |
| 203 | +} |
0 commit comments