mirror of
https://github.com/immich-app/immich
synced 2025-11-14 17:36:12 +00:00
feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
parent
7a46f80ddc
commit
2b1b43a7e4
39 changed files with 982 additions and 999 deletions
|
|
@ -1,8 +1,7 @@
|
|||
import { ApiProperty } from '@nestjs/swagger';
|
||||
import { Type } from 'class-transformer';
|
||||
import { IsEnum, IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
||||
import { CLIPMode, ModelType } from 'src/interfaces/machine-learning.interface';
|
||||
import { Optional, ValidateBoolean } from 'src/validation';
|
||||
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
||||
import { ValidateBoolean } from 'src/validation';
|
||||
|
||||
export class TaskConfig {
|
||||
@ValidateBoolean()
|
||||
|
|
@ -13,19 +12,9 @@ export class ModelConfig extends TaskConfig {
|
|||
@IsString()
|
||||
@IsNotEmpty()
|
||||
modelName!: string;
|
||||
|
||||
@IsEnum(ModelType)
|
||||
@Optional()
|
||||
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
|
||||
modelType?: ModelType;
|
||||
}
|
||||
|
||||
export class CLIPConfig extends ModelConfig {
|
||||
@IsEnum(CLIPMode)
|
||||
@Optional()
|
||||
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
|
||||
mode?: CLIPMode;
|
||||
}
|
||||
export class CLIPConfig extends ModelConfig {}
|
||||
|
||||
export class DuplicateDetectionConfig extends TaskConfig {
|
||||
@IsNumber()
|
||||
|
|
@ -36,7 +25,7 @@ export class DuplicateDetectionConfig extends TaskConfig {
|
|||
maxDistance!: number;
|
||||
}
|
||||
|
||||
export class RecognitionConfig extends ModelConfig {
|
||||
export class FacialRecognitionConfig extends ModelConfig {
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Max(1)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import {
|
|||
TranscodePolicy,
|
||||
VideoCodec,
|
||||
} from 'src/config';
|
||||
import { CLIPConfig, DuplicateDetectionConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
import { CLIPConfig, DuplicateDetectionConfig, FacialRecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
import { ConcurrentQueueName, QueueName } from 'src/interfaces/job.interface';
|
||||
import { ValidateBoolean, validateCronExpression } from 'src/validation';
|
||||
|
||||
|
|
@ -270,10 +270,10 @@ class SystemConfigMachineLearningDto {
|
|||
@IsObject()
|
||||
duplicateDetection!: DuplicateDetectionConfig;
|
||||
|
||||
@Type(() => RecognitionConfig)
|
||||
@Type(() => FacialRecognitionConfig)
|
||||
@ValidateNested()
|
||||
@IsObject()
|
||||
facialRecognition!: RecognitionConfig;
|
||||
facialRecognition!: FacialRecognitionConfig;
|
||||
}
|
||||
|
||||
enum MapTheme {
|
||||
|
|
|
|||
|
|
@ -1,15 +1,5 @@
|
|||
import { CLIPConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
|
||||
export const IMachineLearningRepository = 'IMachineLearningRepository';
|
||||
|
||||
export interface VisionModelInput {
|
||||
imagePath: string;
|
||||
}
|
||||
|
||||
export interface TextModelInput {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface BoundingBox {
|
||||
x1: number;
|
||||
y1: number;
|
||||
|
|
@ -17,26 +7,51 @@ export interface BoundingBox {
|
|||
y2: number;
|
||||
}
|
||||
|
||||
export interface DetectFaceResult {
|
||||
imageWidth: number;
|
||||
imageHeight: number;
|
||||
boundingBox: BoundingBox;
|
||||
score: number;
|
||||
embedding: number[];
|
||||
export enum ModelTask {
|
||||
FACIAL_RECOGNITION = 'facial-recognition',
|
||||
SEARCH = 'clip',
|
||||
}
|
||||
|
||||
export enum ModelType {
|
||||
FACIAL_RECOGNITION = 'facial-recognition',
|
||||
CLIP = 'clip',
|
||||
DETECTION = 'detection',
|
||||
PIPELINE = 'pipeline',
|
||||
RECOGNITION = 'recognition',
|
||||
TEXTUAL = 'textual',
|
||||
VISUAL = 'visual',
|
||||
}
|
||||
|
||||
export enum CLIPMode {
|
||||
VISION = 'vision',
|
||||
TEXT = 'text',
|
||||
export type ModelPayload = { imagePath: string } | { text: string };
|
||||
|
||||
type ModelOptions = { modelName: string };
|
||||
|
||||
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||
|
||||
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||
|
||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||
|
||||
export type FacialRecognitionRequest = {
|
||||
[ModelTask.FACIAL_RECOGNITION]: {
|
||||
[ModelType.DETECTION]: FaceDetectionOptions;
|
||||
[ModelType.RECOGNITION]: ModelOptions;
|
||||
};
|
||||
};
|
||||
|
||||
export interface Face {
|
||||
boundingBox: BoundingBox;
|
||||
embedding: number[];
|
||||
score: number;
|
||||
}
|
||||
|
||||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||
|
||||
export interface IMachineLearningRepository {
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
|
||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ export interface SearchExploreItem<T> {
|
|||
items: SearchExploreItemSet<T>;
|
||||
}
|
||||
|
||||
export type Embedding = number[];
|
||||
|
||||
export interface SearchAssetIDOptions {
|
||||
checksum?: Buffer;
|
||||
deviceAssetId?: string;
|
||||
|
|
@ -106,7 +104,7 @@ export interface SearchExifOptions {
|
|||
}
|
||||
|
||||
export interface SearchEmbeddingOptions {
|
||||
embedding: Embedding;
|
||||
embedding: number[];
|
||||
userIds: string[];
|
||||
}
|
||||
|
||||
|
|
@ -154,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
|
|||
|
||||
export interface AssetDuplicateSearch {
|
||||
assetId: string;
|
||||
embedding: Embedding;
|
||||
embedding: number[];
|
||||
maxDistance?: number;
|
||||
type: AssetType;
|
||||
userIds: string[];
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import { Injectable } from '@nestjs/common';
|
||||
import { readFile } from 'node:fs/promises';
|
||||
import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
import { CLIPConfig } from 'src/dtos/model-config.dto';
|
||||
import {
|
||||
CLIPMode,
|
||||
DetectFaceResult,
|
||||
ClipTextualResponse,
|
||||
ClipVisualResponse,
|
||||
FaceDetectionOptions,
|
||||
FacialRecognitionResponse,
|
||||
IMachineLearningRepository,
|
||||
MachineLearningRequest,
|
||||
ModelPayload,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
TextModelInput,
|
||||
VisionModelInput,
|
||||
} from 'src/interfaces/machine-learning.interface';
|
||||
import { Instrumentation } from 'src/utils/instrumentation';
|
||||
|
||||
|
|
@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
|
|||
@Instrumentation()
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
||||
const formData = await this.getFormData(input, config);
|
||||
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||
const formData = await this.getFormData(payload, config);
|
||||
|
||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
||||
(error: Error | any) => {
|
||||
|
|
@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
|||
);
|
||||
|
||||
if (res.status >= 400) {
|
||||
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
|
||||
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
|
||||
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
||||
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
||||
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||
const request = {
|
||||
[ModelTask.FACIAL_RECOGNITION]: {
|
||||
[ModelType.DETECTION]: { modelName, minScore },
|
||||
[ModelType.RECOGNITION]: { modelName },
|
||||
},
|
||||
};
|
||||
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
|
||||
return {
|
||||
imageHeight: response.imageHeight,
|
||||
imageWidth: response.imageWidth,
|
||||
faces: response[ModelTask.FACIAL_RECOGNITION],
|
||||
};
|
||||
}
|
||||
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.predict<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.VISION,
|
||||
} as CLIPConfig);
|
||||
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.predict<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.TEXT,
|
||||
} as CLIPConfig);
|
||||
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
||||
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
const { enabled, modelName, modelType, ...options } = config;
|
||||
if (!enabled) {
|
||||
throw new Error(`${modelType} is not enabled`);
|
||||
}
|
||||
formData.append('entries', JSON.stringify(config));
|
||||
|
||||
formData.append('modelName', modelName);
|
||||
if (modelType) {
|
||||
formData.append('modelType', modelType);
|
||||
}
|
||||
if (options) {
|
||||
formData.append('options', JSON.stringify(options));
|
||||
}
|
||||
if ('imagePath' in input) {
|
||||
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
||||
} else if ('text' in input) {
|
||||
formData.append('text', input.text);
|
||||
if ('imagePath' in payload) {
|
||||
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||
} else if ('text' in payload) {
|
||||
formData.append('text', payload.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
|
|||
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
||||
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
|
||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { IMediaRepository } from 'src/interfaces/media.interface';
|
||||
import { IMoveRepository } from 'src/interfaces/move.interface';
|
||||
import { IPersonRepository } from 'src/interfaces/person.interface';
|
||||
|
|
@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
|
|||
|
||||
const statistics = { assets: 3 };
|
||||
|
||||
const detectFaceMock = {
|
||||
assetId: 'asset-1',
|
||||
personId: 'person-1',
|
||||
boundingBox: {
|
||||
x1: 100,
|
||||
y1: 100,
|
||||
x2: 200,
|
||||
y2: 200,
|
||||
},
|
||||
const detectFaceMock: DetectedFaces = {
|
||||
faces: [
|
||||
{
|
||||
boundingBox: {
|
||||
x1: 100,
|
||||
y1: 100,
|
||||
x2: 200,
|
||||
y2: 200,
|
||||
},
|
||||
embedding: [1, 2, 3, 4],
|
||||
score: 0.2,
|
||||
},
|
||||
],
|
||||
imageHeight: 500,
|
||||
imageWidth: 400,
|
||||
embedding: [1, 2, 3, 4],
|
||||
score: 0.2,
|
||||
};
|
||||
|
||||
describe(PersonService.name, () => {
|
||||
|
|
@ -642,21 +644,13 @@ describe(PersonService.name, () => {
|
|||
it('should handle no results', async () => {
|
||||
const start = Date.now();
|
||||
|
||||
machineLearningMock.detectFaces.mockResolvedValue([]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
|
||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
await sut.handleDetectFaces({ id: assetStub.image.id });
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{
|
||||
imagePath: assetStub.image.previewPath,
|
||||
},
|
||||
{
|
||||
enabled: true,
|
||||
maxDistance: 0.5,
|
||||
minScore: 0.7,
|
||||
minFaces: 3,
|
||||
modelName: 'buffalo_l',
|
||||
},
|
||||
assetStub.image.previewPath,
|
||||
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
|
||||
);
|
||||
expect(personMock.createFaces).not.toHaveBeenCalled();
|
||||
expect(jobMock.queue).not.toHaveBeenCalled();
|
||||
|
|
@ -671,7 +665,7 @@ describe(PersonService.name, () => {
|
|||
|
||||
it('should create a face with no person and queue recognition job', async () => {
|
||||
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
||||
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
|
||||
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
|
||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
const face = {
|
||||
|
|
|
|||
|
|
@ -333,26 +333,28 @@ export class PersonService {
|
|||
return JobStatus.SKIPPED;
|
||||
}
|
||||
|
||||
const faces = await this.machineLearningRepository.detectFaces(
|
||||
if (!asset.isVisible) {
|
||||
return JobStatus.SKIPPED;
|
||||
}
|
||||
|
||||
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.previewPath },
|
||||
asset.previewPath,
|
||||
machineLearning.facialRecognition,
|
||||
);
|
||||
|
||||
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
|
||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
|
||||
|
||||
if (faces.length > 0) {
|
||||
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
|
||||
|
||||
const mappedFaces = faces.map((face) => ({
|
||||
assetId: asset.id,
|
||||
embedding: face.embedding,
|
||||
imageHeight: face.imageHeight,
|
||||
imageWidth: face.imageWidth,
|
||||
imageHeight,
|
||||
imageWidth,
|
||||
boundingBoxX1: face.boundingBox.x1,
|
||||
boundingBoxX2: face.boundingBox.x2,
|
||||
boundingBoxY1: face.boundingBox.y1,
|
||||
boundingBoxX2: face.boundingBox.x2,
|
||||
boundingBoxY2: face.boundingBox.y2,
|
||||
}));
|
||||
|
||||
|
|
|
|||
|
|
@ -102,12 +102,7 @@ export class SearchService {
|
|||
|
||||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
const embedding = await this.machineLearning.encodeText(
|
||||
machineLearning.url,
|
||||
{ text: dto.query },
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
||||
const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
|
||||
const page = dto.page ?? 1;
|
||||
const size = dto.size || 100;
|
||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||
|
|
|
|||
|
|
@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
|
|||
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
{ imagePath: assetStub.image.previewPath },
|
||||
{ enabled: true, modelName: 'ViT-B-32__openai' },
|
||||
assetStub.image.previewPath,
|
||||
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
|
||||
);
|
||||
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -93,9 +93,9 @@ export class SmartInfoService {
|
|||
return JobStatus.FAILED;
|
||||
}
|
||||
|
||||
const clipEmbedding = await this.machineLearning.encodeImage(
|
||||
const embedding = await this.machineLearning.encodeImage(
|
||||
machineLearning.url,
|
||||
{ imagePath: asset.previewPath },
|
||||
asset.previewPath,
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ export class SmartInfoService {
|
|||
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
|
||||
}
|
||||
|
||||
await this.repository.upsert(asset.id, clipEmbedding);
|
||||
await this.repository.upsert(asset.id, embedding);
|
||||
|
||||
return JobStatus.SUCCESS;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue