Skip to content

Commit 697fbc2

Browse files
committed
Fix equity smoothing
1 parent 30d60a3 commit 697fbc2

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

investing_algorithm_framework/services/backtesting/backtest_service.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import polars as pl
1212

1313
from investing_algorithm_framework.domain import BacktestRun, OrderType, \
14-
TimeUnit, Trade, OperationalException, BacktestDateRange, \
14+
TimeUnit, Trade, OperationalException, BacktestDateRange, TimeFrame, \
1515
Backtest, TradeStatus, PortfolioSnapshot, Order, OrderStatus, OrderSide, \
1616
Portfolio, DataType, generate_backtest_summary_metrics
1717
from investing_algorithm_framework.services.data_providers import \
@@ -157,16 +157,23 @@ def create_vector_backtest(
157157
# Build master index (union of all indices in signal dict)
158158
index = pd.Index([])
159159

160+
most_granular_ohlcv_data_source = \
161+
self._get_most_granular_ohlcv_data_source(strategy.data_sources)
162+
most_granular_ohlcv_data = self._data_provider_service.get_ohlcv_data(
163+
symbol=most_granular_ohlcv_data_source.symbol,
164+
start_date=backtest_date_range.start_date,
165+
end_date=backtest_date_range.end_date,
166+
pandas=True
167+
)
168+
160169
# Make sure to filter out the buy and sell signals that are before
161170
# the backtest start date
162171
buy_signals = {k: v[v.index >= backtest_date_range.start_date]
163172
for k, v in buy_signals.items()}
164173
sell_signals = {k: v[v.index >= backtest_date_range.start_date]
165174
for k, v in sell_signals.items()}
166175

167-
for sig in list(buy_signals.values()) + list(sell_signals.values()):
168-
index = index.union(sig.index)
169-
176+
index = index.union(most_granular_ohlcv_data.index)
170177
index = index.sort_values()
171178

172179
# Initialize trades and portfolio values
@@ -289,6 +296,7 @@ def create_vector_backtest(
289296
)
290297
last_trade = trade
291298
trades.append(trade)
299+
unallocated -= capital_for_trade
292300

293301
# If we are in a position, and we get a sell signal
294302
if current_signal == -1 and last_trade is not None:
@@ -319,6 +327,7 @@ def create_vector_backtest(
319327
"net_gain": net_gain_val
320328
}
321329
)
330+
unallocated += last_trade.available_amount * current_price
322331
last_trade = None
323332

324333
# Create portfolio snapshots
@@ -334,13 +343,17 @@ def create_vector_backtest(
334343

335344
# Datetime is the index for pandas DataFrame, find the
336345
# closest timestamp that is less than or equal to ts
337-
prices = ohlcv.loc[ohlcv.index <= ts, "Close"].values
338-
339-
if len(prices) == 0:
340-
# No price data for this timestamp
341-
price = trade.open_price
342-
else:
343-
price = prices[-1]
346+
# prices = ohlcv.loc[ohlcv.index <= ts, "Close"].values
347+
#
348+
# if len(prices) == 0:
349+
# # No price data for this timestamp
350+
# price = trade.open_price
351+
# else:
352+
# price = prices[-1]
353+
try:
354+
price = ohlcv.loc[:ts, "Close"].iloc[-1]
355+
except IndexError:
356+
continue # skip if no price yet
344357

345358
invested_value += trade.filled_amount * price
346359
total_value = invested_value + unallocated
@@ -546,10 +559,16 @@ def _get_most_granular_ohlcv_data_source(data_sources):
546559
The most granular data source.
547560
"""
548561
granularity_order = {
549-
TimeUnit.SECOND: 1,
550-
TimeUnit.MINUTE: 2,
551-
TimeUnit.HOUR: 3,
552-
TimeUnit.DAY: 4
562+
TimeFrame.ONE_MINUTE: 1,
563+
TimeFrame.FIVE_MINUTE: 5,
564+
TimeFrame.FIFTEEN_MINUTE: 15,
565+
TimeFrame.ONE_HOUR: 60,
566+
TimeFrame.TWO_HOUR: 120,
567+
TimeFrame.FOUR_HOUR: 240,
568+
TimeFrame.TWELVE_HOUR: 720,
569+
TimeFrame.ONE_DAY: 1440,
570+
TimeFrame.ONE_WEEK: 10080,
571+
TimeFrame.ONE_MONTH: 43200
553572
}
554573

555574
most_granular = None
@@ -564,8 +583,8 @@ def _get_most_granular_ohlcv_data_source(data_sources):
564583

565584
for source in ohlcv_data_sources:
566585

567-
if granularity_order[source.time_unit] < highest_granularity:
568-
highest_granularity = granularity_order[source.time_unit]
586+
if granularity_order[source.time_frame] < highest_granularity:
587+
highest_granularity = granularity_order[source.time_frame]
569588
most_granular = source
570589

571590
return most_granular

0 commit comments

Comments
 (0)