代码之家  ›  专栏  ›  技术社区  ›  backtrack

sklearn管道正确使用

  •  2
  • backtrack  · 技术社区  · 8 年前

    我在python中有一个数据帧,它有一个名为“datetime”的日期时间字段。使用管道和FeatureUnion,我试图提取day、month、weekday和isBusinessday。为了提取这些特性,我编写了自定义代码。

    class itemselector(BaseEstimator, TransformerMixin):
        def __init__(self, key):
            self.key = key
    
        def transform(self, X):
            return X[self.key]
    
        def fit(self, X, y=None):
            return self
    
    
        f_df = Pipeline([
    
           ('union', FeatureUnion([
        ('date', Pipeline([
            ('sitem', itemselector('pickup_datetime')),
            ('sday', Extract_date()),
        ])),
        ('month', Pipeline([
            ('sitem', itemselector('pickup_datetime')),
            ('smonth', Extract_month()),
        ])),
    ])),
    
        ])
    

    当我运行这段代码时,我会得到一个列表作为输出。例如:

    df = f_df.fit_transform(df_train[:5])
    

    输出:

    [14 12 19  6 26  3  6  1  4  3]  // it has both day and month.  it is not expected output 
    

    但我是一个日月分离的人。我该怎么做?我的代码出了什么问题?谁能帮我找到它吗?

    更新

    总结一下我的问题,我得到了输出形状 (10,) (5,2)

    更新1

    class Extract_date(BaseEstimator, TransformerMixin):
        def fit(self, X):
            print('one')
            return self
    
        def transform(self, X):
            return X.apply(lambda y: y.day)
    
    
    class Extract_month(BaseEstimator, TransformerMixin):
        def fit(self, X, **atr):
            print('two')
            return self
    
        def transform(self, X):
            return X.apply(lambda y: y.month)
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   Vivek Kumar    8 年前

    好的 Extract_month Extract_date

    reshape(-1,1) 为了这个。

    因此,改变你的方法如下:

    class Extract_date(BaseEstimator, TransformerMixin):
        ...
        ...
    
        def transform(self, X):
            return X.apply(lambda y: y.day).values.reshape(-1,1)
    
    
    class Extract_month(BaseEstimator, TransformerMixin):
        ...
        ...
    
        def transform(self, X):
            return X.apply(lambda y: y.month).values.reshape(-1,1)