当前位置: 首页 > news >正文

基于CNTK实现迁移学习-图像分类【附部分源码】

文章目录

  • 前言
  • 一、什么是迁移学习
  • 二、实现方式
    • 1.预训练模型
    • 2.代码实现
    • 1.变量定义
    • 2.网络构建
    • 3.网络训练/验证/测试
  • 三、效果展示
  • 四、CNTK网络结构可视化(ResNet50)


前言

本文主要实现基于cntk实现迁移学习,以图像分类为例,利用ResNet模型


一、什么是迁移学习

通俗的讲就是站在巨人的肩膀上学习,利用已经训练的比较好的(图像特征提取能力比较好的)模型,根据自定义的数据集,初略修改模型结构,基于之前的权值继续训练,这样做的好处是特征提取能力比较强,能加快训练。

二、实现方式

首先本人之前的一篇关于CNTK文章:基于CNTK/C#实现图像分类,此方法就是使用了迁移学习的方法。

1.预训练模型

本文使用的是CNTK内置的网络结构,网络结构模型如下,可自行免费下载:

  • AlexNet
  • InceptionV3
  • ResNet18
  • ResNet34
  • ResNet50
  • ResNet101
  • ResNet152
  • Fasr-RCNN

2.代码实现

针对对C#有一定基础的同学

1.变量定义

 //迁移学习网络结构的输入层和输出层名称,用于自定义修改网络的输入输出结构
private static string featureNodeName = "features";
private static string lastHiddenNodeName = "z.x";
private static string predictionNodeName = "prediction";
private static string pre_Model = "./PreModel/ResNet50_ImageNet_CNTK.model";
//训练的图像的参数
private static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage";
private static string ImageDir_Test = @"./DataSet_Classification_Chess/test";
private static int[] imageDims = new int[] { 224, 224, 3 };
private static string[] classes_names = new string[] { };
private static int TrainNum = 400;
private static int batch_size = 4;
private static float learning_rate = 0.0001F;
private static float momentum = 0.9F;
private static float l2RegularizationWeight = 0.05F;
private static string model_path = "./resultModel";
private static bool useGPU = true;
private static string ext = "bmp";
private static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;

2.网络构建

这里根据节点名称进行修改:

private static Function CreateTransferLearningModel(string baseModelFile, int numClasses, DeviceDescriptor device, out Variable imageInput, out Variable labelInput, out Function trainingLoss, out Function predictionError)
{
    Function baseModel = Function.Load(baseModelFile, device);

    imageInput = Variable.InputVariable(imageDims, DataType.Float);
    labelInput = Variable.InputVariable(new int[] { numClasses }, DataType.Float);
    Function normalizedFeatureNode = CNTKLib.Minus(imageInput, Constant.Scalar(DataType.Float, 114.0F));

    Variable oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
    Function lastNode = baseModel.FindByName(lastHiddenNodeName);

    Function clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
        ParameterCloningMethod.Freeze,
        new Dictionary<Variable, Variable>() { { oldFeatureNode, normalizedFeatureNode } });

    Function clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, predictionNodeName);

    trainingLoss = CNTKLib.CrossEntropyWithSoftmax(clonedModel, labelInput);
    predictionError = CNTKLib.ClassificationError(clonedModel, labelInput);

    return clonedModel;
}

3.网络训练/验证/测试

参考上边提供的文章,这里对每行代码不做过多解释,直接上main函数的实现方法

classes_names = CreateDataList(ImageDir_Train, 0.9, Path.Combine(model_path, "train_data.txt"), Path.Combine(model_path, "val_data.txt"));
            
MinibatchSource minibatchSource = CreateMinibatchSource(Path.Combine(model_path, "train_data.txt"), imageDims, classes_names.Length);

//网络结构迁移
Variable imageInput, labelInput;
Function trainingLoss, predictionError;
Function transferLearningModel = CreateTransferLearningModel(pre_Model, classes_names.Length, device, out imageInput, out labelInput, out trainingLoss, out predictionError);

//学习率设置
AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions() { l2RegularizationWeight = l2RegularizationWeight };
IList<Learner> parameterLearners = new List<Learner>() {
    Learner.MomentumSGDLearner(transferLearningModel.Parameters(),
    new TrainingParameterScheduleDouble(learning_rate, 0),
    new TrainingParameterScheduleDouble(momentum, 0),
    true,
    additionalLearningOptions)};

//获得训练器
var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners);

//模型训练
int outputFrequencyInMinibatches = 1; 
int data_length = readFileLines(Path.Combine(model_path, "train_data.txt"));
TrainNum = Convert.ToInt32(TrainNum * data_length / batch_size);
for (int minibatchCount = 0; minibatchCount < TrainNum; ++minibatchCount)
{
    var minibatchData = minibatchSource.GetNextMinibatch((uint)batch_size, device);

    trainer.TrainMinibatch(new Dictionary<Variable, MinibatchData>()
    {
        { imageInput, minibatchData[minibatchSource.StreamInfo("image")] },
        { labelInput, minibatchData[minibatchSource.StreamInfo("labels")] } 
    }, device);
    PrintTrainingProgress(trainer, minibatchCount, TrainNum, outputFrequencyInMinibatches);
}

//模型保存
transferLearningModel.Save(Path.Combine(model_path, "Ctu_Classification.model"));

Console.ReadLine();

//模型验证
ValidateModelWithMinibatchSource(Path.Combine(model_path, "Ctu_Classification.model"), Path.Combine(model_path, "val_data.txt"), imageDims, classes_names.Length, device);

Console.ReadLine();

//模型预测
Function model = Function.Load(Path.Combine(model_path, "Ctu_Classification.model"), device);
string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
foreach (string file in all_image)
{
    var inputValue = new Value(new NDArrayView(imageDims, Load(imageDims[0], imageDims[1], file), device));
    var inputDataMap = new Dictionary<Variable, Value>() { { model.Arguments[0], inputValue } };
    var outputDataMap = new Dictionary<Variable, Value>() {
            { model.Output, null }
        };
    model.Evaluate(inputDataMap, outputDataMap, device);
    var outputData = outputDataMap[model.Output].GetDenseData<float>(model.Output).First();

    var output = outputData.Select(x => (double)x).ToArray();
    var classIndex = Array.IndexOf(output, output.Max());
    var className = classes_names[classIndex];
    Console.WriteLine(file + " : " + className);
}

Console.ReadLine();

三、效果展示

在这里插入图片描述

四、CNTK网络结构可视化(ResNet50)

在这里插入图片描述

相关文章:

  • c sql网站开发/百度云搜索引擎官网
  • 包头企业做网站/营销策划有限公司经营范围
  • 义乌城市投资建设集团网站/女教师遭网课入侵视频大全
  • 网站设为首页加入收藏/常用的营销策略
  • 东莞专业做外贸网站/网站seo优化技巧
  • wordpress 说说页面/广东深圳疫情最新
  • redis-集群理论篇
  • 安装Hive集群
  • 大数据行业现在工作很难找吗?
  • KMP算法模式匹配——手工求解next和nextval数组值
  • 【蓝桥杯真题练习】STEMA科技素养练习题库 答案版011 持续更新中~
  • Maven:命令行
  • java计算机毕业设计燕理快递中转站系统设计与实现MyBatis+系统+LW文档+源码+调试部署
  • 杰理之注册编码服务事件回调【篇】
  • SpringBoot统一返回处理出现cannot be cast to java.lang.String异常
  • (刘二大人)PyTorch深度学习实践-卷积网络(Advance)
  • mysql 理论知识
  • Decoder与Encoder重要组件