feat(ml): composable ml (#9973)

* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
This commit is contained in:
Mert 2024-06-06 23:09:47 -04:00 committed by GitHub
parent 7a46f80ddc
commit 2b1b43a7e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 982 additions and 999 deletions

View file

@ -1,8 +1,7 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsEnum, IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { CLIPMode, ModelType } from 'src/interfaces/machine-learning.interface';
import { Optional, ValidateBoolean } from 'src/validation';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { ValidateBoolean } from 'src/validation';
export class TaskConfig {
@ValidateBoolean()
@ -13,19 +12,9 @@ export class ModelConfig extends TaskConfig {
@IsString()
@IsNotEmpty()
modelName!: string;
@IsEnum(ModelType)
@Optional()
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
modelType?: ModelType;
}
export class CLIPConfig extends ModelConfig {
@IsEnum(CLIPMode)
@Optional()
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
mode?: CLIPMode;
}
export class CLIPConfig extends ModelConfig {}
export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()
@ -36,7 +25,7 @@ export class DuplicateDetectionConfig extends TaskConfig {
maxDistance!: number;
}
export class RecognitionConfig extends ModelConfig {
export class FacialRecognitionConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)

View file

@ -30,7 +30,7 @@ import {
TranscodePolicy,
VideoCodec,
} from 'src/config';
import { CLIPConfig, DuplicateDetectionConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
import { CLIPConfig, DuplicateDetectionConfig, FacialRecognitionConfig } from 'src/dtos/model-config.dto';
import { ConcurrentQueueName, QueueName } from 'src/interfaces/job.interface';
import { ValidateBoolean, validateCronExpression } from 'src/validation';
@ -270,10 +270,10 @@ class SystemConfigMachineLearningDto {
@IsObject()
duplicateDetection!: DuplicateDetectionConfig;
@Type(() => RecognitionConfig)
@Type(() => FacialRecognitionConfig)
@ValidateNested()
@IsObject()
facialRecognition!: RecognitionConfig;
facialRecognition!: FacialRecognitionConfig;
}
enum MapTheme {

View file

@ -1,15 +1,5 @@
import { CLIPConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
export const IMachineLearningRepository = 'IMachineLearningRepository';
export interface VisionModelInput {
imagePath: string;
}
export interface TextModelInput {
text: string;
}
export interface BoundingBox {
x1: number;
y1: number;
@ -17,26 +7,51 @@ export interface BoundingBox {
y2: number;
}
export interface DetectFaceResult {
imageWidth: number;
imageHeight: number;
boundingBox: BoundingBox;
score: number;
embedding: number[];
export enum ModelTask {
FACIAL_RECOGNITION = 'facial-recognition',
SEARCH = 'clip',
}
export enum ModelType {
FACIAL_RECOGNITION = 'facial-recognition',
CLIP = 'clip',
DETECTION = 'detection',
PIPELINE = 'pipeline',
RECOGNITION = 'recognition',
TEXTUAL = 'textual',
VISUAL = 'visual',
}
export enum CLIPMode {
VISION = 'vision',
TEXT = 'text',
export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string };
export type FaceDetectionOptions = ModelOptions & { minScore: number };
type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: FaceDetectionOptions;
[ModelType.RECOGNITION]: ModelOptions;
};
};
export interface Face {
boundingBox: BoundingBox;
embedding: number[];
score: number;
}
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
export interface IMachineLearningRepository {
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
}

View file

@ -37,8 +37,6 @@ export interface SearchExploreItem<T> {
items: SearchExploreItemSet<T>;
}
export type Embedding = number[];
export interface SearchAssetIDOptions {
checksum?: Buffer;
deviceAssetId?: string;
@ -106,7 +104,7 @@ export interface SearchExifOptions {
}
export interface SearchEmbeddingOptions {
embedding: Embedding;
embedding: number[];
userIds: string[];
}
@ -154,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
export interface AssetDuplicateSearch {
assetId: string;
embedding: Embedding;
embedding: number[];
maxDistance?: number;
type: AssetType;
userIds: string[];

View file

@ -1,13 +1,16 @@
import { Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
import { CLIPConfig } from 'src/dtos/model-config.dto';
import {
CLIPMode,
DetectFaceResult,
ClipTextualResponse,
ClipVisualResponse,
FaceDetectionOptions,
FacialRecognitionResponse,
IMachineLearningRepository,
MachineLearningRequest,
ModelPayload,
ModelTask,
ModelType,
TextModelInput,
VisionModelInput,
} from 'src/interfaces/machine-learning.interface';
import { Instrumentation } from 'src/utils/instrumentation';
@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
@Instrumentation()
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
const formData = await this.getFormData(input, config);
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
);
if (res.status >= 400) {
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
}
return res.json();
}
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: { modelName, minScore },
[ModelType.RECOGNITION]: { modelName },
},
};
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
return {
imageHeight: response.imageHeight,
imageWidth: response.imageWidth,
faces: response[ModelTask.FACIAL_RECOGNITION],
};
}
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
return this.predict<number[]>(url, input, {
...config,
modelType: ModelType.CLIP,
mode: CLIPMode.VISION,
} as CLIPConfig);
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
return response[ModelTask.SEARCH];
}
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
return this.predict<number[]>(url, input, {
...config,
modelType: ModelType.CLIP,
mode: CLIPMode.TEXT,
} as CLIPConfig);
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
return response[ModelTask.SEARCH];
}
private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
const formData = new FormData();
const { enabled, modelName, modelType, ...options } = config;
if (!enabled) {
throw new Error(`${modelType} is not enabled`);
}
formData.append('entries', JSON.stringify(config));
formData.append('modelName', modelName);
if (modelType) {
formData.append('modelType', modelType);
}
if (options) {
formData.append('options', JSON.stringify(options));
}
if ('imagePath' in input) {
formData.append('image', new Blob([await readFile(input.imagePath)]));
} else if ('text' in input) {
formData.append('text', input.text);
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}

View file

@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { IMediaRepository } from 'src/interfaces/media.interface';
import { IMoveRepository } from 'src/interfaces/move.interface';
import { IPersonRepository } from 'src/interfaces/person.interface';
@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
const statistics = { assets: 3 };
const detectFaceMock = {
assetId: 'asset-1',
personId: 'person-1',
boundingBox: {
x1: 100,
y1: 100,
x2: 200,
y2: 200,
},
const detectFaceMock: DetectedFaces = {
faces: [
{
boundingBox: {
x1: 100,
y1: 100,
x2: 200,
y2: 200,
},
embedding: [1, 2, 3, 4],
score: 0.2,
},
],
imageHeight: 500,
imageWidth: 400,
embedding: [1, 2, 3, 4],
score: 0.2,
};
describe(PersonService.name, () => {
@ -642,21 +644,13 @@ describe(PersonService.name, () => {
it('should handle no results', async () => {
const start = Date.now();
machineLearningMock.detectFaces.mockResolvedValue([]);
machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{
imagePath: assetStub.image.previewPath,
},
{
enabled: true,
maxDistance: 0.5,
minScore: 0.7,
minFaces: 3,
modelName: 'buffalo_l',
},
assetStub.image.previewPath,
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
);
expect(personMock.createFaces).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
@ -671,7 +665,7 @@ describe(PersonService.name, () => {
it('should create a face with no person and queue recognition job', async () => {
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
const face = {

View file

@ -333,26 +333,28 @@ export class PersonService {
return JobStatus.SKIPPED;
}
const faces = await this.machineLearningRepository.detectFaces(
if (!asset.isVisible) {
return JobStatus.SKIPPED;
}
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
machineLearning.url,
{ imagePath: asset.previewPath },
asset.previewPath,
machineLearning.facialRecognition,
);
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
if (faces.length > 0) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
imageHeight: face.imageHeight,
imageWidth: face.imageWidth,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
}));

View file

@ -102,12 +102,7 @@ export class SearchService {
const userIds = await this.getUserIdsToSearch(auth);
const embedding = await this.machineLearning.encodeText(
machineLearning.url,
{ text: dto.query },
machineLearning.clip,
);
const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
const page = dto.page ?? 1;
const size = dto.size || 100;
const { hasNextPage, items } = await this.searchRepository.searchSmart(

View file

@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{ imagePath: assetStub.image.previewPath },
{ enabled: true, modelName: 'ViT-B-32__openai' },
assetStub.image.previewPath,
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
);
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
});

View file

@ -93,9 +93,9 @@ export class SmartInfoService {
return JobStatus.FAILED;
}
const clipEmbedding = await this.machineLearning.encodeImage(
const embedding = await this.machineLearning.encodeImage(
machineLearning.url,
{ imagePath: asset.previewPath },
asset.previewPath,
machineLearning.clip,
);
@ -104,7 +104,7 @@ export class SmartInfoService {
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
}
await this.repository.upsert(asset.id, clipEmbedding);
await this.repository.upsert(asset.id, embedding);
return JobStatus.SUCCESS;
}