-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtools.py
More file actions
31 lines (29 loc) · 1.23 KB
/
tools.py
File metadata and controls
31 lines (29 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from rdkit import Chem
from torch_geometric.data import Data
import torch
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
# Create the edge_index [2, 2*num_edges]
edge_index = [[],[]]
for bond in mol.GetBonds():
beginatomidx, endatomidx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_index[0].extend([beginatomidx, endatomidx])
edge_index[1].extend([endatomidx, beginatomidx])
edge_index = torch.tensor(edge_index, dtype=torch.long)
# Create the nodes features [num_nodes, num_node_features]
nodes_features = []
for atom in mol.GetAtoms():
atom_features = [atom.GetAtomicNum(),
atom.GetDegree(),
atom.GetFormalCharge(),
atom.IsInRing(),
atom.GetIsAromatic(),
int(atom.GetHybridization()),
atom.GetTotalNumHs(),
atom.IsInRingSize(5),
atom.IsInRingSize(6)]
nodes_features.append(atom_features)
nodes_features = torch.tensor(nodes_features, dtype=torch.float)
# Create the data object
data = Data(x=nodes_features, edge_index=edge_index)
return data