簡介:本文介紹如何將Pytorch Geometric運行過程中得到的data或者tensor轉換成networkx可以處理的格式,進行可視化。
其中Pytorch geometric的地址為:https://pytorch-geometric.readthedocs.io/en/latest/
方法一根據networkx的文檔:https://networkx.github.io/documentation/networkx-1.10/reference/generated/networkx.drawing.nx_pylab.draw_networkx.html
我們可以寫出來一個非常簡單的例子,如下(代碼可以左右滑動):
import networkx as nx
import matplotlib.pyplot as plt
G = nx.Graph()
edge_index = [(1, 2), (1, 3), (2, 3), (3, 4)]
G.add_edges_from(edge_index)
nx.draw(G)
plt.show()
運行程序之後,可以得到下面的圖,(偷了一個懶,沒有加label之類的信息)
這個例子給我們的啟發:我們可以將PyG得到的edge_index轉成numpy的格式,然後傳給nx,下面是根據這個寫的一個函數:
在PyG中,邊的表示放在了edge_index中,由一個二維的矩陣構成,edge_index[0]表示節點edge_index[1]表示另一個節點。
def draw(edge_index, name=None):
G = nx.Graph(node_size=15, font_size=8)
src = edge_index[0].cpu().numpy()
dst = edge_index[1].cpu().numpy()
edgelist = zip(src, dst)
for i, j in edgelist:
G.add_edge(i, j)
plt.figure(figsize=(20, 14))
nx.draw_networkx(G)
plt.savefig('{}.png'.format(name if name else 'path'))
註:該方法可以用於模型中的forward函數,用於分析cov,pool等操作
再寫一個與上面思想一致,可以直接運行的一個例子
from torch_geometric.datasets import KarateClub
import networkx as nx
import matplotlib.pyplot as plt
dataset = KarateClub()
edge, x, y = dataset[0]
x_np = x[1].numpy()
y_np = y[1].numpy()
g = nx.Graph()
name, edgeinfo = edge
src = edgeinfo[0].numpy()
dst = edgeinfo[1].numpy()
edgelist = zip(src, dst)
for i, j in edgelist:
g.add_edge(i, j)
nx.draw(g)
plt.savefig('test.png')
plt.show()
其實,torch_geometric.utils中已經帶有to_networkx的函數可以直接將格式為torch_geometric.data.Data 的數據轉換為networkx.DiGraph的格式,該格式可以直接networkx處理,但是我們提前要得到torch_geometric.data.Data的數據格式
import networkx as nx
from torch_geometric.utils.convert import to_networkx
def draw(Data):
G = to_networkx(Data)
nx.draw(G)
plt.savefig("path.png")
plt.show()
註:這個一般可以用於在model加載數據之前數據的分析,比如下面的例子
for i, data in enumerate(train_loader):
draw(data)
data = data.to(args.device)
out = model(data)
loss = F.nll_loss(out, data.y)
print("Training loss:{}".format(loss.item()))
loss.backward()
optimizer.step()
optimizer.zero_grad()
上面的函數是在graph classification進行分析的一段代碼,可以把batch size的設置為1,那麼for循環中得到就是一個graph的數據,在把數據feed給模型之前,我們可以通過該方法分析一下原始的數據是什麼樣子的。
如果有什麼疑問加,請聯繫(同時,可以加入微信群一起學習哈~)