代码之家  ›  专栏  ›  技术社区  ›  seralouk

Seaborn JointPlot为每个类添加颜色

  •  4
  • seralouk  · 技术社区  · 7 年前

    我想用Seaborn绘制两个变量的相关图 jointplot 是的。我试过很多不同的方法,但是我不能根据班级的情况给分数加颜色。

    这是我的代码:

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    sns.set()
    
    X = np.array([5.2945 , 3.6013 , 3.9675 , 5.1602 , 4.1903 , 4.4995 , 4.5234 ,
                  4.6618 , 0.76131, 0.42036, 0.71092, 0.60899, 0.66451, 0.55388,
                  0.63863, 0.62504, 0.     , 0.     , 0.49364, 0.44828, 0.43066,
                  0.57368, 0.     , 0.     , 0.64824, 0.65166, 0.64968, 0.     ,
                  0.     , 0.52522, 0.58259, 1.1309 , 0.     , 0.     , 1.0514 ,
                  0.7519 , 0.78745, 0.94873, 1.0169 , 0.     , 0.     , 1.0416 ,
                  0.     , 0.     , 0.93648, 0.92801, 0.     , 0.     , 0.89594,
                  0.     , 0.80455, 1.0103 ])
    
    y = np.array([ 93, 115, 107, 115, 110, 107, 102, 113,  95, 101, 116,  74, 102,
                   102,  78,  85, 108, 110, 109,  80,  91,  88,  99, 110, 108,  96,
                   105,  93, 107,  98,  88,  75, 106,  92,  82,  84,  84,  92, 115,
                   107,  97, 115,  85, 133, 100,  65,  96, 105, 112, 107, 107, 105])
    
    ax = sns.jointplot(X, y, kind='reg' )
    ax.set_axis_labels(xlabel='Brain scores', ylabel='Cognitive scores')
    plt.tight_layout()
    plt.show()
    

    enter image description here

    现在,我想根据一个类变量为每个点添加颜色 classes 是的。

    2 回复  |  直到 7 年前
        1
  •  5
  •   ImportanceOfBeingErnest    7 年前

    显而易见的解决办法是 regplot 只画回归线,而不画点,然后通过一个通常的散点图来添加这些点,散点图的颜色是 c 争论。

    g = sns.jointplot(X, y, kind='reg', scatter = False )
    g.ax_joint.scatter(X,y, c=classes)
    

    enter image description here

        2
  •  1
  •   seralouk    7 年前

    我设法找到了一个正是我所需要的解决办法。感谢@importanceofbeingernest让我想到 regplot 只画回归线。

    解决方案:

    import pandas as pd
    
    classes = np.array([1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
                        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 
                        2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 
                        3., 3., 3., 3., 3., 3., 3.])
    
    df = pd.DataFrame(map(list, zip(*[X.T, y.ravel().T])))
    df = df.reset_index()
    df['index'] = classes[:]
    
    g = sns.jointplot(X, y, kind='reg', scatter = False )
    for i, subdata in df.groupby("index"):
        sns.kdeplot(subdata.iloc[:,1], ax=g.ax_marg_x, legend=False)
        sns.kdeplot(subdata.iloc[:,2], ax=g.ax_marg_y, vertical=True, legend=False)
        g.ax_joint.plot(subdata.iloc[:,1], subdata.iloc[:,2], "o", ms = 8)
    plt.tight_layout()
    plt.show()
    

    enter image description here