基于CNTK实现迁移学习-图像分类【附部分源码】
文章目录
- 前言
- 一、什么是迁移学习
- 二、实现方式
- 1.预训练模型
- 2.代码实现
- 1.变量定义
- 2.网络构建
- 3.网络训练/验证/测试
- 三、效果展示
- 四、CNTK网络结构可视化(ResNet50)
前言
本文主要实现基于cntk实现迁移学习,以图像分类为例,利用ResNet模型
一、什么是迁移学习
通俗的讲就是站在巨人的肩膀上学习,利用已经训练的比较好的(图像特征提取能力比较好的)模型,根据自定义的数据集,初略修改模型结构,基于之前的权值继续训练,这样做的好处是特征提取能力比较强,能加快训练。
二、实现方式
首先本人之前的一篇关于CNTK文章:基于CNTK/C#实现图像分类,此方法就是使用了迁移学习的方法。
1.预训练模型
本文使用的是CNTK内置的网络结构,网络结构模型如下,可自行免费下载:
- AlexNet
- InceptionV3
- ResNet18
- ResNet34
- ResNet50
- ResNet101
- ResNet152
- Fasr-RCNN
2.代码实现
针对对C#有一定基础的同学
1.变量定义
//迁移学习网络结构的输入层和输出层名称,用于自定义修改网络的输入输出结构
private static string featureNodeName = "features";
private static string lastHiddenNodeName = "z.x";
private static string predictionNodeName = "prediction";
private static string pre_Model = "./PreModel/ResNet50_ImageNet_CNTK.model";
//训练的图像的参数
private static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage";
private static string ImageDir_Test = @"./DataSet_Classification_Chess/test";
private static int[] imageDims = new int[] { 224, 224, 3 };
private static string[] classes_names = new string[] { };
private static int TrainNum = 400;
private static int batch_size = 4;
private static float learning_rate = 0.0001F;
private static float momentum = 0.9F;
private static float l2RegularizationWeight = 0.05F;
private static string model_path = "./resultModel";
private static bool useGPU = true;
private static string ext = "bmp";
private static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;
2.网络构建
这里根据节点名称进行修改:
private static Function CreateTransferLearningModel(string baseModelFile, int numClasses, DeviceDescriptor device, out Variable imageInput, out Variable labelInput, out Function trainingLoss, out Function predictionError)
{
Function baseModel = Function.Load(baseModelFile, device);
imageInput = Variable.InputVariable(imageDims, DataType.Float);
labelInput = Variable.InputVariable(new int[] { numClasses }, DataType.Float);
Function normalizedFeatureNode = CNTKLib.Minus(imageInput, Constant.Scalar(DataType.Float, 114.0F));
Variable oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
Function lastNode = baseModel.FindByName(lastHiddenNodeName);
Function clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
ParameterCloningMethod.Freeze,
new Dictionary<Variable, Variable>() { { oldFeatureNode, normalizedFeatureNode } });
Function clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, predictionNodeName);
trainingLoss = CNTKLib.CrossEntropyWithSoftmax(clonedModel, labelInput);
predictionError = CNTKLib.ClassificationError(clonedModel, labelInput);
return clonedModel;
}
3.网络训练/验证/测试
参考上边提供的文章,这里对每行代码不做过多解释,直接上main函数的实现方法
classes_names = CreateDataList(ImageDir_Train, 0.9, Path.Combine(model_path, "train_data.txt"), Path.Combine(model_path, "val_data.txt"));
MinibatchSource minibatchSource = CreateMinibatchSource(Path.Combine(model_path, "train_data.txt"), imageDims, classes_names.Length);
//网络结构迁移
Variable imageInput, labelInput;
Function trainingLoss, predictionError;
Function transferLearningModel = CreateTransferLearningModel(pre_Model, classes_names.Length, device, out imageInput, out labelInput, out trainingLoss, out predictionError);
//学习率设置
AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions() { l2RegularizationWeight = l2RegularizationWeight };
IList<Learner> parameterLearners = new List<Learner>() {
Learner.MomentumSGDLearner(transferLearningModel.Parameters(),
new TrainingParameterScheduleDouble(learning_rate, 0),
new TrainingParameterScheduleDouble(momentum, 0),
true,
additionalLearningOptions)};
//获得训练器
var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners);
//模型训练
int outputFrequencyInMinibatches = 1;
int data_length = readFileLines(Path.Combine(model_path, "train_data.txt"));
TrainNum = Convert.ToInt32(TrainNum * data_length / batch_size);
for (int minibatchCount = 0; minibatchCount < TrainNum; ++minibatchCount)
{
var minibatchData = minibatchSource.GetNextMinibatch((uint)batch_size, device);
trainer.TrainMinibatch(new Dictionary<Variable, MinibatchData>()
{
{ imageInput, minibatchData[minibatchSource.StreamInfo("image")] },
{ labelInput, minibatchData[minibatchSource.StreamInfo("labels")] }
}, device);
PrintTrainingProgress(trainer, minibatchCount, TrainNum, outputFrequencyInMinibatches);
}
//模型保存
transferLearningModel.Save(Path.Combine(model_path, "Ctu_Classification.model"));
Console.ReadLine();
//模型验证
ValidateModelWithMinibatchSource(Path.Combine(model_path, "Ctu_Classification.model"), Path.Combine(model_path, "val_data.txt"), imageDims, classes_names.Length, device);
Console.ReadLine();
//模型预测
Function model = Function.Load(Path.Combine(model_path, "Ctu_Classification.model"), device);
string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
foreach (string file in all_image)
{
var inputValue = new Value(new NDArrayView(imageDims, Load(imageDims[0], imageDims[1], file), device));
var inputDataMap = new Dictionary<Variable, Value>() { { model.Arguments[0], inputValue } };
var outputDataMap = new Dictionary<Variable, Value>() {
{ model.Output, null }
};
model.Evaluate(inputDataMap, outputDataMap, device);
var outputData = outputDataMap[model.Output].GetDenseData<float>(model.Output).First();
var output = outputData.Select(x => (double)x).ToArray();
var classIndex = Array.IndexOf(output, output.Max());
var className = classes_names[classIndex];
Console.WriteLine(file + " : " + className);
}
Console.ReadLine();