代码之家  ›  专栏  ›  技术社区  ›  mike

从sympy生成优化的八度码

  •  2
  • mike  · 技术社区  · 7 年前

    我有一些巨大的矩阵要导出,其中只包含sin(q)、cos(q)和这些的和/mul。sympy可以计算这个并导出到八度-这太棒了! 不过,既然这些都是大婚姻,我需要一些 cse 或者更好的专门优化。

    我发现 this great tutorial for C code with cse . 所以我尝试移植它,但在打印机类的一些细节上失败了。我认为这是一个无限递归,导致 RecursionError: maximum recursion depth exceeded .

    我的问题是:有没有一个例子,辛八度码和优化是如何结合在一起的?或者有人能帮我把随车随车行驶吗?

    import sympy as sp
    t = sp.symbols('t')
    
    from sympy.printing.octave import OctaveCodePrinter
    from sympy.printing.octave import Assignment
    class matlabMatrixPrinter(OctaveCodePrinter):
    
        def _print_ImmutableDenseMatrix(self, expr):
            sub_exprs, simplified = sp.cse(expr)
            lines = []
            for var, sub_expr in sub_exprs:
                lines.append( self._print(Assignment(var, sub_expr)))
            M = sp.MatrixSymbol('M', *expr.shape)
            return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))
    
    tmp = sp.sin(t)+sp.sin(t)**2
    tmp = sp.ImmutableDenseMatrix((1,1,tmp))
    se, ex = sp.cse(tmp)
    print((ex,se))
    print('\n')
    #tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
    p = matlabMatrixPrinter()
    print(p.doprint(tmp))
    

    编辑:我现在知道了,RETURN语句中的第二个赋值也会运行函数print-immutabledensematrix,所以这最终是一个递归。我不知道为什么在本教程中这对于C代码来说没有问题,但是这里它是递归运行的。这似乎只是一个简化表达式本身的问题,它不能调用self.\u print函数。也许有人知道一些关于这些打印机的知识,以及应该如何打印矩阵和这项任务?你说什么?

    1 回复  |  直到 7 年前
        1
  •  0
  •   mike    7 年前

    经过大量的实验,我觉得我仍然只是理解了代码打印机故意工作流程背后的一些意图。然而,我写了一个子类,它完全按照我的意愿来做(小心,因为这可能不适用于除母系以外的任何事物!).

    也许这对某人有用!对我来说,它肯定能证明sympy是一种工作工具,因为除此之外 sin 评估绝对是不可行的代码。

    我仍然对某人的评论和想法非常感兴趣,谁知道应该如何实现这些特性呢!

    import sympy as sp
    t = sp.symbols('t')
    from sympy.printing.octave import OctaveCodePrinter
    from sympy.printing.octave import Assignment
    class matlabMatrixPrinter(OctaveCodePrinter):
        def print2(self,expr_list,names=None):
            sub_exprs, simplified = sp.cse(expr_list)
            lines = []
            for var, sub_expr in sub_exprs:
                lines.append(self._print(Assignment(var, sub_expr)))
            lines.append('')
            for k,expr in enumerate(simplified):
                if names:
                    M = sp.MatrixSymbol(names[k],*expr.shape)
                else:
                    M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
                lines.append(self._print(Assignment(M,expr)))
            result = ''
            return '\n'.join(lines)
    
    tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
    tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])
    
    p = matlabMatrixPrinter()
    #print(p.print2([tmp,tmp2]))
    print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));
    

    这给出了预期的输出:

    x0 = sin(t);
    x1 = cos(t);
    scalar_matrix = x0.^2 + x0;
    matrix = [x0; x1; 2*x0; x1.^2];
    

    如上所述:自担风险使用:)

    推荐文章