在前几节中,我们积累了一些代码,为了方便以后调用,并使界面更清晰,我们决定将这些代码整合,并放入一个新的Python文件中,从而创建我们自己的函数库。接下来,让我们一步步来完成这个过程。
1. 修改 get_prices
函数
首先,我们将之前的 get_prices
函数稍作修改,这里加入了 interval
变量,以便更方便地调用周K线和月K线。
def get_prices(stock_symbol, interval="1d"):
# 1wk for 1 week, 1mo for 1 month
data = yf.download(stock_symbol, interval=interval)
return data
2. 更新 get_moving_average
函数
接着,在 get_moving_average
函数中,我们加入了一个 mode
变量,使其可以输出简单移动平均(SMA)或指数移动平均(EMA)。
def get_moving_average(prices,
window_size,
mode="sma"):
if mode == "sma":
sma = prices['Adj Close'].rolling(window=window_size).mean()
return sma
elif mode == "ema":
ema = prices['Adj Close'].emw(span=window_size).mean()
return ema
else:
warnings.warn(f"{mode} is not a known mode!!")
3. 将 plot_in_chart
函数写入函数库
我们将上一节的交易视图图表作为默认股价显示器,并将其写入函数 plot_in_chart
中。
def plot_in_chart(prices,
ma_windows_size,
ma_mode="sma"):
chart = Chart()
chart.set(prices)
chart.legend(visible=True)
column_name = f'{ma_mode} {ma_windows_size}'.upper()
line = chart.create_line(column_name)
sma = pd.DataFrame({"Date": prices.index,
column_name: get_moving_average(prices, ma_windows_size, ma_mode)})
line.set(sma)
chart.show(block=True)
4. 整合所有代码
最后,将所有代码整合在一起,并显示10月均线和月K线。下面是完整的代码示例:
import warnings
import pandas as pd
import yfinance as yf
from lightweight_charts import Chart
def get_prices(stock_symbol, interval="1d"):
# 1wk for 1 week, 1mo for 1 month
data = yf.download(stock_symbol, interval=interval)
return data
def get_moving_average(prices,
window_size,
mode="sma"):
if mode == "sma":
sma = prices['Adj Close'].rolling(window=window_size).mean()
return sma
elif mode == "ema":
ema = prices['Adj Close'].emw(span=window_size).mean()
return ema
else:
warnings.warn(f"{mode} is not a known mode!!")
def plot_in_chart(prices,
ma_windows_size,
ma_mode="sma"):
chart = Chart()
chart.set(prices)
chart.legend(visible=True)
column_name = f'{ma_mode} {ma_windows_size}'.upper()
line = chart.create_line(column_name)
sma = pd.DataFrame({"Date": prices.index,
column_name: get_moving_average(prices, ma_windows_size, ma_mode)})
line.set(sma)
chart.show(block=True)
def main():
prices = get_prices("^NDX", "1mo")
plot_in_chart(prices, 10)
if __name__ == '__main__':
main()
5. 创建函数库文件
现在,我们新建一个Python文件,比如我起名utils.py,然后将所有函数放入一个新的Python文件中,以便更好地组织代码。

# utils.py
import warnings
import pandas as pd
import yfinance as yf
from lightweight_charts import Chart
def get_prices(stock_symbol, interval="1d"):
# 1wk for 1 week, 1mo for 1 month
data = yf.download(stock_symbol, interval=interval)
return data
def get_moving_average(prices, window_size, mode="sma"):
if mode == "sma":
sma = prices['Adj Close'].rolling(window=window_size).mean()
return sma
elif mode == "ema":
ema = prices['Adj Close'].emw(span=window_size).mean()
return ema
else:
warnings.warn(f"{mode} is not a known mode!!")
def plot_in_chart(prices, ma_windows_size, ma_mode="sma"):
chart = Chart()
chart.set(prices)
chart.legend(visible=True)
column_name = f'{ma_mode} {ma_windows_size}'.upper()
line = chart.create_line(column_name)
sma = pd.DataFrame({"Date": prices.index,
column_name: get_moving_average(prices, ma_windows_size, ma_mode)})
line.set(sma)
chart.show(block=True)
6. 使用自定义函数库
最后,我们在 main.py
文件中引入自定义的函数库,并调用其中的函数。
# main.py
from utils import get_prices, plot_in_chart
def main():
prices = get_prices("^NDX", "1mo")
plot_in_chart(prices, 10)
if __name__ == '__main__':
main()
通过以上步骤,我们成功地整理了代码,使其更加简洁易读。现在,我们可以更方便地调用这些函数,并且将它们用于其他项目中。
学习了,简单易懂,每一期我都有认真看。感谢大佬的分享
很高兴听到这些内容对你有帮助:)