Skip to content

分割标注中Sam类set_video_list方法形式参数与SAM2原始仓库不符 & sam_tools.py中forward_sam_multi_stage函数存在的问题 #2

@FelixxLuo

Description

@FelixxLuo

作者您好,感谢开源。我在尝试使用分割标注时遇到了一些报错,在排查的过程中发现了代码中的一些问题,想在此探讨。

1. Sam类set_video_list方法形式参数数量问题

这里传递了四个参数 (https://github.com/InternRobotics/RoboInter/blob/main/RoboInterTools/tools/sam.py#L96), 但SAM2原始仓库只需要三个参数 (https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_video_predictor.py#L42)

2. sam_tools.py中forward_sam_multi_stage函数存在的问题

sam_tools.py中forward_sam_multi_stage函数调用了set_video_list方法(https://github.com/InternRobotics/RoboInter/blob/main/RoboInterTools/tools/sam_tools.py#L97), 因此并不能成功运行

3. 问题探讨与可能的解决方案 (一孔之见,仅供参考)

forward_sam_multi_stage函数根据forward或是backward将视频分了段, 但SAM2原始仓库只能通过读取video_path加载视频 (https://github.com/facebookresearch/sam2/blob/main/sam2/utils/misc.py#L172), 无法将分好段的视频video_part直接传递sam_model。

因此,我考虑的做法是修改Sam类set_video_list方法,删去形式参数video_list。另外,修改sam_tools.py中forward_sam_multi_stage函数,将video_part暂存成.mp4文件,然后读取该文件。

修改Sam类set_video_list方法,删去形式参数video_list

def set_video_list(self, video_path):
    if video_path[-3:] !="mp4":
        video_path = video_path + ".mp4"
    self.inference_state = self.predictor.init_state(video_path, offload_video_to_cpu=True, offload_state_to_cpu=True)

修改sam_tools.py中forward_sam_multi_stage函数,将video_part暂存成.mp4文件,然后读取该文件 (需要在开头import os, shutil)

def forward_sam_multi_stage(model_config, model_sam):
    video_path = model_config["video_path"]
    is_video = model_config["is_video"]
    select_frame = model_config["select_frame"]
    direction = model_config["direction"]

    # ----------------------修改开始-----------------------------------------
    temp_save_dir = video_path.rsplit(".", 1)[0] + "_cache"
    if not os.path.exists(temp_save_dir):
        os.makedirs(temp_save_dir)
    # ----------------------修改结束-----------------------------------------
    
    if 'ann_human' in video_path:
        video_path = model_config['origin_video_path']
    video = extract_frames(video_path)
    
    if not is_video:
        video = video[select_frame:select_frame + 1]
    elif direction == "forward":
        video = video[select_frame:]     
    elif direction == "backward":
        video = video[:select_frame+1][::-1]
    
    positive_points_dict = model_config["positive_points"][select_frame]
    negative_points_dict = model_config["negative_points"][select_frame]
    labels_dict = model_config["labels"][select_frame]

    positive_points = [np.array(positive_points_dict[obj_idx]) for obj_idx in positive_points_dict.keys()]
    negative_points = [np.array(negative_points_dict[obj_idx]) for obj_idx in positive_points_dict.keys()]
    labels = [labels_dict[obj_idx] for obj_idx in positive_points_dict.keys()]

    for i in range(len(positive_points)):
        if len(positive_points[i]) == 0:
            raise ValueError("No positive points in the frame")
        if len(negative_points[i]) != 0:
            positive_points[i] = np.concatenate([positive_points[i], negative_points[i]], axis=0)
    
    # if length of video is larger than 800, we split the video into N/800 parts
    num_parts = int(len(video) / 800) + 1
    masks_all = np.zeros((len(positive_points), len(video), 1, video[0].shape[0], video[0].shape[1]))
    ind_all = np.arange(len(video))
    
    for i in range(num_parts):
        video_part = video[i::num_parts]
        ind_part = ind_all[i::num_parts]
        
        # must contain the first frame
        video_part = np.concatenate([video[:1], video_part], axis=0)
        ind_part = np.concatenate([np.zeros_like(ind_part[:1]), ind_part], axis=0)       
        
        
        # ----------------------修改开始-----------------------------------------
        video_part_path = os.path.join(temp_save_dir, f"part_{i}.mp4")
        height, width, _ = video_part[0].shape
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(video_part_path, fourcc, 20, (width, height))

        for frame in video_part:
            out.write(frame)
        out.release()
        
        model_sam.set_video_list(video_part_path)
        masks_all[:, ind_part] = model_sam(positive_points, labels, 0, list(positive_points_dict.keys()))
        
        if os.path.exists(temp_save_dir):
            shutil.rmtree(temp_save_dir)
        # ----------------------修改结束-----------------------------------------

    return masks_all

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions