diff --git a/pyproject.toml b/pyproject.toml index d32ead0..3003ee6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,4 +48,4 @@ local_scheme = "no-local-version" "Homepage" = "https://github.com/meta-pytorch/tritonparse" [project.scripts] -tritonparse = "tritonparse.run:main" +tritonparseoss = "tritonparse.cli:main" diff --git a/run.py b/run.py index 3e7dfa2..0d4de97 100755 --- a/run.py +++ b/run.py @@ -1,79 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. -import argparse -from importlib.metadata import PackageNotFoundError, version - -from .reproducer.cli import _add_reproducer_args -from .reproducer.orchestrator import reproduce -from .utils import _add_parse_args, unified_parse - - -def _get_package_version() -> str: - try: - return version("tritonparse") - except PackageNotFoundError: - return "0+unknown" - - -def main(): - pkg_version = _get_package_version() - - parser = argparse.ArgumentParser( - prog="tritonparse", - description=( - "TritonParse: parse structured logs and generate minimal reproducers" - ), - epilog=( - "Examples:\n" - " tritonparse parse /path/to/logs --out parsed_output\n" - " tritonparse reproduce /path/to/trace.ndjson --line 1 --out-dir repro_output\n" - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "--version", - action="version", - version=f"%(prog)s {pkg_version}", - help="Show program's version number and exit", - ) - - subparsers = parser.add_subparsers(dest="command", required=True) - - # parse subcommand - parse_parser = subparsers.add_parser( - "parse", - help="Parse triton structured logs", - conflict_handler="resolve", - ) - _add_parse_args(parse_parser) - parse_parser.set_defaults(func="parse") - - # reproduce subcommand - repro_parser = subparsers.add_parser( - "reproduce", - help="Build reproducer from trace file", - ) - _add_reproducer_args(repro_parser) - repro_parser.set_defaults(func="reproduce") - - args = parser.parse_args() - - if args.func == "parse": - parse_args = { - k: v for k, v in vars(args).items() if k not in ["command", "func"] - } - unified_parse(**parse_args) - elif args.func == "reproduce": - reproduce( - input_path=args.input, - line_index=args.line, - out_dir=args.out_dir, - template=args.template, - ) - else: - raise RuntimeError(f"Unknown command: {args.func}") - +from tritonparse.cli import main if __name__ == "__main__": # Do not add code here, it won't be run. Add them to the function called below. diff --git a/tritonparse/__main__.py b/tritonparse/__main__.py index ef015cd..2f05ddc 100644 --- a/tritonparse/__main__.py +++ b/tritonparse/__main__.py @@ -1,4 +1,4 @@ -from .run import main +from .cli import main if __name__ == "__main__": diff --git a/tritonparse/cli.py b/tritonparse/cli.py new file mode 100644 index 0000000..089d263 --- /dev/null +++ b/tritonparse/cli.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +import argparse +from importlib.metadata import PackageNotFoundError, version + +from .common import is_fbcode +from .reproducer.cli import _add_reproducer_args +from .reproducer.orchestrator import reproduce +from .utils import _add_parse_args, unified_parse + + +def _get_package_version() -> str: + try: + return version("tritonparse") + except PackageNotFoundError: + return "0+unknown" + + +def main(): + pkg_version = _get_package_version() + + # Use different command name for fbcode vs OSS + prog_name = "tritonparse" if is_fbcode() else "tritonparseoss" + + parser = argparse.ArgumentParser( + prog=prog_name, + description=( + "TritonParse: parse structured logs and generate minimal reproducers" + ), + epilog=( + "Examples:\n" + f" {prog_name} parse /path/to/logs --out parsed_output\n" + f" {prog_name} reproduce /path/to/trace.ndjson --line 2 --out-dir repro_output\n" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--version", + action="version", + version=f"%(prog)s {pkg_version}", + help="Show program's version number and exit", + ) + + subparsers = parser.add_subparsers(dest="command", required=True) + + # parse subcommand + parse_parser = subparsers.add_parser( + "parse", + help="Parse triton structured logs", + conflict_handler="resolve", + ) + _add_parse_args(parse_parser) + parse_parser.set_defaults(func="parse") + + # reproduce subcommand + repro_parser = subparsers.add_parser( + "reproduce", + help="Build reproducer from trace file", + ) + _add_reproducer_args(repro_parser) + repro_parser.set_defaults(func="reproduce") + + args = parser.parse_args() + + if args.func == "parse": + parse_args = { + k: v for k, v in vars(args).items() if k not in ["command", "func"] + } + unified_parse(**parse_args) + elif args.func == "reproduce": + reproduce( + input_path=args.input, + line_index=args.line, + out_dir=args.out_dir, + template=args.template, + ) + else: + raise RuntimeError(f"Unknown command: {args.func}") + + +if __name__ == "__main__": + # Do not add code here, it won't be run. Add them to the function called below. + main() # pragma: no cover