代码之家  ›  专栏  ›  技术社区  ›  Spectacles4

给定区间列表的非重叠对数计数的最佳方法

  •  3
  • Spectacles4  · 技术社区  · 1 年前

    我试图在给定间隔列表的情况下计算不重叠对的数量。

    例如

    [(1, 8), (7, 9), (3, 10), (7, 12), (11, 13), (13, 14), (9, 15)]
    

    共有8对:

    ((1, 8), (11, 13))
    ((1, 8), (13, 14))
    ((1, 8), (9, 15))
    ((7, 9), (11, 13))
    ((7, 9), (13, 14))
    ((3, 10), (11, 13))
    ((3, 10), (13, 14))
    ((7, 12), (13, 14))
    

    我似乎想不出更好的解决方案,除了通过将所有东西与几乎所有其他东西进行比较来强行使用它,从而得到O(n^2)解决方案。

    def count_non_overlapping_pairs(intervals):
        intervals = list(set(intervals))  # deduplicate any intervals
        intervals.sort(key=lambda x: x[1])
        pairs = 0
        for i in range(len(intervals)):
            for j in range(i+1, len(intervals)):
                if intervals[i][1] < intervals[j][0]:
                    pairs += 1
        return pairs
    

    还有比这更优化的解决方案吗?

    3 回复  |  直到 1 年前
        1
  •  0
  •   Luatic    1 年前

    对间隔进行排序,按起点排序一次,按终点排序一次。

    现在,给定一个区间,使用按终点排序的区间中的起点执行二进制搜索。你得到的索引告诉你之前有多少个非重叠区间:在区间开始之前结束的所有区间都是非重叠的。

    对终点做同样的操作:在按起点排序的区间数组中进行二进制搜索。间隔结束后开始的所有间隔都是不重叠的。

    所有其他间隔要么在间隔结束之前开始,但在间隔开始之后;要么在间隔开始后结束,但在它之前开始。

    对每个间隔执行此操作,并对结果求和。一定要减半,不要把间隔数两次。如下所示:

    from bisect import bisect_left, bisect_right
    
    def count_non_overlapping_pairs(intervals):
        starts = sorted(interval[0] for interval in intervals)
        ends = sorted(interval[1] for interval in intervals)
        def count_non_overlapping(interval):
            before = bisect_left(ends, interval[0])
            after = len(ends) - bisect_right(starts, interval[1])
            return before + after
        # halve, because we don't want to count (a, b) and (b, a)
        return sum(map(count_non_overlapping, intervals)) // 2
    
    print(count_non_overlapping_pairs([
        (1, 8),
        (7, 9),
        (3, 10),
        (7, 12),
        (11, 13),
        (13, 14),
        (9, 15)
    ]))
    # prints 8
    

    总的来说,你会得到O(n-logn):两个排序,O(n)乘以两个O(logn)二进制搜索。


    现在观察到,这一半甚至不需要——如果a和b不重叠,那么如果b计算之前的间隔就足够了;a不需要计算它之后的间隔,并且丢弃一半的代码。然后简化为:

    # Counting the intervals before suffices
    def count_non_overlapping_pairs(intervals):
        ends = sorted(interval[1] for interval in intervals)
        def count_before(interval):
            return bisect_left(ends, interval[0])
        return sum(map(count_before, intervals))
    

    对称地,您也可以只计算间隔之后的间隔。

        2
  •  0
  •   Andrej Kesely    1 年前

    您可以对列表进行排序并进行二进制搜索(使用 bisect 单元

    from bisect import bisect
    
    lst = [(1, 8), (7, 9), (3, 10), (7, 12), (11, 13), (13, 14), (9, 15)]
    lst.sort()
    
    out = []
    for a, b in lst:
        for t in lst[bisect(lst, (b, float("inf"))) :]:
            out.append(((a, b), t))
    print(*out, sep="\n")
    

    打印:

    ((1, 8), (9, 15))
    ((1, 8), (11, 13))
    ((1, 8), (13, 14))
    ((3, 10), (11, 13))
    ((3, 10), (13, 14))
    ((7, 9), (11, 13))
    ((7, 9), (13, 14))
    ((7, 12), (13, 14))
    

    编辑:如果您只想计数:

    lst = [(1, 8), (7, 9), (3, 10), (7, 12), (11, 13), (13, 14), (9, 15)]
    lst.sort()
    
    cnt = sum(len(lst) - bisect(lst, (b, float("inf"))) for a, b in lst)
    print(cnt)
    

    打印:

    8
    
        3
  •  0
  •   ewz93    1 年前

    我能想到的加快速度的一种方法是,只要满足条件,就停止内部循环,并为所有剩余元素加一,同时确保列表按每个元组对的第二个值和第一个值排序:

    def count_non_overlapping_pairs(intervals):
        intervals = sorted(set(intervals), key=lambda x: (x[1], x[0]))
        pairs = 0
        for i in range(len(intervals)):
            for j in range(i+1, len(intervals)):
                if intervals[i][1] < intervals[j][0]:
                    pairs += len(intervals) - j
                    break
        return pairs
    

    尽管我想@Luatics使用二进制搜索的解决方案可能对大列表更快。