6
6
import csv
7
7
import os
8
8
from datetime import date , datetime
9
+ from encodings .utf_8 import StreamWriter
9
10
from typing import Union , List , Optional
10
11
import lzma
11
12
@@ -42,7 +43,8 @@ class DataDownloader:
42
43
"""
43
44
44
45
def __init__ (self , api : TqApi , symbol_list : Union [str , List [str ]], dur_sec : int , start_dt : Union [date , datetime ],
45
- end_dt : Union [date , datetime ], csv_file_name : str , adj_type : Union [str , None ] = None ) -> None :
46
+ end_dt : Union [date , datetime ], csv_file_name : Union [str , asyncio .StreamWriter ],
47
+ adj_type : Union [str , None ] = None ) -> None :
46
48
"""
47
49
创建历史数据下载器实例
48
50
@@ -57,7 +59,10 @@ def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int,
57
59
58
60
end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
59
61
60
- csv_file_name (str): 输出 csv 的文件名
62
+ csv_file_name (str/StreamWriter): [必填]输出方式:
63
+ * str : 输出 csv 的文件名
64
+
65
+ * StreamWriter: 直接将内容输出到 StreamWriter
61
66
62
67
adj_type (str/None): 复权计算方式,默认值为 None。"F" 为前复权;"B" 为后复权;None 表示不复权。只对股票、基金合约有效。
63
68
@@ -111,7 +116,10 @@ def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int,
111
116
if adj_type not in [None , "F" , "B" , "FORWARD" , "BACK" ]:
112
117
raise Exception ("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)" )
113
118
self ._adj_type = adj_type [0 ] if adj_type else adj_type
114
- self ._csv_file_name = csv_file_name
119
+ if isinstance (csv_file_name , str ) or isinstance (csv_file_name , asyncio .StreamWriter ):
120
+ self ._csv_file_name = csv_file_name
121
+ else :
122
+ raise Exception ("csv_file_name 参数只支持 str | StreamWriter 类型" )
115
123
self ._csv_header = self ._get_headers ()
116
124
# 缓存合约对应的复权系数矩阵,每个合约只计算一次
117
125
# 含义为截止 datetime 之前(不包含) 应使用 factor 复权
@@ -168,9 +176,12 @@ def _get_data_series(self) -> Optional[pandas.DataFrame]:
168
176
"""
169
177
if not self ._task .done ():
170
178
return None
171
- if not self ._data_series :
172
- self ._data_series = pandas .read_csv (self ._csv_file_name )
173
- return self ._data_series
179
+ if isinstance (self ._csv_file_name , str ):
180
+ if not self ._data_series :
181
+ self ._data_series = pandas .read_csv (self ._csv_file_name )
182
+ return self ._data_series
183
+ else :
184
+ raise Exception ('DataDownloader._get_data_series 接口仅支持 csv_file_name 參數为 str 时使用' )
174
185
175
186
async def _ensure_dividend_factor (self , quote , timestamp ):
176
187
if quote .instrument_id not in self ._dividend_cache :
@@ -234,8 +245,8 @@ async def _ensure_dividend_factor(self, quote, timestamp):
234
245
235
246
async def _run (self ):
236
247
self ._quote_list = await self ._api .get_quote_list (self ._symbol_list )
237
- self . _data_chan = TqChan ( self . _api )
238
- task = self ._api . create_task ( self . _download_data () )
248
+ # 下载数据的 async generator
249
+ gen = self ._download_data ()
239
250
# cols 是复权需要重新计算的列名
240
251
index_datetime_nano = self ._csv_header .index ("datetime_nano" )
241
252
if self ._dur_nano != 0 :
@@ -244,33 +255,43 @@ async def _run(self):
244
255
cols = ["last_price" , "highest" , "lowest" ]
245
256
cols .extend (f"{ x } { i } " for x in ["bid_price" , "ask_price" ] for i in range (1 , 6 ))
246
257
try :
247
- with open (self ._csv_file_name , 'w' , newline = '' ) as csvfile :
248
- csv_writer = csv .writer (csvfile , dialect = 'excel' )
249
- csv_writer .writerow (self ._csv_header )
250
- async for item in self ._data_chan :
251
- for quote in self ._quote_list :
252
- symbol = quote .instrument_id
253
- if self ._adj_type and quote .ins_class in ["STOCK" , "FUND" ]:
254
- # 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
255
- # 前复权需要提前计算除权因子
256
- await self ._ensure_dividend_factor (quote , item [index_datetime_nano ])
257
- dividend_cache = self ._dividend_cache [symbol ]
258
- # dividend_df 和 _data_chan 中取出的数据都是按时间升序排列的,因此可以使用归并算法
259
- if dividend_cache ["last_dt" ] <= item [index_datetime_nano ]:
260
- dividend_df = dividend_cache ["df" ]
261
- dividend_df = dividend_df [dividend_df ["datetime" ].gt (item [index_datetime_nano ])]
262
- dividend_cache ["df" ] = dividend_df
263
- dividend_cache ["last_dt" ] = dividend_df ["datetime" ].iloc [0 ]
264
- dividend_cache ["factor" ] = dividend_df ["factor" ].iloc [0 ]
265
- if dividend_cache ["factor" ] != 1 :
266
- item = item .copy ()
267
- for c in cols : # datetime_nano
268
- index = self ._csv_header .index (f"{ symbol } .{ c } " )
269
- item [index ] = item [index ] * dividend_cache ["factor" ]
270
- csv_writer .writerow (item )
258
+ if isinstance (self ._csv_file_name , asyncio .StreamWriter ):
259
+ writer = StreamWriter (self ._csv_file_name )
260
+ else :
261
+ writer = open (self ._csv_file_name , 'w' , newline = '' )
262
+ csv_writer = csv .writer (writer , dialect = 'excel' )
263
+ csv_writer .writerow (self ._csv_header )
264
+ async for item in gen :
265
+ for quote in self ._quote_list :
266
+ symbol = quote .instrument_id
267
+ if self ._adj_type and quote .ins_class in ["STOCK" , "FUND" ]:
268
+ # 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
269
+ # 前复权需要提前计算除权因子
270
+ await self ._ensure_dividend_factor (quote , item [index_datetime_nano ])
271
+ dividend_cache = self ._dividend_cache [symbol ]
272
+ # dividend_df 和 _data_chan 中取出的数据都是按时间升序排列的,因此可以使用归并算法
273
+ if dividend_cache ["last_dt" ] <= item [index_datetime_nano ]:
274
+ dividend_df = dividend_cache ["df" ]
275
+ dividend_df = dividend_df [dividend_df ["datetime" ].gt (item [index_datetime_nano ])]
276
+ dividend_cache ["df" ] = dividend_df
277
+ dividend_cache ["last_dt" ] = dividend_df ["datetime" ].iloc [0 ]
278
+ dividend_cache ["factor" ] = dividend_df ["factor" ].iloc [0 ]
279
+ if dividend_cache ["factor" ] != 1 :
280
+ item = item .copy ()
281
+ for c in cols : # datetime_nano
282
+ index = self ._csv_header .index (f"{ symbol } .{ c } " )
283
+ item [index ] = item [index ] * dividend_cache ["factor" ]
284
+ csv_writer .writerow (item )
285
+ if isinstance (self ._csv_file_name , asyncio .StreamWriter ):
286
+ await self ._csv_file_name .drain ()
271
287
finally :
272
- task .cancel ()
273
- await asyncio .gather (task , return_exceptions = True )
288
+ if isinstance (self ._csv_file_name , asyncio .StreamWriter ):
289
+ self ._csv_file_name .write_eof ()
290
+ else :
291
+ writer .close ()
292
+ # 这里 `await gen.aclose()` 实际测试代码与文档描述不符,无论文件全部下载完正常退出还是写文件过程中抛出例外退出,都没有再抛出任何例外
293
+ # https://docs.python.org/3/reference/expressions.html#agen.aclose
294
+ await gen .aclose ()
274
295
275
296
async def _timeout_handle (self , timeout , chart ):
276
297
await asyncio .sleep (timeout )
@@ -330,7 +351,7 @@ async def _download_data(self):
330
351
k = {} if tid == - 1 else serials [i ]["data" ].get (str (tid ), {})
331
352
for col in data_cols :
332
353
row .append (self ._get_value (k , col , self ._quote_list [i ]["price_decs" ]))
333
- await self . _data_chan . send ( row )
354
+ yield row
334
355
current_id += 1
335
356
self ._current_dt_nano = item ["datetime" ]
336
357
# 当前 id 已超出订阅范围, 需重新订阅后续数据
@@ -340,7 +361,6 @@ async def _download_data(self):
340
361
await self ._api ._send_chan .send (chart_info )
341
362
finally :
342
363
# 释放chart资源
343
- await self ._data_chan .close ()
344
364
await self ._api ._send_chan .send ({
345
365
"aid" : "set_chart" ,
346
366
"chart_id" : chart_info ["chart_id" ],
@@ -349,7 +369,7 @@ async def _download_data(self):
349
369
"view_width" : 2000 ,
350
370
})
351
371
timeout_task .cancel ()
352
- await timeout_task
372
+ await asyncio . gather ( timeout_task , return_exceptions = True )
353
373
354
374
def _get_headers (self ):
355
375
data_cols = self ._get_data_cols ()
0 commit comments