加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
0012-Pin-server-Support-Vectortype.patch 5.97 KB
一键复制 编辑 原始数据 按行查看 历史
From 50801883d5d5e89d44026ead6e38e149caa346b2 Mon Sep 17 00:00:00 2001
From: d00573793 <dingguangya1@huawei.com>
Date: Sat, 25 Feb 2023 15:31:02 +0800
Subject: [PATCH 12/23] [Pin-server] Support Vectortype
diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h
index 3f7f14b..9693294 100644
--- a/include/Dialect/PluginTypes.h
+++ b/include/Dialect/PluginTypes.h
@@ -55,8 +55,7 @@ enum PluginTypeID {
PointerTyID, ///< Pointers
StructTyID, ///< Structures
ArrayTyID, ///< Arrays
- FixedVectorTyID, ///< Fixed width SIMD vector type
- ScalableVectorTyID ///< Scalable SIMD vector type
+ VectorTyID, ///< Arrays
};
class PluginTypeBase : public Type {
@@ -146,6 +145,21 @@ public:
unsigned getNumElements();
}; // class PluginArrayType
+class PluginVectorType : public Type::TypeBase<PluginVectorType, PluginTypeBase, detail::PluginTypeAndSizeStorage> {
+public:
+ using Base::Base;
+
+ PluginTypeID getPluginTypeID ();
+
+ static bool isValidElementType(Type type);
+
+ static PluginVectorType get(MLIRContext *context, Type elementType, unsigned numElements);
+
+ Type getElementType();
+
+ unsigned getNumElements();
+}; // class PluginVectorType
+
class PluginFunctionType : public Type::TypeBase<PluginFunctionType, PluginTypeBase, detail::PluginFunctionTypeStorage> {
public:
using Base::Base;
diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp
index ba8e4fe..95a78da 100644
--- a/lib/Dialect/PluginDialect.cpp
+++ b/lib/Dialect/PluginDialect.cpp
@@ -38,6 +38,7 @@ void PluginDialect::initialize()
PluginIR::PluginFloatType,
PluginIR::PluginPointerType,
PluginIR::PluginArrayType,
+ PluginIR::PluginVectorType,
PluginIR::PluginFunctionType,
PluginIR::PluginStructType,
PluginIR::PluginBooleanType,
diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp
index 337fc49..89e4b1a 100644
--- a/lib/Dialect/PluginTypes.cpp
+++ b/lib/Dialect/PluginTypes.cpp
@@ -199,6 +199,9 @@ PluginTypeID PluginTypeBase::getPluginTypeID ()
if (auto Ty = dyn_cast<PluginIR::PluginArrayType>()) {
return Ty.getPluginTypeID ();
}
+ if (auto Ty = dyn_cast<PluginIR::PluginVectorType>()) {
+ return Ty.getPluginTypeID ();
+ }
if (auto Ty = dyn_cast<PluginIR::PluginFunctionType>()) {
return Ty.getPluginTypeID ();
}
@@ -406,6 +409,35 @@ unsigned PluginArrayType::getNumElements()
return getImpl()->numElements;
}
+// ===----------------------------------------------------------------------===//
+// Plugin Vector Type
+// ===----------------------------------------------------------------------===//
+
+PluginTypeID PluginVectorType::getPluginTypeID()
+{
+ return PluginTypeID::ArrayTyID;
+}
+
+bool PluginVectorType::isValidElementType(Type type)
+{
+ return type.isa<PluginIntegerType, PluginFloatType>();
+}
+
+PluginVectorType PluginVectorType::get(MLIRContext *context, Type elementType, unsigned numElements)
+{
+ return Base::get(context, elementType, numElements);
+}
+
+Type PluginVectorType::getElementType()
+{
+ return getImpl()->elementType;
+}
+
+unsigned PluginVectorType::getNumElements()
+{
+ return getImpl()->numElements;
+}
+
// ===----------------------------------------------------------------------===//
// Plugin Function Type
// ===----------------------------------------------------------------------===//
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index e3435b0..a471ddf 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -347,6 +347,8 @@ PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type)
return PluginIR::PluginTypeID::PointerTyID;
} else if (type == "ArrayTy") {
return PluginIR::PluginTypeID::ArrayTyID;
+ } else if (type == "VectorTy") {
+ return PluginIR::PluginTypeID::VectorTyID;
} else if (type == "FunctionTy") {
return PluginIR::PluginTypeID::FunctionTyID;
} else if (type == "StructTy") {
diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp
index e1beddf..4d41351 100755
--- a/lib/PluginServer/PluginJson.cpp
+++ b/lib/PluginServer/PluginJson.cpp
@@ -343,6 +343,10 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data)
mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString());
uint64_t elemNum = GetID(type["arraysize"]);
baseType = PluginIR::PluginArrayType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum);
+ } else if (id == static_cast<uint64_t>(PluginIR::VectorTyID)) {
+ mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString());
+ uint64_t elemNum = GetID(type["vectorelemnum"]);
+ baseType = PluginIR::PluginVectorType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum);
} else if (id == static_cast<uint64_t>(PluginIR::FunctionTyID)) {
mlir::Type returnTy = TypeJsonDeSerialize(type["fnreturntype"].toStyledString());
llvm::SmallVector<Type> typelist;
diff --git a/user/LocalVarSummeryPass.cpp b/user/LocalVarSummeryPass.cpp
index 2e157e3..ccee9f7 100755
--- a/user/LocalVarSummeryPass.cpp
+++ b/user/LocalVarSummeryPass.cpp
@@ -61,6 +61,11 @@ static void LocalVarSummery(void)
printf("\n struct argname is : %s\n", pName.c_str());
}
}
+ if(auto stTy = ty.getReturnType().dyn_cast<PluginIR::PluginVectorType>()) {
+ printf("func return type is PluginVectorType\n");
+ printf(" vector elem num : %d\n", stTy.getNumElements());
+ printf(" vector elem type id : %d\n", stTy.getElementType().dyn_cast<PluginIR::PluginTypeBase>().getPluginTypeID());
+ }
size_t paramIndex = 0;
llvm::ArrayRef<mlir::Type> paramsType = ty.getParams();
for (auto ty : ty.getParams()) {
--
2.33.0
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化