1+ @file:OptIn(RequiresKotlinCompilerEmbeddable ::class )
2+
13package kotlinx.benchmark.gradle
24
35import com.squareup.kotlinpoet.*
6+ import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
7+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.measureAnnotationFQN
8+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.paramAnnotationFQN
9+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.setupAnnotationFQN
10+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.teardownAnnotationFQN
11+ import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.warmupAnnotationFQN
12+ import kotlinx.benchmark.gradle.internal.generator.RequiresKotlinCompilerEmbeddable
413import java.io.File
514import java.util.*
615
@@ -10,7 +19,11 @@ internal fun generateBenchmarkSourceFiles(
1019) {
1120 classDescriptors.forEach { descriptor ->
1221 if (descriptor.visibility == Visibility .PUBLIC && ! descriptor.isAbstract) {
13- generateDescriptorFile(descriptor, targetDir)
22+ if (descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty()) {
23+ generateParameterizedDescriptorFile(descriptor, targetDir)
24+ } else {
25+ generateDescriptorFile(descriptor, targetDir)
26+ }
1427 }
1528 }
1629}
@@ -27,6 +40,12 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
2740 .addImport(" androidx.benchmark" , " BenchmarkState" )
2841 .addImport(" androidx.benchmark" , " ExperimentalBenchmarkStateApi" )
2942
43+ if (descriptor.hasSetupOrTeardownMethods()) {
44+ fileSpecBuilder
45+ .addImport(" org.junit" , " Before" )
46+ .addImport(" org.junit" , " After" )
47+ }
48+
3049 val typeSpecBuilder = TypeSpec .classBuilder(descriptorName)
3150 .addAnnotation(
3251 AnnotationSpec .builder(ClassName (" org.junit.runner" , " RunWith" ))
@@ -40,7 +59,122 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
4059 fileSpecBuilder.build().writeTo(androidTestDir)
4160}
4261
43- private fun addBenchmarkMethods (typeSpecBuilder : TypeSpec .Builder , descriptor : ClassAnnotationsDescriptor ) {
62+ private fun generateParameterizedDescriptorFile (descriptor : ClassAnnotationsDescriptor , androidTestDir : File ) {
63+ val descriptorName = " ${descriptor.name} _Descriptor"
64+ val packageName = descriptor.packageName
65+ val fileSpecBuilder = FileSpec .builder(packageName, descriptorName)
66+ .addImport(" org.junit.runner" , " RunWith" )
67+ .addImport(" org.junit.runners" , " Parameterized" )
68+ .addImport(" androidx.benchmark" , " BenchmarkState" )
69+ .addImport(" androidx.benchmark" , " ExperimentalBenchmarkStateApi" )
70+ .addImport(" org.junit" , " Test" )
71+
72+ if (descriptor.hasSetupOrTeardownMethods()) {
73+ fileSpecBuilder
74+ .addImport(" org.junit" , " Before" )
75+ .addImport(" org.junit" , " After" )
76+ }
77+
78+ fileSpecBuilder.addAnnotation(
79+ AnnotationSpec .builder(ClassName (" org.junit.runner" , " RunWith" ))
80+ .addMember(" %T::class" , ClassName (" org.junit.runners" , " Parameterized" ))
81+ .build()
82+ )
83+
84+ // Generate constructor
85+ val constructorSpec = FunSpec .constructorBuilder()
86+ val paramFields = descriptor.getSpecificField(paramAnnotationFQN)
87+ paramFields.forEach { param ->
88+ constructorSpec.addParameter(param.name, getTypeName(param.type))
89+ }
90+
91+ val typeSpecBuilder = TypeSpec .classBuilder(descriptorName)
92+ .primaryConstructor(constructorSpec.build())
93+ .addProperties(paramFields.map { param ->
94+ PropertySpec .builder(param.name, getTypeName(param.type))
95+ .initializer(param.name)
96+ .addModifiers(KModifier .PRIVATE )
97+ .build()
98+ })
99+
100+ addBenchmarkMethods(typeSpecBuilder, descriptor, true )
101+
102+ // Generate companion object with parameters
103+ val companionSpec = TypeSpec .companionObjectBuilder()
104+ .addFunction(generateParametersFunction(paramFields))
105+ .build()
106+
107+ typeSpecBuilder.addType(companionSpec)
108+
109+ fileSpecBuilder.addType(typeSpecBuilder.build())
110+ fileSpecBuilder.build().writeTo(androidTestDir)
111+ }
112+
113+ private fun generateParametersFunction (paramFields : List <FieldAnnotationsDescriptor >): FunSpec {
114+ val dataFunctionBuilder = FunSpec .builder(" data" )
115+ .addAnnotation(JvmStatic ::class )
116+ .returns(
117+ ClassName (" java.util" , " Collection" )
118+ .parameterizedBy(
119+ ClassName (" kotlin" , " Array" )
120+ .parameterizedBy(ANY )
121+ )
122+ )
123+
124+ val paramNameAndIndex = paramFields.mapIndexed { index, param ->
125+ " ${param.name} ={${index} }"
126+ }.joinToString(" , " )
127+
128+ val paramAnnotationValue = " {index}: $paramNameAndIndex "
129+
130+ dataFunctionBuilder.addAnnotation(
131+ AnnotationSpec .builder(ClassName (" org.junit.runners" , " Parameterized.Parameters" ))
132+ .addMember(" name = \" %L\" " , paramAnnotationValue)
133+ .build()
134+ )
135+
136+ val paramValueLists = paramFields.map { param ->
137+ val values = param.annotations
138+ .find { it.name == paramAnnotationFQN }
139+ ?.parameters?.get(" value" ) as List <* >
140+
141+ values.map { value ->
142+ if (param.type == " java.lang.String" ) {
143+ " \"\"\" $value \"\"\" "
144+ } else {
145+ value.toString()
146+ }
147+ }
148+ }
149+
150+ val cartesianProduct = cartesianProduct(paramValueLists as List <List <Any >>)
151+
152+ val returnStatement = StringBuilder (" return listOf(\n " )
153+ cartesianProduct.forEachIndexed { index, combination ->
154+ val arrayContent = combination.joinToString(" , " )
155+ returnStatement.append(" arrayOf($arrayContent )" )
156+ if (index != cartesianProduct.size - 1 ) {
157+ returnStatement.append(" ,\n " )
158+ }
159+ }
160+ returnStatement.append(" \n )" )
161+ dataFunctionBuilder.addStatement(returnStatement.toString())
162+
163+ return dataFunctionBuilder.build()
164+ }
165+
166+ private fun cartesianProduct (lists : List <List <Any >>): List <List <Any >> {
167+ if (lists.isEmpty()) return emptyList()
168+ return lists.fold(listOf (listOf<Any >())) { acc, list ->
169+ acc.flatMap { prefix -> list.map { value -> prefix + value } }
170+ }
171+ }
172+
173+ private fun addBenchmarkMethods (
174+ typeSpecBuilder : TypeSpec .Builder ,
175+ descriptor : ClassAnnotationsDescriptor ,
176+ isParameterized : Boolean = false
177+ ) {
44178 val className = " ${descriptor.packageName} .${descriptor.name} "
45179 val propertyName = descriptor.name.decapitalize(Locale .getDefault())
46180
@@ -55,70 +189,106 @@ private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: C
55189 descriptor.methods
56190 .filter { it.visibility == Visibility .PUBLIC && it.parameters.isEmpty() }
57191 .filterNot { method ->
58- method.annotations.any { annotation -> annotation.name == " kotlinx.benchmark.Param " }
192+ method.annotations.any { annotation -> annotation.name == paramAnnotationFQN }
59193 }
60194 .forEach { method ->
61195 when {
62- method.annotations.any { it.name == " kotlinx.benchmark.Setup " || it.name == " kotlinx.benchmark.TearDown " } -> {
196+ method.annotations.any { it.name == setupAnnotationFQN || it.name == teardownAnnotationFQN } -> {
63197 generateNonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
64198 }
199+
200+ isParameterized && descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty() -> {
201+ generateParameterizedMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
202+ }
203+
65204 else -> {
66205 generateMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
67206 }
68207 }
69208 }
70209}
71210
72- private fun generateMeasurableMethod (
211+ private fun generateCommonMeasurableMethod (
73212 descriptor : ClassAnnotationsDescriptor ,
74213 method : MethodAnnotationsDescriptor ,
75214 propertyName : String ,
76- typeSpecBuilder : TypeSpec .Builder
215+ typeSpecBuilder : TypeSpec .Builder ,
216+ isParameterized : Boolean
77217) {
78218 val measurementIterations = descriptor.annotations
79- .find { it.name == " kotlinx.benchmark.Measurement " }
219+ .find { it.name == measureAnnotationFQN }
80220 ?.parameters?.get(" iterations" ) as ? Int ? : 5
81221 val warmupIterations = descriptor.annotations
82- .find { it.name == " kotlinx.benchmark.Warmup " }
222+ .find { it.name == warmupAnnotationFQN }
83223 ?.parameters?.get(" iterations" ) as ? Int ? : 5
84224
225+ val methodName = " ${descriptor.packageName} .${descriptor.name} .${method.name} "
226+
85227 val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _${method.name} " )
86228 .addAnnotation(ClassName (" org.junit" , " Test" ))
87229 .addAnnotation(
88230 AnnotationSpec .builder(ClassName (" kotlin" , " OptIn" ))
89231 .addMember(" %T::class" , ClassName (" androidx.benchmark" , " ExperimentalBenchmarkStateApi" ))
90232 .build()
91233 )
92- // TODO: Add warmupCount and repeatCount parameters
234+
235+ if (isParameterized) {
236+ descriptor.getSpecificField(paramAnnotationFQN).forEach { field ->
237+ methodSpecBuilder.addStatement(" $propertyName .${field.name} = ${field.name} " )
238+ }
239+ }
240+
241+ methodSpecBuilder
93242 .addStatement(
94243 " val state = %T(warmupCount = $warmupIterations , repeatCount = $measurementIterations )" ,
95244 ClassName (" androidx.benchmark" , " BenchmarkState" )
96245 )
246+ .addStatement(" println(\" Android: $methodName \" )" )
97247 .beginControlFlow(" while (state.keepRunning())" )
98248 .addStatement(" $propertyName .${method.name} ()" )
99249 .endControlFlow()
100250 .addStatement(" val measurementResult = state.getMeasurementTimeNs()" )
101251 .beginControlFlow(" measurementResult.forEachIndexed { index, time ->" )
102252 .addStatement(" println(\" Iteration \$ {index + 1}: \$ time ns\" )" )
103253 .endControlFlow()
254+
104255 typeSpecBuilder.addFunction(methodSpecBuilder.build())
105256}
106257
258+ private fun generateParameterizedMeasurableMethod (
259+ descriptor : ClassAnnotationsDescriptor ,
260+ method : MethodAnnotationsDescriptor ,
261+ propertyName : String ,
262+ typeSpecBuilder : TypeSpec .Builder
263+ ) {
264+ generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = true )
265+ }
266+
267+ private fun generateMeasurableMethod (
268+ descriptor : ClassAnnotationsDescriptor ,
269+ method : MethodAnnotationsDescriptor ,
270+ propertyName : String ,
271+ typeSpecBuilder : TypeSpec .Builder
272+ ) {
273+ generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = false )
274+ }
275+
276+
107277private fun generateNonMeasurableMethod (
108278 descriptor : ClassAnnotationsDescriptor ,
109279 method : MethodAnnotationsDescriptor ,
110280 propertyName : String ,
111281 typeSpecBuilder : TypeSpec .Builder
112282) {
113283 when (method.annotations.first().name) {
114- " kotlinx.benchmark.Setup " -> {
284+ setupAnnotationFQN -> {
115285 val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _setUp" )
116286 .addAnnotation(ClassName (" org.junit" , " Before" ))
117287 .addStatement(" $propertyName .${method.name} ()" )
118288 typeSpecBuilder.addFunction(methodSpecBuilder.build())
119289 }
120290
121- " kotlinx.benchmark.TearDown " -> {
291+ teardownAnnotationFQN -> {
122292 val methodSpecBuilder = FunSpec .builder(" benchmark_${descriptor.name} _tearDown" )
123293 .addAnnotation(ClassName (" org.junit" , " After" ))
124294 .addStatement(" $propertyName .${method.name} ()" )
@@ -127,49 +297,16 @@ private fun generateNonMeasurableMethod(
127297 }
128298}
129299
130- private fun updateAndroidDependencies (buildGradleFile : File , dependencies : List <Pair <String , String ?>>) {
131- if (buildGradleFile.exists()) {
132- val buildGradleContent = buildGradleFile.readText()
133-
134- if (buildGradleContent.contains(" android {" )) {
135- val androidBlockStart = buildGradleContent.indexOf(" android {" )
136- val androidBlockEnd = buildGradleContent.lastIndexOf(" }" ) + 1
137- val androidBlockContent = buildGradleContent.substring(androidBlockStart, androidBlockEnd)
138-
139- val newDependencies = dependencies.filterNot { (dependency, version) ->
140- val dependencyString = version?.let { """ $dependency :$version """ } ? : dependency
141- androidBlockContent.contains(dependencyString)
142- }
143- if (newDependencies.isNotEmpty()) {
144- val updatedAndroidBlockContent = if (androidBlockContent.contains(" dependencies {" )) {
145- val dependenciesBlockStart = androidBlockContent.indexOf(" dependencies {" )
146- val dependenciesBlockEnd = androidBlockContent.indexOf(" }" , dependenciesBlockStart) + 1
147- val dependenciesBlockContent =
148- androidBlockContent.substring(dependenciesBlockStart, dependenciesBlockEnd)
149-
150- val newDependenciesString = newDependencies.joinToString(" \n " ) { (dependency, version) ->
151- version?.let { """ androidTestImplementation("$dependency :$version ")""" }
152- ? : """ androidTestImplementation(files("$dependency "))"""
153- }
154- androidBlockContent.replace(
155- dependenciesBlockContent,
156- dependenciesBlockContent.replace(
157- " dependencies {" ,
158- " dependencies {\n $newDependenciesString "
159- )
160- )
161- } else {
162- val newDependenciesString = newDependencies.joinToString(" \n " ) { (dependency, version) ->
163- version?.let { """ androidTestImplementation("$dependency :$version ")""" }
164- ? : """ androidTestImplementation(files("$dependency "))"""
165- }
166- androidBlockContent.replace(" {" , " {\n dependencies {\n $newDependenciesString \n }\n " )
167- }
168-
169- val updatedBuildGradleContent =
170- buildGradleContent.replace(androidBlockContent, updatedAndroidBlockContent)
171- buildGradleFile.writeText(updatedBuildGradleContent)
172- }
173- }
300+ private fun getTypeName (type : String ): TypeName {
301+ return when (type) {
302+ " int" -> Int ::class .asTypeName()
303+ " long" -> Long ::class .asTypeName()
304+ " boolean" -> Boolean ::class .asTypeName()
305+ " float" -> Float ::class .asTypeName()
306+ " double" -> Double ::class .asTypeName()
307+ " char" -> Char ::class .asTypeName()
308+ " byte" -> Byte ::class .asTypeName()
309+ " short" -> Short ::class .asTypeName()
310+ else -> ClassName .bestGuess(type)
174311 }
175- }
312+ }
0 commit comments