Muyun99's wiki Muyun99's wiki
首页
学术搬砖
学习笔记
生活杂谈
wiki搬运
资源收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

Muyun99

努力成为一个善良的人
首页
学术搬砖
学习笔记
生活杂谈
wiki搬运
资源收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 常见 bug 修复

  • 环境配置

  • 常用库的常见用法

    • pandas 库常用用法
    • glob 库常用用法
    • sklearn 库常用用法
      • sklearn 库常用用法
    • matplotlib 库常用用法--散点图绘制
    • 图像处理库的常用用法
    • ubuntu 系统常用命令
    • numpy 库常用用法
    • matplotlib 库常用用法
    • tmux 常用用法
    • Docker 常用用法
    • zsh 相关用法
    • 常见的专有名词
    • sharelatex 部署
    • typecho 部署
    • cpp STL 常用用法
    • Tensorboard 常用用法
  • wiki搬运
  • 常用库的常见用法
Muyun99
2021-08-11

sklearn 库常用用法

# sklearn 库常用用法

# train_test_split 方法

使用场景示例:我想要对 trainval 文件分成 train 文件 和 valid 文件

import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold

raw_train_df = pd.read_csv(cfg.path_raw_train_csv)
raw_train_df.columns = ['filename', 'label']

train_data, valid_data = train_test_split(
            raw_train_df, shuffle=True, test_size=cfg.size_valid, random_state=cfg.seed_random)
        train_data.to_csv(os.path.join(cfg.path_save_trainval_csv, f'train.csv'), index=False)
        valid_data.to_csv(os.path.join(cfg.path_save_trainval_csv, f'valid.csv'), index=False)
        print(f'train:{train_data.shape[0]}, valid:{valid_data.shape[0]}')
1
2
3
4
5
6
7
8
9
10
11

# StratifiedKFold 用法

使用示例:我想要将一个 train 文件和 valid 文件分成多个 fold

import pandas as pd
from sklearn.model_selection import StratifiedKFold

raw_train_df = pd.read_csv(cfg.path_raw_train_csv)
raw_train_df.columns = ['filename', 'label']

skf = StratifiedKFold(
    n_splits=cfg.num_KFold,
    random_state=cfg.seed_random,
    shuffle=True
)
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(x, y)):
    fold_train = raw_train_df.iloc[train_idx]
    fold_valid = raw_train_df.iloc[val_idx]
    fold_train.to_csv(os.path.join(cfg.path_save_trainval_csv, f'train_fold{fold_idx}.csv'), index=False)
    fold_valid.to_csv(os.path.join(cfg.path_save_trainval_csv, f'valid_fold{fold_idx}.csv'), index=False)
    print(f'train_fold{fold_idx}: {fold_train.shape[0]}, valid_fold{fold_idx}: {fold_valid.shape[0]}')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
上次更新: 2021/08/17, 18:07:06
glob 库常用用法
matplotlib 库常用用法--散点图绘制

← glob 库常用用法 matplotlib 库常用用法--散点图绘制→

最近更新
01
Structured Knowledge Distillation for Semantic Segmentation
06-03
02
README 美化
05-20
03
常见 Tricks 代码片段
05-12
更多文章>
Theme by Vdoing | Copyright © 2021-2023 Muyun99 | MIT License
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式
×