Skip to content

increase the use of scala_toolchain #530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions scala/private/rule_impls.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ _java_filetype = FileType([".java"])
_scala_filetype = FileType([".scala"])
_srcjar_filetype = FileType([".srcjar"])

def _get_toolchain(ctx):
return ctx.toolchains['@io_bazel_rules_scala//scala:toolchain_type']

def _adjust_resources_path_by_strip_prefix(path, resource_strip_prefix):
if not path.startswith(resource_strip_prefix):
fail("Resource file %s is not under the specified prefix to strip" % path)
Expand Down Expand Up @@ -89,20 +92,19 @@ rm -f {jar_output}
touch {statsfile}
""" + ijar_cmd

zipper = _get_toolchain(ctx).zipper
cmd = cmd.format(
path = zipper_arg_path.path,
jar_output = ctx.outputs.jar.path,
zipper = ctx.executable._zipper.path,
zipper = zipper.path,
statsfile = ctx.outputs.statsfile.path,
)

outs = [ctx.outputs.jar, ctx.outputs.statsfile]
if buildijar:
outs.extend([ctx.outputs.ijar])

inputs = ctx.files.resources + [
ctx.outputs.manifest, ctx.executable._zipper, zipper_arg_path
]
inputs = ctx.files.resources + [ctx.outputs.manifest, zipper, zipper_arg_path]

ctx.actions.run_shell(
inputs = inputs,
Expand Down Expand Up @@ -134,18 +136,22 @@ def _join_path(args, sep = ","):

def _compile(ctx, cjars, dep_srcjars, buildijar, transitive_compile_jars,
labels, implicit_junit_deps_needed_for_java_compilation):
toolchain = _get_toolchain(ctx)
ijar_output_path = ""
ijar_cmd_path = ""
if buildijar:
ijar_output_path = ctx.outputs.ijar.path
ijar_cmd_path = ctx.executable._ijar.path
ijar_cmd_path = toolchain.ijar.path

java_srcs = _java_filetype.filter(ctx.files.srcs)
sources = _scala_filetype.filter(ctx.files.srcs) + java_srcs
srcjars = _srcjar_filetype.filter(ctx.files.srcs)
all_srcjars = depset(srcjars, transitive = [dep_srcjars])
# look for any plugins:
plugins = _collect_plugin_paths(ctx.attr.plugins)
plugins = ctx.attr.plugins
if len(plugins) == 0:
plugins = toolchain.plugins
plugins = _collect_plugin_paths(plugins)
dependency_analyzer_plugin_jars = []
dependency_analyzer_mode = "off"
compiler_classpath_jars = cjars
Expand Down Expand Up @@ -179,12 +185,12 @@ CurrentTarget: {current_target}
indirect_targets = indirect_targets,
current_target = current_target)

plugin_arg = _join_path(plugins.to_list())
plugins_list = plugins.to_list()
plugin_arg = _join_path(plugins_list)

separator = ctx.configuration.host_path_separator
compiler_classpath = _join_path(compiler_classpath_jars.to_list(), separator)

toolchain = ctx.toolchains['@io_bazel_rules_scala//scala:toolchain_type']
scalacopts = toolchain.scalacopts + ctx.attr.scalacopts

scalac_args = """
Expand Down Expand Up @@ -244,10 +250,15 @@ StatsfileOutput: {statsfile_output}
if buildijar:
outs.extend([ctx.outputs.ijar])
ins = (compiler_classpath_jars.to_list() + dep_srcjars.to_list() +
list(srcjars) + list(sources) + ctx.files.srcs + ctx.files.plugins +
list(srcjars) + list(sources) + ctx.files.srcs + plugins_list +
dependency_analyzer_plugin_jars + classpath_resources +
ctx.files.resources + ctx.files.resource_jars + ctx.files._java_runtime
+ [ctx.outputs.manifest, ctx.executable._ijar, argfile])
+ [ctx.outputs.manifest, toolchain.ijar, argfile])

scalac_jvm_flags = ctx.attr.scalac_jvm_flags
if len(scalac_jvm_flags) == 0:
scalac_jvm_flags = toolchain.scalac_jvm_flags

ctx.actions.run(
inputs = ins,
outputs = outs,
Expand All @@ -264,8 +275,7 @@ StatsfileOutput: {statsfile_output}
# be correctly handled since the executable is a jvm app that will
# consume the flags on startup.
arguments = [
"--jvm_flag=%s" % f
for f in _expand_location(ctx, ctx.attr.scalac_jvm_flags)
"--jvm_flag=%s" % f for f in _expand_location(ctx, scalac_jvm_flags)
] + ["@" + argfile.path],
)

Expand Down Expand Up @@ -375,7 +385,7 @@ def _build_deployable(ctx, jars_list):
ctx.actions.run(
inputs = jars_list,
outputs = [ctx.outputs.deploy_jar],
executable = ctx.executable._singlejar,
executable = _get_toolchain(ctx).singlejar,
mnemonic = "ScalaDeployJar",
progress_message = "scala deployable %s" % ctx.label,
arguments = args)
Expand Down Expand Up @@ -480,8 +490,9 @@ def _collect_jars_from_common_ctx(ctx, extra_deps = [],

dependency_analyzer_is_off = is_dependency_analyzer_off(ctx)

toolchain = _get_toolchain(ctx)
# Get jars from deps
auto_deps = [ctx.attr._scalalib, ctx.attr._scalareflect]
auto_deps = toolchain.default_classpath
deps_jars = collect_jars(ctx.attr.deps + auto_deps + extra_deps,
dependency_analyzer_is_off)
(cjars, transitive_rjars, jars2labels,
Expand Down Expand Up @@ -636,9 +647,10 @@ def scala_binary_impl(ctx):
return out

def scala_repl_impl(ctx):
toolchain = _get_toolchain(ctx)
# need scala-compiler for MainGenericRunner below
jars = _collect_jars_from_common_ctx(
ctx, extra_runtime_deps = [ctx.attr._scalacompiler])
ctx, extra_runtime_deps = toolchain.repl_runtime_deps)
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)

args = " ".join(ctx.attr.scalacopts)
Expand Down
45 changes: 13 additions & 32 deletions scala/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,16 @@ _launcher_template = {
default = Label("@java_stub_template//file")),
}

# TODO, we should resolve these by ctx.toolchains to find the java toolchain.
# but I don't know what the toolchain type is for java
# https://stackoverflow.com/questions/50977784/how-to-resolve-the-java-toolchain-in-bazel
_implicit_deps = {
"_singlejar": attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/jdk:singlejar"),
allow_files = True),
"_ijar": attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/jdk:ijar"),
allow_files = True),
# TODO: putting this into the toolchain seems to cause it to fail from not finding runfiles
"_scalac": attr.label(
executable = True,
cfg = "host",
default = Label("//src/java/io/bazel/rulesscala/scalac"),
allow_files = True),
"_scalalib": attr.label(
default = Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_library"),
allow_files = True),
"_scalacompiler": attr.label(
default = Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_compiler"),
allow_files = True),
"_scalareflect": attr.label(
default = Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_reflect"),
allow_files = True),
"_zipper": attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/zip:zipper"),
"@io_bazel_rules_scala//src/java/io/bazel/rulesscala/scalac"),
allow_files = True),
"_java_toolchain": attr.label(
default = Label("@bazel_tools//tools/jdk:current_java_toolchain")),
Expand All @@ -60,6 +38,7 @@ _implicit_deps = {
}

# Single dep to allow IDEs to pickup all the implicit dependencies.
# these are private deps, but aspects can traverse private deps
_resolve_deps = {
"_scala_toolchain": attr.label_list(
default = [
Expand Down Expand Up @@ -256,31 +235,33 @@ scala_repl = rule(

_SCALA_BUILD_FILE = """
# scala.BUILD
java_import(
load("@io_bazel_rules_scala//scala:scala_import.bzl", "scala_import")

scala_import(
name = "scala-xml",
jars = ["lib/scala-xml_2.11-1.0.5.jar"],
visibility = ["//visibility:public"],
)

java_import(
scala_import(
name = "scala-parser-combinators",
jars = ["lib/scala-parser-combinators_2.11-1.0.4.jar"],
visibility = ["//visibility:public"],
)

java_import(
scala_import(
name = "scala-library",
jars = ["lib/scala-library.jar"],
visibility = ["//visibility:public"],
)

java_import(
scala_import(
name = "scala-compiler",
jars = ["lib/scala-compiler.jar"],
visibility = ["//visibility:public"],
)

java_import(
scala_import(
name = "scala-reflect",
jars = ["lib/scala-reflect.jar"],
visibility = ["//visibility:public"],
Expand Down
48 changes: 46 additions & 2 deletions scala/scala_toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,8 +1,52 @@
def _scala_toolchain_impl(ctx):
toolchain = platform_common.ToolchainInfo(scalacopts = ctx.attr.scalacopts,)
toolchain = platform_common.ToolchainInfo(
scalacopts = ctx.attr.scalacopts,
plugins = ctx.attr.plugins,
scalac_jvm_flags = ctx.attr.plugins,
singlejar = ctx.executable.singlejar,
ijar = ctx.executable.ijar,
zipper = ctx.executable.zipper,
default_classpath = ctx.attr.default_classpath,
repl_runtime_deps = ctx.attr.repl_runtime_deps,
)
return [toolchain]

scala_toolchain = rule(
_scala_toolchain_impl, attrs = {
_scala_toolchain_impl,
attrs = {
'scalacopts': attr.string_list(),
'plugins': attr.label_list(allow_files = ['.jar']),
'scalac_jvm_flags': attr.string_list(),
'singlejar': attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/jdk:singlejar"),
allow_files = True),
'ijar': attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/jdk:ijar"),
allow_files = True),
'zipper': attr.label(
executable = True,
cfg = "host",
default = Label("@bazel_tools//tools/zip:zipper"),
allow_files = True),
'default_classpath': attr.label_list(
default = [
Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_library"
),
Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_reflect"
),
],
providers = [JavaInfo]),
'repl_runtime_deps': attr.label_list(
default = [
Label(
"//external:io_bazel_rules_scala/dependency/scala/scala_compiler"
),
],
providers = [JavaInfo]),
})