当前位置: 首页 > news >正文

【机器学习】Java 代码实现 CART 决策树算法

文章目录

  • 一、决策树算法
  • 二、CART 决策树
  • 三、Java 代码实现
    • 3.1 TrainDataSet
    • 3.2 DataType
    • 3.3 PredictResult
    • 3.4 CartDecisionTree
    • 3.5 Run


一、决策树算法

关于决策树算法的详细介绍可以参考我的另一篇博客:【机器学习】Decision Tree 决策树算法详解 + Python代码实战


二、CART 决策树

CART(classification and regression tree)树:又称为分类回归树,从名字可以发现,CART树既可用于分类,也可以用于回归。

当数据集的因变量是离散值时,可以采用CART分类树进行拟合,用叶节点概率最大的类别作为该节点的预测类别。

当数据集的因变量是连续值时,可以采用CART回归树进行拟合,用叶节点的均值作为该节点预测值。

本文实现的是 CART 分类树,其实就是采用 GINI 系数作为特征选择依据的决策树


三、Java 代码实现

3.1 TrainDataSet

TrainDataSet:训练数据集存放对象

public class TrainDataSet {

    /**
     * 特征集合
     **/
    public List<Object[]> features = new ArrayList<>();
    /**
     * 数据类型数组
     **/
    DataType[] dataTypes;
    /**
     * 标签集合
     **/
    public List<String> labels = new ArrayList<>();
    /**
     * 特征向量维度
     **/
    public int featureDim;

    public TrainDataSet(DataType[] dataTypes) {
        this.dataTypes = dataTypes;
        this.featureDim = dataTypes.length;
    }

    public int size() {
        return labels.size();
    }

    public Object[] getFeature(int index) {
        return features.get(index);
    }

    public String getLabel(int index) {
        return labels.get(index);
    }

    public void addData(Object[] feature, String label) {
        if (featureDim != feature.length) {
            throwDimensionMismatchException(feature.length);
        }
        features.add(feature);
        labels.add(label);
    }

    public void throwDimensionMismatchException(int errorLen) {
        throw new RuntimeException("DimensionMismatchError: 你应该传入维度为 " + featureDim + " 的特征向量 , 但你传入了维度为 " + errorLen + " 的特征向量");
    }

}

3.2 DataType

DataType:一个枚举类,用来指示特征是数值型的还是字符型的

public enum DataType {
    // 字符串
    String,
    // 数字
    Number;
}

3.3 PredictResult

PredictResult:预测结果对象

public class PredictResult {
    String[] labelArr;
    String predictLabel;
    double[] predictArr;

    public PredictResult(String[] labelArr, String predictLabel, double[] predictArr) {
        this.labelArr = labelArr;
        this.predictLabel = predictLabel;
        this.predictArr = predictArr;
    }

    @Override
    public String toString() {
        return "PredictResult{" +
                "predictLabel='" + predictLabel + '\'' +
                ", predictArr=" + CartDecisionTree.predictArrToString(predictArr, labelArr) +
                '}';
    }
}

3.4 CartDecisionTree

CartDecisionTree:CART 决策树算法对象

public class CartDecisionTree {
    /**
     * 训练的数据集
     **/
    TrainDataSet trainDataSet;
    /**
     * 所有分类类型集合
     **/
    String[] labelArr;
    /**
     * 限制树的深度
     **/
    Integer maxDeep;
    /**
     * 限制叶子节点的个数
     **/
    Integer maxLeafNum;
    /**
     * 限制每个节点的样本数
     **/
    Integer minDataSize;
    /**
     * 决策树
     **/
    DecisionTree root;


    public CartDecisionTree(TrainDataSet trainDataSet, Integer maxDeep, Integer maxLeafNum, Integer minDataSize) {
        this.trainDataSet = trainDataSet;
        this.maxDeep = maxDeep;
        this.maxLeafNum = maxLeafNum;
        this.minDataSize = minDataSize;
        // 将 Label 去重,获取所有类别
        HashSet<String> labelSet = new HashSet<>(trainDataSet.labels);
        this.labelArr = new String[labelSet.size()];
        int i = 0;
        for (String label : labelSet) {
            this.labelArr[i++] = label;
        }
    }

    public void initModel() {
        root = createLeafNode(1, trainDataSet.features, trainDataSet.labels, new boolean[trainDataSet.featureDim]);
    }

    // 传入特征向量,返回预测值
    public PredictResult predict(Object[] features) {
        DecisionTree tree = root;
        while (tree.condition != null) {
            int featureIndex = tree.condition.featureIndex;
            if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
                // 字符类型的分支下走
                List<String> stringList = (List<String>) tree.condition.conditionValue;
                for (int i = 0; i < stringList.size(); i++) {
                    if (stringList.get(i).equals((String) features[featureIndex])) {
                        tree = tree.children.get(i);
                        break;
                    }
                }
            } else {
                // 数字类型的分支下走:左子树为 >= ,右子树为 <
                if ((double) features[featureIndex] >= (double) tree.condition.conditionValue) {
                    tree = tree.children.get(0);
                } else {
                    tree = tree.children.get(1);
                }
            }
        }
        return new PredictResult(labelArr, tree.predictLabel, tree.predictArr);
    }

    public void fit() {
        initModel();
        // BFS 构建决策树
        Queue<DecisionTree> queue = new LinkedList<>();
        queue.add(root);
        while (!queue.isEmpty()) {
            // 从队列中取出当前要构建的不完整的决策树(只有树深度、数据和最可能的标签,还没有分支条件和子节点)
            DecisionTree decisionTree = queue.poll();
            // 对特征进行遍历,找到分支后GINI系数最小的特征进行分支
            BranchResult bestBranchResult = null;
            // 预剪枝:限制树的深度
            if (maxDeep == null || decisionTree.deep + 1 <= maxDeep) {
                for (int featureIndex = 0; featureIndex < trainDataSet.featureDim; featureIndex++) {
                    // 只对没被禁忌的特征进行计算
                    if (!decisionTree.featureTabuArr[featureIndex]) {
                        BranchResult branchResult = null;
                        if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
                            // 对字符串类型的特征进行分支
                            branchResult = stringBranch(featureIndex, decisionTree);
                        } else {
                            // 对数字类型的特征进行分支
                            branchResult = numberBranch(featureIndex, decisionTree);
                        }
                        if (branchResult != null) {
                            if (bestBranchResult == null || branchResult.gini < bestBranchResult.gini) {
                                bestBranchResult = branchResult;
                            }
                        }
                    }
                }
            }
            // 将最佳分支结果中的节点加入队列,并将其加入当前节点的子节点集合
            if (bestBranchResult != null) {
                decisionTree.children.addAll(bestBranchResult.decisionTreeList);
                decisionTree.condition = bestBranchResult.condition;
                for (DecisionTree child : decisionTree.children) {
                    // 如果最佳分支是字符串型分支,那么可以直接禁忌,之后不用再对那个特征进行分支
                    if (bestBranchResult.condition.dataType.equals(DataType.String)) {
                        child.featureTabuArr = decisionTree.featureTabuArr.clone();
                        child.featureTabuArr[bestBranchResult.condition.featureIndex] = true;
                    } else {
                        child.featureTabuArr = decisionTree.featureTabuArr.clone();
                    }
                    child.deep = decisionTree.deep + 1;
                }
                queue.addAll(decisionTree.children);
            }
            // 预剪枝:限制叶子节点的个数
            if (maxLeafNum != null && queue.size() >= maxLeafNum) {
                break;
            }
        }
        root.print();
    }

    // 对数字类型的特征进行分支(对每两个数字中间的值进行分支)
    private BranchResult numberBranch(int featureIndex, DecisionTree decisionTree) {
        BranchResult bestBranchResult = null;
        // 记录已经计算过的二分实数(用字符串是为了避免浮点型变量精度带来的 hash 失效,字符串是数字保留 6 位有效数字的结果),还有一个辅助作用,就是对数值特征值进行去重
        HashSet<String> valueSet = new HashSet<>();
        // 首先获取所有数值
        List<Double> valueList = new ArrayList<>();
        for (int i = 0; i < decisionTree.features.size(); i++) {
            double v = (double) decisionTree.features.get(i)[featureIndex];
            if (valueSet.add(String.format("%.6f", v))) {
                valueList.add(v);
            }
        }
        // 然后排序
        Collections.sort(valueList);
        // 然后就选取中间值进行分支尝试
        valueSet = new HashSet<>();
        for (int i = 0; i < valueList.size() - 1; i++) {
            double mid = (valueList.get(i) + valueList.get(i + 1)) / 2.0;
            if (valueSet.add(String.format("%.6f", mid))) {
                // 辅助计算 GINI 系数
                Map<String, Integer> leftMap = new HashMap<>();
                Map<String, Integer> rightMap = new HashMap<>();
                // 初始化左右节点 左子树为 >= ,右子树为 <
                DecisionTree left = new DecisionTree();
                DecisionTree right = new DecisionTree();
                // 向左右节点加入数据
                for (int j = 0; j < decisionTree.features.size(); j++) {
                    if ((double) decisionTree.features.get(j)[featureIndex] >= mid) {
                        left.features.add(decisionTree.features.get(j));
                        left.labels.add(decisionTree.labels.get(j));
                        Integer cnt = leftMap.getOrDefault(decisionTree.labels.get(j), null);
                        leftMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
                    } else {
                        right.features.add(decisionTree.features.get(j));
                        right.labels.add(decisionTree.labels.get(j));
                        Integer cnt = rightMap.getOrDefault(decisionTree.labels.get(j), null);
                        rightMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
                    }
                }
                // 预剪枝:限制每个节点的样本数
                if (minDataSize != null) {
                    if (left.labels.size() < minDataSize || right.labels.size() < minDataSize) {
                        continue;
                    }
                }
                // 计算 GINI 系数
                double leftGINI = (double) left.labels.size() / decisionTree.labels.size() * GINI(left.labels.size(), leftMap);
                double rightGINI = (double) right.labels.size() / decisionTree.labels.size() * GINI(right.labels.size(), rightMap);
                double gini = leftGINI + rightGINI;
                if (bestBranchResult == null || gini < bestBranchResult.gini) {
                    bestBranchResult = new BranchResult();
                    bestBranchResult.gini = gini;
                    // 左子树为 >= ,右子树为 <
                    calcPredictLabelAndArr(left, leftMap);
                    calcPredictLabelAndArr(right, rightMap);
                    bestBranchResult.decisionTreeList.add(left);
                    bestBranchResult.decisionTreeList.add(right);
                    bestBranchResult.condition = new Condition(DataType.Number, featureIndex, mid);
                }
            }
        }
        return bestBranchResult;
    }

    // 对字符串类型的特征进行分支
    private BranchResult stringBranch(int featureIndex, DecisionTree decisionTree) {
        // 开始根据指定特征进行分组
        Map<String, DecisionTree> decisionTreeListMap = new HashMap<>();
        // 辅助 GINI 系数计算的 Map ,存储当前 feature 的不同取值下,label 的不同取值的个数
        Map<String, Map<String, Integer>> giniCalcMap = new HashMap<>();
        for (int i = 0; i < decisionTree.features.size(); i++) {
            String key = (String) (decisionTree.features.get(i)[featureIndex]);
            if (!decisionTreeListMap.containsKey(key)) {
                decisionTreeListMap.put(key, new DecisionTree());
                giniCalcMap.put(key, new HashMap<>());
            }
            decisionTreeListMap.get(key).features.add(decisionTree.features.get(i));
            decisionTreeListMap.get(key).labels.add(decisionTree.labels.get(i));
            Integer cnt = giniCalcMap.get(key).getOrDefault(decisionTree.labels.get(i), null);
            giniCalcMap.get(key).put(decisionTree.labels.get(i), cnt == null ? 1 : cnt + 1);
        }
        // 如果 decisionTreeListMap 的  size 为 1,说明当前节点当前特征已经纯了,那么就不用对这个特征进行分支了,所以可以直接返回 null
        if (decisionTreeListMap.size() <= 1) {
            return null;
        }
        // 预剪枝:限制每个节点的样本数
        if (minDataSize != null) {
            for (String key : decisionTreeListMap.keySet()) {
                if (decisionTreeListMap.get(key).labels.size() < minDataSize) {
                    return null;
                }
            }
        }
        // 计算GINI系数,并生成 BranchResult
        List<DecisionTree> decisionTreeList = new ArrayList<>();
        List<String> conditionValue = new ArrayList<>();
        double gini = 0d;
        for (String key : decisionTreeListMap.keySet()) {
            DecisionTree tree = decisionTreeListMap.get(key);
            calcPredictLabelAndArr(tree, giniCalcMap.get(key));
            decisionTreeList.add(tree);
            conditionValue.add(key);
            // 计算 GINI 系数
            double rate = ((double) tree.labels.size() / decisionTree.labels.size());
            gini += (rate * GINI(tree.labels.size(), giniCalcMap.get(key)));
        }
        BranchResult branchResult = new BranchResult();
        branchResult.gini = gini;
        branchResult.decisionTreeList = decisionTreeList;
        branchResult.condition = new Condition(DataType.String, featureIndex, conditionValue);
        return branchResult;
    }

    // 根据一个特征取值的 GINI 系数
    private double GINI(int totalCnt, Map<String, Integer> map) {
        double gini = 1d;
        for (String key : map.keySet()) {
            gini -= Math.pow(((double) map.get(key) / totalCnt), 2);
        }
        return gini;
    }

    // 获取一个叶子节点
    private DecisionTree createLeafNode(int deep, List<Object[]> features, List<String> labels, boolean[] featureTabuArr) {
        DecisionTree leaf = new DecisionTree();
        leaf.features = features;
        leaf.labels = labels;
        leaf.deep = deep;
        leaf.featureTabuArr = featureTabuArr;
        calcPredictLabelAndArr(leaf, labels);
        return leaf;
    }

    // 获取最多的标签
    private void calcPredictLabelAndArr(DecisionTree tree, List<String> labels) {
        Map<String, Integer> map = new HashMap<>();
        String mostLabel = null;
        int mostNum = -1;
        for (String label : labels) {
            Integer num = map.getOrDefault(label, null);
            map.put(label, num == null ? 1 : num + 1);
            if (map.get(label) > mostNum) {
                mostNum = map.get(label);
                mostLabel = label;
            }
        }
        if (mostNum == -1) {
            throw new RuntimeException("没找到最多的标签");
        }
        int totalCnt = 0;
        for (String label : map.keySet()) {
            totalCnt += map.get(label);
        }
        double[] predictArr = new double[labelArr.length];
        for (int i = 0; i < labelArr.length; i++) {
            predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
        }
        tree.predictLabel = mostLabel;
        tree.predictArr = predictArr;
    }

    private void calcPredictLabelAndArr(DecisionTree tree, Map<String, Integer> map) {
        String mostLabel = null;
        int mostNum = -1;
        int totalCnt = 0;
        for (String label : map.keySet()) {
            if (map.get(label) > mostNum) {
                mostNum = map.get(label);
                mostLabel = label;
            }
            totalCnt += map.get(label);
        }
        if (mostNum == -1) {
            throw new RuntimeException("没找到最多的标签");
        }
        double[] predictArr = new double[labelArr.length];
        for (int i = 0; i < labelArr.length; i++) {
            predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
        }
        tree.predictLabel = mostLabel;
        tree.predictArr = predictArr;
    }

    // 将概率向量字符串化
    public static String predictArrToString(double[] predictArr, String[] labelArr) {
        if (predictArr.length != labelArr.length) {
            throw new RuntimeException("传入的概率矩阵维度和类型数组长度不一致: " + predictArr.length + " != " + labelArr.length);
        }
        StringBuilder str = new StringBuilder("[ ");
        for (int i = 0; i < labelArr.length - 1; i++) {
            str.append(labelArr[i]).append(":").append(String.format("%.2f", predictArr[i])).append(" , ");
        }
        str.append(labelArr[labelArr.length - 1]).append(":").append(String.format("%.2f", predictArr[predictArr.length - 1])).append(" ]");
        return str.toString();
    }

    // 分支结果
    class BranchResult {
        // GINI系数
        double gini;
        // 分支后的节点集合
        List<DecisionTree> decisionTreeList = new ArrayList<>();
        // 分支条件
        Condition condition;
    }

    // 决策树
    class DecisionTree {
        Condition condition;
        List<Object[]> features = new ArrayList<>();
        List<String> labels = new ArrayList<>();
        List<DecisionTree> children = new ArrayList<>();
        // 特征走到当前节点最可能是的 Label
        String predictLabel;
        // 预测属于每个类别的概率
        double[] predictArr;
        // 树当前的深度
        int deep;
        /**
         * 记录哪些特征不用分支
         **/
        boolean[] featureTabuArr;

        // 前序遍历输出自身信息
        public void print() {
            this.printSelf();
            for (DecisionTree child : children) {
                child.print();
            }
        }

        public void printSelf() {
            if (condition != null) {
                System.out.println("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , featureIndex: " + condition.featureIndex + " , condition: " + condition.conditionValue);
            } else {
                System.out.print("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , features: ");
                for (Object[] feature : features) {
                    System.out.print(Arrays.toString(feature) + " , ");
                }
                System.out.println("labels: " + labels);
            }
        }

    }

    // 分支条件
    class Condition {
        DataType dataType;
        int featureIndex;
        /**
         * 如果 dataType 为 String,则 conditionValue 为 List<String>
         * 如果 dataType 为 Number,则 conditionValue 为 double,且左子树为 >= ,右子树为 <
         **/
        Object conditionValue;

        public Condition(DataType dataType, int featureIndex, Object conditionValue) {
            this.dataType = dataType;
            this.featureIndex = featureIndex;
            this.conditionValue = conditionValue;
        }
    }

}

3.5 Run

Run:运行算法的类

public class Run {
    public static void main(String[] args) {
        // 测试纯文本的分类
        testStringData();
        // 测试纯数字的分类
        testNumberData();
        // 测试文本和数字混合的分类
        testStringAndNumberData();
    }

    public static void testStringData() {
        System.out.println("================================================================== 测试纯文本的分类 ==================================================================");
        // 构建纯文本特征的数据集
        TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.String, DataType.String});
        trainDataSet.addData(new Object[]{"是", "单身"}, "否");
        trainDataSet.addData(new Object[]{"是", "已婚"}, "否");
        trainDataSet.addData(new Object[]{"否", "单身"}, "否");
        trainDataSet.addData(new Object[]{"是", "已婚"}, "否");
        trainDataSet.addData(new Object[]{"否", "离异"}, "是");
        trainDataSet.addData(new Object[]{"否", "已婚"}, "是");
        trainDataSet.addData(new Object[]{"是", "离异"}, "否");
        trainDataSet.addData(new Object[]{"否", "单身"}, "是");
        trainDataSet.addData(new Object[]{"否", "离异"}, "否");
        trainDataSet.addData(new Object[]{"否", "单身"}, "是");
        long startTime = System.currentTimeMillis();
        CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
        cartDecisionTree.fit();
        System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
        System.out.println("用训练好的模型进行预测: ");
        System.out.println("TestData: " + Arrays.toString(new String[]{"是", "离异"}) + " : " + cartDecisionTree.predict(new Object[]{"是", "离异"}));
        System.out.println("TestData: " + Arrays.toString(new String[]{"否", "单身"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身"}));
        System.out.println("TestData: " + Arrays.toString(new String[]{"否", "离异"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "离异"}));
        System.out.println("TestData: " + Arrays.toString(new String[]{"否", "单身"}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身"}));
    }

    public static void testNumberData() {
        System.out.println("================================================================== 测试纯数字的分类 ==================================================================");
        // 构建纯文本特征的数据集
        TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.Number});
        trainDataSet.addData(new Object[]{125d}, "否");
        trainDataSet.addData(new Object[]{100d}, "否");
        trainDataSet.addData(new Object[]{70d}, "否");
        trainDataSet.addData(new Object[]{120d}, "否");
        trainDataSet.addData(new Object[]{95d}, "是");
        trainDataSet.addData(new Object[]{60d}, "是");
        trainDataSet.addData(new Object[]{200d}, "否");
        trainDataSet.addData(new Object[]{85d}, "是");
        trainDataSet.addData(new Object[]{75d}, "否");
        trainDataSet.addData(new Object[]{90d}, "是");
        long startTime = System.currentTimeMillis();
        CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
        cartDecisionTree.fit();
        System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
        System.out.println("用训练好的模型进行预测: ");
        System.out.println("TestData: " + Arrays.toString(new double[]{200d}) + " : " + cartDecisionTree.predict(new Object[]{200d}));
        System.out.println("TestData: " + Arrays.toString(new double[]{85d}) + " : " + cartDecisionTree.predict(new Object[]{85d}));
        System.out.println("TestData: " + Arrays.toString(new double[]{75d}) + " : " + cartDecisionTree.predict(new Object[]{75d}));
        System.out.println("TestData: " + Arrays.toString(new double[]{90d}) + " : " + cartDecisionTree.predict(new Object[]{90d}));
    }

    public static void testStringAndNumberData() {
        System.out.println("================================================================== 测试文本和数字混合的分类 ==================================================================");
        // 构建纯文本特征的数据集
        TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.String, DataType.String, DataType.Number});
        trainDataSet.addData(new Object[]{"是", "单身", 125d}, "否");
        trainDataSet.addData(new Object[]{"是", "已婚", 100d}, "否");
        trainDataSet.addData(new Object[]{"否", "单身", 70d}, "否");
        trainDataSet.addData(new Object[]{"是", "已婚", 120d}, "否");
        trainDataSet.addData(new Object[]{"否", "离异", 95d}, "是");
        trainDataSet.addData(new Object[]{"否", "已婚", 60d}, "是");
        trainDataSet.addData(new Object[]{"是", "离异", 200d}, "否");
        trainDataSet.addData(new Object[]{"否", "单身", 85d}, "是");
        trainDataSet.addData(new Object[]{"否", "离异", 75d}, "否");
        trainDataSet.addData(new Object[]{"否", "单身", 90d}, "是");
        long startTime = System.currentTimeMillis();
        CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
        cartDecisionTree.fit();
        System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
        System.out.println("用训练好的模型进行预测: ");
        System.out.println("TestData: " + Arrays.toString(new Object[]{"是", "离异", 200d}) + " : " + cartDecisionTree.predict(new Object[]{"是", "离异", 200d}));
        System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "单身", 85d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身", 85d}));
        System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "离异", 75d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "离异", 75d}));
        System.out.println("TestData: " + Arrays.toString(new Object[]{"否", "单身", 90d}) + " : " + cartDecisionTree.predict(new Object[]{"否", "单身", 90d}));
    }

}

运行输出如下

================================================================== 测试纯文本的分类 ==================================================================
deep: 1 , predictLabel:, predictArr: [:0.60 ,:0.40 ] , featureIndex: 0 , condition: [,]
deep: 2 , predictLabel:, predictArr: [:0.33 ,:0.67 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [, 已婚] , labels: []
deep: 3 , predictLabel:, predictArr: [:0.50 ,:0.50 ] , features: [, 离异] , [, 离异] , labels: [,]
deep: 3 , predictLabel:, predictArr: [:0.33 ,:0.67 ] , features: [, 单身] , [, 单身] , [, 单身] , labels: [,,]
deep: 2 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 已婚] , [, 已婚] , labels: [,]
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 离异] , labels: []
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 单身] , labels: []
训练用时: 0.026 s
用训练好的模型进行预测: 
TestData: [, 离异] : PredictResult{predictLabel='否', predictArr=[:1.00 ,:0.00 ]}
TestData: [, 单身] : PredictResult{predictLabel='是', predictArr=[:0.33 ,:0.67 ]}
TestData: [, 离异] : PredictResult{predictLabel='否', predictArr=[:0.50 ,:0.50 ]}
TestData: [, 单身] : PredictResult{predictLabel='是', predictArr=[:0.33 ,:0.67 ]}
================================================================== 测试纯数字的分类 ==================================================================
deep: 1 , predictLabel:, predictArr: [:0.60 ,:0.40 ] , featureIndex: 0 , condition: 97.5
deep: 2 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 0 , condition: 110.0
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 0 , condition: 122.5
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 0 , condition: 162.5
deep: 5 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [200.0] , labels: []
deep: 5 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [125.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [120.0] , labels: []
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [100.0] , labels: []
deep: 2 , predictLabel:, predictArr: [:0.33 ,:0.67 ] , featureIndex: 0 , condition: 80.0
deep: 3 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , featureIndex: 0 , condition: 87.5
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , featureIndex: 0 , condition: 92.5
deep: 5 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [95.0] , labels: []
deep: 5 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [90.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [85.0] , labels: []
deep: 3 , predictLabel:, predictArr: [:0.67 ,:0.33 ] , featureIndex: 0 , condition: 65.0
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 0 , condition: 72.5
deep: 5 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [75.0] , labels: []
deep: 5 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [70.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [60.0] , labels: []
训练用时: 0.007 s
用训练好的模型进行预测: 
TestData: [200.0] : PredictResult{predictLabel='否', predictArr=[:1.00 ,:0.00 ]}
TestData: [85.0] : PredictResult{predictLabel='是', predictArr=[:0.00 ,:1.00 ]}
TestData: [75.0] : PredictResult{predictLabel='否', predictArr=[:1.00 ,:0.00 ]}
TestData: [90.0] : PredictResult{predictLabel='是', predictArr=[:0.00 ,:1.00 ]}
================================================================== 测试文本和数字混合的分类 ==================================================================
deep: 1 , predictLabel:, predictArr: [:0.60 ,:0.40 ] , featureIndex: 0 , condition: [,]
deep: 2 , predictLabel:, predictArr: [:0.33 ,:0.67 ] , featureIndex: 2 , condition: 80.0
deep: 3 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , featureIndex: 1 , condition: [离异, 单身]
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [, 离异, 95.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , featureIndex: 2 , condition: 87.5
deep: 5 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [, 单身, 90.0] , labels: []
deep: 5 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [, 单身, 85.0] , labels: []
deep: 3 , predictLabel:, predictArr: [:0.67 ,:0.33 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 4 , predictLabel:, predictArr: [:0.00 ,:1.00 ] , features: [, 已婚, 60.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 离异, 75.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 单身, 70.0] , labels: []
deep: 2 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 1 , condition: [已婚, 离异, 单身]
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , featureIndex: 2 , condition: 110.0
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 已婚, 120.0] , labels: []
deep: 4 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 已婚, 100.0] , labels: []
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 离异, 200.0] , labels: []
deep: 3 , predictLabel:, predictArr: [:1.00 ,:0.00 ] , features: [, 单身, 125.0] , labels: []
训练用时: 0.011 s
用训练好的模型进行预测: 
TestData: [, 离异, 200.0] : PredictResult{predictLabel='否', predictArr=[:1.00 ,:0.00 ]}
TestData: [, 单身, 85.0] : PredictResult{predictLabel='是', predictArr=[:0.00 ,:1.00 ]}
TestData: [, 离异, 75.0] : PredictResult{predictLabel='否', predictArr=[:1.00 ,:0.00 ]}
TestData: [, 单身, 90.0] : PredictResult{predictLabel='是', predictArr=[:0.00 ,:1.00 ]}

相关文章:

  • 百度提交入口网站网址/营销广告
  • 精准营销数据/绍兴seo推广
  • wordpress 调用相册/百度一下点击搜索
  • 网页制作基础教程图片/网络seo优化平台
  • wordpress 阿里云 邮件注册/seo网站自动推广
  • 做网站六安/国内建站平台
  • 开关电源详解
  • 车载测试面试题一览
  • 国内最全的Spring Boot系列之六
  • Go语言初始
  • 【Spring6源码・AOP】代理对象的创建
  • Barra模型因子的构建及应用系列二之Beta因子
  • 一个关于image访问图片跨域的问题
  • 【关于Linux中----进程间通信方式之system V共享内存】
  • 【云原生进阶之容器】第五章容器运行时5.1节--容器运行时总述
  • 自动化 | 这些常用测试平台,你们公司在用的是哪些呢?
  • 【MySQL】深入学习B+索引的使用
  • Anaconda安装、opencv环境配置、jupyter notebook使用虚拟环境