1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > python股票回测源码_股票量化交易回测框架pyalgotrade源码阅读(一)

python股票回测源码_股票量化交易回测框架pyalgotrade源码阅读(一)

时间:2020-06-04 16:30:40

相关推荐

python股票回测源码_股票量化交易回测框架pyalgotrade源码阅读(一)

PyAlgoTrade是什么呢?

一个股票量化交易的策略回测框架。

而作者的说明如下。

To make it easy to backtest stock trading strategies.

简单的来说,是一个用于验证自己交易策略的框架。

适用以下场景:

我有个前无古人后无来者的想法,我觉得我按照这个想法去买股票稳赚不赔,但是为了稳妥起见,我需要测试一下这个我的这个想法到底用没有用,怎么测试呢?

大概下面两种方法

一:弄个模拟交易的软件,每天按照自己的想法买入卖出,然后看看一个月或者一年后的收益如何。

优点:更贴近现实,至少当下的现实

缺点:测试周期大,数据有限

二:我相信我的这个想法不是针对现在或者未来有用,甚至是在以前应该也是起作用的,那么我可以将历史数据调出来,用于测试,看看在历史行情中收益如何。

优点:数据充分,可以反复测试。

缺点:可能不能贴近现实

而pyalgotrade就是为了提供给使用者基于历史数据回测的框架,即为了让你更好的使用上述的第二种方法。

注:无论怎么测,肯定都有偏差的, 因为都是猜,就像×××,你算好了各种概率,想好了各种策略,但是你能保证的只是你赢钱的概率大一些,而不是必赢,因为在没有欺诈的情况下,未来是不可测,也不能确定的,谁也不能预知未来~吧~

文章目录官方示例

设计模式之观察者模式

源码解析

官方示例

sma_crossover.py文件frompyalgotradeimportstrategy

frompyalgotrade.technicalimportma

frompyalgotrade.technicalimportcross

classSMACrossOver(strategy.BacktestingStrategy):

def__init__(self,feed,instrument,smaPeriod):

super(SMACrossOver,self).__init__(feed)

self.__instrument=instrument

self.__position=None

#We'lluseadjustedclosevaluesinsteadofregularclosevalues.

self.setUseAdjustedValues(True)

self.__prices=feed[instrument].getPriceDataSeries()

self.__sma=ma.SMA(self.__prices,smaPeriod)

defgetSMA(self):

returnself.__sma

defonEnterCanceled(self,position):

self.__position=None

defonExitOk(self,position):

self.__position=None

defonExitCanceled(self,position):

#Iftheexitwascanceled,re-submitit.

self.__position.exitMarket()

defonBars(self,bars):

#Ifapositionwasnotopened,checkifweshouldenteralongposition.

ifself.__positionisNone:

ifcross.cross_above(self.__prices,self.__sma)>0:

shares=int(self.getBroker().getCash()*0.9/bars[self.__instrument].getPrice())

#Enterabuymarketorder.Theorderisgoodtillcanceled.

self.__position=self.enterLong(self.__instrument,shares,True)

#Checkifwehavetoexittheposition.

elifnotself.__position.exitActive()andcross.cross_below(self.__prices,self.__sma)>0:

self.__position.exitMarket()

sma_crossover_sample.pyimportsma_crossover

frompyalgotradeimportplotter

frompyalgotrade.toolsimportyahoofinance

frompyalgotrade.stratanalyzerimportsharpe

defmain(plot):

instrument="aapl"

smaPeriod=163

#Downloadthebars.

feed=yahoofinance.build_feed([instrument],,,".")

strat=sma_crossover.SMACrossOver(feed,instrument,smaPeriod)

sharpeRatioAnalyzer=sharpe.SharpeRatio()

strat.attachAnalyzer(sharpeRatioAnalyzer)

ifplot:

plt=plotter.StrategyPlotter(strat,True,False,True)

plt.getInstrumentSubplot(instrument).addDataSeries("sma",strat.getSMA())

strat.run()

print"Sharperatio:%.2f"%sharpeRatioAnalyzer.getSharpeRatio(0.05)

ifplot:

plt.plot()

if__name__=="__main__":

main(True)

上面的代码主要做一件这样的事。

创建了一个策略,这个策略就是你的想法,这个想法是什么呢?

想法是,当价格高于近163日内的平均价格就买入,低于近163日内的平均价格就卖出(平仓)。

其实还做了其他的事,比如策略分析之类的,但是这篇文章暂时忽略。

设计模式之观察者模式#!/usr/bin/python

#coding:utf8

'''

Observer

'''

classSubject(object):

def__init__(self):

self._observers=[]

defattach(self,observer):

ifnotobserverinself._observers:

self._observers.append(observer)

defdetach(self,observer):

try:

self._observers.remove(observer)

exceptValueError:

pass

defnotify(self,modifier=None):

forobserverinself._observers:

ifmodifier!=observer:

observer.update(self)

#Exampleusage

classData(Subject):

def__init__(self,name=''):

Subject.__init__(self)

self.name=name

self._data=0

@property

defdata(self):

returnself._data

@data.setter

defdata(self,value):

self._data=value

self.notify()

classHexViewer:

defupdate(self,subject):

print('HexViewer:Subject%shasdata0x%x'%

(subject.name,subject.data))

classDecimalViewer:

defupdate(self,subject):

print('DecimalViewer:Subject%shasdata%d'%

(subject.name,subject.data))

#Exampleusage...

defmain():

data1=Data('Data1')

data2=Data('Data2')

view1=DecimalViewer()

view2=HexViewer()

data1.attach(view1)

data1.attach(view2)

data2.attach(view2)

data2.attach(view1)

print("SettingData1=10")

data1.data=10

print("SettingData2=15")

data2.data=15

print("SettingData1=3")

data1.data=3

print("SettingData2=5")

data2.data=5

print("DetachHexViewerfromdata1anddata2.")

data1.detach(view2)

data2.detach(view2)

print("SettingData1=10")

data1.data=10

print("SettingData2=15")

data2.data=15

if__name__=='__main__':

main()

意图:

定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都得到通知并被自动更新。

适用性:

当一个抽象模型有两个方面,其中一个方面依赖于另一方面。将这二者封装在独立的对象中以使它们可以各自独立地改变和复用。

当对一个对象的改变需要同时改变其它对象,而不知道具体有多少对象有待改变。

当一个对象必须通知其它对象,而它又不能假定其它对象是谁。换言之,你不希望这些对象是紧密耦合的。

如果你看得懂就略过吧。

上面的代码想做个上面事情呢?

想达到事件的目的,即,在更新数据的时候,会触发相关的事件。

上面定义了主要三个种类型的类,subject,data,viewer。

其中subject是data的父类。

通过attach的操作,将不同的viewer加入到self.__observers列表里面,当data对象要更新数据的时候,就回调用notify方法,而notify方法则会遍历self.__observers列表的每个observer,然后依次调用其update方法。

这也是为毛hexViewer,DecimalViewer都要实现自身的update方法。

为毛要这么写?

前人总结的经验~

能不能不这么写?

可以的。

如果看不懂这个设计模式,那么pyalgotrade的源码看起来可能会比较吃力,但是也只是可能而已,因为很多人看不懂,只是因为没有实际的有用场景而已。

源码解析

首先是框架,看一遍,比如那些模块,不过个人经验之谈就是,看完之后,一般都会有一下迷思。

为毛这么写?

这里到底想干什么?

这么复杂有毛用~

恩,我也是这种感觉~

一般是pdb跟一遍流程或者一个一个找继承关系。

pdb这里就不讲了,主要就是跟每个方法调用死磕到底,当然了,你也许有你得方法,我比较较真就是这样看源代码的,至少现在是这样的。

在看源代码之前,官方文档,示例什么的最好也看一下,这样才能跟接近作者的意思。

这里面有个对象,需要着重声明,那就是bar。

什么是bar呢?

每个bar都是一个时刻股票各个价格的集合,即,当前价格,当前时间,最高价,最低价,成交量什么的。

而这些属性都是通过get_xxx的方法获取的。

获取数据

很明显数据是通过下面这行代码获取的。feed=yahoofinance.build_feed([instrument],,,".")

build_feed方法在tools/yahoofinance.pydefbuild_feed(instruments,fromYear,toYear,storage,frequency=bar.Frequency.DAY,timezone=None,skipErrors=False):

logger=pyalgotrade.logger.getLogger("yahoofinance")

logger=pyalgotrade.logger.getLogger("yahoofinance")

ret=yahoofeed.Feed(frequency,timezone)

foryearinrange(fromYear,toYear+1):

forinstrumentininstruments:

fileName=os.path.join(storage,"%s-%d-yahoofinance.csv"%(instrument,year))

ifnotos.path.exists(fileName):

logger.info("Downloading%s%dto%s"%(instrument,year,fileName))

try:

iffrequency==bar.Frequency.DAY:

download_daily_bars(instrument,year,fileName)

eliffrequency==bar.Frequency.WEEK:

download_weekly_bars(instrument,year,fileName)

else:

raiseException("Invalidfrequency")

exceptException,e:

ifskipErrors:

logger.error(str(e))

continue

else:

raisee

ret.addBarsFromCSV(instrument,fileName)

returnret

在build_feed函数里面又根据情况调用了相应的下载函数defdownload_csv(instrument,begin,end,frequency):

url="http://ichart./table.csv?s=%s&a=%d&b=%d&c=%d&d=%d&e=%d&f=%d&g=%s&ignore=.csv"%(instrument,__adjust_month(begin.month),begin.day,begin.year,__adjust_month(end.month),end.day,end.year,frequency)

returncsvutils.download_csv(url)

而最终执行的下载函数为download_csv,通过这个函数我们可以访问yahoo的api,最终下载函数,当然了,可以进一步的查看csvutils.download_csv函数。

这里我们知道数据是通过download_csv这个函数,将相应的股票代码,开始结束时间及频率传入,然后访问相应的url,得到相应的数据。

feed对象

在tools/yahoofinance.py中我们可以看到,返回的结果并不是一个csv的对象,而是一个ret即,Feed对象,而Feed对象通过addBarsFromCSV将下载的数据加载到内存。

从这里你也许会开始抓狂了为毛一层一层的继承。

其中yahoofeed.Feed在barfeed/yahoofeed.pyclassFeed(csvfeed.BarFeed):

defaddBarsFromCSV(self,instrument,path,timezone=None):

rowParser=RowParser(

self.getDailyBarTime(),self.getFrequency(),timezone,self.__sanitizeBars,self.__barClass

)

super(Feed,self).addBarsFromCSV(instrument,path,rowParser)

上面调用了父类的addBarsFromCSV方法。

父类的addBarsFromCSV在barfeed/csvfeed.pyclassBarFeed(membf.BarFeed):

defaddBarsFromCSV(self,instrument,path,rowParser):

#Loadthecsvfile

loadedBars=[]

reader=csvutils.FastDictReader(open(path,"r"),fieldnames=rowParser.getFieldNames(),delimiter=rowParser.getDelimiter())

forrowinreader:

bar_=rowParser.parseBar(row)

ifbar_isnotNoneand(self.__barFilterisNoneorself.__barFilter.includeBar(bar_)):

loadedBars.append(bar_)

self.addBarsFromSequence(instrument,loadedBars)

然后csvfeed又调用了父类的方法~

值得注意的是,上面的rowParser.parseBar方法在子类实现的 。。。后面会在提及。

addBarsFromSequence方法在barfeed/membf.pyclassBarFeed(barfeed.BaseBarFeed):

defaddBarsFromSequence(self,instrument,bars):

ifself.__started:

raiseException("Can'taddmorebarsonceyoustartedconsumingbars")

self.__bars.setdefault(instrument,[])

self.__nextPos.setdefault(instrument,0)

#Addandsortthebars

self.__bars[instrument].extend(bars)

barCmp=lambdax,y:cmp(x.getDateTime(),y.getDateTime())

self.__bars[instrument].sort(barCmp)

self.registerInstrument(instrument)

然后又调用了父类的方法~

值得注意的是这里将yahoo的数据存在了self.__bars中,至于bars是什么对象,后面再说。

registerInstrument方法在barfeed/__init__.pyclassBaseBarFeed(feed.BaseFeed):

defregisterInstrument(self,instrument):

self.__defaultInstrument=instrument

self.registerDataSeries(instrument)

然后又调用了父类的方法~

registerDataSeries方法在feed/__init__.pyclassBaseFeed(observer.Subject):

def__init__(self,maxLen):

super(BaseFeed,self).__init__()

maxLen=dataseries.get_checked_max_len(maxLen)

self.__ds={}

self.__event=observer.Event()

self.__maxLen=maxLen

defregisterDataSeries(self,key):

ifkeynotinself.__ds:

self.__ds[key]=self.createDataSeries(key,self.__maxLen)

恩,这里就是逻辑的终点了,虽然它还是继承,不过pyalgotrade里面大多数对象都是是继承observer.Subject,之所以继承,是为了完成类似观察者的设计模式里面的事件操作。

简单总结一下继承关系。

barfeed/yahoofeed.Feed -> barfeed/csvfeed.BarFeed -> barfeed/membf.BarFeed -> barfeed/__init__.py.BaseFeed -> feed/__init.py.BaseFeed

然后yahoo的数据结果,最终是由RowParser的parseBar方法依次导入,而RowPaser.parseBar方法是在barfeed/yahoofeed.py中。

然后我们再来走一遍加载数据的流程,不过这次不只是整个逻辑,而这次我们关注于具体的数据是啥。

其中barfeed/yahoofeed里面的RowParser的逻辑及parsrBar的具体的具体实现,截取如下。classRowParser(csvfeed.RowParser):

def__init__(self,dailyBarTime,frequency,timezone=None,sanitize=False,barClass=bar.BasicBar):

self.__dailyBarTime=dailyBarTime

self.__frequency=frequency

self.__timezone=timezone

self.__sanitize=sanitize

self.__barClass=barClass

def__parseDate(self,dateString):

ret=parse_date(dateString)

#TimeonYahoo!FinanceCSVfilesisempty.Iftoldtosetone,doit.

ifself.__dailyBarTimeisnotNone:

ret=bine(ret,self.__dailyBarTime)

#Localizethedatetimeifatimezonewasgiven.

ifself.__timezone:

ret=dt.localize(ret,self.__timezone)

returnret

defgetFieldNames(self):

#Itisexpectedforthefirstrowtohavethefieldnames.

returnNone

defgetDelimiter(self):

return","

defparseBar(self,csvRowDict):

dateTime=self.__parseDate(csvRowDict["Date"])

close=float(csvRowDict["Close"])

open_=float(csvRowDict["Open"])

high=float(csvRowDict["High"])

low=float(csvRowDict["Low"])

volume=float(csvRowDict["Volume"])

adjClose=float(csvRowDict["AdjClose"])

ifself.__sanitize:

open_,high,low,close=common.sanitize_ohlc(open_,high,low,close)

returnself.__barClass(dateTime,open_,high,low,close,volume,adjClose,self.__frequency)

其中解析后返回的结果是一个bar.BasicBar对象。

然后调用父类barfeed/csvfeed里面的addBarsFromCSV方法,得到一个bar.BasicBar对象的列表,即loadBars。传入继承于父类的addBarsFromSequence方法,截取如下。classBarFeed(membf.BarFeed):

defaddBarsFromCSV(self,instrument,path,rowParser):

#Loadthecsvfile

loadedBars=[]

reader=csvutils.FastDictReader(open(path,"r"),fieldnames=rowParser.getFieldNames(),delimiter=rowParser.getDelimiter())

forrowinreader:

bar_=rowParser.parseBar(row)

ifbar_isnotNoneand(self.__barFilterisNoneorself.__barFilter.includeBar(bar_)):

loadedBars.append(bar_)

self.addBarsFromSequence(instrument,loadedBars)

下面则是处理addBarsFromSequence的操作,主要是创建了一个self.__bars的字典,每个股票代码对应相应时间段的bar.BasicBar对象的列表,然后调用父类的registerInstrument方法,传入相应的股票代码。

barfeed/membf.py --> BarFeedclassBarFeed(barfeed.BaseBarFeed):

defaddBarsFromSequence(self,instrument,bars):

ifself.__started:

raiseException("Can'taddmorebarsonceyoustartedconsumingbars")

self.__bars.setdefault(instrument,[])

self.__nextPos.setdefault(instrument,0)

#Addandsortthebars

self.__bars[instrument].extend(bars)

barCmp=lambdax,y:cmp(x.getDateTime(),y.getDateTime())

self.__bars[instrument].sort(barCmp)

self.registerInstrument(instrument)

下面则是registerInstrument的具体逻辑,即注册DataSeries对象,而registerDataSeries方法是在父类实现。

barfeed/__init__.py --->BaseBarFeedBaseBarFeed(feed.BaseFeed):

defregisterInstrument(self,instrument):

self.__defaultInstrument=instrument

self.registerDataSeries(instrument)

下面则是最终的registerDataSeries操作,创建了一个dataseries的对象。

feed/__init__.py --->BaseFeedclassBaseFeed(observer.Subject):

defregisterDataSeries(self,key):

ifkeynotinself.__ds:

self.__ds[key]=self.createDataSeries(key,self.__maxLen)

而createDataSeries方法并没有在基类中实现。@abc.abstractmethod

defcreateDataSeries(self,key,maxLen):

raiseNotImplementedError()

createDataSeries的具体实现则是在barfeed/__init__.py --->BaseBarFeeddefcreateDataSeries(self,key,maxLen):

ret=bards.BarDataSeries(maxLen)

ret.setUseAdjustedValues(self.__useAdjustedValues)

returnret

所以最终,feed对象有两个重要的数据集。

一:

self.__bars

里面的数据结构大概是{"instrument_xx":[bar1,bar2,bar3]}

self.__ds = {}

里面的数据结构大概是self.__ds = {"instrument_xx": dataseries_xx}

其中instrument指特定的股票代码,比如aapl,bar1,bar2则是bar.BasicBar对象,dataseries则是bards.BarDataSeries对象。

至于bar.BasicBar以及dataseries的数据结构到底是什么,大家可以自行瞧瞧。

值得注意的是,父类与基类之间数据获取不会通过共享变量的方式获得,比如最终通过基类self.__ds的数据是通过基类的getKeys的方法暴露给子类去获取实际的数据。。

策略

初始化策略strat=sma_crossover.SMACrossOver(feed,instrument,smaPeriod)

策略最终继承于strategy.BacktestingStrategy

analyzer

创建一个stratanalyzer的实例并attachsharpeRatioAnalyzer=sharpe.SharpeRatio()

strat.attachAnalyzer(sharpeRatioAnalyzer)

analyzer这里暂时不说,因为,这里主要将具体的策略实现,以及feed对象,analyzer以及broker的内容会放在下一篇文章讲。

run

运行策略。strat.run()

run方法在strategy/__init__.py里面的BaseStrategy类。classBaseStrategy(object):

defrun(self):

"""Callonce(**andonlyonce**)torunthestrategy."""

self.__dispatcher.run()

ifself.__barFeed.getCurrentBars()isnotNone:

self.onFinish(self.__barFeed.getCurrentBars())

else:

raiseException("Feedwasempty")

而run方法会调用self.__dispatcher的run方法,即dispatcher.py里面的Dispatcher类,在说Dispatcher类之前,我们得先看看BaseStrategy在初始化的时候到底初始化了啥。classBaseStrategy(object):

def__init__(self,barFeed,broker):

self.__barFeed=barFeed

self.__broker=broker

self.__activePositions=set()

self.__orderToPosition={}

self.__barsProcessedEvent=observer.Event()

self.__analyzers=[]

self.__namedAnalyzers={}

self.__resampledBarFeeds=[]

self.__dispatcher=dispatcher.Dispatcher()

self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)

self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)

self.__dispatcher.getStartEvent().subscribe(self.onStart)

self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)

#Itisimportanttodispatchbrokereventsbeforefeedevents,speciallyifwe'rebacktesting.

self.__dispatcher.addSubject(self.__broker)

self.__dispatcher.addSubject(self.__barFeed)

绑定barFeed,broker到self,初始化__activePositions,OderToPosition,__analyzers,__namedAnlyzers,__resampledBarFeeds的值,并初始化一个observer.Event的实例。

创建一个dispatcher的实例,并在dispatcher的初始化过程中创建两个observer.Event,observer.Event的实例。

其中broker实例通过getOrderUpdatedEvent方法得到一个event实例,并订阅策略的onOrderEvent的事件

barFeed实例通过getNewValuesEvent方法得到一个event实例,并订阅策略的onBars的事件。

dispatcher的实例分别获得startEvent,IdleEvent并订阅onStart,__onIdle事件。

最后dispatcher实例将broker,barFeed两个subject分别加入到dispatcher的subjects列表中。

然后我们在回到Dispatcher类的run方法,这里主要是首先遍历自己__subjects列表里面的subject,然后调用每个subject的start方法,由BaseStrategy类的初始化方法可知,dispatcher加入了两个subject,分别是broker,barFeed。

具体实现如下。classDispatcher(object):

defrun(self):

try:

forsubjectinself.__subjects:

subject.start()

self.__startEvent.emit()

whilenotself.__stop:

eof,eventsDispatched=self.__dispatch()

ifeof:

self.__stop=True

elifnoteventsDispatched:

self.__idleEvent.emit()

finally:

forsubjectinself.__subjects:

subject.stop()

forsubjectinself.__subjects:

subject.join()

整个回测策略的逻辑基本就是在dispatcher调度各个subject并触发事件的过程。

调用完每个subject的start方法后,执行自身的self.__startEvent.emit方法。

然后通过while循环启动整个运转逻辑。

在循环结束后依次启动每个subject并等待所有subject关闭。

现在再次回到初始化过程,查看各个event,subject的内容到底是什么。self.__broker.getOrderUpdatedEvent().subscribe(self.__onOrderEvent)

def__onOrderEvent(self,broker_,orderEvent):

order=orderEvent.getOrder()

self.onOrderUpdated(order)

self.__barFeed.getNewValuesEvent().subscribe(self.__onBars)

def__onBars(self,dateTime,bars):

#THEORDERHEREISVERYIMPORTANT

#1:Letanalyzersprocessbars.

self.__notifyAnalyzers(lambdas:s.beforeOnBars(self,bars))

#2:Letthestrategyprocesscurrentbarsandsubmitorders.

self.onBars(bars)

#3:Notifythatthebarswereprocessed.

self.__barsProcessedEvent.emit(self,bars)

self.__dispatcher.getStartEvent().subscribe(self.onStart)

defonStart(self):

"""Override(optional)togetnotifiedwhenthestrategystartsexecuting.Thedefaultimplementationisempty."""

pass

self.__dispatcher.getIdleEvent().subscribe(self.__onIdle)

def__onIdle(self):

#Forcearesamplechecktoavoiddependingsolelyontheunderlying

#barfeedevents.

forresampledBarFeedinself.__resampledBarFeeds:

resampledBarFeed.checkNow(self.getCurrentDateTime())

self.onIdle()

上面是各个event订阅的subject,是相应的handler函数。

然后现在瞧瞧每个subject的start方法。

其中observer.py里面定义的Subject类似一个抽象工厂,只是定义了各个方法但是并没有实现具体方法的逻辑。

我们首先来看看broker这个subject的start方法的处理逻辑。

而继承observer.Subject的Broker也只是一个抽象工厂,定义了一系列的接口。

在此策略中,我们据代码得知,我们初始化的broker是一个backtesting的broker,代码如下。classBacktestingStrategy(BaseStrategy):

def__init__(self,barFeed,cash_or_brk=1000000):

#ThebrokershouldsubscribetobarFeedeventsbeforethestrategy.

#Thisistoavoidexecutingorderssubmittedinthecurrenttick.

ifisinstance(cash_or_brk,pyalgotrade.broker.Broker):

broker=cash_or_brk

else:

broker=backtesting.Broker(cash_or_brk,barFeed)

查看backtesting的broker

broker/backtesting.py

classBroker(broker.Broker):

defstart(self):

super(Broker,self).start()

查看backtesting的broker -> broker/backtesting.pyclassBroker(broker.Broker):

defstart(self):

super(Broker,self).start()

其中基类的start如下observer.py

classSubject(object):

@abc.abstractmethod

defstart(self):

pass

然后再来看barFeed的subject的start

其中barFeed也没有自己定义start方法,即,start方法也是如上。

在每个subject调用start方法后,dispatcher就会调用自身self.__startEvent.emit。然后到循环eof, eventsDispatched = self.__dispatch()def__dispatch(self):

smallestDateTime=None

eof=True

eventsDispatched=False

#Scanforthelowestdatetime.

forsubjectinself.__subjects:

ifnotsubject.eof():

eof=False

smallestDateTime=utils.safe_min(smallestDateTime,subject.peekDateTime())

再次实例创建的feed为yahoofeed

而依次继承于csvfeed.BarFeed,membf.BarFeed,barfeed.BaseBaseFeed,feed.BaseFeed

其中membf.BarFeed,BaseBarFeed都实现了eof方法。

通过代码追踪,我们发现eof主要为了判断是否以及迭代完每一个bar

代码如下defeof(self):

ret=True

#Checkifthereisatleastonemorebartoreturn.

forinstrument,barsinself.__bars.iteritems():

nextPos=self.__nextPos[instrument]

ifnextPos

ret=False

break

returnret

其中self.__nextPos在addBarsFromSequence函数里面已经将其定义为0,也就是说,这个nextPos是为了在迭代每个bar的同时记录迭代的位置,即索引位置。

当判断完eof之后,则调用__dispatchSubject方法,迭代每个subject并调用其dispatch方法。

其中dispatch的实现在基类feed/__init__.pyclassBaseFeed(observer.Subject):

defdispatch(self):

dateTime,values=self.getNextValuesAndUpdateDS()

ifdateTimeisnotNone:

self.__event.emit(dateTime,values)

returndateTimeisnotNone

getNextValuesAndUpdateDS方法实现在feed/__init__.pydefgetNextValuesAndUpdateDS(self):

dateTime,values=self.getNextValues()

ifdateTimeisnotNone:

forkey,valueinvalues.items():

#Getorcreatethedatseriesforeachkey.

try:

ds=self.__ds[key]

exceptKeyError:

ds=self.createDataSeries(key,self.__maxLen)

self.__ds[key]=ds

ds.appendWithDateTime(dateTime,value)

return(dateTime,values)

def__iter__(self):

returnfeed_iterator(self)

而getNextValues的方法实现在barfeed/__init__.pyclassBaseBarFeed(feed.BaseFeed):

defgetNextValues(self):

dateTime=None

bars=self.getNextBars()

ifbarsisnotNone:

dateTime=bars.getDateTime()

#Checkthatcurrentbardatetimesaregreaterthanthepreviousone.

ifself.__currentBarsisnotNoneandself.__currentBars.getDateTime()>=dateTime:

raiseException(

"Bardatetimesarenotinorder.Previousdatetimewas%sandcurrentdatetimeis%s"%(

self.__currentBars.getDateTime(),

dateTime

)

)

#Updateself.__currentBarsandself.__lastBars

self.__currentBars=bars

forinstrumentinbars.getInstruments():

self.__lastBars[instrument]=bars[instrument]

return(dateTime,bars)

其中 getNextBars的方法实现在barfeed/membf.pyclassBarFeed(barfeed.BaseBarFeed):

defgetNextBars(self):

#Allbarsmusthavethesamedatetime.Wewillreturnalltheoneswiththesmallestdatetime.

smallestDateTime=self.peekDateTime()

ifsmallestDateTimeisNone:

returnNone

#Makeasecondpasstogetallthebarsthathadthesmallestdatetime.

ret={}

forinstrument,barsinself.__bars.iteritems():

nextPos=self.__nextPos[instrument]

ifnextPos

ret[instrument]=bars[nextPos]

self.__nextPos[instrument]+=1

ifself.__currDateTime==smallestDateTime:

raiseException("Duplicatebarsfoundfor%son%s"%(ret.keys(),smallestDateTime))

self.__currDateTime=smallestDateTime

returnbar.Bars(ret)

其中Bars对象则是对bar的进一层封装

提供方法如下。def__getitem__(self,instrument):

returnself.__barDict[instrument]

def__contains__(self,instrument):

returninstrumentinself.__barDict

defitems(self):

defkeys(self):

defgetInstruments(self):

defgetDateTime(self):

defgetBar(self,instrument):

至此,我们了解到了feed对象,以及每个bar是怎么迭代的,但是还没有看到每个bar的处理操作。

所以在回到feed的dispatch方法,处理流程如下defdispatch(self):

dateTime,values=self.getNextValuesAndUpdateDS()

ifdateTimeisnotNone:

self.__event.emit(dateTime,values)

returndateTimeisnotNone

需要着重说明的就是self.__event.emit(dateTime, values)

其中values是一个bar.Bars实例。

broker的dispatch方法defdispatch(self):

#Alleventswerealreadyemittedwhilehandlingbarfeedevents.

pass

这里,我们可以看到如果dataTime不是None的话,就会通过emit提交时间

而feed里面注册了__onBars的handlers

所以在每次迭代的时候都会触发event的emit操作,即执行每个在feed中注册了的handler,这里只注册了一个handler--->__onBarsdef__onBars(self,dateTime,bars):

#THEORDERHEREISVERYIMPORTANT

#1:Letanalyzersprocessbars.

self.__notifyAnalyzers(lambdas:s.beforeOnBars(self,bars))

#2:Letthestrategyprocesscurrentbarsandsubmitorders.

self.onBars(bars)

#3:Notifythatthebarswereprocessed.

self.__barsProcessedEvent.emit(self,bars)

所以迭代每一个bar的时候,都会执行onBar的函数。

而onBar函数是自己定义的,在本示例中,onBar的函数内容如下defonBars(self,bars):

defonBars(self,bars):

#Ifapositionwasnotopened,checkifweshouldenteralongposition.

ifself.__positionisNone:

ifcross.cross_above(self.__prices,self.__sma)>0:

shares=int(self.getBroker().getCash()*0.9/bars[self.__instrument].getPrice())

#Enterabuymarketorder.Theorderisgoodtillcanceled.

self.__position=self.enterLong(self.__instrument,shares,True)

#Checkifwehavetoexittheposition.

elifnotself.__position.exitActive()andcross.cross_below(self.__prices,self.__sma)>0:

self.__position.exitMarket()

bar是每个指定频率的open,close,low,high,adj close,volume数据集合对象。

DataSeries是一个随着迭代,不断增加datetime,以及bar的序列。

而technical的触发是在feed/__init__.py里面的ds.appendWithDateTime。defgetNextValuesAndUpdateDS(self):

dateTime,values=self.getNextValues()

ifdateTimeisnotNone:

forkey,valueinvalues.items():

#Getorcreatethedatseriesforeachkey.

try:

ds=self.__ds[key]

exceptKeyError:

ds=self.createDataSeries(key,self.__maxLen)

self.__ds[key]=ds

ds.appendWithDateTime(dateTime,value)

return(dateTime,values)

然后ma.pyclassSMA(technical.EventBasedFilter):

def__init__(self,dataSeries,period,maxLen=None):

super(SMA,self).__init__(dataSeries,SMAEventWindow(period),maxLen)

然后technical/__init__.pyclassEventBasedFilter(dataseries.SequenceDataSeries):

def__init__(self,windowSize,dtype=float,skipNone=True):

assert(windowSize>0)

assert(isinstance(windowSize,int))

self.__values=collections.NumPyDeque(windowSize,dtype)

self.__windowSize=windowSize

self.__skipNone=skipNone

def__onNewValue(self,dataSeries,dateTime,value):

#Lettheeventwindowperformcalculations.

self.__eventWindow.onNewValue(dateTime,value)

#Gettheresultingvalue

newValue=self.__eventWindow.getValue()

#Addthenewvalue.

self.appendWithDateTime(dateTime,newValue)

而__eventWindow.onNewValue在technical/ma.pyclassSMAEventWindow(technical.EventWindow):

def__init__(self,period):

assert(period>0)

super(SMAEventWindow,self).__init__(period)

self.__value=None

defonNewValue(self,dateTime,value):

firstValue=None

iflen(self.getValues())>0:

firstValue=self.getValues()[0]

assert(firstValueisnotNone)

super(SMAEventWindow,self).onNewValue(dateTime,value)

ifvalueisnotNoneandself.windowFull():

ifself.__valueisNone:

self.__value=self.getValues().mean()

else:

self.__value=self.__value+value/float(self.getWindowSize())-firstValue/float(self.getWindowSize())

defgetValue(self):

returnself.__value

至此基于pyalgotrade的一个简单示例,按照其执行流程的源码解读到此完毕。

后记:后面有点乱了,写篇文章还是蛮费时间的,太长了,pyalgotrade的源码解读估计还得写一段时间去了。

参考链接:

如果觉得不错,并有所收获,请我喝杯茶呗

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。