from turtle import down
import sys
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import argparse
def parse_args() :
parser = argparse.ArgumentParser(description="MNIST") # parser를 만듬, description에서는 이 프로그램의 설명이 담겨 있음
parser.add_argument("--mode", dest = "mode", help="train / eval / test", default=None, type=str) # 받아들일 인수를 추가 함
parser.add_argument("--download", dest = "download", help="download MNIST dataset", default=False, type=bool)
parser.add_argument("--output_dir", dest = "output_dir", help="output directory", default="./output", type=str)
parser.add_argument("--checkpoint", dest = "checkpoint", help="checkpoint trained model", default="None", type=str)
if len(sys.argv) == 1: # main.py만 있다는 소리
parser.print_help()
sys.exit()
args = parser.parse_args()
return args
def get_data():
download_root = "./mnist_dataset"
my_transform = transforms.Compose([
transforms.Resize([32,32]),
transforms.ToTensor(),
transforms.Normalize((0.5,),(1.0,))
])
train_dataset = MNIST(root = download_root,
transform = my_transform,
train = True,
download = args.download)
eval_dataset = MNIST(root = download_root,
transform = my_transform,
train = False,
download = args.download)
test_dataset = MNIST(root = download_root,
transform = my_transform,
train = False,
download = args.download)
return train_dataset, eval_dataset, test_dataset
def main() :
print(torch.__version__)
if torch.cuda.is_available():
print("GPU")
device = torch.device("cuda")
else:
print("CPU")
device = torch.device("cpu")
# Get MNIST Dataset
train_dataset, eval_dataset, test_dataset = get_data()
if __name__ == "__main__" :
args = parse_args()
main()