diff --git a/src/models/train.py b/src/models/train.py index a2f7e50..9cd01be 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -13,14 +13,25 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score import pandas as pd -# Configure MLflow -mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "https://mlflow.sortifal.dev")) +# Configure MLflow with authentication +tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "https://mlflow.sortifal.dev") +username = os.getenv("MLFLOW_TRACKING_USERNAME") +password = os.getenv("MLFLOW_TRACKING_PASSWORD") -# Set MLflow credentials from environment variables -if os.getenv("MLFLOW_TRACKING_USERNAME") and os.getenv("MLFLOW_TRACKING_PASSWORD"): - os.environ["MLFLOW_TRACKING_USERNAME"] = os.getenv("MLFLOW_TRACKING_USERNAME") - os.environ["MLFLOW_TRACKING_PASSWORD"] = os.getenv("MLFLOW_TRACKING_PASSWORD") - print("MLflow credentials configured from environment variables") +# Build authenticated URI if credentials are provided +if username and password: + # Extract protocol and host from URI + if tracking_uri.startswith("https://"): + auth_uri = f"https://{username}:{password}@{tracking_uri[8:]}" + elif tracking_uri.startswith("http://"): + auth_uri = f"http://{username}:{password}@{tracking_uri[7:]}" + else: + auth_uri = tracking_uri + mlflow.set_tracking_uri(auth_uri) + print(f"MLflow credentials configured for {tracking_uri}") +else: + mlflow.set_tracking_uri(tracking_uri) + print("MLflow configured without authentication") # Try to set experiment, but handle auth errors gracefully USE_MLFLOW = True