Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jul 12, 2024
1 parent 85d6df1 commit 476280f
Showing 1 changed file with 40 additions and 36 deletions.
76 changes: 40 additions & 36 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,46 @@ retriever = BM25HF.load_from_hub("${model.id}")`,
];

export const depth_anything_v2 = (model: ModelData): string[] => {
let encoder: string;
let modelConfig: { encoder: string; features: number; out_channels: number[] };

if (model.id === "depth-anything/Depth-Anything-V2-Small") {
encoder = "vits";
modelConfig = { encoder: "vits", features: 64, out_channels: [48, 96, 192, 384] };
} else if (model.id === "depth-anything/Depth-Anything-V2-Base") {
encoder = "vitb";
modelConfig = { encoder: "vitb", features: 128, out_channels: [96, 192, 384, 768] };
} else if (model.id === "depth-anything/Depth-Anything-V2-Large") {
encoder = "vitl";
modelConfig = { encoder: "vitl", features: 256, out_channels: [256, 512, 1024, 1024] };
} else {
throw new Error(`Unsupported model ID: ${model.id}`);
}

return [
`# Install from https://github.com/DepthAnything/Depth-Anything-V2
# Load the model and infer depth from an image
import cv2
import torch
from depth_anything_v2.dpt import DepthAnythingV2
# instantiate the model
encoder = "${encoder}";
model = DepthAnythingV2(${JSON.stringify(modelConfig)});
# load the weights
filepath = hf_hub_download(repo_id="${model.id}", filename=f"depth_anything_v2_${encoder}.pth", repo_type="model");
state_dict = torch.load(filepath, map_location="cpu");
model.load_state_dict(state_dict).eval();
raw_img = cv2.imread("your/image/path");
depth = model.infer_image(raw_img); # HxW raw depth map in numpy`,
];
let encoder: string;
let features: number;
let out_channels: number[];

if (model.id === "depth-anything/Depth-Anything-V2-Small") {
encoder = "vits";
features = 64;
out_channels = [48, 96, 192, 384];
} else if (model.id === "depth-anything/Depth-Anything-V2-Base") {
encoder = "vitb";
features = 128;
out_channels = [96, 192, 384, 768];
} else if (model.id === "depth-anything/Depth-Anything-V2-Large") {
encoder = "vitl";
features = 256;
out_channels = [256, 512, 1024, 1024];
} else {
throw new Error("Unsupported model ID");
}

return [`
# Install from https://github.com/DepthAnything/Depth-Anything-V2
# Load the model and infer depth from an image
import cv2
import torch
from depth_anything_v2.dpt import DepthAnythingV2
# instantiate the model
model = DepthAnythingV2(encoder="${encoder}", features=${features}, out_channels=${JSON.stringify(out_channels)})
# load the weights
filepath = hf_hub_download(repo_id="${model.id}", filename="depth_anything_v2_${encoder}.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict).eval()
raw_img = cv2.imread("your/image/path")
depth = model.infer_image(raw_img) # HxW raw depth map in numpy
`];
};

const diffusers_default = (model: ModelData) => [
Expand Down

0 comments on commit 476280f

Please sign in to comment.