mirror of
https://github.com/immich-app/immich
synced 2025-11-14 17:36:12 +00:00
feat(ml): support multiple urls (#14347)
* support multiple url * update api * styling unnecessary `?.` * update docs, make new url field go first add load balancing section * update tests doc formatting wording wording linting * small styling * `url` -> `urls` * fix tests * update docs * make docusaurus happy --------- Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
parent
411878c0aa
commit
4bf1b84cc2
22 changed files with 202 additions and 73 deletions
|
|
@ -52,7 +52,7 @@ export interface SystemConfig {
|
|||
};
|
||||
machineLearning: {
|
||||
enabled: boolean;
|
||||
url: string;
|
||||
urls: string[];
|
||||
clip: {
|
||||
enabled: boolean;
|
||||
modelName: string;
|
||||
|
|
@ -206,7 +206,7 @@ export const defaults = Object.freeze<SystemConfig>({
|
|||
},
|
||||
machineLearning: {
|
||||
enabled: process.env.IMMICH_MACHINE_LEARNING_ENABLED !== 'false',
|
||||
url: process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003',
|
||||
urls: [process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003'],
|
||||
clip: {
|
||||
enabled: true,
|
||||
modelName: 'ViT-B-32__openai',
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import { vitest } from 'vitest';
|
|||
|
||||
vitest.mock('src/constants', () => ({
|
||||
APP_MEDIA_LOCATION: '/photos',
|
||||
ADDED_IN_PREFIX: 'This property was added in ',
|
||||
DEPRECATED_IN_PREFIX: 'This property was deprecated in ',
|
||||
}));
|
||||
|
||||
describe('StorageCore', () => {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { ApiProperty } from '@nestjs/swagger';
|
||||
import { Type } from 'class-transformer';
|
||||
import { Exclude, Transform, Type } from 'class-transformer';
|
||||
import {
|
||||
ArrayMinSize,
|
||||
IsBoolean,
|
||||
IsEnum,
|
||||
IsInt,
|
||||
|
|
@ -16,6 +17,7 @@ import {
|
|||
ValidateNested,
|
||||
} from 'class-validator';
|
||||
import { SystemConfig } from 'src/config';
|
||||
import { PropertyLifecycle } from 'src/decorators';
|
||||
import { CLIPConfig, DuplicateDetectionConfig, FacialRecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
import {
|
||||
AudioCodec,
|
||||
|
|
@ -269,9 +271,16 @@ class SystemConfigMachineLearningDto {
|
|||
@ValidateBoolean()
|
||||
enabled!: boolean;
|
||||
|
||||
@IsUrl({ require_tld: false, allow_underscores: true })
|
||||
@PropertyLifecycle({ deprecatedAt: 'v1.122.0' })
|
||||
@Exclude()
|
||||
url?: string;
|
||||
|
||||
@IsUrl({ require_tld: false, allow_underscores: true }, { each: true })
|
||||
@ArrayMinSize(1)
|
||||
@Transform(({ obj, value }) => (obj.url ? [obj.url] : value))
|
||||
@ValidateIf((dto) => dto.enabled)
|
||||
url!: string;
|
||||
@ApiProperty({ type: 'array', items: { type: 'string', format: 'uri' }, minItems: 1 })
|
||||
urls!: string[];
|
||||
|
||||
@Type(() => CLIPConfig)
|
||||
@ValidateNested()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
|||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||
|
||||
export interface IMachineLearningRepository {
|
||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||
encodeText(urls: string[], text: string, config: ModelOptions): Promise<number[]>;
|
||||
detectFaces(urls: string[], imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
import { MigrationInterface, QueryRunner } from 'typeorm';
|
||||
|
||||
export class RenameMachineLearningUrlToUrls1733339482860 implements MigrationInterface {
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
await queryRunner.query(`
|
||||
UPDATE system_metadata
|
||||
SET value = jsonb_insert(value #- '{machineLearning,url}', '{machineLearning,urls}'::text[], jsonb_build_array(value->'machineLearning'->'url'))
|
||||
WHERE key = 'system-config' AND value->'machineLearning'->'url' IS NOT NULL
|
||||
`);
|
||||
}
|
||||
|
||||
public async down(queryRunner: QueryRunner): Promise<void> {
|
||||
await queryRunner.query(`
|
||||
UPDATE system_metadata
|
||||
SET value = jsonb_insert(value #- '{machineLearning,urls}', '{machineLearning,url}'::text[], to_jsonb(value->'machineLearning'->'urls'->>0))
|
||||
WHERE key = 'system-config' AND value->'machineLearning'->'urls' IS NOT NULL AND jsonb_array_length(value->'machineLearning'->'urls') >= 1
|
||||
`);
|
||||
}
|
||||
}
|
||||
|
|
@ -155,7 +155,7 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
|||
this.emitHandlers[event].push(item);
|
||||
}
|
||||
|
||||
async emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
|
||||
emit<T extends EmitEvent>(event: T, ...args: ArgsOf<T>): Promise<void> {
|
||||
return this.onEvent({ name: event, args, server: false });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { Injectable } from '@nestjs/common';
|
||||
import { Inject, Injectable } from '@nestjs/common';
|
||||
import { readFile } from 'node:fs/promises';
|
||||
import { CLIPConfig } from 'src/dtos/model-config.dto';
|
||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||
import {
|
||||
ClipTextualResponse,
|
||||
ClipVisualResponse,
|
||||
|
|
@ -13,33 +14,42 @@ import {
|
|||
ModelType,
|
||||
} from 'src/interfaces/machine-learning.interface';
|
||||
|
||||
const errorPrefix = 'Machine learning request';
|
||||
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||
const formData = await this.getFormData(payload, config);
|
||||
|
||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
||||
(error: Error | any) => {
|
||||
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
||||
},
|
||||
);
|
||||
|
||||
if (res.status >= 400) {
|
||||
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||
}
|
||||
return res.json();
|
||||
constructor(@Inject(ILoggerRepository) private logger: ILoggerRepository) {
|
||||
this.logger.setContext(MachineLearningRepository.name);
|
||||
}
|
||||
|
||||
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||
private async predict<T>(urls: string[], payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||
const formData = await this.getFormData(payload, config);
|
||||
for (const url of urls) {
|
||||
try {
|
||||
const response = await fetch(new URL('/predict', url), { method: 'POST', body: formData });
|
||||
if (response.ok) {
|
||||
return response.json();
|
||||
}
|
||||
|
||||
this.logger.warn(
|
||||
`Machine learning request to "${url}" failed with status ${response.status}: ${response.statusText}`,
|
||||
);
|
||||
} catch (error: Error | unknown) {
|
||||
this.logger.warn(
|
||||
`Machine learning request to "${url}" failed: ${error instanceof Error ? error.message : error}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Machine learning request '${JSON.stringify(config)}' failed for all URLs`);
|
||||
}
|
||||
|
||||
async detectFaces(urls: string[], imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||
const request = {
|
||||
[ModelTask.FACIAL_RECOGNITION]: {
|
||||
[ModelType.DETECTION]: { modelName, options: { minScore } },
|
||||
[ModelType.RECOGNITION]: { modelName },
|
||||
},
|
||||
};
|
||||
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
|
||||
const response = await this.predict<FacialRecognitionResponse>(urls, { imagePath }, request);
|
||||
return {
|
||||
imageHeight: response.imageHeight,
|
||||
imageWidth: response.imageWidth,
|
||||
|
|
@ -47,15 +57,15 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
|||
};
|
||||
}
|
||||
|
||||
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
|
||||
async encodeImage(urls: string[], imagePath: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
|
||||
const response = await this.predict<ClipVisualResponse>(urls, { imagePath }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
|
||||
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
||||
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -717,7 +717,7 @@ describe(PersonService.name, () => {
|
|||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||
await sut.handleDetectFaces({ id: assetStub.image.id });
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
['http://immich-machine-learning:3003'],
|
||||
'/uploads/user-id/thumbs/path.jpg',
|
||||
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ export class PersonService extends BaseService {
|
|||
}
|
||||
|
||||
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
|
||||
machineLearning.url,
|
||||
machineLearning.urls,
|
||||
previewFile.path,
|
||||
machineLearning.facialRecognition,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ export class SearchService extends BaseService {
|
|||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
const embedding = await this.machineLearningRepository.encodeText(
|
||||
machineLearning.url,
|
||||
machineLearning.urls,
|
||||
dto.query,
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -289,7 +289,7 @@ describe(SmartInfoService.name, () => {
|
|||
expect(await sut.handleEncodeClip({ id: assetStub.image.id })).toEqual(JobStatus.SUCCESS);
|
||||
|
||||
expect(machineLearningMock.encodeImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
['http://immich-machine-learning:3003'],
|
||||
'/uploads/user-id/thumbs/path.jpg',
|
||||
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
|
||||
);
|
||||
|
|
@ -322,7 +322,7 @@ describe(SmartInfoService.name, () => {
|
|||
|
||||
expect(databaseMock.wait).toHaveBeenCalledWith(512);
|
||||
expect(machineLearningMock.encodeImage).toHaveBeenCalledWith(
|
||||
'http://immich-machine-learning:3003',
|
||||
['http://immich-machine-learning:3003'],
|
||||
'/uploads/user-id/thumbs/path.jpg',
|
||||
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ export class SmartInfoService extends BaseService {
|
|||
}
|
||||
|
||||
const embedding = await this.machineLearningRepository.encodeImage(
|
||||
machineLearning.url,
|
||||
machineLearning.urls,
|
||||
previewFile.path,
|
||||
machineLearning.clip,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ const updatedConfig = Object.freeze<SystemConfig>({
|
|||
},
|
||||
machineLearning: {
|
||||
enabled: true,
|
||||
url: 'http://immich-machine-learning:3003',
|
||||
urls: ['http://immich-machine-learning:3003'],
|
||||
clip: {
|
||||
enabled: true,
|
||||
modelName: 'ViT-B-32__openai',
|
||||
|
|
@ -330,11 +330,11 @@ describe(SystemConfigService.name, () => {
|
|||
|
||||
it('should allow underscores in the machine learning url', async () => {
|
||||
configMock.getEnv.mockReturnValue(mockEnvData({ configFile: 'immich-config.json' }));
|
||||
const partialConfig = { machineLearning: { url: 'immich_machine_learning' } };
|
||||
const partialConfig = { machineLearning: { urls: ['immich_machine_learning'] } };
|
||||
systemMock.readFile.mockResolvedValue(JSON.stringify(partialConfig));
|
||||
|
||||
const config = await sut.getSystemConfig();
|
||||
expect(config.machineLearning.url).toEqual('immich_machine_learning');
|
||||
expect(config.machineLearning.urls).toEqual(['immich_machine_learning']);
|
||||
});
|
||||
|
||||
const externalDomainTests = [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue