IFD:EAI SoS21/course material/Session 4: Programming the Classifier Part1: Difference between revisions

From Medien Wiki
Line 39: Line 39:
</syntaxhighlight>
</syntaxhighlight>
Our classifier is written in a class called "KMeans" because it most closely resembles a KMeans classifier.  
Our classifier is written in a class called "KMeans" because it most closely resembles a KMeans classifier.  
So inside the "KMeans.cpp" you will find the guts of our classifier. You can see that, when the KMeans classifier gets instantiated (=the contructor is called), we feed it with the points and the class labels
So inside the "KMeans.cpp" you will find the guts of our classifier. You can see that, when the KMeans classifier gets instantiated (=the contructor is called), we feed it with the points and the class labels and it immediately calculates the centroids for each class.


<syntaxhighlight lang="c++">
<syntaxhighlight lang="c++">
Line 45: Line 45:
{
{
     _n_classes = n_classes;
     _n_classes = n_classes;
     for(int c=0; c < n_classes; c++)
     for(int c=0; c < n_classes; c++) // traverse once for every class
     {
     {
         _centroids.push_back(Point2D(0,0));
         _centroids.push_back(Point2D(0,0));
Line 62: Line 62:
         _centroids[c] = _centroids[c]/numPoints;
         _centroids[c] = _centroids[c]/numPoints;
     }
     }
}
</syntaxhighlight>
For that it traverses the vector of points a couple of times. Once for every class we have. In each of these traverses it only looks for points of the same class, sums them up, keeps a record of how many they were, and finally divides by the number of points it found. So we end up with the centroid for every class and save it in a vector of centroids called "_centroids". This is a private variable of our class KMeans.
Now, when we classify, we know the centroids and just need to calculate the distance from our new point to every centroid.
<syntaxhighlight lang="c++">
int KMeans::classify(Point2D newPoint)
{
    float min_distance = 99238719884798124; // just a biiig distance to start with
    int class_label = -1; // and a wrong class label
    for(int c=0; c<_n_classes ; c++)
    {
        float distance = _centroids[c].getDistance(newPoint);
        if(distance < min_distance)
        {
            min_distance = distance;
            class_label = c;
        }
    }
    return class_label;
}
}
</syntaxhighlight>
</syntaxhighlight>