Skip to content

Commit 6f0780f

Browse files
Handle transitive dependencies
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent 5e141ec commit 6f0780f

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

scripts/dependency_manager.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ class DependencyManager:
1515
def __init__(self, pyproject_path="pyproject.toml"):
1616
self.pyproject_path = Path(pyproject_path)
1717
self.dependencies = self._load_dependencies()
18+
19+
# Map of packages that need specific transitive dependency constraints when downgraded
20+
self.transitive_dependencies = {
21+
'pandas': {
22+
# When pandas is downgraded to 1.x, ensure numpy compatibility
23+
'numpy': {
24+
'min_constraint': '>=1.16.5,<2.0.0', # pandas 1.x works with numpy 1.x
25+
'applies_when': lambda version: version.startswith('1.')
26+
}
27+
}
28+
}
1829

1930
def _load_dependencies(self):
2031
"""Load dependencies from pyproject.toml"""
@@ -96,6 +107,24 @@ def _create_flexible_minimum_constraint(self, package_name, min_version):
96107
# Fallback to exact version
97108
return f"{package_name}=={min_version}"
98109

110+
def _get_transitive_dependencies(self, package_name, version, version_type):
111+
"""Get transitive dependencies that need specific constraints based on the main package version"""
112+
transitive_reqs = []
113+
114+
if package_name in self.transitive_dependencies:
115+
transitive_deps = self.transitive_dependencies[package_name]
116+
117+
for dep_name, dep_config in transitive_deps.items():
118+
# Check if this transitive dependency applies for this version
119+
if dep_config['applies_when'](version):
120+
if version_type == "min":
121+
# Use the predefined constraint for minimum versions
122+
constraint = dep_config['min_constraint']
123+
transitive_reqs.append(f"{dep_name}{constraint}")
124+
# For default version_type, we don't add transitive deps as Poetry handles them
125+
126+
return transitive_reqs
127+
99128
def generate_requirements(self, version_type="min", include_optional=False):
100129
"""
101130
Generate requirements for specified version type.
@@ -105,6 +134,7 @@ def generate_requirements(self, version_type="min", include_optional=False):
105134
include_optional: Whether to include optional dependencies
106135
"""
107136
requirements = []
137+
transitive_requirements = []
108138

109139
for name, constraint in self.dependencies.items():
110140
if name == 'python':
@@ -126,8 +156,31 @@ def generate_requirements(self, version_type="min", include_optional=False):
126156
# Create flexible constraint that allows patch updates for compatibility
127157
flexible_constraint = self._create_flexible_minimum_constraint(name, min_version)
128158
requirements.append(flexible_constraint)
159+
160+
# Check if this package needs specific transitive dependencies
161+
transitive_deps = self._get_transitive_dependencies(name, min_version, version_type)
162+
transitive_requirements.extend(transitive_deps)
129163

130-
return requirements
164+
# Combine main requirements with transitive requirements
165+
all_requirements = requirements + transitive_requirements
166+
167+
# Remove duplicates (prefer main requirements over transitive ones)
168+
seen_packages = set()
169+
final_requirements = []
170+
171+
# First add main requirements
172+
for req in requirements:
173+
package_name = req.split('>=')[0].split('==')[0].split('<')[0]
174+
seen_packages.add(package_name)
175+
final_requirements.append(req)
176+
177+
# Then add transitive requirements that don't conflict
178+
for req in transitive_requirements:
179+
package_name = req.split('>=')[0].split('==')[0].split('<')[0]
180+
if package_name not in seen_packages:
181+
final_requirements.append(req)
182+
183+
return final_requirements
131184

132185

133186
def write_requirements_file(self, filename, version_type="min", include_optional=False):
@@ -140,6 +193,7 @@ def write_requirements_file(self, filename, version_type="min", include_optional
140193
f.write(f"# Uses flexible constraints to resolve compatibility conflicts:\n")
141194
f.write(f"# - Common packages (requests, urllib3, pandas): >=min,<next_major\n")
142195
f.write(f"# - Other packages: >=min,<next_minor\n")
196+
f.write(f"# - Includes transitive dependencies (e.g., numpy for pandas)\n")
143197
else:
144198
f.write(f"# {version_type.title()} dependency versions generated from pyproject.toml\n")
145199
for req in sorted(requirements):

0 commit comments

Comments
 (0)