diff --git a/jvm/src/test/scala/scala/xml/XMLTest.scala b/jvm/src/test/scala/scala/xml/XMLTest.scala
index 812b5368..032b341a 100644
--- a/jvm/src/test/scala/scala/xml/XMLTest.scala
+++ b/jvm/src/test/scala/scala/xml/XMLTest.scala
@@ -657,6 +657,34 @@ class XMLTestJVM {
def namespaceAware2: Unit =
roundtrip(namespaceAware = true, """""")
+ @UnitTest
+ def useXMLReaderWithXMLFilter(): Unit = {
+ val parent: org.xml.sax.XMLReader = javax.xml.parsers.SAXParserFactory.newInstance.newSAXParser.getXMLReader
+ val filter: org.xml.sax.XMLFilter = new org.xml.sax.helpers.XMLFilterImpl(parent) {
+ override def characters(ch: Array[Char], start: Int, length: Int): Unit = {
+ for (i <- 0 until length) if (ch(start+i) == 'a') ch(start+i) = 'b'
+ super.characters(ch, start, length)
+ }
+ }
+ assertEquals(XML.withXMLReader(filter).loadString("caffeeaaay").toString, "cbffeebbby")
+ }
+
+ @UnitTest
+ def checkThatErrorHandlerIsNotOverwritten(): Unit = {
+ var gotAnError: Boolean = false
+ XML.reader.setErrorHandler(new org.xml.sax.ErrorHandler {
+ override def warning(e: SAXParseException): Unit = gotAnError = true
+ override def error(e: SAXParseException): Unit = gotAnError = true
+ override def fatalError(e: SAXParseException): Unit = gotAnError = true
+ })
+ try {
+ XML.loadString("")
+ } catch {
+ case _: org.xml.sax.SAXParseException =>
+ }
+ assertTrue(gotAnError)
+ }
+
@UnitTest
def nodeSeqNs: Unit = {
val x = {
diff --git a/shared/src/main/scala/scala/xml/XML.scala b/shared/src/main/scala/scala/xml/XML.scala
index 46b5a0ec..531afcc9 100755
--- a/shared/src/main/scala/scala/xml/XML.scala
+++ b/shared/src/main/scala/scala/xml/XML.scala
@@ -14,8 +14,8 @@ package scala
package xml
import factory.XMLLoader
-import java.io.{ File, FileDescriptor, FileInputStream, FileOutputStream }
-import java.io.{ InputStream, Reader, StringReader }
+import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream}
+import java.io.{InputStream, Reader, StringReader}
import java.nio.channels.Channels
import scala.util.control.Exception.ultimately
@@ -72,6 +72,10 @@ object XML extends XMLLoader[Elem] {
def withSAXParser(p: SAXParser): XMLLoader[Elem] =
new XMLLoader[Elem] { override val parser: SAXParser = p }
+ /** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */
+ def withXMLReader(r: XMLReader): XMLLoader[Elem] =
+ new XMLLoader[Elem] { override val reader: XMLReader = r }
+
/**
* Saves a node to a file with given filename using given encoding
* optionally with xmldecl and doctype declaration.
diff --git a/shared/src/main/scala/scala/xml/factory/XMLLoader.scala b/shared/src/main/scala/scala/xml/factory/XMLLoader.scala
index 620e1b6e..5cdc9461 100644
--- a/shared/src/main/scala/scala/xml/factory/XMLLoader.scala
+++ b/shared/src/main/scala/scala/xml/factory/XMLLoader.scala
@@ -14,7 +14,7 @@ package scala
package xml
package factory
-import org.xml.sax.SAXNotRecognizedException
+import org.xml.sax.{SAXNotRecognizedException, XMLReader}
import javax.xml.parsers.SAXParserFactory
import parsing.{FactoryAdapter, NoBindingFactoryAdapter}
import java.io.{File, FileDescriptor, InputStream, Reader}
@@ -46,59 +46,77 @@ trait XMLLoader[T <: Node] {
/* Override this to use a different SAXParser. */
def parser: SAXParser = parserInstance.get
+ /* Override this to use a different XMLReader. */
+ def reader: XMLReader = parser.getXMLReader
+
/**
* Loads XML from the given InputSource, using the supplied parser.
* The methods available in scala.xml.XML use the XML parser in the JDK.
*/
- def loadXML(source: InputSource, parser: SAXParser): T = {
- val result: FactoryAdapter = parse(source, parser)
+ def loadXML(source: InputSource, parser: SAXParser): T = loadXML(source, parser.getXMLReader)
+
+ def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(source, parser.getXMLReader)
+
+ private def loadXML(source: InputSource, reader: XMLReader): T = {
+ val result: FactoryAdapter = parse(source, reader)
result.rootElem.asInstanceOf[T]
}
-
- def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = {
- val result: FactoryAdapter = parse(source, parser)
+
+ private def loadXMLNodes(source: InputSource, reader: XMLReader): Seq[Node] = {
+ val result: FactoryAdapter = parse(source, reader)
result.prolog ++ (result.rootElem :: result.epilogue)
}
- private def parse(source: InputSource, parser: SAXParser): FactoryAdapter = {
+ private def parse(source: InputSource, reader: XMLReader): FactoryAdapter = {
+ if (source == null) throw new IllegalArgumentException("InputSource cannot be null")
+
val result: FactoryAdapter = adapter
+ reader.setContentHandler(result)
+ reader.setDTDHandler(result)
+ /* Do not overwrite pre-configured EntityResolver. */
+ if (reader.getEntityResolver == null) reader.setEntityResolver(result)
+ /* Do not overwrite pre-configured ErrorHandler. */
+ if (reader.getErrorHandler == null) reader.setErrorHandler(result)
+
try {
- parser.setProperty("http://xml.org/sax/properties/lexical-handler", result)
+ reader.setProperty("http://xml.org/sax/properties/lexical-handler", result)
} catch {
case _: SAXNotRecognizedException =>
}
result.scopeStack = TopScope :: result.scopeStack
- parser.parse(source, result)
+ reader.parse(source)
result.scopeStack = result.scopeStack.tail
result
}
+ /** loads XML from given InputSource. */
+ def load(source: InputSource): T = loadXML(source, reader)
+
/** Loads XML from the given file, file descriptor, or filename. */
- def loadFile(file: File): T = loadXML(fromFile(file), parser)
- def loadFile(fd: FileDescriptor): T = loadXML(fromFile(fd), parser)
- def loadFile(name: String): T = loadXML(fromFile(name), parser)
+ def loadFile(file: File): T = load(fromFile(file))
+ def loadFile(fd: FileDescriptor): T = load(fromFile(fd))
+ def loadFile(name: String): T = load(fromFile(name))
- /** loads XML from given InputStream, Reader, sysID, InputSource, or URL. */
- def load(is: InputStream): T = loadXML(fromInputStream(is), parser)
- def load(reader: Reader): T = loadXML(fromReader(reader), parser)
- def load(sysID: String): T = loadXML(fromSysId(sysID), parser)
- def load(source: InputSource): T = loadXML(source, parser)
- def load(url: URL): T = loadXML(fromInputStream(url.openStream()), parser)
+ /** loads XML from given InputStream, Reader, sysID, or URL. */
+ def load(is: InputStream): T = load(fromInputStream(is))
+ def load(reader: Reader): T = load(fromReader(reader))
+ def load(sysID: String): T = load(fromSysId(sysID))
+ def load(url: URL): T = load(fromInputStream(url.openStream()))
/** Loads XML from the given String. */
- def loadString(string: String): T = loadXML(fromString(string), parser)
+ def loadString(string: String): T = load(fromString(string))
/** Load XML nodes, including comments and processing instructions that precede and follow the root element. */
- def loadFileNodes(file: File): Seq[Node] = loadXMLNodes(fromFile(file), parser)
- def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadXMLNodes(fromFile(fd), parser)
- def loadFileNodes(name: String): Seq[Node] = loadXMLNodes(fromFile(name), parser)
- def loadNodes(is: InputStream): Seq[Node] = loadXMLNodes(fromInputStream(is), parser)
- def loadNodes(reader: Reader): Seq[Node] = loadXMLNodes(fromReader(reader), parser)
- def loadNodes(sysID: String): Seq[Node] = loadXMLNodes(fromSysId(sysID), parser)
- def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, parser)
- def loadNodes(url: URL): Seq[Node] = loadXMLNodes(fromInputStream(url.openStream()), parser)
- def loadStringNodes(string: String): Seq[Node] = loadXMLNodes(fromString(string), parser)
+ def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, reader)
+ def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file))
+ def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadNodes(fromFile(fd))
+ def loadFileNodes(name: String): Seq[Node] = loadNodes(fromFile(name))
+ def loadNodes(is: InputStream): Seq[Node] = loadNodes(fromInputStream(is))
+ def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader))
+ def loadNodes(sysID: String): Seq[Node] = loadNodes(fromSysId(sysID))
+ def loadNodes(url: URL): Seq[Node] = loadNodes(fromInputStream(url.openStream()))
+ def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string))
}
diff --git a/shared/src/main/scala/scala/xml/package.scala b/shared/src/main/scala/scala/xml/package.scala
index 7847f63b..d25a80e2 100644
--- a/shared/src/main/scala/scala/xml/package.scala
+++ b/shared/src/main/scala/scala/xml/package.scala
@@ -80,5 +80,6 @@ package object xml {
type SAXParseException = org.xml.sax.SAXParseException
type EntityResolver = org.xml.sax.EntityResolver
type InputSource = org.xml.sax.InputSource
+ type XMLReader = org.xml.sax.XMLReader
type SAXParser = javax.xml.parsers.SAXParser
}
diff --git a/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala b/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala
index 85389711..4fe936a4 100755
--- a/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala
+++ b/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala
@@ -98,8 +98,9 @@ trait MarkupParser extends MarkupParserCommon with TokenTests {
var extIndex = -1
/** holds temporary values of pos */
- // Note: this is clearly an override, but if marked as such it causes a "...cannot override a mutable variable"
- // error with Scala 3; does it work with Scala 3 if not explicitly marked as an override remains to be seen...
+ // Note: if marked as an override, this causes a "...cannot override a mutable variable" error with Scala 3;
+ // SethTisue noted on Oct 14, 2021 that lampepfl/dotty#13744 should fix it - and it probably did,
+ // but Scala XML still builds against Scala 3 version that has this bug, so this still can not be marked as an override :(
var tmppos: Int = _
/** holds the next character */