当前位置: 首页 > news >正文

[拆轮子] PaddleDetection中__shared__、__inject__ 和 from_config 三者分别做了什么

在上一篇中,PaddleDetection Register装饰器到底做了什么
https://blog.csdn.net/HaoZiHuang/article/details/128668393

已经介绍了 __shared____inject__ 的作用:

  • __inject__ 表示引入全局字典中已经封装好的模块。如loss等。
  • __shared__为了实现一些参数的配置全局共享,这些参数可以被backbone, neck,head,loss等所有注册模块共享。

PaddleDetection 文档是这么说的,可是我还是不太懂。于是看了下源码,建议先看上边那篇文章,里边写了在哪部分 __inject__ 列表 和 __shared__列表被读取的。

标题中的三者都是在 ppdet/core/workspace.pycreate 函数使用的,create 函数用于创建已经被 Register装饰的注册过的类

1. __shared__ 部分

在 create 函数中先进行有效性检验, cls_or_name 可以是类别名称的字符串,也可以是已经写好的类,但在 PaddleDetection 当前版本内容,大概率只是字符串

    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
        	# 如果 cls_or_name 这个类已经注册,则 global_config.values 元素是 SchemaDict
            pass
            
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            # 如果有 __dict__ 则直接返回hhhh( 当前版本用的不多 )
            return global_config[name]
            
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

之后解析 __shared__ 列表中的内容

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and \
                   not isinstance(target_key, SharedConfig):
                continue  # 如果当前当前 target_key 不是SharedConfig, 那么参数已被传入
			
			# 
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

而之后的几行如果在全局配置过,比如这样:
在这里插入图片描述
则读取全局配置的内容

2. from_config 部分

之后执行:

    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

由于 backbone neck head 之间的配置可能存在耦合,于是部分类实例化时,可能需要之前模块的配置,所以要在 architecture 初始化时,创建 neck head 之类的

给个例子看吧,transformer 和 detr_head 创建时除了读取之前 config 的内容,也传入了来自前置模块的内容

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])
        # transformer
        kwargs = {'input_shape': backbone.out_shape}
        transformer = create(cfg['transformer'], **kwargs)
        # head
        kwargs = {
            'hidden_dim': transformer.hidden_dim,
            'nhead': transformer.nhead,
            'input_shape': backbone.out_shape
        }
        detr_head = create(cfg['detr_head'], **kwargs)

        return {
            'backbone': backbone,
            'transformer': transformer,
            "detr_head": detr_head,
        }

3. __inject__ 部分

__inject__ 部分其实与 from_config 很像,都是将类实例化为对象,来看一小部分

在这里插入图片描述

k'loss',之前在 __inject__ 列表中
target_key'DETRLoss' 是一个字符串

	target_key = config[k]
	......
	
    elif isinstance(target_key, str):
        if target_key not in global_config:
            raise ValueError("Missing injection config:", target_key)
        target = global_config[target_key]
        if isinstance(target, SchemaDict):
            cls_kwargs[k] = create(target_key)   # 在此处将类实例化
        elif hasattr(target, '__dict__'):  # serialized object
            cls_kwargs[k] = target

可以看到 from_config 是由于组件之间存在参数耦合,要在前者创建完毕后,将部分参数传给后者,所以要借助 create API 手动实例化

__inject__ 的使用很简单,只许在 __inject__ 中指定对应的参数即可,如上图中指定了 loss 部分,而 loss 参数是 DETRLoss,于是 loss 传入后是一个 实例化的 DETRLoss 对象

4. 附录 create 函数源码

def create(cls_or_name, **kwargs):
    """
    Create an instance of given module class.

    Args:
        cls_or_name (type or str): Class of which to create instance.

    Returns: instance of type `cls_or_name`
    """
    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
            pass
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            return global_config[name]
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

    config = global_config[name]
    cls = getattr(config.pymodule, name)
    cls_kwargs = {}
    cls_kwargs.update(global_config[name])

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and not isinstance(target_key,
                                                         SharedConfig): # 如果我指定则就传入指定的
                continue  # value is given for the module
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

    # parse `inject` annoation of registered modules
    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

    if getattr(config, 'inject', None):
        for k in config.inject:
            target_key = config[k]
            # optional dependency
            if target_key is None:
                continue

            if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
                if 'name' not in target_key.keys():
                    continue
                inject_name = str(target_key['name'])
                if inject_name not in global_config:
                    raise ValueError(
                        "Missing injection name {} and check it's name in cfg file".
                        format(k))
                target = global_config[inject_name]
                for i, v in target_key.items():
                    if i == 'name':
                        continue
                    target[i] = v
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(inject_name)
            elif isinstance(target_key, str):
                if target_key not in global_config:
                    raise ValueError("Missing injection config:", target_key)
                target = global_config[target_key]
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(target_key)
                elif hasattr(target, '__dict__'):  # serialized object
                    cls_kwargs[k] = target
            else:
                raise ValueError("Unsupported injection type:", target_key)
    # prevent modification of global config values of reference types
    # (e.g., list, dict) from within the created module instances
    #kwargs = copy.deepcopy(kwargs)
    return cls(**cls_kwargs)

相关文章:

  • 某学校网站建设方案论文/长春网站制作方案定制
  • 菜单微网站/河北seo公司
  • 网站正能量/新闻头条今日要闻10条
  • wordpress文章付费支付宝/种子资源地址
  • 海兴网站建设/网络推广网站电话
  • 专业做网站开发费用/网上营销推广
  • 2022尚硅谷SSM框架跟学(五)Spring基础二
  • Springboot打成JAR包后读取配置文件
  • BetaFlight飞控AOCODARC-F7MINI固件编译
  • Go语言数据结构
  • 营销科学年度复盘|9个数字,见证“科学增长”的力量
  • Elasticsearch(二)--Elasticsearch客户端讲解
  • 【ROS2 入门】ROS 2 参数服务器(parameters)概述
  • 【异常】记一次因错误运用数据冗余,导致的数据不一致的生产事故
  • 【OpenCV 例程 300篇】256. 特征检测之 CenSurE(StarDetector)算法
  • 详细实例说明+典型案例实现 对动态规划法进行全面分析 | C++
  • 【C++】map和set的使用
  • Pytorch DataLoader中的num_workers (选择最合适的num_workers值)