代码之家  ›  专栏  ›  技术社区  ›  Attack68

不带副本的自定义类型的Rust ndarray点

  •  0
  • Attack68  · 技术社区  · 3 年前

    在里面 python NumPy 我可以使用运算符重载创建自定义数据类型,并使用线性代数函数(例如 matmul , dot , tensordot 等等)

    import numpy as np
    
    
    class MyType:
        def __init__(self, a, b):
            self.a, self.b = a, b
    
        def __add__(self, other):
            return MyType(a=self.a + other.a, b=self.b + other.b)
    
        def __mul__(self, other):
            return MyType(a=self.a + other.a, b=self.b * other.b)
    
        def __repr__(self):
            return "<{}, {}>".format(self.a, self.b)
    
    
    arr1 = np.array([MyType("A", 2.0), MyType("B", 3.0)])
    arr2 = np.array([MyType("C", 2.0), MyType("D", 4.0)])
    
    >>> arr1 + arr2
    [<AC, 4.0> <BD, 7.0>]
    >>> arr1 * arr2
    [<AC, 4.0> <BD, 12.0>]
    >>> arr1.dot(arr2))
    <ACBD, 16.0>
    

    在里面 Rust 我可以创建相同的重载,但 ndarray 大木箱 Dot 函数不起作用,因为 MyType 不满足特征界限。

    use std::ops;
    use ndarray::{arr1};
    
    #[derive(Clone, Debug)]
    pub struct MyType {
        a : String,
        b : f64
    }
    
    impl ops::Add<MyType> for MyType {
        type Output = MyType;
        fn add(self, other: MyType) -> MyType {
            MyType {a: [self.a, other.a].join(""), b: self.b + other.b}
        }
    }
    
    impl ops::Mul<MyType> for MyType {
        type Output = MyType;
        fn mul(self, other: MyType) -> MyType {
            MyType {a: [self.a, other.a].join(""), b: self.b * other.b}
        }
    }
    
    fn main() {
        let arr_1 = arr1(&[
            MyType {a: "A".to_string(), b: 2.0}, MyType {a: "B".to_string(), b: 3.0}
        ]);
        let arr_2 = arr1(&[
            MyType {a: "C".to_string(), b: 2.0}, MyType {a: "D".to_string(), b: 4.0}
        ]);
        >>> println!("{:?}", arr_1.clone() + arr_2.clone());
        [MyType { a: "AC", b: 4.0 }, MyType { a: "BD", b: 7.0 }]
        >>> println!("{:?}", arr_1 * arr_2);
        [MyType { a: "AC", b: 4.0 }, MyType { a: "BD", b: 12.0 }]
        >>> println!("{:?}", arr_1.dot(&arr_2));
        error[E0599]: the method `dot` exists for struct `ArrayBase<OwnedRepr<MyType>, Dim<[usize; 1]>>`, but its trait bounds were not satisfied
    }
    

    重要的特征似乎是 pub trait LinalgScalar: 'static + Copy + Zero + One + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> + Div<Output = Self> { }

    我可以实现所有这些,除了 Copy ,因为我不确定的属性 MyType 可以复制,我不相信我可以允许静态的一生。这是否意味着我需要实现我自己的特定形式的 功能无疑会比默认功能效率低得多?

    0 回复  |  直到 3 年前