【机器学习】Java 代码实现 CART 决策树算法
文章目录
- 一、决策树算法
- 二、CART 决策树
- 三、Java 代码实现
- 3.1 TrainDataSet
- 3.2 DataType
- 3.3 PredictResult
- 3.4 CartDecisionTree
- 3.5 Run
一、决策树算法
关于决策树算法的详细介绍可以参考我的另一篇博客:【机器学习】Decision Tree 决策树算法详解 + Python代码实战
二、CART 决策树
CART(classification and regression tree)树:又称为分类回归树,从名字可以发现,CART树既可用于分类,也可以用于回归。
当数据集的因变量是离散值时,可以采用CART分类树进行拟合,用叶节点概率最大的类别作为该节点的预测类别。
当数据集的因变量是连续值时,可以采用CART回归树进行拟合,用叶节点的均值作为该节点预测值。
本文实现的是 CART 分类树,其实就是采用 GINI 系数作为特征选择依据的决策树
三、Java 代码实现
3.1 TrainDataSet
TrainDataSet:训练数据集存放对象
public class TrainDataSet {
/**
* 特征集合
**/
public List<Object[]> features = new ArrayList<>();
/**
* 数据类型数组
**/
DataType[] dataTypes;
/**
* 标签集合
**/
public List<String> labels = new ArrayList<>();
/**
* 特征向量维度
**/
public int featureDim;
public TrainDataSet(DataType[] dataTypes) {
this.dataTypes = dataTypes;
this.featureDim = dataTypes.length;
}
public int size() {
return labels.size();
}
public Object[] getFeature(int index) {
return features.get(index);
}
public String getLabel(int index) {
return labels.get(index);
}
public void addData(Object[] feature, String label) {
if (featureDim != feature.length) {
throwDimensionMismatchException(feature.length);
}
features.add(feature);
labels.add(label);
}
public void throwDimensionMismatchException(int errorLen) {
throw new RuntimeException("DimensionMismatchError: 你应该传入维度为 " + featureDim + " 的特征向量 , 但你传入了维度为 " + errorLen + " 的特征向量");
}
}
3.2 DataType
DataType:一个枚举类,用来指示特征是数值型的还是字符型的
public enum DataType {
// 字符串
String,
// 数字
Number;
}
3.3 PredictResult
PredictResult:预测结果对象
public class PredictResult {
String[] labelArr;
String predictLabel;
double[] predictArr;
public PredictResult(String[] labelArr, String predictLabel, double[] predictArr) {
this.labelArr = labelArr;
this.predictLabel = predictLabel;
this.predictArr = predictArr;
}
@Override
public String toString() {
return "PredictResult{" +
"predictLabel='" + predictLabel + '\'' +
", predictArr=" + CartDecisionTree.predictArrToString(predictArr, labelArr) +
'}';
}
}
3.4 CartDecisionTree
CartDecisionTree:CART 决策树算法对象
public class CartDecisionTree {
/**
* 训练的数据集
**/
TrainDataSet trainDataSet;
/**
* 所有分类类型集合
**/
String[] labelArr;
/**
* 限制树的深度
**/
Integer maxDeep;
/**
* 限制叶子节点的个数
**/
Integer maxLeafNum;
/**
* 限制每个节点的样本数
**/
Integer minDataSize;
/**
* 决策树
**/
DecisionTree root;
public CartDecisionTree(TrainDataSet trainDataSet, Integer maxDeep, Integer maxLeafNum, Integer minDataSize) {
this.trainDataSet = trainDataSet;
this.maxDeep = maxDeep;
this.maxLeafNum = maxLeafNum;
this.minDataSize = minDataSize;
// 将 Label 去重,获取所有类别
HashSet<String> labelSet = new HashSet<>(trainDataSet.labels);
this.labelArr = new String[labelSet.size()];
int i = 0;
for (String label : labelSet) {
this.labelArr[i++] = label;
}
}
public void initModel() {
root = createLeafNode(1, trainDataSet.features, trainDataSet.labels, new boolean[trainDataSet.featureDim]);
}
// 传入特征向量,返回预测值
public PredictResult predict(Object[] features) {
DecisionTree tree = root;
while (tree.condition != null) {
int featureIndex = tree.condition.featureIndex;
if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
// 字符类型的分支下走
List<String> stringList = (List<String>) tree.condition.conditionValue;
for (int i = 0; i < stringList.size(); i++) {
if (stringList.get(i).equals((String) features[featureIndex])) {
tree = tree.children.get(i);
break;
}
}
} else {
// 数字类型的分支下走:左子树为 >= ,右子树为 <
if ((double) features[featureIndex] >= (double) tree.condition.conditionValue) {
tree = tree.children.get(0);
} else {
tree = tree.children.get(1);
}
}
}
return new PredictResult(labelArr, tree.predictLabel, tree.predictArr);
}
public void fit() {
initModel();
// BFS 构建决策树
Queue<DecisionTree> queue = new LinkedList<>();
queue.add(root);
while (!queue.isEmpty()) {
// 从队列中取出当前要构建的不完整的决策树(只有树深度、数据和最可能的标签,还没有分支条件和子节点)
DecisionTree decisionTree = queue.poll();
// 对特征进行遍历,找到分支后GINI系数最小的特征进行分支
BranchResult bestBranchResult = null;
// 预剪枝:限制树的深度
if (maxDeep == null || decisionTree.deep + 1 <= maxDeep) {
for (int featureIndex = 0; featureIndex < trainDataSet.featureDim; featureIndex++) {
// 只对没被禁忌的特征进行计算
if (!decisionTree.featureTabuArr[featureIndex]) {
BranchResult branchResult = null;
if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
// 对字符串类型的特征进行分支
branchResult = stringBranch(featureIndex, decisionTree);
} else {
// 对数字类型的特征进行分支
branchResult = numberBranch(featureIndex, decisionTree);
}
if (branchResult != null) {
if (bestBranchResult == null || branchResult.gini < bestBranchResult.gini) {
bestBranchResult = branchResult;
}
}
}
}
}
// 将最佳分支结果中的节点加入队列,并将其加入当前节点的子节点集合
if (bestBranchResult != null) {
decisionTree.children.addAll(bestBranchResult.decisionTreeList);
decisionTree.condition = bestBranchResult.condition;
for (DecisionTree child : decisionTree.children) {
// 如果最佳分支是字符串型分支,那么可以直接禁忌,之后不用再对那个特征进行分支
if (bestBranchResult.condition.dataType.equals(DataType.String)) {
child.featureTabuArr = decisionTree.featureTabuArr.clone();
child.featureTabuArr[bestBranchResult.condition.featureIndex] = true;
} else {
child.featureTabuArr = decisionTree.featureTabuArr.clone();
}
child.deep = decisionTree.deep + 1;
}
queue.addAll(decisionTree.children);
}
// 预剪枝:限制叶子节点的个数
if (maxLeafNum != null && queue.size() >= maxLeafNum) {
break;
}
}
root.print();
}
// 对数字类型的特征进行分支(对每两个数字中间的值进行分支)
private BranchResult numberBranch(int featureIndex, DecisionTree decisionTree) {
BranchResult bestBranchResult = null;
// 记录已经计算过的二分实数(用字符串是为了避免浮点型变量精度带来的 hash 失效,字符串是数字保留 6 位有效数字的结果),还有一个辅助作用,就是对数值特征值进行去重
HashSet<String> valueSet = new HashSet<>();
// 首先获取所有数值
List<Double> valueList = new ArrayList<>();
for (int i = 0; i < decisionTree.features.size(); i++) {
double v = (double) decisionTree.features.get(i)[featureIndex];
if (valueSet.add(String.format("%.6f", v))) {
valueList.add(v);
}
}
// 然后排序
Collections.sort(valueList);
// 然后就选取中间值进行分支尝试
valueSet = new HashSet<>();
for (int i = 0; i < valueList.size() - 1; i++) {
double mid = (valueList.get(i) + valueList.get(i + 1)) / 2.0;
if (valueSet.add(String.format("%.6f", mid))) {
// 辅助计算 GINI 系数
Map<String, Integer> leftMap = new HashMap<>();
Map<String, Integer> rightMap = new HashMap<>();
// 初始化左右节点 左子树为 >= ,右子树为 <
DecisionTree left = new DecisionTree();
DecisionTree right = new DecisionTree();
// 向左右节点加入数据
for (int j = 0; j < decisionTree.features.size(); j++) {
if ((double) decisionTree.features.get(j)[featureIndex] >= mid) {
left.features.add(decisionTree.features.get(j));
left.labels.add(decisionTree.labels.get(j));
Integer cnt = leftMap.getOrDefault(decisionTree.labels.get(j), null);
leftMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
} else {
right.features.add(decisionTree.features.get(j));
right.labels.add(decisionTree.labels.get(j));
Integer cnt = rightMap.getOrDefault(decisionTree.labels.get(j), null);
rightMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
}
}
// 预剪枝:限制每个节点的样本数
if (minDataSize != null) {
if (left.labels.size() < minDataSize || right.labels.size() < minDataSize) {
continue;
}
}
// 计算 GINI 系数
double leftGINI = (double) left.labels.size() / decisionTree.labels.size() * GINI(left.labels.size(), leftMap);
double rightGINI = (double) right.labels.size() / decisionTree.labels.size() * GINI(right.labels.size(), rightMap);
double gini = leftGINI + rightGINI;
if (bestBranchResult == null || gini < bestBranchResult.gini) {
bestBranchResult = new BranchResult();
bestBranchResult.gini = gini;
// 左子树为 >= ,右子树为 <
calcPredictLabelAndArr(left, leftMap);
calcPredictLabelAndArr(right, rightMap);
bestBranchResult.decisionTreeList.add(left);
bestBranchResult.decisionTreeList.add(right);
bestBranchResult.condition = new Condition(DataType.Number, featureIndex, mid);
}
}
}
return bestBranchResult;
}
// 对字符串类型的特征进行分支
private BranchResult stringBranch(int featureIndex, DecisionTree decisionTree) {
// 开始根据指定特征进行分组
Map<String, DecisionTree> decisionTreeListMap = new HashMap<>();
// 辅助 GINI 系数计算的 Map ,存储当前 feature 的不同取值下,label 的不同取值的个数
Map<String, Map<String, Integer>> giniCalcMap = new HashMap<>();
for (int i = 0; i < decisionTree.features.size(); i++) {
String key = (String) (decisionTree.features.get(i)[featureIndex]);
if (!decisionTreeListMap.containsKey(key)) {
decisionTreeListMap.put(key, new DecisionTree());
giniCalcMap.put(key, new HashMap<>());
}
decisionTreeListMap.get(key).features.add(decisionTree.features.get(i));
decisionTreeListMap.get(key).labels.add(decisionTree.labels.get(i));
Integer cnt = giniCalcMap.get(key).getOrDefault(decisionTree.labels.get(i), null);
giniCalcMap.get(key).put(decisionTree.labels.get(i), cnt == null ? 1 : cnt + 1);
}
// 如果 decisionTreeListMap 的 size 为 1,说明当前节点当前特征已经纯了,那么就不用对这个特征进行分支了,所以可以直接返回 null
if (decisionTreeListMap.size() <= 1) {
return null;
}
// 预剪枝:限制每个节点的样本数
if (minDataSize != null) {
for (String key : decisionTreeListMap.keySet()) {
if (decisionTreeListMap.get(key).labels.size() < minDataSize) {
return null;
}
}
}
// 计算GINI系数,并生成 BranchResult
List<DecisionTree> decisionTreeList = new ArrayList<>();
List<String> conditionValue = new ArrayList<>();
double gini = 0d;
for (String key : decisionTreeListMap.keySet()) {
DecisionTree tree = decisionTreeListMap.get(key);
calcPredictLabelAndArr(tree, giniCalcMap.get(key));
decisionTreeList.add(tree);
conditionValue.add(key);
// 计算 GINI 系数
double rate = ((double) tree.labels.size() / decisionTree.labels.size());
gini += (rate * GINI(tree.labels.size(), giniCalcMap.get(key)));
}
BranchResult branchResult = new BranchResult();
branchResult.gini = gini;
branchResult.decisionTreeList = decisionTreeList;
branchResult.condition = new Condition(DataType.String, featureIndex, conditionValue);
return branchResult;
}
// 根据一个特征取值的 GINI 系数
private double GINI(int totalCnt, Map<String, Integer> map) {
double gini = 1d;
for (String key : map.keySet()) {
gini -= Math.pow(((double) map.get(key) / totalCnt), 2);
}
return gini;
}
// 获取一个叶子节点
private DecisionTree createLeafNode(int deep, List<Object[]> features, List<String> labels, boolean[] featureTabuArr) {
DecisionTree leaf = new DecisionTree();
leaf.features = features;
leaf.labels = labels;
leaf.deep = deep;
leaf.featureTabuArr = featureTabuArr;
calcPredictLabelAndArr(leaf, labels);
return leaf;
}
// 获取最多的标签
private void calcPredictLabelAndArr(DecisionTree tree, List<String> labels) {
Map<String, Integer> map = new HashMap<>();
String mostLabel = null;
int mostNum = -1;
for (String label : labels) {
Integer num = map.getOrDefault(label, null);
map.put(label, num == null ? 1 : num + 1);
if (map.get(label) > mostNum) {
mostNum = map.get(label);
mostLabel = label;
}
}
if (mostNum == -1) {
throw new RuntimeException("没找到最多的标签");
}
int totalCnt = 0;
for (String label : map.keySet()) {
totalCnt += map.get(label);
}
double[] predictArr = new double[labelArr.length];
for (int i = 0; i < labelArr.length; i++) {
predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
}
tree.predictLabel = mostLabel;
tree.predictArr = predictArr;
}
private void calcPredictLabelAndArr(DecisionTree tree, Map<String, Integer> map) {
String mostLabel = null;
int mostNum = -1;
int totalCnt = 0;
for (String label : map.keySet()) {
if (map.get(label) > mostNum) {
mostNum = map.get(label);
mostLabel = label;
}
totalCnt += map.get(label);
}
if (mostNum == -1) {
throw new RuntimeException("没找到最多的标签");
}
double[] predictArr = new double[labelArr.length];
for (int i = 0; i < labelArr.length; i++) {
predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
}
tree.predictLabel = mostLabel;
tree.predictArr = predictArr;
}
// 将概率向量字符串化
public static String predictArrToString(double[] predictArr, String[] labelArr) {
if (predictArr.length != labelArr.length) {
throw new RuntimeException("传入的概率矩阵维度和类型数组长度不一致: " + predictArr.length + " != " + labelArr.length);
}
StringBuilder str = new StringBuilder("[ ");
for (int i = 0; i < labelArr.length - 1; i++) {
str.append(labelArr[i]).append(":").append(String.format("%.2f", predictArr[i])).append(" , ");
}
str.append(labelArr[labelArr.length - 1]).append(":").append(String.format("%.2f", predictArr[predictArr.length - 1])).append(" ]");
return str.toString();
}
// 分支结果
class BranchResult {
// GINI系数
double gini;
// 分支后的节点集合
List<DecisionTree> decisionTreeList = new ArrayList<>();
// 分支条件
Condition condition;
}
// 决策树
class DecisionTree {
Condition condition;
List<Object[]> features = new ArrayList<>();
List<String> labels = new ArrayList<>();
List<DecisionTree> children = new ArrayList<>();
// 特征走到当前节点最可能是的 Label
String predictLabel;
// 预测属于每个类别的概率
double[] predictArr;
// 树当前的深度
int deep;
/**
* 记录哪些特征不用分支
**/
boolean[] featureTabuArr;
// 前序遍历输出自身信息
public void print() {
this.printSelf();
for (DecisionTree child : children) {
child.print();
}
}
public void printSelf() {
if (condition != null) {
System.out.println("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , featureIndex: " + condition.featureIndex + " , condition: " + condition.conditionValue);
} else {
System.out.print("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , features: ");
for (Object[] feature : features) {
System.out.print(Arrays.toString(feature) + " , ");
}
System.out.println("labels: " + labels);
}
}
}
// 分支条件
class Condition {
DataType dataType;
int featureIndex;
/**
* 如果 dataType 为 String,则 conditionValue 为 List<String>
* 如果 dataType 为 Number,则 conditionValue 为 double,且左子树为 >= ,右子树为 <
**/
Object conditionValue;
public Condition(DataType dataType, int featureIndex, Object conditionValue) {
this.dataType = dataType;
this.featureIndex = featureIndex;
this.conditionValue = conditionValue;
}
}
}
3.5 Run
Run:运行算法的类
public class Run {
public static void main(String[] args) {
// 测试纯文本的分类
testStringData();
// 测试纯数字的分类
testNumberData();
// 测试文本和数字混合的分类
testStringAndNumberData();
}
public static void testStringData() {
System.out.println("================================================================== 测试纯文本的分类 ==================================================================");
// 构建纯文本特征的数据集
TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.String, DataType.String});
trainDataSet.addData(new Object[]{"是", "单身"}, "否");
trainDataSet.addData(new Object[]{"是", "已婚"}, "否");
trainDataSet.addData(new Object[]{"否", "单身"}, "否");
trainDataSet.addData(new Object[]{"是", "已婚"}, "否");
trainDataSet.addData(new Object[]{"否", "离异"}, "是");
trainDataSet.addData(new Object[]{"否", "已婚"}, "是");
trainDataSet.addData(new Object[]{"是", "离异"}, "否");
trainDataSet.addData(new Object[]{"否", "单身"}, "是");
trainDataSet.addData(new Object[]{"否", "离异"}, "否");
trainDataSet.addData(new Object[]{"否", "单身"}, "是");
long startTime = System.currentTimeMillis();
CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
cartDecisionTree.fit();
System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
System.out.println("用训练好的模型进行预测: ");
System.out.println("TestData: " + Arrays.toString(new String[]{"是", "离异"}) + " : " + cartDecisionTree.predict(new Object[]{"是", "离异"}));
System.out.println("TestData: " + Arrays.toString(new String[]{"否", "单身"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身"}));
System.out.println("TestData: " + Arrays.toString(new String[]{"否", "离异"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "离异"}));
System.out.println("TestData: " + Arrays.toString(new String[]{"否", "单身"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身"}));
}
public static void testNumberData() {
System.out.println("================================================================== 测试纯数字的分类 ==================================================================");
// 构建纯文本特征的数据集
TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.Number});
trainDataSet.addData(new Object[]{125d}, "否");
trainDataSet.addData(new Object[]{100d}, "否");
trainDataSet.addData(new Object[]{70d}, "否");
trainDataSet.addData(new Object[]{120d}, "否");
trainDataSet.addData(new Object[]{95d}, "是");
trainDataSet.addData(new Object[]{60d}, "是");
trainDataSet.addData(new Object[]{200d}, "否");
trainDataSet.addData(new Object[]{85d}, "是");
trainDataSet.addData(new Object[]{75d}, "否");
trainDataSet.addData(new Object[]{90d}, "是");
long startTime = System.currentTimeMillis();
CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
cartDecisionTree.fit();
System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
System.out.println("用训练好的模型进行预测: ");
System.out.println("TestData: " + Arrays.toString(new double[]{200d}) + " : " + cartDecisionTree.predict(new Object[]{200d}));
System.out.println("TestData: " + Arrays.toString(new double[]{85d}) + " : " + cartDecisionTree.predict(new Object[]{85d}));
System.out.println("TestData: " + Arrays.toString(new double[]{75d}) + " : " + cartDecisionTree.predict(new Object[]{75d}));
System.out.println("TestData: " + Arrays.toString(new double[]{90d}) + " : " + cartDecisionTree.predict(new Object[]{90d}));
}
public static void testStringAndNumberData() {
System.out.println("================================================================== 测试文本和数字混合的分类 ==================================================================");
// 构建纯文本特征的数据集
TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.String, DataType.String, DataType.Number});
trainDataSet.addData(new Object[]{"是", "单身", 125d}, "否");
trainDataSet.addData(new Object[]{"是", "已婚", 100d}, "否");
trainDataSet.addData(new Object[]{"否", "单身", 70d}, "否");
trainDataSet.addData(new Object[]{"是", "已婚", 120d}, "否");
trainDataSet.addData(new Object[]{"否", "离异", 95d}, "是");
trainDataSet.addData(new Object[]{"否", "已婚", 60d}, "是");
trainDataSet.addData(new Object[]{"是", "离异", 200d}, "否");
trainDataSet.addData(new Object[]{"否", "单身", 85d}, "是");
trainDataSet.addData(new Object[]{"否", "离异", 75d}, "否");
trainDataSet.addData(new Object[]{"否", "单身", 90d}, "是");
long startTime = System.currentTimeMillis();
CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
cartDecisionTree.fit();
System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
System.out.println("用训练好的模型进行预测: ");
System.out.println("TestData: " + Arrays.toString(new Object[]{"是", "离异", 200d}) + " : " + cartDecisionTree.predict(new Object[]{"是", "离异", 200d}));
System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "单身", 85d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身", 85d}));
System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "离异", 75d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "离异", 75d}));
System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "单身", 90d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身", 90d}));
}
}
运行输出如下:
================================================================== 测试纯文本的分类 ==================================================================
deep: 1 , predictLabel: 否 , predictArr: [ 否:0.60 , 是:0.40 ] , featureIndex: 0 , condition: [否, 是]
deep: 2 , predictLabel: 是 , predictArr: [ 否:0.33 , 是:0.67 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [否, 已婚] , labels: [是]
deep: 3 , predictLabel: 否 , predictArr: [ 否:0.50 , 是:0.50 ] , features: [否, 离异] , [否, 离异] , labels: [是, 否]
deep: 3 , predictLabel: 是 , predictArr: [ 否:0.33 , 是:0.67 ] , features: [否, 单身] , [否, 单身] , [否, 单身] , labels: [否, 是, 是]
deep: 2 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 已婚] , [是, 已婚] , labels: [否, 否]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 离异] , labels: [否]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 单身] , labels: [否]
训练用时: 0.026 s
用训练好的模型进行预测:
TestData: [是, 离异] : PredictResult{predictLabel='否', predictArr=[ 否:1.00 , 是:0.00 ]}
TestData: [否, 单身] : PredictResult{predictLabel='是', predictArr=[ 否:0.33 , 是:0.67 ]}
TestData: [否, 离异] : PredictResult{predictLabel='否', predictArr=[ 否:0.50 , 是:0.50 ]}
TestData: [否, 单身] : PredictResult{predictLabel='是', predictArr=[ 否:0.33 , 是:0.67 ]}
================================================================== 测试纯数字的分类 ==================================================================
deep: 1 , predictLabel: 否 , predictArr: [ 否:0.60 , 是:0.40 ] , featureIndex: 0 , condition: 97.5
deep: 2 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 0 , condition: 110.0
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 0 , condition: 122.5
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 0 , condition: 162.5
deep: 5 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [200.0] , labels: [否]
deep: 5 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [125.0] , labels: [否]
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [120.0] , labels: [否]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [100.0] , labels: [否]
deep: 2 , predictLabel: 是 , predictArr: [ 否:0.33 , 是:0.67 ] , featureIndex: 0 , condition: 80.0
deep: 3 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , featureIndex: 0 , condition: 87.5
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , featureIndex: 0 , condition: 92.5
deep: 5 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [95.0] , labels: [是]
deep: 5 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [90.0] , labels: [是]
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [85.0] , labels: [是]
deep: 3 , predictLabel: 否 , predictArr: [ 否:0.67 , 是:0.33 ] , featureIndex: 0 , condition: 65.0
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 0 , condition: 72.5
deep: 5 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [75.0] , labels: [否]
deep: 5 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [70.0] , labels: [否]
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [60.0] , labels: [是]
训练用时: 0.007 s
用训练好的模型进行预测:
TestData: [200.0] : PredictResult{predictLabel='否', predictArr=[ 否:1.00 , 是:0.00 ]}
TestData: [85.0] : PredictResult{predictLabel='是', predictArr=[ 否:0.00 , 是:1.00 ]}
TestData: [75.0] : PredictResult{predictLabel='否', predictArr=[ 否:1.00 , 是:0.00 ]}
TestData: [90.0] : PredictResult{predictLabel='是', predictArr=[ 否:0.00 , 是:1.00 ]}
================================================================== 测试文本和数字混合的分类 ==================================================================
deep: 1 , predictLabel: 否 , predictArr: [ 否:0.60 , 是:0.40 ] , featureIndex: 0 , condition: [否, 是]
deep: 2 , predictLabel: 是 , predictArr: [ 否:0.33 , 是:0.67 ] , featureIndex: 2 , condition: 80.0
deep: 3 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , featureIndex: 1 , condition: [离异, 单身]
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [否, 离异, 95.0] , labels: [是]
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , featureIndex: 2 , condition: 87.5
deep: 5 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [否, 单身, 90.0] , labels: [是]
deep: 5 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [否, 单身, 85.0] , labels: [是]
deep: 3 , predictLabel: 否 , predictArr: [ 否:0.67 , 是:0.33 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 4 , predictLabel: 是 , predictArr: [ 否:0.00 , 是:1.00 ] , features: [否, 已婚, 60.0] , labels: [是]
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [否, 离异, 75.0] , labels: [否]
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [否, 单身, 70.0] , labels: [否]
deep: 2 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , featureIndex: 2 , condition: 110.0
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 已婚, 120.0] , labels: [否]
deep: 4 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 已婚, 100.0] , labels: [否]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 离异, 200.0] , labels: [否]
deep: 3 , predictLabel: 否 , predictArr: [ 否:1.00 , 是:0.00 ] , features: [是, 单身, 125.0] , labels: [否]
训练用时: 0.011 s
用训练好的模型进行预测:
TestData: [是, 离异, 200.0] : PredictResult{predictLabel='否', predictArr=[ 否:1.00 , 是:0.00 ]}
TestData: [否, 单身, 85.0] : PredictResult{predictLabel='是', predictArr=[ 否:0.00 , 是:1.00 ]}
TestData: [否, 离异, 75.0] : PredictResult{predictLabel='否', predictArr=[ 否:1.00 , 是:0.00 ]}
TestData: [否, 单身, 90.0] : PredictResult{predictLabel='是', predictArr=[ 否:0.00 , 是:1.00 ]}