Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,6 @@ def load(
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.
4. If `return_state_dict` is True, return model weights, only used for compatibility
when `model` and `net_name` are all `None`.

"""
bundle_dir_ = _process_bundle_dir(bundle_dir)
Expand Down
19 changes: 3 additions & 16 deletions tests/bundle/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,17 @@ class TestLoad(unittest.TestCase):
@skip_if_quick
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
with skip_if_downloading_fails():
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
bundle_root = os.path.join(tempdir, bundle_name)
# load weights
weights = load(
model_1 = load(
name=bundle_name,
model_file=model_file,
bundle_dir=tempdir,
repo=repo,
source="github",
progress=False,
device=device,
return_state_dict=True,
)
# prepare network
with open(os.path.join(bundle_root, bundle_files[2])) as f:
Expand All @@ -289,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
del net_args["_target_"]
model = getattr(nets, model_name)(**net_args)
model.to(device)
model.load_state_dict(weights)
model.load_state_dict(model_1)
model.eval()

# prepare data and test
Expand All @@ -313,13 +311,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
progress=False,
device=device,
source="github",
return_state_dict=False,
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

# test compatibility with return_state_dict=True.
model_3 = load(
name=bundle_name,
model_file=model_file,
Expand All @@ -328,7 +324,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
**net_args,
)
model_3.eval()
Expand All @@ -343,14 +338,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
# load weights
model = load(
name=bundle_name,
bundle_dir=tempdir,
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
)
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)

# prepare data and test
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
Expand All @@ -371,7 +359,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
net_override=net_override,
)

Expand Down
7 changes: 1 addition & 6 deletions tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))

model = load(
name=bundle_name,
source="ngc",
version=version,
bundle_dir=tempdir,
remove_prefix=remove_prefix,
return_state_dict=False,
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
)
assert_allclose(
model.state_dict()[TESTCASE_WEIGHTS["key"]],
Expand Down
Loading