5
5
6
6
from tqsdk .datetime import _get_expire_rest_days
7
7
from tqsdk .datetime_state import TqDatetimeState
8
- from tqsdk .diff import _simple_merge_diff
8
+ from tqsdk .diff import _simple_merge_diff , _is_key_exist , _simple_merge_diff_and_collect_paths , _get_obj
9
9
10
10
11
11
class DataExtension (object ):
@@ -25,6 +25,14 @@ class DataExtension(object):
25
25
'pos_short': int,
26
26
'pos': int
27
27
}
28
+ },
29
+ ordres: {
30
+ *: {
31
+ 'is_dead': bool
32
+ 'is_online': bool
33
+ 'is_error': bool
34
+ 'trade_price': float
35
+ }
28
36
}
29
37
}
30
38
}
@@ -35,6 +43,30 @@ def __init__(self, api):
35
43
self ._api = api
36
44
self ._data = {'trade' : {}} # 数据截面, 现在的功能只需要记录 trade
37
45
self ._diffs = []
46
+ self ._diffs_paths = set ()
47
+ self ._prototype = {
48
+ "trade" : {
49
+ "*" : {
50
+ "orders" : {
51
+ "*" : {
52
+ "status" : None ,
53
+ "exchange_order_id" : None
54
+ }
55
+ },
56
+ "trades" : {
57
+ "*" : None
58
+ },
59
+ "positions" : {
60
+ "*" : {
61
+ "pos_long_his" : None ,
62
+ "pos_long_today" : None ,
63
+ "pos_short_his" : None ,
64
+ "pos_short_today" : None
65
+ }
66
+ }
67
+ }
68
+ }
69
+ }
38
70
39
71
async def _run (self , api_send_chan , api_recv_chan , md_send_chan , md_recv_chan ):
40
72
self ._logger = self ._api ._logger .getChild ("DataExtension" )
@@ -74,10 +106,16 @@ async def _md_handler(self):
74
106
75
107
async def _md_recv (self , pack ):
76
108
"""将行情数据和交易数据合并至 self._data """
77
- for d in pack .get ("data" , {} ):
109
+ for d in pack .get ("data" , [] ):
78
110
self ._datetime_state .update_state (d )
79
111
if d .get ('trade' , None ):
80
- _simple_merge_diff (self ._data ['trade' ], d ['trade' ], reduce_diff = False )
112
+ _simple_merge_diff_and_collect_paths (
113
+ result = self ._data ['trade' ],
114
+ diff = d ['trade' ],
115
+ path = ('trade' , ),
116
+ diff_paths = self ._diffs_paths ,
117
+ prototype = self ._prototype ['trade' ]
118
+ )
81
119
self ._diffs .append (d )
82
120
83
121
def _generate_ext_diff (self ):
@@ -86,43 +124,79 @@ def _generate_ext_diff(self):
86
124
此函数在 send_diff() 才会调用, self._datetime_state.data_ready 一定为 True,
87
125
调用 self._datetime_state.get_current_dt() 一定有正确的当前时间
88
126
"""
89
- pend_diff = {}
90
127
for d in self ._diffs :
91
128
if d .get ('quotes' , None ):
92
- _simple_merge_diff (pend_diff , self ._update_quotes (d ), reduce_diff = False )
93
- if d .get ('trade' , None ):
94
- _simple_merge_diff (pend_diff , self ._update_positions (d ), reduce_diff = False )
129
+ self ._update_quotes (d )
130
+ pend_diff = {}
131
+ _simple_merge_diff (pend_diff , self ._get_positions_pend_diff (), reduce_diff = False )
132
+ orders_set = set () # 计算过委托单,is_dead、is_online、is_error
133
+ orders_price_set = set () # 根据成交计算哪些 order 需要重新计算平均成交价 trade_price
134
+ for path in self ._diffs_paths :
135
+ if path [2 ] == 'orders' :
136
+ _ , account_key , _ , order_id , _ = path
137
+ if (account_key , order_id ) not in orders_set :
138
+ orders_set .add ((account_key , order_id ))
139
+ order = _get_obj (self ._data , ['trade' , account_key , 'orders' , order_id ])
140
+ if order :
141
+ pend_order = pend_diff .setdefault ('trade' , {}).setdefault (account_key , {}).setdefault ('orders' , {}).setdefault (order_id , {})
142
+ pend_order ['is_dead' ] = order ['status' ] == "FINISHED"
143
+ pend_order ['is_online' ] = order ['exchange_order_id' ] != "" and order ['status' ] == "ALIVE"
144
+ pend_order ['is_error' ] = order ['exchange_order_id' ] == "" and order ['status' ] == "FINISHED"
145
+ elif path [2 ] == 'trades' :
146
+ _ , account_key , _ , trade_id = path
147
+ trade = _get_obj (self ._data , path )
148
+ order_id = trade .get ('order_id' , '' )
149
+ if order_id :
150
+ orders_price_set .add (('trade' , account_key , 'orders' , order_id ))
151
+ for path in orders_price_set :
152
+ _ , account_key , _ , order_id = path
153
+ trade_price = self ._get_trade_price (account_key , order_id )
154
+ if trade_price == trade_price :
155
+ pend_order = pend_diff .setdefault ('trade' , {}).setdefault (account_key , {}).setdefault ('orders' , {}).setdefault (order_id , {})
156
+ pend_order ['trade_price' ] = trade_price
157
+ self ._diffs_paths = set ()
95
158
return pend_diff
96
159
97
160
def _update_quotes (self , diff ):
98
- pend_diff = {}
99
161
for symbol in diff ['quotes' ]:
100
- expire_datetime = diff ['quotes' ].get (symbol , {}).get ('expire_datetime' , float ('nan' ))
101
- if expire_datetime == expire_datetime :
102
- # expire_rest_days 距离到期日的剩余天数(自然日天数)
103
- # 正数表示距离到期日的剩余天数,0表示到期日当天,负数表示距离到期日已经过去的天数
162
+ if not _is_key_exist (diff , path = ['quotes' , symbol ], key = ['expire_datetime' ]):
163
+ continue
164
+ expire_datetime = diff ['quotes' ][symbol ]['expire_datetime' ]
165
+ if expire_datetime and expire_datetime == expire_datetime : # 排除 None 和 nan
166
+ # expire_rest_days 距离到期日的剩余天数(自然日天数),正数表示距离到期日的剩余天数,0表示到期日当天,负数表示距离到期日已经过去的天数
167
+ # 直接修改在 diff 里面的数据,当 diffs 里有多个对同个合约的修改时,保持数据截面的一致
104
168
expire_rest_days = _get_expire_rest_days (expire_datetime , self ._datetime_state .get_current_dt () / 1e9 )
105
- pend_diff [symbol ] = {'expire_rest_days' : expire_rest_days }
106
- return {'quotes' : pend_diff } if pend_diff else {}
169
+ diff ['quotes' ][symbol ]['expire_rest_days' ] = expire_rest_days
107
170
108
- def _update_positions (self , diff ):
171
+ def _get_positions_pend_diff (self ):
109
172
pend_diff = {}
110
- for account_key in diff ['trade' ]:
111
- for symbol in diff ['trade' ]. get ( account_key , {}) .get ('positions' , {}):
112
- pos = diff [ 'trade' ][ account_key ][ ' positions' ][ symbol ]
113
- if 'pos_long_his' in pos or 'pos_long_today' in pos or 'pos_short_his' in pos or 'pos_short_today' in pos :
114
- data_pos = self . _data [ 'trade' ][ account_key ][ 'positions' ][ symbol ]
115
- pos_long = data_pos [ 'pos_long_his' ] + data_pos [ 'pos_long_today' ]
116
- pos_short = data_pos [ 'pos_short_his ' ] + data_pos [ 'pos_short_today ' ]
117
- pend_diff . setdefault ( account_key , {})
118
- pend_diff [ account_key ] .setdefault ('positions' , {})
173
+ for account_key in self . _data ['trade' ]:
174
+ positions = self . _data ['trade' ][ account_key ] .get ('positions' , {})
175
+ for symbol , pos in positions . items ():
176
+ paths = [( 'trade' , account_key , 'positions' , symbol ) + ( key , )
177
+ for key in [ 'pos_long_his' , 'pos_long_today' , 'pos_short_his' , 'pos_short_today' ] ]
178
+ if any ([ p in self . _diffs_paths for p in paths ]):
179
+ pos_long = pos [ 'pos_long_his ' ] + pos [ 'pos_long_today ' ]
180
+ pos_short = pos [ 'pos_short_his' ] + pos [ 'pos_short_today' ]
181
+ pend_diff . setdefault ( account_key , {}) .setdefault ('positions' , {})
119
182
pend_diff [account_key ]['positions' ][symbol ] = {
120
183
'pos_long' : pos_long ,
121
184
'pos_short' : pos_short ,
122
185
'pos' : pos_long - pos_short
123
186
}
124
187
return {'trade' : pend_diff } if pend_diff else {}
125
188
189
+ def _get_trade_price (self , account_key , order_id ):
190
+ # 计算某个 order_id 对应的成交均价
191
+ trades = self ._data ['trade' ][account_key ]['trades' ]
192
+ trade_id_list = [t_id for t_id in trades .keys () if trades [t_id ]['order_id' ] == order_id ]
193
+ sum_volume = sum ([trades [t_id ]['volume' ] for t_id in trade_id_list ])
194
+ if sum_volume == 0 :
195
+ return float ('nan' )
196
+ else :
197
+ sum_amount = sum ([trades [t_id ]['volume' ] * trades [t_id ]['price' ] for t_id in trade_id_list ])
198
+ return sum_amount / sum_volume
199
+
126
200
async def _send_diff (self ):
127
201
if self ._datetime_state .data_ready and self ._pending_peek and self ._diffs :
128
202
# 生成增量业务截面, 该截面包含补充的字段,只在真正需要给下游发送数据时,才将需要发送的数据放在 _diffs 中
0 commit comments