fork download
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3.  
  4. def k_means_clustering(data, k, max_iterations=100):
  5. """
  6. Perform K-means clustering on the given data.
  7.  
  8. Parameters:
  9. data (numpy.ndarray): The data to be clustered.
  10. k (int): The number of clusters.
  11. max_iterations (int): Maximum number of iterations for convergence.
  12.  
  13. Returns:
  14. tuple: A tuple containing centroids and labels.
  15. centroids (numpy.ndarray): The final centroids of the clusters.
  16. labels (numpy.ndarray): The labels for each data point.
  17. """
  18. centroids = data[np.random.choice(range(len(data)), k, replace=False)]
  19. for _ in range(max_iterations):
  20. labels = np.argmin(np.linalg.norm(data - centroids[:, np.newaxis], axis=2), axis=0)
  21. new_centroids = np.array([data[labels == i].mean(axis=0) for i in range(k)])
  22. if np.all(centroids == new_centroids):
  23. break
  24. centroids = new_centroids
  25. return centroids, labels
  26.  
  27. np.random.seed(42)
  28. data = np.concatenate([np.random.normal(loc=i, scale=1, size=(100, 2)) for i in range(3)])
  29. k = 3
  30. centroids, labels = k_means_clustering(data, k)
  31.  
  32. plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', alpha=0.7)
  33. plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='X', s=200, label='Centroids')
  34. plt.title('K-means Clustering')
  35. plt.legend()
  36. plt.show()
  37.  
Success #stdin #stdout 0.68s 55980KB
stdin
mushroom
stdout
Standard output is empty