Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jan 3, 2025
1 parent 91ef7a2 commit c9a074e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.nio.file.{Files, Paths}
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf._
import org.apache.commons.logging.LogFactory

import ml.dmlc.xgboost4j.java.{ColumnBatch, CudfColumnBatch}
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
Expand Down Expand Up @@ -66,19 +65,16 @@ case class HostMemoryBufferInfo(hostMemoryBuffer: HostMemoryBuffer, size: Long)
// The data will be cached into host memory
private[spark] class HostExternalMemoryIterator()
extends ExternalMemory[HostMemoryBufferInfo] {
private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")
private lazy val allocator = DefaultHostMemoryAllocator.get()

class XGBoostHostBufferConsumer extends HostBufferConsumer {
private val hostBuffers = ArrayBuffer.empty[HostMemoryBufferInfo]

override def handleBuffer(hostMemoryBuffer: HostMemoryBuffer, l: Long): Unit = {
logger.info("XGBoostHostBufferConsumer handleBuffer: " + l)
hostBuffers.append(HostMemoryBufferInfo(hostMemoryBuffer = hostMemoryBuffer, size = l))
}

def getHostMemoryBuffer: HostMemoryBufferInfo = {
logger.info("getHostMemoryBuffer size: " + hostBuffers.size)
if (hostBuffers.size == 1) {
hostBuffers(0)
} else if (hostBuffers.size > 1) {
Expand All @@ -92,7 +88,6 @@ private[spark] class HostExternalMemoryIterator()
offset += h.size
}
}
logger.info("getHostMemoryBuffer -- ")
HostMemoryBufferInfo(buffer, totalSize)
} else {
throw new RuntimeException("No data") // Unreachable
Expand All @@ -107,22 +102,19 @@ private[spark] class HostExternalMemoryIterator()
* @return the content
*/
override def convertTable(table: Table): HostMemoryBufferInfo = {
logger.info("HostMemoryBufferInfo convertTable ++")
val names = (0 until table.getNumberOfColumns).map(_.toString)
val options = ArrowIPCWriterOptions.builder().withNotNullableColumnNames(names: _*).build()
val consumer = new XGBoostHostBufferConsumer()
withResource(Table.writeArrowIPCChunked(options, consumer)) { writer =>
writer.write(table)
}
logger.info("HostMemoryBufferInfo convertTable - table: " + table)
consumer.getHostMemoryBuffer
}

class XGBoostHostBufferProvider(bufferInfo: HostMemoryBufferInfo) extends HostBufferProvider {
var offset = 0L

override def readInto(hostMemoryBuffer: HostMemoryBuffer, l: Long): Long = {
logger.info("XGBoostHostBufferProvider readInto: " + l)
val amountLeft = bufferInfo.size - offset
val amountToCopy = Math.max(0, Math.min(l, amountLeft))
if (amountToCopy > 0) {
Expand All @@ -144,7 +136,6 @@ private[spark] class HostExternalMemoryIterator()
* @return Table
*/
override def loadTable(content: HostMemoryBufferInfo): Table = {
logger.info("HostExternalMemoryIterator loadTable +")
val tables = ArrayBuffer.empty[Table]
withResource(new XGBoostHostBufferProvider(content)) { provider =>
withResource(Table.readArrowIPCChunked(provider)) { reader =>
Expand All @@ -153,7 +144,6 @@ private[spark] class HostExternalMemoryIterator()
tables.append(table.get)
table = Option(reader.getNextIfAvailable)
}
logger.info("HostExternalMemoryIterator loadTable size: " + tables.size)
if (tables.size == 1) {
tables(0)
} else if (tables.size > 1) {
Expand Down Expand Up @@ -183,7 +173,7 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
val path = Paths.get(dirPath)
if (!Files.exists(path)) {
Files.createDirectories(path)
println(s"Directory created: $dirPath")
println(s"Directory created at: $dirPath")
} else {
println(s"Directory already exists: $dirPath")
}
Expand Down Expand Up @@ -275,7 +265,6 @@ private[spark] class ExternalMemoryIterator(val input: Iterator[Table],
val indices: ColumnIndices,
val path: Option[String] = None)
extends Iterator[ColumnBatch] with AutoCloseable {
private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")

private var iter = input

Expand All @@ -302,14 +291,12 @@ private[spark] class ExternalMemoryIterator(val input: Iterator[Table],
if (iter == input) {
externalMemory.cacheTable(batch.table)
}
logger.info("ExternalMemoryIterator next: table" + batch.table)
val xx = new CudfColumnBatch(
new CudfColumnBatch(
batch.select(indices.featureIds.get),
batch.select(indices.labelId),
batch.select(indices.weightId.getOrElse(-1)),
batch.select(indices.marginId.getOrElse(-1)),
batch.select(indices.groupId.getOrElse(-1)));
xx
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
}

private class GpuColumnBatch(val table: Table) extends AutoCloseable {
private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")

def select(index: Int): Table = {
select(Seq(index))
Expand All @@ -310,13 +309,5 @@ private class GpuColumnBatch(val table: Table) extends AutoCloseable {
new Table(indices.map(table.getColumn): _*)
}

override def close(): Unit = {
logger.info("GpuColumnBatch close +")
if (Option(table).isDefined) {
logger.info("GpuColumnBatch close ++")
table.close()
}
logger.info("GpuColumnBatch close -")

}
override def close(): Unit = Option(table).foreach(_.close())
}

0 comments on commit c9a074e

Please sign in to comment.