Last active
September 12, 2025 15:32
-
-
Save MSiam/735ab7bad8b4b232ed858e8ecec5a1f8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # To run the code | |
| # export PYTHONPATH=INTERVL_CODE_PATH/internvlchat/:$PYTHONPATH | |
| # python demo_internvl_grounding.py --auto --image_path IMAGE_PATH --exp EXPRESSION | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from torchvision.transforms.functional import InterpolationMode | |
| import re | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| from internvl.model import load_model_and_tokenizer | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| def build_transform(input_size): | |
| MEAN, STD = IMAGENET_MEAN, IMAGENET_STD | |
| transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=MEAN, std=STD) | |
| ]) | |
| return transform | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float('inf') | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
| return best_ratio | |
| def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): | |
| orig_width, orig_height = image.size | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
| i * j <= max_num and i * j >= min_num) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the image | |
| resized_img = image.resize((target_width, target_height)) | |
| processed_images = [] | |
| for i in range(blocks): | |
| box = ( | |
| (i % (target_width // image_size)) * image_size, | |
| (i // (target_width // image_size)) * image_size, | |
| ((i % (target_width // image_size)) + 1) * image_size, | |
| ((i // (target_width // image_size)) + 1) * image_size | |
| ) | |
| # split the image | |
| split_img = resized_img.crop(box) | |
| processed_images.append(split_img) | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = image.resize((image_size, image_size)) | |
| processed_images.append(thumbnail_img) | |
| return processed_images, (orig_width, orig_height, target_width, target_height) | |
| def load_image(image_file, input_size=448, max_num=6): | |
| image = Image.open(image_file).convert('RGB') | |
| transform = build_transform(input_size=input_size) | |
| images, meta_info = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) | |
| pixel_values = [transform(image) for image in images] | |
| pixel_values = torch.stack(pixel_values) | |
| return pixel_values, meta_info | |
| def rescale_boxes(boxes, meta_info): | |
| ow, oh, tw, th = meta_info | |
| boxes_list = [] | |
| for box in boxes: | |
| box[0] *= float(ow)/tw | |
| box[2] *= float(ow)/tw | |
| box[1] *= float(oh)/th | |
| box[3] *= float(oh)/th | |
| box = [int(r) for r in box] | |
| boxes_list.append(box) | |
| return boxes_list | |
| def parse_response_internvl(response, image_size=(448, 448)): | |
| PATTERN = re.compile(r'\[*\[(.*?), (.*?), (.*?), (.*?)\]\]*') | |
| response1 = re.findall(PATTERN, response) | |
| response1 = [(float(response1[i][0]), float(response1[i][1]), float(response1[i][2]), float(response1[i][3])) for i in range(len(response1))] | |
| final_response = np.array(response1, dtype=np.float32) / 1000 | |
| final_response[:, 0::2] *= image_size[0] | |
| final_response[:, 1::2] *= image_size[1] | |
| final_response = np.array(final_response, dtype=np.int32) | |
| return final_response | |
| def vis_boxes(final_response, img): | |
| for response in final_response: | |
| start = response[:2] | |
| end = response[2:] | |
| img = cv2.rectangle(img, start, end, (255, 0, 0), 2) | |
| plt.imshow(img); | |
| plt.savefig("internvl3.png") | |
| plt.show() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str, default='OpenGVLab/InternVL3-8B') | |
| parser.add_argument('--auto', action='store_true') | |
| parser.add_argument('--load-in-8bit', action='store_true') | |
| parser.add_argument('--load-in-4bit', action='store_true') | |
| parser.add_argument('--image_path', type=str, default='example.png') | |
| parser.add_argument('--exp', type=str, default='moving elephants') | |
| args = parser.parse_args() | |
| model, tokenizer = load_model_and_tokenizer(args) | |
| image_size = model.config.force_image_size or model.config.vision_config.image_size | |
| use_thumbnail = model.config.use_thumbnail | |
| pixel_values, meta_info = load_image(args.image_path) | |
| generation_config = dict(num_beams=1, max_new_tokens=100, min_new_tokens=1, do_sample=False, temperature=0.0) | |
| pixel_values = pixel_values.to(torch.bfloat16).cuda() | |
| prompt = f'Please provide the bounding box coordinate of the region this sentence describes: <ref>{args.exp}</ref>' | |
| response = model.chat( | |
| tokenizer=tokenizer, | |
| pixel_values=pixel_values, | |
| question=prompt, | |
| generation_config=generation_config, | |
| verbose=True | |
| ) | |
| boxes = parse_response_internvl(response, meta_info[:2]) | |
| img = np.array(Image.open(args.image_path)) | |
| vis_boxes(boxes, img) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment