Skip to content

Commit 4c1c95a

Browse files
committed
[JNI] Fixed the loading for windows libraries.
1 parent 3f23f5f commit 4c1c95a

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

modules/jni/src/main/scala/org/platanios/tensorflow/jni/TensorFlow.scala

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,17 @@ object TensorFlow {
6666
val classLoader = Thread.currentThread.getContextClassLoader
6767

6868
// Check if a TensorFlow native framework library resources are provided and load them.
69-
(makeResourceNames(LIB_FRAMEWORK_NAME) ++ makeResourceNames(LIB_NAME)).map {
70-
case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
69+
(makeResourceNames(LIB_FRAMEWORK_NAME) ++ makeResourceNames(LIB_NAME)).foreach {
70+
case (name, path, preLoad) =>
71+
val resource = extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
72+
if (preLoad) {
73+
resource.foreach(r => System.load(r.toAbsolutePath.toString))
74+
}
7175
}
7276

7377
// Load the TensorFlow JNI bindings from the appropriate resource.
7478
val jniPaths = makeResourceNames(JNI_LIB_NAME).flatMap {
75-
case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
79+
case (name, path, _) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
7680
}
7781
if (jniPaths.isEmpty) {
7882
throw new UnsatisfiedLinkError(
@@ -92,7 +96,7 @@ object TensorFlow {
9296

9397
// Load the TensorFlow ops library from the appropriate resource.
9498
val opsPaths = makeResourceNames(OPS_LIB_NAME).flatMap {
95-
case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
99+
case (name, path, _) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
96100
}
97101
opsPaths.foreach(path => loadOpLibrary(path.toAbsolutePath.toString))
98102
}
@@ -110,29 +114,29 @@ object TensorFlow {
110114

111115
/** Maps the provided library name to a set of filenames, similar to [[System.mapLibraryName]], but considering all
112116
* combinations of `dylib` and `so` extensions, along with versioning for TensorFlow 2.x. */
113-
private def mapLibraryName(lib: String): Seq[String] = {
117+
private def mapLibraryName(lib: String): Seq[(String, Boolean)] = {
114118
if (platform == "windows") {
115-
Seq(s"$lib.dll", s"$lib.lib")
119+
Seq((s"$lib.dll", true), (s"$lib.lib", false))
116120
} else if (lib == JNI_LIB_NAME || lib == OPS_LIB_NAME) {
117-
Seq(s"lib$lib.so")
121+
Seq((s"lib$lib.so", false))
118122
} else {
119123
Seq(
120-
s"lib$lib.so.link",
121-
s"lib$lib.so.2.link",
122-
s"lib$lib.so.2.3.0",
123-
s"lib$lib.dylib.link",
124-
s"lib$lib.2.dylib.link",
125-
s"lib$lib.2.3.0.dylib",
124+
(s"lib$lib.so.link", false),
125+
(s"lib$lib.so.2.link", false),
126+
(s"lib$lib.so.2.3.0", false),
127+
(s"lib$lib.dylib.link", false),
128+
(s"lib$lib.2.dylib.link", false),
129+
(s"lib$lib.2.3.0.dylib", false),
126130
)
127131
}
128132
}
129133

130134
/** Generates the resource names and paths for the specified library. */
131-
private def makeResourceNames(lib: String): Seq[(String, String)] = {
135+
private def makeResourceNames(lib: String): Seq[(String, String, Boolean)] = {
132136
if (lib == LIB_NAME || lib == LIB_FRAMEWORK_NAME) {
133-
mapLibraryName(lib).map(name => (name, name))
137+
mapLibraryName(lib).map { case (name, preLoad) => (name, name, preLoad) }
134138
} else {
135-
mapLibraryName(lib).map(name => (name, s"native/$platform/$name"))
139+
mapLibraryName(lib).map { case (name, preLoad) => (name, s"native/$platform/$name", preLoad) }
136140
}
137141
}
138142

0 commit comments

Comments
 (0)