adjusted rand index和adjusted mutual information score

来源:互联网 发布:淘宝流量突然下降了 编辑:程序博客网 时间:2024/06/16 02:45
import matplotlib.pyplot as pltimport numpy as npfrom sklearn import metricsdef uniform_labelings_scores(score_func, n_samples, n_clusters_range, fixed_n_classes=None, n_runs=5, seed=42):    random_labels = np.random.RandomState(seed).random_integers    scores = np.zeros((len(n_clusters_range), n_runs))    for i, k in enumerate(n_clusters_range):        for j in range(n_runs):            if fixed_n_classes is None:                labels_a = random_labels(low=1, high=k, size=n_samples)            else:                labels_a = random_labels(low=1, high=fixed_n_classes, size=n_samples)            labels_b = random_labels(low=1, high=k, size=n_samples)            scores[i, j] = score_func(labels_a, labels_b)    return scoresscore_funcs = [metrics.adjusted_rand_score, metrics.adjusted_mutual_info_score]n_samples = 100n_clusters_range = map(int, np.linspace(2, 100, 10))plots = []names = []for score_func in score_funcs:    scores = uniform_labelings_scores(score_func, n_samples, n_clusters_range, fixed_n_classes=10)    plots.append(plt.errorbar(n_clusters_range, np.mean(scores, axis=1), scores.std(axis=1))[0])    names.append(score_func.__name__)plt.legend(plots, names)plt.show()

0 0