11from tornado import gen , web , locks
22import traceback
33import urllib .parse
4-
54from notebook .base .handlers import IPythonHandler
65import threading
76import json
1110
1211from .pull import GitPuller
1312from .version import __version__
14- from .hookspecs import handle_files
15- from .plugins .zip_puller import ZipSourceGoogleDriveDownloader
16- from .plugins .zip_puller import ZipSourceDropBoxDownloader
17- from .plugins .zip_puller import ZipSourceWebDownloader
13+ from . import hookspecs
1814import pluggy
15+ import nbgitpuller
1916
2017
2118class SyncHandler (IPythonHandler ):
@@ -43,17 +40,50 @@ def emit(self, data):
4340 self .write ('data: {}\n \n ' .format (serialized_data ))
4441 yield self .flush ()
4542
46- def setup_plugins (self , repo ):
43+ def setup_plugins (self , provider ):
4744 pm = pluggy .PluginManager ("nbgitpuller" )
48- pm .add_hookspecs (handle_files )
49- if "drive.google.com" in repo :
50- pm .register (ZipSourceGoogleDriveDownloader ())
51- elif "dropbox.com" in repo :
52- pm .register (ZipSourceDropBoxDownloader ())
53- else :
54- pm .register (ZipSourceWebDownloader ())
45+ pm .add_hookspecs (hookspecs )
46+ pm .load_setuptools_entrypoints ("nbgitpuller" , name = provider )
5547 return pm
5648
49+ def handle_provider_zip (self , provider ):
50+ pm = self .setup_plugins (provider )
51+ req_args = {k : v [0 ].decode () for k , v in self .request .arguments .items ()}
52+ download_q = Queue ()
53+ req_args ["download_q" ] = download_q
54+ hf_args = {"query_line_args" : req_args }
55+ dl_thread = ThreadWithResult (target = pm .hook .handle_files , kwargs = hf_args )
56+ dl_thread .start ()
57+ self .progress_loop (download_q )
58+ dl_thread .join ()
59+ return dl_thread .result
60+
61+ @gen .coroutine
62+ def progress_loop (self , queue ):
63+ while True :
64+ try :
65+ progress = queue .get_nowait ()
66+ except Empty :
67+ yield gen .sleep (0.1 )
68+ continue
69+ if progress is None :
70+ yield gen .sleep (5 )
71+ return
72+ if isinstance (progress , Exception ):
73+ self .emit ({
74+ 'phase' : 'error' ,
75+ 'message' : str (progress ),
76+ 'output' : '\n ' .join ([
77+ line .strip ()
78+ for line in traceback .format_exception (
79+ type (progress ), progress , progress .__traceback__
80+ )
81+ ])
82+ })
83+ return
84+
85+ self .emit ({'output' : progress , 'phase' : 'syncing' })
86+
5787 @web .authenticated
5888 @gen .coroutine
5989 def get (self ):
@@ -69,7 +99,7 @@ def get(self):
6999 try :
70100 repo = self .get_argument ('repo' )
71101 branch = self .get_argument ('branch' , None )
72- compressed = self .get_argument ('compressed ' , "false" )
102+ provider = self .get_argument ('provider ' , None )
73103 depth = self .get_argument ('depth' , None )
74104 if depth :
75105 depth = int (depth )
@@ -82,22 +112,31 @@ def get(self):
82112 # so that all repos are always in scope after cloning. Sometimes
83113 # server_root_dir will include things like `~` and so the path
84114 # must be expanded.
85- repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]),
86- os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
87- repo_dir = os .path .join (repo_parent_dir , self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
115+ repo_parent_dir = os .path .join (os .path .expanduser (self .settings ['server_root_dir' ]), os .getenv ('NBGITPULLER_PARENTPATH' , '' ))
116+ nbgitpuller .REPO_PARENT_DIR = repo_parent_dir
117+
118+ repo_dir = os .path .join (
119+ repo_parent_dir ,
120+ self .get_argument ('targetpath' , repo .split ('/' )[- 1 ]))
88121
89122 # We gonna send out event streams!
90123 self .set_header ('content-type' , 'text/event-stream' )
91124 self .set_header ('cache-control' , 'no-cache' )
92125
93- if compressed == 'true' :
94- pm = self .setup_plugins (repo )
95- results = pm .hook .handle_files (repo = repo , repo_parent_dir = repo_parent_dir )[0 ]
126+ # if provider is specified then we are dealing with compressed
127+ # archive and not a git repo
128+ if provider is not None :
129+ pm = self .setup_plugins (provider )
130+ req_args = {k : v [0 ].decode () for k , v in self .request .arguments .items ()}
131+ download_q = Queue ()
132+ req_args ["progress_func" ] = lambda : self .progress_loop (download_q )
133+ req_args ["download_q" ] = download_q
134+ hf_args = {"query_line_args" : req_args }
135+ results = pm .hook .handle_files (** hf_args )
96136 repo_dir = repo_parent_dir + results ["unzip_dir" ]
97137 repo = "file://" + results ["origin_repo_path" ]
98138
99139 gp = GitPuller (repo , repo_dir , branch = branch , depth = depth , parent = self .settings ['nbapp' ])
100-
101140 q = Queue ()
102141
103142 def pull ():
@@ -110,33 +149,11 @@ def pull():
110149 q .put_nowait (e )
111150 raise e
112151 self .gp_thread = threading .Thread (target = pull )
113-
114152 self .gp_thread .start ()
115-
116- while True :
117- try :
118- progress = q .get_nowait ()
119- except Empty :
120- yield gen .sleep (0.5 )
121- continue
122- if progress is None :
123- break
124- if isinstance (progress , Exception ):
125- self .emit ({
126- 'phase' : 'error' ,
127- 'message' : str (progress ),
128- 'output' : '\n ' .join ([
129- line .strip ()
130- for line in traceback .format_exception (
131- type (progress ), progress , progress .__traceback__
132- )
133- ])
134- })
135- return
136-
137- self .emit ({'output' : progress , 'phase' : 'syncing' })
138-
153+ self .progress_loop (q )
154+ yield gen .sleep (3 )
139155 self .emit ({'phase' : 'finished' })
156+
140157 except Exception as e :
141158 self .emit ({
142159 'phase' : 'error' ,
@@ -170,11 +187,10 @@ def initialize(self):
170187 @gen .coroutine
171188 def get (self ):
172189 app_env = os .getenv ('NBGITPULLER_APP' , default = 'notebook' )
173-
174190 repo = self .get_argument ('repo' )
175191 branch = self .get_argument ('branch' , None )
176192 depth = self .get_argument ('depth' , None )
177- compressed = self .get_argument ('compressed ' , "false" )
193+ provider = self .get_argument ('provider ' , None )
178194 urlPath = self .get_argument ('urlpath' , None ) or \
179195 self .get_argument ('urlPath' , None )
180196 subPath = self .get_argument ('subpath' , None ) or \
@@ -195,14 +211,17 @@ def get(self):
195211 else :
196212 path = 'tree/' + path
197213
214+ if provider is not None :
215+ path = "tree/"
216+
198217 self .write (
199218 self .render_template (
200219 'status.html' ,
201220 repo = repo ,
202221 branch = branch ,
203- compressed = compressed ,
204222 path = path ,
205223 depth = depth ,
224+ provider = provider ,
206225 targetpath = targetpath ,
207226 version = __version__
208227 ))
@@ -239,3 +258,10 @@ def get(self):
239258 )
240259
241260 self .redirect (new_url )
261+
262+
263+ class ThreadWithResult (threading .Thread ):
264+ def __init__ (self , group = None , target = None , name = None , args = (), kwargs = {}, * , daemon = None ):
265+ def function ():
266+ self .result = target (* args , ** kwargs )
267+ super ().__init__ (group = group , target = function , name = name , daemon = daemon )
0 commit comments