diff --git a/cnpy.cpp b/cnpy.cpp index 2d28578..9074e0e 100644 --- a/cnpy.cpp +++ b/cnpy.cpp @@ -277,13 +277,22 @@ cnpy::npz_t cnpy::npz_load(std::string fname) { } cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) { - FILE* fp = fopen(fname.c_str(),"rb"); + struct AutoCloser + { + FILE * fp; + AutoCloser (FILE * ptr) : fp(ptr) {} + ~AutoCloser (void) + { + if (fp != nullptr) + fclose(fp); + } + } closer(fopen(fname.c_str(), "rb")); - if(!fp) throw std::runtime_error("npz_load: Unable to open file "+fname); + if(!closer.fp) throw std::runtime_error("npz_load: Unable to open file "+fname); while(1) { std::vector local_header(30); - size_t header_res = fread(&local_header[0],sizeof(char),30,fp); + size_t header_res = fread(&local_header[0],sizeof(char),30,closer.fp); if(header_res != 30) throw std::runtime_error("npz_load: failed fread"); @@ -293,46 +302,51 @@ cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) { //read in the variable name uint16_t name_len = *(uint16_t*) &local_header[26]; std::string vname(name_len,' '); - size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp); + size_t vname_res = fread(&vname[0],sizeof(char),name_len,closer.fp); if(vname_res != name_len) throw std::runtime_error("npz_load: failed fread"); vname.erase(vname.end()-4,vname.end()); //erase the lagging .npy //read in the extra field uint16_t extra_field_len = *(uint16_t*) &local_header[28]; - fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field + fseek(closer.fp,extra_field_len,SEEK_CUR); //skip past the extra field uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); if(vname == varname) { - NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp,compr_bytes,uncompr_bytes); - fclose(fp); + NpyArray array = (compr_method == 0) ? load_the_npy_file(closer.fp) : load_the_npz_array(closer.fp,compr_bytes,uncompr_bytes); return array; } else { //skip past the data uint32_t size = *(uint32_t*) &local_header[22]; - fseek(fp,size,SEEK_CUR); + fseek(closer.fp,size,SEEK_CUR); } } - fclose(fp); - //if we get here, we haven't found the variable in the file throw std::runtime_error("npz_load: Variable name "+varname+" not found in "+fname); } cnpy::NpyArray cnpy::npy_load(std::string fname) { - FILE* fp = fopen(fname.c_str(), "rb"); + struct AutoCloser + { + FILE * fp; + AutoCloser (FILE * ptr) : fp(ptr) {} + ~AutoCloser (void) + { + if (fp != nullptr) + fclose(fp); + } + } closer(fopen(fname.c_str(), "rb")); - if(!fp) throw std::runtime_error("npy_load: Unable to open file "+fname); + if(!closer.fp) throw std::runtime_error("npy_load: Unable to open file "+fname); - NpyArray arr = load_the_npy_file(fp); + NpyArray arr = load_the_npy_file(closer.fp); - fclose(fp); return arr; }