import { consts } from 'globalsData';

const classificationAlgorithmHyperParameters = [
  { algorithm: 'logisticRegression', hyperParameters: ['regulationParameter', 'lrsolver', 'penalty', 'tolerance', 'modelMetrics'] },
  { algorithm: 'linearDiscriminantAnalysis', hyperParameters: ['ldaSolver', 'tolerance', 'modelMetrics'] },
  { algorithm: 'supportVectorMachine', hyperParameters: ['regulationParameter', 'gamma', 'kernel', 'tolerance', 'modelMetrics'] },
  { algorithm: 'randomForest', hyperParameters: ['numberOfEstimators', 'rfCriterion', 'rfMaxFeatures', 'modelMetrics'] },
  { algorithm: 'kNearestNeighborsClassifier', hyperParameters: ['numberOfNearestNeighbors', 'p', 'kNNalgorithm', 'modelMetrics'] },
  { algorithm: 'baggingClassifier', hyperParameters: ['bcNumberOfEstimators', 'bcMaxFeatures', 'modelMetrics'] },
  { algorithm: 'adaBoostClassifier', hyperParameters: ['adcAlgorithm', 'adcNumberOfEstimators', 'learningRate', 'modelMetrics'] },
  {
    algorithm: 'gradientBoostingClassifier',
    hyperParameters: ['gbcCriterion', 'loss', 'gbcMaxFeatures', 'numberOfEstimators', 'learningRateGbc', 'tolerance', 'modelMetrics'],
  },
  { algorithm: 'XGBClassifier', hyperParameters: ['booster', 'importanceType', 'learningRateXgb', 'numberOfEstimators', 'modelMetrics'] },
  { algorithm: 'all' },
];

const regressionAlgorithmHyperParameters = [
  { algorithm: 'linearRegression' },
  { algorithm: 'ridgeRegression', hyperParameters: ['alpha', 'solver', 'tol'] },
  { algorithm: 'lassoRegression', hyperParameters: ['alpha', 'selection', 'tol'] },
  { algorithm: 'randomForestRegressor', hyperParameters: ['rfCriterion', 'rfMax_features', 'rfN_estimators'] },
  { algorithm: 'kNNRegressor', hyperParameters: ['kNNalgorithm', 'n_neighbors', 'p', 'weights'] },
  { algorithm: 'baggingRegressor', hyperParameters: ['brMax_features', 'brN_estimators'] },
  { algorithm: 'adaBoostRegressor', hyperParameters: ['learning_rate', 'adaLoss', 'adaN_estimators'] },
  {
    algorithm: 'gradientBoostingRegressor',
    hyperParameters: ['gbrCriterion', 'gbrLoss', 'gbrMaxFeatures', 'gbrN_estimators', 'tol'],
  },
  { algorithm: 'XGBRegressor', hyperParameters: ['booster', 'importance_type', 'learning_rate', 'xgbN_estimators'] },
  { algorithm: 'all' },
];

// Hyperparameters for Timeseries starts
const timeseriesAlgorithmHyperParameters = [
  { algorithm: 'exponentialSmoothing', hyperParameters: ['trend', 'dampedTrend', 'seasonal', 'seasonalPeriods', 'freq'] },
  { algorithm: 'arima', hyperParameters: ['arimaSeasonal', 'seasonalDifferencing'] },
  {
    algorithm: 'prophet',
    hyperParameters: ['prophetSeasonalityMode', 'prophetYearlySeasonality', 'prophetWeeklySeasonality', 'prophetDailySeasonality'],
  },
  {
    algorithm: 'neuralProphet',
    hyperParameters: ['npSeasonalityMode', 'npYearlySeasonality', 'npWeeklySeasonality', 'npDailySeasonality'],
  },
  { algorithm: 'all' },
];
// Hyperparameters for Timeseries ends

// Hyperparameters for Clustering starts
const clusteringAlgorithmHyperParameters = [
  { algorithm: 'agglomerativeClustering', hyperParameters: ['agglomative_n_cluster', 'agglomerativeMetricListAlgorithm', 'agglomerativeLinkageListAlgorithm'] },
  { algorithm: 'kmeans', hyperParameters: ['kmeansInitListAlgorithm', 'kmeans_algorithm','tol','kmeans_n_cluster'] },
  {
    algorithm: 'kmodes',
    hyperParameters: ['kmod_n_cluster', 'kmodesInitList'],
  },
  { algorithm: 'auto_n_cluster' },
];
// Hyperparameters for Clustering ends

const getHyperParametersList = (accelerator) => {
  if (accelerator === consts.CLASSIFICATION) {
    return classificationAlgorithmHyperParameters;
  } else if (accelerator === consts.REGRESSION) {
    return regressionAlgorithmHyperParameters;
  } else if (accelerator === consts.TIMESERIES) {
    return timeseriesAlgorithmHyperParameters;
  } 
  else if (accelerator === consts.CLUSTERING) {
    return clusteringAlgorithmHyperParameters;
  } else {
    return [];
  }
};

const getAlgorithmHyperParameters = (accelerator, formData) => {
  if (!accelerator || !formData) return;
  const { algorithm } = formData;
  const hyperParametersList = getHyperParametersList(accelerator).find((item) => item.algorithm === algorithm)?.hyperParameters;
  const hyperParameters = {};
  for (let key in formData) {
    if (!!hyperParametersList && hyperParametersList.length && hyperParametersList.includes(key)) {
      hyperParameters[key] = formData[key];
    }
  }
  return hyperParameters;
};

export default getAlgorithmHyperParameters;
