Skip to content

Commit 8446a3b

Browse files
shinny-packshinny-mayanqiong
authored andcommitted
Update Version 3.2.11
1 parent 41ee293 commit 8446a3b

File tree

8 files changed

+70
-44
lines changed

8 files changed

+70
-44
lines changed

PKG-INFO

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: tqsdk
3-
Version: 3.2.10
3+
Version: 3.2.11
44
Summary: TianQin SDK
55
Home-page: https://www.shinnytech.com/tqsdk
66
Author: TianQin

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
# built documents.
4949
#
5050
# The short X.Y version.
51-
version = u'3.2.10'
51+
version = u'3.2.11'
5252
# The full version, including alpha/beta/rc tags.
53-
release = u'3.2.10'
53+
release = u'3.2.11'
5454

5555
# The language for content autogenerated by Sphinx. Refer to documentation
5656
# for a list of supported languages.

doc/version.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
版本变更
44
=============================
5+
3.2.11 (2022/07/27)
6+
7+
* 增加:下载数据时 csv_file_name 参数支持 str / asyncio.StreamWriter 两种类型
8+
* 修复:vwap_table 手数计算错误的问题
9+
10+
511
3.2.10 (2022/07/20)
612

713
* 增加:增加中证 1000 指数,免费用户可获取该指数行情,参考文档 :ref:`mddatas`

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_tag(self):
3636

3737
setuptools.setup(
3838
name='tqsdk',
39-
version="3.2.10",
39+
version="3.2.11",
4040
description='TianQin SDK',
4141
author='TianQin',
4242
author_email='tianqincn@gmail.com',

tqsdk/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '3.2.10'
1+
__version__ = '3.2.11'

tqsdk/algorithm/time_table_generater.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def vwap_table(api: TqApi, symbol: str, target_pos: int, duration: float,
257257
volume_left = target_volume # 剩余手数
258258
percent_left = 1 # 剩余百分比
259259
for index, value in predicted_percent.items():
260-
volume = round(target_volume * (value / percent_left))
260+
volume = round(volume_left * (value / percent_left))
261261
volume_left -= volume
262262
percent_left -= value
263263
append_time_table = pd.DataFrame([

tqsdk/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def insert_order(self, symbol: str, direction: str, offset: str = "", volume: in
11271127
direction (str): "BUY" 或 "SELL"
11281128
11291129
offset (str): "OPEN", "CLOSE" 或 "CLOSETODAY" \
1130-
(上期所和原油分平今/平昨, 平今用"CLOSETODAY", 平昨用"CLOSE"; 其他交易所直接用"CLOSE" 按照交易所的规则平仓), \
1130+
(上期所和上期能源分平今/平昨, 平今用"CLOSETODAY", 平昨用"CLOSE"; 其他交易所直接用"CLOSE" 按照交易所的规则平仓), \
11311131
股票交易中该参数无需填写
11321132
11331133
volume (int): 下单交易数量, 期货为下单手数, A 股股票为股数

tqsdk/tools/downloader.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import csv
77
import os
88
from datetime import date, datetime
9+
from encodings.utf_8 import StreamWriter
910
from typing import Union, List, Optional
1011
import lzma
1112

@@ -42,7 +43,8 @@ class DataDownloader:
4243
"""
4344

4445
def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime],
45-
end_dt: Union[date, datetime], csv_file_name: str, adj_type: Union[str, None] = None) -> None:
46+
end_dt: Union[date, datetime], csv_file_name: Union[str, asyncio.StreamWriter],
47+
adj_type: Union[str, None] = None) -> None:
4648
"""
4749
创建历史数据下载器实例
4850
@@ -57,7 +59,10 @@ def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int,
5759
5860
end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
5961
60-
csv_file_name (str): 输出 csv 的文件名
62+
csv_file_name (str/StreamWriter): [必填]输出方式:
63+
* str : 输出 csv 的文件名
64+
65+
* StreamWriter: 直接将内容输出到 StreamWriter
6166
6267
adj_type (str/None): 复权计算方式,默认值为 None。"F" 为前复权;"B" 为后复权;None 表示不复权。只对股票、基金合约有效。
6368
@@ -111,7 +116,10 @@ def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int,
111116
if adj_type not in [None, "F", "B", "FORWARD", "BACK"]:
112117
raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)")
113118
self._adj_type = adj_type[0] if adj_type else adj_type
114-
self._csv_file_name = csv_file_name
119+
if isinstance(csv_file_name, str) or isinstance(csv_file_name, asyncio.StreamWriter):
120+
self._csv_file_name = csv_file_name
121+
else:
122+
raise Exception("csv_file_name 参数只支持 str | StreamWriter 类型")
115123
self._csv_header = self._get_headers()
116124
# 缓存合约对应的复权系数矩阵,每个合约只计算一次
117125
# 含义为截止 datetime 之前(不包含) 应使用 factor 复权
@@ -168,9 +176,12 @@ def _get_data_series(self) -> Optional[pandas.DataFrame]:
168176
"""
169177
if not self._task.done():
170178
return None
171-
if not self._data_series:
172-
self._data_series = pandas.read_csv(self._csv_file_name)
173-
return self._data_series
179+
if isinstance(self._csv_file_name, str):
180+
if not self._data_series:
181+
self._data_series = pandas.read_csv(self._csv_file_name)
182+
return self._data_series
183+
else:
184+
raise Exception('DataDownloader._get_data_series 接口仅支持 csv_file_name 參數为 str 时使用')
174185

175186
async def _ensure_dividend_factor(self, quote, timestamp):
176187
if quote.instrument_id not in self._dividend_cache:
@@ -234,8 +245,8 @@ async def _ensure_dividend_factor(self, quote, timestamp):
234245

235246
async def _run(self):
236247
self._quote_list = await self._api.get_quote_list(self._symbol_list)
237-
self._data_chan = TqChan(self._api)
238-
task = self._api.create_task(self._download_data())
248+
# 下载数据的 async generator
249+
gen = self._download_data()
239250
# cols 是复权需要重新计算的列名
240251
index_datetime_nano = self._csv_header.index("datetime_nano")
241252
if self._dur_nano != 0:
@@ -244,33 +255,43 @@ async def _run(self):
244255
cols = ["last_price", "highest", "lowest"]
245256
cols.extend(f"{x}{i}" for x in ["bid_price", "ask_price"] for i in range(1, 6))
246257
try:
247-
with open(self._csv_file_name, 'w', newline='') as csvfile:
248-
csv_writer = csv.writer(csvfile, dialect='excel')
249-
csv_writer.writerow(self._csv_header)
250-
async for item in self._data_chan:
251-
for quote in self._quote_list:
252-
symbol = quote.instrument_id
253-
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
254-
# 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
255-
# 前复权需要提前计算除权因子
256-
await self._ensure_dividend_factor(quote, item[index_datetime_nano])
257-
dividend_cache = self._dividend_cache[symbol]
258-
# dividend_df 和 _data_chan 中取出的数据都是按时间升序排列的,因此可以使用归并算法
259-
if dividend_cache["last_dt"] <= item[index_datetime_nano]:
260-
dividend_df = dividend_cache["df"]
261-
dividend_df = dividend_df[dividend_df["datetime"].gt(item[index_datetime_nano])]
262-
dividend_cache["df"] = dividend_df
263-
dividend_cache["last_dt"] = dividend_df["datetime"].iloc[0]
264-
dividend_cache["factor"] = dividend_df["factor"].iloc[0]
265-
if dividend_cache["factor"] != 1:
266-
item = item.copy()
267-
for c in cols: # datetime_nano
268-
index = self._csv_header.index(f"{symbol}.{c}")
269-
item[index] = item[index] * dividend_cache["factor"]
270-
csv_writer.writerow(item)
258+
if isinstance(self._csv_file_name, asyncio.StreamWriter):
259+
writer = StreamWriter(self._csv_file_name)
260+
else:
261+
writer = open(self._csv_file_name, 'w', newline='')
262+
csv_writer = csv.writer(writer, dialect='excel')
263+
csv_writer.writerow(self._csv_header)
264+
async for item in gen:
265+
for quote in self._quote_list:
266+
symbol = quote.instrument_id
267+
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
268+
# 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
269+
# 前复权需要提前计算除权因子
270+
await self._ensure_dividend_factor(quote, item[index_datetime_nano])
271+
dividend_cache = self._dividend_cache[symbol]
272+
# dividend_df 和 _data_chan 中取出的数据都是按时间升序排列的,因此可以使用归并算法
273+
if dividend_cache["last_dt"] <= item[index_datetime_nano]:
274+
dividend_df = dividend_cache["df"]
275+
dividend_df = dividend_df[dividend_df["datetime"].gt(item[index_datetime_nano])]
276+
dividend_cache["df"] = dividend_df
277+
dividend_cache["last_dt"] = dividend_df["datetime"].iloc[0]
278+
dividend_cache["factor"] = dividend_df["factor"].iloc[0]
279+
if dividend_cache["factor"] != 1:
280+
item = item.copy()
281+
for c in cols: # datetime_nano
282+
index = self._csv_header.index(f"{symbol}.{c}")
283+
item[index] = item[index] * dividend_cache["factor"]
284+
csv_writer.writerow(item)
285+
if isinstance(self._csv_file_name, asyncio.StreamWriter):
286+
await self._csv_file_name.drain()
271287
finally:
272-
task.cancel()
273-
await asyncio.gather(task, return_exceptions=True)
288+
if isinstance(self._csv_file_name, asyncio.StreamWriter):
289+
self._csv_file_name.write_eof()
290+
else:
291+
writer.close()
292+
# 这里 `await gen.aclose()` 实际测试代码与文档描述不符,无论文件全部下载完正常退出还是写文件过程中抛出例外退出,都没有再抛出任何例外
293+
# https://docs.python.org/3/reference/expressions.html#agen.aclose
294+
await gen.aclose()
274295

275296
async def _timeout_handle(self, timeout, chart):
276297
await asyncio.sleep(timeout)
@@ -330,7 +351,7 @@ async def _download_data(self):
330351
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
331352
for col in data_cols:
332353
row.append(self._get_value(k, col, self._quote_list[i]["price_decs"]))
333-
await self._data_chan.send(row)
354+
yield row
334355
current_id += 1
335356
self._current_dt_nano = item["datetime"]
336357
# 当前 id 已超出订阅范围, 需重新订阅后续数据
@@ -340,7 +361,6 @@ async def _download_data(self):
340361
await self._api._send_chan.send(chart_info)
341362
finally:
342363
# 释放chart资源
343-
await self._data_chan.close()
344364
await self._api._send_chan.send({
345365
"aid": "set_chart",
346366
"chart_id": chart_info["chart_id"],
@@ -349,7 +369,7 @@ async def _download_data(self):
349369
"view_width": 2000,
350370
})
351371
timeout_task.cancel()
352-
await timeout_task
372+
await asyncio.gather(timeout_task, return_exceptions=True)
353373

354374
def _get_headers(self):
355375
data_cols = self._get_data_cols()

0 commit comments

Comments
 (0)