1+ /*
2+ * Copyright 2024 The Android Open Source Project
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * https://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
17+ package com.example.platform.media.video
18+
19+ import android.content.Context
20+ import android.graphics.Bitmap
21+ import android.graphics.BitmapFactory
22+ import android.graphics.Matrix
23+ import androidx.media3.common.GlTextureInfo
24+ import androidx.media3.common.VideoFrameProcessingException
25+ import androidx.media3.common.util.GlRect
26+ import androidx.media3.common.util.GlUtil
27+ import androidx.media3.common.util.Size
28+ import androidx.media3.common.util.UnstableApi
29+ import androidx.media3.common.util.Util
30+ import androidx.media3.effect.ByteBufferGlEffect
31+ import com.google.common.collect.ImmutableMap
32+ import com.google.common.util.concurrent.ListenableFuture
33+ import com.google.common.util.concurrent.ListeningExecutorService
34+ import com.google.common.util.concurrent.MoreExecutors
35+ import org.tensorflow.lite.DataType
36+ import org.tensorflow.lite.Interpreter
37+ import org.tensorflow.lite.InterpreterApi
38+ import org.tensorflow.lite.gpu.CompatibilityList
39+ import org.tensorflow.lite.gpu.GpuDelegate
40+ import org.tensorflow.lite.support.common.FileUtil
41+ import org.tensorflow.lite.support.common.ops.DequantizeOp
42+ import org.tensorflow.lite.support.common.ops.NormalizeOp
43+ import org.tensorflow.lite.support.image.ImageProcessor
44+ import org.tensorflow.lite.support.image.TensorImage
45+ import org.tensorflow.lite.support.image.ops.ResizeOp
46+ import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
47+ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
48+ import java.util.concurrent.Future
49+
50+ @UnstableApi
51+ class StyleTransferEffect (context : Context , styleAssetFileName : String ) : ByteBufferGlEffect.Processor<Bitmap> {
52+
53+ private val transformInterpreter: InterpreterApi
54+ private val inputTransformTargetHeight: Int
55+ private val inputTransformTargetWidth: Int
56+ private val outputTransformShape: IntArray
57+
58+ private var preProcess: ListeningExecutorService = MoreExecutors .listeningDecorator(
59+ Util .newSingleThreadExecutor(" preProcess" ))
60+ private var postProcess: ListeningExecutorService = MoreExecutors .listeningDecorator(
61+ Util .newSingleThreadExecutor(" postProcess" ))
62+ private var tfRun: ListeningExecutorService = MoreExecutors .listeningDecorator(
63+ Util .newSingleThreadExecutor(" tfRun" ))
64+
65+ private val predictOutput: TensorBuffer
66+
67+ private var inputWidth: Int = 0
68+ private var inputHeight: Int = 0
69+
70+
71+ init {
72+ val options = Interpreter .Options ()
73+ val compatibilityList = CompatibilityList ()
74+ val gpuDelegateOptions = compatibilityList.bestOptionsForThisDevice
75+ val gpuDelegate = GpuDelegate (gpuDelegateOptions)
76+ options.addDelegate(gpuDelegate)
77+ val predictModel = " predict_float16.tflite"
78+ val transferModel = " transfer_float16.tflite"
79+ val predictInterpreter = Interpreter (FileUtil .loadMappedFile(context, predictModel), options)
80+ transformInterpreter = InterpreterApi .create(FileUtil .loadMappedFile(context, transferModel), options)
81+ val inputPredictTargetHeight = predictInterpreter.getInputTensor(0 ).shape()[1 ]
82+ val inputPredictTargetWidth = predictInterpreter.getInputTensor(0 ).shape()[2 ]
83+ val outputPredictShape = predictInterpreter.getOutputTensor(0 ).shape()
84+
85+ inputTransformTargetHeight = transformInterpreter.getInputTensor(0 ).shape()[1 ]
86+ inputTransformTargetWidth = transformInterpreter.getInputTensor(0 ).shape()[2 ]
87+ outputTransformShape = transformInterpreter.getOutputTensor(0 ).shape()
88+
89+ val inputStream = context.assets.open(styleAssetFileName)
90+ val styleImage = BitmapFactory .decodeStream(inputStream)
91+ inputStream.close()
92+ val styleTensorImage = getScaledTensorImage(styleImage, inputPredictTargetWidth, inputPredictTargetHeight)
93+ predictOutput = TensorBuffer .createFixedSize(outputPredictShape, DataType .FLOAT32 )
94+ predictInterpreter.run (styleTensorImage.buffer, predictOutput.buffer)
95+ }
96+
97+ override fun configure (inputWidth : Int , inputHeight : Int ): Size {
98+ this .inputWidth = inputWidth
99+ this .inputHeight = inputHeight
100+ return Size (inputTransformTargetWidth, inputTransformTargetHeight)
101+ }
102+
103+ override fun getScaledRegion (presentationTimeUs : Long ): GlRect {
104+ val minSide = minOf(inputWidth, inputHeight)
105+ return GlRect (0 , 0 , minSide, minSide)
106+ }
107+
108+ override fun processImage (
109+ image : ByteBufferGlEffect .Image ,
110+ presentationTimeUs : Long ,
111+ ): ListenableFuture <Bitmap > {
112+ val tensorImageFuture = preProcess(image)
113+ val tensorBufferFuture = tfRun(tensorImageFuture)
114+ return postProcess(tensorBufferFuture)
115+ }
116+
117+ override fun release () {}
118+
119+ override fun finishProcessingAndBlend (
120+ outputFrame : GlTextureInfo ,
121+ presentationTimeUs : Long ,
122+ result : Bitmap ,
123+ ) {
124+ try {
125+ copyBitmapToFbo(result, outputFrame, getScaledRegion(presentationTimeUs))
126+ } catch (e: GlUtil .GlException ) {
127+ throw VideoFrameProcessingException .from(e)
128+ }
129+ }
130+
131+ private fun preProcess (image : ByteBufferGlEffect .Image ): ListenableFuture <TensorImage > {
132+ return preProcess.submit<TensorImage > {
133+ val bitmap = image.copyToBitmap()
134+ getScaledTensorImage(bitmap, inputTransformTargetWidth, inputTransformTargetHeight)
135+ }
136+ }
137+
138+ private fun tfRun (tensorImageFuture : Future <TensorImage >): ListenableFuture <TensorBuffer > {
139+ return tfRun.submit<TensorBuffer > {
140+ val tensorImage = tensorImageFuture.get()
141+ val outputImage = TensorBuffer .createFixedSize(outputTransformShape, DataType .FLOAT32 )
142+
143+ transformInterpreter.runForMultipleInputsOutputs(
144+ arrayOf(tensorImage.buffer, predictOutput.buffer),
145+ ImmutableMap .builder<Int , Any >().put(0 , outputImage.buffer).build()
146+ )
147+ outputImage
148+ }
149+ }
150+
151+ private fun postProcess (futureOutputImage : ListenableFuture <TensorBuffer >): ListenableFuture <Bitmap > {
152+ return postProcess.submit<Bitmap > {
153+ val outputImage = futureOutputImage.get()
154+ val imagePostProcessor = ImageProcessor .Builder ()
155+ .add(DequantizeOp (0f , 255f ))
156+ .build()
157+ val outputTensorImage = TensorImage (DataType .FLOAT32 )
158+ outputTensorImage.load(outputImage)
159+ imagePostProcessor.process(outputTensorImage).bitmap
160+ }
161+ }
162+
163+ private fun getScaledTensorImage (bitmap : Bitmap , targetWidth : Int , targetHeight : Int ): TensorImage {
164+ val cropSize = minOf(bitmap.width, bitmap.height)
165+ val imageProcessor = ImageProcessor .Builder ()
166+ .add(ResizeWithCropOrPadOp (cropSize, cropSize))
167+ .add(ResizeOp (targetHeight, targetWidth, ResizeOp .ResizeMethod .BILINEAR ))
168+ .add(NormalizeOp (0f , 255f ))
169+ .build()
170+ val tensorImage = TensorImage (DataType .FLOAT32 )
171+ tensorImage.load(bitmap)
172+ return imageProcessor.process(tensorImage)
173+ }
174+
175+ private fun copyBitmapToFbo (bitmap : Bitmap , textureInfo : GlTextureInfo , rect : GlRect ) {
176+ val bitmapToGl = Matrix ().apply { setScale(1f , - 1f ) }
177+ val texId = GlUtil .createTexture(bitmap.width, bitmap.height, false )
178+ val fboId = GlUtil .createFboForTexture(texId)
179+ GlUtil .setTexture(texId,
180+ Bitmap .createBitmap(bitmap, 0 , 0 , bitmap.width, bitmap.height, bitmapToGl, true ))
181+ GlUtil .blitFrameBuffer(fboId, GlRect (0 , 0 , bitmap.width, bitmap.height), textureInfo.fboId, rect)
182+ GlUtil .deleteTexture(texId)
183+ GlUtil .deleteFbo(fboId)
184+ }
185+ }
0 commit comments