11from tornado import gen , web , locks
22import traceback
33import urllib .parse
4-
54from notebook .base .handlers import IPythonHandler
65import threading
76import json
1110
1211from .pull import GitPuller
1312from .version import __version__
14- from .hookspecs import handle_files
15- from .plugins .zip_puller import ZipSourceGoogleDriveDownloader
16- from .plugins .zip_puller import ZipSourceDropBoxDownloader
17- from .plugins .zip_puller import ZipSourceWebDownloader
13+ from . import hookspecs
1814import pluggy
15+ import nbgitpuller
1916
2017
2118class SyncHandler (IPythonHandler ):
@@ -43,17 +40,38 @@ def emit(self, data):
4340 self .write ('data: {}\n \n ' .format (serialized_data ))
4441 yield self .flush ()
4542
46- def setup_plugins (self , repo ):
43+ def setup_plugins (self , provider ):
4744 pm = pluggy .PluginManager ("nbgitpuller" )
48- pm .add_hookspecs (handle_files )
49- if "drive.google.com" in repo :
50- pm .register (ZipSourceGoogleDriveDownloader ())
51- elif "dropbox.com" in repo :
52- pm .register (ZipSourceDropBoxDownloader ())
53- else :
54- pm .register (ZipSourceWebDownloader ())
45+ pm .add_hookspecs (hookspecs )
46+ pm .load_setuptools_entrypoints ("nbgitpuller" , name = provider )
5547 return pm
5648
49+ @gen .coroutine
50+ def progress_loop (self , queue ):
51+ while True :
52+ try :
53+ progress = queue .get_nowait ()
54+ except Empty :
55+ yield gen .sleep (0.1 )
56+ continue
57+ if progress is None :
58+ yield gen .sleep (5 )
59+ return
60+ if isinstance (progress , Exception ):
61+ self .emit ({
62+ 'phase' : 'error' ,
63+ 'message' : str (progress ),
64+ 'output' : '\n ' .join ([
65+ line .strip ()
66+ for line in traceback .format_exception (
67+ type (progress ), progress , progress .__traceback__
68+ )
69+ ])
70+ })
71+ return
72+
73+ self .emit ({'output' : progress , 'phase' : 'syncing' })
74+
5775 @web .authenticated
5876 @gen .coroutine
5977 def get (self ):
@@ -69,7 +87,7 @@ def get(self):
6987 try :
7088 repo = self .get_argument ('repo' )
7189 branch = self .get_argument ('branch' , None )
72- compressed = self .get_argument ('compressed ' , "false" )
90+ provider = self .get_argument ('provider ' , None )
7391 depth = self .get_argument ('depth' , None )
7492 if depth :
7593 depth = int (depth )
@@ -82,22 +100,31 @@ def get(self):
82100 # so that all repos are always in scope after cloning. Sometimes
83101 # server_root_dir will include things like `~` and so the path
84102 # must be expanded.
85- repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]),
86- os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
87- repo_dir = os .path .join (repo_parent_dir , self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
103+ repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]), os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
104+ nbgitpuller .REPO_PARENT_DIR = repo_parent_dir
105+
106+ repo_dir = os .path .join (
107+ repo_parent_dir ,
108+ self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
88109
89110 # We gonna send out event streams!
90111 self .set_header ('content-type' , 'text/event-stream' )
91112 self .set_header ('cache-control' , 'no-cache' )
92113
93- if compressed == 'true' :
94- pm = self .setup_plugins (repo )
95- results = pm .hook .handle_files (repo = repo , repo_parent_dir = repo_parent_dir )[0 ]
114+ # if provider is specified then we are dealing with compressed
115+ # archive and not a git repo
116+ if provider is not None :
117+ pm = self .setup_plugins (provider )
118+ req_args = {k : v [0 ].decode () for k , v in self .request .arguments .items ()}
119+ download_q = Queue ()
120+ req_args ["progress_func" ] = lambda : self .progress_loop (download_q )
121+ req_args ["download_q" ] = download_q
122+ hf_args = {"query_line_args" : req_args }
123+ results = pm .hook .handle_files (** hf_args )
96124 repo_dir = repo_parent_dir + results ["unzip_dir" ]
97125 repo = "file://" + results ["origin_repo_path" ]
98126
99127 gp = GitPuller (repo , repo_dir , branch = branch , depth = depth , parent = self .settings ['nbapp' ])
100-
101128 q = Queue ()
102129
103130 def pull ():
@@ -110,33 +137,11 @@ def pull():
110137 q .put_nowait (e )
111138 raise e
112139 self .gp_thread = threading .Thread (target = pull )
113-
114140 self .gp_thread .start ()
115-
116- while True :
117- try :
118- progress = q .get_nowait ()
119- except Empty :
120- yield gen .sleep (0.5 )
121- continue
122- if progress is None :
123- break
124- if isinstance (progress , Exception ):
125- self .emit ({
126- 'phase' : 'error' ,
127- 'message' : str (progress ),
128- 'output' : '\n ' .join ([
129- line .strip ()
130- for line in traceback .format_exception (
131- type (progress ), progress , progress .__traceback__
132- )
133- ])
134- })
135- return
136-
137- self .emit ({'output' : progress , 'phase' : 'syncing' })
138-
141+ self .progress_loop (q )
142+ yield gen .sleep (3 )
139143 self .emit ({'phase' : 'finished' })
144+
140145 except Exception as e :
141146 self .emit ({
142147 'phase' : 'error' ,
@@ -170,11 +175,10 @@ def initialize(self):
170175 @gen .coroutine
171176 def get (self ):
172177 app_env = os .getenv ('NBGITPULLER_APP' , default = 'notebook' )
173-
174178 repo = self .get_argument ('repo' )
175179 branch = self .get_argument ('branch' , None )
176180 depth = self .get_argument ('depth' , None )
177- compressed = self .get_argument ('compressed ' , "false" )
181+ provider = self .get_argument ('provider ' , None )
178182 urlPath = self .get_argument ('urlpath' , None ) or \
179183 self .get_argument ('urlPath' , None )
180184 subPath = self .get_argument ('subpath' , None ) or \
@@ -195,14 +199,17 @@ def get(self):
195199 else :
196200 path = 'tree/' + path
197201
202+ if provider is not None :
203+ path = "tree/"
204+
198205 self .write (
199206 self .render_template (
200207 'status.html' ,
201208 repo = repo ,
202209 branch = branch ,
203- compressed = compressed ,
204210 path = path ,
205211 depth = depth ,
212+ provider = provider ,
206213 targetpath = targetpath ,
207214 version = __version__
208215 ))
@@ -239,3 +246,10 @@ def get(self):
239246 )
240247
241248 self .redirect (new_url )
249+
250+
251+ class ThreadWithResult (threading .Thread ):
252+ def __init__ (self , group = None , target = None , name = None , args = (), kwargs = {}, * , daemon = None ):
253+ def function ():
254+ self .result = target (* args , ** kwargs )
255+ super ().__init__ (group = group , target = function , name = name , daemon = daemon )
0 commit comments