defcalc_shannon(dataset): n = len(dataset) cnt = {} for v in dataset: label = v[-1] cnt[label] = cnt.get(label, 0) + 1 shannon = 0.0 for label in cnt: prob = float(cnt[label]) / n shannon -= prob * log(prob, 2) return shannon
划分数据集
把数据集 dataset 按照在 axis 分量的取值是否是 value 进行划分
为了避免污染数据集,因此新创建了局部变量
对于每一条记录,如果取值是 value 则在此处分开,并拼接这个属性之后的属性
追加到答案中,最终返回
1 2 3 4 5 6 7 8
defsplit_dataset(dataset, axis, value): sub_dataset = [] for record in dataset: if record[axis] == value: prev = record[:axis] prev.extend(record[axis + 1:]) sub_dataset.append(prev) return sub_dataset
选择最优划分
这一步中对每一个属性进行划分,选择其中可以让熵增最大的那一个属性的下标并返回
由于数据集的特点,因此除了最后一项是标签之外其他都是属性
如果熵增变大,则更新答案,最后返回结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
defget_best_feat(dataset): n = len(dataset[0]) - 1 base = calc_shannon(dataset) mx_gain, idx = 0.0, -1 for i inrange(n): feat_set = set([v[i] for v in dataset]) nxt_shannon = 0.0 for value in feat_set: subset = split_dataset(dataset, i, value) prob = len(subset) / len(dataset) nxt_shannon += prob * calc_shannon(subset) dealt = base - nxt_shannon if dealt > mx_gain: mx_gain = dealt idx = i return idx
defmajority_count(feat): counter = {} for f in feat: counter[f] = counter.get(f, 0) + 1 sorted_counter = sorted(counter.items(), key=lambda x: x[1], reverse=True) return sorted_counter[0][0]
defcreate_tree(dataset, labels): feats = [d[-1] for d in dataset] if feats.count(feats[0]) == len(feats): return feats[0] iflen(dataset[0]) == 1: return majority_count(feats) feat = get_best_feat(dataset) label = labels[feat] res = {label: {}} del(labels[feat]) subset = set([d[feat] for d in dataset]) for v in subset: sub_labels = labels[:] res[label][v] = create_tree(split_dataset(dataset, feat, v), sub_labels) return res