代码之家  ›  专栏  ›  技术社区  ›  Krystal

如何在三维绘图中绘制多维数组?

  •  0
  • Krystal  · 技术社区  · 1 年前

    我有一个多维阵列(P x N x M),我想在3D图中绘制每个N x M阵列,使P个图像沿着z轴堆叠。

    你知道用Python怎么做吗?

    提前感谢

    2 回复  |  直到 1 年前
        1
  •  1
  •   Geom    1 年前

    如果你想让N x M个阵列作为“热图”沿着z轴堆叠在一起,这是一种方法:

    import numpy as np
    import matplotlib.pyplot as plt
    
    # Generate some dummy arrays
    P, N, M = 5, 10, 10
    data = np.random.rand(P, N, M)
    
    # Create a 3D figure
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    # Create meshgrid for x, y values
    x, y = np.meshgrid(np.arange(M), np.arange(N))
    
    # Plot each N x M array as a heatmap at different heights along the z-axis
    for p in range(P):
        heatmap = data[p]
        ax.plot_surface(x, y, np.full_like(heatmap, p), facecolors=plt.cm.viridis(heatmap), rstride=1, cstride=1, antialiased=True, shade=False)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('P')
    ax.set_title('Stacked Heatmaps')
    plt.show()
    

    结果:

    enter image description here

        2
  •  0
  •   Aditya Tiwari    1 年前

    您可以通过使用Matplotlib的Axes3D模块来实现这一点。该代码将生成3D散点图,其中来自P x N x M阵列的每个2D切片沿z轴堆叠在不同的高度(由z变量控制)。散点图中每个点的颜色表示相应切片中的值,并添加一个颜色条来指示数据值。

    import numpy as np
    import matplotlib.pyplot as plt
    
    # Example multidimensional array of shape (P, N, M)
    # Replace this with your actual data
    P, N, M = 5, 10, 10
    data = np.random.rand(P, N, M)
    
    # Create a 3D scatter plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    # Create a meshgrid for the x and y values
    x, y = np.meshgrid(range(N), range(M))
    
    for p in range(P):
      # Flatten the 2D slice and stack it at the height of p
      z = np.full((N, M), p)
      ax.scatter(x, y, z, c=data[p].ravel(), cmap='viridis')
    
    # Set labels for each axis
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    # Customize the colorbar
    norm = plt.Normalize(data.min(), data.max())
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    sm.set_array([])
    fig.colorbar(sm, label='Data Values')
    
    plt.show()
    
    

    结果: enter image description here