File no-hadoop.patch of Package xgboost
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 8d4f2c05..7649df65 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -29,18 +29,6 @@
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
<version>${scala-collection-compat.version}</version>
</dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-hdfs</artifactId>
- <version>${hadoop.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-common</artifactId>
- <version>${hadoop.version}</version>
- <scope>provided</scope>
- </dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
deleted file mode 100644
index 655b9902..00000000
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
+++ /dev/null
@@ -1,117 +0,0 @@
-package ml.dmlc.xgboost4j.java;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.*;
-import java.util.stream.Collectors;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-
-public class ExternalCheckpointManager {
-
- private Log logger = LogFactory.getLog("ExternalCheckpointManager");
- private String modelSuffix = ".model";
- private Path checkpointPath;
- private FileSystem fs;
-
- public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
- if (checkpointPath == null || checkpointPath.isEmpty()) {
- throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
- " empty checkpoint path");
- }
- this.checkpointPath = new Path(checkpointPath);
- this.fs = fs;
- }
-
- private String getPath(int version) {
- return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
- }
-
- private List<Integer> getExistingVersions() throws IOException {
- if (!fs.exists(checkpointPath)) {
- return new ArrayList<>();
- } else {
- return Arrays.stream(fs.listStatus(checkpointPath))
- .map(path -> path.getPath().getName())
- .filter(fileName -> fileName.endsWith(modelSuffix))
- .map(fileName -> Integer.valueOf(
- fileName.substring(0, fileName.length() - modelSuffix.length())))
- .collect(Collectors.toList());
- }
- }
-
- public void cleanPath() throws IOException {
- fs.delete(checkpointPath, true);
- }
-
- public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
- List<Integer> versions = getExistingVersions();
- if (versions.size() > 0) {
- int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
- String checkpointPath = getPath(latestVersion);
- InputStream in = fs.open(new Path(checkpointPath));
- logger.info("loaded checkpoint from " + checkpointPath);
- Booster booster = XGBoost.loadModel(in);
- booster.setVersion(latestVersion);
- return booster;
- } else {
- return null;
- }
- }
-
- public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
- List<String> prevModelPaths = getExistingVersions().stream()
- .map(this::getPath).collect(Collectors.toList());
- String eventualPath = getPath(boosterToCheckpoint.getVersion());
- String tempPath = eventualPath + "-" + UUID.randomUUID();
- try (OutputStream out = fs.create(new Path(tempPath), true)) {
- boosterToCheckpoint.saveModel(out);
- fs.rename(new Path(tempPath), new Path(eventualPath));
- logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
- prevModelPaths.stream().forEach(path -> {
- try {
- fs.delete(new Path(path), true);
- } catch (IOException e) {
- logger.error("failed to delete outdated checkpoint at " + path, e);
- }
- });
- }
- }
-
- public void cleanUpHigherVersions(int currentRound) throws IOException {
- getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
- try {
- fs.delete(new Path(getPath(v)), true);
- } catch (IOException e) {
- logger.error("failed to clean checkpoint from other training instance", e);
- }
- });
- }
-
- public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
- throws IOException {
- if (checkpointInterval > 0) {
- List<Integer> prevRounds =
- getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
- prevRounds.add(0);
- int firstCheckpointRound = prevRounds.stream()
- .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
- List<Integer> arr = new ArrayList<>();
- for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
- arr.add(i);
- }
- arr.add(numOfRounds);
- return arr;
- } else if (checkpointInterval <= 0) {
- List<Integer> l = new ArrayList<Integer>();
- l.add(numOfRounds);
- return l;
- } else {
- throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
- }
- }
-}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
index bcd0b1b1..3e23c15f 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
@@ -22,7 +22,6 @@ import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.fs.FileSystem;
/**
* trainer for xgboost
@@ -134,34 +133,35 @@ public class XGBoost {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
- private static void saveCheckpoint(
- Booster booster,
- int iter,
- Set<Integer> checkpointIterations,
- ExternalCheckpointManager ecm) throws XGBoostError {
- try {
- if (checkpointIterations.contains(iter)) {
- ecm.updateCheckpoint(booster);
- }
- } catch (Exception e) {
- logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
- throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
- }
- }
-
- public static Booster trainAndSaveCheckpoint(
+ /**
+ * Train a booster given parameters.
+ *
+ * @param dtrain Data to be trained.
+ * @param params Parameters.
+ * @param round Number of boosting iterations.
+ * @param watches a group of items to be evaluated during training, this allows user to watch
+ * performance on the validation set.
+ * @param metrics array containing the evaluation metrics for each matrix in watches for each
+ * iteration
+ * @param earlyStoppingRounds if non-zero, training would be stopped
+ * after a specified number of consecutive
+ * goes to the unexpected direction in any evaluation metric.
+ * @param obj customized objective
+ * @param eval customized evaluation
+ * @param booster train from scratch if set to null; train from an existing booster if not null.
+ * @return The trained booster.
+ */
+ public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
- int numRounds,
+ int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRounds,
- Booster booster,
- int checkpointInterval,
- String checkpointPath,
- FileSystem fs) throws XGBoostError, IOException {
+ Booster booster) throws XGBoostError {
+
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
@@ -169,11 +169,6 @@ public class XGBoost {
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
- Set<Integer> checkpointIterations = new HashSet<>();
- ExternalCheckpointManager ecm = null;
- if (checkpointPath != null) {
- ecm = new ExternalCheckpointManager(checkpointPath, fs);
- }
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
@@ -184,7 +179,7 @@ public class XGBoost {
evalMats = mats.toArray(new DMatrix[mats.size()]);
bestIteration = 0;
- metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
+ metrics = metrics == null ? new float[evalNames.length][round] : metrics;
//collect all data matrixs
DMatrix[] allMats;
@@ -209,22 +204,17 @@ public class XGBoost {
booster.setParams(params);
}
- if (ecm != null) {
- checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
- }
-
boolean initial_best_score_flag = false;
boolean max_direction = false;
// begin to train
- for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
+ for (int iter = booster.getVersion() / 2; iter < round; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
- saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
}
@@ -290,44 +280,6 @@ public class XGBoost {
return booster;
}
- /**
- * Train a booster given parameters.
- *
- * @param dtrain Data to be trained.
- * @param params Parameters.
- * @param round Number of boosting iterations.
- * @param watches a group of items to be evaluated during training, this allows user to watch
- * performance on the validation set.
- * @param metrics array containing the evaluation metrics for each matrix in watches for each
- * iteration
- * @param earlyStoppingRounds if non-zero, training would be stopped
- * after a specified number of consecutive
- * goes to the unexpected direction in any evaluation metric.
- * @param obj customized objective
- * @param eval customized evaluation
- * @param booster train from scratch if set to null; train from an existing booster if not null.
- * @return The trained booster.
- */
- public static Booster train(
- DMatrix dtrain,
- Map<String, Object> params,
- int round,
- Map<String, DMatrix> watches,
- float[][] metrics,
- IObjective obj,
- IEvaluation eval,
- int earlyStoppingRounds,
- Booster booster) throws XGBoostError {
- try {
- return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
- earlyStoppingRounds, booster,
- -1, null, null);
- } catch (IOException e) {
- logger.error("training failed in xgboost4j", e);
- throw new XGBoostError("training failed in xgboost4j ", e);
- }
- }
-
private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) {
return (int)o;
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
deleted file mode 100644
index 240c2387..00000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala
-
-import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
-import org.apache.hadoop.fs.FileSystem
-
-class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
- extends JavaECM(checkpointPath, fs) {
-
- def updateCheckpoint(booster: Booster): Unit = {
- super.updateCheckpoint(booster.booster)
- }
-
- def loadCheckpointAsScalaBooster(): Booster = {
- val loadedBooster = super.loadCheckpointAsBooster()
- if (loadedBooster == null) {
- null
- } else {
- new Booster(loadedBooster)
- }
- }
-}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
index 50d86c89..49ce29d8 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
@@ -20,61 +20,12 @@ import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoostError, XGBoost => JXGBoost}
import scala.jdk.CollectionConverters._
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
/**
* XGBoost Scala Training function.
*/
object XGBoost {
- private[scala] def trainAndSaveCheckpoint(
- dtrain: DMatrix,
- params: Map[String, Any],
- numRounds: Int,
- watches: Map[String, DMatrix] = Map(),
- metrics: Array[Array[Float]] = null,
- obj: ObjectiveTrait = null,
- eval: EvalTrait = null,
- earlyStoppingRound: Int = 0,
- prevBooster: Booster,
- checkpointParams: Option[ExternalCheckpointParams]): Booster = {
-
- // we have to filter null value for customized obj and eval
- val jParams: java.util.Map[String, AnyRef] =
- params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava
-
- val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
- val jBooster = if (prevBooster == null) {
- null
- } else {
- prevBooster.booster
- }
-
- val xgboostInJava = checkpointParams.
- map(cp => {
- JXGBoost.trainAndSaveCheckpoint(
- dtrain.jDMatrix,
- jParams,
- numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
- cp.checkpointInterval,
- cp.checkpointPath,
- new Path(cp.checkpointPath).getFileSystem(new Configuration()))
- }).
- getOrElse(
- JXGBoost.train(
- dtrain.jDMatrix,
- jParams,
- numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
- )
- if (prevBooster == null) {
- new Booster(xgboostInJava)
- } else {
- // Avoid creating a new SBooster with the same JBooster
- prevBooster
- }
- }
-
/**
* Train a booster given parameters.
*
@@ -104,8 +55,23 @@ object XGBoost {
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
- trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
- booster, None)
+ val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
+ val jBooster = if (booster == null) {
+ null
+ } else {
+ booster.booster
+ }
+ val xgboostInJava = JXGBoost.train(
+ dtrain.jDMatrix,
+ // we have to filter null value for customized obj and eval
+ params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava,
+ round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
+ if (booster == null) {
+ new Booster(xgboostInJava)
+ } else {
+ // Avoid creating a new SBooster with the same JBooster
+ booster
+ }
}
/**
@@ -160,41 +126,3 @@ object XGBoost {
new Booster(xgboostInJava)
}
}
-
-private[scala] case class ExternalCheckpointParams(
- checkpointInterval: Int,
- checkpointPath: String,
- skipCleanCheckpoint: Boolean)
-
-private[scala] object ExternalCheckpointParams {
-
- def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
- val checkpointPath: String = params.get("checkpoint_path") match {
- case None | Some(null) | Some("") => null
- case Some(path: String) => path
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
- s" an instance of String, but current value is ${params("checkpoint_path")}")
- }
-
- val checkpointInterval: Int = params.get("checkpoint_interval") match {
- case None => 0
- case Some(freq: Int) => freq
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
- " an instance of Int.")
- }
-
- val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
- case None => false
- case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
- case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
- " an instance of Boolean")
- }
- if (checkpointPath == null || checkpointInterval == 0) {
- None
- } else {
- Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
- }
- }
-}
-
-