建设机械网站制作,谷歌seo,.net 手机网站源码下载,深圳工程造价信息网官网Java调用Pytorch实现以图搜图
设计技术栈#xff1a; 1、ElasticSearch环境#xff1b; 2、Python运行环境#xff08;如果事先没有pytorch模型时#xff0c;可以用python脚本创建模型#xff09;#xff1b;
1、运行效果 2、创建模型#xff08;有则可以跳过#xf…Java调用Pytorch实现以图搜图
设计技术栈 1、ElasticSearch环境 2、Python运行环境如果事先没有pytorch模型时可以用python脚本创建模型
1、运行效果 2、创建模型有则可以跳过
vi script.py
import torch
import torch.nn as nn
import torchvision.models as modelsclass ImageFeatureExtractor(nn.Module):def __init__(self):super(ImageFeatureExtractor, self).__init__()self.resnet models.resnet50(pretrainedTrue)#最终输出维度1024的向量下文elastic search要设置dims为1024self.resnet.fc nn.Linear(2048, 1024)def forward(self, x):x self.resnet(x)return xif __name__ __main__:model ImageFeatureExtractor()model.eval()#根据模型随便创建一个输入input torch.rand([1, 3, 224, 224])output model(input)#以这种方式保存script torch.jit.trace(model, input)script.save(model.pt)2、java项目pom.xml
dependenciesdependencygroupIdorg.springframework.boot/groupIdartifactIdspring-boot-starter-web/artifactId/dependencydependencygroupIdorg.projectlombok/groupIdartifactIdlombok/artifactIdscopeprovided/scope/dependencydependencygroupIdai.djl.pytorch/groupIdartifactIdpytorch-engine/artifactIdversion0.19.0/version/dependencydependencygroupIdai.djl.pytorch/groupIdartifactIdpytorch-native-cpu/artifactIdversion1.10.0/versionscoperuntime/scope/dependencydependencygroupIdai.djl.pytorch/groupIdartifactIdpytorch-jni/artifactIdversion1.10.0-0.19.0/version/dependencydependencygroupIdorg.elasticsearch.client/groupIdartifactIdelasticsearch-rest-high-level-client/artifactId/dependency/dependencies3、ES创建文档
PUT /isi
{mappings: {properties: {vector: {type: dense_vector,dims: 1024},url : {type : keyword},user_id: {type: keyword}}}
}4、编写java代码调用模型
ORCUtil.java
package com.topprismcloud.rtm;import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;public class ORCUtil {private static final String INDEX isi;private static final int IMAGE_SIZE 224;private static Model model; // 模型private static PredictorImage, float[] predictor; // predictor.predict(input)相当于python中model(input)static {try {model Model.newInstance(model);// 这里的model.pt是上面代码展示的那种方式保存的model.load(ORCUtil.class.getClassLoader().getResourceAsStream(model.pt));Transform resize new Resize(IMAGE_SIZE);Transform toTensor new ToTensor();Transform normalize new Normalize(new float[] { 0.485f, 0.456f, 0.406f },new float[] { 0.229f, 0.224f, 0.225f });// Translator处理输入Image转为tensor、输出转为float[]TranslatorImage, float[] translator new TranslatorImage, float[]() {Overridepublic NDList processInput(TranslatorContext ctx, Image input) throws Exception {NDManager ndManager ctx.getNDManager();System.out.println(input: input.getWidth() , input.getHeight());NDArray transform normalize.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));System.out.println(transform.getShape());NDList list new NDList();list.add(transform);return list;}Overridepublic float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {return ndList.get(0).toFloatArray();}};predictor new Predictor(model, translator, Device.cpu(), true);} catch (Exception e) {e.printStackTrace();}}public static void upload() throws Exception {HttpHost hostnew HttpHost(14.20.30.16, 9200, HttpHost.DEFAULT_SCHEME_NAME);RestClientBuilder builderRestClient.builder(host);CredentialsProvider credentialsProvider new BasicCredentialsProvider();credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(elastic, 123456));builder.setHttpClientConfigCallback(f - f.setDefaultCredentialsProvider(credentialsProvider));RestHighLevelClient client new RestHighLevelClient( builder);// 批量上传请求BulkRequest bulkRequest new BulkRequest(INDEX);File file new File(D:\\001ENV\\nginx-1.24.0\\html\\resource\\new);for (File listFile : file.listFiles()) {
// float[] vector predictor.predict(ImageFactory.getInstance()
// .fromInputStream(Test.class.getClassLoader().getResourceAsStream(new/ listFile.getName())));float[] vector predictor.predict(ImageFactory.getInstance().fromInputStream(new FileInputStream(listFile)));// 构建文档MapString, Object jsonMap new HashMap();jsonMap.put(url, /resource/listFile.getName());jsonMap.put(vector, vector);jsonMap.put(user_id, user123);IndexRequest request new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);bulkRequest.add(request);}client.bulk(bulkRequest, RequestOptions.DEFAULT);client.close();}// 接收待搜索图片的inputstream搜索与其相似的图片public static ListSearchResult search(InputStream input) throws Throwable {float[] vector predictor.predict(ImageFactory.getInstance().fromInputStream(input));System.out.println(Arrays.toString(vector));// 展示k个结果int k 100;// 连接Elasticsearch服务器RestHighLevelClient client new RestHighLevelClient(RestClient.builder(new HttpHost(14.20.30.16, 9200, http)));SearchRequest searchRequest new SearchRequest(INDEX);Script script new Script(ScriptType.INLINE, painless, cosineSimilarity(params.queryVector, doc[vector]),Collections.singletonMap(queryVector, vector));FunctionScoreQueryBuilder functionScoreQueryBuilder QueryBuilders.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));SearchSourceBuilder searchSourceBuilder new SearchSourceBuilder();searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, vector) // 不返回vector字段太多了没用还耗时.size(k);searchRequest.source(searchSourceBuilder);SearchResponse searchResponse client.search(searchRequest, RequestOptions.DEFAULT);SearchHits hits searchResponse.getHits();ListSearchResult list new ArrayList();for (SearchHit hit : hits) {// 处理搜索结果System.out.println(hit.toString());SearchResult result new SearchResult((String) hit.getSourceAsMap().get(url), hit.getScore());list.add(result);}client.close();return list;}public static void main(String[] args) throws Throwable {ORCUtil.upload();System.out.println(hao);}
}SearchController.java
package com.topprismcloud.rtm;import java.util.List;import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;RestController
CrossOrigin
public class SearchController {PostMapping(search)public ResponseEntity search(MultipartFile file) {try {ListSearchResult list ORCUtil.search(file.getInputStream());return ResponseEntity.ok(list);} catch (Throwable e) {return ResponseEntity.status(400).body(null);}}
}SearchResult.java
package com.topprismcloud.rtm;import lombok.AllArgsConstructor;
import lombok.Data;Data
AllArgsConstructor
public class SearchResult {private String url;private Float score;
}5、前端
index.html
!DOCTYPE html
html langzhheadmeta charsetUTF-8title以图搜图/titlestylebody {background: url(/img/bg.jpg);background-attachment: fixed;background-size: 100% 100%;}bodydiv {width: 1000px;margin: 50px auto;padding: 10px 20px;border: 1px solid lightgray;border-radius: 20px;box-sizing: border-box;background: rgba(255, 255, 255, 0.7);}.upload {display: inline-block;width: 300px;height: 280px;border: 1px dashed lightcoral;vertical-align: top;}.upload .cover {width: 200px;height: 200px;margin: 10px 50px;border: 1px solid black;box-sizing: border-box;text-align: center;line-height: 200px;position: relative;}.upload img {width: 198px;height: 198px;position: absolute;left: 0;top: 0;}.upload input {margin-left: 50px;}.upload button {width: 80px;height: 30px;margin-left: 110px;}.result-block {display: inline-block;margin-left: 40px;border: 1px solid lightgray;border-radius: 10px;min-height: 500px;width: 600px;}.result-block h1 {text-align: center;margin-top: 100px;}.result {padding: 10px;cursor: pointer;display: inline-block;}.result:hover {background: rgb(240, 240, 240);}.result p {width: 110px;overflow: hidden;white-space: nowrap;text-overflow: ellipsis;}.result img {width: 160px;height: 160px;}.result .prob {color: rgb(37, 147, 60)}/stylescript srcjs/jquery-3.6.0.js/script
/headbodydivdiv classuploaddiv classcover请选择图片img idimage src //divinput idfile typefile/divdiv classresult-blockh1请选择图片/h1/div/divul idbox/ulscriptvar file $(#file)file.change(function () {let f this.files[0]let index f.name.lastIndexOf(.)let fileText f.name.substring(index, f.name.length)let ext fileText.toLowerCase() //文件类型console.log(ext)if (ext ! .png ext ! .jpg ext ! .jpeg) {alert(系统仅支持 JPG、PNG、JPEG 格式的图片请您调整格式后重新上传)return}$(.result-block).empty().append($(h1正在识别中.../h1))$(#image).attr(src, getObjectURL(f));let formData new FormData()formData.append(file, f)$.ajax({url: http://10.1.2.240:8081/search,method: post,data: formData,processData: false,contentType: false,success: res {console.log(shibie, res)$(.result-block).empty()for (let item of res) {console.log(item)let html div classresultimg src${item.url}/div styledisplay: inline-block;vertical-align: topp classprob得分${item.score.toFixed(4)}/p/div/div$(.result-block).append($(html))}}})});$(#button).click(function (e) {var file $(#file)[0].files[0] //单个console.log(file)})function getObjectURL(file) {var url null;if (window.createObjcectURL ! undefined) {url window.createOjcectURL(file);} else if (window.URL ! undefined) {url window.URL.createObjectURL(file);} else if (window.webkitURL ! undefined) {url window.webkitURL.createObjectURL(file);}return url;}function detect() {}/script
/body/html6、打包后的源代码
以图搜图Javahtml源代码
相关参考文章Java调用Pytorch模型进行图像识别