PyTorch 사용

from turtle import down
import sys
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import argparse 

def parse_args() :
    parser = argparse.ArgumentParser(description="MNIST") # parser를 만듬, description에서는 이 프로그램의 설명이 담겨 있음
    parser.add_argument("--mode", dest = "mode", help="train / eval / test", default=None, type=str) # 받아들일 인수를 추가 함
    parser.add_argument("--download", dest = "download", help="download MNIST dataset", default=False, type=bool)
    parser.add_argument("--output_dir", dest = "output_dir", help="output directory", default="./output", type=str)
    parser.add_argument("--checkpoint", dest = "checkpoint", help="checkpoint trained model", default="None", type=str)

    if len(sys.argv) == 1: # main.py만 있다는 소리
        parser.print_help()
        sys.exit()
    args = parser.parse_args()
    return args


def get_data():
    download_root = "./mnist_dataset"
    my_transform = transforms.Compose([
        transforms.Resize([32,32]),
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(1.0,))
    ])
    train_dataset = MNIST(root = download_root, 
                          transform = my_transform,
                          train = True,
                          download = args.download)
    eval_dataset = MNIST(root = download_root,
                         transform = my_transform,
                         train = False,
                         download = args.download)
    test_dataset = MNIST(root = download_root,
                         transform = my_transform,
                         train = False,
                         download = args.download)
    
    return train_dataset, eval_dataset, test_dataset

def main() : 
    print(torch.__version__)

    if torch.cuda.is_available():
        print("GPU")
        device = torch.device("cuda")
    
    else:
        print("CPU")
        device = torch.device("cpu")

    # Get MNIST Dataset
    train_dataset, eval_dataset, test_dataset = get_data()

if __name__ == "__main__" :
    args = parse_args()
    main()