feat(ml): support multiple urls (#14347)

* support multiple url

* update api

* styling

unnecessary `?.`

* update docs, make new url field go first

add load balancing section

* update tests

doc formatting

wording

wording

linting

* small styling

* `url` -> `urls`

* fix tests

* update docs

* make docusaurus happy

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
Mert 2024-12-04 15:17:47 -05:00 committed by GitHub
parent 411878c0aa
commit 4bf1b84cc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 202 additions and 73 deletions

View file

@ -52,7 +52,7 @@ export interface SystemConfig {
};
machineLearning: {
enabled: boolean;
url: string;
urls: string[];
clip: {
enabled: boolean;
modelName: string;
@ -206,7 +206,7 @@ export const defaults = Object.freeze<SystemConfig>({
},
machineLearning: {
enabled: process.env.IMMICH_MACHINE_LEARNING_ENABLED !== 'false',
url: process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003',
urls: [process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003'],
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',

View file

@ -3,6 +3,8 @@ import { vitest } from 'vitest';
vitest.mock('src/constants', () => ({
APP_MEDIA_LOCATION: '/photos',
ADDED_IN_PREFIX: 'This property was added in ',
DEPRECATED_IN_PREFIX: 'This property was deprecated in ',
}));
describe('StorageCore', () => {

View file

@ -1,6 +1,7 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { Exclude, Transform, Type } from 'class-transformer';
import {
ArrayMinSize,
IsBoolean,
IsEnum,
IsInt,
@ -16,6 +17,7 @@ import {
ValidateNested,
} from 'class-validator';
import { SystemConfig } from 'src/config';
import { PropertyLifecycle } from 'src/decorators';
import { CLIPConfig, DuplicateDetectionConfig, FacialRecognitionConfig } from 'src/dtos/model-config.dto';
import {
AudioCodec,
@ -269,9 +271,16 @@ class SystemConfigMachineLearningDto {
@ValidateBoolean()
enabled!: boolean;
@IsUrl({ require_tld: false, allow_underscores: true })
@PropertyLifecycle({ deprecatedAt: 'v1.122.0' })
@Exclude()
url?: string;
@IsUrl({ require_tld: false, allow_underscores: true }, { each: true })
@ArrayMinSize(1)
@Transform(({ obj, value }) => (obj.url ? [obj.url] : value))
@ValidateIf((dto) => dto.enabled)
url!: string;
@ApiProperty({ type: 'array', items: { type: 'string', format: 'uri' }, minItems: 1 })
urls!: string[];
@Type(() => CLIPConfig)
@ValidateNested()

View file

@ -51,7 +51,7 @@ export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(urls: string[], text: string, config: ModelOptions): Promise<number[]>;
detectFaces(urls: string[], imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
}

View file

@ -0,0 +1,19 @@
import { MigrationInterface, QueryRunner } from 'typeorm';
export class RenameMachineLearningUrlToUrls1733339482860 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`
UPDATE system_metadata
SET value = jsonb_insert(value #- '{machineLearning,url}', '{machineLearning,urls}'::text[], jsonb_build_array(value->'machineLearning'->'url'))
WHERE key = 'system-config' AND value->'machineLearning'->'url' IS NOT NULL
`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`
UPDATE system_metadata
SET value = jsonb_insert(value #- '{machineLearning,urls}', '{machineLearning,url}'::text[], to_jsonb(value->'machineLearning'->'urls'->>0))
WHERE key = 'system-config' AND value->'machineLearning'->'urls' IS NOT NULL AND jsonb_array_length(value->'machineLearning'->'urls') >= 1
`);
}
}

View file

@ -155,7 +155,7 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
this.emitHandlers[event].push(item);
}
async emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
return this.onEvent({ name: event, args, server: false });
}

View file

@ -1,6 +1,7 @@
import { Injectable } from '@nestjs/common';
import { Inject, Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
import { CLIPConfig } from 'src/dtos/model-config.dto';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import {
ClipTextualResponse,
ClipVisualResponse,
@ -13,33 +14,42 @@ import {
ModelType,
} from 'src/interfaces/machine-learning.interface';
const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
},
);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
}
return res.json();
constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) {
this.logger.setContext(MachineLearningRepository.name);
}
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
private async predict<T>(urls: string[], payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
for (const url of urls) {
try {
const response = await fetch(new URL('/predict', url), { method: 'POST', body: formData });
if (response.ok) {
return response.json();
}
this.logger.warn(
`Machine learning request to "${url}" failed with status ${response.status}: ${response.statusText}`,
);
} catch (error: Error | unknown) {
this.logger.warn(
`Machine learning request to "${url}" failed: ${error instanceof Error ? error.message : error}`,
);
}
}
throw new Error(`Machine learning request '${JSON.stringify(config)}' failed for all URLs`);
}
async detectFaces(urls: string[], imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: { modelName, options: { minScore } },
[ModelType.RECOGNITION]: { modelName },
},
};
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
const response = await this.predict<FacialRecognitionResponse>(urls, { imagePath }, request);
return {
imageHeight: response.imageHeight,
imageWidth: response.imageWidth,
@ -47,15 +57,15 @@ export class MachineLearningRepository implements IMachineLearningRepository {
};
}
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
async encodeImage(urls: string[], imagePath: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
const response = await this.predict<ClipVisualResponse>(urls, { imagePath }, request);
return response[ModelTask.SEARCH];
}
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
return response[ModelTask.SEARCH];
}

View file

@ -717,7 +717,7 @@ describe(PersonService.name, () => {
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
['http://immich-machine-learning:3003'],
'/uploads/user-id/thumbs/path.jpg',
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
);

View file

@ -297,7 +297,7 @@ export class PersonService extends BaseService {
}
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
machineLearning.url,
machineLearning.urls,
previewFile.path,
machineLearning.facialRecognition,
);

View file

@ -86,7 +86,7 @@ export class SearchService extends BaseService {
const userIds = await this.getUserIdsToSearch(auth);
const embedding = await this.machineLearningRepository.encodeText(
machineLearning.url,
machineLearning.urls,
dto.query,
machineLearning.clip,
);

View file

@ -289,7 +289,7 @@ describe(SmartInfoService.name, () => {
expect(await sut.handleEncodeClip({ id: assetStub.image.id })).toEqual(JobStatus.SUCCESS);
expect(machineLearningMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
['http://immich-machine-learning:3003'],
'/uploads/user-id/thumbs/path.jpg',
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
);
@ -322,7 +322,7 @@ describe(SmartInfoService.name, () => {
expect(databaseMock.wait).toHaveBeenCalledWith(512);
expect(machineLearningMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
['http://immich-machine-learning:3003'],
'/uploads/user-id/thumbs/path.jpg',
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
);

View file

@ -122,7 +122,7 @@ export class SmartInfoService extends BaseService {
}
const embedding = await this.machineLearningRepository.encodeImage(
machineLearning.url,
machineLearning.urls,
previewFile.path,
machineLearning.clip,
);

View file

@ -85,7 +85,7 @@ const updatedConfig = Object.freeze<SystemConfig>({
},
machineLearning: {
enabled: true,
url: 'http://immich-machine-learning:3003',
urls: ['http://immich-machine-learning:3003'],
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
@ -330,11 +330,11 @@ describe(SystemConfigService.name, () => {
it('should allow underscores in the machine learning url', async () => {
configMock.getEnv.mockReturnValue(mockEnvData({ configFile: 'immich-config.json' }));
const partialConfig = { machineLearning: { url: 'immich_machine_learning' } };
const partialConfig = { machineLearning: { urls: ['immich_machine_learning'] } };
systemMock.readFile.mockResolvedValue(JSON.stringify(partialConfig));
const config = await sut.getSystemConfig();
expect(config.machineLearning.url).toEqual('immich_machine_learning');
expect(config.machineLearning.urls).toEqual(['immich_machine_learning']);
});
const externalDomainTests = [