diff --git a/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBModule.scala b/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBModule.scala index d7979097e83e..bec7577645c1 100644 --- a/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBModule.scala +++ b/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBModule.scala @@ -22,6 +22,8 @@ trait ScalaPBModule extends ScalaModule { def scalaPBVersion: T[String] + def scalaPBGenerators: T[Seq[Generator]] = Seq[Generator](ScalaGen) + def scalaPBFlatPackage: T[Boolean] = Task { false } def scalaPBJavaConversions: T[Boolean] = Task { false } @@ -141,7 +143,8 @@ trait ScalaPBModule extends ScalaModule { scalaPBSources().map(_.path), scalaPBOptions(), Task.dest, - scalaPBCompileOptions() + scalaPBCompileOptions(), + scalaPBGenerators() ) } } diff --git a/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBWorker.scala b/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBWorker.scala index d067f0b35bb9..b1bc99a8cbd1 100644 --- a/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBWorker.scala +++ b/contrib/scalapblib/src/mill/contrib/scalapblib/ScalaPBWorker.scala @@ -5,6 +5,7 @@ import java.io.File import mill.api.PathRef import mill.api.{Discover, ExternalModule} +import upickle.default.ReadWriter class ScalaPBWorker { @@ -15,16 +16,18 @@ class ScalaPBWorker { sources: Seq[File], scalaPBOptions: String, generatedDirectory: File, - otherArgs: Seq[String] + otherArgs: Seq[String], + generators: Seq[Generator] ): Unit = { val pbcClasspath = scalaPBClasspath.map(_.path).toVector mill.util.Jvm.withClassLoader(pbcClasspath, null) { cl => val scalaPBCompilerClass = cl.loadClass("scalapb.ScalaPBC") val mainMethod = scalaPBCompilerClass.getMethod("main", classOf[Array[java.lang.String]]) - val opts = if (scalaPBOptions.isEmpty) "" else scalaPBOptions + ":" - val args = otherArgs ++ Seq( - s"--scala_out=${opts}${generatedDirectory.getCanonicalPath}" - ) ++ roots.map(root => s"--proto_path=${root.getCanonicalPath}") ++ sources.map( + val args = otherArgs ++ generators.map { gen => + val opts = if (scalaPBOptions.isEmpty || !gen.supportsScalaPBOptions) "" + else scalaPBOptions + ":" + s"${gen.generator}=$opts${generatedDirectory.getCanonicalPath}" + } ++ roots.map(root => s"--proto_path=${root.getCanonicalPath}") ++ sources.map( _.getCanonicalPath ) ctx.log.debug(s"ScalaPBC args: ${args.mkString(" ")}") @@ -70,7 +73,8 @@ class ScalaPBWorker { scalaPBSources: Seq[os.Path], scalaPBOptions: String, dest: os.Path, - scalaPBCExtraArgs: Seq[String] + scalaPBCExtraArgs: Seq[String], + generators: Seq[Generator] )(implicit ctx: mill.api.TaskCtx): mill.api.Result[PathRef] = { val compiler = scalaPB(scalaPBClasspath) val sources = scalaPBSources.flatMap { @@ -86,7 +90,14 @@ class ScalaPBWorker { Seq(ioFile) } val roots = scalaPBSources.map(_.toIO).filter(_.isDirectory) - compiler.compileScalaPB(roots, sources, scalaPBOptions, dest.toIO, scalaPBCExtraArgs) + compiler.compileScalaPB( + roots, + sources, + scalaPBOptions, + dest.toIO, + scalaPBCExtraArgs, + generators + ) mill.api.Result.Success(PathRef(dest)) } } @@ -97,10 +108,25 @@ trait ScalaPBWorkerApi { source: Seq[File], scalaPBOptions: String, generatedDirectory: File, - otherArgs: Seq[String] + otherArgs: Seq[String], + generators: Seq[Generator] ): Unit } +sealed trait Generator derives ReadWriter { + def generator: String + def supportsScalaPBOptions: Boolean +} +case object ScalaGen extends Generator { + override def generator: String = "--scala_out" + override def supportsScalaPBOptions: Boolean = true +} +case object JavaGen extends Generator { + override def generator: String = "--java_out" + override def supportsScalaPBOptions: Boolean = + false // Java options are specified directly in the proto file +} + object ScalaPBWorkerApi extends ExternalModule { def scalaPBWorker: Worker[ScalaPBWorker] = Task.Worker { new ScalaPBWorker() } lazy val millDiscover = Discover[this.type] diff --git a/contrib/scalapblib/test/src/mill/contrib/scalapblib/TutorialTests.scala b/contrib/scalapblib/test/src/mill/contrib/scalapblib/TutorialTests.scala index 8a32a6df10e9..30c41fe568da 100644 --- a/contrib/scalapblib/test/src/mill/contrib/scalapblib/TutorialTests.scala +++ b/contrib/scalapblib/test/src/mill/contrib/scalapblib/TutorialTests.scala @@ -11,7 +11,9 @@ import utest.{TestSuite, Tests, assert, *} object TutorialTests extends TestSuite { val testScalaPbVersion = "0.11.7" - trait TutorialBase extends TestRootModule + trait TutorialBase extends TestRootModule { + val core: TutorialModule + } trait TutorialModule extends ScalaPBModule { def scalaVersion = sys.props.getOrElse("TEST_SCALA_2_12_VERSION", ???) @@ -70,12 +72,28 @@ object TutorialTests extends TestSuite { lazy val millDiscover = Discover[this.type] } + object TutorialWithJavaGen extends TutorialBase { + object core extends TutorialModule { + override def scalaPBGenerators = Seq(JavaGen) + } + + lazy val millDiscover = Discover[this.type] + } + + object TutorialWithScalaAndJavaGen extends TutorialBase { + object core extends TutorialModule { + override def scalaPBGenerators = Seq[Generator](ScalaGen, JavaGen) + } + + lazy val millDiscover = Discover[this.type] + } + val resourcePath: os.Path = os.Path(sys.env("MILL_TEST_RESOURCE_DIR")) def protobufOutPath(eval: UnitTester): os.Path = eval.outPath / "core/compileScalaPB.dest/com/example/tutorial" - def compiledSourcefiles: Seq[os.RelPath] = Seq[os.RelPath]( + def compiledScalaSourcefiles: Seq[os.RelPath] = Seq[os.RelPath]( os.rel / "AddressBook.scala", os.rel / "Person.scala", os.rel / "TutorialProto.scala", @@ -83,6 +101,39 @@ object TutorialTests extends TestSuite { os.rel / "IncludeProto.scala" ) + def compiledJavaSourcefiles: Seq[os.RelPath] = Seq[os.RelPath]( + os.rel / "AddressBookProtos.java", + os.rel / "IncludeOuterClass.java" + ) + + // Helper function to test compilation with different generators + def testCompilation( + module: TutorialBase, + expectedFiles: Seq[os.RelPath] + ): Unit = { + UnitTester(module, resourcePath).scoped { eval => + if (!mill.constants.Util.isWindows) { + val Right(result) = eval.apply(module.core.compileScalaPB): @unchecked + + val outPath = protobufOutPath(eval) + val outputFiles = os.walk(result.value.path).filter(os.isFile) + val expectedSourcefiles = expectedFiles.map(outPath / _) + + assert( + result.value.path == eval.outPath / "core/compileScalaPB.dest", + outputFiles.nonEmpty, + outputFiles.forall(expectedSourcefiles.contains), + outputFiles.size == outputFiles.size, + result.evalCount > 0 + ) + + // don't recompile if nothing changed + val Right(result2) = eval.apply(module.core.compileScalaPB): @unchecked + assert(result2.evalCount == 0) + } + } + } + def tests: Tests = Tests { test("scalapbVersion") { @@ -97,30 +148,12 @@ object TutorialTests extends TestSuite { } test("compileScalaPB") { - test("calledDirectly") - UnitTester(Tutorial, resourcePath).scoped { eval => - if (!mill.constants.Util.isWindows) { - val Right(result) = eval.apply(Tutorial.core.compileScalaPB): @unchecked - - val outPath = protobufOutPath(eval) - - val outputFiles = os.walk(result.value.path).filter(os.isFile) - - val expectedSourcefiles = compiledSourcefiles.map(outPath / _) - - assert( - result.value.path == eval.outPath / "core/compileScalaPB.dest", - outputFiles.nonEmpty, - outputFiles.forall(expectedSourcefiles.contains), - outputFiles.size == 5, - result.evalCount > 0 - ) - - // don't recompile if nothing changed - val Right(result2) = eval.apply(Tutorial.core.compileScalaPB): @unchecked - - assert(result2.evalCount == 0) - } - } + test("scalaGen") - testCompilation(Tutorial, compiledScalaSourcefiles) + test("javaGen") - testCompilation(TutorialWithJavaGen, compiledJavaSourcefiles) + test("scalaAndJavaGen") - testCompilation( + TutorialWithScalaAndJavaGen, + compiledScalaSourcefiles ++ compiledJavaSourcefiles + ) test("calledWithSpecificFile") - UnitTester( TutorialWithSpecificSources, @@ -165,7 +198,7 @@ object TutorialTests extends TestSuite { // // // val outputFiles = os.walk(outPath).filter(_.isFile) // -// // val expectedSourcefiles = compiledSourcefiles.map(outPath / _) +// // val expectedSourcefiles = compiledScalaSourcefiles.map(outPath / _) // // // assert( // // outputFiles.nonEmpty,