-
Notifications
You must be signed in to change notification settings - Fork 45
Fix bug of graph_net/torch/fx_graph_cache_util.py and valiate samples of dtype generlization pass #513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| from torch.fx.passes.shape_prop import ShapeProp | ||
|
|
||
|
|
||
| def parse_immutable_model_path_into_sole_graph_module(model_path, device=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个基础函数不可以改。它的名字没有shape propagate的语义。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些测试代码太复杂了,没法review。应该写成标准的sample_pass的test。可以借鉴#514 这个pr的graph_net/test/subgraph_input_shapes_naive_rewriter_test.sh的写法。
|
graph_net/test/dtype_gen_test.sh
Outdated
| for sample in "$TORCHVISION_ROOT"/*; do | ||
| python3 -m graph_net.model_path_handler --model-path $sample --handler-config=$CONFIG_APPLY | ||
| done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接指定了torchvision,这是不对的。应该取自graph_net/config/small100_torch_samples_list.txt。
graph_net/test/dtype_gen_test.sh
Outdated
| else | ||
| echo "FAIL" | ||
| fi | ||
| done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
替换为
done | tee >(grep SUCCESS | wc -l | xargs -I{} echo SUCCESS {}) | tee >(grep FAIL | wc -l | xargs -I{} echo FAIL {}); wait这样表示统计最终的成功与失败的个数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parse_immutable_model_path_into_sole_graph_module()返回的trace_model不包含tensor_meta导致dtype_pass.need_rewrite(traced_model)返回总是false,所以改为调用get_torch_module_and_inputs和parse_sole_graph_module和ShapeProp(traced_model).propagate(*inputs),使trace_model包含tensor_meta
| @@ -1,12 +1,12 @@ | |||
| #!/bin/bash | |||
| # !/bin/bash | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个shabang是标准写法,不应该改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到

PR Category
Description