我正在研究一个涉及正交化多组向量的问题。最后,为了我的目的,它需要是jax.grad,因为它也将进入梯度下降算法。
一般来说,函数看起来像这样
输入:
A(…,M,N)
-
输入张量,其中第一个未知数量的索引枚举了我们想要QR反编译的不同矩阵。对于每个矩阵,列是我们想要制作正交的向量(与标准输入方案相同)
退货
Q(…,M,N)
最初我只打算使用jax.np.qr函数和Q矩阵。但是,我使用这种分解计算的量对Q矩阵列的符号非常敏感。输出是输入的平滑函数非常重要。(结果是我要进行傅里叶变换)
我的问题源于(我认为)qr函数输出不是输入的平滑函数。输入矩阵的微小变化会导致Q矩阵彼此非常接近,直到一个符号。但这不会给我带来我所期望的行为。(我正在计算wannier函数,这些函数是与规范相关的量,这些符号是规范选择的一部分)
通过查阅文档,我没有看到任何明显可以解决这个问题的关键字。我想知道是否有人以前做过类似的事情,或者知道我在哪里可以找到这样的东西。
我可以使用手写的标准Gram-Schmidt程序来工作,但这个函数将在相当大的矩阵网格上被调用很多次。因此,如果一个主要的库已经实现了这样的东西,它将比我能写的任何糟糕的代码都要好得多。。。