ml-solarwind / plots_shapley.py
plots_shapley.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 24 13:59:57 2020

@author: baum_c4
"""
import numpy as np
import pandas as pd

import pickle
import matplotlib.colors as mcol
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib

with open('plot_learningset.pickle', 'rb') as f:# this datafile contains the original features (not standardized) _o, the medians _m, and standard deviation _s
    [learnvector_o,learnvector_m,learnvector_s,timevector]=pickle.load(f)
    
    
with open('plot_Shapley.pickle', 'rb') as f:
    shap_values=pickle.load(f)    
    
    
font = { 'size'   : 12}

matplotlib.rc('font', **font)

cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName",["r","m","b"])
f,ax = plt.subplots(1,3,figsize=(20*0.39,8*0.39))

c=ax[0].scatter(learnvector_o[0]/6371,shap_values[:,0]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
#cbar=f.colorbar(c,ax=ax[0])
#cbar.set_label('SW speed, X direction [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[0].set_xlabel('ACE position in X [Re]')
ax[0].set_ylabel('Impact on model output [min]')
ax[0].set_xlim(215,255)
ax[0].set_xticks(np.arange(215, 256, step=10))
#ax[0].set_title('ACE Position in GSE')
#plt.xticks(np.arange(215, 256, step=10))
#plt.yticks(np.arange(-45, 46, step=15))
ax[0].grid(True)



c=ax[1].scatter(learnvector_o[1]/6371,shap_values[:,1]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
#cbar=plt.colorbar(c,ax=ax[1])
#cbar.set_label('SW speed, X direction [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[1].set_xlabel('ACE position in Y [Re]')
ax[1].set_xlim(-50,50)
#ax[1].set_ylabel('Impact on model output [min]')
#ax[0].set_title('ACE Position in GSE')
#plt.xticks(np.arange(215, 256, step=10))
#plt.yticks(np.arange(-45, 46, step=15))
ax[1].grid(True)


c=ax[2].scatter(learnvector_o[2]/6371,shap_values[:,2]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
cbar=plt.colorbar(c,ax=ax[2])
cbar.set_label('Solar wind v$_X$ [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[2].set_xlabel('ACE position in Z [Re]')
ax[2].set_xlim(-25,25)
#ax[2].set_ylabel('Impact on model output [min]')
#ax[0].set_title('ACE Position in GSE')
#plt.xticks(np.arange(215, 256, step=10))
#plt.yticks(np.arange(-45, 46, step=15))
ax[2].grid(True)

plt.tight_layout()
plt.savefig('plot_shapley_pos.pdf',bbox_inches='tight')
plt.show()    



cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName",["r","m","b"])
f,ax = plt.subplots(1,3,figsize=(20*0.39,8*0.39))

c=ax[0].scatter(learnvector_o[4],shap_values[:,4]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
#cbar=f.colorbar(c,ax=ax[0])
#cbar.set_label('SW speed, X direction [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[0].set_xlabel('Solar wind v$_Y$ [km/s]')
ax[0].set_ylabel('Impact on model output [min]')
ax[0].set_xlim(-150,150)
#ax[0].set_title('ACE Position in GSE')
#plt.xticks(np.arange(215, 256, step=10))
#plt.yticks(np.arange(-45, 46, step=15))
ax[0].grid(True)



c=ax[1].scatter(learnvector_o[5],shap_values[:,5]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
#cbar=plt.colorbar(c,ax=ax[1])
#cbar.set_label('SW speed, X direction [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[1].set_xlabel('Solar wind v$_Z$ [km/s]')
#ax[1].set_ylabel('Impact on model output [min]')
#ax[0].set_title('ACE Position in GSE')
#ax[1].set_xticks(np.arange(-200, 200, step=200))
ax[1].set_xlim(-200,200)
#plt.yticks(np.arange(-45, 46, step=15))
ax[1].grid(True)


c=ax[2].scatter(learnvector_o[6],shap_values[:,6]/60,c=learnvector_o[3],s=3,vmin=-900,vmax=-300,cmap=cm1) 
cbar=plt.colorbar(c,ax=ax[2])
cbar.set_label('Solar wind v$_X$ [km/s]')
#ax.set_rmax(2)
#ax.set_rticks([0.5, 1, 1.5, 2])  # less radial ticks
#ax.set_rlabel_position(-22.5)  # get radial labels away from plotted line
ax[2].set_xlabel('DST [nT]')
ax[2].set_xlim(-125,50)
ax[2].set_xticks(np.arange(-100,51,50))
#ax[2].set_ylabel('Impact on model output [min]')
#ax[0].set_title('ACE Position in GSE')
#plt.xticks(np.arange(215, 256, step=10))
#plt.yticks(np.arange(-45, 46, step=15))
ax[2].grid(True)

plt.tight_layout()
plt.savefig('plot_shapley_speed.pdf',bbox_inches='tight')
plt.show()    



f,ax = plt.subplots(1,1,figsize=(8*0.39,8*0.39))

c=ax.scatter(learnvector_o[3],shap_values[:,3]/60,s=3,color='k') 
ax.set_xlabel('Solar wind $v_X$ [km/s]')
ax.set_ylabel('Impact on model output [min]')
#ax[0].set_title('ACE Position in GSE')
ax.set_xticks(np.arange(-900, -200, step=300))
#plt.yticks(np.arange(-45, 46, step=15))
ax.grid(True)

plt.tight_layout()
plt.savefig('plot_shapley_xspeed.pdf',bbox_inches='tight')
plt.show()