diff --git a/docs/openapi.yml b/docs/openapi.yml
index a819902cb6..98fe89a91f 100644
--- a/docs/openapi.yml
+++ b/docs/openapi.yml
@@ -494,6 +494,40 @@ components:
items:
type: string
description: "List of options for multiple_choice questions"
+ example:
+ - "Democratic"
+ - "Republican"
+ - "Libertarian"
+ - "Green"
+ - "Other"
+ all_options_ever:
+ type: array
+ items:
+ type: string
+ description: "List of all options ever for multiple_choice questions"
+ example:
+ - "Democratic"
+ - "Republican"
+ - "Libertarian"
+ - "Green"
+ - "Blue"
+ - "Other"
+ options_history:
+ type: array
+ description: "List of [iso format time, options] pairs for multiple_choice questions"
+ items:
+ type: array
+ items:
+ oneOf:
+ - type: string
+ description: "ISO 8601 timestamp when the options became active"
+ - type: array
+ items:
+ type: string
+ description: "Options list active from this timestamp onward"
+ example:
+ - ["0001-01-01T00:00:00", ["a", "b", "c", "other"]]
+ - ["2026-10-22T16:00:00", ["a", "b", "c", "d", "other"]]
status:
type: string
enum: [ upcoming, open, closed, resolved ]
@@ -1306,6 +1340,7 @@ paths:
actual_close_time: "2020-11-01T00:00:00Z"
type: "numeric"
options: null
+ options_history: null
status: "resolved"
resolution: "77289125.94957079"
resolution_criteria: "Resolution Criteria Copy"
@@ -1479,6 +1514,7 @@ paths:
actual_close_time: "2015-12-15T03:34:00Z"
type: "binary"
options: null
+ options_history: null
status: "resolved"
possibilities:
type: "binary"
@@ -1548,6 +1584,16 @@ paths:
- "Libertarian"
- "Green"
- "Other"
+ all_options_ever:
+ - "Democratic"
+ - "Republican"
+ - "Libertarian"
+ - "Green"
+ - "Blue"
+ - "Other"
+ options_history:
+ - ["0001-01-01T00:00:00", ["Democratic", "Republican", "Libertarian", "Other"]]
+ - ["2026-10-22T16:00:00", ["Democratic", "Republican", "Libertarian", "Green", "Other"]]
status: "open"
possibilities: { }
resolution: null
diff --git a/misc/views.py b/misc/views.py
index ce428de8c3..74f6914dfb 100644
--- a/misc/views.py
+++ b/misc/views.py
@@ -113,7 +113,9 @@ def get_site_stats(request):
now_year = datetime.now().year
public_questions = Question.objects.filter_public()
stats = {
- "predictions": Forecast.objects.filter(question__in=public_questions).count(),
+ "predictions": Forecast.objects.filter(question__in=public_questions)
+ .exclude(source=Forecast.SourceChoices.AUTOMATIC)
+ .count(),
"questions": public_questions.count(),
"resolved_questions": public_questions.filter(actual_resolve_time__isnull=False)
.exclude(resolution__in=UnsuccessfulResolutionType)
diff --git a/posts/models.py b/posts/models.py
index 635e4a93ce..2a3f4d0e2d 100644
--- a/posts/models.py
+++ b/posts/models.py
@@ -810,7 +810,11 @@ def update_forecasts_count(self):
Update forecasts count cache
"""
- self.forecasts_count = self.forecasts.filter_within_question_period().count()
+ self.forecasts_count = (
+ self.forecasts.filter_within_question_period()
+ .exclude(source=Forecast.SourceChoices.AUTOMATIC)
+ .count()
+ )
self.save(update_fields=["forecasts_count"])
def update_forecasters_count(self):
diff --git a/questions/admin.py b/questions/admin.py
index dbefab5257..d12545aa6d 100644
--- a/questions/admin.py
+++ b/questions/admin.py
@@ -32,7 +32,12 @@ class QuestionAdmin(CustomTranslationAdmin, DynamicArrayMixin):
"curation_status",
"post_link",
]
- readonly_fields = ["post_link", "view_forecasts"]
+ readonly_fields = [
+ "post_link",
+ "view_forecasts",
+ "options",
+ "options_history",
+ ]
search_fields = [
"id",
"title_original",
diff --git a/questions/migrations/0013_forecast_source.py b/questions/migrations/0013_forecast_source.py
index ccd11208eb..4230d216bf 100644
--- a/questions/migrations/0013_forecast_source.py
+++ b/questions/migrations/0013_forecast_source.py
@@ -15,7 +15,7 @@ class Migration(migrations.Migration):
name="source",
field=models.CharField(
blank=True,
- choices=[("api", "Api"), ("ui", "Ui")],
+ choices=[("api", "Api"), ("ui", "Ui"), ("automatic", "Automatic")],
default="",
max_length=30,
null=True,
diff --git a/questions/migrations/0033_question_options_history.py b/questions/migrations/0033_question_options_history.py
new file mode 100644
index 0000000000..7c4b69a97b
--- /dev/null
+++ b/questions/migrations/0033_question_options_history.py
@@ -0,0 +1,50 @@
+# Generated by Django 5.1.13 on 2025-11-15 19:35
+from datetime import datetime
+
+
+import questions.models
+from django.db import migrations, models
+
+
+def initialize_options_history(apps, schema_editor):
+ Question = apps.get_model("questions", "Question")
+ questions = Question.objects.filter(options__isnull=False)
+ for question in questions:
+ if question.options:
+ question.options_history = [(datetime.min.isoformat(), question.options)]
+ Question.objects.bulk_update(questions, ["options_history"])
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("questions", "0032_alter_aggregateforecast_forecast_values_and_more"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="forecast",
+ name="source",
+ field=models.CharField(
+ blank=True,
+ choices=[("api", "Api"), ("ui", "Ui"), ("automatic", "Automatic")],
+ db_index=True,
+ default="",
+ max_length=30,
+ null=True,
+ ),
+ ),
+ migrations.AddField(
+ model_name="question",
+ name="options_history",
+ field=models.JSONField(
+ blank=True,
+ help_text="For Multiple Choice only.\n
list of tuples: (isoformat_datetime, options_list). (json stores them as lists)\n
Records the history of options over time.\n
Initialized with (datetime.min.isoformat(), self.options) upon question creation.\n
Updated whenever options are changed.",
+ null=True,
+ validators=[questions.models.validate_options_history],
+ ),
+ ),
+ migrations.RunPython(
+ initialize_options_history, reverse_code=migrations.RunPython.noop
+ ),
+ ]
diff --git a/questions/models.py b/questions/models.py
index 3849b0f8e4..60edd13fda 100644
--- a/questions/models.py
+++ b/questions/models.py
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
+from django.core.exceptions import ValidationError
from django.db import models
from django.db.models import Count, QuerySet, Q, F, Exists, OuterRef
from django.utils import timezone
@@ -8,7 +9,7 @@
from sql_util.aggregates import SubqueryAggregate
from questions.constants import QuestionStatus
-from questions.types import AggregationMethod
+from questions.types import AggregationMethod, OptionsHistoryType
from scoring.constants import ScoreTypes
from users.models import User
from utils.models import TimeStampedModel, TranslatedModel
@@ -20,6 +21,27 @@
DEFAULT_INBOUND_OUTCOME_COUNT = 200
+def validate_options_history(value):
+ # Expect: [ (float, [str, ...]), ... ] or equivalent
+ if not isinstance(value, list):
+ raise ValidationError("Must be a list.")
+ for i, item in enumerate(value):
+ if (
+ not isinstance(item, (list, tuple))
+ or len(item) != 2
+ or not isinstance(item[0], str)
+ or not isinstance(item[1], list)
+ or not all(isinstance(s, str) for s in item[1])
+ ):
+ raise ValidationError(f"Bad item at index {i}: {item!r}")
+ try:
+ datetime.fromisoformat(item[0])
+ except ValueError:
+ raise ValidationError(
+ f"Bad datetime format at index {i}: {item[0]!r}, must be isoformat string"
+ )
+
+
class QuestionQuerySet(QuerySet):
def annotate_forecasts_count(self):
return self.annotate(
@@ -197,8 +219,20 @@ class QuestionType(models.TextChoices):
)
unit = models.CharField(max_length=25, blank=True)
- # list of multiple choice option labels
- options = ArrayField(models.CharField(max_length=200), blank=True, null=True)
+ # multiple choice fields
+ options: list[str] | None = ArrayField(
+ models.CharField(max_length=200), blank=True, null=True
+ )
+ options_history: OptionsHistoryType | None = models.JSONField(
+ null=True,
+ blank=True,
+ validators=[validate_options_history],
+ help_text="""For Multiple Choice only.
+
list of tuples: (isoformat_datetime, options_list). (json stores them as lists)
+
Records the history of options over time.
+
Initialized with (datetime.min.isoformat(), self.options) upon question creation.
+
Updated whenever options are changed.""",
+ )
# Legacy field that will be removed
possibilities = models.JSONField(null=True, blank=True)
@@ -251,6 +285,9 @@ def save(self, **kwargs):
self.zero_point = None
if self.type != self.QuestionType.MULTIPLE_CHOICE:
self.options = None
+ if self.type == self.QuestionType.MULTIPLE_CHOICE and not self.options_history:
+ # initialize options history on first save
+ self.options_history = [(datetime.min.isoformat(), self.options or [])]
return super().save(**kwargs)
@@ -545,8 +582,11 @@ class Forecast(models.Model):
)
class SourceChoices(models.TextChoices):
- API = "api"
- UI = "ui"
+ API = "api" # made via the api
+ UI = "ui" # made using the api
+ # an automatically assigned forecast
+ # usually this means a regular forecast was split
+ AUTOMATIC = "automatic"
# logging the source of the forecast for data purposes
source = models.CharField(
@@ -555,6 +595,7 @@ class SourceChoices(models.TextChoices):
null=True,
choices=SourceChoices.choices,
default="",
+ db_index=True,
)
distribution_input = models.JSONField(
@@ -596,14 +637,16 @@ def get_prediction_values(self) -> list[float | None]:
return self.probability_yes_per_category
return self.continuous_cdf
- def get_pmf(self) -> list[float]:
+ def get_pmf(self, replace_none: bool = False) -> list[float]:
"""
- gets the PMF for this forecast, replacing None values with 0.0
- Not for serialization use (keep None values in that case)
+ gets the PMF for this forecast
+ replaces None values with 0.0 if replace_none is True
"""
if self.probability_yes:
return [1 - self.probability_yes, self.probability_yes]
if self.probability_yes_per_category:
+ if not replace_none:
+ return self.probability_yes_per_category
return [
v or 0.0 for v in self.probability_yes_per_category
] # replace None with 0.0
@@ -678,18 +721,20 @@ def get_cdf(self) -> list[float | None] | None:
return self.forecast_values
return None
- def get_pmf(self) -> list[float]:
+ def get_pmf(self, replace_none: bool = False) -> list[float | None]:
"""
- gets the PMF for this forecast, replacing None values with 0.0
- Not for serialization use (keep None values in that case)
+ gets the PMF for this forecast
+ replacing None values with 0.0 if replace_none is True
"""
# grab annotation if it exists for efficiency
question_type = getattr(self, "question_type", self.question.type)
- forecast_values = [
- v or 0.0 for v in self.forecast_values
- ] # replace None with 0.0
+ forecast_values = self.forecast_values
+ if question_type == Question.QuestionType.MULTIPLE_CHOICE:
+ if not replace_none:
+ return forecast_values
+ return [v or 0.0 for v in forecast_values] # replace None with 0.0
if question_type in QUESTION_CONTINUOUS_TYPES:
- cdf: list[float] = forecast_values
+ cdf: list[float] = forecast_values # type: ignore
pmf = [cdf[0]]
for i in range(1, len(cdf)):
pmf.append(cdf[i] - cdf[i - 1])
diff --git a/questions/serializers/common.py b/questions/serializers/common.py
index d514b59516..0f7b1d46d2 100644
--- a/questions/serializers/common.py
+++ b/questions/serializers/common.py
@@ -1,5 +1,6 @@
-import logging
from datetime import datetime, timezone as dt_timezone, timedelta
+from collections import Counter
+import logging
import numpy as np
from django.utils import timezone
@@ -17,10 +18,9 @@
AggregateForecast,
Forecast,
)
-from questions.serializers.aggregate_forecasts import (
- serialize_question_aggregations,
-)
-from questions.types import QuestionMovement
+from questions.serializers.aggregate_forecasts import serialize_question_aggregations
+from questions.services.multiple_choice_handlers import get_all_options_from_history
+from questions.types import OptionsHistoryType, QuestionMovement
from users.models import User
from utils.the_math.formulas import (
get_scaled_quartiles_from_cdf,
@@ -40,6 +40,7 @@ class QuestionSerializer(serializers.ModelSerializer):
actual_close_time = serializers.SerializerMethodField()
resolution = serializers.SerializerMethodField()
spot_scoring_time = serializers.SerializerMethodField()
+ all_options_ever = serializers.SerializerMethodField()
class Meta:
model = Question
@@ -58,6 +59,8 @@ class Meta:
"type",
# Multiple-choice Questions only
"options",
+ "all_options_ever",
+ "options_history",
"group_variable",
# Used for Group Of Questions to determine
# whether question is eligible for forecasting
@@ -122,6 +125,10 @@ def get_actual_close_time(self, question: Question):
return min(question.scheduled_close_time, question.actual_resolve_time)
return question.scheduled_close_time
+ def get_all_options_ever(self, question: Question):
+ if question.options_history:
+ return get_all_options_from_history(question.options_history)
+
def get_resolution(self, question: Question):
resolution = question.resolution
@@ -226,9 +233,85 @@ class Meta(QuestionWriteSerializer.Meta):
"cp_reveal_time",
)
+ def validate(self, data: dict):
+ data = super().validate(data)
+
+ if qid := data.get("id"):
+ question = Question.objects.get(id=qid)
+ if data.get("options") != question.options:
+ # if there are user forecasts, we can't update options this way
+ if question.user_forecasts.exists():
+ ValidationError(
+ "Cannot update options through this endpoint while there are "
+ "user forecasts. "
+ "Instead, use /api/questions/update-mc-options/ or the UI on "
+ "the question detail page."
+ )
+
# TODO: add validation for updating continuous question bounds
+class MultipleChoiceOptionsUpdateSerializer(serializers.Serializer):
+ options = serializers.ListField(child=serializers.CharField(), required=True)
+ grace_period_end = serializers.DateTimeField(required=False)
+
+ def validate_new_options(
+ self,
+ new_options: list[str],
+ options_history: OptionsHistoryType,
+ grace_period_end: datetime | None = None,
+ ):
+ ts, current_options = options_history[-1]
+ if new_options == current_options: # no change
+ return
+ if len(new_options) == len(current_options): # renaming
+ if any(v > 1 for v in Counter(new_options).values()):
+ ValidationError("new_options includes duplicate labels")
+ elif timezone.now().timestamp() < ts:
+ raise ValidationError("options cannot change during a grace period")
+ elif len(new_options) < len(current_options): # deletion
+ if len(new_options) < 2:
+ raise ValidationError("Must have 2 or more options")
+ if new_options[-1] != current_options[-1]:
+ raise ValidationError("Cannot delete last option")
+ if [l for l in new_options if l not in current_options]:
+ raise ValidationError(
+ "options cannot change name while some are being deleted"
+ )
+ elif len(new_options) > len(current_options): # addition
+ if not grace_period_end or grace_period_end <= timezone.now():
+ raise ValidationError(
+ "grace_period_end must be in the future if adding options"
+ )
+ if new_options[-1] != current_options[-1]:
+ raise ValidationError("Cannot add option after last option")
+ if [l for l in current_options if l not in new_options]:
+ raise ValidationError(
+ "options cannot change name while some are being added"
+ )
+
+ def validate(self, data: dict) -> dict:
+ question: Question = self.context.get("question")
+ if not question:
+ raise ValidationError(f"question must be provided in context")
+
+ if question.type != Question.QuestionType.MULTIPLE_CHOICE:
+ raise ValidationError("question must be of multiple choice type")
+
+ options = data.get("options")
+ options_history = question.options_history
+ if not options or not options_history:
+ raise ValidationError(
+ "updating multiple choice questions requires options "
+ "and question must already have options_history"
+ )
+
+ grace_period_end = data.get("grace_period_end")
+ self.validate_new_options(options, options_history, grace_period_end)
+
+ return data
+
+
class ConditionalSerializer(serializers.ModelSerializer):
"""
Contains basic info about conditional questions
@@ -394,7 +477,7 @@ class ForecastWriteSerializer(serializers.ModelSerializer):
probability_yes = serializers.FloatField(allow_null=True, required=False)
probability_yes_per_category = serializers.DictField(
- child=serializers.FloatField(), allow_null=True, required=False
+ child=serializers.FloatField(allow_null=True), allow_null=True, required=False
)
continuous_cdf = serializers.ListField(
child=serializers.FloatField(),
@@ -435,21 +518,47 @@ def binary_validation(self, probability_yes):
)
return probability_yes
- def multiple_choice_validation(self, probability_yes_per_category, options):
+ def multiple_choice_validation(
+ self,
+ probability_yes_per_category: dict[str, float | None],
+ current_options: list[str],
+ options_history: OptionsHistoryType | None,
+ ):
if probability_yes_per_category is None:
raise serializers.ValidationError(
"probability_yes_per_category is required"
)
if not isinstance(probability_yes_per_category, dict):
raise serializers.ValidationError("Forecast must be a dictionary")
- if set(probability_yes_per_category.keys()) != set(options):
- raise serializers.ValidationError("Forecast must include all options")
- values = [float(probability_yes_per_category[option]) for option in options]
- if not all([0.001 <= v <= 0.999 for v in values]) or not np.isclose(
- sum(values), 1
- ):
+ if not set(current_options).issubset(set(probability_yes_per_category.keys())):
+ raise serializers.ValidationError(
+ f"Forecast must reflect current options: {current_options}"
+ )
+ all_options = get_all_options_from_history(options_history)
+ if not set(probability_yes_per_category.keys()).issubset(set(all_options)):
+ raise serializers.ValidationError(
+ "Forecast contains probabilities for unknown options"
+ )
+
+ values: list[float | None] = []
+ for option in all_options:
+ value = probability_yes_per_category.get(option, None)
+ if option in current_options:
+ if (value is None) or (not (0.001 <= value <= 0.999)):
+ raise serializers.ValidationError(
+ "Probabilities for current options must be between 0.001 and 0.999"
+ )
+ elif value is not None:
+ raise serializers.ValidationError(
+ f"Probability for inactivate option '{option}' must be null or absent"
+ )
+ values.append(value)
+ if not np.isclose(sum(filter(None, values)), 1):
raise serializers.ValidationError(
- "All probabilities must be between 0.001 and 0.999 and sum to 1.0"
+ "Forecast values must sum to 1.0. "
+ f"Received {probability_yes_per_category} which is interpreted as "
+ f"values: {values} representing {all_options} "
+ f"with current options {current_options}"
)
return values
@@ -556,7 +665,7 @@ def validate(self, data):
"provided for multiple choice questions"
)
data["probability_yes_per_category"] = self.multiple_choice_validation(
- probability_yes_per_category, question.options
+ probability_yes_per_category, question.options, question.options_history
)
else: # Continuous question
if probability_yes or probability_yes_per_category:
@@ -625,6 +734,21 @@ def serialize_question(
archived_scores = question.user_archived_scores
user_forecasts = question.request_user_forecasts
last_forecast = user_forecasts[-1] if user_forecasts else None
+ # if the user has a pre-registered forecast,
+ # replace the current forecast and anything after it
+ if question.type == Question.QuestionType.MULTIPLE_CHOICE:
+ # Right now, Multiple Choice is the only type that can have pre-registered
+ # forecasts
+ if last_forecast and last_forecast.start_time > timezone.now():
+ user_forecasts = [
+ f for f in user_forecasts if f.start_time < timezone.now()
+ ]
+ if user_forecasts:
+ last_forecast.start_time = user_forecasts[-1].start_time
+ user_forecasts[-1] = last_forecast
+ else:
+ last_forecast.start_time = timezone.now()
+ user_forecasts = [last_forecast]
if (
last_forecast
and last_forecast.end_time
@@ -639,11 +763,7 @@ def serialize_question(
many=True,
).data,
"latest": (
- MyForecastSerializer(
- user_forecasts[-1],
- ).data
- if user_forecasts
- else None
+ MyForecastSerializer(last_forecast).data if last_forecast else None
),
"score_data": dict(),
}
diff --git a/questions/services/forecasts.py b/questions/services/forecasts.py
index 15aba16fa3..2616dc7f09 100644
--- a/questions/services/forecasts.py
+++ b/questions/services/forecasts.py
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict
-from datetime import timedelta
-from typing import cast, Iterable
+from datetime import datetime, timedelta, timezone as dt_timezone
+from typing import cast, Iterable, Literal
import sentry_sdk
from django.db import transaction
@@ -13,6 +13,7 @@
from posts.models import PostUserSnapshot, PostSubscription
from posts.services.subscriptions import create_subscription_cp_change
from posts.tasks import run_on_post_forecast
+from questions.services.multiple_choice_handlers import get_all_options_from_history
from scoring.models import Score
from users.models import User
from utils.cache import cache_per_object
@@ -34,21 +35,67 @@
def create_forecast(
*,
- question: Question = None,
- user: User = None,
- continuous_cdf: list[float] = None,
- probability_yes: float = None,
- probability_yes_per_category: list[float] = None,
- distribution_input=None,
+ question: Question,
+ user: User,
+ continuous_cdf: list[float] | None = None,
+ probability_yes: float | None = None,
+ probability_yes_per_category: list[float | None] | None = None,
+ distribution_input: dict | None = None,
+ end_time: datetime | None = None,
+ source: Forecast.SourceChoices | Literal[""] | None = None,
**kwargs,
):
now = timezone.now()
post = question.get_post()
+ source = source or ""
+
+ # delete all future-dated predictions, as this one will override them
+ Forecast.objects.filter(question=question, author=user, start_time__gt=now).delete()
+
+ # if the forecast to be created is for a multiple choice question during a grace
+ # period, we need to agument the forecast accordingly (possibly preregister)
+ if question.type == Question.QuestionType.MULTIPLE_CHOICE:
+ if not probability_yes_per_category:
+ raise ValueError("probability_yes_per_category required for MC questions")
+ options_history = question.options_history
+ if options_history and len(options_history) > 1:
+ period_end = datetime.fromisoformat(options_history[-1][0]).replace(
+ tzinfo=dt_timezone.utc
+ )
+ if period_end > now:
+ all_options = get_all_options_from_history(question.options_history)
+ prior_options = options_history[-2][1]
+ if end_time is None or end_time > period_end:
+ # create a pre-registration for the given forecast
+ Forecast.objects.create(
+ question=question,
+ author=user,
+ start_time=period_end,
+ end_time=end_time,
+ probability_yes_per_category=probability_yes_per_category,
+ post=post,
+ source=Forecast.SourceChoices.AUTOMATIC,
+ **kwargs,
+ )
+ end_time = period_end
+
+ prior_pmf: list[float | None] = [None] * len(all_options)
+ for i, (option, value) in enumerate(
+ zip(all_options, probability_yes_per_category)
+ ):
+ if value is None:
+ continue
+ if option in prior_options:
+ prior_pmf[i] = (prior_pmf[i] or 0.0) + value
+ else:
+ prior_pmf[-1] = (prior_pmf[-1] or 0.0) + value
+ probability_yes_per_category = prior_pmf
forecast = Forecast.objects.create(
question=question,
author=user,
start_time=now,
+ end_time=end_time,
continuous_cdf=continuous_cdf,
probability_yes=probability_yes,
probability_yes_per_category=probability_yes_per_category,
@@ -56,6 +103,7 @@ def create_forecast(
distribution_input if question.type in QUESTION_CONTINUOUS_TYPES else None
),
post=post,
+ source=source,
**kwargs,
)
# tidy up all forecasts
diff --git a/questions/services/multiple_choice_handlers.py b/questions/services/multiple_choice_handlers.py
new file mode 100644
index 0000000000..092d49b5fa
--- /dev/null
+++ b/questions/services/multiple_choice_handlers.py
@@ -0,0 +1,225 @@
+from datetime import datetime, timezone as dt_timezone
+
+from django.db import transaction
+from django.db.models import Q
+from django.utils import timezone
+
+from questions.models import Question, Forecast
+from questions.types import OptionsHistoryType
+
+
+def get_all_options_from_history(
+ options_history: OptionsHistoryType | None,
+) -> list[str]:
+ """Returns the list of all options ever available. The last value in the list
+ is always the "catch-all" option.
+
+ example:
+ options_history = [
+ ("2020-01-01", ["a", "b", "other"]),
+ ("2020-01-02", ["a", "b", "c", "other"]),
+ ("2020-01-03", ["a", "c", "other"]),
+ ]
+ return ["a", "b", "c", "other"]
+ """
+ if not options_history:
+ raise ValueError("Cannot make master list from empty history")
+ designated_other_label = options_history[0][1][-1]
+ all_labels: list[str] = []
+ for _, options in options_history:
+ for label in options[:-1]:
+ if label not in all_labels:
+ all_labels.append(label)
+ return all_labels + [designated_other_label]
+
+
+def multiple_choice_rename_option(
+ question: Question,
+ old_option: str,
+ new_option: str,
+) -> Question:
+ """
+ Modifies question in place and returns it.
+ Renames multiple choice option in question options and options history.
+ """
+ if question.type != Question.QuestionType.MULTIPLE_CHOICE:
+ raise ValueError("Question must be multiple choice")
+ if not question.options or old_option not in question.options:
+ raise ValueError("Old option not found")
+ if new_option in question.options:
+ raise ValueError("New option already exists")
+ if not question.options_history:
+ raise ValueError("Options history is empty")
+
+ question.options = [
+ new_option if opt == old_option else opt for opt in question.options
+ ]
+ for i, (timestr, options) in enumerate(question.options_history):
+ question.options_history[i] = (
+ timestr,
+ [new_option if opt == old_option else opt for opt in options],
+ )
+
+ return question
+
+
+def multiple_choice_delete_options(
+ question: Question,
+ options_to_delete: list[str],
+ timestep: datetime | None = None,
+) -> Question:
+ """
+ Modifies question in place and returns it.
+ Deletes multiple choice options in question options.
+ Adds a new entry to options_history.
+ Slices all user forecasts at timestep.
+ Triggers recalculation of aggregates.
+ """
+ if not options_to_delete:
+ return question
+ timestep = timestep or timezone.now()
+ if question.type != Question.QuestionType.MULTIPLE_CHOICE:
+ raise ValueError("Question must be multiple choice")
+ if not question.options or not all(
+ [opt in question.options for opt in options_to_delete]
+ ):
+ raise ValueError("Option to delete not found")
+ if not question.options_history:
+ raise ValueError("Options history is empty")
+
+ if (
+ datetime.fromisoformat(question.options_history[-1][0]).replace(
+ tzinfo=dt_timezone.utc
+ )
+ > timestep
+ ):
+ raise ValueError("timestep is before the last options history entry")
+
+ # update question
+ new_options = [opt for opt in question.options if opt not in options_to_delete]
+ all_options = get_all_options_from_history(question.options_history)
+
+ question.options = new_options
+ question.options_history.append((timestep.isoformat(), new_options))
+ question.save()
+
+ # update user forecasts
+ user_forecasts = question.user_forecasts.filter(
+ Q(end_time__isnull=True) | Q(end_time__gt=timestep),
+ start_time__lt=timestep,
+ )
+ forecasts_to_create: list[Forecast] = []
+ for forecast in user_forecasts:
+ # get new PMF
+ previous_pmf = forecast.probability_yes_per_category
+ if len(previous_pmf) != len(all_options):
+ raise ValueError(
+ f"Forecast {forecast.id} PMF length does not match "
+ f"all options {all_options}"
+ )
+ new_pmf: list[float | None] = [None] * len(all_options)
+ for value, label in zip(previous_pmf, all_options):
+ if value is None:
+ continue
+ if label in new_options:
+ new_pmf[all_options.index(label)] = (
+ new_pmf[all_options.index(label)] or 0.0
+ ) + value
+ else:
+ new_pmf[-1] = (
+ new_pmf[-1] or 0.0
+ ) + value # add to catch-all last option
+
+ # slice forecast
+ if forecast.start_time >= timestep:
+ # forecast is completely after timestep, just update PMF
+ forecast.probability_yes_per_category = new_pmf
+ continue
+ forecasts_to_create.append(
+ Forecast(
+ question=question,
+ author=forecast.author,
+ start_time=timestep,
+ end_time=forecast.end_time,
+ probability_yes_per_category=new_pmf,
+ post=forecast.post,
+ source=Forecast.SourceChoices.AUTOMATIC, # mark as automatic forecast
+ )
+ )
+ forecast.end_time = timestep
+
+ with transaction.atomic():
+ Forecast.objects.bulk_update(
+ user_forecasts, ["end_time", "probability_yes_per_category"]
+ )
+ Forecast.objects.bulk_create(forecasts_to_create)
+
+ # trigger recalculation of aggregates
+ from questions.services.forecasts import build_question_forecasts
+
+ build_question_forecasts(question)
+
+ return question
+
+
+def multiple_choice_add_options(
+ question: Question,
+ options_to_add: list[str],
+ grace_period_end: datetime,
+ timestep: datetime | None = None,
+) -> Question:
+ """
+ Modifies question in place and returns it.
+ Adds multiple choice options in question options.
+ Adds a new entry to options_history.
+ Terminates all user forecasts at grace_period_end.
+ Triggers recalculation of aggregates.
+ """
+ if not options_to_add:
+ return question
+ timestep = timestep or timezone.now()
+ if question.type != Question.QuestionType.MULTIPLE_CHOICE:
+ raise ValueError("Question must be multiple choice")
+ if not question.options or any([opt in question.options for opt in options_to_add]):
+ raise ValueError("Option to add already found")
+ if not question.options_history:
+ raise ValueError("Options history is empty")
+
+ if timestep > grace_period_end:
+ raise ValueError("grace_period_end must end after timestep")
+ if (
+ datetime.fromisoformat(question.options_history[-1][0]).replace(
+ tzinfo=dt_timezone.utc
+ )
+ > timestep
+ ):
+ raise ValueError("timestep is before the last options history entry")
+
+ # update question
+ new_options = question.options[:-1] + options_to_add + question.options[-1:]
+ question.options = new_options
+ question.options_history.append((grace_period_end.isoformat(), new_options))
+ question.save()
+
+ # update user forecasts
+ user_forecasts = question.user_forecasts.all()
+ for forecast in user_forecasts:
+ pmf = forecast.probability_yes_per_category
+ forecast.probability_yes_per_category = (
+ pmf[:-1] + [None] * len(options_to_add) + [pmf[-1]]
+ )
+ if forecast.start_time < grace_period_end and (
+ forecast.end_time is None or forecast.end_time > grace_period_end
+ ):
+ forecast.end_time = grace_period_end
+ with transaction.atomic():
+ Forecast.objects.bulk_update(
+ user_forecasts, ["probability_yes_per_category", "end_time"]
+ )
+
+ # trigger recalculation of aggregates
+ from questions.services.forecasts import build_question_forecasts
+
+ build_question_forecasts(question)
+
+ return question
diff --git a/questions/types.py b/questions/types.py
index 9556806b41..f87735e520 100644
--- a/questions/types.py
+++ b/questions/types.py
@@ -3,6 +3,8 @@
from django.db import models
from django.db.models import TextChoices
+OptionsHistoryType = list[tuple[str, list[str]]]
+
class Direction(TextChoices):
UNCHANGED = "unchanged"
diff --git a/scoring/score_math.py b/scoring/score_math.py
index fada04f0d1..546b19d310 100644
--- a/scoring/score_math.py
+++ b/scoring/score_math.py
@@ -20,7 +20,7 @@
@dataclass
class AggregationEntry:
- pmf: np.ndarray | list[float]
+ pmf: np.ndarray | list[float | None]
num_forecasters: int
timestamp: float
@@ -36,7 +36,7 @@ def get_geometric_means(
timesteps.add(forecast.end_time.timestamp())
for timestep in sorted(timesteps):
prediction_values = [
- f.get_pmf()
+ f.get_pmf(replace_none=True)
for f in forecasts
if f.start_time.timestamp() <= timestep
and (f.end_time is None or f.end_time.timestamp() > timestep)
@@ -84,9 +84,12 @@ def evaluate_forecasts_baseline_accuracy(
forecast_coverage = forecast_duration / total_duration
pmf = forecast.get_pmf()
if question_type in ["binary", "multiple_choice"]:
- forecast_score = (
- 100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf))
- )
+ # forecasts always have `None` assigned to MC options that aren't
+ # available at the time. Detecting these allows us to avoid trying to
+ # follow the question's options_history.
+ options_at_time = len([p for p in pmf if p is not None])
+ p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
+ forecast_score = 100 * np.log(p * options_at_time) / np.log(options_at_time)
else:
if resolution_bucket in [0, len(pmf) - 1]:
baseline = 0.05
@@ -116,8 +119,13 @@ def evaluate_forecasts_baseline_spot_forecast(
if start <= spot_forecast_timestamp < end:
pmf = forecast.get_pmf()
if question_type in ["binary", "multiple_choice"]:
+ # forecasts always have `None` assigned to MC options that aren't
+ # available at the time. Detecting these allows us to avoid trying to
+ # follow the question's options_history.
+ options_at_time = len([p for p in pmf if p is not None])
+ p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
forecast_score = (
- 100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf))
+ 100 * np.log(p * options_at_time) / np.log(options_at_time)
)
else:
if resolution_bucket in [0, len(pmf) - 1]:
@@ -159,17 +167,21 @@ def evaluate_forecasts_peer_accuracy(
continue
pmf = forecast.get_pmf()
+ p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
interval_scores: list[float | None] = []
for gm in geometric_mean_forecasts:
if forecast_start <= gm.timestamp < forecast_end:
- score = (
+ gmp = (
+ gm.pmf[resolution_bucket] or gm.pmf[-1]
+ ) # if None, read from Other
+ interval_score = (
100
* (gm.num_forecasters / (gm.num_forecasters - 1))
- * np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket])
+ * np.log(p / gmp)
)
if question_type in QUESTION_CONTINUOUS_TYPES:
- score /= 2
- interval_scores.append(score)
+ interval_score /= 2
+ interval_scores.append(interval_score)
else:
interval_scores.append(None)
@@ -218,10 +230,10 @@ def evaluate_forecasts_peer_spot_forecast(
)
if start <= spot_forecast_timestamp < end:
pmf = forecast.get_pmf()
+ p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
+ gmp = gm.pmf[resolution_bucket] or gm.pmf[-1] # if None, read from Other
forecast_score = (
- 100
- * (gm.num_forecasters / (gm.num_forecasters - 1))
- * np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket])
+ 100 * (gm.num_forecasters / (gm.num_forecasters - 1)) * np.log(p / gmp)
)
if question_type in QUESTION_CONTINUOUS_TYPES:
forecast_score /= 2
@@ -260,11 +272,15 @@ def evaluate_forecasts_legacy_relative(
continue
pmf = forecast.get_pmf()
+ p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
interval_scores: list[float | None] = []
for bf in baseline_forecasts:
if forecast_start <= bf.timestamp < forecast_end:
- score = np.log2(pmf[resolution_bucket] / bf.pmf[resolution_bucket])
- interval_scores.append(score)
+ bfp = (
+ bf.pmf[resolution_bucket] or bf.pmf[-1]
+ ) # if None, read from Other
+ interval_score = np.log2(p / bfp)
+ interval_scores.append(interval_score)
else:
interval_scores.append(None)
@@ -316,7 +332,7 @@ def evaluate_question(
if spot_forecast_time:
spot_forecast_timestamp = min(spot_forecast_time.timestamp(), actual_close_time)
- # We need all user forecasts to calculated GeoMean even
+ # We need all user forecasts to calculate GeoMean even
# if we're only scoring some or none of the users
user_forecasts = question.user_forecasts.all()
if only_include_user_ids:
diff --git a/tests/unit/test_questions/conftest.py b/tests/unit/test_questions/conftest.py
index 7f7ab29e4f..57ebbb3d20 100644
--- a/tests/unit/test_questions/conftest.py
+++ b/tests/unit/test_questions/conftest.py
@@ -9,6 +9,7 @@
__all__ = [
"question_binary",
+ "question_multiple_choice",
"question_numeric",
"conditional_1",
"question_binary_with_forecast_user_1",
@@ -28,6 +29,7 @@ def question_multiple_choice():
return create_question(
question_type=Question.QuestionType.MULTIPLE_CHOICE,
options=["a", "b", "c", "d"],
+ options_history=[("0001-01-01T00:00:00", ["a", "b", "c", "d"])],
)
diff --git a/tests/unit/test_questions/test_models.py b/tests/unit/test_questions/test_models.py
index ba405474ab..74c5e49b3f 100644
--- a/tests/unit/test_questions/test_models.py
+++ b/tests/unit/test_questions/test_models.py
@@ -43,3 +43,14 @@ def test_filter_within_question_period(
Forecast.objects.filter(id=f1.id).filter_within_question_period().exists()
== include
)
+
+
+def test_initialize_multiple_choice_question():
+ question = create_question(
+ question_type=Question.QuestionType.MULTIPLE_CHOICE,
+ options=["a", "b", "other"],
+ )
+ question.save()
+ assert (
+ question.options_history and question.options_history[0][1] == question.options
+ )
diff --git a/tests/unit/test_questions/test_services.py b/tests/unit/test_questions/test_services/test_lifecycle.py
similarity index 100%
rename from tests/unit/test_questions/test_services.py
rename to tests/unit/test_questions/test_services/test_lifecycle.py
diff --git a/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py b/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py
new file mode 100644
index 0000000000..b7f7840c6e
--- /dev/null
+++ b/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py
@@ -0,0 +1,399 @@
+from datetime import datetime
+
+import pytest # noqa
+
+from questions.models import Question, Forecast
+from questions.services.multiple_choice_handlers import (
+ multiple_choice_add_options,
+ multiple_choice_delete_options,
+ multiple_choice_rename_option,
+)
+from tests.unit.utils import datetime_aware as dt
+from users.models import User
+
+
+@pytest.mark.parametrize(
+ "old_option,new_option,expect_success",
+ [
+ ("Option B", "Option D", True),
+ ("Option X", "Option Y", False), # old_option does not exist
+ ("Option A", "Option A", False), # new_option already exists
+ ],
+)
+def test_multiple_choice_rename_option(
+ question_multiple_choice, old_option, new_option, expect_success
+):
+ question = question_multiple_choice
+ question.options = ["Option A", "Option B", "Option C"]
+ question.save()
+
+ if not expect_success:
+ with pytest.raises(ValueError):
+ multiple_choice_rename_option(question, old_option, new_option)
+ return
+ updated_question = multiple_choice_rename_option(question, old_option, new_option)
+
+ assert old_option not in updated_question.options
+ assert new_option in updated_question.options
+ assert len(updated_question.options) == 3
+
+
+@pytest.mark.parametrize(
+ "initial_options,options_to_delete,forecasts,expected_forecasts,expect_success",
+ [
+ (["a", "b", "other"], ["b"], [], [], True), # simplest path
+ (["a", "b", "other"], ["c"], [], [], False), # try to remove absent item
+ (["a", "b", "other"], ["a", "b"], [], [], True), # remove two items
+ (
+ ["a", "b", "other"],
+ ["b"],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ ),
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, None, 0.8],
+ source=Forecast.SourceChoices.AUTOMATIC,
+ ),
+ ],
+ True,
+ ), # happy path
+ (
+ ["a", "b", "c", "other"],
+ ["b", "c"],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.1, 0.4],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, 0.1, 0.4],
+ ),
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, None, None, 0.8],
+ source=Forecast.SourceChoices.AUTOMATIC,
+ ),
+ ],
+ True,
+ ), # happy path removing 2
+ (
+ ["a", "b", "other"],
+ ["b"],
+ [
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.8],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.8],
+ ),
+ ],
+ True,
+ ), # forecast is at / after timestep
+ (
+ ["a", "b", "other"],
+ [],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ True,
+ ), # no effect
+ (
+ ["a", "b", "other"],
+ ["b"],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.8],
+ )
+ ],
+ [],
+ False,
+ ), # initial forecast is invalid
+ (
+ ["a", "b", "other"],
+ ["b"],
+ [
+ Forecast(
+ start_time=dt(2023, 1, 1),
+ end_time=dt(2024, 1, 1),
+ probability_yes_per_category=[0.6, 0.15, 0.25],
+ ),
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ ),
+ ],
+ [
+ Forecast(
+ start_time=dt(2023, 1, 1),
+ end_time=dt(2024, 1, 1),
+ probability_yes_per_category=[0.6, 0.15, 0.25],
+ ),
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ ),
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, None, 0.8],
+ source=Forecast.SourceChoices.AUTOMATIC,
+ ),
+ ],
+ True,
+ ), # preserve previous forecasts
+ ],
+)
+def test_multiple_choice_delete_options(
+ question_multiple_choice: Question,
+ user1: User,
+ initial_options: list[str],
+ options_to_delete: list[str],
+ forecasts: list[Forecast],
+ expected_forecasts: list[Forecast],
+ expect_success: bool,
+):
+ question = question_multiple_choice
+ question.options = initial_options
+ question.options_history = [(datetime.min.isoformat(), initial_options)]
+ question.save()
+
+ timestep = dt(2025, 1, 1)
+ for forecast in forecasts:
+ forecast.author = user1
+ forecast.question = question
+ forecast.save()
+
+ if not expect_success:
+ with pytest.raises(ValueError):
+ multiple_choice_delete_options(
+ question, options_to_delete, timestep=timestep
+ )
+ return
+
+ multiple_choice_delete_options(question, options_to_delete, timestep=timestep)
+
+ question.refresh_from_db()
+ expected_options = [opt for opt in initial_options if opt not in options_to_delete]
+ assert question.options == expected_options
+ ts, options = question.options_history[-1]
+ assert ts == (
+ timestep.isoformat() if options_to_delete else datetime.min.isoformat()
+ )
+ assert options == expected_options
+
+ forecasts = question.user_forecasts.order_by("start_time")
+ assert len(forecasts) == len(expected_forecasts)
+ for f, e in zip(forecasts, expected_forecasts):
+ assert f.start_time == e.start_time
+ assert f.end_time == e.end_time
+ assert f.probability_yes_per_category == e.probability_yes_per_category
+ assert f.source == e.source
+
+
+@pytest.mark.parametrize(
+ "initial_options,options_to_add,grace_period_end,forecasts,expected_forecasts,"
+ "expect_success",
+ [
+ (["a", "b", "other"], ["c"], dt(2025, 1, 1), [], [], True), # simplest path
+ (["a", "b", "other"], ["b"], dt(2025, 1, 1), [], [], False), # copied add
+ (["a", "b", "other"], ["c", "d"], dt(2025, 1, 1), [], [], True), # double add
+ # grace period before last options history
+ (["a", "b", "other"], ["c"], dt(1900, 1, 1), [], [], False),
+ (
+ ["a", "b", "other"],
+ ["c"],
+ dt(2025, 1, 1),
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, None, 0.5],
+ )
+ ],
+ True,
+ ), # happy path
+ (
+ ["a", "b", "other"],
+ ["c", "d"],
+ dt(2025, 1, 1),
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, None, None, 0.5],
+ )
+ ],
+ True,
+ ), # happy path adding two options
+ (
+ ["a", "b", "other"],
+ ["c"],
+ dt(2025, 1, 1),
+ [
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2025, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, None, 0.5],
+ )
+ ],
+ True,
+ ), # forecast starts at /after grace_period_end
+ (
+ ["a", "b", "other"],
+ [],
+ dt(2025, 1, 1),
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ [
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ )
+ ],
+ True,
+ ), # no effect
+ (
+ ["a", "b", "other"],
+ ["c"],
+ dt(2025, 1, 1),
+ [
+ Forecast(
+ start_time=dt(2023, 1, 1),
+ end_time=dt(2024, 1, 1),
+ probability_yes_per_category=[0.6, 0.15, 0.25],
+ ),
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=None,
+ probability_yes_per_category=[0.2, 0.3, 0.5],
+ ),
+ ],
+ [
+ Forecast(
+ start_time=dt(2023, 1, 1),
+ end_time=dt(2024, 1, 1),
+ probability_yes_per_category=[0.6, 0.15, None, 0.25],
+ ),
+ Forecast(
+ start_time=dt(2024, 1, 1),
+ end_time=dt(2025, 1, 1),
+ probability_yes_per_category=[0.2, 0.3, None, 0.5],
+ ),
+ ],
+ True,
+ ), # edit all forecasts including old
+ ],
+)
+def test_multiple_choice_add_options(
+ question_multiple_choice: Question,
+ user1: User,
+ initial_options: list[str],
+ options_to_add: list[str],
+ grace_period_end: datetime,
+ forecasts: list[Forecast],
+ expected_forecasts: list[Forecast],
+ expect_success: bool,
+):
+ question = question_multiple_choice
+ question.options = initial_options
+ question.options_history = [(datetime.min.isoformat(), initial_options)]
+ question.save()
+
+ for forecast in forecasts:
+ forecast.author = user1
+ forecast.question = question
+ forecast.save()
+
+ if not expect_success:
+ with pytest.raises(ValueError):
+ multiple_choice_add_options(
+ question, options_to_add, grace_period_end, timestep=dt(2024, 7, 1)
+ )
+ return
+
+ multiple_choice_add_options(
+ question, options_to_add, grace_period_end, timestep=dt(2024, 7, 1)
+ )
+
+ question.refresh_from_db()
+ expected_options = initial_options[:-1] + options_to_add + initial_options[-1:]
+ assert question.options == expected_options
+ ts, options = question.options_history[-1]
+ assert ts == (
+ grace_period_end.isoformat() if options_to_add else datetime.min.isoformat()
+ )
+ assert options == expected_options
+
+ forecasts = question.user_forecasts.order_by("start_time")
+ assert len(forecasts) == len(expected_forecasts)
+ for f, e in zip(forecasts, expected_forecasts):
+ assert f.start_time == e.start_time
+ assert f.end_time == e.end_time
+ assert f.probability_yes_per_category == e.probability_yes_per_category
+ assert f.source == e.source
diff --git a/tests/unit/test_questions/test_views.py b/tests/unit/test_questions/test_views.py
index 2f009b1452..3e75a4f275 100644
--- a/tests/unit/test_questions/test_views.py
+++ b/tests/unit/test_questions/test_views.py
@@ -10,11 +10,13 @@
from posts.models import Post
from questions.models import Forecast, Question, UserForecastNotification
+from questions.types import OptionsHistoryType
from questions.tasks import check_and_schedule_forecast_widrawal_due_notifications
from tests.unit.test_posts.conftest import * # noqa
from tests.unit.test_posts.factories import factory_post
from tests.unit.test_questions.conftest import * # noqa
from tests.unit.test_questions.factories import create_question
+from users.models import User
class TestQuestionForecast:
@@ -75,30 +77,173 @@ def test_forecast_binary_invalid(self, post_binary_public, user1_client, props):
)
assert response.status_code == 400
+ @freeze_time("2025-01-01")
@pytest.mark.parametrize(
- "props",
+ "options_history,forecast_props,expected",
[
- {"probability_yes_per_category": {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}},
+ (
+ [("0001-01-01T00:00:00", ["a", "other"])],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "other": 0.4,
+ },
+ "end_time": "2026-01-01",
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, 0.4],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ ],
+ ), # simple path
+ (
+ [("0001-01-01T00:00:00", ["a", "b", "other"])],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "b": 0.15,
+ "other": 0.25,
+ },
+ "end_time": "2026-01-01",
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, 0.25],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ ],
+ ), # simple path 3 options
+ (
+ [
+ ("0001-01-01T00:00:00", ["a", "b", "other"]),
+ (datetime(2024, 1, 1).isoformat(), ["a", "other"]),
+ ],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "other": 0.4,
+ },
+ "end_time": "2026-01-01",
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, None, 0.4],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ ],
+ ), # option deletion
+ (
+ [
+ ("0001-01-01T00:00:00", ["a", "b", "other"]),
+ (datetime(2024, 1, 1).isoformat(), ["a", "b", "c", "other"]),
+ ],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "b": 0.15,
+ "c": 0.20,
+ "other": 0.05,
+ },
+ "end_time": "2026-01-01",
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, 0.20, 0.05],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ ],
+ ), # option addition
+ (
+ [
+ ("0001-01-01T00:00:00", ["a", "b", "other"]),
+ (datetime(2026, 1, 1).isoformat(), ["a", "b", "c", "other"]),
+ ],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "b": 0.15,
+ "c": 0.20,
+ "other": 0.05,
+ },
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, None, 0.25],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, 0.20, 0.05],
+ start_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=None,
+ source=Forecast.SourceChoices.AUTOMATIC,
+ ),
+ ],
+ ), # forecasting during a grace period
+ (
+ [
+ ("0001-01-01T00:00:00", ["a", "b", "other"]),
+ (datetime(2026, 1, 1).isoformat(), ["a", "b", "c", "other"]),
+ ],
+ {
+ "probability_yes_per_category": {
+ "a": 0.6,
+ "b": 0.15,
+ "c": 0.20,
+ "other": 0.05,
+ },
+ "end_time": "2027-01-01",
+ },
+ [
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, None, 0.25],
+ start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ ),
+ Forecast(
+ probability_yes_per_category=[0.6, 0.15, 0.20, 0.05],
+ start_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc),
+ end_time=datetime(2027, 1, 1, tzinfo=dt_timezone.utc),
+ source=Forecast.SourceChoices.AUTOMATIC,
+ ),
+ ],
+ ), # forecasting during a grace period with end time
],
)
def test_forecast_multiple_choice(
- self, post_multiple_choice_public, user1, user1_client, props
+ self,
+ post_multiple_choice_public: Post,
+ user1: User,
+ user1_client,
+ options_history: OptionsHistoryType,
+ forecast_props: dict,
+ expected: list[Forecast],
):
+ question = post_multiple_choice_public.question
+ question.options_history = options_history
+ question.options = options_history[-1][1]
+ question.save()
response = user1_client.post(
self.url,
- data=json.dumps(
- [{"question": post_multiple_choice_public.question.id, **props}]
- ),
+ data=json.dumps([{"question": question.id, **forecast_props}]),
content_type="application/json",
)
assert response.status_code == 201
- forecast = Forecast.objects.filter(
- question=post_multiple_choice_public.question, author=user1
- ).first()
- assert forecast
- assert forecast.probability_yes_per_category == list(
- props.get("probability_yes_per_category").values()
- )
+ forecasts = Forecast.objects.filter(
+ question=post_multiple_choice_public.question,
+ author=user1,
+ ).order_by("start_time")
+ assert len(forecasts) == len(expected)
+ for f, e in zip(forecasts, expected):
+ assert f.start_time == e.start_time
+ assert f.end_time == e.end_time
+ assert f.probability_yes_per_category == e.probability_yes_per_category
+ assert f.source == e.source
@pytest.mark.parametrize(
"props",
diff --git a/tests/unit/test_scoring/test_score_math.py b/tests/unit/test_scoring/test_score_math.py
index 23f5f78c71..652dcd9be3 100644
--- a/tests/unit/test_scoring/test_score_math.py
+++ b/tests/unit/test_scoring/test_score_math.py
@@ -47,7 +47,7 @@ def F(q=None, v=None, s=None, e=None):
return forecast
-def A(p: list[float] | None = None, n: int = 0, t: int | None = None):
+def A(p: list[float | None] | None = None, n: int = 0, t: int | None = None):
# Create an AggregationEntry object with basic values
# p: pmf
# n: number of forecasters
@@ -75,6 +75,11 @@ class TestScoreMath:
([F()] * 100, [A(n=100)]),
# maths
([F(v=0.7), F(v=0.8), F(v=0.9)], [A(p=[0.18171206, 0.79581144], n=3)]),
+ # multiple choice forecasts with placeholder 0s
+ (
+ [F(q=QT.MULTIPLE_CHOICE, v=[0.6, 0.15, None, 0.25])] * 2,
+ [A(n=2, p=[0.6, 0.15, 0.0, 0.25])],
+ ),
# start times
([F(), F(s=1)], [A(), A(t=1, n=2)]),
([F(), F(s=1), F(s=2)], [A(), A(t=1, n=2), A(t=2, n=3)]),
@@ -85,7 +90,7 @@ class TestScoreMath:
# numeric
(
[F(q=QT.NUMERIC), F(q=QT.NUMERIC)],
- [A(p=[0] + [1 / 200] * 200 + [0], n=2)],
+ [A(p=[0.0] + [1 / 200] * 200 + [0.0], n=2)],
),
(
[
@@ -103,7 +108,10 @@ def test_get_geometric_means(
result = get_geometric_means(forecasts)
assert len(result) == len(expected)
for ra, ea in zip(result, expected):
- assert all(round(r, 8) == round(e, 8) for r, e in zip(ra.pmf, ea.pmf))
+ assert all(
+ ((r == e) or (round(r, 8) == round(e, 8)))
+ for r, e in zip(ra.pmf, ea.pmf)
+ )
assert ra.num_forecasters == ea.num_forecasters
assert ra.timestamp == ea.timestamp
@@ -131,6 +139,37 @@ def test_get_geometric_means(
([F(v=0.9, s=5)], {}, [S(v=84.79969066 / 2, c=0.5)]), # half coverage
([F(v=2 ** (-1 / 2))], {}, [S(v=50)]),
([F(v=2 ** (-3 / 2))], {}, [S(v=-50)]),
+ # multiple choice w/ placeholder at index 2
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=0.0)],
+ ), # chosen to have a score of 0 for simplicity
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=50)],
+ ), # same score as index == 3 since None should read from "Other"
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=50)],
+ ), # chosen to have a score of 50 for simplicity
# numeric
(
[F(q=QT.NUMERIC)],
@@ -199,6 +238,37 @@ def test_evaluate_forecasts_baseline_accuracy(self, forecasts, args, expected):
([F(v=0.9, s=5)], {}, [S(v=84.79969066, c=1)]),
([F(v=2 ** (-1 / 2))], {}, [S(v=50)]),
([F(v=2 ** (-3 / 2))], {}, [S(v=-50)]),
+ # multiple choice w/ placeholder at index 2
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=0.0)],
+ ), # chosen to have a score of 0 for simplicity
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=50)],
+ ), # same score as index == 3 since None should read from "Other"
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)],
+ )
+ ],
+ {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=50)],
+ ), # chosen to have a score of 50 for simplicity
# numeric
(
[F(q=QT.NUMERIC)],
@@ -319,6 +389,64 @@ def test_evaluate_forecasts_baseline_spot_forecast(self, forecasts, args, expect
S(v=100 * (0.5 * 0 + 0.5 * np.log(0.9 / gmean([0.1, 0.5]))), c=0.5),
],
),
+ # multiple choice w/ placeholder at index 2
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=0), S(v=0)],
+ ), # chosen to have a score of 0 for simplicity
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=25), S(v=-25)],
+ ), # same score as index == 3 since 0.0 should read from "Other"
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=25), S(v=-25)],
+ ), # chosen to have a score of 25 for simplicity
# TODO: add tests with base forecasts different from forecasts
],
)
@@ -403,6 +531,64 @@ def test_evaluate_forecasts_peer_accuracy(self, forecasts, args, expected):
{},
[S(v=100 * np.log(0.1 / 0.5)), S(v=100 * np.log(0.5 / 0.1)), S(c=0)],
),
+ # multiple choice w/ placeholder at index 2
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=0), S(v=0)],
+ ), # chosen to have a score of 0 for simplicity
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=25), S(v=-25)],
+ ), # same score as index == 3 since None should read from "Other"
+ (
+ [
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[
+ 1 / 3,
+ 1 - (np.e ** (0.25) / 3) - 1 / 3,
+ None,
+ np.e ** (0.25) / 3,
+ ],
+ ),
+ F(
+ q=QT.MULTIPLE_CHOICE,
+ v=[1 / 3, 1 / 3, None, 1 / 3],
+ ),
+ ],
+ {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE},
+ [S(v=25), S(v=-25)],
+ ), # chosen to have a score of 25 for simplicity
# TODO: add tests with base forecasts different from forecasts
],
)
diff --git a/tests/unit/test_utils/test_the_math/conftest.py b/tests/unit/test_utils/test_the_math/conftest.py
index b048040bbf..8f150a3813 100644
--- a/tests/unit/test_utils/test_the_math/conftest.py
+++ b/tests/unit/test_utils/test_the_math/conftest.py
@@ -1 +1,4 @@
-from tests.unit.test_questions.conftest import question_binary # noqa
+from tests.unit.test_questions.conftest import (
+ question_binary,
+ question_multiple_choice,
+) # noqa
diff --git a/tests/unit/test_utils/test_the_math/test_aggregations.py b/tests/unit/test_utils/test_the_math/test_aggregations.py
index 73aaa5119e..911c9b4594 100644
--- a/tests/unit/test_utils/test_the_math/test_aggregations.py
+++ b/tests/unit/test_utils/test_the_math/test_aggregations.py
@@ -23,6 +23,12 @@
GoldMedalistsAggregation,
JoinedBeforeDateAggregation,
SingleAggregation,
+ compute_weighted_semi_standard_deviations,
+)
+from utils.typing import (
+ ForecastValues,
+ ForecastsValues,
+ Weights,
)
@@ -46,6 +52,64 @@ def test_summarize_array(array, max_size, expceted_array):
class TestAggregations:
+ @pytest.mark.parametrize(
+ "forecasts_values, weights, expected",
+ [
+ (
+ [[0.5, 0.5]],
+ None,
+ ([0.0, 0.0], [0.0, 0.0]),
+ ), # Trivial
+ (
+ [
+ [0.5, 0.5],
+ [0.5, 0.5],
+ [0.5, 0.5],
+ ],
+ None,
+ ([0.0, 0.0], [0.0, 0.0]),
+ ), # 3 unwavaring forecasts
+ (
+ [
+ [0.2, 0.8],
+ [0.5, 0.5],
+ [0.8, 0.2],
+ ],
+ None,
+ ([0.3, 0.3], [0.3, 0.3]),
+ ), # 3 unwavaring forecasts
+ (
+ [
+ [0.6, 0.15, None, 0.25],
+ [0.6, 0.15, None, 0.25],
+ ],
+ None,
+ ([0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]),
+ ), # identical forecasts with placeholders
+ (
+ [
+ [0.4, 0.25, None, 0.35],
+ [0.6, 0.15, None, 0.25],
+ ],
+ None,
+ ([0.1, 0.05, 0.0, 0.05], [0.1, 0.05, 0.0, 0.05]),
+ ), # minorly different forecasts with placeholders
+ ],
+ )
+ def test_compute_weighted_semi_standard_deviations(
+ self,
+ forecasts_values: ForecastsValues,
+ weights: Weights | None,
+ expected: tuple[ForecastValues, ForecastValues],
+ ):
+ result = compute_weighted_semi_standard_deviations(forecasts_values, weights)
+ rl, ru = result
+ el, eu = expected
+ for v, e in zip(rl, el):
+ np.testing.assert_approx_equal(v, e)
+ for v, e in zip(ru, eu):
+ np.testing.assert_approx_equal(v, e)
+
@pytest.mark.parametrize("aggregation_name", [Agg.method for Agg in AGGREGATIONS])
def test_aggregations_initialize(
self, question_binary: Question, aggregation_name: str
@@ -241,46 +305,120 @@ def test_aggregations_initialize(
histogram=None,
),
),
+ # Multiple choice with placeholders
+ (
+ {},
+ ForecastSet(
+ forecasts_values=[
+ [0.6, 0.15, None, 0.25],
+ [0.6, 0.25, None, 0.15],
+ ],
+ timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ forecaster_ids=[1, 2],
+ timesteps=[
+ datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
+ datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
+ ],
+ ),
+ True,
+ False,
+ AggregateForecast(
+ start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ method=AggregationMethod.UNWEIGHTED,
+ forecast_values=[0.6, 0.20, None, 0.20],
+ interval_lower_bounds=[0.6, 0.15, None, 0.15],
+ centers=[0.6, 0.20, None, 0.20],
+ interval_upper_bounds=[0.6, 0.25, None, 0.25],
+ means=[0.6, 0.20, None, 0.20],
+ forecaster_count=2,
+ ),
+ ),
+ (
+ {},
+ ForecastSet(
+ forecasts_values=[
+ [0.6, 0.15, None, 0.25],
+ [0.6, 0.25, None, 0.15],
+ [0.4, 0.35, None, 0.25],
+ ],
+ timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ forecaster_ids=[1, 2],
+ timesteps=[
+ datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
+ datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
+ ],
+ ),
+ True,
+ False,
+ AggregateForecast(
+ start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ method=AggregationMethod.UNWEIGHTED,
+ forecast_values=[
+ 0.5453965360072925,
+ 0.22730173199635367,
+ None,
+ 0.22730173199635367,
+ ],
+ interval_lower_bounds=[
+ 0.3635976906715284,
+ 0.1363810391978122,
+ None,
+ 0.1363810391978122,
+ ],
+ centers=[
+ 0.5453965360072925,
+ 0.22730173199635367,
+ None,
+ 0.22730173199635367,
+ ],
+ interval_upper_bounds=[
+ 0.5453965360072925,
+ 0.3182224247948951,
+ None,
+ 0.22730173199635367,
+ ],
+ means=[
+ 0.5333333333333333,
+ 0.25,
+ None,
+ 0.21666666666666667,
+ ],
+ forecaster_count=3,
+ ),
+ ),
],
)
def test_UnweightedAggregation(
self,
question_binary: Question,
+ question_multiple_choice: Question,
init_params: dict,
forecast_set: ForecastSet,
include_stats: bool,
histogram: bool,
expected: AggregateForecast,
):
- aggregation = UnweightedAggregation(question=question_binary, **init_params)
- new_aggregation = aggregation.calculate_aggregation_entry(
+ if len(forecast_set.forecasts_values[0]) == 2:
+ question = question_binary
+ else:
+ question = question_multiple_choice
+
+ aggregation = UnweightedAggregation(question=question, **init_params)
+ new_aggregation: AggregateForecast = aggregation.calculate_aggregation_entry(
forecast_set, include_stats, histogram
)
- assert new_aggregation.start_time == expected.start_time
- assert (
- new_aggregation.forecast_values == expected.forecast_values
- ) or np.allclose(new_aggregation.forecast_values, expected.forecast_values)
- assert new_aggregation.forecaster_count == expected.forecaster_count
- assert (
- new_aggregation.interval_lower_bounds == expected.interval_lower_bounds
- ) or np.allclose(
- new_aggregation.interval_lower_bounds, expected.interval_lower_bounds
- )
- assert (new_aggregation.centers == expected.centers) or np.allclose(
- new_aggregation.centers, expected.centers
- )
- assert (
- new_aggregation.interval_upper_bounds == expected.interval_upper_bounds
- ) or np.allclose(
- new_aggregation.interval_upper_bounds, expected.interval_upper_bounds
- )
- assert (new_aggregation.means == expected.means) or np.allclose(
- new_aggregation.means, expected.means
- )
- assert (new_aggregation.histogram == expected.histogram) or np.allclose(
- new_aggregation.histogram, expected.histogram
- )
+ for r, e in [
+ (new_aggregation.forecast_values, expected.forecast_values),
+ (new_aggregation.interval_lower_bounds, expected.interval_lower_bounds),
+ (new_aggregation.centers, expected.centers),
+ (new_aggregation.interval_upper_bounds, expected.interval_upper_bounds),
+ (new_aggregation.means, expected.means),
+ (new_aggregation.histogram, expected.histogram),
+ ]:
+ r = np.where(np.equal(r, None), np.nan, r).astype(float)
+ e = np.where(np.equal(e, None), np.nan, e).astype(float)
+ np.testing.assert_allclose(r, e, equal_nan=True)
@pytest.mark.parametrize(
"init_params, forecast_set, include_stats, histogram, expected",
@@ -468,20 +606,52 @@ def test_UnweightedAggregation(
histogram=None,
),
),
+ # Multiple choice with placeholders
+ (
+ {},
+ ForecastSet(
+ forecasts_values=[
+ [0.6, 0.15, None, 0.25],
+ [0.6, 0.25, None, 0.15],
+ ],
+ timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ forecaster_ids=[1, 2],
+ timesteps=[
+ datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
+ datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
+ ],
+ ),
+ True,
+ False,
+ AggregateForecast(
+ start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
+ method=AggregationMethod.UNWEIGHTED,
+ forecast_values=[0.6, 0.20, None, 0.20],
+ interval_lower_bounds=[0.6, 0.15, None, 0.15],
+ centers=[0.6, 0.20, None, 0.20],
+ interval_upper_bounds=[0.6, 0.25, None, 0.25],
+ means=[0.6, 0.20, None, 0.20],
+ forecaster_count=2,
+ ),
+ ),
],
)
def test_RecencyWeightedAggregation(
self,
question_binary: Question,
+ question_multiple_choice: Question,
init_params: dict,
forecast_set: ForecastSet,
include_stats: bool,
histogram: bool,
expected: AggregateForecast,
):
- aggregation = RecencyWeightedAggregation(
- question=question_binary, **init_params
- )
+ if len(forecast_set.forecasts_values[0]) == 2:
+ question = question_binary
+ else:
+ question = question_multiple_choice
+
+ aggregation = RecencyWeightedAggregation(question=question, **init_params)
new_aggregation = aggregation.calculate_aggregation_entry(
forecast_set, include_stats, histogram
)
diff --git a/tests/unit/test_utils/test_the_math/test_formulas.py b/tests/unit/test_utils/test_the_math/test_formulas.py
index 54f78dd357..30bb3d3e13 100644
--- a/tests/unit/test_utils/test_the_math/test_formulas.py
+++ b/tests/unit/test_utils/test_the_math/test_formulas.py
@@ -15,7 +15,12 @@ class TestFormulas:
binary_details = {"type": Question.QuestionType.BINARY}
multiple_choice_details = {
"type": Question.QuestionType.MULTIPLE_CHOICE,
- "options": ["A", "B", "C"],
+ "options": ["a", "c", "Other"],
+ "options_history": [
+ (0, ["a", "b", "Other"]),
+ (100, ["a", "Other"]),
+ (200, ["a", "c", "Other"]),
+ ],
}
numeric_details = {
"type": Question.QuestionType.NUMERIC,
@@ -57,8 +62,10 @@ class TestFormulas:
("", binary_details, None),
(None, binary_details, None),
# Multiple choice questions
- ("A", multiple_choice_details, 0),
- ("C", multiple_choice_details, 2),
+ ("a", multiple_choice_details, 0),
+ ("b", multiple_choice_details, 1),
+ ("c", multiple_choice_details, 2),
+ ("Other", multiple_choice_details, 3),
# Numeric questions
("below_lower_bound", numeric_details, 0),
("-2", numeric_details, 0),
diff --git a/tests/unit/test_utils/test_the_math/test_measures.py b/tests/unit/test_utils/test_the_math/test_measures.py
index b5ee3c8356..ab2273d2f8 100644
--- a/tests/unit/test_utils/test_the_math/test_measures.py
+++ b/tests/unit/test_utils/test_the_math/test_measures.py
@@ -56,14 +56,26 @@
(
[
[0.33, 0.33, 0.34],
- [0.0, 0.5, 0.5],
+ [0.01, 0.49, 0.5],
[0.4, 0.2, 0.4],
[0.2, 0.6, 0.2],
],
[0.1, 0.2, 0.3, 0.4],
[50.0],
- [[0.2, 0.5, 0.37]],
+ [[0.2, 0.49, 0.37]],
+ ),
+ (
+ [
+ [0.33, 0.33, None, 0.34],
+ [0.01, 0.49, None, 0.5],
+ [0.4, 0.2, None, 0.4],
+ [0.2, 0.6, None, 0.2],
+ ],
+ [0.1, 0.2, 0.3, 0.4],
+ [50.0],
+ [[0.2, 0.49, None, 0.37]],
),
+ # multiple choice options with placeholder values
],
)
def test_weighted_percentile_2d(values, weights, percentiles, expected_result):
@@ -73,7 +85,11 @@ def test_weighted_percentile_2d(values, weights, percentiles, expected_result):
result = weighted_percentile_2d(
values=values, weights=weights, percentiles=percentiles
)
- np.testing.assert_allclose(result, expected_result)
+ result = np.where(np.equal(result, None), np.nan, result).astype(float)
+ expected_result = np.where(
+ np.equal(expected_result, None), np.nan, expected_result
+ ).astype(float)
+ np.testing.assert_allclose(result, expected_result, equal_nan=True)
if weights is None and [percentiles] == [50.0]: # should behave like np.median
numpy_medians = np.median(values, axis=0)
np.testing.assert_allclose(result, [numpy_medians])
@@ -95,6 +111,7 @@ def test_percent_point_function(cdf, percentiles, expected_result):
@pytest.mark.parametrize(
"p1, p2, question, expected_result",
[
+ # binary
(
[0.5, 0.5],
[0.5, 0.5],
@@ -107,6 +124,7 @@ def test_percent_point_function(cdf, percentiles, expected_result):
Question(type="binary"),
sum([-0.1 * np.log2(0.5 / 0.6), 0.1 * np.log2(0.5 / 0.4)]), # 0.05849625
),
+ # multiple choice
(
[0.5, 0.5],
[0.5, 0.5],
@@ -138,6 +156,54 @@ def test_percent_point_function(cdf, percentiles, expected_result):
]
), # 1.3169925
),
+ (
+ [0.2, 0.3, 0.5],
+ [0.2, 0.2, 0.6],
+ Question(type="multiple_choice"),
+ sum(
+ [
+ 0,
+ (0.3 - 0.2) * np.log2(0.3 / 0.2),
+ (0.5 - 0.6) * np.log2(0.5 / 0.6),
+ ]
+ ), # 0.0847996
+ ),
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.3, None, 0.5],
+ Question(type="multiple_choice"),
+ 0.0,
+ ), # deal with Nones happily
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.3, 0.1, 0.4],
+ Question(type="multiple_choice"),
+ 0.0,
+ ), # no difference across adding an option
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.2, 0.1, 0.5],
+ Question(type="multiple_choice"),
+ sum(
+ [
+ 0,
+ (0.3 - 0.2) * np.log2(0.3 / 0.2),
+ (0.5 - 0.6) * np.log2(0.5 / 0.6),
+ ]
+ ), # 0.0847996
+ ), # difference across adding an option
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.1, None, 0.7, 0.2],
+ Question(type="multiple_choice"),
+ sum(
+ [
+ (0.2 - 0.1) * np.log2(0.2 / 0.1),
+ (0.8 - 0.9) * np.log2(0.8 / 0.9),
+ ]
+ ), # 0.1169925
+ ), # difference across removing and adding options
+ # continuous
(
[0.01, 0.5, 0.99],
[0.01, 0.5, 0.99],
@@ -214,6 +280,7 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result):
@pytest.mark.parametrize(
"p1, p2, question, expected_result",
[
+ # binary
(
[0.5, 0.5],
[0.5, 0.5],
@@ -230,6 +297,7 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result):
(-0.1, (2 / 3) / 1),
],
),
+ # multiple choice
(
[0.5, 0.5],
[0.5, 0.5],
@@ -270,6 +338,61 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result):
(-0.3, (1 / 9) / (4 / 6)),
],
),
+ (
+ [0.2, 0.3, 0.5],
+ [0.2, 0.2, 0.6],
+ Question(type="multiple_choice"),
+ [
+ (0.0, (2 / 8) / (2 / 8)),
+ (-0.1, (2 / 8) / (3 / 7)),
+ (0.1, (6 / 4) / (5 / 5)),
+ ],
+ ),
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.3, None, 0.5],
+ Question(type="multiple_choice"),
+ [
+ (0.0, (2 / 8) / (2 / 8)),
+ (0.0, (3 / 7) / (3 / 7)),
+ (0.0, 1.0),
+ (0.0, (5 / 5) / (5 / 5)),
+ ],
+ ), # deal with 0.0s happily
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.3, 0.1, 0.4],
+ Question(type="multiple_choice"),
+ [
+ (0.0, (2 / 8) / (2 / 8)),
+ (0.0, (3 / 7) / (3 / 7)),
+ (0.0, 1.0),
+ (0.0, (5 / 5) / (5 / 5)),
+ ],
+ ), # no difference across adding an option
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.2, 0.2, 0.1, 0.5],
+ Question(type="multiple_choice"),
+ [
+ (0.0, (2 / 8) / (2 / 8)),
+ (-0.1, (2 / 8) / (3 / 7)),
+ (0.0, 1.0),
+ (0.1, (6 / 4) / (5 / 5)),
+ ],
+ ), # difference across adding an option
+ (
+ [0.2, 0.3, None, 0.5],
+ [0.1, None, 0.7, 0.2],
+ Question(type="multiple_choice"),
+ [
+ (-0.1, (1 / 9) / (2 / 8)),
+ (0.0, 1.0),
+ (0.0, 1.0),
+ (0.1, (9 / 1) / (8 / 2)),
+ ],
+ ), # difference across removing and adding options
+ # continuous
(
[0.0, 0.5, 1.0],
[0.0, 0.5, 1.0],
diff --git a/users/views.py b/users/views.py
index 993613cc96..15f843e161 100644
--- a/users/views.py
+++ b/users/views.py
@@ -326,7 +326,7 @@ def get_forecasting_stats_data(
)
if user is not None:
forecasts = forecasts.filter(author=user)
- forecasts_count = forecasts.count()
+ forecasts_count = forecasts.exclude(source=Forecast.SourceChoices.AUTOMATIC).count()
questions_predicted_count = forecasts.values("question").distinct().count()
score_count = len(scores)
diff --git a/utils/csv_utils.py b/utils/csv_utils.py
index 447bbe1d69..d096d4ad55 100644
--- a/utils/csv_utils.py
+++ b/utils/csv_utils.py
@@ -16,6 +16,7 @@
Forecast,
QUESTION_CONTINUOUS_TYPES,
)
+from questions.services.multiple_choice_handlers import get_all_options_from_history
from questions.types import AggregationMethod
from scoring.models import Score, ArchivedScore
from utils.the_math.aggregations import get_aggregation_history
@@ -328,7 +329,9 @@ def generate_data(
+ "**`Default Project ID`** - the id of the default project for the Post.\n"
+ "**`Label`** - for a group question, this is the sub-question object.\n"
+ "**`Question Type`** - the type of the question. Binary, Multiple Choice, Numeric, Discrete, or Date.\n"
- + "**`MC Options`** - the options for a multiple choice question, if applicable.\n"
+ + "**`MC Options (Current)`** - the current options for a multiple choice question, if applicable.\n"
+ + "**`MC Options (All)`** - the options for a multiple choice question across all time, if applicable.\n"
+ + "**`MC Options History`** - the history of options over time. Each entry is a isoformat time and a record of what the options were at that time.\n"
+ "**`Lower Bound`** - the lower bound of the forecasting range for a continuous question.\n"
+ "**`Open Lower Bound`** - whether the lower bound is open.\n"
+ "**`Upper Bound`** - the upper bound of the forecasting range for a continuous question.\n"
@@ -357,7 +360,9 @@ def generate_data(
"Default Project ID",
"Label",
"Question Type",
- "MC Options",
+ "MC Options (Current)",
+ "MC Options (All)",
+ "MC Options History",
"Lower Bound",
"Open Lower Bound",
"Upper Bound",
@@ -406,7 +411,13 @@ def format_value(val):
post.default_project_id,
question.label,
question.type,
- question.options or None,
+ question.options,
+ (
+ get_all_options_from_history(question.options_history)
+ if question.options_history
+ else None
+ ),
+ question.options_history or None,
format_value(question.range_min),
question.open_lower_bound,
format_value(question.range_max),
@@ -446,7 +457,7 @@ def format_value(val):
+ "**`End Time`** - the time when the forecast ends. If not populated, the forecast is still active. Note that this can be set in the future indicating an expiring forecast.\n"
+ "**`Forecaster Count`** - if this is an aggregate forecast, how many forecasts contribute to it.\n"
+ "**`Probability Yes`** - the probability of the binary question resolving to 'Yes'\n"
- + "**`Probability Yes Per Category`** - a list of probabilities corresponding to each option for a multiple choice question. Cross-reference 'MC Options' in `question_data.csv`.\n"
+ + "**`Probability Yes Per Category`** - a list of probabilities corresponding to each option for a multiple choice question. Cross-reference 'MC Options (All)' in `question_data.csv`. Note that a Multiple Choice forecast will have None in places where the corresponding option wasn't available for forecast at the time.\n"
+ "**`Continuous CDF`** - the value of the CDF (cumulative distribution function) at each of the locations in the continuous range for a continuous question. Cross-reference 'Continuous Range' in `question_data.csv`.\n"
+ "**`Probability Below Lower Bound`** - the probability of the question resolving below the lower bound for a continuous question.\n"
+ "**`Probability Above Upper Bound`** - the probability of the question resolving above the upper bound for a continuous question.\n"
diff --git a/utils/the_math/aggregations.py b/utils/the_math/aggregations.py
index 40c0193de3..60cfa26ae9 100644
--- a/utils/the_math/aggregations.py
+++ b/utils/the_math/aggregations.py
@@ -489,6 +489,9 @@ def get_range_values(
forecasts_values, weights, [25.0, 50.0, 75.0]
)
centers_array = np.array(centers)
+ centers_array[np.equal(centers_array, 0.0) | (centers_array == 0.0)] = (
+ 1.0 # avoid divide by zero
+ )
normalized_centers = np.array(aggregation_forecast_values)
normalized_lowers = np.array(lowers)
normalized_lowers[non_nones] = (
@@ -498,7 +501,7 @@ def get_range_values(
)
normalized_uppers = np.array(uppers)
normalized_uppers[non_nones] = (
- normalized_lowers[non_nones]
+ normalized_uppers[non_nones]
* normalized_centers[non_nones]
/ centers_array[non_nones]
)
@@ -641,9 +644,18 @@ def calculate_aggregation_entry(
Question.QuestionType.BINARY,
Question.QuestionType.MULTIPLE_CHOICE,
]:
- aggregation.means = np.average(
- forecast_set.forecasts_values, weights=weights, axis=0
- ).tolist()
+ forecasts_values = np.array(forecast_set.forecasts_values)
+ nones = (
+ np.equal(forecasts_values[0], None)
+ if forecasts_values.size
+ else np.array([])
+ )
+ forecasts_values[:, nones] = np.nan
+ means = np.average(forecasts_values, weights=weights, axis=0).astype(
+ object
+ )
+ means[np.isnan(means.astype(float))] = None
+ aggregation.means = means.tolist()
if histogram and self.question.type in [
Question.QuestionType.BINARY,
diff --git a/utils/the_math/formulas.py b/utils/the_math/formulas.py
index 999444794c..d582039269 100644
--- a/utils/the_math/formulas.py
+++ b/utils/the_math/formulas.py
@@ -5,6 +5,7 @@
from questions.constants import UnsuccessfulResolutionType
from questions.models import Question
+from questions.services.multiple_choice_handlers import get_all_options_from_history
from utils.typing import ForecastValues
logger = logging.getLogger(__name__)
@@ -33,7 +34,8 @@ def string_location_to_scaled_location(
if question.type == Question.QuestionType.BINARY:
return 1.0 if string_location == "yes" else 0.0
if question.type == Question.QuestionType.MULTIPLE_CHOICE:
- return float(question.options.index(string_location))
+ list_of_all_options = get_all_options_from_history(question.options_history)
+ return float(list_of_all_options.index(string_location))
# continuous
if string_location == "below_lower_bound":
return question.range_min - 1.0
diff --git a/utils/the_math/measures.py b/utils/the_math/measures.py
index e20bd381be..7edce08712 100644
--- a/utils/the_math/measures.py
+++ b/utils/the_math/measures.py
@@ -17,16 +17,17 @@ def weighted_percentile_2d(
percentiles: Percentiles = None,
) -> Percentiles:
values = np.array(values)
+ sorted_values = values.copy() # avoid side effects
+ # replace None with np.nan for calculations (return to None at the end)
+ sorted_values[np.equal(sorted_values, None)] = np.nan
+
if weights is None:
ordered_weights = np.ones_like(values)
else:
weights = np.array(weights)
- ordered_weights = weights[values.argsort(axis=0)]
+ ordered_weights = weights[sorted_values.argsort(axis=0)]
percentiles = np.array(percentiles or [50.0])
- sorted_values = values.copy() # avoid side effects
- # replace None with -1.0 for calculations (return to None at the end)
- sorted_values[np.equal(sorted_values, None)] = -1.0
sorted_values.sort(axis=0)
# get the normalized cumulative weights
@@ -52,10 +53,10 @@ def weighted_percentile_2d(
+ sorted_values[right_indexes, column_indicies]
)
)
- # replace -1.0 back to None
+ # replace np.nan back to None
weighted_percentiles = np.array(weighted_percentiles)
weighted_percentiles = np.where(
- weighted_percentiles == -1.0, None, weighted_percentiles
+ np.isnan(weighted_percentiles.astype(float)), None, weighted_percentiles
)
return weighted_percentiles.tolist()
@@ -104,10 +105,22 @@ def prediction_difference_for_sorting(
"""for binary and multiple choice, takes pmfs
for continuous takes cdfs"""
p1, p2 = np.array(p1), np.array(p2)
- p1[np.equal(p1, None)] = -1.0 # replace None with -1.0 for calculations
- p2[np.equal(p2, None)] = -1.0 # replace None with -1.0 for calculations
# Uses Jeffrey's Divergence
- if question_type in ["binary", "multiple_choice"]:
+ if question_type == Question.QuestionType.MULTIPLE_CHOICE:
+ # cover for Nones
+ p1_nones = np.equal(p1, None)
+ p2_nones = np.equal(p2, None)
+ never_nones = np.logical_not(p1_nones | p2_nones)
+ p1_new = p1[never_nones]
+ p2_new = p2[never_nones]
+ p1_new[-1] += sum(p1[~p1_nones & p2_nones])
+ p2_new[-1] += sum(p2[~p2_nones & p1_nones])
+ p1 = p1_new
+ p2 = p2_new
+ if question_type in [
+ Question.QuestionType.BINARY,
+ Question.QuestionType.MULTIPLE_CHOICE,
+ ]:
return sum([(p - q) * np.log2(p / q) for p, q in zip(p1, p2)])
cdf1 = np.array([1 - np.array(p1), p1])
cdf2 = np.array([1 - np.array(p2), p2])
@@ -123,14 +136,22 @@ def prediction_difference_for_display(
"""for binary and multiple choice, takes pmfs
for continuous takes cdfs"""
p1, p2 = np.array(p1), np.array(p2)
- p1[np.equal(p1, None)] = -1.0 # replace None with -1.0 for calculations
- p2[np.equal(p2, None)] = -1.0 # replace None with -1.0 for calculations
if question.type == "binary":
# single-item list of (pred diff, ratio of odds)
return [(p2[1] - p1[1], (p2[1] / (1 - p2[1])) / (p1[1] / (1 - p1[1])))]
elif question.type == "multiple_choice":
# list of (pred diff, ratio of odds)
- return [(q - p, (q / (1 - q)) / (p / (1 - p))) for p, q in zip(p1, p2)]
+ for p, q in zip(p1[:-1], p2[:-1]):
+ if p is None or q is None:
+ p1[-1] = (p1[-1] or 0.0) + (p or 0.0)
+ p2[-1] = (p2[-1] or 0.0) + (q or 0.0)
+ arr = []
+ for p, q in zip(p1, p2):
+ if p is None or q is None:
+ arr.append((0.0, 1.0))
+ else:
+ arr.append((q - p, (q / (1 - q)) / (p / (1 - p))))
+ return arr
# total earth mover's distance, assymmetric earth mover's distance
x_locations = unscaled_location_to_scaled_location(
np.linspace(0, 1, len(p1)), question