Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions cnpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,22 @@ cnpy::npz_t cnpy::npz_load(std::string fname) {
}

cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) {
FILE* fp = fopen(fname.c_str(),"rb");
struct AutoCloser
{
FILE * fp;
AutoCloser (FILE * ptr) : fp(ptr) {}
~AutoCloser (void)
{
if (fp != nullptr)
fclose(fp);
}
} closer(fopen(fname.c_str(), "rb"));

if(!fp) throw std::runtime_error("npz_load: Unable to open file "+fname);
if(!closer.fp) throw std::runtime_error("npz_load: Unable to open file "+fname);

while(1) {
std::vector<char> local_header(30);
size_t header_res = fread(&local_header[0],sizeof(char),30,fp);
size_t header_res = fread(&local_header[0],sizeof(char),30,closer.fp);
if(header_res != 30)
throw std::runtime_error("npz_load: failed fread");

Expand All @@ -293,46 +302,51 @@ cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) {
//read in the variable name
uint16_t name_len = *(uint16_t*) &local_header[26];
std::string vname(name_len,' ');
size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp);
size_t vname_res = fread(&vname[0],sizeof(char),name_len,closer.fp);
if(vname_res != name_len)
throw std::runtime_error("npz_load: failed fread");
vname.erase(vname.end()-4,vname.end()); //erase the lagging .npy

//read in the extra field
uint16_t extra_field_len = *(uint16_t*) &local_header[28];
fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field
fseek(closer.fp,extra_field_len,SEEK_CUR); //skip past the extra field

uint16_t compr_method = *reinterpret_cast<uint16_t*>(&local_header[0]+8);
uint32_t compr_bytes = *reinterpret_cast<uint32_t*>(&local_header[0]+18);
uint32_t uncompr_bytes = *reinterpret_cast<uint32_t*>(&local_header[0]+22);

if(vname == varname) {
NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp,compr_bytes,uncompr_bytes);
fclose(fp);
NpyArray array = (compr_method == 0) ? load_the_npy_file(closer.fp) : load_the_npz_array(closer.fp,compr_bytes,uncompr_bytes);
return array;
}
else {
//skip past the data
uint32_t size = *(uint32_t*) &local_header[22];
fseek(fp,size,SEEK_CUR);
fseek(closer.fp,size,SEEK_CUR);
}
}

fclose(fp);

//if we get here, we haven't found the variable in the file
throw std::runtime_error("npz_load: Variable name "+varname+" not found in "+fname);
}

cnpy::NpyArray cnpy::npy_load(std::string fname) {

FILE* fp = fopen(fname.c_str(), "rb");
struct AutoCloser
{
FILE * fp;
AutoCloser (FILE * ptr) : fp(ptr) {}
~AutoCloser (void)
{
if (fp != nullptr)
fclose(fp);
}
} closer(fopen(fname.c_str(), "rb"));

if(!fp) throw std::runtime_error("npy_load: Unable to open file "+fname);
if(!closer.fp) throw std::runtime_error("npy_load: Unable to open file "+fname);

NpyArray arr = load_the_npy_file(fp);
NpyArray arr = load_the_npy_file(closer.fp);

fclose(fp);
return arr;
}

Expand Down