25 lines
723 B
Python
25 lines
723 B
Python
import mlflow
|
|
import mlflow.sklearn
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
mlflow.set_tracking_uri("http://localhost:5000")
|
|
mlflow.set_experiment("csgo-match-prediction")
|
|
|
|
def train_model(X_train, y_train, params):
|
|
with mlflow.start_run(run_name="rf-v1"):
|
|
# Log params
|
|
mlflow.log_params(params)
|
|
mlflow.log_param("data_version", "v1.0.0")
|
|
|
|
# Train
|
|
model = RandomForestClassifier(**params)
|
|
model.fit(X_train, y_train)
|
|
|
|
# Log metrics
|
|
accuracy = model.score(X_test, y_test)
|
|
mlflow.log_metric("accuracy", accuracy)
|
|
|
|
# Log model
|
|
mlflow.sklearn.log_model(model, "model")
|
|
|
|
return model |