代码之家  ›  专栏  ›  技术社区  ›  x89

只创建2种颜色的热图

  •  1
  • x89  · 技术社区  · 1 年前

    我有一个数据集,看起来像这样:

    profession     Australia_F   Australia_M      Canada_F      Canada_M    Kenya_F   Kenya_M
    Author         20            80               55            34          60        23
    Librarian      10            34               89            33          89        12
    Pilot          78            12               67            90          12        55
    

    我想用这些值绘制一种热图。我试过了:

    melted_df = pd.melt(df, id_vars='Profession', var_name='Country_Gender', value_name='Number')
    
    
    
    melted_df[['Country', 'Gender']] = melted_df['Country_Gender'].str.split('_', expand=True)
    melted_df['Number'] = pd.to_numeric(melted_df['Number'], errors='coerce')
    
    heatmap_data = melted_df.pivot_table(index='Profession', columns=['Country', 'Gender'], values='Number')
    
    plt.figure(figsize=(10, 8))  
    sns.heatmap(heatmap_data, cmap='coolwarm', annot=True, fmt=".1f", linewidths=.5)
    plt.xlabel('Country and Gender')  
    plt.ylabel('Profession')  
    plt.xticks(rotation=45)  
    plt.tight_layout()  
    plt.savefig('heatmap.png')
    

    它似乎有效,但目前它根据数值为所有单元格分配不同的颜色。然而,我只想要两种颜色在我的图表:红色和amp;蓝色

    我想要的是,对于每个职业(每一行),我比较每个国家的F值与M值,并将较高值的单元格涂成红色。

    例如,对于“作者”,这三个单元格应为红色:

    澳大利亚_M(80) 加拿大_F(55) 肯尼亚_F(60)

    而该行中的其他3个应该是蓝色的。我怎样才能做到这一点?

    2 回复  |  直到 1 年前
        1
  •  1
  •   JohanC    1 年前

    可以使用两种不同的数据帧进行着色和文本注释。创建原始数据帧的副本,比较偶数列和奇数列将创建布尔值的数据帧。这些布尔值(内部值 0 对于 False 1 对于 True )然后决定颜色。

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    
    data = {'Profession': ['Author', 'Librarian', 'Pilot'],
            'Australia_F': [20, 10, 78],
            'Australia_M': [80, 34, 12],
            'Canada_F': [55, 89, 67],
            'Canada_M': [34, 33, 90],
            'Kenya_F': [60, 89, 12],
            'Kenya_M': [23, 12, 55]}
    df = pd.DataFrame(data).set_index('Profession')
    df_coloring = df.copy()
    for colF, colM in zip(df_coloring.columns[::2], df_coloring.columns[1::2]):
        df_coloring[colF] = df[colF] > df[colM]
        df_coloring[colM] = df[colM] > df[colF]
    
    sns.set_style('white')
    plt.figure(figsize=(10, 8))
    sns.heatmap(df_coloring, cmap='coolwarm', annot=df, fmt=".1f", linewidths=.5, cbar=False)
    plt.xlabel('Country and Gender')
    plt.ylabel('Profession')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    

    sns.heatmap with two colors

    或者,你可以添加额外的分离,将性别放在顶部,将国家放在底部:

    sns.set_style('white')
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(df_coloring, cmap='coolwarm', annot=df, fmt=".0f", linewidths=.5, cbar=False, annot_kws={"size": 22})
    countries = [l.get_text()[:-2] for l in ax.get_xticklabels()[::2]]
    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(ax.get_xticks(), [l.get_text()[-1:] for l in ax.get_xticklabels()])
    ax_top.tick_params(length=0)
    ax.set_xticks(range(1, len(df.columns), 2), countries)
    
    for i in range(0, len(df.columns) + 1, 2):
        ax.axvline(i, lw=4, color='white')
    for i in range(0, len(df) + 1):
        ax.axhline(i, lw=4, color='white')
    ax.set_xlabel('Country and Gender')
    ax.set_ylabel('Profession')
    plt.tight_layout()
    plt.show()
    

    sns.heatmap with extra separations

        2
  •  0
  •   Hussein ِAl-Fartousi    1 年前
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    data = {
        'Profession': ['Author', 'Librarian', 'Pilot'],
        'Australia_F': [20, 10, 78],
        'Australia_M': [80, 34, 12],
        'Canada_F': [55, 89, 67],
        'Canada_M': [34, 33, 90],
        'Kenya_F': [60, 89, 12],
        'Kenya_M': [23, 12, 55]
    }
    df = pd.DataFrame(data)
    melted_df = pd.melt(df, id_vars='Profession', var_name='Country_Gender', value_name='Number')
    melted_df[['Country', 'Gender']] = melted_df['Country_Gender'].str.split('_', expand=True)
    melted_df['Number'] = pd.to_numeric(melted_df['Number'], errors='coerce')
    def assign_color(row):
        if row['Gender'] == 'F':
            return 'red' if row['Number'] > melted_df[(melted_df['Profession'] == row['Profession']) & (melted_df['Country'] == row['Country']) & (melted_df['Gender'] == 'M')]['Number'].values[0] else 'blue'
        else:
            return 'red' if row['Number'] > melted_df[(melted_df['Profession'] == row['Profession']) & (melted_df['Country'] == row['Country']) & (melted_df['Gender'] == 'F')]['Number'].values[0] else 'blue'
    melted_df['Color'] = melted_df.apply(assign_color, axis=1)
    heatmap_data = melted_df.pivot_table(index='Profession', columns=['Country', 'Gender'], values='Number')
    plt.figure(figsize=(10, 8))  
    sns.heatmap(heatmap_data, cmap='coolwarm', annot=True, fmt=".1f", linewidths=.5, cbar=False, square=True, mask=heatmap_data.isna(), annot_kws={"fontsize":10}, center=50)
    plt.xlabel('Country and Gender')  
    plt.ylabel('Profession')  
    plt.xticks(rotation=45)  
    plt.tight_layout()  
    plt.savefig('heatmap.png')