|
9 | 9 | # except for python 2.7 standard library and Spark 2.1 |
10 | 10 | import sys |
11 | 11 | from datetime import datetime, timedelta, tzinfo |
| 12 | +import yaml |
12 | 13 | from time import localtime, strftime |
13 | 14 | from types import MethodType |
14 | 15 |
|
@@ -1606,6 +1607,39 @@ def parse_arguments(args): |
1606 | 1607 | return options |
1607 | 1608 |
|
1608 | 1609 |
|
| 1610 | +def parse_arguments_from_yaml_file(args): |
| 1611 | + """ |
| 1612 | + This function accepts the path to a config file |
| 1613 | + and extracts the needed arguments for the metastore migration |
| 1614 | + ---------- |
| 1615 | + Return: |
| 1616 | + Dictionary of config options |
| 1617 | + """ |
| 1618 | + parser = argparse.ArgumentParser(prog=args[0]) |
| 1619 | + parser.add_argument('-f', '--config_file', required=True, default='artifacts/config.yaml`', help='Provide yaml configuration file path to read migration arguments from. Default path: `artifacts/config.yaml`') |
| 1620 | + options = get_options(parser, args) |
| 1621 | + config_file_path = options['config_file'] |
| 1622 | + ## read the yaml file |
| 1623 | + with open(config_file_path, 'r') as yaml_file_stream: |
| 1624 | + config_options = yaml.load(yaml_file_stream) |
| 1625 | + |
| 1626 | + if config_options['mode'] == FROM_METASTORE: |
| 1627 | + validate_options_in_mode( |
| 1628 | + options=config_options, mode=FROM_METASTORE, |
| 1629 | + required_options=['output_path'], |
| 1630 | + not_allowed_options=['input_path'] |
| 1631 | + ) |
| 1632 | + elif config_options['mode'] == TO_METASTORE: |
| 1633 | + validate_options_in_mode( |
| 1634 | + options=config_options, mode=TO_METASTORE, |
| 1635 | + required_options=['input_path'], |
| 1636 | + not_allowed_options=['output_path'] |
| 1637 | + ) |
| 1638 | + else: |
| 1639 | + raise AssertionError('unknown mode ' + options['mode']) |
| 1640 | + |
| 1641 | + return config_options |
| 1642 | + |
1609 | 1643 | def get_spark_env(): |
1610 | 1644 | try: |
1611 | 1645 | sc = SparkContext.getOrCreate() |
@@ -1733,7 +1767,10 @@ def validate_aws_regions(region): |
1733 | 1767 |
|
1734 | 1768 |
|
1735 | 1769 | def main(): |
1736 | | - options = parse_arguments(sys.argv) |
| 1770 | + # options = parse_arguments(sys.argv) |
| 1771 | + |
| 1772 | + ## This now reads options from path to config yaml file |
| 1773 | + options = parse_arguments_from_yaml_file(sys.argv) |
1737 | 1774 |
|
1738 | 1775 | connection = {"url": options["jdbc_url"], "user": options["jdbc_username"], "password": options["jdbc_password"]} |
1739 | 1776 | db_prefix = options.get("database_prefix") or "" |
|
0 commit comments