代码之家  ›  专栏  ›  技术社区  ›  Yuxi L

在matplotlib中用网格线绘制热图会错过网格线

  •  0
  • Yuxi L  · 技术社区  · 5 月前

    我正试图绘制一张有网格线的热图。这是我的代码(改编自 this post ):

    # Plot a heatmap with gridlines
    
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from functional import seq
    
    arr = np.random.randn(3, 20)
    
    plt.tight_layout()
    ax = plt.subplot(111)
    ax.imshow(arr, cmap='viridis')
    
    xr = ax.get_xlim()
    yr = ax.get_ylim()
    ax.set_xticks(np.arange(max(xr))-0.5, minor=True)
    ax.set_yticks(np.arange(max(yr))-0.5, minor=True)
    ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1)
    ax.tick_params(which='major', bottom=False, left=False)
    ax.tick_params(which='minor', bottom=False, left=False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    for spine in ax.spines.values():
        spine.set_visible(False)
    
    
    plt.show()
    

    我得到以下情节(裁剪为内容):

    bad heatplot

    第3列、第8列和倒数第三列之后没有垂直网格线。由于这个脚本使用小刻度来绘制网格线,我还想打印这些刻度,并在不隐藏标签的情况下绘制相同的图:

    print('xticks:', np.arange(max(xr))-0.5)
    
    xticks: [-0.5  0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5  9.5 10.5 11.5 12.5
     13.5 14.5 15.5 16.5 17.5 18.5]
    

    bad heatmap with labels

    这表明所有必要的滴答声都在那里。这里的问题是什么?

    2 回复  |  直到 5 月前
        1
  •  1
  •   David K.    5 月前

    我不知道为什么会这样。使用主网格可以解决此问题:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    arr = np.random.randn(3, 20)
    
    plt.tight_layout()
    ax = plt.subplot(111)
    ax.imshow(arr, cmap='viridis')
    
    xr = ax.get_xlim()
    yr = ax.get_ylim()
    ax.set_xticks(np.arange(max(xr))-0.5, minor=False)
    ax.set_yticks(np.arange(max(yr))-0.5, minor=False)
    
    ax.grid(which='major', snap=False, color='k', linestyle='-', linewidth=1)
    
    ax.tick_params(which='major', bottom=False, left=False)
    ax.tick_params(which='minor', bottom=False, left=False)
    
    plt.show()
    
        2
  •  1
  •   rehaqds    5 月前

    实际上,这有点棘手,但问题是,在你的代码中,你只设置了小刻度。由于你没有设置主要的刻度,matplotlib为你做了这件事。

    如果你用print(ax.get_xticks())打印刻度的位置,你会得到:

    [-2.5  0.   2.5  5.   7.5 10.  12.5 15.  17.5 20. ]
    

    如果你查看你的图,你会注意到,在x=2.5、7.5、12.5和17.5处,每个有主要刻度的位置都没有小网格。如果我们仔细想想,这是有道理的。主要刻度是主要刻度(如果我们不定义它们,matplotlib可以为我们定义),然后在主要刻度之间有次要刻度。

    因此,要有一个漂亮的网格,只需在设置小刻度之前在代码中添加以下内容:

    ax.set_xticks(np.arange(max(xr)), minor=False)
    ax.set_yticks(np.arange(max(yr)), minor=False)
    

    plot