MLOps/src/models/train.py

25 lines
723 B
Python

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("csgo-match-prediction")
def train_model(X_train, y_train, params):
with mlflow.start_run(run_name="rf-v1"):
# Log params
mlflow.log_params(params)
mlflow.log_param("data_version", "v1.0.0")
# Train
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Log metrics
accuracy = model.score(X_test, y_test)
mlflow.log_metric("accuracy", accuracy)
# Log model
mlflow.sklearn.log_model(model, "model")
return model