当前位置: 首页 > news >正文

基于paddlex图像分类模型训练(一):图像分类数据集切分:文件夹转化为imagenet训练格式

背景

在使用paddlex GUI训练图像分类时,内部自动对导入的分类文件夹进行细分,本文主要介绍其图像分类数据切分源码,或可作为其他项目储备代码:https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/tools/dataset_split/imagenet_split.py
在这里插入图片描述

数据集形式

PaddleClas 数据切分说明 使用 txt 格式文件指定训练集和测试集,以 ImageNet1k 数据集为例,其中 train_list.txt 和 val_list.txt 的格式形如:

# 每一行采用"空格"分隔图像路径与标注

# 下面是 train_list.txt 中的格式样例
train/n01440764/n01440764_10026.JPEG 0
...
# 下面是 val_list.txt 中的格式样例
val/ILSVRC2012_val_00000001.JPEG 65
...

切分前:分类好的文件夹
切分后:train_list.txt,text_list.txt,val_list.txt,labels.txt

相关代码 (mian.py,imagenet_split.py, utils.py)

# -*- coding: utf-8 -*-
# @Time : 2023/1/18 15:30
# @Author : XyZeng

import imagenet_split


def test_cls_split( input_dir='./datasets/cls_2',
                    val_percent=0.2,
                    test_percent=0,
                    ):
    '''
        val_num = int(image_num * val_percent)
        test_num = int(image_num * test_percent)
        train_num = image_num - val_num - test_num
    '''

    imagenet_split.split_imagenet_dataset(
        dataset_dir=input_dir,
        val_percent=val_percent,
        test_percent=test_percent,
        save_dir=input_dir

    )

if __name__ == '__main__':
    # input_dir = './datasets/cls_2'
    test_cls_split()

imagenet_split.py

用于将文件夹整理好的分类图像,转换为可用于训练的txt

import random
from utils  import *

def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir):
    all_files = list_files(dataset_dir)
    label_list = list()
    train_image_anno_list = list()
    val_image_anno_list = list()
    test_image_anno_list = list()
    for file in all_files:
        if not is_pic(file):
            continue
        label, image_name = osp.split(file)
        if label not in label_list:
            label_list.append(label)
    label_list = sorted(label_list)

    for i in range(len(label_list)):
        image_list = list_files(osp.join(dataset_dir, label_list[i]))
        image_anno_list = list()
        for img in image_list:
            image_anno_list.append([osp.join(label_list[i], img), i])
        random.shuffle(image_anno_list)
        image_num = len(image_anno_list)
        val_num = int(image_num * val_percent)
        test_num = int(image_num * test_percent)
        train_num = image_num - val_num - test_num

        train_image_anno_list += image_anno_list[:train_num]
        val_image_anno_list += image_anno_list[train_num:train_num + val_num]
        test_image_anno_list += image_anno_list[train_num + val_num:]

    with open(
            osp.join(save_dir, 'train_list.txt'), mode='w',
            encoding='utf-8') as f:
        for x in train_image_anno_list:
            file, label = x
            f.write('{} {}\n'.format(file, label))
    with open(
            osp.join(save_dir, 'val_list.txt'), mode='w',
            encoding='utf-8') as f:
        for x in val_image_anno_list:
            file, label = x
            f.write('{} {}\n'.format(file, label))
    if len(test_image_anno_list):
        with open(
                osp.join(save_dir, 'test_list.txt'), mode='w',
                encoding='utf-8') as f:
            for x in test_image_anno_list:
                file, label = x
                f.write('{} {}\n'.format(file, label))
    # 创建label标签
    with open(
            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
        for l in sorted(label_list):
            f.write('{}\n'.format(l))

    return len(train_image_anno_list), len(val_image_anno_list), len(
        test_image_anno_list)

utils.py

后续可以添加自己各种功能函数

import os
import os.path as osp

def is_pic(filename):
    """ 判断文件是否为图片格式

    Args:
        filename: 文件路径
    """
    suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
    suffix = filename.strip().split('.')[-1]
    if suffix not in suffixes:
        return False
    return True


def list_files(dirname):
    """ 列出目录下所有文件(包括所属的一级子目录下文件)

    Args:
        dirname: 目录路径
    """

    def filter_file(f):
        if f.startswith('.'):
            return True
        return False

    all_files = list()
    dirs = list()
    for f in os.listdir(dirname):
        if filter_file(f):
            continue
        if osp.isdir(osp.join(dirname, f)):
            dirs.append(f)
        else:
            all_files.append(f)
    for d in dirs:
        for f in os.listdir(osp.join(dirname, d)):
            if filter_file(f):
                continue
            if osp.isdir(osp.join(dirname, d, f)):
                continue
            all_files.append(osp.join(d, f))
    return all_files

相关文章:

  • 阿里巴巴网站怎么做/seo推广优势
  • 信用中国官网企业查询/2022年搜索引擎优化指南
  • 校园网站策划书/百度链接提交收录入口
  • 提供营销网站建设公司/百度推广一个月多少钱
  • 漫画做视频在线观看网站/营销背景包括哪些内容
  • 无锡网站建设 app/青岛谷歌优化公司
  • 【021·未解】1947. 最大兼容性评分和【暴力回溯】
  • (1分钟速览)KBM-SLAM 论文阅读笔记
  • Linux:查看服务器信息,CPU、内存、系统版本、内核版本等
  • 8.框架Spring
  • hutool日常用法
  • 考研数学你必须要懂的事情
  • 一些工具软件的使用
  • IDEA创建SpringBoot的Web项目,并使用外部Tomcat
  • chromecast激活
  • 深信服某次面试题
  • IT运维服务体系的总体架构是什么?
  • 博弈论-多智能体强化学习基础