Skip to content

add scripting support similar to scala2 scripting #11180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ object Contexts {
/** Sourcefile corresponding to given abstract file, memoized */
def getSource(file: AbstractFile, codec: => Codec = Codec(settings.encoding.value)) = {
util.Stats.record("Context.getSource")
base.sources.getOrElseUpdate(file, new SourceFile(file, codec))
base.sources.getOrElseUpdate(file, SourceFile(file, codec))
}

/** SourceFile with given path name, memoized */
Expand Down
44 changes: 40 additions & 4 deletions compiler/src/dotty/tools/dotc/util/SourceFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,36 @@ object ScriptSourceFile {
@sharable private val headerPattern = Pattern.compile("""^(::)?!#.*(\r|\n|\r\n)""", Pattern.MULTILINE)
private val headerStarts = List("#!", "::#!")

/** Return true if has a script header */
def hasScriptHeader(content: Array[Char]): Boolean = {
headerStarts exists (content startsWith _)
}

def apply(file: AbstractFile, content: Array[Char]): SourceFile = {
/** Length of the script header from the given content, if there is one.
* The header begins with "#!" or "::#!" and ends with a line starting
* with "!#" or "::!#".
* The header begins with "#!" or "::#!" and is either a single line,
* or it ends with a line starting with "!#" or "::!#", if present.
*/
val headerLength =
if (headerStarts exists (content startsWith _)) {
// convert initial hash-bang line to a comment
val matcher = headerPattern matcher content.mkString
if (matcher.find) matcher.end
else throw new IOException("script file does not close its header with !# or ::!#")
else content.indexOf('\n') // end of first line
}
else 0
new SourceFile(file, content drop headerLength) {

// overwrite hash-bang lines with all spaces
val hashBangLines = content.take(headerLength).mkString.split("\\r?\\n")
if hashBangLines.nonEmpty then
for i <- 0 until headerLength do
content(i) match {
case '\r' | '\n' =>
case _ =>
content(i) = ' '
}

new SourceFile(file, content) {
override val underlying = new SourceFile(this.file, this.content)
}
}
Expand Down Expand Up @@ -245,6 +262,25 @@ object SourceFile {
else
sourcePath.toString
}

/** Return true if file is a script:
* if filename extension is not .scala and has a script header.
*/
def isScript(file: AbstractFile, content: Array[Char]): Boolean =
if file.hasExtension(".scala") then
false
else
ScriptSourceFile.hasScriptHeader(content)

def apply(file: AbstractFile, codec: Codec): SourceFile =
// see note above re: Files.exists is remarkably slow
val chars = try new String(file.toByteArray, codec.charSet).toCharArray
catch case _: java.nio.file.NoSuchFileException => Array[Char]()
if isScript(file, chars) then
ScriptSourceFile(file, chars)
else
new SourceFile(file, chars)

}

@sharable object NoSource extends SourceFile(NoAbstractFile, Array[Char]()) {
Expand Down
125 changes: 118 additions & 7 deletions compiler/src/dotty/tools/scripting/Main.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,133 @@
package dotty.tools.scripting

import java.io.File
import java.nio.file.Path
import java.net.URLClassLoader
import java.lang.reflect.{ Modifier, Method }

/** Main entry point to the Scripting execution engine */
object Main:
/** All arguments before -script <target_script> are compiler arguments.
All arguments afterwards are script arguments.*/
def distinguishArgs(args: Array[String]): (Array[String], File, Array[String]) =
val (compilerArgs, rest) = args.splitAt(args.indexOf("-script"))
private def distinguishArgs(args: Array[String]): (Array[String], File, Array[String], Boolean) =
// NOTE: if -script <scriptName> not present, quit with error.
val (leftArgs, rest) = args.splitAt(args.indexOf("-script"))
if( rest.size < 2 ) then
sys.error(s"missing: -script <scriptName>")

val file = File(rest(1))
val scriptArgs = rest.drop(2)
(compilerArgs, file, scriptArgs)
var saveJar = false
val compilerArgs = leftArgs.filter {
case "-save" | "-savecompiled" =>
saveJar = true
false
case _ =>
true
}
(compilerArgs, file, scriptArgs, saveJar)
end distinguishArgs

def main(args: Array[String]): Unit =
val (compilerArgs, scriptFile, scriptArgs) = distinguishArgs(args)
try ScriptingDriver(compilerArgs, scriptFile, scriptArgs).compileAndRun()
val (compilerArgs, scriptFile, scriptArgs, saveJar) = distinguishArgs(args)
try ScriptingDriver(compilerArgs, scriptFile, scriptArgs).compileAndRun { (outDir:Path, classpath:String) =>
val classFiles = outDir.toFile.listFiles.toList match {
case Nil => sys.error(s"no files below [$outDir]")
case list => list
}

val (mainClassName, mainMethod) = detectMainClassAndMethod(outDir, classpath, scriptFile)

if saveJar then
// write a standalone jar to the script parent directory
writeJarfile(outDir, scriptFile, scriptArgs, classpath, mainClassName)

// invoke the compiled script main method
mainMethod.invoke(null, scriptArgs)
}
catch
case ScriptingException(msg) =>
println(s"Error: $msg")
case e:Exception =>
e.printStackTrace
println(s"Error: ${e.getMessage}")
sys.exit(1)

case e: java.lang.reflect.InvocationTargetException =>
throw e.getCause

private def writeJarfile(outDir: Path, scriptFile: File, scriptArgs:Array[String],
classpath:String, mainClassName: String): Unit =
val jarTargetDir: Path = Option(scriptFile.toPath.getParent) match {
case None => sys.error(s"no parent directory for script file [$scriptFile]")
case Some(parent) => parent
}

def scriptBasename = scriptFile.getName.takeWhile(_!='.')
val jarPath = s"$jarTargetDir/$scriptBasename.jar"

val cpPaths = classpath.split(pathsep).map {
// protect relative paths from being converted to absolute
case str if str.startsWith(".") && File(str).isDirectory => s"${str.withSlash}/"
case str if str.startsWith(".") => str.withSlash
case str => File(str).toURI.toURL.toString
}

import java.util.jar.Attributes.Name
val cpString:String = cpPaths.distinct.mkString(" ")
val manifestAttributes:Seq[(Name, String)] = Seq(
(Name.MANIFEST_VERSION, "1.0.0"),
(Name.MAIN_CLASS, mainClassName),
(Name.CLASS_PATH, cpString),
)
import dotty.tools.io.{Jar, Directory}
val jar = new Jar(jarPath)
val writer = jar.jarWriter(manifestAttributes:_*)
writer.writeAllFrom(Directory(outDir))
end writeJarfile

private def detectMainClassAndMethod(outDir: Path, classpath: String,
scriptFile: File): (String, Method) =
val outDirURL = outDir.toUri.toURL
val classpathUrls = classpath.split(pathsep).map(File(_).toURI.toURL)
val cl = URLClassLoader(classpathUrls :+ outDirURL)

def collectMainMethods(target: File, path: String): List[(String, Method)] =
val nameWithoutExtension = target.getName.takeWhile(_ != '.')
val targetPath =
if path.nonEmpty then s"${path}.${nameWithoutExtension}"
else nameWithoutExtension

if target.isDirectory then
for
packageMember <- target.listFiles.toList
membersMainMethod <- collectMainMethods(packageMember, targetPath)
yield membersMainMethod
else if target.getName.endsWith(".class") then
val cls = cl.loadClass(targetPath)
try
val method = cls.getMethod("main", classOf[Array[String]])
if Modifier.isStatic(method.getModifiers) then List((cls.getName, method)) else Nil
catch
case _: java.lang.NoSuchMethodException => Nil
else Nil
end collectMainMethods

val candidates = for
file <- outDir.toFile.listFiles.toList
method <- collectMainMethods(file, "")
yield method

candidates match
case Nil =>
throw ScriptingException(s"No main methods detected in script ${scriptFile}")
case _ :: _ :: _ =>
throw ScriptingException("A script must contain only one main method. " +
s"Detected the following main methods:\n${candidates.mkString("\n")}")
case m :: Nil => m
end match
end detectMainClassAndMethod

def pathsep = sys.props("path.separator")

extension(pathstr:String) {
def withSlash:String = pathstr.replace('\\', '/')
}
77 changes: 26 additions & 51 deletions compiler/src/dotty/tools/scripting/ScriptingDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,51 @@ package dotty.tools.scripting

import java.nio.file.{ Files, Path }
import java.io.File
import java.net.{ URL, URLClassLoader }
import java.lang.reflect.{ Modifier, Method }

import scala.jdk.CollectionConverters._

import dotty.tools.dotc.{ Driver, Compiler }
import dotty.tools.dotc.core.Contexts, Contexts.{ Context, ContextBase, ctx }
import dotty.tools.dotc.config.CompilerCommand
import dotty.tools.dotc.{ Driver }
import dotty.tools.dotc.core.Contexts, Contexts.{ Context, ctx }
import dotty.tools.io.{ PlainDirectory, Directory }
import dotty.tools.dotc.reporting.Reporter
import dotty.tools.dotc.config.Settings.Setting._

import sys.process._
import dotty.tools.dotc.util.ScriptSourceFile
import dotty.tools.io.AbstractFile

class ScriptingDriver(compilerArgs: Array[String], scriptFile: File, scriptArgs: Array[String]) extends Driver:
def compileAndRun(): Unit =
def compileAndRun(pack:(Path, String) => Unit = null): Unit =
val outDir = Files.createTempDirectory("scala3-scripting")
val (toCompile, rootCtx) = setup(compilerArgs :+ scriptFile.getAbsolutePath, initCtx.fresh)

given Context = rootCtx.fresh.setSetting(rootCtx.settings.outputDir,
new PlainDirectory(Directory(outDir)))

if doCompile(newCompiler, toCompile).hasErrors then
throw ScriptingException("Errors encountered during compilation")

try detectMainMethod(outDir, ctx.settings.classpath.value).invoke(null, scriptArgs)
val result = doCompile(newCompiler, toCompile)
if result.hasErrors then
throw ScriptingException(s"Errors encountered during compilation to dir [$outDir]")

try
if outDir.toFile.listFiles.toList.isEmpty then
sys.error(s"no files generated by compiling script ${scriptFile}")

Option(pack) match {
case None =>
case Some(func) =>
val javaClasspath = sys.props("java.class.path")
val pathsep = sys.props("path.separator")
val runtimeClasspath = s"${ctx.settings.classpath.value}$pathsep$javaClasspath"
func(outDir, runtimeClasspath)
}
catch
case e: java.lang.reflect.InvocationTargetException =>
throw e.getCause
finally
deleteFile(outDir.toFile)

def content(file: Path): Array[Char] = new String(Files.readAllBytes(file)).toCharArray
def scriptSource(file: Path) = ScriptSourceFile(AbstractFile.getFile(file), content(file))

end compileAndRun

private def deleteFile(target: File): Unit =
Expand All @@ -41,46 +56,6 @@ class ScriptingDriver(compilerArgs: Array[String], scriptFile: File, scriptArgs:
target.delete()
end deleteFile

private def detectMainMethod(outDir: Path, classpath: String): Method =
val outDirURL = outDir.toUri.toURL
val classpathUrls = classpath.split(":").map(File(_).toURI.toURL)
val cl = URLClassLoader(classpathUrls :+ outDirURL)

def collectMainMethods(target: File, path: String): List[Method] =
val nameWithoutExtension = target.getName.takeWhile(_ != '.')
val targetPath =
if path.nonEmpty then s"${path}.${nameWithoutExtension}"
else nameWithoutExtension

if target.isDirectory then
for
packageMember <- target.listFiles.toList
membersMainMethod <- collectMainMethods(packageMember, targetPath)
yield membersMainMethod
else if target.getName.endsWith(".class") then
val cls = cl.loadClass(targetPath)
try
val method = cls.getMethod("main", classOf[Array[String]])
if Modifier.isStatic(method.getModifiers) then List(method) else Nil
catch
case _: java.lang.NoSuchMethodException => Nil
else Nil
end collectMainMethods

val candidates = for
file <- outDir.toFile.listFiles.toList
method <- collectMainMethods(file, "")
yield method

candidates match
case Nil =>
throw ScriptingException("No main methods detected in your script")
case _ :: _ :: _ =>
throw ScriptingException("A script must contain only one main method. " +
s"Detected the following main methods:\n${candidates.mkString("\n")}")
case m :: Nil => m
end match
end detectMainMethod
end ScriptingDriver

case class ScriptingException(msg: String) extends RuntimeException(msg)
19 changes: 19 additions & 0 deletions compiler/test-resources/scripting/hashBang.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env scala
# comment
STUFF=nada
!#

def main(args: Array[String]): Unit =
System.err.printf("mainClassFromStack: %s\n",mainFromStack)
//assert(mainFromStack.contains("HashBang"),s"fromStack[$mainFromStack]")

lazy val mainFromStack:String = {
val result = new java.io.StringWriter()
new RuntimeException("stack").printStackTrace(new java.io.PrintWriter(result))
val stack = result.toString.split("[\r\n]+").toList
//for( s <- stack ){ System.err.printf("[%s]\n",s) }
stack.filter { str => str.contains(".main(") }.map {
_.replaceAll(".*[(]","").
replaceAll("[:)].*","")
}.distinct.take(1).mkString("")
}
22 changes: 22 additions & 0 deletions compiler/test-resources/scripting/mainClassOnStack.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env scala
export STUFF=nada
lots of other stuff that isn't valid scala
!#
object Zoo {
def main(args: Array[String]): Unit =
printf("script.name: %s\n",sys.props("script.name"))
printf("mainClassFromStack: %s\n",mainFromStack)
assert(mainFromStack == "Zoo",s"fromStack[$mainFromStack]")

lazy val mainFromStack:String = {
val result = new java.io.StringWriter()
new RuntimeException("stack").printStackTrace(new java.io.PrintWriter(result))
val stack = result.toString.split("[\r\n]+").toList
// for( s <- stack ){ System.err.printf("[%s]\n",s) }
val shortStack = stack.filter { str => str.contains(".main(") && ! str.contains("$") }.map {
_.replaceAll("[.].*","").replaceAll("\\s+at\\s+","")
}
// for( s <- shortStack ){ System.err.printf("[%s]\n",s) }
shortStack.take(1).mkString("|")
}
}
6 changes: 6 additions & 0 deletions compiler/test-resources/scripting/scriptName.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env scala

def main(args: Array[String]): Unit =
val name = sys.props("script.name")
printf("script.name: %s\n",name)
assert(name == "scriptName.scala")
Loading