代码之家  ›  专栏  ›  技术社区  ›  Sergio

python中大量数字的乘法

  •  0
  • Sergio  · 技术社区  · 1 年前

    我正在为自己开发一个小型python程序,我需要一个算法来快速将一个巨大的数组与数字相乘(超过66万个数字,每个数字为9位)。结果数字超过400万位。目前我正在使用math.prod,它可以在大约10分钟内计算出来,但这太慢了,特别是如果我想增加数字的数量。

    我检查了一些更快乘法的算法,例如SchnhageStrassen算法和ToomBook乘法,但我不明白它们是如何工作的,也不知道如何进行。我尝试了在互联网上找到的一些版本,但它们运行得不太好,甚至更慢。我想知道是否有人知道如何更快地将这些数字相乘,或者可以解释一下如何使用一些数学方法来实现这一点?

    1 回复  |  直到 1 年前
        1
  •  2
  •   Ry- Vincenzo Alcamo    1 年前

    math.prod 将一次累积一个产品编号。你可以通过递归划分列表来做得更好,例如,取每一半的乘积,这样可以减小中间产品的大小。

    对我来说,这将在几秒钟后运行:

    import math
    
    
    def recursive_prod(ns, r):
        if len(r) <= 10:  # arbitrary small base case
            return math.prod(ns[i] for i in r)
    
        split_at = len(r) // 2
        return recursive_prod(ns, r[:split_at]) * recursive_prod(ns, r[split_at:])
    
    
    import random
    ns = [random.randrange(1_000_000_000) for _ in range(660_000)]
    p = recursive_prod(ns, range(len(ns)))
    

    乘法操作数的大小比线性时间长,其中大小大致与对数相同,log(a*b)=log a+log b,所以你可以认为这大致类似于级联的运行时:

    def single_accumulator(seq):
        acc = []
    
        for x in seq:
            acc = acc + x  # note: `+=` works in place
    
        return acc
    
    
    def recursive_concat(seq):
        if len(seq) == 1:
            return seq[0]
    
        split_at = len(seq) // 2
        return recursive_concat(seq[:split_at]) + recursive_concat(seq[split_at:])
    
    
    #single_accumulator(...)
    recursive_concat([[1], [2], [3], ...])
    
        2
  •  0
  •   no comment Pratyush Goutam    1 年前

    我一直在使用一个棘手的方法,总是将最古老的两个尚未相乘的数字相乘,直到只剩下一个:

    from operator import mul
    
    def prod(ns):
        ns = list(ns)
        it = iter(ns)
        ns += map(mul, it, it)
        return ns[-1]
    

    大约和@Ry一样快。这种速度来自Karatsuba在将两个大数相乘时发挥的作用,而不是将一个大数和一个小数相乘。

    并使用 decimal 由于更快的乘法算法,速度似乎快了五倍(如果你想用十进制打印,它的优点是打印速度更快):

    from operator import mul
    from decimal import *
    
    setcontext(Context(prec=MAX_PREC, Emax=MAX_EMAX))
    
    def prod(ns):
        ns = list(map(Decimal, ns))
        it = iter(ns)
        ns += map(mul, it, it)
        return ns[-1]