以 SEGNET pytorch 為例子, 安裝pytorch with GPU CUDA support 可以給3080 TI 使用
(1) create environment
conda create --name pytorch2 python=3.8
(2) 進入pytorch ,可以使用選擇環境的方式, 做出 install 的command, copy and execute command line
自行配版本, 總是會發生不確定的問題!
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
https://pytorch.org/get-started/locally/
(3) install scimage
pip install scikit-image
(4) vinceecws/SegNet_PyTorch
source code download
https://github.com/vinceecws/SegNet_PyTorch
(4.1) download data set
官方原始的 data set 已無法使用
改到: github.com/lih/CamVid download 二個目錄, CamVid_RGB, CamVid_Label
(5) train.py
建立 Camvid_RGB, Camvid_label, res
將download 到的 training set, images 放到.上述目錄.
(5) 程式相闗小修改
raw_dir = os.path.join(os.getcwd(), 'CamVid_RGB')
lbl_dir = os.path.join(os.getcwd(), 'CamVid_Label')
SegNet.Train.Train(trainloader, os.path.abspath("checkpoint.pth.tar"))
#如果選YES, 要注意 weight 檔名.
(6) 改變存檔條件, segnet.py
原來要做的Epoch 太久,
if (i%5==4):
Train.save_checkpoint({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, path)
順利的話, 會在source 的目錄, 產生 weight file.tar
(6) test_segnet.py
這項就不是小改可以run, 原始的file 有點bug.
將 main(args)中, 修改
classes_dir = "d:/Segnet"
classes = np.load('classes.npy')
camvid_raw_dir = "d:/Segnet/camvid_RGB"
camvid_labelled_dir = "d:/Segnet/camvid_label"
weight_fn = "checkpoint.pth.tar"
# segnet_weights.pth.tar
res_dir = "res"
就不用paser. 那段code.
(7) run test_segnet.py
>> python test_segnet.py
結果影像在reg 目錄.
留言列表