当前位置: 首页 > news >正文

MindSpore 实现unflod和flod

本文写于2022年12月23日。因为MindSpore框架在不断更新,可能你看到这篇文章的时候已经不再适用,或者有更好的实现方式。

unflod的实现

unflod的实现比较简单,因为已经有nn的接口了。实现方法可以参考我的另一篇博文

MindSpore和Python中nn.Unfold的区别_失落的熊熊的博客-CSDN博客_python unfold

flod的实现

flod的实现就有点不那么容易了,因为还没有开发出接口。已经提了issue,并得到了回答,官方说大概要2022年Q1才能完成,并给了我另一个解决方法,使用Col2Im算子。而Col2lm算子的使用也是一个令人头大的事情。

希望添加flod算子 · Issue #I663W0 · MindSpore/mindspore - Gitee.com

官方的文档可以说写的真的非常不清楚,甚至给了我很大的误导。大家可以先看一下官方的文档。

mindspore.Tensor — MindSpore master documentation

mindspore.ops.col2im — MindSpore master documentation

根据我的理解col2im的input需要四个维度。分别是(bs,c,kernel_size*kernel_size,n)。其中bs是batch_size;c是输出的深度,也就是unflod之前的深度;kernel_size是核大小;n为滑动窗口的数量,这个值是计算得出的。

下面给大家我进行测试的一个示例代码,辅助大家对unflod和flod的理解。

import mindspore as ms
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype

opReshape = ms.ops.Reshape()
unfold = nn.Unfold(ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='valid')
x = ms.Tensor(input_data=np.random.rand(1, 2, 4, 4), dtype=mstype.float32) # 1,2,4,4
print(x)
x_unflod = unfold(x) # 1,18,2,2
x_unflod = opReshape(x_unflod,(1,2,9,-1))
[bs,c,h,w]=x.shape
output_size = ms.Tensor(input_data=[h, w], dtype=mstype.int32)
x_flod = ms.ops.col2im(x_unflod, output_size, kernel_size=[3, 3], dilation=[1, 1],
                       padding_value=[0, 0], stride=[1, 1])
print(x_flod)

而正如文章Pytorch unfold和fold_comea23的博客-CSDN博客_pytorch unfold

 所说,只有卷积不重叠的时候unflod和flod才是互逆的,如果想要得到互逆的结果,可以修改参数以实现。

opReshape = ms.ops.Reshape()
unfold = nn.Unfold(ksizes=[1, 3, 3, 1], strides=[1, 3, 3, 1], rates=[1, 1, 1, 1], padding='valid')
x = ms.Tensor(input_data=np.random.rand(1, 1, 6, 6), dtype=mstype.float32) # 1,2,4,4
print(x)
x_unflod = unfold(x) # 1,9,2,2

x_unflod = opReshape(x_unflod,(1,1,9,-1))

[bs,c,h,w]=x.shape
output_size = ms.Tensor(input_data=[h, w], dtype=mstype.int32)
x_flod = ms.ops.col2im(x_unflod, output_size, kernel_size=[3, 3], dilation=[1, 1],
                       padding_value=[0, 0], stride=[3, 3])
print(x_flod)

存在问题:

在使用中,我还是遇到了问题。我在反向传播时用了col2im,然后出现了错误。RuntimeError: Illegal primitive: Primitive Col2Im's bprop not defined.

def construct 中代码如下:

output_size = Tensor(input_data=[h, w], dtype=mstype.int32)
part_ref_rerang = part_ref_rerang_unflod.col2im(output_size, kernel_size=[3, 3], dilation=[1, 1],
padding_value=[1, 1], stride=[1, 1])


报错信息如下:
 

Traceback (most recent call last):
File "main.py", line 125, in
main()
File "main.py", line 108, in main
train_loop(model, train_dataset, loss_ce, optimizer, args) # ________
File "/tmp/pycharm_project_894/train_loop.py", line 80, in train_loop
loss, logits = train_step(data, label)
File "/tmp/pycharm_project_894/train_loop.py", line 70, in train_step
(loss, logits), grads = grad_fn(data, label)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/functional.py", line 455, in inner_aux_grad_fn
return res, grad_weight(aux_fn, weights)(*args)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/composite/base.py", line 530, in after_grad
return grad(fn, weights)(*args, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/common/api.py", line 98, in wrapper
results = fn(*arg, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/composite/base.py", line 517, in after_grad
pynative_executor.grad(grad, fn, weights, grad_position, *args, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/common/api.py", line 819, in grad
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
RuntimeError: Illegal primitive: Primitive Col2Im's bprop not defined.

问题已经回复了issue,希望后续可以得到一个解决。

希望添加flod算子 · Issue #I663W0 · MindSpore/mindspore - Gitee.com

相关文章:

  • 演示:基于WPF的自绘的中国地铁轨道控件
  • 鸿蒙OpenHarmony【轻量系统芯片移植案例】标准系统方案之扬帆移植案例
  • Docker学习笔记(四)单主机网络
  • 海外云手机怎么实现TikTok多账号防关联?
  • 太阳能光伏板航拍红外图像缺陷分类数据集
  • 适合骑行的开放式耳机哪个品牌好?四款开放式蓝牙耳机推荐
  • 【Unity】实现从Excel读取数据制作年份选择器
  • 请求包的大小会影响Redis每秒处理请求数量
  • FPS游戏之漫谈开房间流程
  • C语言统计成绩
  • kafka平滑升级过程指导
  • 【吴恩达·机器学习】第四章:详解神经网络:推理和训练
  • 一文看懂Linux内核页缓存(Page Cache)
  • 安卓面经_安卓基础面全解析<16/30>之线程池全解析
  • 电脑Tab键有什么功能?分享Tab键的6个妙用
  • 四、网络层(六)移动IP
  • 元数据相关的术语,你知道几个?
  • Jmeter实现websocket协议接口测试
  • 直播弹幕系统(五)- 整合Stomp替换原生WebSocket方案探究
  • 【关于时间序列的ML】项目 8 :使用 Facebook Prophet 模型预测股票价格
  • 洛谷 CF1743APassword 题解
  • element plus + vue3表单第一次数据未清空的bug问题解决
  • 电力系统两阶段随机优化(Matlab实现)
  • 基于GINA/凭证提供程序的自助密码管理
  • 如何通过引用传递变量?
  • C++虚函数与多态
  • 获取rdp保存的凭证
  • 谁能主宰智能驾驶赛道?「芯片+感知」是第一主角
  • 【有营养的算法笔记】从推导证明的角度深剖前缀和与差分算法
  • 3D格式转换工具HOOPS Exchange助力3D 打印软件实现质的飞跃
  • 需求的收集,筛选和排序
  • 【Kotlin 协程】Flow 异步流 ④ ( 流的构建器函数 | flow 构建器函数 | flowOf 构建器函数 | asFlow 构建器函数 )