Skip to content

Commit 9f9507d

Browse files
committed
Prevent showencryptedfieldsmap from creating data keys
1 parent 51ac968 commit 9f9507d

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed

django_mongodb_backend/management/commands/showencryptedfieldsmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ def handle(self, *args, **options):
3030
for app_config in apps.get_app_configs():
3131
for model in router.get_migratable_models(app_config, db):
3232
if model_has_encrypted_fields(model):
33-
fields = editor._get_encrypted_fields(model)
33+
fields = editor._get_encrypted_fields(model, create_data_keys=False)
3434
encrypted_fields_map[model._meta.db_table] = fields
3535
self.stdout.write(json_util.dumps(encrypted_fields_map, indent=2))

django_mongodb_backend/schema.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def _create_collection(self, model):
477477
# Unencrypted path
478478
db.create_collection(db_table)
479479

480-
def _get_encrypted_fields(self, model, key_alt_name_prefix=None, path_prefix=None):
480+
def _get_encrypted_fields(
481+
self, model, *, key_alt_name_prefix=None, path_prefix=None, create_data_keys=True
482+
):
481483
"""
482484
Return the encrypted fields map for the given model. The "prefix"
483485
arguments are used when this method is called recursively on embedded
@@ -488,12 +490,12 @@ def _get_encrypted_fields(self, model, key_alt_name_prefix=None, path_prefix=Non
488490
key_alt_name_prefix = key_alt_name_prefix or model._meta.db_table
489491
path_prefix = path_prefix or ""
490492
auto_encryption_opts = client._options.auto_encryption_opts
491-
key_vault_db, key_vault_collection = auto_encryption_opts._key_vault_namespace.split(".", 1)
492-
key_vault_collection = client[key_vault_db][key_vault_collection]
493+
_, key_vault_collection = auto_encryption_opts._key_vault_namespace.split(".", 1)
494+
key_vault = self.get_collection(key_vault_collection)
493495
# Create partial unique index on keyAltNames.
494496
# TODO: find a better place for this. It only needs to run once for an
495497
# application's lifetime.
496-
key_vault_collection.create_index(
498+
key_vault.create_index(
497499
"keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}
498500
)
499501
# Select the KMS provider.
@@ -516,22 +518,29 @@ def _get_encrypted_fields(self, model, key_alt_name_prefix=None, path_prefix=Non
516518
field.embedded_model,
517519
key_alt_name_prefix=key_alt_name,
518520
path_prefix=path,
521+
create_data_keys=create_data_keys,
519522
)
520523
# An EmbeddedModelField may not have any encrypted fields.
521524
if embedded_result:
522525
field_list.extend(embedded_result["fields"])
523526
continue
524527
# Populate data for encrypted field.
525528
if getattr(field, "encrypted", False):
526-
data_key = key_vault_collection.find_one({"keyAltNames": key_alt_name})
527-
if data_key:
528-
data_key = data_key["_id"]
529-
else:
529+
if create_data_keys:
530530
data_key = connection.client_encryption.create_data_key(
531531
kms_provider=kms_provider,
532532
key_alt_names=[key_alt_name],
533533
master_key=master_key,
534534
)
535+
else:
536+
data_key = key_vault.find_one({"keyAltNames": key_alt_name})
537+
if data_key:
538+
data_key = data_key["_id"]
539+
else:
540+
raise ImproperlyConfigured(
541+
f"Encryption key {key_alt_name} not found. Have "
542+
f"migrated the {model} model?"
543+
)
535544
field_dict = {
536545
"bsonType": field.db_type(connection),
537546
"path": path,

django_mongodb_backend/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class OperationDebugWrapper:
118118
"create_indexes",
119119
"create_search_index",
120120
"drop",
121+
"find_one",
121122
"index_information",
122123
"insert_many",
123124
"delete_many",

tests/encryption_/test_management.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from io import StringIO
22

33
from bson import json_util
4+
from django.core.exceptions import ImproperlyConfigured
45
from django.core.management import call_command
6+
from django.db import connections
57
from django.test import modify_settings
68

9+
from .models import EncryptionKey
710
from .test_base import EncryptionTestCase
811

912

@@ -96,7 +99,7 @@ class CommandTests(EncryptionTestCase):
9699

97100
def _compare_output(self, expected, actual):
98101
for field in actual["fields"]:
99-
field.pop("keyId", None) # remove dynamic keyId
102+
field.pop("keyId") # remove dynamic keyId
100103
self.assertEqual(expected, actual)
101104

102105
def test_show_encrypted_fields_map(self):
@@ -109,3 +112,20 @@ def test_show_encrypted_fields_map(self):
109112
with self.subTest(model=model_key):
110113
self.assertIn(model_key, command_output)
111114
self._compare_output(expected, command_output[model_key])
115+
116+
def test_missing_key(self):
117+
test_key = "encryption__patient.patient_record.ssn"
118+
msg = (
119+
f"Encryption key {test_key} not found. Have migrated the "
120+
"<class 'encryption_.models.PatientRecord'> model?"
121+
)
122+
EncryptionKey.objects.filter(key_alt_name=test_key).delete()
123+
try:
124+
with self.assertRaisesMessage(ImproperlyConfigured, msg):
125+
call_command("showencryptedfieldsmap", "--database", "encrypted", verbosity=0)
126+
finally:
127+
# Replace the deleted key.
128+
connections["encrypted"].client_encryption.create_data_key(
129+
kms_provider="local",
130+
key_alt_names=[test_key],
131+
)

tests/encryption_/test_schema.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,19 @@ def test_get_encrypted_fields_all_models(self):
9696
checks their encrypted fields map from the schema editor,
9797
and compares to expected BSON type & queries mapping.
9898
"""
99+
# Deleting all keys is only correct only if this test includes all
100+
# test models. This test may not be needed since it's tested when the
101+
# test runner migrates all models. If any subTest fails, the key vault
102+
# will be left in an inconsistent state.
103+
EncryptionKey.objects.all().delete()
99104
connection = connections["encrypted"]
100-
101105
for model_name, expected in self.expected_map.items():
102106
with self.subTest(model=model_name):
103107
model_class = getattr(models, model_name)
104108
with connection.schema_editor() as editor:
105-
client = connection.connection
106-
encrypted_fields = editor._get_encrypted_fields(model_class, client)
109+
encrypted_fields = editor._get_encrypted_fields(model_class)
107110
for field in encrypted_fields["fields"]:
108-
field.pop("keyId", None) # Remove dynamic value
111+
field.pop("keyId") # Remove dynamic value
109112
self.assertEqual(encrypted_fields, expected)
110113

111114
def test_key_creation_and_lookup(self):

0 commit comments

Comments
 (0)