1000字范文,内容丰富有趣,学习的好帮手!
1000字范文 > 股票接口指数行情模型建立代码分享

股票接口指数行情模型建立代码分享

时间:2023-12-30 06:42:37

相关推荐

股票接口指数行情模型建立代码分享

基于行情数据,我们构建一个多标签的深度回归模型,将日内技术特征、时间特征、国内市场日间特征、美国市场日间特征、国内市场股票估值特征等因子联合建模,希望学到他们之间复杂的关系。主体的代码框架如下,其中TickerEncoder是一个Embedding模块,负责将所有涉及到的股票ticker编码然后映射成一个向量,TimeEncoder负责将时间编码成一个向量,FeatureEncoder则将输入的特征数据进行加工后进行池化降维,然后统一拼接到一起,通过分类器预测出36个类别。

class DeepPredictor(nn.Module):

def __init__(self, ticker_dim=16, depth=1, output_dim=32, class_num=36):

super(DeepPredictor, self).__init__()

self.ticker_encoder = TickerEncoder(dim=ticker_dim)

self.time_encoder = TimeEncoder(dim=4, out_dim=output_dim)

self.intraday_tech_encoder = FeatureEncoder(self.ticker_encoder, tech_dim=18, ticker_dim=ticker_dim, output_dim=output_dim, depth=depth)

self.daily_tech_cn_encoder = FeatureEncoder(self.ticker_encoder, tech_dim=23, ticker_dim=ticker_dim, output_dim=output_dim, depth=depth)

self.daily_tech_us_encoder = FeatureEncoder(self.ticker_encoder, tech_dim=23, ticker_dim=ticker_dim, output_dim=output_dim, depth=depth)

self.daily_fin_cn_encoder = FeatureEncoder(self.ticker_encoder, tech_dim=23, ticker_dim=ticker_dim, output_dim=output_dim, depth=depth)

self.pooling = nn.AdaptiveAvgPool1d(1)

self.pooling2 = nn.AdaptiveMaxPool1d(1)

self.classifier = nn.Linear(in_features=output_dim * 6, out_features=class_num)

def forward(self, timestamp, intraday_tech_fea, intraday_tech_ticker, daily_tech_fea_cn, daily_tech_ticker_cn, daily_fin_fea_cn, daily_fin_ticker_cn, daily_tech_fea_us, daily_tech_ticker_us):

time_out = self.time_encoder(timestamp.squeeze(0)) # seq_len * dim

intraday_tech_out = self.intraday_tech_encoder(intraday_tech_fea.squeeze(0), intraday_tech_ticker.squeeze(0)) # seq_len * ticker_len * dim

intraday_tech_out1 = self.pooling(intraday_tech_out.transpose(1, 2)).squeeze(2) # seq_len * dim

intraday_tech_out2 = self.pooling2(intraday_tech_out.transpose(1, 2)).squeeze(2) # seq_len * dim

seq_len = intraday_tech_out.shape[0]

daily_tech_cn_out = self.daily_tech_cn_encoder(daily_tech_fea_cn, daily_tech_ticker_cn) # seq_len * ticker_len * dim

daily_tech_cn_out = self.pooling(daily_tech_cn_out.transpose(1, 2)).squeeze(2).repeat(seq_len, 1) # seq_len * dim

daily_fin_cn_out = self.daily_fin_cn_encoder(daily_fin_fea_cn, daily_fin_ticker_cn) # seq_len * ticker_len * dim

daily_fin_cn_out = self.pooling(daily_fin_cn_out.transpose(1, 2)).squeeze(2).repeat(seq_len, 1) # seq_len * dim

daily_tech_us_out = self.daily_tech_us_encoder(daily_tech_fea_us, daily_tech_ticker_us) # seq_len * ticker_len * dim

daily_tech_us_out = self.pooling(daily_tech_us_out.transpose(1, 2)).squeeze(2).repeat(seq_len, 1) # seq_len * dim

feature = torch.cat([time_out, intraday_tech_out1, intraday_tech_out2, daily_tech_cn_out, daily_fin_cn_out, daily_tech_us_out], dim=1) # seq_len * (dim * 6)

output = self.classifier(feature) # seq_len * class_num

return output

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。