Nauč se Python > Kurzy > Datový kurz PyLadies > Strojové učení - klasifikační úlohy > Úvod klasifikace

Klasifikace #

Zatím jsme se zabývali jen regresními úlohami. Učení s učitelem ale zahrnuje dvě hlavní skupiny úloh - regresní úlohy a klasifikační úlohy.

Zatímco u regresních úloh je výstupem modelu spojitá hodnota (float), v klasifikačních úlohách představuje výstup modelu indikátor třídy (label).

Držme se našeho rybího trhu a ukažme si to na příkladu. Úloha predikovat váhu ryby byla regresní úloha, predikovali jsme spojitou hodnotu. Pokud budeme chtít predikovat druh ryby (Perch - okoun, Roach - plotice, Pike - štika, ...), jedná se o predikci kategorické hodnoty, tedy o klasifikaci.

Klasifikační úlohy mají trochu jiné vlastnosti a logiku, než úlohy regresní, proto existují modely přímo určené na takové úlohy. Říká se jim klasifikátory.

Zkusíme se ale nejdřív podívat na úlohu klasifikace z pohledu, který už známe, tedy z pohledu krajiny.

data

In [1]:
# načeteme si data 
import pandas as pd 
import numpy as np 
np.random.seed(2020)  # nastavení náhodného klasifikátoru

data = pd.read_csv("static/fish_data.csv", index_col=0)
data
Out[1]:
Species Weight Length1 Length2 Length3 Height Width ID
0 Bream 242.0 23.2 25.4 30.0 11.5200 4.0200 0
1 Bream 290.0 24.0 26.3 31.2 12.4800 4.3056 1
2 Bream 340.0 23.9 26.5 31.1 12.3778 4.6961 2
3 Bream 363.0 26.3 29.0 33.5 12.7300 4.4555 3
4 Bream 430.0 26.5 29.0 34.0 12.4440 5.1340 4
... ... ... ... ... ... ... ... ...
153 Smelt 9.8 11.4 12.0 13.2 2.2044 1.1484 153
154 Smelt 12.2 11.5 12.2 13.4 2.0904 1.3936 154
155 Smelt 13.4 11.7 12.4 13.5 2.4300 1.2690 155
156 Smelt 12.2 12.1 13.0 13.8 2.2770 1.2558 156
158 Smelt 19.9 13.8 15.0 16.2 2.9322 1.8792 158

123 rows × 8 columns

Úkol 1: #

Nejčastějším druhem ryby je Perch (okoun). Naším cílem je vytvořit klasifikátor, který pro zadané míry (váha, různé délky a šířky) vrátí informaci, zda se jedná o okouna nebo jiný druh. (Máme tedy pro jednoduchost jen dvě třídy, Perch a ostatní.)

  • Uměla bys tuto úlohu napasovat na krajinu? Co by mohly být souřadnice a co nadmořská výška?

  • Pokud ses úspěšně poprala s předchozím dotazem, můžeš na klasifikaci použít některý z regresních modelů (ano, asi to nebude ideální, když jde o klasifikaci, ale zkusme nejdříve to, co již umíme). Co ale bude hodnota odezvy a jak ji budeme interpretovat?

Klasifikační modely #

Přinášíme opět nějakou základní nabídku klasifikačních modelů:

Úkol 2: #

Vyberete si jeden model a zkuste natrénovat na ryby.

Nejprve připravíme data obdobně jako v minulé hodině. Jako sloupeček odezvy použijeme True pro okouny a False pro ostatní ryby. Sloupeček Species pak už nebudeme potřebovat, stejně tak můžeme vypustit sloupeček ID.

In [2]:
# připravme data
y = data["Species"] == "Perch"
y = y.astype(int)
X = data.drop(columns=["ID", "Species"])

Dalším krokem je rozdělení na trénovací a validační data. Nezapomeňme na stratifikaci.

In [3]:
# rozdělme na trénovací a validační množinu
from sklearn.model_selection import train_test_split 
X_train_raw, X_test_raw, y_train, y_test =  train_test_split(X, y, stratify=y)

Data přeškálujeme.

In [4]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)
X_test = scaler.transform(X_test_raw)

Jako model zvolíme rozhodovací strom. Neboj se zkusit jiný klasifikátor dle své volby.

In [5]:
# vezměme klasifikátor 
# můžeš změnit 
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
In [6]:
# natrénujte
model.fit(X_train, y_train);

Máme natrénovaný model, jdeme se podívat, jak funguje na validačních datech.

In [7]:
# ohodnoťme validační množinu 
pred = model.predict(X_test) 
In [8]:
print("Skutečná třída:  Predikce:")
for true, predicted in zip(y_test, pred):
    print(f"{true:<15}  {predicted:<10} {'OK' if true == predicted else 'X'}")

print(f"Počet chyb: {sum(y_test != pred)}")
Skutečná třída:  Predikce:
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
1                1          OK
0                0          OK
1                1          OK
0                0          OK
0                0          OK
1                0          X
0                0          OK
1                0          X
0                0          OK
1                0          X
0                0          OK
0                0          OK
0                0          OK
1                1          OK
0                0          OK
0                1          X
0                1          X
0                0          OK
0                0          OK
1                1          OK
1                0          X
1                1          OK
1                1          OK
1                1          OK
0                0          OK
Počet chyb: 6

Úkol 3: #

  • Asi je jasné, že regresní metriky se nám na klasifikační úlohy moc nehodí. Co bys použila jako metriku pro klasifikační úlohu?

Úkol 4: #

  • Jedna z možností je porovnávat procento úspěšně klasifikovaných vzorů. V našem případě, to bude:
In [9]:
print(f"Úspěšnost: {100*sum(y_test == pred)/len(y_test):.2f} %")
Úspěšnost: 80.65 %

Úspěšnost není úplně špatná, poznat druh ryby podle rozměrů není jendoduchá úloha.

Představ si ale, že budeme mít v datovou množinu se 100 rybami, 95 z nich bude okounů (typu Perch). Bude ti klasifikátor, který bude mít toto procento úspěšnosti (stejné jako vyšlo nám), připadat dobrý nebo ne? Proč?

Úkol 5: #

Nejprve projdeme klasifikační metriky. Pokud studuješ sama, nastuduj si kapitolu o klasifikačních metrikách a pak se vrať k tomuto cvičení.

Vyber si metriku pro naši úlohu a zkus najít, co nejlepší klasifikátor. Pak si načti testovací množinu a podívej se, jaké tvůj klasifikátor dává výsledky.

In [10]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.svm import SVC

# zkus naučit různé modely a vyber nejlepší
models = {}

# KNeigbors 
for N in 1, 3, 5, 7:
    models[("nearest neighbors", N)] = KNeighborsClassifier(n_neighbors=N, weights="distance")
    
# tree
for d in range(3, 20):
    models[("tree", d)] = DecisionTreeClassifier(max_depth=d, class_weight='balanced')
    
# random forest
for N in range(1, 100):
    models[("random forest", N)] = RandomForestClassifier(n_estimators=N, class_weight='balanced')
    
# SVC
for C in range(-2, 10):
    models[("SVC", 10**C)] = SVC(C=10**C, class_weight='balanced')

Vytvořili jsme si slušnou zásobu modelů, uložili jsme je do slovníku. Každý model máme pro různé hodnoty příslušného hyper-parametru.

In [11]:
models
Out[11]:
{('nearest neighbors',
  1): KNeighborsClassifier(n_neighbors=1, weights='distance'),
 ('nearest neighbors',
  3): KNeighborsClassifier(n_neighbors=3, weights='distance'),
 ('nearest neighbors', 5): KNeighborsClassifier(weights='distance'),
 ('nearest neighbors',
  7): KNeighborsClassifier(n_neighbors=7, weights='distance'),
 ('tree', 3): DecisionTreeClassifier(class_weight='balanced', max_depth=3),
 ('tree', 4): DecisionTreeClassifier(class_weight='balanced', max_depth=4),
 ('tree', 5): DecisionTreeClassifier(class_weight='balanced', max_depth=5),
 ('tree', 6): DecisionTreeClassifier(class_weight='balanced', max_depth=6),
 ('tree', 7): DecisionTreeClassifier(class_weight='balanced', max_depth=7),
 ('tree', 8): DecisionTreeClassifier(class_weight='balanced', max_depth=8),
 ('tree', 9): DecisionTreeClassifier(class_weight='balanced', max_depth=9),
 ('tree', 10): DecisionTreeClassifier(class_weight='balanced', max_depth=10),
 ('tree', 11): DecisionTreeClassifier(class_weight='balanced', max_depth=11),
 ('tree', 12): DecisionTreeClassifier(class_weight='balanced', max_depth=12),
 ('tree', 13): DecisionTreeClassifier(class_weight='balanced', max_depth=13),
 ('tree', 14): DecisionTreeClassifier(class_weight='balanced', max_depth=14),
 ('tree', 15): DecisionTreeClassifier(class_weight='balanced', max_depth=15),
 ('tree', 16): DecisionTreeClassifier(class_weight='balanced', max_depth=16),
 ('tree', 17): DecisionTreeClassifier(class_weight='balanced', max_depth=17),
 ('tree', 18): DecisionTreeClassifier(class_weight='balanced', max_depth=18),
 ('tree', 19): DecisionTreeClassifier(class_weight='balanced', max_depth=19),
 ('random forest',
  1): RandomForestClassifier(class_weight='balanced', n_estimators=1),
 ('random forest',
  2): RandomForestClassifier(class_weight='balanced', n_estimators=2),
 ('random forest',
  3): RandomForestClassifier(class_weight='balanced', n_estimators=3),
 ('random forest',
  4): RandomForestClassifier(class_weight='balanced', n_estimators=4),
 ('random forest',
  5): RandomForestClassifier(class_weight='balanced', n_estimators=5),
 ('random forest',
  6): RandomForestClassifier(class_weight='balanced', n_estimators=6),
 ('random forest',
  7): RandomForestClassifier(class_weight='balanced', n_estimators=7),
 ('random forest',
  8): RandomForestClassifier(class_weight='balanced', n_estimators=8),
 ('random forest',
  9): RandomForestClassifier(class_weight='balanced', n_estimators=9),
 ('random forest',
  10): RandomForestClassifier(class_weight='balanced', n_estimators=10),
 ('random forest',
  11): RandomForestClassifier(class_weight='balanced', n_estimators=11),
 ('random forest',
  12): RandomForestClassifier(class_weight='balanced', n_estimators=12),
 ('random forest',
  13): RandomForestClassifier(class_weight='balanced', n_estimators=13),
 ('random forest',
  14): RandomForestClassifier(class_weight='balanced', n_estimators=14),
 ('random forest',
  15): RandomForestClassifier(class_weight='balanced', n_estimators=15),
 ('random forest',
  16): RandomForestClassifier(class_weight='balanced', n_estimators=16),
 ('random forest',
  17): RandomForestClassifier(class_weight='balanced', n_estimators=17),
 ('random forest',
  18): RandomForestClassifier(class_weight='balanced', n_estimators=18),
 ('random forest',
  19): RandomForestClassifier(class_weight='balanced', n_estimators=19),
 ('random forest',
  20): RandomForestClassifier(class_weight='balanced', n_estimators=20),
 ('random forest',
  21): RandomForestClassifier(class_weight='balanced', n_estimators=21),
 ('random forest',
  22): RandomForestClassifier(class_weight='balanced', n_estimators=22),
 ('random forest',
  23): RandomForestClassifier(class_weight='balanced', n_estimators=23),
 ('random forest',
  24): RandomForestClassifier(class_weight='balanced', n_estimators=24),
 ('random forest',
  25): RandomForestClassifier(class_weight='balanced', n_estimators=25),
 ('random forest',
  26): RandomForestClassifier(class_weight='balanced', n_estimators=26),
 ('random forest',
  27): RandomForestClassifier(class_weight='balanced', n_estimators=27),
 ('random forest',
  28): RandomForestClassifier(class_weight='balanced', n_estimators=28),
 ('random forest',
  29): RandomForestClassifier(class_weight='balanced', n_estimators=29),
 ('random forest',
  30): RandomForestClassifier(class_weight='balanced', n_estimators=30),
 ('random forest',
  31): RandomForestClassifier(class_weight='balanced', n_estimators=31),
 ('random forest',
  32): RandomForestClassifier(class_weight='balanced', n_estimators=32),
 ('random forest',
  33): RandomForestClassifier(class_weight='balanced', n_estimators=33),
 ('random forest',
  34): RandomForestClassifier(class_weight='balanced', n_estimators=34),
 ('random forest',
  35): RandomForestClassifier(class_weight='balanced', n_estimators=35),
 ('random forest',
  36): RandomForestClassifier(class_weight='balanced', n_estimators=36),
 ('random forest',
  37): RandomForestClassifier(class_weight='balanced', n_estimators=37),
 ('random forest',
  38): RandomForestClassifier(class_weight='balanced', n_estimators=38),
 ('random forest',
  39): RandomForestClassifier(class_weight='balanced', n_estimators=39),
 ('random forest',
  40): RandomForestClassifier(class_weight='balanced', n_estimators=40),
 ('random forest',
  41): RandomForestClassifier(class_weight='balanced', n_estimators=41),
 ('random forest',
  42): RandomForestClassifier(class_weight='balanced', n_estimators=42),
 ('random forest',
  43): RandomForestClassifier(class_weight='balanced', n_estimators=43),
 ('random forest',
  44): RandomForestClassifier(class_weight='balanced', n_estimators=44),
 ('random forest',
  45): RandomForestClassifier(class_weight='balanced', n_estimators=45),
 ('random forest',
  46): RandomForestClassifier(class_weight='balanced', n_estimators=46),
 ('random forest',
  47): RandomForestClassifier(class_weight='balanced', n_estimators=47),
 ('random forest',
  48): RandomForestClassifier(class_weight='balanced', n_estimators=48),
 ('random forest',
  49): RandomForestClassifier(class_weight='balanced', n_estimators=49),
 ('random forest',
  50): RandomForestClassifier(class_weight='balanced', n_estimators=50),
 ('random forest',
  51): RandomForestClassifier(class_weight='balanced', n_estimators=51),
 ('random forest',
  52): RandomForestClassifier(class_weight='balanced', n_estimators=52),
 ('random forest',
  53): RandomForestClassifier(class_weight='balanced', n_estimators=53),
 ('random forest',
  54): RandomForestClassifier(class_weight='balanced', n_estimators=54),
 ('random forest',
  55): RandomForestClassifier(class_weight='balanced', n_estimators=55),
 ('random forest',
  56): RandomForestClassifier(class_weight='balanced', n_estimators=56),
 ('random forest',
  57): RandomForestClassifier(class_weight='balanced', n_estimators=57),
 ('random forest',
  58): RandomForestClassifier(class_weight='balanced', n_estimators=58),
 ('random forest',
  59): RandomForestClassifier(class_weight='balanced', n_estimators=59),
 ('random forest',
  60): RandomForestClassifier(class_weight='balanced', n_estimators=60),
 ('random forest',
  61): RandomForestClassifier(class_weight='balanced', n_estimators=61),
 ('random forest',
  62): RandomForestClassifier(class_weight='balanced', n_estimators=62),
 ('random forest',
  63): RandomForestClassifier(class_weight='balanced', n_estimators=63),
 ('random forest',
  64): RandomForestClassifier(class_weight='balanced', n_estimators=64),
 ('random forest',
  65): RandomForestClassifier(class_weight='balanced', n_estimators=65),
 ('random forest',
  66): RandomForestClassifier(class_weight='balanced', n_estimators=66),
 ('random forest',
  67): RandomForestClassifier(class_weight='balanced', n_estimators=67),
 ('random forest',
  68): RandomForestClassifier(class_weight='balanced', n_estimators=68),
 ('random forest',
  69): RandomForestClassifier(class_weight='balanced', n_estimators=69),
 ('random forest',
  70): RandomForestClassifier(class_weight='balanced', n_estimators=70),
 ('random forest',
  71): RandomForestClassifier(class_weight='balanced', n_estimators=71),
 ('random forest',
  72): RandomForestClassifier(class_weight='balanced', n_estimators=72),
 ('random forest',
  73): RandomForestClassifier(class_weight='balanced', n_estimators=73),
 ('random forest',
  74): RandomForestClassifier(class_weight='balanced', n_estimators=74),
 ('random forest',
  75): RandomForestClassifier(class_weight='balanced', n_estimators=75),
 ('random forest',
  76): RandomForestClassifier(class_weight='balanced', n_estimators=76),
 ('random forest',
  77): RandomForestClassifier(class_weight='balanced', n_estimators=77),
 ('random forest',
  78): RandomForestClassifier(class_weight='balanced', n_estimators=78),
 ('random forest',
  79): RandomForestClassifier(class_weight='balanced', n_estimators=79),
 ('random forest',
  80): RandomForestClassifier(class_weight='balanced', n_estimators=80),
 ('random forest',
  81): RandomForestClassifier(class_weight='balanced', n_estimators=81),
 ('random forest',
  82): RandomForestClassifier(class_weight='balanced', n_estimators=82),
 ('random forest',
  83): RandomForestClassifier(class_weight='balanced', n_estimators=83),
 ('random forest',
  84): RandomForestClassifier(class_weight='balanced', n_estimators=84),
 ('random forest',
  85): RandomForestClassifier(class_weight='balanced', n_estimators=85),
 ('random forest',
  86): RandomForestClassifier(class_weight='balanced', n_estimators=86),
 ('random forest',
  87): RandomForestClassifier(class_weight='balanced', n_estimators=87),
 ('random forest',
  88): RandomForestClassifier(class_weight='balanced', n_estimators=88),
 ('random forest',
  89): RandomForestClassifier(class_weight='balanced', n_estimators=89),
 ('random forest',
  90): RandomForestClassifier(class_weight='balanced', n_estimators=90),
 ('random forest',
  91): RandomForestClassifier(class_weight='balanced', n_estimators=91),
 ('random forest',
  92): RandomForestClassifier(class_weight='balanced', n_estimators=92),
 ('random forest',
  93): RandomForestClassifier(class_weight='balanced', n_estimators=93),
 ('random forest',
  94): RandomForestClassifier(class_weight='balanced', n_estimators=94),
 ('random forest',
  95): RandomForestClassifier(class_weight='balanced', n_estimators=95),
 ('random forest',
  96): RandomForestClassifier(class_weight='balanced', n_estimators=96),
 ('random forest',
  97): RandomForestClassifier(class_weight='balanced', n_estimators=97),
 ('random forest',
  98): RandomForestClassifier(class_weight='balanced', n_estimators=98),
 ('random forest',
  99): RandomForestClassifier(class_weight='balanced', n_estimators=99),
 ('SVC', 0.01): SVC(C=0.01, class_weight='balanced'),
 ('SVC', 0.1): SVC(C=0.1, class_weight='balanced'),
 ('SVC', 1): SVC(C=1, class_weight='balanced'),
 ('SVC', 10): SVC(C=10, class_weight='balanced'),
 ('SVC', 100): SVC(C=100, class_weight='balanced'),
 ('SVC', 1000): SVC(C=1000, class_weight='balanced'),
 ('SVC', 10000): SVC(C=10000, class_weight='balanced'),
 ('SVC', 100000): SVC(C=100000, class_weight='balanced'),
 ('SVC', 1000000): SVC(C=1000000, class_weight='balanced'),
 ('SVC', 10000000): SVC(C=10000000, class_weight='balanced'),
 ('SVC', 100000000): SVC(C=100000000, class_weight='balanced'),
 ('SVC', 1000000000): SVC(C=1000000000, class_weight='balanced')}

Obdobně jako v minulé hodině vytvoříme funci, která ohodnotí model a vrátí hodnoty vybrané metriky na trénovací a validační množině. Hodnoty vrací ve slovníku (což nám pak umožní snadnější vytvoření dataframu s výsledky).

In [12]:
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

def train_and_eval(X_train, X_test, y_train, y_test, model):
    model.fit(X_train, y_train)
    y_pred_test = model.predict(X_test)
    y_pred_train = model.predict(X_train)
    return {"train": f1_score(y_train, y_pred_train), # metriku můžeš vyměnit za nějakou svojí
            "test": f1_score(y_test, y_pred_test)
           }
In [13]:
results = []
for name, model in models.items():
    res = train_and_eval(X_train, X_test, y_train, y_test, model)
    res["model"] = name[0]
    res["param"] = name[1]
    results.append(res)
    
df_results = pd.DataFrame(results)
df_results
Out[13]:
train test model param
0 1.000000 0.818182 nearest neighbors 1.000000e+00
1 1.000000 0.818182 nearest neighbors 3.000000e+00
2 1.000000 0.869565 nearest neighbors 5.000000e+00
3 1.000000 0.818182 nearest neighbors 7.000000e+00
4 0.704545 0.666667 tree 3.000000e+00
... ... ... ... ...
127 1.000000 1.000000 SVC 1.000000e+05
128 1.000000 1.000000 SVC 1.000000e+06
129 1.000000 1.000000 SVC 1.000000e+07
130 1.000000 1.000000 SVC 1.000000e+08
131 1.000000 1.000000 SVC 1.000000e+09

132 rows × 4 columns

Závislost úspěsnosti modelu (dle zvolené metriky) na hodnotě příslušného hyperparametru si zobrazíme v grafu.

In [14]:
import seaborn as sns
import matplotlib.pyplot as plt

def zobraz_model(model_name, ax, logx=False):
    sns.lineplot(x="param", y="train", data=df_results[df_results["model"]==model_name], label="train", ax=ax)
    sns.lineplot(x="param", y="test", data=df_results[df_results["model"]==model_name], label="test", ax=ax)
    ax.set_title(model_name.capitalize())
    if logx:
        ax.set(xscale="log")
    
fig, axs = plt.subplots(ncols=4, figsize=(16,4))
zobraz_model("nearest neighbors", axs[0])
zobraz_model("tree", axs[1])
zobraz_model("random forest", axs[2])
zobraz_model("SVC", axs[3], logx=True)
No description has been provided for this image

Úkol 6: #

Vyber si model, který se na validační množině jeví jako nejlepší. Vyzkoušej jej na testovací data.

In [15]:
# načtení data
test_data = pd.read_csv("static/fish_data_test.csv", index_col=0)
y_real_test = test_data["Species"] == "Perch"
y_real_test = y_real_test.astype(int)
X_real_test = test_data.drop(columns=["ID", "Species"])
X_real_test = scaler.transform(X_real_test)
In [16]:
# predikce
model = models[("SVC", 10**4)]
test_pred = model.predict(X_real_test)
In [17]:
# zkus přidat zvolenou metriku
print(f"Skutečná třída:  Predikce:")
for true, predicted in zip(y_real_test, test_pred):
    print(f"{true:<15}  {predicted:<10} {'OK' if true == predicted else 'X'}")

print(f"Počet chyb: {sum(y_real_test != test_pred)}")
print(f"Úspěšnost: {100*sum(y_real_test == test_pred)/len(y_real_test):.2f} %")
Skutečná třída:  Predikce:
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                1          X
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
1                0          X
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
0                0          OK
0                0          OK
0                1          X
Počet chyb: 3
Úspěšnost: 91.67 %

Toto je stránka lekce z kurzu, který probíhá nebo proběhl naživo s instruktorem.