Skip to content

Commit 34087aa

Browse files
Merge pull request #168 from geometric-intelligence/pirnn_refactor
Pirnn refactor
2 parents 5f1d3db + fe1dd2e commit 34087aa

36 files changed

+78351
-5099
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ neurometry/datasets/rnn_grid_cells/Dual agent path integration high res/*
88
neurometry/datasets/rnn_grid_cells/Single agent path integration high res/*
99
neurometry/curvature/grid-cells-curvature/models/xu_rnn/results/*
1010
neurometry/curvature/grid-cells-curvature/multi-agent/*
11+
neurometry/neuroai/piRNNs/models/results/*
12+
neurometry/neuroai/piRNNs/multi-agent/*
1113
notebooks/
1214

1315
*viewer*
@@ -36,6 +38,9 @@ neurometry/datasets/rnn_grid_cells/Single agent path integration/*
3638

3739
neurometry/curvature/grid-cells-curvature/models/xu_rnn/logs/*
3840
neurometry/curvature/grid-cells-curvature/models/xu_rnn/wandb/*
41+
neurometry/neuroai/piRNNs/models/logs/*
42+
neurometry/neuroai/piRNNs/models/wandb/*
43+
neurometry/neuroai/piRNNs/models/pretrained/*
3944

4045

4146
# Result files

neurometry/curvature/datasets/gridcells.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
import pandas as pd
55

6+
import neurometry.curvature.datasets.structures as structures
7+
68
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
79
import geomstats.backend as gs
810

9-
import neurometry.curvature.datasets.structures as structures
10-
1111

1212
# TODO
1313
def load_grid_cells_synthetic(

neurometry/curvature/datasets/synthetic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33
import logging
44
import os
55

6-
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
7-
import geomstats.backend as gs
86
import numpy as np
97
import pandas as pd
108
import skimage
119
import torch
12-
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
1310
from torch.distributions.multivariate_normal import MultivariateNormal
1411

1512
from neurometry.topology.persistent_homology import (
1613
cohomological_circular_coordinates,
1714
cohomological_toroidal_coordinates,
1815
)
1916

17+
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
18+
import geomstats.backend as gs
19+
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
20+
2021

2122
def load_projected_images(n_scalars=5, n_angles=1000, img_size=128):
2223
"""Load a dataset of 2D images projected into 1D projections.

neurometry/curvature/evaluate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
import numpy as np
55
import torch
66

7+
from neurometry.curvature.datasets.synthetic import (
8+
get_s1_synthetic_immersion,
9+
get_s2_synthetic_immersion,
10+
get_t2_synthetic_immersion,
11+
)
12+
713
os.environ["GEOMSTATS_BACKEND"] = "pytorch"
814
import geomstats.backend as gs
915
from geomstats.geometry.base import ImmersedSet
1016
from geomstats.geometry.euclidean import Euclidean
1117
from geomstats.geometry.pullback_metric import PullbackMetric
1218
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
1319

14-
from neurometry.curvature.datasets.synthetic import (
15-
get_s1_synthetic_immersion,
16-
get_s2_synthetic_immersion,
17-
get_t2_synthetic_immersion,
18-
)
19-
2020

2121
class NeuralManifoldIntrinsic(ImmersedSet):
2222
def __init__(self, dim, neural_embedding_dim, neural_immersion, equip=True):

neurometry/curvature/grid-cells-curvature/models/xu_rnn/LICENSE

Lines changed: 0 additions & 201 deletions
This file was deleted.

0 commit comments

Comments
 (0)