加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
0025-Pin-server-Fix-Pass-DoOptimize-method-and-struct-sel.patch 13.76 KB
一键复制 编辑 原始数据 按行查看 历史
From eb3fe5b526da6c6f514d40c688c72771b4a8b766 Mon Sep 17 00:00:00 2001
From: dingguangya <dingguangya1@huawei.com>
Date: Wed, 15 Mar 2023 10:00:18 +0800
Subject: [PATCH] [Pin-server] Fix Pass DoOptimize method and struct
self-contained
diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td
index 28710cf..c06a4ad 100644
--- a/include/Dialect/PluginOps.td
+++ b/include/Dialect/PluginOps.td
@@ -59,7 +59,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
let arguments = (ins UI64Attr:$id,
StrAttr:$funcName,
OptionalAttr<BoolAttr>:$declaredInline,
- TypeAttr:$type);
+ TypeAttr:$type,
+ OptionalAttr<BoolAttr>:$validType);
let regions = (region AnyRegion:$bodyRegion);
// Add custom build methods for the operation. These method populates
@@ -69,7 +70,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
OpBuilderDAG<(ins "uint64_t":$id,
"StringRef":$funcName,
"bool":$declaredInline,
- "Type":$type)>
+ "Type":$type, "bool":$validType)>,
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$funcName, "bool":$declaredInline, "bool":$validType)>
];
let extraClassDeclaration = [{
diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h
index 9693294..5c5f54a 100644
--- a/include/Dialect/PluginTypes.h
+++ b/include/Dialect/PluginTypes.h
@@ -190,12 +190,10 @@ public:
static bool isValidElementType(Type type);
- static PluginStructType get(MLIRContext *context, std::string name, ArrayRef<Type> elements, ArrayRef<std::string> elemNames);
+ static PluginStructType get(MLIRContext *context, std::string name, ArrayRef<std::string> elemNames);
std::string getName();
- ArrayRef<Type> getBody();
-
ArrayRef<std::string> getElementNames();
}; // class PluginStructType
diff --git a/include/user/StructReorder.h b/include/user/StructReorder.h
index d3e4486..573dc3c 100644
--- a/include/user/StructReorder.h
+++ b/include/user/StructReorder.h
@@ -34,10 +34,10 @@ public:
int DoOptimize()
{
- uint64_t *fun = (uint64_t *)GetFuncAddr();
+ uint64_t fun = (uint64_t)GetFuncAddr();
return DoOptimize(fun);
}
- int DoOptimize(uint64_t *fun);
+ int DoOptimize(uint64_t fun);
};
}
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 40dae3a..a3462ed 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -118,15 +118,26 @@ bool CGnodeOp::IsRealSymbol()
// ===----------------------------------------------------------------------===//
void FunctionOp::build(OpBuilder &builder, OperationState &state,
- uint64_t id, StringRef funcName, bool declaredInline, Type type)
+ uint64_t id, StringRef funcName, bool declaredInline, Type type, bool validType)
{
state.addRegion();
state.addAttribute("id", builder.getI64IntegerAttr(id));
state.addAttribute("funcName", builder.getStringAttr(funcName));
state.addAttribute("declaredInline", builder.getBoolAttr(declaredInline));
+ state.addAttribute("validType", builder.getBoolAttr(validType));
if (type) state.addAttribute("type", TypeAttr::get(type));
}
+void FunctionOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef funcName, bool declaredInline, bool validType)
+{
+ state.addRegion();
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("funcName", builder.getStringAttr(funcName));
+ state.addAttribute("declaredInline", builder.getBoolAttr(declaredInline));
+ state.addAttribute("validType", builder.getBoolAttr(validType));
+}
+
Type FunctionOp::getResultType()
{
PluginIR::PluginFunctionType resultType = type().dyn_cast<PluginIR::PluginFunctionType>();
diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp
index 6035c4f..416dbd7 100644
--- a/lib/Dialect/PluginTypes.cpp
+++ b/lib/Dialect/PluginTypes.cpp
@@ -146,29 +146,28 @@ namespace detail {
};
struct PluginStructTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<std::string, ArrayRef<Type>, ArrayRef<std::string>>;
+ using KeyTy = std::tuple<std::string, ArrayRef<std::string>>;
- PluginStructTypeStorage(std::string name, ArrayRef<Type> elements, ArrayRef<std::string> elemNames)
- : name(name), elements(elements), elemNames(elemNames) {}
+ PluginStructTypeStorage(std::string name, ArrayRef<std::string> elemNames)
+ : name(name), elemNames(elemNames) {}
static PluginStructTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key)
{
return new (allocator.allocate<PluginStructTypeStorage>())
- PluginStructTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key)), allocator.copyInto(std::get<2>(key)));
+ PluginStructTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key)));
}
static unsigned hashKey(const KeyTy &key) {
// LLVM doesn't like hashing bools in tuples.
- return llvm::hash_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key));
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
}
bool operator==(const KeyTy &key) const
{
- return std::make_tuple(name, elements, elemNames) == key;
+ return std::make_tuple(name, elemNames) == key;
}
std::string name;
- ArrayRef<Type> elements;
ArrayRef<std::string> elemNames;
};
}
@@ -493,9 +492,9 @@ bool PluginStructType::isValidElementType(Type type) {
return !type.isa<PluginVoidType, PluginFunctionType>();
}
-PluginStructType PluginStructType::get(MLIRContext *context, std::string name, ArrayRef<Type> elements, ArrayRef<std::string> elemNames)
+PluginStructType PluginStructType::get(MLIRContext *context, std::string name, ArrayRef<std::string> elemNames)
{
- return Base::get(context, name, elements, elemNames);
+ return Base::get(context, name, elemNames);
}
std::string PluginStructType::getName()
@@ -503,11 +502,6 @@ std::string PluginStructType::getName()
return getImpl()->name;
}
-ArrayRef<Type> PluginStructType::getBody()
-{
- return getImpl()->elements;
-}
-
ArrayRef<std::string> PluginStructType::getElementNames()
{
return getImpl()->elemNames;
diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp
index c7ce788..0392b75 100755
--- a/lib/PluginServer/PluginJson.cpp
+++ b/lib/PluginServer/PluginJson.cpp
@@ -58,12 +58,6 @@ Json::Value PluginJson::TypeJsonSerialize (PluginIR::PluginTypeBase type)
std::string tyName = Ty.getName();
item["structtype"] = tyName;
size_t paramIndex = 0;
- ArrayRef<Type> paramsType = Ty.getBody();
- for (auto ty :paramsType) {
- std::string paramStr = "elemType" + std::to_string(paramIndex++);
- item["structelemType"][paramStr] = TypeJsonSerialize(ty.dyn_cast<PluginIR::PluginTypeBase>());
- }
- paramIndex = 0;
ArrayRef<std::string> paramsNames = Ty.getElementNames();
for (auto name :paramsNames) {
std::string paramStr = "elemName" + std::to_string(paramIndex++);
@@ -281,7 +275,6 @@ bool PluginJson::ProcessBlock(mlir::Block* block, mlir::Region& rg, const Json::
} else if (opCode == AsmOp::getOperationName().str()) {
AsmOpJsonDeserialize(opJson.toStyledString());
} else if (opCode == SwitchOp::getOperationName().str()) {
- printf("switch op deserialize\n");
SwitchOpJsonDeserialize(opJson.toStyledString());
} else if (opCode == GotoOp::getOperationName().str()) {
GotoOpJsonDeSerialize(opJson.toStyledString());
@@ -368,9 +361,17 @@ void PluginJson::FuncOpJsonDeSerialize(
bool declaredInline = false;
if (funcAttributes["declaredInline"] == "1") declaredInline = true;
auto location = opBuilder.getUnknownLoc();
- PluginIR::PluginTypeBase retType = TypeJsonDeSerialize(node["retType"].toStyledString());
- FunctionOp fOp = opBuilder.create<FunctionOp>(
- location, id, funcAttributes["funcName"], declaredInline, retType);
+ bool validType = false;
+ FunctionOp fOp;
+ if (funcAttributes["validType"] == "1") {
+ validType = true;
+ PluginIR::PluginTypeBase retType = TypeJsonDeSerialize(node["retType"].toStyledString());
+ fOp = opBuilder.create<FunctionOp>(
+ location, id, funcAttributes["funcName"], declaredInline, retType, validType);
+ } else {
+ fOp = opBuilder.create<FunctionOp>(location, id, funcAttributes["funcName"], declaredInline, validType);
+ }
+
mlir::Region &bodyRegion = fOp.bodyRegion();
Json::Value regionJson = node["region"];
Json::Value::Members bbMember = regionJson.getMemberNames();
@@ -445,21 +446,14 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data)
baseType = PluginIR::PluginFunctionType::get(PluginServer::GetInstance()->GetContext(), returnTy, typelist);
} else if (id == static_cast<uint64_t>(PluginIR::StructTyID)) {
std::string tyName = type["structtype"].asString();
- llvm::SmallVector<Type> typelist;
- Json::Value::Members elemTypeNum = type["structelemType"].getMemberNames();
- for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) {
- string Key = "elemType" + std::to_string(paramIndex);
- mlir::Type paramTy = TypeJsonDeSerialize(type["structelemType"][Key].toStyledString());
- typelist.push_back(paramTy);
- }
llvm::SmallVector<std::string> names;
Json::Value::Members elemNameNum = type["structelemName"].getMemberNames();
- for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) {
+ for (size_t paramIndex = 0; paramIndex < elemNameNum.size(); paramIndex++) {
std::string Key = "elemName" + std::to_string(paramIndex);
std::string elemName = type["structelemName"][Key].asString();
names.push_back(elemName);
}
- baseType = PluginIR::PluginStructType::get(PluginServer::GetInstance()->GetContext(), tyName, typelist, names);
+ baseType = PluginIR::PluginStructType::get(PluginServer::GetInstance()->GetContext(), tyName, names);
}
else {
if (PluginTypeId == PluginIR::VoidTyID) {
diff --git a/user/ArrayWidenPass.cpp b/user/ArrayWidenPass.cpp
index 591ecdb..db8223a 100644
--- a/user/ArrayWidenPass.cpp
+++ b/user/ArrayWidenPass.cpp
@@ -1475,13 +1475,13 @@ static void convertToNewLoop(LoopOp* loop, FunctionOp* funcOp)
return;
}
-static void ProcessArrayWiden(uint64_t *fun)
+static void ProcessArrayWiden(uint64_t fun)
{
std::cout << "Running first pass, awiden\n";
PluginServerAPI pluginAPI;
- FunctionOp funcOp = pluginAPI.GetFunctionOpById((uint64_t)fun);
+ FunctionOp funcOp = pluginAPI.GetFunctionOpById(fun);
if (funcOp == nullptr) return;
context = funcOp.getOperation()->getContext();
@@ -1498,7 +1498,7 @@ static void ProcessArrayWiden(uint64_t *fun)
}
}
-int ArrayWidenPass::DoOptimize(uint64_t *fun)
+int ArrayWidenPass::DoOptimize(uint64_t fun)
{
ProcessArrayWiden(fun);
return 0;
diff --git a/user/InlineFunctionPass.cpp b/user/InlineFunctionPass.cpp
index a51e6fe..cc4c7c4 100755
--- a/user/InlineFunctionPass.cpp
+++ b/user/InlineFunctionPass.cpp
@@ -30,7 +30,7 @@ static void UserOptimizeFunc(void)
vector<FunctionOp> allFunction = pluginAPI.GetAllFunc();
int count = 0;
for (size_t i = 0; i < allFunction.size(); i++) {
- if (allFunction[i].declaredInlineAttr().getValue())
+ if (allFunction[i] && allFunction[i].declaredInlineAttr().getValue())
count++;
}
fprintf(stderr, "declaredInline have %d functions were declared.\n", count);
diff --git a/user/LocalVarSummeryPass.cpp b/user/LocalVarSummeryPass.cpp
index c336487..04d4ac9 100755
--- a/user/LocalVarSummeryPass.cpp
+++ b/user/LocalVarSummeryPass.cpp
@@ -44,6 +44,7 @@ static void LocalVarSummery(void)
}
mlir::Plugin::FunctionOp funcOp = allFunction[i];
printf("func name is :%s\n", funcOp.funcNameAttr().getValue().str().c_str());
+ if (funcOp.validTypeAttr().getValue()) {
mlir::Type dgyty = funcOp.type();
if (auto ty = dgyty.dyn_cast<PluginIR::PluginFunctionType>()) {
if(auto stTy = ty.getReturnType().dyn_cast<PluginIR::PluginStructType>()) {
@@ -69,6 +70,7 @@ static void LocalVarSummery(void)
printf("\n Param type id : %d\n", ty.dyn_cast<PluginIR::PluginTypeBase>().getPluginTypeID());
}
}
+ }
for (size_t j = 0; j < decls.size(); j++) {
auto decl = decls[j];
string name = decl.symNameAttr().getValue().str();
diff --git a/user/StructReorder.cpp b/user/StructReorder.cpp
index f4e824e..ab2f086 100644
--- a/user/StructReorder.cpp
+++ b/user/StructReorder.cpp
@@ -177,10 +177,10 @@ static bool handle_type(PluginIR::PluginTypeBase type)
return false;
}
-static void ProcessStructReorder(uint64_t *fun)
+static void ProcessStructReorder(uint64_t fun)
{
fprintf(stderr, "Running first pass, structreoder\n");
-
+
PluginServerAPI pluginAPI;
vector<CGnodeOp> allnodes = pluginAPI.GetAllCGnode();
fprintf(stderr, "allnodes size is %d\n", allnodes.size());
@@ -222,7 +222,7 @@ static void ProcessStructReorder(uint64_t *fun)
}
-int StructReorderPass::DoOptimize(uint64_t *fun)
+int StructReorderPass::DoOptimize(uint64_t fun)
{
ProcessStructReorder(fun);
return 0;
--
2.33.0
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化