-
Notifications
You must be signed in to change notification settings - Fork 144
Description
Problem: The following fails when using models & base_models in the safetensors format
model = "Mistral-7B-v0.3"
base_model = "Mistral-7B-Instruct-v0.3"
watcher = ww.WeightWatcher(model=model, base_model=base_model, log_level=logging.DEBUG)
description = watcher.describe( max_evals=50000, layers=[286])
Why: Safetensors screws up the order of the layers because it stores the layers in sort order by name
To fix this, the code internally reads a layer_map
ipytorch_model.bin.index.json
Before any filters are applied, the code first iterates over all the layers (i.e. over all safetensors files) to see which layers are skipped.
Then, a second pass is applied, (I think) the filters are already applied
So in the case above, the first layer_id is not 286, it is 0
I think what happes
Proposed Solutions:
WWLayerIterator.apply_filters() only checks ww_layer.layer_id and ww_layer.name against the filter_ids,
Don't check ww_layer.layer_id against filter_ids directly. Instead
- loop up the key in the layer map, and check the layer against the ww_layer.name. The names are usually unique for safetensors, but not in general
- change the current
layer_idto the correctlayer_id, but I'm not sure how yet since there a re 2 passes over the iterator - only apply the filters if layer_map is none...so it is is only applied on the first pass
- hack like 3, but skip all filters
if layer_map is not None and layer_map['source']==SAFETENSORS - for debugging, don't use layerids, use names for safetensors. THIS DOES NOT WORK EITHER
Note sure if any of these will work since the Test_Albert_DeltaLayerIterator unit test does work and breaks some of the tests..but not all...with 3:
- BREAKS:
test_deltas_layer_id_filters
test_delta_layer_iterator_with_filters
THESE TESTS FAIL
test_safetensors_deltas_with_filter_by_ids
test_safetensors_deltas_with_filter_by_names
adding now...does not work yet
it ONLY works with 3. !!!
- CAN THIS WORK ?
Modifying the layer_map could be bad since we might iterate over it
It may be better to set the source in the iterator