diff --git a/Math/BOLL.py b/Math/BOLL.py index 26322681d..b3c9635ec 100644 --- a/Math/BOLL.py +++ b/Math/BOLL.py @@ -1,28 +1,29 @@ import math - +# 定义一个辅助函数,用于避免零值 def _truncate(x): - return x if x != 0 else 1e-7 - + return x if x != 0 else 1e-7 # 如果 x 为 0,则返回一个很小的数 1e-7 +# 定义 BOLL_Metric 类,用于表示布林带指标 class BOLL_Metric: def __init__(self, ma, theta): - self.theta = _truncate(theta) - self.UP = ma + 2*theta - self.DOWN = _truncate(ma - 2*theta) - self.MID = ma - + self.theta = _truncate(theta) # 存储标准差,避免为零 + self.UP = ma + 2 * theta # 计算布林带上轨 + self.DOWN = _truncate(ma - 2 * theta) # 计算布林带下轨,避免为零 + self.MID = ma # 存储中轨(移动平均) +# 定义 BollModel 类,用于计算布林带 class BollModel: def __init__(self, N=20): - assert N > 1 - self.N = N - self.arr = [] + assert N > 1 # 确保 N 大于 1 + self.N = N # 设置周期 N + self.arr = [] # 存储价格值的列表 + # 添加新的价格值并计算布林带指标 def add(self, value) -> BOLL_Metric: - self.arr.append(value) - if len(self.arr) > self.N: - self.arr = self.arr[-self.N:] - ma = sum(self.arr)/len(self.arr) - theta = math.sqrt(sum((x-ma)**2 for x in self.arr) / len(self.arr)) - return BOLL_Metric(ma, theta) + self.arr.append(value) # 将新值添加到列表中 + if len(self.arr) > self.N: # 如果列表长度超过周期 N + self.arr = self.arr[-self.N:] # 只保留最近 N 个值 + ma = sum(self.arr) / len(self.arr) # 计算移动平均 + theta = math.sqrt(sum((x - ma) ** 2 for x in self.arr) / len(self.arr)) # 计算标准差 + return BOLL_Metric(ma, theta) # 返回计算得到的布林带指标对象 \ No newline at end of file diff --git a/Math/Demark.py b/Math/Demark.py index 82d9b05b8..bbe2b580c 100644 --- a/Math/Demark.py +++ b/Math/Demark.py @@ -1,70 +1,77 @@ import copy from dataclasses import dataclass from typing import List, Literal, Optional, TypedDict - from Common.CEnum import BI_DIR - +# 定义 K线数据类 @dataclass class C_KL: - idx: int - close: float - high: float - low: float + idx: int # K线索引 + close: float # 收盘价 + high: float # 最高价 + low: float # 最低价 + # 根据是否为收盘价和方向返回相应的值 def v(self, is_close: bool, _dir: BI_DIR) -> float: if is_close: return self.close return self.high if _dir == BI_DIR.UP else self.low - +# 定义布林带类型 T_DEMARK_TYPE = Literal['setup', 'countdown'] - +# 定义类型字典,用于存储布林带信息 class T_DEMARK_INDEX(TypedDict): - type: T_DEMARK_TYPE - dir: BI_DIR - idx: int - series: 'CDemarkSetup' - + type: T_DEMARK_TYPE # 类型(setup 或 countdown) + dir: BI_DIR # 方向 + idx: int # 索引 + series: 'CDemarkSetup' # 关联的 CDemarkSetup 对象 +# 定义 CDemarkIndex 类,用于管理布林带数据 class CDemarkIndex: def __init__(self): - self.data: List[T_DEMARK_INDEX] = [] + self.data: List[T_DEMARK_INDEX] = [] # 存储布林带数据的列表 + # 添加新的布林带数据 def add(self, _dir: BI_DIR, _type: T_DEMARK_TYPE, idx: int, series: 'CDemarkSetup'): self.data.append({"dir": _dir, "idx": idx, "type": _type, "series": series}) + # 获取所有 setup 类型的数据 def get_setup(self) -> List[T_DEMARK_INDEX]: return [info for info in self.data if info['type'] == 'setup'] + # 获取所有 countdown 类型的数据 def get_countdown(self) -> List[T_DEMARK_INDEX]: return [info for info in self.data if info['type'] == 'countdown'] + # 更新数据 def update(self, demark_index: 'CDemarkIndex'): self.data.extend(demark_index.data) - +# 定义 CDemarkCountdown 类,用于处理 countdown 逻辑 class CDemarkCountdown: def __init__(self, _dir: BI_DIR, kl_list: List[C_KL], TDST_peak: float): - self.dir = _dir - self.kl_list: List[C_KL] = copy.deepcopy(kl_list) - self.idx = 0 - self.TDST_peak = TDST_peak - self.finish = False + self.dir = _dir # 方向 + self.kl_list: List[C_KL] = copy.deepcopy(kl_list) # K线列表的深拷贝 + self.idx = 0 # 当前计数器索引 + self.TDST_peak = TDST_peak # TDST 峰值 + self.finish = False # 是否完成标志 + # 更新 K线并返回是否成功 def update(self, kl: C_KL) -> bool: if self.finish: return False - self.kl_list.append(kl) + self.kl_list.append(kl) # 添加新的 K线 if len(self.kl_list) <= CDemarkEngine.COUNTDOWN_BIAS: return False if self.idx == CDemarkEngine.MAX_COUNTDOWN: - self.finish = True + self.finish = True # 达到最大计数,标记为完成 return False + # 检查是否超出 TDST 峰值 if (self.dir == BI_DIR.DOWN and kl.high > self.TDST_peak) or (self.dir == BI_DIR.UP and kl.low < self.TDST_peak): self.finish = True return False + # 检查是否满足计数条件 if self.dir == BI_DIR.DOWN and self.kl_list[-1].close < self.kl_list[-1 - CDemarkEngine.COUNTDOWN_BIAS].v(CDemarkEngine.COUNTDOWN_CMP2CLOSE, self.dir): self.idx += 1 return True @@ -73,67 +80,73 @@ def update(self, kl: C_KL) -> bool: return True return False - +# 定义 CDemarkSetup 类,用于处理 setup 逻辑 class CDemarkSetup: def __init__(self, _dir: BI_DIR, kl_list: List[C_KL], pre_kl: C_KL): - self.dir = _dir - self.kl_list: List[C_KL] = copy.deepcopy(kl_list) - self.pre_kl = pre_kl # 跳空时用 - assert len(self.kl_list) == CDemarkEngine.SETUP_BIAS - self.countdown: Optional[CDemarkCountdown] = None - self.setup_finished = False - self.idx = 0 - self.TDST_peak: Optional[float] = None - + self.dir = _dir # 方向 + self.kl_list: List[C_KL] = copy.deepcopy(kl_list) # K线列表的深拷贝 + self.pre_kl = pre_kl # 前一根 K线,用于跳空时的比较 + assert len(self.kl_list) == CDemarkEngine.SETUP_BIAS # 确保 K线数量正确 + self.countdown: Optional[CDemarkCountdown] = None # 初始化 countdown + self.setup_finished = False # setup 是否完成的标志 + self.idx = 0 # 当前索引 + self.TDST_peak: Optional[float] = None # TDST 峰值 self.last_demark_index = CDemarkIndex() # 缓存用 + # 更新 K线并返回 CDemarkIndex def update(self, kl: C_KL) -> CDemarkIndex: - self.last_demark_index = CDemarkIndex() + self.last_demark_index = CDemarkIndex() # 初始化缓存 if not self.setup_finished: - self.kl_list.append(kl) + self.kl_list.append(kl) # 添加新的 K线 + # 根据方向判断是否完成 setup if self.dir == BI_DIR.DOWN: - if self.kl_list[-1].close < self.kl_list[-1-CDemarkEngine.SETUP_BIAS].v(CDemarkEngine.SETUP_CMP2CLOSE, self.dir): - self.add_setup() + if self.kl_list[-1].close < self.kl_list[-1 - CDemarkEngine.SETUP_BIAS].v(CDemarkEngine.SETUP_CMP2CLOSE, self.dir): + self.add_setup() # 添加 setup else: - self.setup_finished = True - elif self.kl_list[-1].close > self.kl_list[-1-CDemarkEngine.SETUP_BIAS].v(CDemarkEngine.SETUP_CMP2CLOSE, self.dir): - self.add_setup() + self.setup_finished = True # 标记为完成 + elif self.kl_list[-1].close > self.kl_list[-1 - CDemarkEngine.SETUP_BIAS].v(CDemarkEngine.SETUP_CMP2CLOSE, self.dir): + self.add_setup() # 添加 setup else: - self.setup_finished = True + self.setup_finished = True # 标记为完成 + # 如果达到 DEMARK_LEN 且没有完成 setup,初始化 countdown if self.idx == CDemarkEngine.DEMARK_LEN and not self.setup_finished and self.countdown is None: self.countdown = CDemarkCountdown(self.dir, self.kl_list[:-1], self.cal_TDST_peak()) + # 更新 countdown if self.countdown is not None and self.countdown.update(kl): self.last_demark_index.add(self.dir, 'countdown', self.countdown.idx, self) - return self.last_demark_index + return self.last_demark_index # 返回缓存的数据 + # 添加 setup 数据 def add_setup(self): self.idx += 1 self.last_demark_index.add(self.dir, 'setup', self.idx, self) + # 计算 TDST 峰值 def cal_TDST_peak(self) -> float: - assert len(self.kl_list) == CDemarkEngine.SETUP_BIAS+CDemarkEngine.DEMARK_LEN - arr = self.kl_list[CDemarkEngine.SETUP_BIAS:CDemarkEngine.SETUP_BIAS+CDemarkEngine.DEMARK_LEN] - assert len(arr) == CDemarkEngine.DEMARK_LEN + assert len(self.kl_list) == CDemarkEngine.SETUP_BIAS + CDemarkEngine.DEMARK_LEN # 确保 K线数量正确 + arr = self.kl_list[CDemarkEngine.SETUP_BIAS:CDemarkEngine.SETUP_BIAS + CDemarkEngine.DEMARK_LEN] # 获取相关 K线 + assert len(arr) == CDemarkEngine.DEMARK_LEN # 确保 K线数量正确 + # 根据方向计算峰值 if self.dir == BI_DIR.DOWN: - res = max(kl.high for kl in arr) + res = max(kl.high for kl in arr) # 计算最高价 if CDemarkEngine.TIAOKONG_ST and arr[0].high < self.pre_kl.close: - res = max(res, self.pre_kl.close) + res = max(res, self.pre_kl.close) # 跳空时更新峰值 else: - res = min(kl.low for kl in arr) + res = min(kl.low for kl in arr) # 计算最低价 if CDemarkEngine.TIAOKONG_ST and arr[0].low > self.pre_kl.close: - res = min(res, self.pre_kl.close) - self.TDST_peak = res - return res - + res = min(res, self.pre_kl.close) # 跳空时更新峰值 + self.TDST_peak = res # 设置 TDST 峰值 + return res # 返回峰值 +# 定义 CDemarkEngine 类,管理整个 Demark 指标的逻辑 class CDemarkEngine: - DEMARK_LEN = 9 - SETUP_BIAS = 4 - COUNTDOWN_BIAS = 2 - MAX_COUNTDOWN = 13 - TIAOKONG_ST = True # 第一根跳空时是否跟前一根的close比 - SETUP_CMP2CLOSE = True - COUNTDOWN_CMP2CLOSE = True + DEMARK_LEN = 9 # Demark 长度 + SETUP_BIAS = 4 # setup 偏移 + COUNTDOWN_BIAS = 2 # countdown 偏移 + MAX_COUNTDOWN = 13 # 最大 countdown 次数 + TIAOKONG_ST = True # 第一根跳空时是否跟前一根的 close 比较 + SETUP_CMP2CLOSE = True # setup 比较收盘价 + COUNTDOWN_CMP2CLOSE = True # countdown 比较收盘价 def __init__( self, @@ -145,6 +158,7 @@ def __init__( setup_cmp2close=True, countdown_cmp2close=True ): + # 初始化参数 CDemarkEngine.DEMARK_LEN = demark_len CDemarkEngine.SETUP_BIAS = setup_bias CDemarkEngine.COUNTDOWN_BIAS = countdown_bias @@ -152,56 +166,63 @@ def __init__( CDemarkEngine.TIAOKONG_ST = tiaokong_st CDemarkEngine.SETUP_CMP2CLOSE = setup_cmp2close CDemarkEngine.COUNTDOWN_CMP2CLOSE = countdown_cmp2close + self.kl_lst: List[C_KL] = [] # K线列表 + self.series: List[CDemarkSetup] = [] # setup 列表 - self.kl_lst: List[C_KL] = [] - self.series: List[CDemarkSetup] = [] - + # 更新 K线数据并返回 CDemarkIndex def update(self, idx: int, close: float, high: float, low: float) -> CDemarkIndex: - self.kl_lst.append(C_KL(idx, close, high, low)) - if len(self.kl_lst) <= CDemarkEngine.SETUP_BIAS+1: - return CDemarkIndex() - - if self.kl_lst[-1].close < self.kl_lst[-1-self.SETUP_BIAS].close: + self.kl_lst.append(C_KL(idx, close, high, low)) # 添加新的 K线 + if len(self.kl_lst) <= CDemarkEngine.SETUP_BIAS + 1: + return CDemarkIndex() # 如果 K线数量不足,返回空的 CDemarkIndex + # 判断当前 K线是否满足 setup 条件 + if self.kl_lst[-1].close < self.kl_lst[-1 - self.SETUP_BIAS].close: + # 如果没有正在进行的下跌 setup,则添加新的下跌 setup if not any(series.dir == BI_DIR.DOWN and not series.setup_finished for series in self.series): - self.series.append(CDemarkSetup(BI_DIR.DOWN, self.kl_lst[-CDemarkEngine.SETUP_BIAS-1:-1], self.kl_lst[-CDemarkEngine.SETUP_BIAS-2])) + self.series.append(CDemarkSetup(BI_DIR.DOWN, self.kl_lst[-CDemarkEngine.SETUP_BIAS - 1:-1], self.kl_lst[-CDemarkEngine.SETUP_BIAS - 2])) + # 标记已经完成的上升 series for series in self.series: if series.dir == BI_DIR.UP and series.countdown is None and not series.setup_finished: series.setup_finished = True - elif self.kl_lst[-1].close > self.kl_lst[-1-self.SETUP_BIAS].close: + elif self.kl_lst[-1].close > self.kl_lst[-1 - self.SETUP_BIAS].close: + # 如果没有正在进行的上升 setup,则添加新的上升 setup if not any(series.dir == BI_DIR.UP and not series.setup_finished for series in self.series): - self.series.append(CDemarkSetup(BI_DIR.UP, self.kl_lst[-CDemarkEngine.SETUP_BIAS-1:-1], self.kl_lst[-CDemarkEngine.SETUP_BIAS-2])) + self.series.append(CDemarkSetup(BI_DIR.UP, self.kl_lst[-CDemarkEngine.SETUP_BIAS - 1:-1], self.kl_lst[-CDemarkEngine.SETUP_BIAS - 2])) + # 标记已经完成的下跌 series for series in self.series: if series.dir == BI_DIR.DOWN and series.countdown is None and not series.setup_finished: series.setup_finished = True + self.clear() # 清理无效的 series + self.clean_series_from_setup_finish() # 清理已完成的 setup + result = self.cal_result() # 计算结果 + self.clear() # 再次清理 + return result # 返回结果 - self.clear() - self.clean_series_from_setup_finish() - - result = self.cal_result() - self.clear() - return result - + # 计算最终的 CDemarkIndex def cal_result(self) -> CDemarkIndex: - demark_index = CDemarkIndex() + demark_index = CDemarkIndex() # 创建新的 CDemarkIndex for series in self.series: - demark_index.update(series.last_demark_index) - return demark_index + demark_index.update(series.last_demark_index) # 更新 demark_index + return demark_index # 返回结果 + # 清理无效的 series def clear(self): + # 清理已完成但没有 countdown 的 series invalid_series = [series for series in self.series if series.setup_finished and series.countdown is None] for s in invalid_series: self.series.remove(s) + # 清理已完成 countdown 的 series invalid_series = [series for series in self.series if series.countdown is not None and series.countdown.finish] for s in invalid_series: self.series.remove(s) + # 清理已完成的 setup def clean_series_from_setup_finish(self): - finished_setup: Optional[int] = None + finished_setup: Optional[int] = None # 标记完成的 setup for series in self.series: - demark_idx = series.update(self.kl_lst[-1]) + demark_idx = series.update(self.kl_lst[-1]) # 更新 series for setup_idx in demark_idx.get_setup(): if setup_idx['idx'] == CDemarkEngine.DEMARK_LEN: - assert finished_setup is None - finished_setup = id(series) + assert finished_setup is None # 确保只有一个完成的 setup + finished_setup = id(series) # 记录完成的 setup if finished_setup is not None: - self.series = [series for series in self.series if id(series) == finished_setup] + self.series = [series for series in self.series if id(series) == finished_setup] # 只保留完成的 setup \ No newline at end of file diff --git a/Math/KDJ.py b/Math/KDJ.py index 0fad01aed..e5a282cff 100644 --- a/Math/KDJ.py +++ b/Math/KDJ.py @@ -1,34 +1,41 @@ class KDJ_Item: def __init__(self, k, d, j): - self.k = k - self.d = d - self.j = j - + self.k = k # K值 + self.d = d # D值 + self.j = j # J值 class KDJ: def __init__(self, period: int = 9): super(KDJ, self).__init__() - self.arr = [] - self.period = period - self.pre_kdj = KDJ_Item(50, 50, 50) + self.arr = [] # 存储高低价的列表 + self.period = period # KDJ计算的周期,默认为9 + self.pre_kdj = KDJ_Item(50, 50, 50) # 初始化前一个KDJ值,默认值为50 + # 添加新的高、低、收盘价,并计算当前的KDJ值 def add(self, high, low, close) -> KDJ_Item: self.arr.append({ - 'high': high, - 'low': low, + 'high': high, # 添加当前的最高价 + 'low': low, # 添加当前的最低价 }) + # 如果存储的价格数据超过周期,则移除最早的数据 if len(self.arr) > self.period: self.arr.pop(0) - hn = max([x['high'] for x in self.arr]) - ln = min([x['low'] for x in self.arr]) - cn = close + # 计算当前周期内的最高价和最低价 + hn = max([x['high'] for x in self.arr]) # 当前周期内的最高价 + ln = min([x['low'] for x in self.arr]) # 当前周期内的最低价 + cn = close # 当前的收盘价 + + # 计算RSV值 rsv = 100 * (cn - ln) / (hn - ln) if hn != ln else 0.0 - cur_k = 2 / 3 * self.pre_kdj.k + 1 / 3 * rsv - cur_d = 2 / 3 * self.pre_kdj.d + 1 / 3 * cur_k - cur_j = 3 * cur_k - 2 * cur_d + # 计算当前的K、D、J值 + cur_k = 2 / 3 * self.pre_kdj.k + 1 / 3 * rsv # 当前K值 + cur_d = 2 / 3 * self.pre_kdj.d + 1 / 3 * cur_k # 当前D值 + cur_j = 3 * cur_k - 2 * cur_d # 当前J值 + + # 创建当前的KDJ项 cur_kdj = KDJ_Item(cur_k, cur_d, cur_j) - self.pre_kdj = cur_kdj + self.pre_kdj = cur_kdj # 更新前一个KDJ值 - return cur_kdj + return cur_kdj # 返回当前的KDJ值 \ No newline at end of file diff --git a/Math/MACD.py b/Math/MACD.py index 15b6fd625..6fd9d57e6 100644 --- a/Math/MACD.py +++ b/Math/MACD.py @@ -1,29 +1,35 @@ from typing import List - +# 定义 CMACD_item 类,表示 MACD 指标的一个数据项 class CMACD_item: def __init__(self, fast_ema, slow_ema, DIF, DEA): - self.fast_ema = fast_ema - self.slow_ema = slow_ema - self.DIF = DIF - self.DEA = DEA - self.macd = 2 * (DIF - DEA) - + self.fast_ema = fast_ema # 快速指数移动平均 + self.slow_ema = slow_ema # 慢速指数移动平均 + self.DIF = DIF # DIF 值 + self.DEA = DEA # DEA 值 + self.macd = 2 * (DIF - DEA) # 计算 MACD 值 +# 定义 CMACD 类,表示 MACD 指标的计算 class CMACD: def __init__(self, fastperiod=12, slowperiod=26, signalperiod=9): - self.macd_info: List[CMACD_item] = [] - self.fastperiod = fastperiod - self.slowperiod = slowperiod - self.signalperiod = signalperiod + self.macd_info: List[CMACD_item] = [] # 存储 MACD 数据项的列表 + self.fastperiod = fastperiod # 快速 EMA 的周期 + self.slowperiod = slowperiod # 慢速 EMA 的周期 + self.signalperiod = signalperiod # 信号线的周期 + # 添加新的价格值并计算相应的 MACD 数据项 def add(self, value) -> CMACD_item: - if not self.macd_info: + if not self.macd_info: # 如果列表为空,初始化第一个数据项 self.macd_info.append(CMACD_item(fast_ema=value, slow_ema=value, DIF=0, DEA=0)) else: + # 计算快速 EMA _fast_ema = (2 * value + (self.fastperiod - 1) * self.macd_info[-1].fast_ema) / (self.fastperiod + 1) + # 计算慢速 EMA _slow_ema = (2 * value + (self.slowperiod - 1) * self.macd_info[-1].slow_ema) / (self.slowperiod + 1) + # 计算 DIF _dif = _fast_ema - _slow_ema + # 计算 DEA _dea = (2 * _dif + (self.signalperiod - 1) * self.macd_info[-1].DEA) / (self.signalperiod + 1) + # 添加新的 MACD 数据项到列表中 self.macd_info.append(CMACD_item(fast_ema=_fast_ema, slow_ema=_slow_ema, DIF=_dif, DEA=_dea)) - return self.macd_info[-1] + return self.macd_info[-1] # 返回最新的 MACD 数据项 \ No newline at end of file diff --git a/Math/RSI.py b/Math/RSI.py index 31ccd140a..39ab33b30 100644 --- a/Math/RSI.py +++ b/Math/RSI.py @@ -1,29 +1,41 @@ class RSI: def __init__(self, period: int = 14): super(RSI, self).__init__() - self.close_arr = [] - self.period = period - self.diff = [] - self.up = [] - self.down = [] + self.close_arr = [] # 存储收盘价的列表 + self.period = period # RSI计算的周期,默认为14 + self.diff = [] # 存储收盘价的差值 + self.up = [] # 存储上涨的平均值 + self.down = [] # 存储下跌的平均值 + # 添加新的收盘价并计算当前的RSI值 def add(self, close): - self.close_arr.append(close) + self.close_arr.append(close) # 将新的收盘价添加到列表中 + + # 如果是第一个收盘价,返回默认的RSI值50.0 if len(self.close_arr) == 1: return 50.0 + + # 计算当前收盘价与前一个收盘价的差值 self.diff.append(self.close_arr[-1] - self.close_arr[-2]) + + # 如果差值列表的长度小于周期,计算初始的上涨和下跌平均值 if len(self.diff) < self.period: - self.up.append(sum(x for x in self.diff if x > 0)/self.period) - self.down.append(sum(-x for x in self.diff if x < 0)/self.period) + self.up.append(sum(x for x in self.diff if x > 0) / self.period) # 计算上涨平均值 + self.down.append(sum(-x for x in self.diff if x < 0) / self.period) # 计算下跌平均值 else: + # 根据当前差值更新上涨和下跌平均值 if self.diff[-1] > 0: - upval = self.diff[-1] - downval = 0.0 + upval = self.diff[-1] # 当前上涨值 + downval = 0.0 # 当前下跌值为0 else: - upval = 0.0 - downval = -self.diff[-1] + upval = 0.0 # 当前上涨值为0 + downval = -self.diff[-1] # 当前下跌值 + + # 更新上涨和下跌的平均值 self.up.append((self.up[-1] * (self.period - 1) + upval) / self.period) self.down.append((self.down[-1] * (self.period - 1) + downval) / self.period) - rs = self.up[-1] / self.down[-1] if self.down[-1] != 0 else 0 - rsi = 100.0 - 100.0 / (1.0 + rs) - return rsi + + # 计算相对强弱指数RSI + rs = self.up[-1] / self.down[-1] if self.down[-1] != 0 else 0 # 避免除以零 + rsi = 100.0 - 100.0 / (1.0 + rs) # 计算RSI值 + return rsi # 返回计算得到的RSI值 \ No newline at end of file diff --git a/Math/TrendLine.py b/Math/TrendLine.py index 60ec487ef..a2f123d4f 100644 --- a/Math/TrendLine.py +++ b/Math/TrendLine.py @@ -1,75 +1,81 @@ import copy from dataclasses import dataclass from math import sqrt - from Common.CEnum import BI_DIR, TREND_LINE_SIDE - +# 定义一个点类 @dataclass class Point: - x: int - y: float + x: int # 点的 x 坐标 + y: float # 点的 y 坐标 + # 计算当前点与另一个点的斜率 def cal_slope(self, p): - return (self.y-p.y)/(self.x-p.x) if self.x != p.x else float("inf") - + return (self.y - p.y) / (self.x - p.x) if self.x != p.x else float("inf") # 避免除以零 +# 定义一条直线类 @dataclass class Line: - p: Point - slope: float + p: Point # 直线上的一个点 + slope: float # 直线的斜率 + # 计算点到直线的距离 def cal_dis(self, p): - return abs(self.slope*p.x - p.y + self.p.y - self.slope*self.p.x) / sqrt(self.slope**2 + 1) - + return abs(self.slope * p.x - p.y + self.p.y - self.slope * self.p.x) / sqrt(self.slope**2 + 1) +# 定义趋势线类 class CTrendLine: def __init__(self, lst, side=TREND_LINE_SIDE.OUTSIDE): - self.line = None - self.side = side - self.cal(lst) + self.line = None # 初始化趋势线 + self.side = side # 趋势线的侧面(内部或外部) + self.cal(lst) # 计算趋势线 + # 计算趋势线 def cal(self, lst): - bench = float('inf') + bench = float('inf') # 初始化基准值为无穷大 + # 根据趋势线的侧面选择点 if self.side == TREND_LINE_SIDE.INSIDE: - all_p = [Point(bi.get_begin_klu().idx, bi.get_begin_val()) for bi in lst[-1::-2]] + all_p = [Point(bi.get_begin_klu().idx, bi.get_begin_val()) for bi in lst[-1::-2]] # 选择内部点 else: - all_p = [Point(bi.get_end_klu().idx, bi.get_end_val()) for bi in lst[-1::-2]] - c_p = copy.copy(all_p) + all_p = [Point(bi.get_end_klu().idx, bi.get_end_val()) for bi in lst[-1::-2]] # 选择外部点 + + c_p = copy.copy(all_p) # 复制所有点 while True: - line, idx = cal_tl(c_p, lst[-1].dir, self.side) - dis = sum(line.cal_dis(p) for p in all_p) - if dis < bench: - bench = dis - self.line = line - c_p = c_p[idx:] - if len(c_p) == 1: + line, idx = cal_tl(c_p, lst[-1].dir, self.side) # 计算趋势线 + dis = sum(line.cal_dis(p) for p in all_p) # 计算所有点到趋势线的距离之和 + if dis < bench: # 如果当前距离小于基准值 + bench = dis # 更新基准值 + self.line = line # 更新趋势线 + c_p = c_p[idx:] # 更新点列表 + if len(c_p) == 1: # 如果只剩下一个点,停止循环 break - +# 初始化峰值斜率 def init_peak_slope(_dir, side): if side == TREND_LINE_SIDE.INSIDE: - return 0 + return 0 # 内部趋势线的初始斜率为0 elif _dir == BI_DIR.UP: - return float("inf") + return float("inf") # 向上趋势线的初始斜率为无穷大 else: - return -float("inf") - + return -float("inf") # 向下趋势线的初始斜率为负无穷大 +# 计算趋势线 def cal_tl(c_p, _dir, side): - p = c_p[0] - peak_slope = init_peak_slope(_dir, side) - idx = 1 - for point_idx, p2 in enumerate(c_p[1:]): - slope = p.cal_slope(p2) + p = c_p[0] # 取第一个点 + peak_slope = init_peak_slope(_dir, side) # 初始化峰值斜率 + idx = 1 # 初始化索引 + for point_idx, p2 in enumerate(c_p[1:]): # 遍历剩余的点 + slope = p.cal_slope(p2) # 计算当前点与下一个点的斜率 + # 根据方向过滤斜率 if (_dir == BI_DIR.UP and slope < 0) or (_dir == BI_DIR.DOWN and slope > 0): continue + # 根据侧面更新峰值斜率 if side == TREND_LINE_SIDE.INSIDE: if (_dir == BI_DIR.UP and slope > peak_slope) or (_dir == BI_DIR.DOWN and slope < peak_slope): - peak_slope = slope - idx = point_idx+1 + peak_slope = slope # 更新峰值斜率 + idx = point_idx + 1 # 更新索引 else: if (_dir == BI_DIR.UP and slope < peak_slope) or (_dir == BI_DIR.DOWN and slope > peak_slope): - peak_slope = slope - idx = point_idx+1 - return Line(p, peak_slope), idx + peak_slope = slope # 更新峰值斜率 + idx = point_idx + 1 # 更新索引 + return Line(p, peak_slope), idx # 返回计算得到的趋势线和索引 \ No newline at end of file diff --git a/Math/TrendModel.py b/Math/TrendModel.py index ee63a10a0..bc6a116ca 100644 --- a/Math/TrendModel.py +++ b/Math/TrendModel.py @@ -1,22 +1,27 @@ from Common.CEnum import TREND_TYPE from Common.ChanException import CChanException, ErrCode - +# 定义趋势模型类 class CTrendModel: def __init__(self, trend_type: TREND_TYPE, T: int): - self.T = T - self.arr = [] - self.type = trend_type + self.T = T # 设置窗口大小 T + self.arr = [] # 初始化存储值的列表 + self.type = trend_type # 设置趋势类型 + # 添加新值并计算相应的趋势值 def add(self, value) -> float: - self.arr.append(value) + self.arr.append(value) # 将新值添加到列表中 + # 如果列表长度超过窗口大小 T,则只保留最近 T 个值 if len(self.arr) > self.T: self.arr = self.arr[-self.T:] - if self.type == TREND_TYPE.MEAN: - return sum(self.arr)/len(self.arr) - elif self.type == TREND_TYPE.MAX: - return max(self.arr) - elif self.type == TREND_TYPE.MIN: - return min(self.arr) + + # 根据趋势类型计算并返回相应的值 + if self.type == TREND_TYPE.MEAN: # 如果趋势类型是均值 + return sum(self.arr) / len(self.arr) # 返回均值 + elif self.type == TREND_TYPE.MAX: # 如果趋势类型是最大值 + return max(self.arr) # 返回最大值 + elif self.type == TREND_TYPE.MIN: # 如果趋势类型是最小值 + return min(self.arr) # 返回最小值 else: - raise CChanException(f"Unknown trendModel Type = {self.type}", ErrCode.PARA_ERROR) + # 如果趋势类型未知,抛出异常 + raise CChanException(f"Unknown trendModel Type = {self.type}", ErrCode.PARA_ERROR) \ No newline at end of file