加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
simple_print.lua 3.24 KB
一键复制 编辑 原始数据 按行查看 历史
Clement Farabet 提交于 2015-09-04 16:31 . Whitespace cleanup.
local function removeNodeFromEdges(node_id, edges)
local from_nodes = {}
local to_nodes = {}
-- remove edges
local idx = 1
while idx <= #edges do
local edge = edges[idx]
if edge.source == node_id then
local to_node = edges[idx].target
table.insert(to_nodes, to_node)
table.remove(edges, idx)
elseif edge.target == node_id then
local from_node = edges[idx].source
table.insert(from_nodes, from_node)
table.remove(edges, idx)
else
idx = idx + 1
end
end
-- add new edges
for _, f in pairs(from_nodes) do
for _, t in pairs(to_nodes) do
local edge = {source = f, target= t}
table.insert(edges, edge)
end
end
return edges
end
local function isNodeGood(node)
return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
end
local function reIndexNodes(nodes, edges)
-- make reverse map
local rev_map = {}
for idx = 1, #nodes do
rev_map[nodes[idx].id] = idx
nodes[idx].id = idx
end
for idx = 1, #edges do
local edge = edges[idx]
edge.source = rev_map[edge.source]
edge.target = rev_map[edge.target]
end
return nodes, edges
end
local function cleanGraph(nodes, edges)
local idx = 1
while idx <= #nodes do
local node = nodes[idx]
if isNodeGood(node.orig_node) then
idx = idx + 1
else
local id = node.id
table.remove(nodes, idx)
edges = removeNodeFromEdges(id, edges)
end
end
return reIndexNodes(nodes, edges)
end
local function loadGraph(graph)
local nodes = {}
local edges = {}
for _, node in ipairs(graph.nodes) do
local idx = node.id
table.insert(nodes, {id=idx, orig_node = node} )
for ich = 1, #node.children do
table.insert( edges, {source = idx, target = node.children[ich].id})
end
end
nodes, edges = cleanGraph(nodes, edges)
return nodes , edges
end
local M = {}
function M.todot( graph, title )
local nodes, edges = loadGraph(graph)
local str = {}
table.insert(str,'digraph G {\n')
if title then
table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
end
table.insert(str,'node [shape = oval]; ')
local nodelabels = {}
for i,node in ipairs(nodes) do
local true_node = node.orig_node
local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
nodelabels[i] = 'n' .. true_node.id
table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
end
table.insert(str,'\n')
for i,edge in ipairs(edges) do
table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
end
table.insert(str,'}')
return table.concat(str,'')
end
function M.dot(g,title,fname)
local gv = M.todot(g, title)
local fngv = (fname or os.tmpname()) .. '.dot'
local fgv = io.open(fngv,'w')
fgv:write(gv)
fgv:close()
local fnsvg = (fname or os.tmpname()) .. '.svg'
os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
if not fname then
require 'qtsvg'
local qs = qt.QSvgWidget(fnsvg)
qs:show()
os.remove(fngv)
os.remove(fnsvg)
-- print(fngv,fnpng)
return qs
end
end
return M
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化