@@ -15,13 +15,16 @@ import Trees.*
15
15
import Types .*
16
16
import Symbols .*
17
17
import Names .*
18
+ import StdNames .str
18
19
import NameOps .*
19
20
import inlines .Inlines
20
21
import transform .ValueClasses
21
22
import transform .SymUtils .*
22
- import dotty .tools .io .{File , FileExtension }
23
+ import dotty .tools .io .{File , FileExtension , JarArchive }
24
+ import util .{Property , SourceFile }
23
25
import java .io .PrintWriter
24
26
27
+ import ExtractAPI .NonLocalClassSymbolsInCurrentUnits
25
28
26
29
import scala .collection .mutable
27
30
import scala .util .hashing .MurmurHash3
@@ -62,13 +65,59 @@ class ExtractAPI extends Phase {
62
65
// definitions, and `PostTyper` does not change definitions).
63
66
override def runsAfter : Set [String ] = Set (transform.PostTyper .name)
64
67
68
+ override def runOn (units : List [CompilationUnit ])(using Context ): List [CompilationUnit ] =
69
+ val nonLocalClassSymbols = new mutable.HashSet [Symbol ]
70
+ val ctx0 = ctx.withProperty(NonLocalClassSymbolsInCurrentUnits , Some (nonLocalClassSymbols))
71
+ val units0 = super .runOn(units)(using ctx0)
72
+ ctx.withIncCallback(recordNonLocalClasses(nonLocalClassSymbols, _))
73
+ units0
74
+ end runOn
75
+
76
+ private def recordNonLocalClasses (nonLocalClassSymbols : mutable.HashSet [Symbol ], cb : interfaces.IncrementalCallback )(using Context ): Unit =
77
+ for cls <- nonLocalClassSymbols if ! cls.isLocal do
78
+ val sourceFile = cls.source
79
+ if sourceFile.exists && cls.isDefinedInCurrentRun then
80
+ recordNonLocalClass(cls, sourceFile, cb)
81
+ cb.apiPhaseCompleted()
82
+ cb.dependencyPhaseCompleted()
83
+
84
+ private def recordNonLocalClass (cls : Symbol , sourceFile : SourceFile , cb : interfaces.IncrementalCallback )(using Context ): Unit =
85
+ def registerProductNames (fullClassName : String , binaryClassName : String ) =
86
+ val pathToClassFile = s " ${binaryClassName.replace('.' , java.io.File .separatorChar)}.class "
87
+
88
+ val classFile = {
89
+ ctx.settings.outputDir.value match {
90
+ case jar : JarArchive =>
91
+ new java.io.File (s " $jar! $pathToClassFile" )
92
+ case outputDir =>
93
+ new java.io.File (outputDir.file, pathToClassFile)
94
+ }
95
+ }
96
+
97
+ cb.generatedNonLocalClass(sourceFile, classFile.toPath(), binaryClassName, fullClassName)
98
+ end registerProductNames
99
+
100
+ val fullClassName = atPhase(sbtExtractDependenciesPhase) {
101
+ ExtractDependencies .classNameAsString(cls)
102
+ }
103
+ val binaryClassName = cls.binaryClassName
104
+ registerProductNames(fullClassName, binaryClassName)
105
+
106
+ // Register the names of top-level module symbols that emit two class files
107
+ val isTopLevelUniqueModule =
108
+ cls.owner.is(PackageClass ) && cls.is(ModuleClass ) && cls.companionClass == NoSymbol
109
+ if isTopLevelUniqueModule || cls.isPackageObject then
110
+ registerProductNames(fullClassName, binaryClassName.stripSuffix(str.MODULE_SUFFIX ))
111
+ end recordNonLocalClass
112
+
65
113
override def run (using Context ): Unit = {
66
114
val unit = ctx.compilationUnit
67
115
val sourceFile = unit.source
68
116
ctx.withIncCallback: cb =>
69
117
cb.startSource(sourceFile)
70
118
71
- val apiTraverser = new ExtractAPICollector
119
+ val nonLocalClassSymbols = ctx.property(NonLocalClassSymbolsInCurrentUnits ).get
120
+ val apiTraverser = ExtractAPICollector (nonLocalClassSymbols)
72
121
val classes = apiTraverser.apiSource(unit.tpdTree)
73
122
val mainClasses = apiTraverser.mainClasses
74
123
@@ -92,6 +141,8 @@ object ExtractAPI:
92
141
val name : String = " sbt-api"
93
142
val description : String = " sends a representation of the API of classes to sbt"
94
143
144
+ private val NonLocalClassSymbolsInCurrentUnits : Property .Key [mutable.HashSet [Symbol ]] = Property .Key ()
145
+
95
146
/** Extracts full (including private members) API representation out of Symbols and Types.
96
147
*
97
148
* The exact representation used for each type is not important: the only thing
@@ -134,7 +185,7 @@ object ExtractAPI:
134
185
* without going through an intermediate representation, see
135
186
* http://www.scala-sbt.org/0.13/docs/Understanding-Recompilation.html#Hashing+an+API+representation
136
187
*/
137
- private class ExtractAPICollector (using Context ) extends ThunkHolder {
188
+ private class ExtractAPICollector (nonLocalClassSymbols : mutable. HashSet [ Symbol ])( using Context ) extends ThunkHolder {
138
189
import tpd .*
139
190
import xsbti .api
140
191
@@ -252,6 +303,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
252
303
childrenOfSealedClass, topLevel, tparams)
253
304
254
305
allNonLocalClassesInSrc += cl
306
+ nonLocalClassSymbols += sym
255
307
256
308
if (sym.isStatic && ! sym.is(Trait ) && ctx.platform.hasMainMethod(sym)) {
257
309
// If sym is an object, all main methods count, otherwise only @static ones count.
0 commit comments