代码之家  ›  专栏  ›  技术社区  ›  kye

最小数值。广告示例无法编译

  •  1
  • kye  · 技术社区  · 8 年前

    我试图从Numeric.AD编译以下最小示例:

    import Numeric.AD 
    timeAndGrad f l = grad f l
    main = putStrLn "hi"
    

    我遇到了这个错误:

    test.hs:3:24:
        Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
                                           s a)
                                      -> Numeric.AD.Internal.Reverse.Reverse s a’
                    with actual type ‘t’
          because type variable ‘s’ would escape its scope
        This (rigid, skolem) type variable is bound by
          a type expected by the context:
            Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
            f (Numeric.AD.Internal.Reverse.Reverse s a)
            -> Numeric.AD.Internal.Reverse.Reverse s a
          at test.hs:3:19-26
        Relevant bindings include
          l :: f a (bound at test.hs:3:15)
          f :: t (bound at test.hs:3:13)
          timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
        In the first argument of ‘grad’, namely ‘f’
        In the expression: grad f l
    

    你知道为什么会这样吗?从前面的例子来看,我认为这是“扁平化” grad 这是一种类型:

    grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

    但我实际上需要在我的代码中这样做。事实上,这是无法编译的最小示例。我想做的更复杂的事情是这样的:

    example :: SomeType
    example f x args = (do stuff with the gradient and gradient "function")
        where gradient = grad f x
              gradientFn = grad f
              (other where clauses involving gradient and gradient "function")
    

    这里有一个稍微复杂一些的版本,带有类型签名,可以编译。

    {-# LANGUAGE RankNTypes #-}
    
    import Numeric.AD 
    import Numeric.AD.Internal.Reverse
    
    -- compiles but I can't figure out how to use it in code
    grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
    grad2 f l = grad f l
    
    -- compiles with the right type, but the resulting gradient is all 0s...
    grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
    grad2' f l = grad f' l
           where f' = Lift . f . extractAll
           -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work
    
    extractAll :: [Reverse t a] -> [a]
    extractAll xs = map extract xs
               where extract (Lift x) = x -- non-exhaustive pattern match
    
    dist :: (Show a, Num a, Floating a) => [a] -> a
    dist [x, y] = sqrt(x^2 + y^2)
    
    -- incorrect output: [0.0, 0.0]
    main = putStrLn $ show $ grad2' dist [1,2]
    

    然而,我不知道如何使用第一个版本, grad2 ,在代码中,因为我不知道如何处理 Reverse s a 第二版本, grad2' ,具有正确的类型,因为我使用内部构造函数 Lift 创建 反向s a ,但我不能理解内部(特别是参数 s )工作,因为输出梯度都是0。使用其他构造函数 Reverse (此处未显示)也会产生错误的梯度。

    或者,是否有使用过 ad 密码我认为我的用例非常常见。

    1 回复  |  直到 8 年前
        1
  •  2
  •   leftaroundabout    8 年前

    具有 where f' = Lift . f . extractAll 本质上,您创建了一个自动微分基础类型的后门,它丢弃了所有导数,只保留常量值。如果你用这个 grad ,您得到的结果为零也就不足为奇了!

    明智的方法是使用 毕业生 事实上:

    dist :: Floating a => [a] -> a
    dist [x, y] = sqrt $ x^2 + y^2
    -- preferrable is of course `dist = sqrt . sum . map (^2)`
    
    main = print $ grad dist [1,2]
    -- output: [0.4472135954999579,0.8944271909999159]
    

    你不需要知道任何更复杂的东西来使用自动微分。只要你能区分 Num Floating -多态函数,一切都将按原样工作。如果需要区分作为参数传入的函数,则需要使该参数秩-2多态(另一种选择是切换到 ad 功能,但我敢说这不太优雅,也不会给你带来太多好处)。

    {-# LANGUAGE Rank2Types, UnicodeSyntax #-}
    
    mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
    mainWith f = print $ grad f [1,2]
    
    main = mainWith dist
    
    推荐文章