diff --git a/src/models/train.py b/src/models/train.py index 0cd2331..dc11bba 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -137,7 +137,13 @@ def main(): # Try to log model to MLflow (if permissions allow) try: - mlflow.sklearn.log_model(model, "model") + # Create input example for model signature + input_example = X_train.head(1) + mlflow.sklearn.log_model( + model, + artifact_path="model", + input_example=input_example + ) print("\nModel logged to MLflow successfully!") except Exception as e: print(f"\nWarning: Could not log model to MLflow: {e}")