diff --git a/.gitignore b/.gitignore index 871e4b8..f79af86 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ data/ logging_accuracy/ +__pycache__/ diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..74ca340 --- /dev/null +++ b/environment.yml @@ -0,0 +1,106 @@ +name: scratch +channels: + - pytorch + - anaconda + - conda-forge + - defaults +dependencies: + - bitarray #=2.3.6 + - blas + - brotlipy + - bzip2 + - ca-certificates + - certifi + - cffi + - charset-normalizer + - click + - colorama + - cryptography + - cvxopt + - dataclasses + - dsdp + - ffmpeg + - fftw + - filelock + - freetype + - giflib + - glpk + - gmp + - gnutls + - gsl + - huggingface_hub + - idna + - importlib-metadata + - importlib_metadata + - joblib + - jpeg + - lame + - lcms2 + - ld_impl_linux-64 + - libblas + - libcblas + - libffi + - libidn2 + - liblapack + - libopus + - libpng + - libtasn1 + - libtiff + - libunistring + - libuv + - libvpx # + - libwebp + - libwebp-base + - lz4-c + - metis + - mpfr + - ncurses + - nettle + - numpy + - numpy-base + - olefile + - openh264 + - openssl + - packaging + - pandas + - pillow + - pip + - pycparser + - pyopenssl + - pyparsing + - pysocks + - python + - python-dateutil + - python_abi + - pytorch + - pytorch-mutex + - pytz + - pyyaml + - readline + - regex + - requests + - sacremoses + - scikit-learn + - scipy + - setuptools + - six + - sqlite + - suitesparse + - tbb + - threadpoolctl + - tk + - tokenizers + - torchaudio # + - torchvision + - tqdm + - transformers + - typing-extensions + - typing_extensions + - urllib3 + - wheel + - x264 # + - xz + - yaml + - zipp + - zlib + - zstd diff --git a/helper.py b/helper.py index 6e669ca..44cc062 100644 --- a/helper.py +++ b/helper.py @@ -121,6 +121,7 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): net=None X=None Y=None + NUM_WORKERS=0 if data_type=='gaussian': ''' @@ -140,13 +141,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model @@ -169,13 +170,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 2) @@ -194,13 +195,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 2) @@ -231,13 +232,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 784) @@ -268,13 +269,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 784) @@ -304,13 +305,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 784) @@ -340,13 +341,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 3072) @@ -376,13 +377,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 3072) @@ -411,13 +412,13 @@ def get_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_size): Y = u_traindata.data.reshape((u_traindata.data.shape[0], -1)) p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=int(batch_size*(1-beta)/beta), \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = p_data.shape[-1]) @@ -461,6 +462,7 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader=None u_validloader=None net=None + NUM_WORKERS=0 if data_type=='gaussian': ''' @@ -478,9 +480,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model @@ -500,9 +502,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_validdata = get_PNDataSplits(toy_testdata, unlabeled_size=pos_size*2) u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=pos_size*2, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=pos_size*2, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 2) @@ -519,9 +521,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=pos_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 2) @@ -550,9 +552,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 784) @@ -581,9 +583,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 784) @@ -610,9 +612,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_validdata = get_PNDataSplits(testdata, pos_size=int(500*alpha), neg_size=int(500*alpha), data_type='cifar') u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 3072) @@ -640,9 +642,9 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=batch_size, \ - shuffle=True, num_workers=2) + shuffle=True, num_workers=NUM_WORKERS) ## Initialize model net = get_model(net_type, input_dim = 3072) @@ -673,4 +675,4 @@ def get_PN_dataset(data_dir, data_type,net_type, device, alpha, beta, batch_siz net = get_model(net_type) net = net.to(device) - return u_trainloader, u_validloader, net \ No newline at end of file + return u_trainloader, u_validloader, net diff --git a/train_PU.py b/train_PU.py index b8ae0f6..bd8e0fa 100644 --- a/train_PU.py +++ b/train_PU.py @@ -46,13 +46,14 @@ torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) +torch.mps.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) print(args) net_type = args.net_type -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.has_mps else 'cpu') train_method = args.train_method data_type = args.data_type ## Train set for positive and unlabeled @@ -104,6 +105,10 @@ net = torch.nn.DataParallel(net) cudnn.benchmark = True +if device.startswith('mps'): + net = torch.nn.DataParallel(net) + torch.backends.mps.benchmark = True + criterion = nn.CrossEntropyLoss() if optimizer_str=="SGD": @@ -287,4 +292,4 @@ else: print(TiCE_estimate(X,Y,data_type)) -outfile.close() \ No newline at end of file +outfile.close()