代码之家  ›  专栏  ›  技术社区  ›  Tom de Geus

stl类向量类的稳健型caster

  •  0
  • Tom de Geus  · 技术社区  · 7 年前

    我有一个与stl向量非常相似的类(这些差异对于pybind11类型的caster并不重要,因此我将在这里忽略它们)。我已经为这门课写了一个打字机。下面给出了我的代码的一个最小的工作示例。代码下面包含一个显示问题的示例。

    问题是我的施法器很有限(因为我用过 py::array_t )中。原则上,接口接受元组、列表和numpy数组。但是,当我基于typename重载时,对于输入的元组和列表,接口将失败(只需选择第一个重载,即使它是不正确的类型)。

    我的问题是: 我怎样才能使类型脚轮更坚固?是否有一种有效的方式重复使用尽可能多的现有类型的脚轮STL向量类?

    C++代码(包括PYBID11接口)

    #include <iostream>
    #include <vector>
    #include <pybind11/pybind11.h>
    #include <pybind11/stl.h>
    #include <pybind11/numpy.h>
    
    namespace py = pybind11;
    
    // class definition
    // ----------------
    
    template<typename T>
    class Vector
    {
    private:
    
      std::vector<T> mData;
    
    public:
    
      Vector(){};
      Vector(size_t N) { mData.resize(N); };
    
      auto   data ()       { return mData.data (); };
      auto   data () const { return mData.data (); };
      auto   begin()       { return mData.begin(); };
      auto   begin() const { return mData.begin(); };
      auto   end  ()       { return mData.end  (); };
      auto   end  () const { return mData.end  (); };
      size_t size () const { return mData.size (); };
    
      std::vector<size_t> shape()   const { return std::vector<size_t>(1, mData.size()); }
      std::vector<size_t> strides() const { return std::vector<size_t>(1, sizeof(T)   ); }
    
      template<typename It> static Vector<T> Copy(It first, It last) {
        Vector out(last-first);
        std::copy(first, last, out.begin());
        return out;
      }
    };
    
    // C++ functions: overload based on type
    // -------------------------------------
    
    Vector<int>    foo(const Vector<int>    &A){ std::cout << "int"    << std::endl; return A; }
    Vector<double> foo(const Vector<double> &A){ std::cout << "double" << std::endl; return A; }
    
    // pybind11 type caster
    // --------------------
    
    namespace pybind11 {
    namespace detail {
    
    template<typename T> struct type_caster<Vector<T>>
    {
    public:
    
      PYBIND11_TYPE_CASTER(Vector<T>, _("Vector<T>"));
    
      bool load(py::handle src, bool convert)
      {
        if ( !convert && !py::array_t<T>::check_(src) ) return false;
    
        auto buf = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(src);
        if ( !buf ) return false;
    
        auto rank = buf.ndim();
        if ( rank != 1 ) return false;
    
        value = Vector<T>::Copy(buf.data(), buf.data()+buf.size());
    
        return true;
      }
    
      static py::handle cast(const Vector<T>& src, py::return_value_policy policy, py::handle parent)
      {
        py::array a(std::move(src.shape()), std::move(src.strides()), src.data());
    
        return a.release();
      }
    };
    
    }} // namespace pybind11::detail
    
    // Python interface
    // ----------------
    
    PYBIND11_MODULE(example,m)
    {
      m.doc() = "pybind11 example plugin";
    
      m.def("foo", py::overload_cast<const Vector<int   > &>(&foo));
      m.def("foo", py::overload_cast<const Vector<double> &>(&foo));
    }
    

    例子

    import numpy as np
    import example
    
    print(example.foo((1,2,3)))
    print(example.foo((1.5,2.5,3.5)))
    
    print(example.foo(np.array([1,2,3])))
    print(example.foo(np.array([1.5,2.5,3.5])))
    

    输出:

    int
    [1 2 3]
    int
    [1 2 3]
    int
    [1 2 3]
    double
    [1.5 2.5 3.5]
    
    1 回复  |  直到 7 年前
        1
  •  0
  •   Tom de Geus    7 年前

    一个非常简单的解决方案是专业化 pybind11::detail::list_caster 是的。类型施法器现在变得像

    namespace pybind11 {
    namespace detail {
    
    template <typename Type> struct type_caster<Vector<Type>> : list_caster<Vector<Type>, Type> { };
    
    }} // namespace pybind11::detail
    

    请注意,这确实需要 Vector 有方法:

    • clear()
    • push_back(const Type &value)
    • reserve(size_t n) (在测试中似乎是可选的)

    完整示例

    #include <iostream>
    #include <vector>
    #include <pybind11/pybind11.h>
    #include <pybind11/stl.h>
    #include <pybind11/numpy.h>
    
    namespace py = pybind11;
    
    // class definition
    // ----------------
    
    template<typename T>
    class Vector
    {
    private:
    
      std::vector<T> mData;
    
    public:
    
      Vector(){};
      Vector(size_t N) { mData.resize(N); };
    
      auto   data ()       { return mData.data (); };
      auto   data () const { return mData.data (); };
      auto   begin()       { return mData.begin(); };
      auto   begin() const { return mData.begin(); };
      auto   end  ()       { return mData.end  (); };
      auto   end  () const { return mData.end  (); };
      size_t size () const { return mData.size (); };
    
      void push_back(const T &value) { mData.push_back(value); }
      void clear() { mData.clear(); }
      void reserve(size_t n) { mData.reserve(n); }
    
      std::vector<size_t> shape()   const { return std::vector<size_t>(1, mData.size()); }
      std::vector<size_t> strides() const { return std::vector<size_t>(1, sizeof(T)   ); }
    
      template<typename It> static Vector<T> Copy(It first, It last) {
        printf("Vector<T>::Copy %s\n", __PRETTY_FUNCTION__);
        Vector out(last-first);
        std::copy(first, last, out.begin());
        return out;
      }
    };
    
    // C++ functions: overload based on type
    // -------------------------------------
    
    Vector<int>    foo(const Vector<int>    &A){ std::cout << "int"    << std::endl; return A; }
    Vector<double> foo(const Vector<double> &A){ std::cout << "double" << std::endl; return A; }
    
    // pybind11 type caster
    // --------------------
    
        namespace pybind11 {
        namespace detail {
    
        template <typename Type> struct type_caster<Vector<Type>> : list_caster<Vector<Type>, Type> { };
    
        }} // namespace pybind11::detail
    
    // Python interface
    // ----------------
    
    PYBIND11_MODULE(example,m)
    {
      m.doc() = "pybind11 example plugin";
    
      m.def("foo", py::overload_cast<const Vector<double> &>(&foo));
      m.def("foo", py::overload_cast<const Vector<int   > &>(&foo));
    
    }
    
    推荐文章