diff --git a/.gitignore b/.gitignore index 3df34cb..7e283bd 100755 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,10 @@ dwh_media/ *.lock *.env dex/dex-data/dex.db + +# training output files +*.pkl +*.xlsx +*.csv +*.pdf +*.json \ No newline at end of file diff --git a/README.md b/README.md index 572128e..95a645b 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ pip install -r requirements.txt # input data -csv input file with at least the following columns: +input file (CSV or XLSX) with at least the following columns: | column | description | | ------------- | ------------- | | Main | Main category | @@ -30,7 +30,7 @@ See python train.py for all options. To train Middle and Sub categoeries use: ``` -python train.py --csv file.csv --columns Middle,Sub +python train.py --input-file file.csv --columns Middle,Sub ``` This step will generate a categories `json` file. Use this file to load the categories in the backend. ``` @@ -39,7 +39,7 @@ python manage.py load_categories To train Middle category use: ``` -python train.py --csv file.csv --columns Middle +python train.py --input-file file.csv --columns Middle ``` Rename resulting files to "main_model.pkl, sub_model.pkl, main_slugs.pkl, sub_slugs.pkl" and copy the pkl files into the classification endpoint. diff --git a/app/engine.py b/app/engine.py index 56dc2ba..d6b46ca 100644 --- a/app/engine.py +++ b/app/engine.py @@ -5,11 +5,10 @@ from sklearn.linear_model import LogisticRegression from nltk.stem.snowball import DutchStemmer import joblib -import warnings +import os import nltk import re import csv -import psutil class TextClassifier: _text = 'Text' @@ -41,35 +40,42 @@ def export_model(self, file): joblib.dump(self.model, file) def preprocessor(self, text): + text = str(text) text=text.lower() - text=re.sub("\\W"," ",text) # remove special chars - + # stem words words=re.split("\\s+",text) stemmed_words=[self.stemmer.stem(word=word) for word in words] return ' '.join(stemmed_words) - def load_data(self, csv_file, frac=1): - df = pd.read_csv(csv_file, sep=None, engine='python') + def load_data(self, input_file, frac=1): + _, extension = os.path.splitext(input_file) + + if extension == '.csv': + df = pd.read_csv(input_file, sep=None, engine='python') + elif extension == '.xlsx': + df = pd.read_excel(input_file) + else: + raise Exception('Could not read input file. Extension should be .csv or .xlsx') + + print(df) + df = df.dropna( axis=0, how='any', - thresh=None, subset=[self._text, self._main, self._middle, self._sub], inplace=False ) # cleanup dataset - df = df.drop_duplicates(subset=[self._text], keep='first') + #df = df.drop_duplicates(subset=[self._text], keep='first') # for dev use only a subset (for speed purpose) - df = df.sample(frac=frac).reset_index(drop=True) + #df = df.sample(frac=frac).reset_index(drop=True) # construct unique label df[self._lbl] = df[self._main] + "|" + df[self._middle] + "|" + df[self._sub] number_of_examples = df[self._lbl].value_counts().to_frame() - df['is_bigger_than_50'] = df[self._lbl].isin(number_of_examples[number_of_examples[self._lbl]>50].index) - df['is_bigger_than_50'].value_counts() - df = df[df['is_bigger_than_50'] == True] + # The example dataset is not large enough to train a good classification model # print(len(self.df),'rows valid') return df @@ -77,7 +83,9 @@ def load_data(self, csv_file, frac=1): def make_data_sets(self, df, split=0.9, columns=['Middle', 'Sub']): texts = df[self._text] - labels = df[columns].apply('|'.join, axis=1) + labels = df[columns].applymap(lambda x: x.lower().capitalize()).apply('|'.join, axis=1) + + print(labels.value_counts()) train_texts, test_texts, train_labels, test_labels = train_test_split( texts, labels, test_size=1-split, stratify=labels) @@ -85,7 +93,6 @@ def make_data_sets(self, df, split=0.9, columns=['Middle', 'Sub']): return texts, labels, train_texts, train_labels, test_texts, test_labels def fit(self, train_texts, train_labels): - pipeline = Pipeline([ ('vect', CountVectorizer(preprocessor=self.preprocessor, stop_words=self.stop_words)), ('tfidf', TfidfTransformer()), @@ -119,7 +126,7 @@ def fit(self, train_texts, train_labels): 'vect__ngram_range': ((1, 1),) # (1,2) } - grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=psutil.cpu_count(logical=False),cv=5) + grid_search = GridSearchCV(pipeline, parameters_slow,verbose=True,n_jobs=1,cv=5) grid_search.fit(train_texts, train_labels) #print('Best parameters: ') #print(grid_search.best_params_) diff --git a/app/train.py b/app/train.py index 580b901..72046fb 100644 --- a/app/train.py +++ b/app/train.py @@ -7,7 +7,7 @@ def parse_args(): parser = argparse.ArgumentParser() optional = parser._action_groups.pop() required = parser.add_argument_group('required arguments') - required.add_argument('--csv', required=True) + required.add_argument('--input-file', required=True) optional.add_argument('--columns', default='') optional.add_argument('--fract', default=1.0, type=float) optional.add_argument('--output-fixtures', const=True, nargs="?", default=True, type=bool) @@ -70,7 +70,7 @@ def generate_fixtures(categories): print("Warning invalid slug: {slug}, length: {length}".format(slug=slug, length=len(slug))) return cats.values() - + def train(df, columns, output_validation=False, output_fixtures=True): texts, labels, train_texts, train_labels, test_texts, test_labels = classifier.make_data_sets(df, columns=columns) colnames = "_".join(columns) @@ -108,9 +108,9 @@ def train(df, columns, output_validation=False, output_fixtures=True): print("Using args: {}".format(args)) classifier = TextClassifier() - df = classifier.load_data(csv_file=args.csv, frac=args.fract) + df = classifier.load_data(input_file=args.input_file, frac=args.fract) if len(df) == 0: - print("Failed to load {}".format(args.csv)) + print("Failed to load {}".format(args.input_file)) exit(-1) else: print("{} rows loaded".format(len(df))) diff --git a/notebook/requirements.txt b/notebook/requirements.txt index 3375307..1f00bc0 100644 --- a/notebook/requirements.txt +++ b/notebook/requirements.txt @@ -71,5 +71,4 @@ wcwidth==0.1.9 webencodings==0.5.1 wrapt==1.12.1 xlrd==1.2.0 -xlsx2csv==0.7.6 zipp==3.1.0 diff --git a/requirements-train.txt b/requirements-train.txt deleted file mode 100644 index b3e1048..0000000 --- a/requirements-train.txt +++ /dev/null @@ -1,44 +0,0 @@ -asgiref==3.2.10 -astroid==2.4.2 -attrs==19.3.0 -click==7.1.2 -cycler==0.10.0 -dill==0.3.2 -Django==3.0.7 -Flask==1.1.2 -Flask-Cors==3.0.8 -gunicorn==20.0.4 -isort==4.3.21 -itsdangerous==1.1.0 -Jinja2==2.11.2 -joblib==0.15.1 -kiwisolver==1.2.0 -lazy-object-proxy==1.4.3 -MarkupSafe==1.1.1 -matplotlib -mccabe==0.6.1 -more-itertools==8.4.0 -nltk==3.5 -numpy -packaging==20.4 -pandas -pluggy==0.13.1 -py==1.8.2 -pylint==2.5.3 -pyparsing==2.4.7 -pytest==5.4.3 -python-dateutil==2.8.1 -pytz==2020.1 -regex==2020.6.8 -scikit-learn -scipy -six==1.15.0 -sklearn==0.0 -sqlparse==0.3.1 -threadpoolctl==2.1.0 -toml==0.10.1 -tqdm==4.46.1 -wcwidth==0.2.4 -Werkzeug==1.0.1 -wrapt==1.12.1 -psutil==5.7.0 diff --git a/requirements.txt b/requirements.txt index b8ed1b3..76df6b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,39 +1,53 @@ +asgiref==3.2.10 astroid==2.4.2 attrs==19.3.0 click==7.1.2 +contourpy==1.0.7 +cycler==0.10.0 dill==0.3.2 +Django==3.0.7 +et-xmlfile==1.1.0 Flask==1.1.2 Flask-Cors==3.0.8 +fonttools==4.39.4 gunicorn==20.0.4 +importlib-resources==5.12.0 isort==4.3.21 itsdangerous==1.1.0 Jinja2==2.11.2 joblib==0.15.1 +kiwisolver==1.4.4 lazy-object-proxy==1.4.3 MarkupSafe==1.1.1 +matplotlib==3.7.1 mccabe==0.6.1 more-itertools==8.4.0 nltk==3.5 -numpy==1.18.5 +numpy==1.24.3 +openpyxl==3.0.10 packaging==20.4 -pandas==1.0.4 +pandas==1.5.3 +Pillow==9.5.0 pluggy==0.13.1 +psutil==5.9.5 py==1.8.2 pylint==2.5.3 pyparsing==2.4.7 pytest==5.4.3 python-dateutil==2.8.1 pytz==2020.1 -regex==2020.6.8 -scikit-learn==0.23.1 -scipy==1.4.1 +regex==2023.6.3 +scikit-learn==1.0.2 +scipy==1.10.1 six==1.15.0 sklearn==0.0 +sqlparse==0.3.1 threadpoolctl==2.1.0 toml==0.10.1 tqdm==4.46.1 -uWSGI==2.0.19 +uWSGI==2.0.21 wcwidth==0.2.4 Werkzeug==1.0.1 wrapt==1.12.1 -psutil==5.7.0 +xlrd==2.0.1 +zipp==3.15.0 \ No newline at end of file