From 6c0a6e13b6e90640d32d5a3821dc10503857056b Mon Sep 17 00:00:00 2001 From: Corneliu Cofaru Date: Mon, 17 Sep 2018 17:54:53 +0200 Subject: [PATCH] Another round of bugfixes --- src/conceptnet.jl | 32 +++++++++++++++++--------------- test/runtests.jl | 4 ++-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/conceptnet.jl b/src/conceptnet.jl index 6a9c175..0ee084b 100644 --- a/src/conceptnet.jl +++ b/src/conceptnet.jl @@ -10,7 +10,7 @@ function init(::Type{ConceptNet}) DataDeps.unpack, :multi), ("English", - "https://conceptnet.s3.amazonaws.com/downloads/2017/numberbatch/numberbatch-17.06.txt.gz", + "https://conceptnet.s3.amazonaws.com/downloads/2017/numberbatch/numberbatch-en-17.06.txt.gz", "numberbatch-en-17.06.txt", "72faf0a487c61b9a6a8c9ff0a1440d2f4936bb19102bddf27a833c2567620f2d", DataDeps.unpack, @@ -47,15 +47,10 @@ end function _load_embeddings(::Type{<:ConceptNet}, embedding_file, max_vocab_size, keep_words) local LL, indexed_words, index - if any(endswith.(embedding_file, [".h5", ".hdf5"])) - LL, indexed_words = _load_hdf5_embeddings(embedding_file, - max_vocab_size=max_vocab_size, - keep_words=keep_words) - else - LL, indexed_words = _load_txt_embeddings(embedding_file, - max_vocab_size=max_vocab_size, - keep_words=keep_words) - end + _loader = ifelse(any(endswith.(embedding_file, [".h5", ".hdf5"])), + _load_hdf5_embeddings, + _load_txt_embeddings) + LL, indexed_words = _loader(embedding_file, max_vocab_size, keep_words) return LL, indexed_words end @@ -69,6 +64,7 @@ end function _load_txt_embeddings(file::AbstractString, max_vocab_size, keep_words) + local LL, indexed_words open(file, "r") do fid vocab_size, vector_size = map(x->parse(Int,x), split(readline(fid))) max_stored_vocab_size = _get_vocab_size(vocab_size, max_vocab_size) @@ -84,18 +80,23 @@ function _load_txt_embeddings(file::AbstractString, max_vocab_size, keep_words) return word, embedding end + # TODO Improve performance of this bit cnt = 0 + indices = Int[] for (index, row) in enumerate(data) word, embedding = _parseline(row) + LL[:, index] = embedding + indexed_words[index] = word if length(keep_words)==0 || word in keep_words - LL[:, index] = embedding - idexed_words[index] = word + push!(indices, index) cnt+=1 - if cnt > max_stored_vocab_size + if cnt > max_stored_vocab_size-1 break end end end + LL = LL[:, indices] + indexed_words = indexed_words[indices] end return LL, indexed_words end @@ -106,14 +107,15 @@ function _load_hdf5_embeddings(file::AbstractString, max_vocab_size, keep_words) payload = h5open(read, file)["mat"] words = payload["axis1"] vectors = payload["block0_values"] - max_vocab_size = _get_vocab_size(length(words), max_vocab_size) + max_stored_vocab_size = _get_vocab_size(length(words), max_vocab_size) indices = Int[] + cnt = 0 for (index, word) in enumerate(words) if length(keep_words)==0 || word in keep_words push!(indices, index) cnt+=1 - if cnt > max_stored_vocab_size + if cnt > max_stored_vocab_size-1 break end end diff --git a/test/runtests.jl b/test/runtests.jl index 917a8c1..b7fa198 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -191,9 +191,9 @@ end @test embs_mini.vocab == embs_full.vocab[1:100] embs_specific = load_embeddings(ConceptNet{:compressed}; - keep_words=Set(["red", "green", "blue"])) + keep_words=Set(["/c/en/red", "/c/en/green", "/c/en/blue"])) @test size(embs_specific.embeddings) == (300, 3) - @test Set(embs_specific.vocab) == Set(["red", "green", "blue"]) + @test Set(embs_specific.vocab) == Set(["/c/en/red", "/c/en/green", "/c/en/blue"]) end end