From 93cddbffe60bb70247fcd135d1553a063bb316dd Mon Sep 17 00:00:00 2001 From: Matthieu Heitz Date: Thu, 15 Sep 2022 16:11:34 -0700 Subject: [PATCH] Avoid computing PCA when using a custom cost matrix --- wot/ot/ot_model.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/wot/ot/ot_model.py b/wot/ot/ot_model.py index 43a5f30..15ec235 100644 --- a/wot/ot/ot_model.py +++ b/wot/ot/ot_model.py @@ -291,23 +291,21 @@ def compute_single_transport_map(self, config): logger.info('No cells at {}'.format(t1)) return None - local_pca = config.pop('local_pca', None) - eigenvals = None - if local_pca is not None and local_pca > 0: - # pca, mean = wot.ot.get_pca(local_pca, p0.X, p1.X) - # p0_x = wot.ot.pca_transform(pca, mean, p0.X) - # p1_x = wot.ot.pca_transform(pca, mean, p1.X) - p0_x, p1_x, pca, mean = wot.ot.compute_pca(p0.X, p1.X, local_pca) - eigenvals = np.diag(pca.singular_values_) - else: - p0_x = p0.X - p1_x = p1.X - - #Check if we need to calculate a cost matrix + # Check if we need to calculate a cost matrix if config['C'] is None: - C = OTModel.compute_default_cost_matrix(p0_x, p1_x, eigenvals) - config['C'] = C - + local_pca = config.pop('local_pca', None) + eigenvals = None + if local_pca is not None and local_pca > 0: + # pca, mean = wot.ot.get_pca(local_pca, p0.X, p1.X) + # p0_x = wot.ot.pca_transform(pca, mean, p0.X) + # p1_x = wot.ot.pca_transform(pca, mean, p1.X) + p0_x, p1_x, pca, mean = wot.ot.compute_pca(p0.X, p1.X, local_pca) + eigenvals = np.diag(pca.singular_values_) + else: + p0_x = p0.X + p1_x = p1.X + config['C'] = OTModel.compute_default_cost_matrix(p0_x, p1_x, eigenvals) + C = config['C'] delta_days = t1 - t0