Files
ai/sandbox/dexorder/conda_manager.py

513 lines
16 KiB
Python

"""
Conda Package Manager
Manages dynamic installation and cleanup of conda packages for user components.
Scans metadata files to determine required packages and syncs the conda environment.
Extra packages (user-installed beyond the base container) are tracked in
``extra_packages.json`` under ``data_dir`` so they can be removed when no
script references them. Packages that are later promoted into the base image
(i.e. appear in ``environment.yml``) are silently evicted from tracking
rather than uninstalled.
"""
import json
import logging
import subprocess
import sys
from pathlib import Path
from typing import Optional, Set
# Filename (stored under data_dir, outside the git repo) for tracking
# user-installed extra packages.
EXTRA_PACKAGES_FILENAME = "extra_packages.json"
log = logging.getLogger(__name__)
# =============================================================================
# Conda Environment Detection
# =============================================================================
def get_conda_env_path() -> Optional[Path]:
"""
Detect the active conda environment path.
Returns:
Path to conda environment, or None if not in a conda environment
"""
# Check for CONDA_PREFIX environment variable
import os
conda_prefix = os.getenv("CONDA_PREFIX")
if conda_prefix:
return Path(conda_prefix)
# Check if python executable is in a conda environment
python_path = Path(sys.executable)
# Look for conda-meta directory (indicates conda environment)
for parent in [python_path.parent, python_path.parent.parent]:
if (parent / "conda-meta").exists():
return parent
return None
def get_conda_executable() -> Optional[Path]:
"""
Find the conda executable.
Returns:
Path to conda executable, or None if not found
"""
env_path = get_conda_env_path()
if not env_path:
return None
# Try common locations
for conda_name in ["conda", "mamba"]:
# Look in environment bin
conda_bin = env_path / "bin" / conda_name
if conda_bin.exists():
return conda_bin
# Look in parent conda installation
parent_conda = env_path.parent.parent / "bin" / conda_name
if parent_conda.exists():
return parent_conda
return None
# =============================================================================
# Package Management
# =============================================================================
def get_installed_packages() -> Set[str]:
"""
Get set of currently installed conda packages.
Returns:
Set of package names
"""
conda_exe = get_conda_executable()
if not conda_exe:
log.error("Failed to list conda packages: conda executable not found")
return set()
try:
result = subprocess.run(
[str(conda_exe), "list", "-n", "dexorder", "--json"],
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
packages = json.loads(result.stdout)
return {pkg["name"] for pkg in packages}
else:
log.error(f"Failed to list conda packages: {result.stderr}")
return set()
except subprocess.TimeoutExpired:
log.error("Timeout while listing conda packages")
return set()
except Exception as e:
log.error(f"Error listing conda packages: {e}")
return set()
def load_extra_packages(data_dir: Path) -> Set[str]:
"""Load the set of user-installed extra packages (beyond the base container)."""
path = data_dir / EXTRA_PACKAGES_FILENAME
if path.exists():
try:
return set(json.loads(path.read_text()))
except Exception as e:
log.error(f"Failed to load extra packages: {e}")
return set()
def save_extra_packages(data_dir: Path, packages: Set[str]) -> None:
"""Persist the set of user-installed extra packages."""
path = data_dir / EXTRA_PACKAGES_FILENAME
try:
path.write_text(json.dumps(sorted(packages)))
except Exception as e:
log.error(f"Failed to save extra packages: {e}")
def install_packages(packages: list[str], data_dir: Optional[Path] = None) -> dict:
"""
Install conda packages if not already installed.
Args:
packages: List of package names to install
data_dir: If provided, newly installed packages are added to the extra
package tracking file (``extra_packages.json``) so they can
be cleaned up when no longer needed.
Returns:
dict with:
- success: bool
- installed: list[str] - packages that were installed
- skipped: list[str] - packages already installed
- failed: list[str] - packages that failed to install
- error: str (if any)
"""
if not packages:
return {
"success": True,
"installed": [],
"skipped": [],
"failed": [],
}
# Get currently installed packages
installed = get_installed_packages()
# Filter out already installed packages
to_install = [pkg for pkg in packages if pkg not in installed]
skipped = [pkg for pkg in packages if pkg in installed]
if not to_install:
log.info(f"All packages already installed: {skipped}")
return {
"success": True,
"installed": [],
"skipped": skipped,
"failed": [],
}
# Install missing packages
log.info(f"Installing conda packages: {to_install}")
conda_exe = get_conda_executable()
if not conda_exe:
log.error("Failed to install packages: conda executable not found")
return {
"success": False,
"installed": [],
"skipped": skipped,
"failed": to_install,
"error": "conda executable not found",
}
try:
result = subprocess.run(
[str(conda_exe), "install", "-y", "-n", "dexorder", "-c", "conda-forge"] + to_install,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
if result.returncode == 0:
log.info(f"Successfully installed packages: {to_install}")
if data_dir:
extras = load_extra_packages(data_dir)
extras.update(to_install)
save_extra_packages(data_dir, extras)
return {
"success": True,
"installed": to_install,
"skipped": skipped,
"failed": [],
}
else:
log.error(f"Failed to install packages: {result.stderr}")
return {
"success": False,
"installed": [],
"skipped": skipped,
"failed": to_install,
"error": result.stderr,
}
except subprocess.TimeoutExpired:
log.error("Timeout while installing conda packages")
return {
"success": False,
"installed": [],
"skipped": skipped,
"failed": to_install,
"error": "Installation timeout",
}
except Exception as e:
log.error(f"Error installing conda packages: {e}")
return {
"success": False,
"installed": [],
"skipped": skipped,
"failed": to_install,
"error": str(e),
}
def remove_packages(packages: list[str]) -> dict:
"""
Remove conda packages.
Args:
packages: List of package names to remove
Returns:
dict with:
- success: bool
- removed: list[str] - packages that were removed
- error: str (if any)
"""
if not packages:
return {
"success": True,
"removed": [],
}
log.info(f"Removing conda packages: {packages}")
conda_exe = get_conda_executable()
if not conda_exe:
log.error("Failed to remove packages: conda executable not found")
return {
"success": False,
"removed": [],
"error": "conda executable not found",
}
try:
result = subprocess.run(
[str(conda_exe), "remove", "-y", "-n", "dexorder"] + packages,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
log.info(f"Successfully removed packages: {packages}")
return {
"success": True,
"removed": packages,
}
else:
log.error(f"Failed to remove packages: {result.stderr}")
return {
"success": False,
"removed": [],
"error": result.stderr,
}
except subprocess.TimeoutExpired:
log.error("Timeout while removing conda packages")
return {
"success": False,
"removed": [],
"error": "Removal timeout",
}
except Exception as e:
log.error(f"Error removing conda packages: {e}")
return {
"success": False,
"removed": [],
"error": str(e),
}
# =============================================================================
# Metadata Scanning
# =============================================================================
def scan_metadata_packages(data_dir: Path) -> Set[str]:
"""
Scan all metadata files to find required conda packages.
Args:
data_dir: Base data directory containing category subdirectories
Returns:
Set of all required package names
"""
packages = set()
# Scan all category directories
for category_dir in data_dir.iterdir():
if not category_dir.is_dir():
continue
# Scan all items in this category
for item_dir in category_dir.iterdir():
if not item_dir.is_dir():
continue
metadata_path = item_dir / "metadata.json"
if not metadata_path.exists():
continue
try:
metadata = json.loads(metadata_path.read_text())
conda_packages = metadata.get("conda_packages", [])
if conda_packages:
packages.update(conda_packages)
log.debug(f"Found packages in {item_dir.name}: {conda_packages}")
except Exception as e:
log.error(f"Failed to read metadata from {metadata_path}: {e}")
return packages
def get_base_packages(environment_yml: Path) -> Set[str]:
"""
Get base packages from environment.yml.
Args:
environment_yml: Path to environment.yml file
Returns:
Set of base package names
"""
if not environment_yml.exists():
log.warning(f"environment.yml not found at {environment_yml}")
return set()
try:
import yaml
with open(environment_yml) as f:
env_spec = yaml.safe_load(f)
packages = set()
# Add conda packages
for dep in env_spec.get("dependencies", []):
if isinstance(dep, str):
# Extract package name (before version spec)
pkg_name = dep.split(">=")[0].split("=")[0].split("<")[0].split(">")[0].strip()
packages.add(pkg_name)
return packages
except Exception as e:
log.error(f"Failed to parse environment.yml: {e}")
return set()
# =============================================================================
# Cleanup and Sync Operations
# =============================================================================
def cleanup_extra_packages(data_dir: Path, environment_yml: Optional[Path] = None) -> dict:
"""
Remove tracked extra packages that are no longer needed by any script.
Only packages previously recorded in ``extra_packages.json`` are ever
considered for removal — base container packages are never touched.
Packages that have since been promoted into the base container image
(i.e. now appear in ``environment.yml``) are quietly evicted from the
tracking file without being uninstalled.
Args:
data_dir: Base data directory (tracking file lives here)
environment_yml: Path to environment.yml for base package reconciliation
Returns:
dict with:
- success: bool
- to_remove: list[str] - packages identified for removal
- removed: list[str] - packages actually removed
- error: str (if any)
"""
src_dir = data_dir / "src"
required = scan_metadata_packages(src_dir)
base = get_base_packages(environment_yml) if environment_yml and environment_yml.exists() else set()
extras = load_extra_packages(data_dir)
# Packages promoted into the base image are no longer "extra" — evict from tracking
now_base = extras & base
if now_base:
log.info(f"Packages promoted to base image, evicting from extra tracking: {now_base}")
extras -= now_base
# Only remove packages that are tracked as extras and no longer referenced by any script
to_remove = sorted(extras - required)
result: dict = {"success": True, "to_remove": to_remove, "removed": []}
if to_remove:
remove_result = remove_packages(to_remove)
result["success"] = remove_result["success"]
result["removed"] = remove_result.get("removed", [])
if remove_result["success"]:
extras -= set(to_remove)
else:
result["error"] = remove_result.get("error")
save_extra_packages(data_dir, extras)
return result
def sync_packages(data_dir: Path, environment_yml: Optional[Path] = None) -> dict:
"""
Sync conda packages with metadata requirements.
Scans all metadata files, computes desired package set, and removes
packages that are no longer needed (excluding base environment packages).
Args:
data_dir: Base data directory
environment_yml: Path to environment.yml (optional)
Returns:
dict with:
- success: bool
- required: list[str] - packages required by metadata
- base: list[str] - base packages from environment.yml
- installed: list[str] - currently installed packages
- to_remove: list[str] - packages to be removed
- removed: list[str] - packages that were removed
- error: str (if any)
"""
log.info("Starting conda package sync")
# Metadata lives under data_dir/src/category/item/metadata.json
required_packages = scan_metadata_packages(data_dir / "src")
log.info(f"Required packages from metadata: {required_packages}")
# Get base packages from environment.yml
base_packages = set()
if environment_yml and environment_yml.exists():
base_packages = get_base_packages(environment_yml)
log.info(f"Base packages from environment.yml: {base_packages}")
# Get currently installed packages
installed_packages = get_installed_packages()
log.info(f"Currently installed packages: {len(installed_packages)} total")
# Compute packages to remove
# Remove packages that are:
# - Currently installed
# - Not in base packages
# - Not in required packages
protected = base_packages | required_packages
to_remove = [pkg for pkg in installed_packages if pkg not in protected]
# Filter out critical system packages (be conservative)
system_prefixes = ["python", "conda", "pip", "setuptools", "wheel", "_"]
to_remove = [pkg for pkg in to_remove if not any(pkg.startswith(prefix) for prefix in system_prefixes)]
log.info(f"Packages to remove: {to_remove}")
result = {
"success": True,
"required": sorted(required_packages),
"base": sorted(base_packages),
"installed": sorted(installed_packages),
"to_remove": to_remove,
"removed": [],
}
# Remove packages if any
if to_remove:
remove_result = remove_packages(to_remove)
result["success"] = remove_result["success"]
result["removed"] = remove_result.get("removed", [])
if not remove_result["success"]:
result["error"] = remove_result.get("error", "Unknown error")
log.info(f"Conda package sync complete: {len(result['removed'])} packages removed")
return result