代码之家  ›  专栏  ›  技术社区  ›  riqitang

声明numpy矩阵的简洁方法

  •  1
  • riqitang  · 技术社区  · 7 年前

    什么是一种简短的、可读的方法来声明一个999x999 numpy矩阵,其中每一行是 [1,2,3,...,999] ? 最终矩阵应为:

    [[1,2,3,...,999]
    [1,2,3,...,999]
    ...
    [1,2,3,...,999]]
    
    2 回复  |  直到 7 年前
        1
  •  5
  •   jpp    7 年前

    你可以用 numpy.tile :

    import numpy as np
    
    res = np.tile(range(10), (5, 1))
    
    print(res)
    
    array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    

    或者,也可以添加到零数组中:

    res = np.zeros((5, 10)) + range(10)
    
        2
  •  1
  •   AGN Gazer    7 年前

    @jpp的回答很优雅,但以下解决方案更有效:

    res = np.empty((nrows, ncols))
    res[:, :] = np.arange(ncols)
    

    时间安排:

    %timeit a = np.empty((1000,1000)); a[:, :] = np.arange(1000)
    445 µs ± 9.08 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %timeit np.tile(range(1000), (1000, 1))
    1.43 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    进一步的定时测试:

    在@jpp注释之后,我又添加了一个直接在Python解释器中完成的测试(与Jupyter笔记本中运行的原始测试不同,因为它当时已经启动并运行):

    >>> import sys
    >>> print(sys.version)
    3.6.5 |Anaconda, Inc.| (default, Apr 26 2018, 08:42:37) 
    [GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)]
    >>> import numpy as np
    >>> print(np.__version__)
    1.13.3
    >>> import timeit
    >>> t = timeit.repeat('res = np.empty((nrows, ncols)); res[:, :] = np.arange(ncols)', setup='import numpy as np; nrows=ncols=1000', number=100, repeat=50)
    >>> print(min(t), max(t), np.mean(t), np.std(t))
    0.04336756598786451 0.053294404002372175 0.0459639201409 0.00240180447219
    >>> t = timeit.repeat('res = np.tile(range(ncols), (nrows, 1))', setup='import numpy as np; nrows=ncols=1000', number=100, repeat=50)
    >>> print(min(t), max(t), np.mean(t), np.std(t))
    0.05032560401014052 0.05859642301220447 0.0530669655403 0.00225117881195
    

    结果是 numpy 1.14.5 几乎相同:

    >>> import sys
    >>> print(sys.version)
    3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 11:07:29) 
    [GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)]
    >>> import numpy as np
    >>> print(np.__version__)
    1.14.5
    >>> import timeit
    >>> t = timeit.repeat('res = np.empty((nrows, ncols)); res[:, :] = np.arange(ncols)', setup='import numpy as np; nrows=ncols=1000', number=100, repeat=50)
    >>> print(min(t), max(t), np.mean(t), np.std(t))
    0.04360878499574028 0.05562149798788596 0.04657964294136036 0.0025253372244474614
    >>> t = timeit.repeat('res = np.tile(range(ncols), (nrows, 1))', setup='import numpy as np; nrows=ncols=1000', number=100, repeat=50)
    >>> print(min(t), max(t), np.mean(t), np.std(t))
    0.05024543400213588 0.06169128899637144 0.05339125283906469 0.00276210097759817