[拆轮子] PaddleDetection中__shared__、__inject__ 和 from_config 三者分别做了什么
在上一篇中,PaddleDetection Register装饰器到底做了什么
https://blog.csdn.net/HaoZiHuang/article/details/128668393
已经介绍了 __shared__
和 __inject__
的作用:
__inject__
表示引入全局字典中已经封装好的模块。如loss等。__shared__
为了实现一些参数的配置全局共享,这些参数可以被backbone, neck,head,loss等所有注册模块共享。
PaddleDetection 文档是这么说的,可是我还是不太懂。于是看了下源码,建议先看上边那篇文章,里边写了在哪部分 __inject__
列表 和 __shared__
列表被读取的。
标题中的三者都是在 ppdet/core/workspace.py
中 create
函数使用的,create
函数用于创建已经被 Register装饰的注册过的类
1. __shared__
部分
在 create 函数中先进行有效性检验, cls_or_name
可以是类别名称的字符串,也可以是已经写好的类,但在 PaddleDetection 当前版本内容,大概率只是字符串
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
if name in global_config:
if isinstance(global_config[name], SchemaDict):
# 如果 cls_or_name 这个类已经注册,则 global_config.values 元素是 SchemaDict
pass
elif hasattr(global_config[name], "__dict__"):
# support instance return directly
# 如果有 __dict__ 则直接返回hhhh( 当前版本用的不多 )
return global_config[name]
else:
raise ValueError("The module {} is not registered".format(name))
else:
raise ValueError("The module {} is not registered".format(name))
之后解析 __shared__
列表中的内容
# parse `shared` annoation of registered modules
if getattr(config, 'shared', None):
for k in config.shared:
target_key = config[k]
shared_conf = config.schema[k].default
assert isinstance(shared_conf, SharedConfig)
if target_key is not None and \
not isinstance(target_key, SharedConfig):
continue # 如果当前当前 target_key 不是SharedConfig, 那么参数已被传入
#
elif shared_conf.key in global_config:
# `key` is present in config
cls_kwargs[k] = global_config[shared_conf.key] # 必须在全局设置! __shared__ (num_classes之类的)
else:
cls_kwargs[k] = shared_conf.default_value # 否则就搞默认的
而之后的几行如果在全局配置过,比如这样:
则读取全局配置的内容
2. from_config
部分
之后执行:
if getattr(cls, 'from_config', None):
cls_kwargs.update(cls.from_config(config, **kwargs))
由于 backbone neck head 之间的配置可能存在耦合,于是部分类实例化时,可能需要之前模块的配置,所以要在 architecture 初始化时,创建 neck head 之类的
给个例子看吧,transformer 和 detr_head 创建时除了读取之前 config 的内容,也传入了来自前置模块的内容
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# transformer
kwargs = {'input_shape': backbone.out_shape}
transformer = create(cfg['transformer'], **kwargs)
# head
kwargs = {
'hidden_dim': transformer.hidden_dim,
'nhead': transformer.nhead,
'input_shape': backbone.out_shape
}
detr_head = create(cfg['detr_head'], **kwargs)
return {
'backbone': backbone,
'transformer': transformer,
"detr_head": detr_head,
}
3. __inject__
部分
__inject__
部分其实与 from_config
很像,都是将类实例化为对象,来看一小部分
k
是 'loss'
,之前在 __inject__
列表中
target_key
是 'DETRLoss'
是一个字符串
target_key = config[k]
......
elif isinstance(target_key, str):
if target_key not in global_config:
raise ValueError("Missing injection config:", target_key)
target = global_config[target_key]
if isinstance(target, SchemaDict):
cls_kwargs[k] = create(target_key) # 在此处将类实例化
elif hasattr(target, '__dict__'): # serialized object
cls_kwargs[k] = target
可以看到 from_config
是由于组件之间存在参数耦合,要在前者创建完毕后,将部分参数传给后者,所以要借助 create API 手动实例化
而 __inject__
的使用很简单,只许在 __inject__
中指定对应的参数即可,如上图中指定了 loss 部分,而 loss 参数是 DETRLoss
,于是 loss 传入后是一个 实例化的 DETRLoss 对象
4. 附录 create 函数源码
def create(cls_or_name, **kwargs):
"""
Create an instance of given module class.
Args:
cls_or_name (type or str): Class of which to create instance.
Returns: instance of type `cls_or_name`
"""
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
if name in global_config:
if isinstance(global_config[name], SchemaDict):
pass
elif hasattr(global_config[name], "__dict__"):
# support instance return directly
return global_config[name]
else:
raise ValueError("The module {} is not registered".format(name))
else:
raise ValueError("The module {} is not registered".format(name))
config = global_config[name]
cls = getattr(config.pymodule, name)
cls_kwargs = {}
cls_kwargs.update(global_config[name])
# parse `shared` annoation of registered modules
if getattr(config, 'shared', None):
for k in config.shared:
target_key = config[k]
shared_conf = config.schema[k].default
assert isinstance(shared_conf, SharedConfig)
if target_key is not None and not isinstance(target_key,
SharedConfig): # 如果我指定则就传入指定的
continue # value is given for the module
elif shared_conf.key in global_config:
# `key` is present in config
cls_kwargs[k] = global_config[shared_conf.key] # 必须在全局设置! __shared__ (num_classes之类的)
else:
cls_kwargs[k] = shared_conf.default_value # 否则就搞默认的
# parse `inject` annoation of registered modules
if getattr(cls, 'from_config', None):
cls_kwargs.update(cls.from_config(config, **kwargs))
if getattr(config, 'inject', None):
for k in config.inject:
target_key = config[k]
# optional dependency
if target_key is None:
continue
if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
if 'name' not in target_key.keys():
continue
inject_name = str(target_key['name'])
if inject_name not in global_config:
raise ValueError(
"Missing injection name {} and check it's name in cfg file".
format(k))
target = global_config[inject_name]
for i, v in target_key.items():
if i == 'name':
continue
target[i] = v
if isinstance(target, SchemaDict):
cls_kwargs[k] = create(inject_name)
elif isinstance(target_key, str):
if target_key not in global_config:
raise ValueError("Missing injection config:", target_key)
target = global_config[target_key]
if isinstance(target, SchemaDict):
cls_kwargs[k] = create(target_key)
elif hasattr(target, '__dict__'): # serialized object
cls_kwargs[k] = target
else:
raise ValueError("Unsupported injection type:", target_key)
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return cls(**cls_kwargs)