diff --git a/analyzer/src/main/scala/com/avsystem/commons/analyzer/ExplicitGenerics.scala b/analyzer/src/main/scala/com/avsystem/commons/analyzer/ExplicitGenerics.scala index 585f58ce5..b84f3a9e8 100644 --- a/analyzer/src/main/scala/com/avsystem/commons/analyzer/ExplicitGenerics.scala +++ b/analyzer/src/main/scala/com/avsystem/commons/analyzer/ExplicitGenerics.scala @@ -5,24 +5,51 @@ import scala.tools.nsc.Global class ExplicitGenerics(g: Global) extends AnalyzerRule(g, "explicitGenerics") { - import global._ + import global.* lazy val explicitGenericsAnnotTpe = classType("com.avsystem.commons.annotation.explicitGenerics") - def analyze(unit: CompilationUnit) = if (explicitGenericsAnnotTpe != NoType) { + + private def fail(pos: Position, symbol: Symbol): Unit = + report(pos, s"$symbol requires that its type arguments are explicit (not inferred)") + + def analyze(unit: CompilationUnit): Unit = if (explicitGenericsAnnotTpe != NoType) { def requiresExplicitGenerics(sym: Symbol): Boolean = sym != NoSymbol && (sym :: sym.overrides).flatMap(_.annotations).exists(_.tree.tpe <:< explicitGenericsAnnotTpe) + def applyOfAnnotatedCompanion(preSym: Symbol): Boolean = + preSym != NoSymbol && preSym.isMethod && preSym.name == TermName("apply") && { + val owner = preSym.owner + val companionCls = + if (owner.isModuleClass) owner.companionClass + else if (owner.isModule) owner.moduleClass.companionClass + else NoSymbol + requiresExplicitGenerics(companionCls) + } + def analyzeTree(tree: Tree): Unit = analyzer.macroExpandee(tree) match { case `tree` | EmptyTree => tree match { - case t@TypeApply(pre, args) if requiresExplicitGenerics(pre.symbol) => + case t@TypeApply(pre, args) if requiresExplicitGenerics(pre.symbol) || applyOfAnnotatedCompanion(pre.symbol) => val inferredTypeParams = args.forall { case tt: TypeTree => tt.original == null || tt.original == EmptyTree case _ => false } if (inferredTypeParams) { - report(t.pos, s"${pre.symbol} requires that its type arguments are explicit (not inferred)") + // If we're on companion.apply, report on the class symbol for clearer message + val targetSym = if (applyOfAnnotatedCompanion(pre.symbol)) pre.symbol.owner.companionClass else pre.symbol + fail(t.pos, targetSym) + } + case n@New(tpt) if requiresExplicitGenerics(tpt.tpe.typeSymbol) => + val explicitTypeArgsProvided = tpt match { + case tt: TypeTree => tt.original match { + case AppliedTypeTree(_, args) if args.nonEmpty => true + case _ => false + } + case _ => false + } + if (!explicitTypeArgsProvided) { + fail(n.pos, tpt.tpe.typeSymbol) } case _ => } @@ -30,6 +57,7 @@ class ExplicitGenerics(g: Global) extends AnalyzerRule(g, "explicitGenerics") { case prevTree => analyzeTree(prevTree) } + analyzeTree(unit.body) } } diff --git a/analyzer/src/test/scala/com/avsystem/commons/analyzer/ExplicitGenericsTest.scala b/analyzer/src/test/scala/com/avsystem/commons/analyzer/ExplicitGenericsTest.scala index 575bbce51..9c7181e73 100644 --- a/analyzer/src/test/scala/com/avsystem/commons/analyzer/ExplicitGenericsTest.scala +++ b/analyzer/src/test/scala/com/avsystem/commons/analyzer/ExplicitGenericsTest.scala @@ -41,4 +41,47 @@ final class ExplicitGenericsTest extends AnyFunSuite with AnalyzerTest { |val x = TestUtils.genericMacro[Int](123) |""".stripMargin) } + + test("inferred in constructor should be rejected") { + assertErrors(2, + scala""" + |import com.avsystem.commons.analyzer.TestUtils + | + |val x = new TestUtils.GenericClass() + |val y = new TestUtils.GenericCaseClass(123) + |""".stripMargin) + } + + + test("inferred in apply when constructor marked should be rejected") { + assertErrors(1, + scala""" + |import com.avsystem.commons.analyzer.TestUtils + | + |val x = TestUtils.GenericCaseClass(123) + |""".stripMargin) + } + + test("explicit in constructor should not be rejected") { + assertNoErrors( + scala""" + |import com.avsystem.commons.analyzer.TestUtils + | + |val x = new TestUtils.GenericClass[Int]() + |""".stripMargin) + } + + test("not marked should not be rejected") { + assertNoErrors( + scala""" + |def method[T](e: T) = e + |class NotMarkedGenericClass[T] + |final case class NotMarkedGenericCaseClass[T](arg: T) + | + |val w = method(123) + |val x = new NotMarkedGenericClass() + |val y = NotMarkedGenericCaseClass(123) + |val z = new NotMarkedGenericClass() + |""".stripMargin) + } } diff --git a/analyzer/src/test/scala/com/avsystem/commons/analyzer/TestUtils.scala b/analyzer/src/test/scala/com/avsystem/commons/analyzer/TestUtils.scala index 9a4e00c70..dfb82d724 100644 --- a/analyzer/src/test/scala/com/avsystem/commons/analyzer/TestUtils.scala +++ b/analyzer/src/test/scala/com/avsystem/commons/analyzer/TestUtils.scala @@ -28,4 +28,10 @@ object TestUtils { def genericMethod[T](arg: T): T = arg @explicitGenerics def genericMacro[T](arg: T): T = macro genericMacroImpl[T] + + @explicitGenerics + class GenericClass[T] + + @explicitGenerics + case class GenericCaseClass[T](arg: T) }