Skip to content

Commit 8b3a832

Browse files
authored
Merge pull request #20 from danyaljj/weight-vector-utils
Weight vector utils
2 parents f8a7e2c + 0628dfb commit 8b3a832

File tree

3 files changed

+204
-2
lines changed

3 files changed

+204
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Maven Coordinates
1616
-----------------
1717
To use Illinois-SL in your project add the following to your pom,
1818

19-
```
19+
```xml
2020
<dependencies>
2121
...
2222
<dependency>

pom.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
</repository>
2828
</distributionManagement>
2929

30-
3130
<build>
3231
<plugins>
3332
<plugin>
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package edu.illinois.cs.cogcomp.sl.util;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
6+
import java.io.*;
7+
import java.net.URISyntaxException;
8+
import java.net.URL;
9+
import java.net.URLDecoder;
10+
import java.util.ArrayList;
11+
import java.util.Enumeration;
12+
import java.util.List;
13+
import java.util.jar.JarEntry;
14+
import java.util.jar.JarFile;
15+
import java.util.zip.GZIPInputStream;
16+
import java.util.zip.GZIPOutputStream;
17+
18+
public class WeightVectorUtils {
19+
20+
private final static Logger log = LoggerFactory.getLogger(WeightVectorUtils.class);
21+
22+
public static void save(String fileName, WeightVector wv) throws IOException {
23+
BufferedOutputStream stream =
24+
new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(fileName)));
25+
26+
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(stream));
27+
28+
float[] w = wv.getWeightArray();
29+
30+
writer.write("WeightVector");
31+
writer.newLine();
32+
33+
writer.write(w.length + "");
34+
writer.newLine();
35+
36+
int numNonZero = 0;
37+
for (int index = 0; index < w.length; index++) {
38+
if (w[index] != 0) {
39+
writer.write(index + ":" + w[index]);
40+
writer.newLine();
41+
numNonZero++;
42+
}
43+
}
44+
45+
writer.close();
46+
47+
log.info("Number of non zero weights: " + numNonZero);
48+
}
49+
50+
public static WeightVector load(String fileName) {
51+
try {
52+
GZIPInputStream zipin = new GZIPInputStream(new FileInputStream(fileName));
53+
54+
BufferedReader reader = new BufferedReader(new InputStreamReader(zipin));
55+
56+
String line;
57+
58+
line = reader.readLine().trim();
59+
if (!line.equals("WeightVector")) {
60+
reader.close();
61+
throw new IOException("Invalid model file.");
62+
}
63+
64+
line = reader.readLine().trim();
65+
int size = Integer.parseInt(line);
66+
67+
WeightVector w = new WeightVector(size);
68+
69+
while ((line = reader.readLine()) != null) {
70+
line = line.trim();
71+
String[] parts = line.split(":");
72+
int index = Integer.parseInt(parts[0]);
73+
float value = Float.parseFloat(parts[1]);
74+
w.setElement(index, value);
75+
}
76+
77+
zipin.close();
78+
79+
return w;
80+
} catch (Exception e) {
81+
log.error("Error loading model file {}", fileName);
82+
System.exit(-1);
83+
}
84+
return null;
85+
}
86+
87+
public static WeightVector loadWeightVectorFromClassPath(String fileName) {
88+
try {
89+
Class<WeightVectorUtils> clazz = WeightVectorUtils.class;
90+
List<URL> list = lsResources(clazz, fileName);
91+
92+
if (list.size() == 0) {
93+
log.error("File {} not found on the classpath", fileName);
94+
throw new Exception("File not found on classpath");
95+
}
96+
InputStream stream = list.get(0).openStream();
97+
98+
GZIPInputStream zipin = new GZIPInputStream(stream);
99+
100+
BufferedReader reader = new BufferedReader(new InputStreamReader(zipin));
101+
102+
String line;
103+
104+
line = reader.readLine().trim();
105+
if (!line.equals("WeightVector")) {
106+
reader.close();
107+
throw new IOException("Invalid model file.");
108+
}
109+
110+
line = reader.readLine().trim();
111+
int size = Integer.parseInt(line);
112+
113+
WeightVector w = new WeightVector(size);
114+
115+
while ((line = reader.readLine()) != null) {
116+
line = line.trim();
117+
String[] parts = line.split(":");
118+
int index = Integer.parseInt(parts[0]);
119+
float value = Float.parseFloat(parts[1]);
120+
w.setElement(index, value);
121+
}
122+
123+
zipin.close();
124+
return w;
125+
} catch (Exception e) {
126+
log.error("Error loading model file {}", fileName);
127+
System.exit(-1);
128+
}
129+
return null;
130+
}
131+
132+
133+
/**
134+
* Lists resources that are contained within a path. This works for any resource on the
135+
* classpath, either in the file system or in a jar file. The function returns a list of URLs,
136+
* connections to which can be opened for reading.
137+
* <p>
138+
* <b>NB</b>: This method works only for full file names. If you need to list the files of a
139+
* directory contained in the classpath use lsResourcesDir(Class, String) in illinois-core-utilities
140+
*
141+
* @param clazz The class whose path is scanned
142+
* @param path The name of the resource(s) to be returned
143+
* @return A list of URLs
144+
*/
145+
public static List<URL> lsResources(Class clazz, String path) throws URISyntaxException,
146+
IOException {
147+
URL dirURL = clazz.getResource(path);
148+
149+
if (dirURL == null) {
150+
ClassLoader loader = Thread.currentThread().getContextClassLoader();
151+
dirURL = loader.getResource(path);
152+
}
153+
154+
if (dirURL == null) {
155+
return new ArrayList<>();
156+
}
157+
158+
String dirPath = dirURL.getPath();
159+
if (dirURL.getProtocol().equals("file")) {
160+
String[] list = new File(dirURL.toURI()).list();
161+
List<URL> urls = new ArrayList<>();
162+
163+
if (list == null) {
164+
// if the list is null, but the dirURL is not, then dirURL is
165+
// actually a file!
166+
urls.add(dirURL);
167+
} else {
168+
for (String l : list) {
169+
URL url = (new File(dirPath + File.separator + l)).toURI().toURL();
170+
urls.add(url);
171+
}
172+
}
173+
return urls;
174+
}
175+
176+
if (dirURL.getProtocol().equals("jar")) {
177+
int exclamation = dirPath.indexOf("!");
178+
String jarPath = dirPath.substring(5, exclamation);
179+
String jarRoot = dirPath.substring(0, exclamation + 1);
180+
181+
JarFile jar = new JarFile(URLDecoder.decode(jarPath, "UTF-8"));
182+
Enumeration<JarEntry> entries = jar.entries();
183+
184+
List<URL> urls = new ArrayList<>();
185+
while (entries.hasMoreElements()) {
186+
JarEntry element = entries.nextElement();
187+
188+
String name = element.getName();
189+
190+
// Because the path string comes from JarEntry, We SHOULD use
191+
// '/'' instead of File.SEPERATOR.
192+
// And it seems that the only way to figure out if a JarEntry
193+
// path is a folder or file is to check the last character.
194+
if (name.startsWith(path) && !name.equals(path + "/")) {
195+
URL url = new URL("jar:" + jarRoot + "/" + name);
196+
urls.add(url);
197+
}
198+
}
199+
return urls;
200+
}
201+
throw new UnsupportedOperationException("Cannot list files for URL " + dirURL);
202+
}
203+
}

0 commit comments

Comments
 (0)