diff --git a/dvc.yaml b/dvc.yaml new file mode 100644 index 0000000..a66a999 --- /dev/null +++ b/dvc.yaml @@ -0,0 +1,32 @@ +stages: + preprocess: + cmd: python src/data/preprocess.py + deps: + - src/data/preprocess.py + - data/raw + params: + - preprocess.test_size + - preprocess.random_state + outs: + - data/processed/features.csv + - data/processed/train.csv + - data/processed/test.csv + metrics: + - data/processed/data_metrics.json: + cache: false + + train: + cmd: python src/models/train.py + deps: + - src/models/train.py + - data/processed/train.csv + - data/processed/test.csv + params: + - train.n_estimators + - train.max_depth + - train.random_state + outs: + - models/model.pkl + metrics: + - models/metrics.json: + cache: false diff --git a/params.yaml b/params.yaml new file mode 100644 index 0000000..953c388 --- /dev/null +++ b/params.yaml @@ -0,0 +1,8 @@ +preprocess: + test_size: 0.2 + random_state: 42 + +train: + n_estimators: 100 + max_depth: 10 + random_state: 42 diff --git a/rapports/Rapport2.pdf b/rapports/Rapport2.pdf new file mode 100644 index 0000000..fd6f2cf Binary files /dev/null and b/rapports/Rapport2.pdf differ diff --git a/rapports/Rapport2.typ b/rapports/Rapport2.typ new file mode 100644 index 0000000..8106128 --- /dev/null +++ b/rapports/Rapport2.typ @@ -0,0 +1,445 @@ +// Using native Typst table instead of tablex for compatibility + +#set document(title: "Projet CS:GO - Pipeline MLOps", author: "Équipe MLOps") +#set page(margin: 2cm, numbering: "1") +#set text(size: 11pt) +#set heading(numbering: "1.1") + +#align(center)[ + #text(18pt, weight: "bold")[Projet CS:GO Esports Intelligence Platform] + #v(0.5cm) + #text(14pt)[Pipeline MLOps et Stratégie de Monitoring] + #v(0.3cm) + #line(length: 100%) + #v(0.5cm) + + #grid( + columns: (1fr, 1fr), + [*Équipe : Paul Roost, Axelle Desthombes, Alexis Bruneteau* ], [*Date :* #datetime.today().display()] + ) + + #v(0.2cm) + *Dataset :* CS:GO Professional Matches (Kaggle - 25K+ matches) \ + *Objectif :* Prédiction des résultats de matchs et optimisation des stratégies esports +] + +#v(1cm) + += Atelier 1 : Pipeline du Fil Rouge + +== Architecture Générale du Pipeline + +#figure( + image("images/pipeline2.svg", width: 60%), + caption: [Architecture complète du pipeline MLOps CS:GO] +) + + + +== Étapes Détaillées du Pipeline + +=== Collecte et Ingestion des Données + +*Sources de données :* +- *HLTV.org* : Résultats historiques, classements équipes +- *Steam API* : Données joueurs en temps réel +- *Tournament APIs* : Calendriers, formats de compétition + +*Pipeline d'ingestion automatisé avec Apache Airflow :* + +```python +@dag(schedule_interval="@hourly", start_date=datetime(2024,1,1)) +def csgo_data_ingestion(): + + extract_hltv_matches = PythonOperator( + task_id='extract_hltv', + python_callable=scrape_hltv_matches + ) + + validate_data = PythonOperator( + task_id='validate_raw_data', + python_callable=validate_match_schema + ) + + store_s3 = PythonOperator( + task_id='store_to_s3', + python_callable=upload_to_s3 + ) + + extract_hltv_matches >> validate_data >> store_s3 +``` + +=== Feature Engineering Multi-Niveaux + +#table( + columns: (2fr, 3fr), + stroke: 0.5pt, + [*Catégorie*], [*Features*], + [*Team-level*], [ + • `recent_form_10_matches` - Ratio W/L récent \ + • `map_pool_strength` - Win rate par map \ + • `clutch_success_rate` - Performance clutch \ + • `eco_round_conversion` - Gestion économique + ], + [*Context*], [ + • `tournament_tier` - Prestige de l'événement \ + • `prize_pool_amount` - Facteur de pression \ + • `head_to_head_record` - Historique direct \ + • `current_game_patch` - Version meta game + ], + [*Live*], [ + • `current_score_difference` - Score en cours \ + • `momentum_last_5_rounds` - Élan récent \ + • `economy_advantage` - Avantage économique + ] +) + +=== Entraînement Multi-Target + +Architecture d'apprentissage multitâche avec PyTorch : + +```python +class CSGOPredictor(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.shared_layers = nn.Sequential( + nn.Linear(input_dim, 256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, 128) + ) + + # Têtes spécialisées par tâche + self.match_winner = nn.Linear(128, 2) # Classification binaire + self.final_score = nn.Linear(128, 2) # Régression scores + self.total_maps = nn.Linear(128, 4) # Nombre de maps + + def forward(self, x): + shared_repr = self.shared_layers(x) + return { + 'match_winner': self.match_winner(shared_repr), + 'final_score': self.final_score(shared_repr), + 'total_maps': self.total_maps(shared_repr) + } +``` + +== Automatisation et Points de Contrôle + +=== Stratégie d'Automatisation + +#table( + columns: (2fr, 1fr, 3fr), + stroke: 0.5pt, + [*Étape*], [*Status*], [*Justification*], + [*Ingestion données*], [AUTO], [Nouveaux matchs quotidiens, obsolescence rapide], + [*Feature Engineering*], [AUTO], [Features dépendent de données temps-réel], + [*Model Retraining*], [AUTO], [Meta game évolue (patches, transferts)], + [*Deployment*], [AUTO], [Évite erreurs humaines, rollback rapide], + [*Model Selection*], [MANUEL], [Décisions business complexes nécessitant expertise] +) + +=== Points de Contrôle Critiques + +*Validation des Données :* +```python +def validate_match_data(df): + """Validation avant feature engineering""" + checks = [ + ('schema_compliance', validate_schema(df)), + ('completeness', check_missing_values(df, threshold=0.05)), + ('consistency', validate_team_names(df)), + ('freshness', check_data_age(df, max_hours=24)), + ('volume', validate_daily_match_count(df, min_matches=50)) + ] + + for check_name, result in checks: + if not result.passed: + raise DataValidationError(f"{check_name} failed") +``` + +*Validation des Performances :* +```python +def validate_model_performance(model, validation_data): + """Validation avant déploiement""" + metrics = evaluate_model(model, validation_data) + + # Seuils minimaux + assert metrics['accuracy'] > 0.65, "Accuracy insuffisante" + assert metrics['roi_betting'] > 1.05, "ROI non profitable" + assert metrics['upset_detection'] > 0.20, "Détection upsets faible" + + return True +``` + +=== Difficultés Techniques et Solutions + +*Défi 1 : Concept Drift Extrême* + +Les mises à jour du jeu modifient significativement les stratégies et l'équilibre, ce qui peut rendre les modèles existants moins performants. + +*Solution :* Détection automatisée de drift + retraining d'urgence +```python +def detect_meta_shift(recent_matches, baseline): + """Détecte changements post-patch""" + map_rates = calculate_map_win_rates(recent_matches) + baseline_rates = baseline['map_win_rates'] + + for map_name in map_rates: + ks_stat, p_value = ks_2samp(map_rates[map_name], + baseline_rates[map_name]) + if p_value < 0.01: # Drift significatif + return True + return False +``` + +*Défi 2 : Cold Start Problem* + +Les nouvelles équipes ou changements de composition ne disposent pas d'historique suffisant pour l'entraînement. + +*Solution :* Transfer learning via embeddings joueurs +```python +def handle_cold_start_team(roster, player_db): + """Prédictions via similarité joueurs""" + team_embedding = [player_db.get_embedding(p.id) for p in roster] + similar_teams = find_similar_teams(team_embedding, top_k=5) + return weighted_prediction_from_similar(similar_teams) +``` + +#pagebreak() + += Atelier 2 : Expériences et Monitoring + +== Tracking des Expériences avec MLflow + +=== Configuration et Logging Structuré + +```python +mlflow.set_tracking_uri("http://mlflow-server:5000") +mlflow.set_experiment("csgo-match-prediction") + +def train_and_log_experiment(config): + with mlflow.start_run(run_name=f"csgo-v{config.version}"): + + # Hyperparamètres + mlflow.log_params({ + "model_type": config.model_type, + "learning_rate": config.lr, + "batch_size": config.batch_size, + "data_version": config.data_version + }) + + # Métriques par époque + for epoch in range(config.epochs): + train_loss = train_one_epoch(model, train_loader) + val_metrics = evaluate_model(model, val_loader) + + mlflow.log_metrics({ + "train_loss": train_loss, + "val_accuracy": val_metrics['accuracy'], + "betting_roi": val_metrics['roi'], + "upset_detection": val_metrics['upset_rate'] + }, step=epoch) + + # Artefacts finaux + mlflow.pytorch.log_model(model, "model") + mlflow.log_artifacts("evaluation_plots/") +``` + +=== Métriques Trackées + +#table( + columns: (2fr, 3fr), + stroke: 0.5pt, + [*Catégorie*], [*Métriques*], + [*Performance ML*], [ + • Accuracy, Precision, Recall, F1-Score \ + • ROC-AUC, Calibration Error \ + • Performance par segment (tier tournoi) + ], + [*Business*], [ + • ROI betting, Profit/Loss \ + • Sharpe Ratio, Upset Detection Rate \ + • User Engagement, Revenue Impact + ], + [*Computational*], [ + • Training Time, Inference Latency \ + • Model Size, Memory Usage \ + • API Response Time + ] +) + +== Stratégie de Monitoring Complète + +=== Métriques de Surveillance Multi-Niveaux + +*Surveillance de la qualité des données :* +```python +class DataMonitoring: + def monitor_data_quality(self, new_batch): + metrics = {} + + # Volume et couverture + metrics['daily_match_count'] = len(new_batch) + metrics['team_coverage'] = new_batch['team_name'].nunique() + + # Qualité + metrics['missing_rate'] = new_batch.isnull().mean().mean() + metrics['duplicates'] = new_batch.duplicated().sum() + + # Drift distribution + for col in ['team_ranking', 'match_duration']: + drift = calculate_drift_score(new_batch[col], baseline[col]) + metrics[f'{col}_drift'] = drift + + return metrics +``` + +*Model Performance Monitoring :* +```python +def monitor_model_performance(predictions, actuals): + """Monitoring performance temps-réel""" + rolling_metrics = {} + + # Fenêtres glissantes + for window in [1, 7, 30]: # jours + recent = get_recent_data(window) + rolling_metrics[f'accuracy_{window}d'] = accuracy_score( + recent['actual'], recent['predicted'] + ) + rolling_metrics[f'roi_{window}d'] = calculate_roi( + recent['predictions'], recent['outcomes'] + ) + + return rolling_metrics +``` + +=== Système d'Alertes Intelligent + +#table( + columns: (1fr, 2fr, 2fr), + stroke: 0.5pt, + [*Sévérité*], [*Seuils*], [*Actions*], + [*CRITIQUE*], [ + • Accuracy 7j \< 60% \ + • ROI 7j \< 100% \ + • API errors \> 5% + ], [ + • PagerDuty + Slack \ + • Email équipe oncall \ + • Rollback automatique + ], + [*WARNING*], [ + • Accuracy trending ↓ \ + • Concept drift p\<0.05 \ + • Latency \> 300ms + ], [ + • Slack \#alerts \ + • Email ML team \ + • Investigation requise + ], + [*INFO*], [ + • Nouveaux tournaments \ + • Performance updates \ + • System health + ], [ + • Slack \#monitoring \ + • Dashboard updates + ] +) + +=== Dashboards et Rapports + +*Dashboard Temps-Réel (Grafana) :* + +- *Model Performance* : Accuracy, ROI, Calibration trends +- *Data Pipeline Health* : Volume, freshness, quality scores +- *API Performance* : Latency P95, request rate, error rate +- *Business Metrics* : Revenue impact, user engagement + +*Rapports Hebdomadaires Automatisés :* + +```python +class WeeklyReportGenerator: + def generate_performance_report(self, week_start, week_end): + sections = [ + self.executive_summary(), # KPIs clés + self.model_performance(), # Analyse détaillée + self.business_impact(), # Valeur générée + self.technical_health(), # Infrastructure + self.recommendations() # Actions recommandées + ] + return self.compile_html_report(sections) +``` + +== Architecture de Monitoring Production + +=== Alerting Multi-Canal + +```python +class AlertManager: + def __init__(self): + self.channels = { + 'slack': SlackNotifier(SLACK_WEBHOOK), + 'email': EmailNotifier(EMAIL_CONFIG), + 'pagerduty': PagerDutyNotifier(PAGERDUTY_KEY) + } + + def send_alert(self, alert): + if alert['severity'] == 'CRITICAL': + // Alertes critiques sur tous les canaux + self.channels['pagerduty'].send(alert) + self.channels['slack'].send_critical(alert) + self.channels['email'].send_oncall(alert) + elif alert['severity'] == 'WARNING': + // Warnings vers Slack et email + self.channels['slack'].send_warning(alert) + self.channels['email'].send_team(alert) +``` + +=== Runbooks d'Incident + +*Alerte Critique : Accuracy < 60%* + +1. *Actions Immédiates (0-15min)* + - Vérifier qualité des données récentes + - Identifier changements meta/tournois + - Rollback si accuracy < 55% + +2. *Investigation (15-60min)* + - Analyse drift sur données récentes + - Comparaison prédictions vs résultats + - Validation pipeline features + +3. *Résolution (1-4h)* + - Retraining d'urgence si drift détecté + - Fix pipeline si problème data quality + - Rollback si problème infrastructure + += Conclusion + +L'architecture MLOps développée pour ce projet CS:GO présente plusieurs caractéristiques importantes : + +*Architecture de production robuste :* +- Apprentissage multi-tâches permettant des prédictions variées selon les besoins métier +- Service en temps réel respectant les contraintes de latence +- Gestion de la dérive conceptuelle liée à l'évolution du meta-jeu +- Surveillance complète des données, modèles et métriques business + +*Mesure de la valeur métier :* +- Suivi du retour sur investissement pour les applications de paris et fantasy leagues +- Métriques d'engagement utilisateur pour optimiser la rétention +- Impact sur le chiffre d'affaires pour justifier les investissements + +*Fiabilité opérationnelle :* +- Retour en arrière automatique en cas de dégradation des performances +- Système d'alertes multi-canaux pour une réaction rapide +- Procédures documentées pour la résolution d'incidents +- Plan de continuité d'activité pour les événements critiques + +Ce travail démontre l'application des principes MLOps modernes à un domaine spécialisé, en mettant l'accent sur la création de valeur métier et la fiabilité opérationnelle. + +#align(center)[ + #line(length: 50%) + #v(0.3cm) + *Équipe MLOps - Projet CS:GO Intelligence Platform* +] \ No newline at end of file diff --git a/rapports/images/pipeline2.svg b/rapports/images/pipeline2.svg new file mode 100644 index 0000000..1f7bdab --- /dev/null +++ b/rapports/images/pipeline2.svg @@ -0,0 +1,3 @@ + + +

INGEST

CACHE

PROCESS

VALIDATE

TRAIN

TRACK

REGISTER

QUALITY CHECK

BUILD

DEPLOY

RELEASE

MONITOR

ALERT

SERVE

FEEDBACK

RETRAIN

🗃️ Data SourcesHLTV.org Steam APITournament Feeds☁️ Data LakeAmazon S3Raw match results Feature StoreRedis CacheReal-time features⚙️ Data PipelineAirflow + SparkETL Feature Engineering🔍 Quality GatesGreat ExpectationsSchema Drift Detection🤖 Model TrainingMulti-Target MLMatch Prediction📊 Experiment TrackingMLflowHyperparameter Tuning📦 Model RegistryModel StoreA/B Testing🔄 CI/CD PipelineGitea ActionsAutomated Testing🚀 Model ServingFastAPI + ECSReal-time API📈 MonitoringPrometheus + GrafanaModel Performance🚨 AlertingPagerDuty + SlackPerformance Alerts👥 End UsersFantasy SportsBetting Analytics
\ No newline at end of file diff --git a/src/data/preprocess.py b/src/data/preprocess.py new file mode 100644 index 0000000..21d70d5 --- /dev/null +++ b/src/data/preprocess.py @@ -0,0 +1,106 @@ +""" +Data preprocessing pipeline for CSGO match prediction. +Loads raw data, performs feature engineering, and splits into train/test sets. +""" +import pandas as pd +import yaml +import json +from pathlib import Path +from sklearn.model_selection import train_test_split + +def load_params(): + """Load parameters from params.yaml""" + with open("params.yaml") as f: + params = yaml.safe_load(f) + return params["preprocess"] + +def load_raw_data(): + """Load raw CSGO match data""" + results = pd.read_csv("data/raw/results.csv") + return results + +def engineer_features(df): + """Create features for match prediction""" + # Basic features from results + features = df[[ + 'result_1', 'result_2', 'starting_ct', + 'ct_1', 't_2', 't_1', 'ct_2', + 'rank_1', 'rank_2', 'map_wins_1', 'map_wins_2' + ]].copy() + + # Engineered features + features['rank_diff'] = features['rank_1'] - features['rank_2'] + features['map_wins_diff'] = features['map_wins_1'] - features['map_wins_2'] + features['total_rounds'] = features['result_1'] + features['result_2'] + features['round_diff'] = features['result_1'] - features['result_2'] + + # Target: match_winner (1 or 2) -> convert to 0 or 1 + target = df['match_winner'] - 1 + + return features, target + +def save_metrics(X_train, X_test, y_train, y_test): + """Save dataset metrics""" + metrics = { + "n_samples": len(X_train) + len(X_test), + "n_train": len(X_train), + "n_test": len(X_test), + "n_features": X_train.shape[1], + "class_balance_train": { + "class_0": int((y_train == 0).sum()), + "class_1": int((y_train == 1).sum()) + } + } + + Path("data/processed").mkdir(parents=True, exist_ok=True) + with open("data/processed/data_metrics.json", "w") as f: + json.dump(metrics, f, indent=2) + +def main(): + """Main preprocessing pipeline""" + print("Loading parameters...") + params = load_params() + + print("Loading raw data...") + df = load_raw_data() + print(f"Loaded {len(df)} matches") + + print("Engineering features...") + X, y = engineer_features(df) + print(f"Created {X.shape[1]} features") + + print("Splitting data...") + X_train, X_test, y_train, y_test = train_test_split( + X, y, + test_size=params["test_size"], + random_state=params["random_state"], + stratify=y + ) + + print("Saving processed data...") + Path("data/processed").mkdir(parents=True, exist_ok=True) + + # Save full features + full_features = X.copy() + full_features['target'] = y + full_features.to_csv("data/processed/features.csv", index=False) + + # Save train set + train_data = X_train.copy() + train_data['target'] = y_train + train_data.to_csv("data/processed/train.csv", index=False) + + # Save test set + test_data = X_test.copy() + test_data['target'] = y_test + test_data.to_csv("data/processed/test.csv", index=False) + + # Save metrics + save_metrics(X_train, X_test, y_train, y_test) + + print("Preprocessing completed successfully!") + print(f"Train set: {len(X_train)} samples") + print(f"Test set: {len(X_test)} samples") + +if __name__ == "__main__": + main() diff --git a/src/models/train.py b/src/models/train.py index 55ee145..a001a56 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -1,40 +1,145 @@ +""" +Model training pipeline for CSGO match prediction. +Trains a Random Forest classifier and logs results to MLflow. +""" import mlflow import mlflow.sklearn +import yaml +import json +import pickle +from pathlib import Path from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score import pandas as pd +# Configure MLflow mlflow.set_tracking_uri("https://mlflow.sortifal.dev") mlflow.set_experiment("csgo-match-prediction") -def train_model(X_train, y_train, X_test, y_test, params): - with mlflow.start_run(run_name="rf-v1"): - # Log params +def load_params(): + """Load training parameters from params.yaml""" + with open("params.yaml") as f: + params = yaml.safe_load(f) + return params["train"] + +def load_data(): + """Load preprocessed training and test data""" + train_df = pd.read_csv("data/processed/train.csv") + test_df = pd.read_csv("data/processed/test.csv") + + X_train = train_df.drop('target', axis=1) + y_train = train_df['target'] + X_test = test_df.drop('target', axis=1) + y_test = test_df['target'] + + return X_train, y_train, X_test, y_test + +def train_model(X_train, y_train, params): + """Train Random Forest classifier""" + print("Training Random Forest model...") + model = RandomForestClassifier( + n_estimators=params["n_estimators"], + max_depth=params["max_depth"], + random_state=params["random_state"], + n_jobs=-1 + ) + model.fit(X_train, y_train) + return model + +def evaluate_model(model, X_test, y_test): + """Evaluate model and return metrics""" + print("Evaluating model...") + y_pred = model.predict(X_test) + y_pred_proba = model.predict_proba(X_test)[:, 1] + + metrics = { + "accuracy": float(accuracy_score(y_test, y_pred)), + "precision": float(precision_score(y_test, y_pred)), + "recall": float(recall_score(y_test, y_pred)), + "f1_score": float(f1_score(y_test, y_pred)), + "roc_auc": float(roc_auc_score(y_test, y_pred_proba)) + } + + return metrics + +def save_model(model, metrics): + """Save model and metrics locally""" + Path("models").mkdir(parents=True, exist_ok=True) + + # Save model as pickle + with open("models/model.pkl", "wb") as f: + pickle.dump(model, f) + + # Save metrics as JSON + with open("models/metrics.json", "w") as f: + json.dump(metrics, f, indent=2) + + print(f"Model saved to models/model.pkl") + print(f"Metrics saved to models/metrics.json") + +def main(): + """Main training pipeline with MLflow tracking""" + print("=" * 60) + print("CSGO Match Prediction - Model Training") + print("=" * 60) + + # Load parameters and data + params = load_params() + X_train, y_train, X_test, y_test = load_data() + + print(f"\nDataset info:") + print(f" Training samples: {len(X_train)}") + print(f" Test samples: {len(X_test)}") + print(f" Features: {X_train.shape[1]}") + + # Start MLflow run + with mlflow.start_run(run_name="random-forest-csgo"): + # Log parameters mlflow.log_params(params) - mlflow.log_param("data_version", "v1.0.0") - - # Train - model = RandomForestClassifier(**params) - model.fit(X_train, y_train) - - # Log metrics - accuracy = model.score(X_test, y_test) - mlflow.log_metric("accuracy", accuracy) - - # Log model - # mlflow.sklearn.log_model(model, "model") # Commented out due to server permission issue - - return model + mlflow.log_param("n_features", X_train.shape[1]) + mlflow.log_param("n_train_samples", len(X_train)) + mlflow.log_param("n_test_samples", len(X_test)) + + # Train model + model = train_model(X_train, y_train, params) + + # Evaluate model + metrics = evaluate_model(model, X_test, y_test) + + # Log metrics to MLflow + mlflow.log_metrics(metrics) + + # Log feature importance + feature_importance = dict(zip(X_train.columns, model.feature_importances_)) + top_features = sorted(feature_importance.items(), key=lambda x: x[1], reverse=True)[:5] + print("\nTop 5 most important features:") + for feat, importance in top_features: + print(f" {feat}: {importance:.4f}") + mlflow.log_metric(f"importance_{feat}", importance) + + # Try to log model to MLflow (if permissions allow) + try: + mlflow.sklearn.log_model(model, "model") + print("\nModel logged to MLflow successfully!") + except Exception as e: + print(f"\nWarning: Could not log model to MLflow: {e}") + print("Model will only be saved locally.") + + # Save model and metrics locally + save_model(model, metrics) + + # Print results + print("\n" + "=" * 60) + print("Training Results:") + print("=" * 60) + for metric, value in metrics.items(): + print(f" {metric}: {value:.4f}") + print("=" * 60) + + print(f"\nMLflow run ID: {mlflow.active_run().info.run_id}") + print(f"View run at: {mlflow.get_tracking_uri()}") + + print("\nTraining pipeline completed successfully!") if __name__ == "__main__": - # Load data (example with results.csv) - df = pd.read_csv("/home/paul/ING3/MLOps/data/raw/results.csv") - # Select numeric columns for features - numeric_cols = ['result_1', 'result_2', 'starting_ct', 'ct_1', 't_2', 't_1', 'ct_2', 'rank_1', 'rank_2', 'map_wins_1', 'map_wins_2'] - X = df[numeric_cols] - y = df['match_winner'] - 1 # 0 or 1 - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) - - params = {"n_estimators": 100, "max_depth": 10} - model = train_model(X_train, y_train, X_test, y_test, params) - print("Training completed and logged to MLflow.") \ No newline at end of file + main()