代码拉取完成,页面将自动刷新
#include "dnnl.hpp"
using namespace dnnl;
using tag = memory::format_tag;
using dt = memory::data_type;
extern "C"
void winconv(float *__restrict__ image, const int irows, const int icols,
const int C, float *__restrict__ filter, const int K,
const int batch, float *__restrict__ out) {
dnnl::engine engine(dnnl::engine::kind::cpu, 0);
dnnl::stream engine_stream(engine);
const memory::dim N = batch,
IH = irows, IW = icols, OH = IH - 2, OW = IW - 2,
IC = C, OC = K, KH = 3, KW = 3;
memory::dims src_dims = {N, IC, IH, IW};
memory::dims weights_dims = {OC, IC, KH, KW};
memory::dims dst_dims = {N, OC, OH, OW};
memory::dims strides_dims = {1, 1};
memory::dims padding_dims_l = {0, 0};
memory::dims padding_dims_r = {0, 0};
// Create memory objects for tensor data (src, weights, dst)
auto user_src_mem = memory({src_dims, dt::f32, tag::nchw}, engine, image);
auto user_weights_mem = memory({weights_dims, dt::f32, tag::oihw}, engine, filter);
auto user_dst_mem = memory({dst_dims, dt::f32, tag::nchw}, engine, out);
// Create memory descriptors with format_tag::any for the primitive
auto conv_src_md = memory::desc(src_dims, dt::f32, tag::any);
auto conv_weights_md = memory::desc(weights_dims, dt::f32, tag::any);
auto conv_dst_md = memory::desc(dst_dims, dt::f32, tag::any);
// Create operation descriptor
// === MODIFICATION NOTE === change `convolution_auto` -> `convolution_winograd` for avx512
#ifdef WINOGRAD
#define CNN_ALGO algorithm::convolution_winograd
#else
#define CNN_ALGO algorithm::convolution_direct
#endif
auto conv_desc = convolution_forward::desc(prop_kind::forward_inference,
CNN_ALGO, conv_src_md, conv_weights_md,
conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
// Create primitive descriptor.
auto conv_pd = convolution_forward::primitive_desc(conv_desc, engine);
// Create memory descriptors with format_tag::any for the primitive
auto conv_src_mem = user_src_mem;
auto conv_weights_mem = user_weights_mem;
auto conv_dst_mem = user_dst_mem;
// Reorder the data in case memory layouts are different
if (conv_pd.src_desc() != user_src_mem.get_desc()) {
conv_src_mem = memory(conv_pd.src_desc(), engine);
reorder(user_src_mem, conv_src_mem)
.execute(engine_stream, user_src_mem, conv_src_mem);
}
if (conv_pd.weights_desc() != user_weights_mem.get_desc()) {
conv_weights_mem = memory(conv_pd.weights_desc(), engine);
reorder(user_weights_mem, conv_weights_mem)
.execute(engine_stream, user_weights_mem, conv_weights_mem);
}
if (conv_pd.dst_desc() != user_dst_mem.get_desc()) {
conv_dst_mem = memory(conv_pd.dst_desc(), engine);
}
// Create the primitive.
auto conv_prim = convolution_forward(conv_pd);
// Primitive arguments.
std::unordered_map<int, memory> conv_args;
conv_args.insert({DNNL_ARG_SRC, conv_src_mem});
conv_args.insert({DNNL_ARG_WEIGHTS, conv_weights_mem});
conv_args.insert({DNNL_ARG_DST, conv_dst_mem});
// Primitive execution: convolution
conv_prim.execute(engine_stream, conv_args);
if (conv_pd.dst_desc() != user_dst_mem.get_desc()) {
reorder(conv_dst_mem, user_dst_mem)
.execute(engine_stream, conv_dst_mem, user_dst_mem);
} else
user_dst_mem = conv_dst_mem;
// Wait for the computation to finalize.
engine_stream.wait();
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。