@@ -66,13 +66,17 @@ object TensorFlow {
6666 val classLoader = Thread .currentThread.getContextClassLoader
6767
6868 // Check if a TensorFlow native framework library resources are provided and load them.
69- (makeResourceNames(LIB_FRAMEWORK_NAME ) ++ makeResourceNames(LIB_NAME )).map {
70- case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
69+ (makeResourceNames(LIB_FRAMEWORK_NAME ) ++ makeResourceNames(LIB_NAME )).foreach {
70+ case (name, path, preLoad) =>
71+ val resource = extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
72+ if (preLoad) {
73+ resource.foreach(r => System .load(r.toAbsolutePath.toString))
74+ }
7175 }
7276
7377 // Load the TensorFlow JNI bindings from the appropriate resource.
7478 val jniPaths = makeResourceNames(JNI_LIB_NAME ).flatMap {
75- case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
79+ case (name, path, _ ) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
7680 }
7781 if (jniPaths.isEmpty) {
7882 throw new UnsatisfiedLinkError (
@@ -92,7 +96,7 @@ object TensorFlow {
9296
9397 // Load the TensorFlow ops library from the appropriate resource.
9498 val opsPaths = makeResourceNames(OPS_LIB_NAME ).flatMap {
95- case (name, path) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
99+ case (name, path, _ ) => extractResource(name, classLoader.getResourceAsStream(path), tempDirectory)
96100 }
97101 opsPaths.foreach(path => loadOpLibrary(path.toAbsolutePath.toString))
98102 }
@@ -110,29 +114,29 @@ object TensorFlow {
110114
111115 /** Maps the provided library name to a set of filenames, similar to [[System.mapLibraryName ]], but considering all
112116 * combinations of `dylib` and `so` extensions, along with versioning for TensorFlow 2.x. */
113- private def mapLibraryName (lib : String ): Seq [String ] = {
117+ private def mapLibraryName (lib : String ): Seq [( String , Boolean ) ] = {
114118 if (platform == " windows" ) {
115- Seq (s " $lib.dll " , s " $lib.lib " )
119+ Seq (( s " $lib.dll " , true ), ( s " $lib.lib " , false ) )
116120 } else if (lib == JNI_LIB_NAME || lib == OPS_LIB_NAME ) {
117- Seq (s " lib $lib.so " )
121+ Seq (( s " lib $lib.so " , false ) )
118122 } else {
119123 Seq (
120- s " lib $lib.so.link " ,
121- s " lib $lib.so.2.link " ,
122- s " lib $lib.so.2.3.0 " ,
123- s " lib $lib.dylib.link " ,
124- s " lib $lib.2.dylib.link " ,
125- s " lib $lib.2.3.0.dylib " ,
124+ ( s " lib $lib.so.link " , false ) ,
125+ ( s " lib $lib.so.2.link " , false ) ,
126+ ( s " lib $lib.so.2.3.0 " , false ) ,
127+ ( s " lib $lib.dylib.link " , false ) ,
128+ ( s " lib $lib.2.dylib.link " , false ) ,
129+ ( s " lib $lib.2.3.0.dylib " , false ) ,
126130 )
127131 }
128132 }
129133
130134 /** Generates the resource names and paths for the specified library. */
131- private def makeResourceNames (lib : String ): Seq [(String , String )] = {
135+ private def makeResourceNames (lib : String ): Seq [(String , String , Boolean )] = {
132136 if (lib == LIB_NAME || lib == LIB_FRAMEWORK_NAME ) {
133- mapLibraryName(lib).map(name => (name, name))
137+ mapLibraryName(lib).map { case (name, preLoad) => (name, name, preLoad) }
134138 } else {
135- mapLibraryName(lib).map(name => (name, s " native/ $platform/ $name" ))
139+ mapLibraryName(lib).map { case (name, preLoad) => (name, s " native/ $platform/ $name" , preLoad) }
136140 }
137141 }
138142
0 commit comments