代码之家  ›  专栏  ›  技术社区  ›  3kstc

如何使用matplotlib遍历数据帧使图表连续多列

  •  0
  • 3kstc  · 技术社区  · 7 年前

    我有数据帧 是的

    电流输出为 [2781 rows x 10 columns] 使用测距数据集 start_date = '2006-01-01' 直到 end_date = '2016-12-31' . 输出如下所示为数据帧 print(df) :

    电流输出:

                ANZ Price  ANZ 3 day SMA  CBA Price  CBA 3 day SMA  MQG Price   MQG 3 day SMA  NAB Price  NAB 3 day SMA  WBC Price  WBC 3 day SMA 
    Date
    2006-01-02  23.910000            NaN  42.569401            NaN  66.558502             NaN  30.792999            NaN  22.566401            NaN
    2006-01-03  24.040001            NaN  42.619099            NaN  66.086403             NaN  30.935699            NaN  22.705400            NaN
    2006-01-04  24.180000      24.043334  42.738400      42.642300  66.587997       66.410967  31.078400      30.935699  22.784901      22.685567 
    2006-01-05  24.219999      24.146667  42.708599      42.688699  66.558502       66.410967  30.964300      30.992800  22.794800      22.761700
    ...               ...             ...       ...            ...        ...            ...         ...            ...        ...            ...
    2016-12-27   87.346667     30.670000  30.706666      32.869999  32.729999       87.346667  30.670000      30.706666  32.869999      32.729999
    2016-12-28   87.456667     31.000000  30.773333      32.980000  32.829999       87.456667  31.000000      30.773333  32.980000      32.829999
    2016-12-29   87.520002     30.670000  30.780000      32.599998  32.816666       87.520002  30.670000      30.780000  32.599998      32.816666
    

    #!/usr/bin/python3
    from pandas_datareader import data
    import pandas as pd
    import itertools as it
    import os
    import numpy as np
    import fix_yahoo_finance as yf
    import matplotlib.pyplot as plt
    yf.pdr_override()
    
    stock_list = sorted(["ANZ.AX", "WBC.AX", "MQG.AX", "CBA.AX", "NAB.AX"])
    number_of_decimal_places = 8
    moving_average_period = 3
    
    def get_moving_average(df, stock_name):
        df2 = df.rolling(window=moving_average_period).mean()
        df2.rename(columns={stock_name: stock_name.replace("Price", str(moving_average_period) + " day SMA")}, inplace=True)
        df = pd.concat([df, df2], axis=1, join_axes=[df.index])
        return df
    
    
    # Function to get the closing price of the individual stocks
    # from the stock_list list
    def get_closing_price(stock_name, specific_close):
        symbol = stock_name
        start_date = '2006-01-01'
        end_date = '2016-12-31'
        df = data.get_data_yahoo(symbol, start_date, end_date)
        sym = symbol + " "
        print(sym * 10)
        df = df.drop(['Open', 'High', 'Low', 'Adj Close', 'Volume'], axis=1)
    
        df = df.rename(columns={'Close': specific_close})
    
        # https://stackoverflow.com/questions/16729483/converting-strings-to-floats-in-a-dataframe
        # df[specific_close] = df[specific_close].astype('float64')
        # print(type(df[specific_close]))
        return df
    
    
    # Creates a big DataFrame with all the stock's Closing
    # Price returns the DataFrame
    def get_all_close_prices(directory):
        count = 0
        for stock_name in stock_list:
            specific_close = stock_name.replace(".AX", "") + " Price"
            if not count:
                prev_df = get_closing_price(stock_name, specific_close)
                prev_df = get_moving_average(prev_df,  specific_close)
            else:
                new_df = get_closing_price(stock_name, specific_close)
                new_df = get_moving_average(new_df, specific_close)
                # https://stackoverflow.com/questions/11637384/pandas-join-merge-concat-two-dataframes
                prev_df = prev_df.join(new_df)
            count += 1
        # prev_df.to_csv(directory)
    
        df = pd.DataFrame(prev_df, columns=list(prev_df))
        df = df.apply(pd.to_numeric)
        convert_df_to_csv(df, directory)
        return df
    
    
    def convert_df_to_csv(df, directory):
        df.to_csv(directory)
    
    def main():
        # FINDS THE CURRENT DIRECTORY AND CREATES THE CSV TO DUMP THE DF
        csv_in_current_directory = os.getcwd() + "/stock_output.csv"
    
        csv_in_current_directory_dow_distribution = os.getcwd() + "/dow_distribution.csv"
        # FUNCTION THAT GETS ALL THE CLOSING PRICES OF THE STOCKS
        # AND RETURNS IT AS ONE COMPLETE DATAFRAME
        df = get_all_close_prices(csv_in_current_directory)    
        print(df)
    
    
    # Main line of code
    if __name__ == "__main__":
        main()
    

    问题:

    从这个 df 多线图(每个库存一个图) 是的 许多行(价格和sma)。如何使用matplotlib执行此操作?这可以用for循环来完成吗,并在循环迭代时保存各个图?如果是,怎么办?

    2 回复  |  直到 7 年前
        1
  •  0
  •   Stef    7 年前

    首次进口 import matplotlib.pyplot as plt .

    个别地块

    df.plot(y=[0,1])
    df.plot(y=[2,3])
    df.plot(y=[4,5])
    df.plot(y=[6,7])
    df.plot(y=[8,9]) 
    
    plt.show()
    

    也可以在循环中保存各个绘图:

    for i in range(0,9,2):
       df.plot(y=[i,i+1])
       plt.savefig('{}.png'.format(i)) 
    

    子块

    fig, axes = plt.subplots(nrows=2, ncols=3)
    
    df.plot(ax=axes[0,0],y=[0,1])
    df.plot(ax=axes[0,1],y=[2,3])
    df.plot(ax=axes[0,2],y=[4,5])
    df.plot(ax=axes[1,0],y=[6,7])
    df.plot(ax=axes[1,1],y=[8,9])
    
    plt.show()  
    

    https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.plot.html 用于自定义绘图的选项。

        2
  •  0
  •   3kstc    7 年前

    最好的方法是创建一个依赖于列表大小的函数 是的

    def generate_SMA_graphs(df):
        columnNames = list(df.head(0))
        print("CN:\t", columnNames)
        print(len(columnNames))
    
        count = 0
        for stock in stock_list:
            stock_iter = count * (len(moving_average_period_list) + 1)
            sma_iter = stock_iter + 1
            for moving_average_period in moving_average_period_list:
                fig = plt.figure()
                df.plot(y=[columnNames[stock_iter], columnNames[sma_iter]])
                plt.xlabel('Time')
                plt.ylabel('Price ($)')
                graph_title = columnNames[stock_iter] + " vs. " + columnNames[sma_iter]
                plt.title(graph_title)
                plt.grid(True)
                plt.savefig(graph_title.replace(" ", "") + ".png")
                print("\t\t\t\tCompleted: ", graph_title)
                plt.close(fig)
                sma_iter += 1
            count += 1
    

    使用上面的代码,无论列表有多长(对于x或y,股票列表或SMA列表),上面的函数都将生成一个图表,将给定股票的原始价格与每个SMA进行比较。