''' ██████  ██████  ███████  █████  ████████  ██████  █████  ███  ██ ███████  ██       ██   ██ ██      ██   ██    ██     ██   ██ ██   ██ ████  ██ ██       ██  ███ ██████  █████  ███████  ██  ██  ██ ███████ ██ ██  ██ █████  ██  ██ ██   ██ ██     ██   ██  ██  ██  ██ ██   ██ ██  ██ ██ ██      ██████  ██  ██ ███████ ██  ██  ██  ██████  ██  ██ ██   ████ ███████  This code is licensed under CC BY-SA ''' import matplotlib.pyplot as plt import csv from rdkit import Chem import time import numpy as np import matplotlib.pyplot as plt from rdkit.Chem import Descriptors # Define the compound classes compound_classes = { 'Neutral molecules': [], 'Aromatic': [], 'Non aromatic': [], 'Cyclic nitramines': [], 'Acyclic nitramines': [], 'Molecules with nitro groups': [], 'Molecules without nitro groups': [], 'Ethers and esters': [], 'Peroxides': [], 'Molecules with -C(NO2)3 groups': [], 'Azides': [], 'Nitrate esters': [], } oxygen_balances = [] deviations = [] mw=[] # Read the SMILES strings and reference/predicted values from a CSV file with open('C:/Path/to/Treats.csv') as f: reader = csv.reader(f) next(reader) # skip header row for row in reader: smiles = row[0] # print(row[2]) ref_value = float(row[1]) print(ref_value) pred_value = float(row[2]) # Determine the compound class of the molecule mol = Chem.MolFromSmiles(smiles) if mol is None: continue # skip invalid SMILES strings if str('.') not in smiles: compound_classes['Neutral molecules'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[n,c,o]')): compound_classes['Aromatic'].append((ref_value, pred_value)) else: compound_classes['Non aromatic'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[N,n][N+](=O)[O-]')): if mol.HasSubstructMatch(Chem.MolFromSmarts(' [NXr][NXr](=O)[O-]')): compound_classes['Cyclic nitramines'].append((ref_value, pred_value)) else: compound_classes['Acyclic nitramines'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[N+](=O)[O-]')): compound_classes['Molecules with nitro groups'].append((ref_value, pred_value)) else: compound_classes['Molecules without nitro groups'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[#6][OX2H]')): compound_classes['Ethers and esters'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[#8][#8]')): compound_classes['Peroxides'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('C([N+](=O)[O-])([N+](=O)[O-])[N+](=O)[O-]')): compound_classes['Molecules with -C(NO2)3 groups'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('[N-]=[N+]=[N-]')): compound_classes['Azides'].append((ref_value, pred_value)) if mol.HasSubstructMatch(Chem.MolFromSmarts('O[N+]([O-])=O')): compound_classes['Nitrate esters'].append((ref_value, pred_value)) # Extract the oxygen balance values elements = mol.GetAtoms() num_c = sum(1 for elem in elements if elem.GetSymbol() == 'C') num_h = sum(1 for elem in elements if elem.GetSymbol() == 'H') num_o = sum(1 for elem in elements if elem.GetSymbol() == 'O') mwt = Descriptors.ExactMolWt(mol) ob = ((num_o - 2 * num_c - num_h / 2) * 1600) / mwt mw.append(mwt) oxygen_balances.append(ob) deviations.append(100*(pred_value - ref_value)/ref_value) # Calculate MAE, MAPE, and R-squared for each compound class print("Compound Class, Number of occurrences, MAE, MAPE, R2") #print(headers) for compound_class, values in compound_classes.items(): if len(values) == 0: continue ref_values, pred_values = zip(*values) avg_deviation_abs = sum(abs(ref - pred) for ref, pred in values) / len(values) avg_deviation_percent = sum(abs(ref - pred) / ref * 100 for ref, pred in values) / len(values) ssr = np.sum((np.array(ref_values) - np.array(pred_values)) ** 2) # Sum of squared residuals sst = np.sum((np.array(ref_values) - np.mean(np.array(ref_values))) ** 2) # Total sum of squares r2 = 1 - (ssr / sst) # R-squared value compound_class_size=len(ref_values) print(compound_class,',',compound_class_size,',', avg_deviation_abs, ',',avg_deviation_percent,',', r2) plt.figure(figsize=(8, 6)) plt.scatter(mw, deviations, s=10, color='#FF91AF') plt.xlabel('Molarweight') plt.ylabel('MAPE') plt.title('MAPE vs Molarweight ') plt.xlim(min(mw), max(mw)) plt.ylim(min(deviations), max(deviations)) plt.show()