加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.cpp 12.15 KB
一键复制 编辑 原始数据 按行查看 历史
zksite 提交于 2024-07-03 15:16 . first commit
#include <stdio.h>
#include <vector>
#include <memory>
#include <stdexcept>
#include <string>
#include <cmath>
#include <torch/torch.h>
using namespace torch;
using namespace std;
namespace nn = nn;
//
//// 3x3 Convolution with padding
//nn::Conv2d conv3x3(int64_t in_planes, int64_t out_planes, int64_t stride = 1, int64_t padding = 1, bool bias = false) {
// return nn::Conv2d(
// nn::Conv2dOptions(in_planes, out_planes, 3).stride(stride).padding(padding).bias(false));
//}
//
//struct ConvBlockImpl : nn::Module {
// nn::BatchNorm2d bn1, bn2, bn3;
// nn::Conv2d conv1, conv2, conv3;
// std::optional<nn::Sequential> downsample;
//
// ConvBlockImpl(int64_t in_planes, int64_t out_planes)
// : bn1(in_planes),
// conv1(conv3x3(in_planes, out_planes / 2)),
// bn2(out_planes / 2),
// conv2(conv3x3(out_planes / 2, out_planes / 4)),
// bn3(out_planes / 4),
// conv3(conv3x3(out_planes / 4, out_planes / 4)) {
//
// if (in_planes != out_planes) {
// downsample = nn::Sequential(
// nn::BatchNorm2d(in_planes),
// nn::ReLU(true),
// nn::Conv2d(nn::Conv2dOptions(in_planes, out_planes, 1).stride(1).bias(false))
// );
// } else {
// downsample = nullptr;
// }
//
// register_module("bn1", bn1);
// register_module("conv1", conv1);
// register_module("bn2", bn2);
// register_module("conv2", conv2);
// register_module("bn3", bn3);
// register_module("conv3", conv3);
// if (downsample.has_value()) {
// register_module("downsample", downsample.value());
// }
// }
//
// torch::Tensor forward(torch::Tensor x) {
// auto residual = x;
//
// auto out1 = conv1->forward(relu(bn1->forward(x)));
// auto out2 = conv2->forward(relu(bn2->forward(out1)));
// auto out3 = conv3->forward(relu(bn3->forward(out2)));
//
// out3 = torch::cat({out1, out2, out3}, 1);
//
// if (downsample.has_value()) {
// residual = downsample.value()->forward(residual);
// }
//
// out3 += residual;
//
// return out3;
// }
//};
//
//TORCH_MODULE(ConvBlock);
//
//struct HourGlassImpl : nn::Module {
// int64_t num_modules;
// int64_t depth;
// int64_t features;
//
// HourGlassImpl(int64_t num_modules, int64_t depth, int64_t num_features)
// : num_modules(num_modules), depth(depth), features(num_features) {
// _generate_network(depth);
// }
//
// void _generate_network(int64_t level) {
// register_module("b1_" + std::to_string(level), ConvBlock(features, features));
// register_module("b2_" + std::to_string(level), ConvBlock(features, features));
//
// if (level > 1) {
// _generate_network(level - 1);
// } else {
// register_module("b2_plus_" + std::to_string(level), ConvBlock(features, features));
// }
//
// register_module("b3_" + std::to_string(level), ConvBlock(features, features));
// }
//
// torch::Tensor _forward(int64_t level, torch::Tensor inp) {
// auto up1 = _modules["b1_" + std::to_string(level)]->as<ConvBlock>()->forward(inp);
// auto low1 = avg_pool2d(inp, 2, 2);
// auto low2 = _modules["b2_" + std::to_string(level)]->as<ConvBlock>()->forward(low1);
//
// if (level > 1) {
// low2 = _forward(level - 1, low2);
// } else {
// low2 = _modules["b2_plus_" + std::to_string(level)]->as<ConvBlock>()->forward(low2);
// }
//
// auto low3 = _modules["b3_" + std::to_string(level)]->as<ConvBlock>()->forward(low2);
// auto up2 = interpolate(low3, upsample_options().scale_factor({2.0, 2.0}).mode(torch::kNearest));
//
// return up1 + up2;
// }
//
// torch::Tensor forward(torch::Tensor x) {
// return _forward(depth, x);
// }
//};
//
//TORCH_MODULE(HourGlass);
//
//struct FANImpl : nn::Module {
// int64_t num_modules;
//
// FANImpl(int64_t num_modules = 1)
// : num_modules(num_modules) {
// conv1 = register_module("conv1", Conv2d(Conv2dOptions(3, 64, 7).stride(2).padding(3)));
// bn1 = register_module("bn1", nn::BatchNorm2d(64));
// conv2 = register_module("conv2", ConvBlock(64, 128));
// conv3 = register_module("conv3", ConvBlock(128, 128));
// conv4 = register_module("conv4", ConvBlock(128, 256));
//
// for (int64_t hg_module = 0; hg_module < num_modules; ++hg_module) {
// add_module("m" + std::to_string(hg_module), HourGlass(1, 4, 256));
// add_module("top_m_" + std::to_string(hg_module), ConvBlock(256, 256));
// add_module("conv_last" + std::to_string(hg_module),
// Conv2d(Conv2dOptions(256, 256, 1).stride(1).padding(0)));
// add_module("bn_end" + std::to_string(hg_module), nn::BatchNorm2d(256));
// add_module("l" + std::to_string(hg_module), Conv2d(Conv2dOptions(256, 68, 1).stride(1).padding(0)));
//
// if (hg_module < num_modules - 1) {
// add_module("bl" + std::to_string(hg_module), Conv2d(Conv2dOptions(256, 256, 1).stride(1).padding(0)));
// add_module("al" + std::to_string(hg_module), Conv2d(Conv2dOptions(68, 256, 1).stride(1).padding(0)));
// }
// }
// }
//
// torch::Tensor forward(torch::Tensor x) {
// x = relu(bn1->forward(conv1->forward(x)));
// x = avg_pool2d(conv2->forward(x), 2, 2);
// x = conv3->forward(x);
// x = conv4->forward(x);
//
// auto previous = x;
// std::vector<torch::Tensor> outputs;
//
// for (int64_t i = 0; i < num_modules; ++i) {
// auto hg = _modules["m" + std::to_string(i)]->as<HourGlass>()->forward(previous);
//
// auto ll = _modules["top_m_" + std::to_string(i)]->as<ConvBlock>()->forward(hg);
// ll = relu(_modules["bn_end" + std::to_string(i)]->as<nn::BatchNorm2d>()->forward(
// _modules["conv_last" + std::to_string(i)]->as<Conv2d>()->forward(ll)
// ));
//
// auto tmp_out = _modules["l" + std::to_string(i)]->as<Conv2d>()->forward(ll);
// outputs.push_back(tmp_out);
//
// if (i < num_modules - 1) {
// ll = _modules["bl" + std::to_string(i)]->as<Conv2d>()->forward(ll);
// auto tmp_out_ = _modules["al" + std::to_string(i)]->as<Conv2d>()->forward(tmp_out);
// previous = previous + ll + tmp_out_;
// }
// }
//
// return torch::stack(outputs);
// }
//
// Conv2d conv1;
// nn::BatchNorm2d bn1;
// ConvBlock conv2, conv3, conv4;
//};
//
//TORCH_MODULE(FAN);
//
//struct BottleneckImpl : nn::Module {
// static const int expansion = 4;
// Conv2d conv1, conv2, conv3;
// nn::BatchNorm2d bn1, bn2, bn3;
// nn::ReLU relu;
// nn::Sequential downsample;
// int64_t stride;
//
// BottleneckImpl(int64_t inplanes, int64_t planes, int64_t stride = 1, nn::Sequential downsample = nullptr)
// : conv1(Conv2d(Conv2dOptions(inplanes, planes, 1).bias(false))),
// bn1(planes),
// conv2(Conv2d(Conv2dOptions(planes, planes, 3).stride(stride).padding(1).bias(false))),
// bn2(planes),
// conv3(Conv2d(Conv2dOptions(planes, planes * 4, 1).bias(false))),
// bn3(planes * 4),
// relu(true),
// downsample(downsample),
// stride(stride) {
//
// register_module("conv1", conv1);
// register_module("bn1", bn1);
// register_module("conv2", conv2);
// register_module("bn2", bn2);
// register_module("conv3", conv3);
// register_module("bn3", bn3);
// register_module("relu", relu);
// if (downsample != nullptr) {
// register_module("downsample", downsample);
// }
// }
//
// torch::Tensor forward(torch::Tensor x) {
// auto residual = x;
//
// auto out = relu(bn1->forward(conv1->forward(x)));
// out = relu(bn2->forward(conv2->forward(out)));
// out = bn3->forward(conv3->forward(out));
//
// if (downsample != nullptr) {
// residual = downsample->forward(x);
// }
//
// out += residual;
// out = relu(out);
//
// return out;
// }
//};
//
//TORCH_MODULE(Bottleneck);
//
//struct ResNetDepthImpl : nn::Module {
// int64_t inplanes;
// Conv2d conv1;
// nn::BatchNorm2d bn1;
// nn::ReLU relu;
// nn::MaxPool2d maxpool;
// nn::Sequential layer1, layer2, layer3, layer4;
// nn::AvgPool2d avgpool;
// nn::Linear fc;
//
// ResNetDepthImpl(int64_t num_classes = 68)
// : inplanes(64),
// conv1(Conv2d(Conv2dOptions(3 + 68, 64, 7).stride(2).padding(3).bias(false))),
// bn1(64),
// relu(true),
// maxpool(3, 2, 1),
// avgpool(7),
// fc(512 * BottleneckImpl::expansion, num_classes) {
//
// layer1 = _make_layer(64, 3);
// layer2 = _make_layer(128, 8, 2);
// layer3 = _make_layer(256, 36, 2);
// layer4 = _make_layer(512, 3, 2);
//
// register_module("conv1", conv1);
// register_module("bn1", bn1);
// register_module("relu", relu);
// register_module("maxpool", maxpool);
// register_module("layer1", layer1);
// register_module("layer2", layer2);
// register_module("layer3", layer3);
// register_module("layer4", layer4);
// register_module("avgpool", avgpool);
// register_module("fc", fc);
// }
//
// nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1) {
// nn::Sequential downsample = nullptr;
// if (stride != 1 || inplanes != planes * BottleneckImpl::expansion) {
// downsample = nn::Sequential(
// Conv2d(Conv2dOptions(inplanes, planes * BottleneckImpl::expansion, 1).stride(stride).bias(false)),
// nn::BatchNorm2d(planes * BottleneckImpl::expansion)
// );
// }
//
// nn::Sequential layers;
// layers->push_back(Bottleneck(inplanes, planes, stride, downsample));
// inplanes = planes * BottleneckImpl::expansion;
// for (int64_t i = 1; i < blocks; ++i) {
// layers->push_back(Bottleneck(inplanes, planes));
// }
//
// return layers;
// }
//
// torch::Tensor forward(torch::Tensor x) {
// x = relu(bn1->forward(conv1->forward(x)));
// x = maxpool->forward(x);
//
// x = layer1->forward(x);
// x = layer2->forward(x);
// x = layer3->forward(x);
// x = layer4->forward(x);
//
// x = avgpool->forward(x);
// x = x.view({x.size(0), -1});
// x = fc->forward(x);
//
// return x;
// }
//};
//
//TORCH_MODULE(ResNetDepth);
namespace coastal {
class BaseModuleImpl : public torch::nn::Module {
public:
template<typename ModuleType>
std::shared_ptr<ModuleType> get_module(const std::string &name) {
auto modules = named_modules();
std::shared_ptr<torch::nn::Module> *module = modules.find(name);
if (module == nullptr) {
throw std::runtime_error("Module not found");
}
return std::dynamic_pointer_cast<ModuleType>(*module);
}
};
}
class MyModel : public coastal::BaseModuleImpl {
public:
MyModel() {
// 注册子模块
auto conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 5)));
std::cout << "conv1: " << conv1 << std::endl;
}
torch::Tensor forward(torch::Tensor x) {
return get_module<torch::nn::Conv2dImpl>(std::string("conv1"))->forward(x);
}
};
int main() {
try {
// 创建模型实例
auto model = std::make_shared<MyModel>();
// 创建输入张量
auto input = torch::randn({1, 1, 28, 28});
// 调用 forward 方法
auto output = model->forward(input);
// 输出结果张量的尺寸
std::cout << output.sizes() << std::endl;
} catch (const std::exception &e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化