Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,6 @@ def load(
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
please check `monai.data.load_net_with_metadata` for more details.
4. If `return_state_dict` is True, return model weights, only used for compatibility
when `model` and `net_name` are all `None`.

"""
bundle_dir_ = _process_bundle_dir(bundle_dir)
Expand Down
19 changes: 3 additions & 16 deletions tests/bundle/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,17 @@ class TestLoad(unittest.TestCase):
@skip_if_quick
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
with skip_if_downloading_fails():
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
bundle_root = os.path.join(tempdir, bundle_name)
# load weights
weights = load(
model_1 = load(
name=bundle_name,
model_file=model_file,
bundle_dir=tempdir,
repo=repo,
source="github",
progress=False,
device=device,
return_state_dict=True,
)
# prepare network
with open(os.path.join(bundle_root, bundle_files[2])) as f:
Expand All @@ -289,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
del net_args["_target_"]
model = getattr(nets, model_name)(**net_args)
model.to(device)
model.load_state_dict(weights)
model.load_state_dict(model_1)
model.eval()

# prepare data and test
Expand All @@ -313,13 +311,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
progress=False,
device=device,
source="github",
return_state_dict=False,
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)

# test compatibility with return_state_dict=True.
model_3 = load(
name=bundle_name,
model_file=model_file,
Expand All @@ -328,7 +324,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
device=device,
net_name=model_name,
source="github",
return_state_dict=False,
**net_args,
)
model_3.eval()
Expand All @@ -343,14 +338,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
# download bundle, and load weights from the downloaded path
with tempfile.TemporaryDirectory() as tempdir:
# load weights
model = load(
name=bundle_name,
bundle_dir=tempdir,
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
)
model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)

# prepare data and test
input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
Expand All @@ -371,7 +359,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override)
source="monaihosting",
progress=False,
device=device,
return_state_dict=False,
net_override=net_override,
)

Expand Down
1 change: 0 additions & 1 deletion tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
version=version,
bundle_dir=tempdir,
remove_prefix=remove_prefix,
return_state_dict=False,
)
assert_allclose(
model.state_dict()[TESTCASE_WEIGHTS["key"]],
Expand Down
Loading